cleanup column_type_mappings (#17569)

Signed-off-by: Đặng Minh Dũng <dungdm93@live.com>
This commit is contained in:
Đặng Minh Dũng 2022-01-14 16:07:17 +07:00 committed by GitHub
parent 26dc600aff
commit 5a740901d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 105 additions and 170 deletions

View File

@ -53,7 +53,7 @@ from sqlalchemy.ext.compiler import compiles
from sqlalchemy.orm import Session
from sqlalchemy.sql import quoted_name, text
from sqlalchemy.sql.expression import ColumnClause, Select, TextAsFrom, TextClause
from sqlalchemy.types import String, TypeEngine, UnicodeText
from sqlalchemy.types import TypeEngine
from typing_extensions import TypedDict
from superset import security_manager, sql_parse
@ -71,6 +71,12 @@ if TYPE_CHECKING:
from superset.connectors.sqla.models import TableColumn
from superset.models.core import Database
ColumnTypeMapping = Tuple[
Pattern[str],
Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
GenericDataType,
]
logger = logging.getLogger()
@ -156,26 +162,37 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
engine = "base" # str as defined in sqlalchemy.engine.engine
engine_aliases: Set[str] = set()
engine_name: Optional[
str
] = None # used for user messages, overridden in child classes
engine_name: Optional[str] = None # for user messages, overridden in child classes
_date_trunc_functions: Dict[str, str] = {}
_time_grain_expressions: Dict[Optional[str], str] = {}
column_type_mappings: Tuple[
Tuple[
Pattern[str],
Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
GenericDataType,
],
...,
] = (
column_type_mappings: Tuple[ColumnTypeMapping, ...] = (
(
re.compile(r"^string", re.IGNORECASE),
types.String(),
GenericDataType.STRING,
),
(
re.compile(r"^n((var)?char|text)", re.IGNORECASE),
types.UnicodeText(),
GenericDataType.STRING,
),
(
re.compile(r"^(var)?char", re.IGNORECASE),
types.String(),
GenericDataType.STRING,
),
(
re.compile(r"^(tiny|medium|long)?text", re.IGNORECASE),
types.String(),
GenericDataType.STRING,
),
(
re.compile(r"^smallint", re.IGNORECASE),
types.SmallInteger(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^int.*", re.IGNORECASE),
re.compile(r"^int(eger)?", re.IGNORECASE),
types.Integer(),
GenericDataType.NUMERIC,
),
@ -184,6 +201,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
types.BigInteger(),
GenericDataType.NUMERIC,
),
(re.compile(r"^long", re.IGNORECASE), types.Float(), GenericDataType.NUMERIC,),
(
re.compile(r"^decimal", re.IGNORECASE),
types.Numeric(),
@ -222,26 +240,10 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
GenericDataType.NUMERIC,
),
(
re.compile(r"^string", re.IGNORECASE),
types.String(),
utils.GenericDataType.STRING,
re.compile(r"^timestamp", re.IGNORECASE),
types.TIMESTAMP(),
GenericDataType.TEMPORAL,
),
(
re.compile(r"^N((VAR)?CHAR|TEXT)", re.IGNORECASE),
UnicodeText(),
utils.GenericDataType.STRING,
),
(
re.compile(r"^((VAR)?CHAR|TEXT|STRING)", re.IGNORECASE),
String(),
utils.GenericDataType.STRING,
),
(
re.compile(r"^((TINY|MEDIUM|LONG)?TEXT)", re.IGNORECASE),
String(),
utils.GenericDataType.STRING,
),
(re.compile(r"^LONG", re.IGNORECASE), types.Float(), GenericDataType.NUMERIC,),
(
re.compile(r"^datetime", re.IGNORECASE),
types.DateTime(),
@ -252,19 +254,14 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
types.DateTime(),
GenericDataType.TEMPORAL,
),
(
re.compile(r"^timestamp", re.IGNORECASE),
types.TIMESTAMP(),
GenericDataType.TEMPORAL,
),
(re.compile(r"^time", re.IGNORECASE), types.Time(), GenericDataType.TEMPORAL,),
(
re.compile(r"^interval", re.IGNORECASE),
types.Interval(),
GenericDataType.TEMPORAL,
),
(re.compile(r"^time", re.IGNORECASE), types.Time(), GenericDataType.TEMPORAL,),
(
re.compile(r"^bool.*", re.IGNORECASE),
re.compile(r"^bool(ean)?", re.IGNORECASE),
types.Boolean(),
GenericDataType.BOOLEAN,
),
@ -693,7 +690,6 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
to_sql_kwargs["name"] = table.table
if table.schema:
# Only add schema when it is preset and non empty.
to_sql_kwargs["schema"] = table.schema
@ -844,6 +840,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
"""
Get all tables from schema
:param database: The database to get info
:param inspector: SqlAlchemy inspector
:param schema: Schema to inspect. If omitted, uses default schema for database
:return: All tables in schema
@ -860,6 +857,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
"""
Get all views from schema
:param database: The database to get info
:param inspector: SqlAlchemy inspector
:param schema: Schema name. If omitted, uses default schema for database
:return: All views in schema
@ -924,7 +922,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
:param database: Database instance
:param query: SqlAlchemy query
:param columns: List of TableColumns
:return: SqlAlchemy query with additional where clause referencing latest
:return: SqlAlchemy query with additional where clause referencing the latest
partition
"""
# TODO: Fix circular import caused by importing Database, TableColumn
@ -954,12 +952,12 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
:param database: Database instance
:param table_name: Table name, unquoted
:param engine: SqlALchemy Engine instance
:param engine: SqlAlchemy Engine instance
:param schema: Schema, unquoted
:param limit: limit to impose on query
:param show_cols: Show columns in query; otherwise use "*"
:param indent: Add indentation to query
:param latest_partition: Only query latest partition
:param latest_partition: Only query the latest partition
:param cols: Columns to include in query
:return: SQL query
"""
@ -993,7 +991,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return sql
@classmethod
def estimate_statement_cost(cls, statement: str, cursor: Any,) -> Dict[str, Any]:
def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]:
"""
Generate a SQL query that estimates the cost of a given statement.
@ -1024,7 +1022,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
:param statement: A single SQL statement
:param database: Database instance
:param username: Effective username
:param user_name: Effective username
:return: Dictionary with different costs
"""
parsed_query = ParsedQuery(statement)
@ -1089,7 +1087,6 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
:param connect_args: config to be updated
:param uri: URI
:param impersonate_user: Flag indicating if impersonation is enabled
:param username: Effective username
:return: None
"""
@ -1122,8 +1119,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
Conditionally mutate and/or quote a sqlalchemy expression label. If
force_column_alias_quotes is set to True, return the label as a
sqlalchemy.sql.elements.quoted_name object to ensure that the select query
and query results have same case. Otherwise return the mutated label as a
regular string. If maxmimum supported column name length is exceeded,
and query results have same case. Otherwise, return the mutated label as a
regular string. If maximum supported column name length is exceeded,
generate a truncated label by calling truncate_label().
:param label: expected expression label/alias
@ -1143,15 +1140,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
def get_sqla_column_type(
cls,
column_type: Optional[str],
column_type_mappings: Tuple[
Tuple[
Pattern[str],
Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
GenericDataType,
],
...,
] = column_type_mappings,
) -> Union[Tuple[TypeEngine, GenericDataType], None]:
column_type_mappings: Tuple[ColumnTypeMapping, ...] = column_type_mappings,
) -> Optional[Tuple[TypeEngine, GenericDataType]]:
"""
Return a sqlalchemy native column type that corresponds to the column type
defined in the data source (return None to use default type inferred by
@ -1159,16 +1149,18 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
(see MSSQL for example of NCHAR/NVARCHAR handling).
:param column_type: Column type returned by inspector
:param column_type_mappings: Maps from string to SqlAlchemy TypeEngine
:return: SqlAlchemy column type
"""
if not column_type:
return None
for regex, sqla_type, generic_type in column_type_mappings:
match = regex.match(column_type)
if match:
if callable(sqla_type):
return sqla_type(match), generic_type
return sqla_type, generic_type
if not match:
continue
if callable(sqla_type):
return sqla_type(match), generic_type
return sqla_type, generic_type
return None
@staticmethod
@ -1192,7 +1184,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
"""
In the case that a label exceeds the max length supported by the engine,
this method is used to construct a deterministic and unique label based on
the original label. By default this returns an md5 hash of the original label,
the original label. By default, this returns a md5 hash of the original label,
conditionally truncated if the length of the hash exceeds the max column length
of the engine.
@ -1211,8 +1203,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
) -> str:
"""
Convert sqlalchemy column type to string representation.
By default removes collation and character encoding info to avoid unnecessarily
long datatypes.
By default, removes collation and character encoding info to avoid
unnecessarily long datatypes.
:param sqla_column_type: SqlAlchemy column type
:param dialect: Sqlalchemy dialect
@ -1304,20 +1296,14 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
native_type: Optional[str],
db_extra: Optional[Dict[str, Any]] = None,
source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE,
column_type_mappings: Tuple[
Tuple[
Pattern[str],
Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
GenericDataType,
],
...,
] = column_type_mappings,
) -> Union[ColumnSpec, None]:
column_type_mappings: Tuple[ColumnTypeMapping, ...] = column_type_mappings,
) -> Optional[ColumnSpec]:
"""
Converts native database type to sqlalchemy column type.
:param native_type: Native database typee
:param source: Type coming from the database table or cursor description
:param native_type: Native database type
:param db_extra: The database extra object
:param source: Type coming from the database table or cursor description
:param column_type_mappings: Maps from string to SqlAlchemy TypeEngine
:return: ColumnSpec object
"""
col_types = cls.get_sqla_column_type(
@ -1417,7 +1403,6 @@ class BasicParametersType(TypedDict, total=False):
class BasicParametersMixin:
"""
Mixin for configuring DB engine specs via a dictionary.

View File

@ -16,7 +16,7 @@
# under the License.
import re
from datetime import datetime
from typing import Any, Callable, Dict, Match, Optional, Pattern, Tuple, Union
from typing import Any, Dict, Optional, Pattern, Tuple
from urllib import parse
from flask_babel import gettext as __
@ -33,9 +33,12 @@ from sqlalchemy.dialects.mysql import (
TINYTEXT,
)
from sqlalchemy.engine.url import URL
from sqlalchemy.types import TypeEngine
from superset.db_engine_specs.base import BaseEngineSpec, BasicParametersMixin
from superset.db_engine_specs.base import (
BaseEngineSpec,
BasicParametersMixin,
ColumnTypeMapping,
)
from superset.errors import SupersetErrorType
from superset.models.sql_lab import Query
from superset.utils import core as utils
@ -70,14 +73,7 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
)
encryption_parameters = {"ssl": "1"}
column_type_mappings: Tuple[
Tuple[
Pattern[str],
Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
GenericDataType,
],
...,
] = (
column_type_mappings = (
(re.compile(r"^int.*", re.IGNORECASE), INTEGER(), GenericDataType.NUMERIC,),
(re.compile(r"^tinyint", re.IGNORECASE), TINYINT(), GenericDataType.NUMERIC,),
(
@ -208,15 +204,8 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
native_type: Optional[str],
db_extra: Optional[Dict[str, Any]] = None,
source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE,
column_type_mappings: Tuple[
Tuple[
Pattern[str],
Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
GenericDataType,
],
...,
] = column_type_mappings,
) -> Union[ColumnSpec, None]:
column_type_mappings: Tuple[ColumnTypeMapping, ...] = column_type_mappings,
) -> Optional[ColumnSpec]:
column_spec = super().get_column_spec(native_type)
if column_spec:

View File

@ -18,25 +18,18 @@ import json
import logging
import re
from datetime import datetime
from typing import (
Any,
Callable,
Dict,
List,
Match,
Optional,
Pattern,
Tuple,
TYPE_CHECKING,
Union,
)
from typing import Any, Dict, List, Optional, Pattern, Tuple, TYPE_CHECKING
from flask_babel import gettext as __
from sqlalchemy.dialects.postgresql import ARRAY, DOUBLE_PRECISION, ENUM, JSON
from sqlalchemy.dialects.postgresql.base import PGInspector
from sqlalchemy.types import String, TypeEngine
from sqlalchemy.types import String
from superset.db_engine_specs.base import BaseEngineSpec, BasicParametersMixin
from superset.db_engine_specs.base import (
BaseEngineSpec,
BasicParametersMixin,
ColumnTypeMapping,
)
from superset.errors import SupersetErrorType
from superset.exceptions import SupersetException
from superset.models.sql_lab import Query
@ -193,10 +186,10 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
(
re.compile(r"^array.*", re.IGNORECASE),
lambda match: ARRAY(int(match[2])) if match[2] else String(),
utils.GenericDataType.STRING,
GenericDataType.STRING,
),
(re.compile(r"^json.*", re.IGNORECASE), JSON(), utils.GenericDataType.STRING,),
(re.compile(r"^enum.*", re.IGNORECASE), ENUM(), utils.GenericDataType.STRING,),
(re.compile(r"^json.*", re.IGNORECASE), JSON(), GenericDataType.STRING,),
(re.compile(r"^enum.*", re.IGNORECASE), ENUM(), GenericDataType.STRING,),
)
@classmethod
@ -275,15 +268,8 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
native_type: Optional[str],
db_extra: Optional[Dict[str, Any]] = None,
source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE,
column_type_mappings: Tuple[
Tuple[
Pattern[str],
Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
GenericDataType,
],
...,
] = column_type_mappings,
) -> Union[ColumnSpec, None]:
column_type_mappings: Tuple[ColumnTypeMapping, ...] = column_type_mappings,
) -> Optional[ColumnSpec]:
column_spec = super().get_column_spec(native_type)
if column_spec:

View File

@ -23,19 +23,7 @@ from collections import defaultdict, deque
from contextlib import closing
from datetime import datetime
from distutils.version import StrictVersion
from typing import (
Any,
Callable,
cast,
Dict,
List,
Match,
Optional,
Pattern,
Tuple,
TYPE_CHECKING,
Union,
)
from typing import Any, cast, Dict, List, Optional, Pattern, Tuple, TYPE_CHECKING, Union
from urllib import parse
import pandas as pd
@ -49,11 +37,10 @@ from sqlalchemy.engine.result import RowProxy
from sqlalchemy.engine.url import make_url, URL
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import ColumnClause, Select
from sqlalchemy.types import TypeEngine
from superset import cache_manager, is_feature_enabled
from superset.common.db_query_status import QueryStatus
from superset.db_engine_specs.base import BaseEngineSpec
from superset.db_engine_specs.base import BaseEngineSpec, ColumnTypeMapping
from superset.errors import SupersetErrorType
from superset.exceptions import SupersetTemplateException
from superset.models.sql_lab import Query
@ -95,7 +82,6 @@ CONNECTION_UNKNOWN_DATABASE_ERROR = re.compile(
r"line (?P<location>.+?): Catalog '(?P<catalog_name>.+?)' does not exist"
)
logger = logging.getLogger(__name__)
@ -449,86 +435,82 @@ class PrestoEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-metho
(
re.compile(r"^boolean.*", re.IGNORECASE),
types.BOOLEAN,
utils.GenericDataType.BOOLEAN,
GenericDataType.BOOLEAN,
),
(
re.compile(r"^tinyint.*", re.IGNORECASE),
TinyInteger(),
utils.GenericDataType.NUMERIC,
GenericDataType.NUMERIC,
),
(
re.compile(r"^smallint.*", re.IGNORECASE),
types.SMALLINT(),
utils.GenericDataType.NUMERIC,
GenericDataType.NUMERIC,
),
(
re.compile(r"^integer.*", re.IGNORECASE),
types.INTEGER(),
utils.GenericDataType.NUMERIC,
GenericDataType.NUMERIC,
),
(
re.compile(r"^bigint.*", re.IGNORECASE),
types.BIGINT(),
utils.GenericDataType.NUMERIC,
GenericDataType.NUMERIC,
),
(
re.compile(r"^real.*", re.IGNORECASE),
types.FLOAT(),
utils.GenericDataType.NUMERIC,
GenericDataType.NUMERIC,
),
(
re.compile(r"^double.*", re.IGNORECASE),
types.FLOAT(),
utils.GenericDataType.NUMERIC,
GenericDataType.NUMERIC,
),
(
re.compile(r"^decimal.*", re.IGNORECASE),
types.DECIMAL(),
utils.GenericDataType.NUMERIC,
GenericDataType.NUMERIC,
),
(
re.compile(r"^varchar(\((\d+)\))*$", re.IGNORECASE),
lambda match: types.VARCHAR(int(match[2])) if match[2] else types.String(),
utils.GenericDataType.STRING,
GenericDataType.STRING,
),
(
re.compile(r"^char(\((\d+)\))*$", re.IGNORECASE),
lambda match: types.CHAR(int(match[2])) if match[2] else types.CHAR(),
utils.GenericDataType.STRING,
GenericDataType.STRING,
),
(
re.compile(r"^varbinary.*", re.IGNORECASE),
types.VARBINARY(),
utils.GenericDataType.STRING,
),
(
re.compile(r"^json.*", re.IGNORECASE),
types.JSON(),
utils.GenericDataType.STRING,
GenericDataType.STRING,
),
(re.compile(r"^json.*", re.IGNORECASE), types.JSON(), GenericDataType.STRING,),
(
re.compile(r"^date.*", re.IGNORECASE),
types.DATETIME(),
utils.GenericDataType.TEMPORAL,
GenericDataType.TEMPORAL,
),
(
re.compile(r"^timestamp.*", re.IGNORECASE),
types.TIMESTAMP(),
utils.GenericDataType.TEMPORAL,
GenericDataType.TEMPORAL,
),
(
re.compile(r"^interval.*", re.IGNORECASE),
Interval(),
utils.GenericDataType.TEMPORAL,
GenericDataType.TEMPORAL,
),
(
re.compile(r"^time.*", re.IGNORECASE),
types.Time(),
utils.GenericDataType.TEMPORAL,
GenericDataType.TEMPORAL,
),
(re.compile(r"^array.*", re.IGNORECASE), Array(), utils.GenericDataType.STRING),
(re.compile(r"^map.*", re.IGNORECASE), Map(), utils.GenericDataType.STRING),
(re.compile(r"^row.*", re.IGNORECASE), Row(), utils.GenericDataType.STRING),
(re.compile(r"^array.*", re.IGNORECASE), Array(), GenericDataType.STRING),
(re.compile(r"^map.*", re.IGNORECASE), Map(), GenericDataType.STRING),
(re.compile(r"^row.*", re.IGNORECASE), Row(), GenericDataType.STRING),
)
@classmethod
@ -1217,15 +1199,8 @@ class PrestoEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-metho
native_type: Optional[str],
db_extra: Optional[Dict[str, Any]] = None,
source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE,
column_type_mappings: Tuple[
Tuple[
Pattern[str],
Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
GenericDataType,
],
...,
] = column_type_mappings,
) -> Union[ColumnSpec, None]:
column_type_mappings: Tuple[ColumnTypeMapping, ...] = column_type_mappings,
) -> Optional[ColumnSpec]:
column_spec = super().get_column_spec(
native_type, column_type_mappings=column_type_mappings