From 5a740901d6b068ba0855103f5b926865043a2aea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C4=90=E1=BA=B7ng=20Minh=20D=C5=A9ng?= Date: Fri, 14 Jan 2022 16:07:17 +0700 Subject: [PATCH] cleanup column_type_mappings (#17569) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Đặng Minh Dũng --- superset/db_engine_specs/base.py | 137 ++++++++++++--------------- superset/db_engine_specs/mysql.py | 29 ++---- superset/db_engine_specs/postgres.py | 38 +++----- superset/db_engine_specs/presto.py | 71 +++++--------- 4 files changed, 105 insertions(+), 170 deletions(-) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index cd416dc12a..f579fc2502 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -53,7 +53,7 @@ from sqlalchemy.ext.compiler import compiles from sqlalchemy.orm import Session from sqlalchemy.sql import quoted_name, text from sqlalchemy.sql.expression import ColumnClause, Select, TextAsFrom, TextClause -from sqlalchemy.types import String, TypeEngine, UnicodeText +from sqlalchemy.types import TypeEngine from typing_extensions import TypedDict from superset import security_manager, sql_parse @@ -71,6 +71,12 @@ if TYPE_CHECKING: from superset.connectors.sqla.models import TableColumn from superset.models.core import Database +ColumnTypeMapping = Tuple[ + Pattern[str], + Union[TypeEngine, Callable[[Match[str]], TypeEngine]], + GenericDataType, +] + logger = logging.getLogger() @@ -156,26 +162,37 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods engine = "base" # str as defined in sqlalchemy.engine.engine engine_aliases: Set[str] = set() - engine_name: Optional[ - str - ] = None # used for user messages, overridden in child classes + engine_name: Optional[str] = None # for user messages, overridden in child classes _date_trunc_functions: Dict[str, str] = {} _time_grain_expressions: Dict[Optional[str], str] = {} - column_type_mappings: Tuple[ - Tuple[ - Pattern[str], - Union[TypeEngine, Callable[[Match[str]], TypeEngine]], - GenericDataType, - ], - ..., - ] = ( + column_type_mappings: Tuple[ColumnTypeMapping, ...] = ( + ( + re.compile(r"^string", re.IGNORECASE), + types.String(), + GenericDataType.STRING, + ), + ( + re.compile(r"^n((var)?char|text)", re.IGNORECASE), + types.UnicodeText(), + GenericDataType.STRING, + ), + ( + re.compile(r"^(var)?char", re.IGNORECASE), + types.String(), + GenericDataType.STRING, + ), + ( + re.compile(r"^(tiny|medium|long)?text", re.IGNORECASE), + types.String(), + GenericDataType.STRING, + ), ( re.compile(r"^smallint", re.IGNORECASE), types.SmallInteger(), GenericDataType.NUMERIC, ), ( - re.compile(r"^int.*", re.IGNORECASE), + re.compile(r"^int(eger)?", re.IGNORECASE), types.Integer(), GenericDataType.NUMERIC, ), @@ -184,6 +201,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods types.BigInteger(), GenericDataType.NUMERIC, ), + (re.compile(r"^long", re.IGNORECASE), types.Float(), GenericDataType.NUMERIC,), ( re.compile(r"^decimal", re.IGNORECASE), types.Numeric(), @@ -222,26 +240,10 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods GenericDataType.NUMERIC, ), ( - re.compile(r"^string", re.IGNORECASE), - types.String(), - utils.GenericDataType.STRING, + re.compile(r"^timestamp", re.IGNORECASE), + types.TIMESTAMP(), + GenericDataType.TEMPORAL, ), - ( - re.compile(r"^N((VAR)?CHAR|TEXT)", re.IGNORECASE), - UnicodeText(), - utils.GenericDataType.STRING, - ), - ( - re.compile(r"^((VAR)?CHAR|TEXT|STRING)", re.IGNORECASE), - String(), - utils.GenericDataType.STRING, - ), - ( - re.compile(r"^((TINY|MEDIUM|LONG)?TEXT)", re.IGNORECASE), - String(), - utils.GenericDataType.STRING, - ), - (re.compile(r"^LONG", re.IGNORECASE), types.Float(), GenericDataType.NUMERIC,), ( re.compile(r"^datetime", re.IGNORECASE), types.DateTime(), @@ -252,19 +254,14 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods types.DateTime(), GenericDataType.TEMPORAL, ), - ( - re.compile(r"^timestamp", re.IGNORECASE), - types.TIMESTAMP(), - GenericDataType.TEMPORAL, - ), + (re.compile(r"^time", re.IGNORECASE), types.Time(), GenericDataType.TEMPORAL,), ( re.compile(r"^interval", re.IGNORECASE), types.Interval(), GenericDataType.TEMPORAL, ), - (re.compile(r"^time", re.IGNORECASE), types.Time(), GenericDataType.TEMPORAL,), ( - re.compile(r"^bool.*", re.IGNORECASE), + re.compile(r"^bool(ean)?", re.IGNORECASE), types.Boolean(), GenericDataType.BOOLEAN, ), @@ -693,7 +690,6 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods to_sql_kwargs["name"] = table.table if table.schema: - # Only add schema when it is preset and non empty. to_sql_kwargs["schema"] = table.schema @@ -844,6 +840,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods """ Get all tables from schema + :param database: The database to get info :param inspector: SqlAlchemy inspector :param schema: Schema to inspect. If omitted, uses default schema for database :return: All tables in schema @@ -860,6 +857,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods """ Get all views from schema + :param database: The database to get info :param inspector: SqlAlchemy inspector :param schema: Schema name. If omitted, uses default schema for database :return: All views in schema @@ -924,7 +922,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods :param database: Database instance :param query: SqlAlchemy query :param columns: List of TableColumns - :return: SqlAlchemy query with additional where clause referencing latest + :return: SqlAlchemy query with additional where clause referencing the latest partition """ # TODO: Fix circular import caused by importing Database, TableColumn @@ -954,12 +952,12 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods :param database: Database instance :param table_name: Table name, unquoted - :param engine: SqlALchemy Engine instance + :param engine: SqlAlchemy Engine instance :param schema: Schema, unquoted :param limit: limit to impose on query :param show_cols: Show columns in query; otherwise use "*" :param indent: Add indentation to query - :param latest_partition: Only query latest partition + :param latest_partition: Only query the latest partition :param cols: Columns to include in query :return: SQL query """ @@ -993,7 +991,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return sql @classmethod - def estimate_statement_cost(cls, statement: str, cursor: Any,) -> Dict[str, Any]: + def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]: """ Generate a SQL query that estimates the cost of a given statement. @@ -1024,7 +1022,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods :param statement: A single SQL statement :param database: Database instance - :param username: Effective username + :param user_name: Effective username :return: Dictionary with different costs """ parsed_query = ParsedQuery(statement) @@ -1089,7 +1087,6 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods :param connect_args: config to be updated :param uri: URI - :param impersonate_user: Flag indicating if impersonation is enabled :param username: Effective username :return: None """ @@ -1122,8 +1119,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods Conditionally mutate and/or quote a sqlalchemy expression label. If force_column_alias_quotes is set to True, return the label as a sqlalchemy.sql.elements.quoted_name object to ensure that the select query - and query results have same case. Otherwise return the mutated label as a - regular string. If maxmimum supported column name length is exceeded, + and query results have same case. Otherwise, return the mutated label as a + regular string. If maximum supported column name length is exceeded, generate a truncated label by calling truncate_label(). :param label: expected expression label/alias @@ -1143,15 +1140,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods def get_sqla_column_type( cls, column_type: Optional[str], - column_type_mappings: Tuple[ - Tuple[ - Pattern[str], - Union[TypeEngine, Callable[[Match[str]], TypeEngine]], - GenericDataType, - ], - ..., - ] = column_type_mappings, - ) -> Union[Tuple[TypeEngine, GenericDataType], None]: + column_type_mappings: Tuple[ColumnTypeMapping, ...] = column_type_mappings, + ) -> Optional[Tuple[TypeEngine, GenericDataType]]: """ Return a sqlalchemy native column type that corresponds to the column type defined in the data source (return None to use default type inferred by @@ -1159,16 +1149,18 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods (see MSSQL for example of NCHAR/NVARCHAR handling). :param column_type: Column type returned by inspector + :param column_type_mappings: Maps from string to SqlAlchemy TypeEngine :return: SqlAlchemy column type """ if not column_type: return None for regex, sqla_type, generic_type in column_type_mappings: match = regex.match(column_type) - if match: - if callable(sqla_type): - return sqla_type(match), generic_type - return sqla_type, generic_type + if not match: + continue + if callable(sqla_type): + return sqla_type(match), generic_type + return sqla_type, generic_type return None @staticmethod @@ -1192,7 +1184,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods """ In the case that a label exceeds the max length supported by the engine, this method is used to construct a deterministic and unique label based on - the original label. By default this returns an md5 hash of the original label, + the original label. By default, this returns a md5 hash of the original label, conditionally truncated if the length of the hash exceeds the max column length of the engine. @@ -1211,8 +1203,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods ) -> str: """ Convert sqlalchemy column type to string representation. - By default removes collation and character encoding info to avoid unnecessarily - long datatypes. + By default, removes collation and character encoding info to avoid + unnecessarily long datatypes. :param sqla_column_type: SqlAlchemy column type :param dialect: Sqlalchemy dialect @@ -1304,20 +1296,14 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods native_type: Optional[str], db_extra: Optional[Dict[str, Any]] = None, source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE, - column_type_mappings: Tuple[ - Tuple[ - Pattern[str], - Union[TypeEngine, Callable[[Match[str]], TypeEngine]], - GenericDataType, - ], - ..., - ] = column_type_mappings, - ) -> Union[ColumnSpec, None]: + column_type_mappings: Tuple[ColumnTypeMapping, ...] = column_type_mappings, + ) -> Optional[ColumnSpec]: """ Converts native database type to sqlalchemy column type. - :param native_type: Native database typee - :param source: Type coming from the database table or cursor description + :param native_type: Native database type :param db_extra: The database extra object + :param source: Type coming from the database table or cursor description + :param column_type_mappings: Maps from string to SqlAlchemy TypeEngine :return: ColumnSpec object """ col_types = cls.get_sqla_column_type( @@ -1417,7 +1403,6 @@ class BasicParametersType(TypedDict, total=False): class BasicParametersMixin: - """ Mixin for configuring DB engine specs via a dictionary. diff --git a/superset/db_engine_specs/mysql.py b/superset/db_engine_specs/mysql.py index 0d7493c386..90cbe621fa 100644 --- a/superset/db_engine_specs/mysql.py +++ b/superset/db_engine_specs/mysql.py @@ -16,7 +16,7 @@ # under the License. import re from datetime import datetime -from typing import Any, Callable, Dict, Match, Optional, Pattern, Tuple, Union +from typing import Any, Dict, Optional, Pattern, Tuple from urllib import parse from flask_babel import gettext as __ @@ -33,9 +33,12 @@ from sqlalchemy.dialects.mysql import ( TINYTEXT, ) from sqlalchemy.engine.url import URL -from sqlalchemy.types import TypeEngine -from superset.db_engine_specs.base import BaseEngineSpec, BasicParametersMixin +from superset.db_engine_specs.base import ( + BaseEngineSpec, + BasicParametersMixin, + ColumnTypeMapping, +) from superset.errors import SupersetErrorType from superset.models.sql_lab import Query from superset.utils import core as utils @@ -70,14 +73,7 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin): ) encryption_parameters = {"ssl": "1"} - column_type_mappings: Tuple[ - Tuple[ - Pattern[str], - Union[TypeEngine, Callable[[Match[str]], TypeEngine]], - GenericDataType, - ], - ..., - ] = ( + column_type_mappings = ( (re.compile(r"^int.*", re.IGNORECASE), INTEGER(), GenericDataType.NUMERIC,), (re.compile(r"^tinyint", re.IGNORECASE), TINYINT(), GenericDataType.NUMERIC,), ( @@ -208,15 +204,8 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin): native_type: Optional[str], db_extra: Optional[Dict[str, Any]] = None, source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE, - column_type_mappings: Tuple[ - Tuple[ - Pattern[str], - Union[TypeEngine, Callable[[Match[str]], TypeEngine]], - GenericDataType, - ], - ..., - ] = column_type_mappings, - ) -> Union[ColumnSpec, None]: + column_type_mappings: Tuple[ColumnTypeMapping, ...] = column_type_mappings, + ) -> Optional[ColumnSpec]: column_spec = super().get_column_spec(native_type) if column_spec: diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py index c2951c394d..f6c6888ee9 100644 --- a/superset/db_engine_specs/postgres.py +++ b/superset/db_engine_specs/postgres.py @@ -18,25 +18,18 @@ import json import logging import re from datetime import datetime -from typing import ( - Any, - Callable, - Dict, - List, - Match, - Optional, - Pattern, - Tuple, - TYPE_CHECKING, - Union, -) +from typing import Any, Dict, List, Optional, Pattern, Tuple, TYPE_CHECKING from flask_babel import gettext as __ from sqlalchemy.dialects.postgresql import ARRAY, DOUBLE_PRECISION, ENUM, JSON from sqlalchemy.dialects.postgresql.base import PGInspector -from sqlalchemy.types import String, TypeEngine +from sqlalchemy.types import String -from superset.db_engine_specs.base import BaseEngineSpec, BasicParametersMixin +from superset.db_engine_specs.base import ( + BaseEngineSpec, + BasicParametersMixin, + ColumnTypeMapping, +) from superset.errors import SupersetErrorType from superset.exceptions import SupersetException from superset.models.sql_lab import Query @@ -193,10 +186,10 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin): ( re.compile(r"^array.*", re.IGNORECASE), lambda match: ARRAY(int(match[2])) if match[2] else String(), - utils.GenericDataType.STRING, + GenericDataType.STRING, ), - (re.compile(r"^json.*", re.IGNORECASE), JSON(), utils.GenericDataType.STRING,), - (re.compile(r"^enum.*", re.IGNORECASE), ENUM(), utils.GenericDataType.STRING,), + (re.compile(r"^json.*", re.IGNORECASE), JSON(), GenericDataType.STRING,), + (re.compile(r"^enum.*", re.IGNORECASE), ENUM(), GenericDataType.STRING,), ) @classmethod @@ -275,15 +268,8 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin): native_type: Optional[str], db_extra: Optional[Dict[str, Any]] = None, source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE, - column_type_mappings: Tuple[ - Tuple[ - Pattern[str], - Union[TypeEngine, Callable[[Match[str]], TypeEngine]], - GenericDataType, - ], - ..., - ] = column_type_mappings, - ) -> Union[ColumnSpec, None]: + column_type_mappings: Tuple[ColumnTypeMapping, ...] = column_type_mappings, + ) -> Optional[ColumnSpec]: column_spec = super().get_column_spec(native_type) if column_spec: diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index 7f8405220d..376151587c 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -23,19 +23,7 @@ from collections import defaultdict, deque from contextlib import closing from datetime import datetime from distutils.version import StrictVersion -from typing import ( - Any, - Callable, - cast, - Dict, - List, - Match, - Optional, - Pattern, - Tuple, - TYPE_CHECKING, - Union, -) +from typing import Any, cast, Dict, List, Optional, Pattern, Tuple, TYPE_CHECKING, Union from urllib import parse import pandas as pd @@ -49,11 +37,10 @@ from sqlalchemy.engine.result import RowProxy from sqlalchemy.engine.url import make_url, URL from sqlalchemy.orm import Session from sqlalchemy.sql.expression import ColumnClause, Select -from sqlalchemy.types import TypeEngine from superset import cache_manager, is_feature_enabled from superset.common.db_query_status import QueryStatus -from superset.db_engine_specs.base import BaseEngineSpec +from superset.db_engine_specs.base import BaseEngineSpec, ColumnTypeMapping from superset.errors import SupersetErrorType from superset.exceptions import SupersetTemplateException from superset.models.sql_lab import Query @@ -95,7 +82,6 @@ CONNECTION_UNKNOWN_DATABASE_ERROR = re.compile( r"line (?P.+?): Catalog '(?P.+?)' does not exist" ) - logger = logging.getLogger(__name__) @@ -449,86 +435,82 @@ class PrestoEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-metho ( re.compile(r"^boolean.*", re.IGNORECASE), types.BOOLEAN, - utils.GenericDataType.BOOLEAN, + GenericDataType.BOOLEAN, ), ( re.compile(r"^tinyint.*", re.IGNORECASE), TinyInteger(), - utils.GenericDataType.NUMERIC, + GenericDataType.NUMERIC, ), ( re.compile(r"^smallint.*", re.IGNORECASE), types.SMALLINT(), - utils.GenericDataType.NUMERIC, + GenericDataType.NUMERIC, ), ( re.compile(r"^integer.*", re.IGNORECASE), types.INTEGER(), - utils.GenericDataType.NUMERIC, + GenericDataType.NUMERIC, ), ( re.compile(r"^bigint.*", re.IGNORECASE), types.BIGINT(), - utils.GenericDataType.NUMERIC, + GenericDataType.NUMERIC, ), ( re.compile(r"^real.*", re.IGNORECASE), types.FLOAT(), - utils.GenericDataType.NUMERIC, + GenericDataType.NUMERIC, ), ( re.compile(r"^double.*", re.IGNORECASE), types.FLOAT(), - utils.GenericDataType.NUMERIC, + GenericDataType.NUMERIC, ), ( re.compile(r"^decimal.*", re.IGNORECASE), types.DECIMAL(), - utils.GenericDataType.NUMERIC, + GenericDataType.NUMERIC, ), ( re.compile(r"^varchar(\((\d+)\))*$", re.IGNORECASE), lambda match: types.VARCHAR(int(match[2])) if match[2] else types.String(), - utils.GenericDataType.STRING, + GenericDataType.STRING, ), ( re.compile(r"^char(\((\d+)\))*$", re.IGNORECASE), lambda match: types.CHAR(int(match[2])) if match[2] else types.CHAR(), - utils.GenericDataType.STRING, + GenericDataType.STRING, ), ( re.compile(r"^varbinary.*", re.IGNORECASE), types.VARBINARY(), - utils.GenericDataType.STRING, - ), - ( - re.compile(r"^json.*", re.IGNORECASE), - types.JSON(), - utils.GenericDataType.STRING, + GenericDataType.STRING, ), + (re.compile(r"^json.*", re.IGNORECASE), types.JSON(), GenericDataType.STRING,), ( re.compile(r"^date.*", re.IGNORECASE), types.DATETIME(), - utils.GenericDataType.TEMPORAL, + GenericDataType.TEMPORAL, ), ( re.compile(r"^timestamp.*", re.IGNORECASE), types.TIMESTAMP(), - utils.GenericDataType.TEMPORAL, + GenericDataType.TEMPORAL, ), ( re.compile(r"^interval.*", re.IGNORECASE), Interval(), - utils.GenericDataType.TEMPORAL, + GenericDataType.TEMPORAL, ), ( re.compile(r"^time.*", re.IGNORECASE), types.Time(), - utils.GenericDataType.TEMPORAL, + GenericDataType.TEMPORAL, ), - (re.compile(r"^array.*", re.IGNORECASE), Array(), utils.GenericDataType.STRING), - (re.compile(r"^map.*", re.IGNORECASE), Map(), utils.GenericDataType.STRING), - (re.compile(r"^row.*", re.IGNORECASE), Row(), utils.GenericDataType.STRING), + (re.compile(r"^array.*", re.IGNORECASE), Array(), GenericDataType.STRING), + (re.compile(r"^map.*", re.IGNORECASE), Map(), GenericDataType.STRING), + (re.compile(r"^row.*", re.IGNORECASE), Row(), GenericDataType.STRING), ) @classmethod @@ -1217,15 +1199,8 @@ class PrestoEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-metho native_type: Optional[str], db_extra: Optional[Dict[str, Any]] = None, source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE, - column_type_mappings: Tuple[ - Tuple[ - Pattern[str], - Union[TypeEngine, Callable[[Match[str]], TypeEngine]], - GenericDataType, - ], - ..., - ] = column_type_mappings, - ) -> Union[ColumnSpec, None]: + column_type_mappings: Tuple[ColumnTypeMapping, ...] = column_type_mappings, + ) -> Optional[ColumnSpec]: column_spec = super().get_column_spec( native_type, column_type_mappings=column_type_mappings