cleanup column_type_mappings (#17569)

Signed-off-by: Đặng Minh Dũng <dungdm93@live.com>
This commit is contained in:
Đặng Minh Dũng 2022-01-14 16:07:17 +07:00 committed by GitHub
parent 26dc600aff
commit 5a740901d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 105 additions and 170 deletions

View File

@ -53,7 +53,7 @@ from sqlalchemy.ext.compiler import compiles
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from sqlalchemy.sql import quoted_name, text from sqlalchemy.sql import quoted_name, text
from sqlalchemy.sql.expression import ColumnClause, Select, TextAsFrom, TextClause from sqlalchemy.sql.expression import ColumnClause, Select, TextAsFrom, TextClause
from sqlalchemy.types import String, TypeEngine, UnicodeText from sqlalchemy.types import TypeEngine
from typing_extensions import TypedDict from typing_extensions import TypedDict
from superset import security_manager, sql_parse from superset import security_manager, sql_parse
@ -71,6 +71,12 @@ if TYPE_CHECKING:
from superset.connectors.sqla.models import TableColumn from superset.connectors.sqla.models import TableColumn
from superset.models.core import Database from superset.models.core import Database
ColumnTypeMapping = Tuple[
Pattern[str],
Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
GenericDataType,
]
logger = logging.getLogger() logger = logging.getLogger()
@ -156,26 +162,37 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
engine = "base" # str as defined in sqlalchemy.engine.engine engine = "base" # str as defined in sqlalchemy.engine.engine
engine_aliases: Set[str] = set() engine_aliases: Set[str] = set()
engine_name: Optional[ engine_name: Optional[str] = None # for user messages, overridden in child classes
str
] = None # used for user messages, overridden in child classes
_date_trunc_functions: Dict[str, str] = {} _date_trunc_functions: Dict[str, str] = {}
_time_grain_expressions: Dict[Optional[str], str] = {} _time_grain_expressions: Dict[Optional[str], str] = {}
column_type_mappings: Tuple[ column_type_mappings: Tuple[ColumnTypeMapping, ...] = (
Tuple[ (
Pattern[str], re.compile(r"^string", re.IGNORECASE),
Union[TypeEngine, Callable[[Match[str]], TypeEngine]], types.String(),
GenericDataType, GenericDataType.STRING,
], ),
..., (
] = ( re.compile(r"^n((var)?char|text)", re.IGNORECASE),
types.UnicodeText(),
GenericDataType.STRING,
),
(
re.compile(r"^(var)?char", re.IGNORECASE),
types.String(),
GenericDataType.STRING,
),
(
re.compile(r"^(tiny|medium|long)?text", re.IGNORECASE),
types.String(),
GenericDataType.STRING,
),
( (
re.compile(r"^smallint", re.IGNORECASE), re.compile(r"^smallint", re.IGNORECASE),
types.SmallInteger(), types.SmallInteger(),
GenericDataType.NUMERIC, GenericDataType.NUMERIC,
), ),
( (
re.compile(r"^int.*", re.IGNORECASE), re.compile(r"^int(eger)?", re.IGNORECASE),
types.Integer(), types.Integer(),
GenericDataType.NUMERIC, GenericDataType.NUMERIC,
), ),
@ -184,6 +201,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
types.BigInteger(), types.BigInteger(),
GenericDataType.NUMERIC, GenericDataType.NUMERIC,
), ),
(re.compile(r"^long", re.IGNORECASE), types.Float(), GenericDataType.NUMERIC,),
( (
re.compile(r"^decimal", re.IGNORECASE), re.compile(r"^decimal", re.IGNORECASE),
types.Numeric(), types.Numeric(),
@ -222,26 +240,10 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
GenericDataType.NUMERIC, GenericDataType.NUMERIC,
), ),
( (
re.compile(r"^string", re.IGNORECASE), re.compile(r"^timestamp", re.IGNORECASE),
types.String(), types.TIMESTAMP(),
utils.GenericDataType.STRING, GenericDataType.TEMPORAL,
), ),
(
re.compile(r"^N((VAR)?CHAR|TEXT)", re.IGNORECASE),
UnicodeText(),
utils.GenericDataType.STRING,
),
(
re.compile(r"^((VAR)?CHAR|TEXT|STRING)", re.IGNORECASE),
String(),
utils.GenericDataType.STRING,
),
(
re.compile(r"^((TINY|MEDIUM|LONG)?TEXT)", re.IGNORECASE),
String(),
utils.GenericDataType.STRING,
),
(re.compile(r"^LONG", re.IGNORECASE), types.Float(), GenericDataType.NUMERIC,),
( (
re.compile(r"^datetime", re.IGNORECASE), re.compile(r"^datetime", re.IGNORECASE),
types.DateTime(), types.DateTime(),
@ -252,19 +254,14 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
types.DateTime(), types.DateTime(),
GenericDataType.TEMPORAL, GenericDataType.TEMPORAL,
), ),
( (re.compile(r"^time", re.IGNORECASE), types.Time(), GenericDataType.TEMPORAL,),
re.compile(r"^timestamp", re.IGNORECASE),
types.TIMESTAMP(),
GenericDataType.TEMPORAL,
),
( (
re.compile(r"^interval", re.IGNORECASE), re.compile(r"^interval", re.IGNORECASE),
types.Interval(), types.Interval(),
GenericDataType.TEMPORAL, GenericDataType.TEMPORAL,
), ),
(re.compile(r"^time", re.IGNORECASE), types.Time(), GenericDataType.TEMPORAL,),
( (
re.compile(r"^bool.*", re.IGNORECASE), re.compile(r"^bool(ean)?", re.IGNORECASE),
types.Boolean(), types.Boolean(),
GenericDataType.BOOLEAN, GenericDataType.BOOLEAN,
), ),
@ -693,7 +690,6 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
to_sql_kwargs["name"] = table.table to_sql_kwargs["name"] = table.table
if table.schema: if table.schema:
# Only add schema when it is preset and non empty. # Only add schema when it is preset and non empty.
to_sql_kwargs["schema"] = table.schema to_sql_kwargs["schema"] = table.schema
@ -844,6 +840,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
""" """
Get all tables from schema Get all tables from schema
:param database: The database to get info
:param inspector: SqlAlchemy inspector :param inspector: SqlAlchemy inspector
:param schema: Schema to inspect. If omitted, uses default schema for database :param schema: Schema to inspect. If omitted, uses default schema for database
:return: All tables in schema :return: All tables in schema
@ -860,6 +857,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
""" """
Get all views from schema Get all views from schema
:param database: The database to get info
:param inspector: SqlAlchemy inspector :param inspector: SqlAlchemy inspector
:param schema: Schema name. If omitted, uses default schema for database :param schema: Schema name. If omitted, uses default schema for database
:return: All views in schema :return: All views in schema
@ -924,7 +922,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
:param database: Database instance :param database: Database instance
:param query: SqlAlchemy query :param query: SqlAlchemy query
:param columns: List of TableColumns :param columns: List of TableColumns
:return: SqlAlchemy query with additional where clause referencing latest :return: SqlAlchemy query with additional where clause referencing the latest
partition partition
""" """
# TODO: Fix circular import caused by importing Database, TableColumn # TODO: Fix circular import caused by importing Database, TableColumn
@ -954,12 +952,12 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
:param database: Database instance :param database: Database instance
:param table_name: Table name, unquoted :param table_name: Table name, unquoted
:param engine: SqlALchemy Engine instance :param engine: SqlAlchemy Engine instance
:param schema: Schema, unquoted :param schema: Schema, unquoted
:param limit: limit to impose on query :param limit: limit to impose on query
:param show_cols: Show columns in query; otherwise use "*" :param show_cols: Show columns in query; otherwise use "*"
:param indent: Add indentation to query :param indent: Add indentation to query
:param latest_partition: Only query latest partition :param latest_partition: Only query the latest partition
:param cols: Columns to include in query :param cols: Columns to include in query
:return: SQL query :return: SQL query
""" """
@ -993,7 +991,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return sql return sql
@classmethod @classmethod
def estimate_statement_cost(cls, statement: str, cursor: Any,) -> Dict[str, Any]: def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]:
""" """
Generate a SQL query that estimates the cost of a given statement. Generate a SQL query that estimates the cost of a given statement.
@ -1024,7 +1022,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
:param statement: A single SQL statement :param statement: A single SQL statement
:param database: Database instance :param database: Database instance
:param username: Effective username :param user_name: Effective username
:return: Dictionary with different costs :return: Dictionary with different costs
""" """
parsed_query = ParsedQuery(statement) parsed_query = ParsedQuery(statement)
@ -1089,7 +1087,6 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
:param connect_args: config to be updated :param connect_args: config to be updated
:param uri: URI :param uri: URI
:param impersonate_user: Flag indicating if impersonation is enabled
:param username: Effective username :param username: Effective username
:return: None :return: None
""" """
@ -1122,8 +1119,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
Conditionally mutate and/or quote a sqlalchemy expression label. If Conditionally mutate and/or quote a sqlalchemy expression label. If
force_column_alias_quotes is set to True, return the label as a force_column_alias_quotes is set to True, return the label as a
sqlalchemy.sql.elements.quoted_name object to ensure that the select query sqlalchemy.sql.elements.quoted_name object to ensure that the select query
and query results have same case. Otherwise return the mutated label as a and query results have same case. Otherwise, return the mutated label as a
regular string. If maxmimum supported column name length is exceeded, regular string. If maximum supported column name length is exceeded,
generate a truncated label by calling truncate_label(). generate a truncated label by calling truncate_label().
:param label: expected expression label/alias :param label: expected expression label/alias
@ -1143,15 +1140,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
def get_sqla_column_type( def get_sqla_column_type(
cls, cls,
column_type: Optional[str], column_type: Optional[str],
column_type_mappings: Tuple[ column_type_mappings: Tuple[ColumnTypeMapping, ...] = column_type_mappings,
Tuple[ ) -> Optional[Tuple[TypeEngine, GenericDataType]]:
Pattern[str],
Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
GenericDataType,
],
...,
] = column_type_mappings,
) -> Union[Tuple[TypeEngine, GenericDataType], None]:
""" """
Return a sqlalchemy native column type that corresponds to the column type Return a sqlalchemy native column type that corresponds to the column type
defined in the data source (return None to use default type inferred by defined in the data source (return None to use default type inferred by
@ -1159,16 +1149,18 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
(see MSSQL for example of NCHAR/NVARCHAR handling). (see MSSQL for example of NCHAR/NVARCHAR handling).
:param column_type: Column type returned by inspector :param column_type: Column type returned by inspector
:param column_type_mappings: Maps from string to SqlAlchemy TypeEngine
:return: SqlAlchemy column type :return: SqlAlchemy column type
""" """
if not column_type: if not column_type:
return None return None
for regex, sqla_type, generic_type in column_type_mappings: for regex, sqla_type, generic_type in column_type_mappings:
match = regex.match(column_type) match = regex.match(column_type)
if match: if not match:
if callable(sqla_type): continue
return sqla_type(match), generic_type if callable(sqla_type):
return sqla_type, generic_type return sqla_type(match), generic_type
return sqla_type, generic_type
return None return None
@staticmethod @staticmethod
@ -1192,7 +1184,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
""" """
In the case that a label exceeds the max length supported by the engine, In the case that a label exceeds the max length supported by the engine,
this method is used to construct a deterministic and unique label based on this method is used to construct a deterministic and unique label based on
the original label. By default this returns an md5 hash of the original label, the original label. By default, this returns a md5 hash of the original label,
conditionally truncated if the length of the hash exceeds the max column length conditionally truncated if the length of the hash exceeds the max column length
of the engine. of the engine.
@ -1211,8 +1203,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
) -> str: ) -> str:
""" """
Convert sqlalchemy column type to string representation. Convert sqlalchemy column type to string representation.
By default removes collation and character encoding info to avoid unnecessarily By default, removes collation and character encoding info to avoid
long datatypes. unnecessarily long datatypes.
:param sqla_column_type: SqlAlchemy column type :param sqla_column_type: SqlAlchemy column type
:param dialect: Sqlalchemy dialect :param dialect: Sqlalchemy dialect
@ -1304,20 +1296,14 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
native_type: Optional[str], native_type: Optional[str],
db_extra: Optional[Dict[str, Any]] = None, db_extra: Optional[Dict[str, Any]] = None,
source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE, source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE,
column_type_mappings: Tuple[ column_type_mappings: Tuple[ColumnTypeMapping, ...] = column_type_mappings,
Tuple[ ) -> Optional[ColumnSpec]:
Pattern[str],
Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
GenericDataType,
],
...,
] = column_type_mappings,
) -> Union[ColumnSpec, None]:
""" """
Converts native database type to sqlalchemy column type. Converts native database type to sqlalchemy column type.
:param native_type: Native database typee :param native_type: Native database type
:param source: Type coming from the database table or cursor description
:param db_extra: The database extra object :param db_extra: The database extra object
:param source: Type coming from the database table or cursor description
:param column_type_mappings: Maps from string to SqlAlchemy TypeEngine
:return: ColumnSpec object :return: ColumnSpec object
""" """
col_types = cls.get_sqla_column_type( col_types = cls.get_sqla_column_type(
@ -1417,7 +1403,6 @@ class BasicParametersType(TypedDict, total=False):
class BasicParametersMixin: class BasicParametersMixin:
""" """
Mixin for configuring DB engine specs via a dictionary. Mixin for configuring DB engine specs via a dictionary.

View File

@ -16,7 +16,7 @@
# under the License. # under the License.
import re import re
from datetime import datetime from datetime import datetime
from typing import Any, Callable, Dict, Match, Optional, Pattern, Tuple, Union from typing import Any, Dict, Optional, Pattern, Tuple
from urllib import parse from urllib import parse
from flask_babel import gettext as __ from flask_babel import gettext as __
@ -33,9 +33,12 @@ from sqlalchemy.dialects.mysql import (
TINYTEXT, TINYTEXT,
) )
from sqlalchemy.engine.url import URL from sqlalchemy.engine.url import URL
from sqlalchemy.types import TypeEngine
from superset.db_engine_specs.base import BaseEngineSpec, BasicParametersMixin from superset.db_engine_specs.base import (
BaseEngineSpec,
BasicParametersMixin,
ColumnTypeMapping,
)
from superset.errors import SupersetErrorType from superset.errors import SupersetErrorType
from superset.models.sql_lab import Query from superset.models.sql_lab import Query
from superset.utils import core as utils from superset.utils import core as utils
@ -70,14 +73,7 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
) )
encryption_parameters = {"ssl": "1"} encryption_parameters = {"ssl": "1"}
column_type_mappings: Tuple[ column_type_mappings = (
Tuple[
Pattern[str],
Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
GenericDataType,
],
...,
] = (
(re.compile(r"^int.*", re.IGNORECASE), INTEGER(), GenericDataType.NUMERIC,), (re.compile(r"^int.*", re.IGNORECASE), INTEGER(), GenericDataType.NUMERIC,),
(re.compile(r"^tinyint", re.IGNORECASE), TINYINT(), GenericDataType.NUMERIC,), (re.compile(r"^tinyint", re.IGNORECASE), TINYINT(), GenericDataType.NUMERIC,),
( (
@ -208,15 +204,8 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
native_type: Optional[str], native_type: Optional[str],
db_extra: Optional[Dict[str, Any]] = None, db_extra: Optional[Dict[str, Any]] = None,
source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE, source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE,
column_type_mappings: Tuple[ column_type_mappings: Tuple[ColumnTypeMapping, ...] = column_type_mappings,
Tuple[ ) -> Optional[ColumnSpec]:
Pattern[str],
Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
GenericDataType,
],
...,
] = column_type_mappings,
) -> Union[ColumnSpec, None]:
column_spec = super().get_column_spec(native_type) column_spec = super().get_column_spec(native_type)
if column_spec: if column_spec:

View File

@ -18,25 +18,18 @@ import json
import logging import logging
import re import re
from datetime import datetime from datetime import datetime
from typing import ( from typing import Any, Dict, List, Optional, Pattern, Tuple, TYPE_CHECKING
Any,
Callable,
Dict,
List,
Match,
Optional,
Pattern,
Tuple,
TYPE_CHECKING,
Union,
)
from flask_babel import gettext as __ from flask_babel import gettext as __
from sqlalchemy.dialects.postgresql import ARRAY, DOUBLE_PRECISION, ENUM, JSON from sqlalchemy.dialects.postgresql import ARRAY, DOUBLE_PRECISION, ENUM, JSON
from sqlalchemy.dialects.postgresql.base import PGInspector from sqlalchemy.dialects.postgresql.base import PGInspector
from sqlalchemy.types import String, TypeEngine from sqlalchemy.types import String
from superset.db_engine_specs.base import BaseEngineSpec, BasicParametersMixin from superset.db_engine_specs.base import (
BaseEngineSpec,
BasicParametersMixin,
ColumnTypeMapping,
)
from superset.errors import SupersetErrorType from superset.errors import SupersetErrorType
from superset.exceptions import SupersetException from superset.exceptions import SupersetException
from superset.models.sql_lab import Query from superset.models.sql_lab import Query
@ -193,10 +186,10 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
( (
re.compile(r"^array.*", re.IGNORECASE), re.compile(r"^array.*", re.IGNORECASE),
lambda match: ARRAY(int(match[2])) if match[2] else String(), lambda match: ARRAY(int(match[2])) if match[2] else String(),
utils.GenericDataType.STRING, GenericDataType.STRING,
), ),
(re.compile(r"^json.*", re.IGNORECASE), JSON(), utils.GenericDataType.STRING,), (re.compile(r"^json.*", re.IGNORECASE), JSON(), GenericDataType.STRING,),
(re.compile(r"^enum.*", re.IGNORECASE), ENUM(), utils.GenericDataType.STRING,), (re.compile(r"^enum.*", re.IGNORECASE), ENUM(), GenericDataType.STRING,),
) )
@classmethod @classmethod
@ -275,15 +268,8 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
native_type: Optional[str], native_type: Optional[str],
db_extra: Optional[Dict[str, Any]] = None, db_extra: Optional[Dict[str, Any]] = None,
source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE, source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE,
column_type_mappings: Tuple[ column_type_mappings: Tuple[ColumnTypeMapping, ...] = column_type_mappings,
Tuple[ ) -> Optional[ColumnSpec]:
Pattern[str],
Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
GenericDataType,
],
...,
] = column_type_mappings,
) -> Union[ColumnSpec, None]:
column_spec = super().get_column_spec(native_type) column_spec = super().get_column_spec(native_type)
if column_spec: if column_spec:

View File

@ -23,19 +23,7 @@ from collections import defaultdict, deque
from contextlib import closing from contextlib import closing
from datetime import datetime from datetime import datetime
from distutils.version import StrictVersion from distutils.version import StrictVersion
from typing import ( from typing import Any, cast, Dict, List, Optional, Pattern, Tuple, TYPE_CHECKING, Union
Any,
Callable,
cast,
Dict,
List,
Match,
Optional,
Pattern,
Tuple,
TYPE_CHECKING,
Union,
)
from urllib import parse from urllib import parse
import pandas as pd import pandas as pd
@ -49,11 +37,10 @@ from sqlalchemy.engine.result import RowProxy
from sqlalchemy.engine.url import make_url, URL from sqlalchemy.engine.url import make_url, URL
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import ColumnClause, Select from sqlalchemy.sql.expression import ColumnClause, Select
from sqlalchemy.types import TypeEngine
from superset import cache_manager, is_feature_enabled from superset import cache_manager, is_feature_enabled
from superset.common.db_query_status import QueryStatus from superset.common.db_query_status import QueryStatus
from superset.db_engine_specs.base import BaseEngineSpec from superset.db_engine_specs.base import BaseEngineSpec, ColumnTypeMapping
from superset.errors import SupersetErrorType from superset.errors import SupersetErrorType
from superset.exceptions import SupersetTemplateException from superset.exceptions import SupersetTemplateException
from superset.models.sql_lab import Query from superset.models.sql_lab import Query
@ -95,7 +82,6 @@ CONNECTION_UNKNOWN_DATABASE_ERROR = re.compile(
r"line (?P<location>.+?): Catalog '(?P<catalog_name>.+?)' does not exist" r"line (?P<location>.+?): Catalog '(?P<catalog_name>.+?)' does not exist"
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -449,86 +435,82 @@ class PrestoEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-metho
( (
re.compile(r"^boolean.*", re.IGNORECASE), re.compile(r"^boolean.*", re.IGNORECASE),
types.BOOLEAN, types.BOOLEAN,
utils.GenericDataType.BOOLEAN, GenericDataType.BOOLEAN,
), ),
( (
re.compile(r"^tinyint.*", re.IGNORECASE), re.compile(r"^tinyint.*", re.IGNORECASE),
TinyInteger(), TinyInteger(),
utils.GenericDataType.NUMERIC, GenericDataType.NUMERIC,
), ),
( (
re.compile(r"^smallint.*", re.IGNORECASE), re.compile(r"^smallint.*", re.IGNORECASE),
types.SMALLINT(), types.SMALLINT(),
utils.GenericDataType.NUMERIC, GenericDataType.NUMERIC,
), ),
( (
re.compile(r"^integer.*", re.IGNORECASE), re.compile(r"^integer.*", re.IGNORECASE),
types.INTEGER(), types.INTEGER(),
utils.GenericDataType.NUMERIC, GenericDataType.NUMERIC,
), ),
( (
re.compile(r"^bigint.*", re.IGNORECASE), re.compile(r"^bigint.*", re.IGNORECASE),
types.BIGINT(), types.BIGINT(),
utils.GenericDataType.NUMERIC, GenericDataType.NUMERIC,
), ),
( (
re.compile(r"^real.*", re.IGNORECASE), re.compile(r"^real.*", re.IGNORECASE),
types.FLOAT(), types.FLOAT(),
utils.GenericDataType.NUMERIC, GenericDataType.NUMERIC,
), ),
( (
re.compile(r"^double.*", re.IGNORECASE), re.compile(r"^double.*", re.IGNORECASE),
types.FLOAT(), types.FLOAT(),
utils.GenericDataType.NUMERIC, GenericDataType.NUMERIC,
), ),
( (
re.compile(r"^decimal.*", re.IGNORECASE), re.compile(r"^decimal.*", re.IGNORECASE),
types.DECIMAL(), types.DECIMAL(),
utils.GenericDataType.NUMERIC, GenericDataType.NUMERIC,
), ),
( (
re.compile(r"^varchar(\((\d+)\))*$", re.IGNORECASE), re.compile(r"^varchar(\((\d+)\))*$", re.IGNORECASE),
lambda match: types.VARCHAR(int(match[2])) if match[2] else types.String(), lambda match: types.VARCHAR(int(match[2])) if match[2] else types.String(),
utils.GenericDataType.STRING, GenericDataType.STRING,
), ),
( (
re.compile(r"^char(\((\d+)\))*$", re.IGNORECASE), re.compile(r"^char(\((\d+)\))*$", re.IGNORECASE),
lambda match: types.CHAR(int(match[2])) if match[2] else types.CHAR(), lambda match: types.CHAR(int(match[2])) if match[2] else types.CHAR(),
utils.GenericDataType.STRING, GenericDataType.STRING,
), ),
( (
re.compile(r"^varbinary.*", re.IGNORECASE), re.compile(r"^varbinary.*", re.IGNORECASE),
types.VARBINARY(), types.VARBINARY(),
utils.GenericDataType.STRING, GenericDataType.STRING,
),
(
re.compile(r"^json.*", re.IGNORECASE),
types.JSON(),
utils.GenericDataType.STRING,
), ),
(re.compile(r"^json.*", re.IGNORECASE), types.JSON(), GenericDataType.STRING,),
( (
re.compile(r"^date.*", re.IGNORECASE), re.compile(r"^date.*", re.IGNORECASE),
types.DATETIME(), types.DATETIME(),
utils.GenericDataType.TEMPORAL, GenericDataType.TEMPORAL,
), ),
( (
re.compile(r"^timestamp.*", re.IGNORECASE), re.compile(r"^timestamp.*", re.IGNORECASE),
types.TIMESTAMP(), types.TIMESTAMP(),
utils.GenericDataType.TEMPORAL, GenericDataType.TEMPORAL,
), ),
( (
re.compile(r"^interval.*", re.IGNORECASE), re.compile(r"^interval.*", re.IGNORECASE),
Interval(), Interval(),
utils.GenericDataType.TEMPORAL, GenericDataType.TEMPORAL,
), ),
( (
re.compile(r"^time.*", re.IGNORECASE), re.compile(r"^time.*", re.IGNORECASE),
types.Time(), types.Time(),
utils.GenericDataType.TEMPORAL, GenericDataType.TEMPORAL,
), ),
(re.compile(r"^array.*", re.IGNORECASE), Array(), utils.GenericDataType.STRING), (re.compile(r"^array.*", re.IGNORECASE), Array(), GenericDataType.STRING),
(re.compile(r"^map.*", re.IGNORECASE), Map(), utils.GenericDataType.STRING), (re.compile(r"^map.*", re.IGNORECASE), Map(), GenericDataType.STRING),
(re.compile(r"^row.*", re.IGNORECASE), Row(), utils.GenericDataType.STRING), (re.compile(r"^row.*", re.IGNORECASE), Row(), GenericDataType.STRING),
) )
@classmethod @classmethod
@ -1217,15 +1199,8 @@ class PrestoEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-metho
native_type: Optional[str], native_type: Optional[str],
db_extra: Optional[Dict[str, Any]] = None, db_extra: Optional[Dict[str, Any]] = None,
source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE, source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE,
column_type_mappings: Tuple[ column_type_mappings: Tuple[ColumnTypeMapping, ...] = column_type_mappings,
Tuple[ ) -> Optional[ColumnSpec]:
Pattern[str],
Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
GenericDataType,
],
...,
] = column_type_mappings,
) -> Union[ColumnSpec, None]:
column_spec = super().get_column_spec( column_spec = super().get_column_spec(
native_type, column_type_mappings=column_type_mappings native_type, column_type_mappings=column_type_mappings