feat: support complex types and use get_columns implementation of starrrocks python client (#24237)

This commit is contained in:
miomiocat 2023-06-09 01:08:37 +08:00 committed by GitHub
parent 69c2cd5f40
commit fd3effe712
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 62 additions and 67 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 14 KiB

After

Width:  |  Height:  |  Size: 6.5 KiB

View File

@ -22,9 +22,7 @@ from typing import Any, Optional
from urllib import parse from urllib import parse
from flask_babel import gettext as __ from flask_babel import gettext as __
from sqlalchemy import Integer, Numeric, types from sqlalchemy import Float, Integer, Numeric, types
from sqlalchemy.engine import Inspector
from sqlalchemy.engine.result import Row as ResultRow
from sqlalchemy.engine.url import URL from sqlalchemy.engine.url import URL
from sqlalchemy.sql.type_api import TypeEngine from sqlalchemy.sql.type_api import TypeEngine
@ -45,10 +43,26 @@ class TINYINT(Integer):
__visit_name__ = "TINYINT" __visit_name__ = "TINYINT"
class DOUBLE(Numeric): class LARGEINT(Integer):
__visit_name__ = "LARGEINT"
class DOUBLE(Float):
__visit_name__ = "DOUBLE" __visit_name__ = "DOUBLE"
class HLL(Numeric):
__visit_name__ = "HLL"
class BITMAP(Numeric):
__visit_name__ = "BITMAP"
class PERCENTILE(Numeric):
__visit_name__ = "PERCENTILE"
class ARRAY(TypeEngine): # pylint: disable=no-init class ARRAY(TypeEngine): # pylint: disable=no-init
__visit_name__ = "ARRAY" __visit_name__ = "ARRAY"
@ -88,6 +102,11 @@ class StarRocksEngineSpec(MySQLEngineSpec):
TINYINT(), TINYINT(),
GenericDataType.NUMERIC, GenericDataType.NUMERIC,
), ),
(
re.compile(r"^largeint", re.IGNORECASE),
LARGEINT(),
GenericDataType.NUMERIC,
),
( (
re.compile(r"^decimal.*", re.IGNORECASE), re.compile(r"^decimal.*", re.IGNORECASE),
types.DECIMAL(), types.DECIMAL(),
@ -108,11 +127,23 @@ class StarRocksEngineSpec(MySQLEngineSpec):
types.CHAR(), types.CHAR(),
GenericDataType.STRING, GenericDataType.STRING,
), ),
(
re.compile(r"^json", re.IGNORECASE),
types.JSON(),
GenericDataType.STRING,
),
( (
re.compile(r"^binary.*", re.IGNORECASE), re.compile(r"^binary.*", re.IGNORECASE),
types.String(), types.String(),
GenericDataType.STRING, GenericDataType.STRING,
), ),
(
re.compile(r"^percentile", re.IGNORECASE),
PERCENTILE(),
GenericDataType.STRING,
),
(re.compile(r"^hll", re.IGNORECASE), HLL(), GenericDataType.STRING),
(re.compile(r"^bitmap", re.IGNORECASE), BITMAP(), GenericDataType.STRING),
(re.compile(r"^array.*", re.IGNORECASE), ARRAY(), GenericDataType.STRING), (re.compile(r"^array.*", re.IGNORECASE), ARRAY(), GenericDataType.STRING),
(re.compile(r"^map.*", re.IGNORECASE), MAP(), GenericDataType.STRING), (re.compile(r"^map.*", re.IGNORECASE), MAP(), GenericDataType.STRING),
(re.compile(r"^struct.*", re.IGNORECASE), STRUCT(), GenericDataType.STRING), (re.compile(r"^struct.*", re.IGNORECASE), STRUCT(), GenericDataType.STRING),
@ -145,62 +176,11 @@ class StarRocksEngineSpec(MySQLEngineSpec):
if "." in database: if "." in database:
database = database.split(".")[0] + "." + schema database = database.split(".")[0] + "." + schema
else: else:
database += "." + schema database = "default_catalog." + schema
uri = uri.set(database=database) uri = uri.set(database=database)
return uri, connect_args return uri, connect_args
@classmethod
def get_columns(
cls, inspector: Inspector, table_name: str, schema: Optional[str]
) -> list[dict[str, Any]]:
columns = cls._show_columns(inspector, table_name, schema)
result: list[dict[str, Any]] = []
for column in columns:
column_spec = cls.get_column_spec(column.Type)
column_type = column_spec.sqla_type if column_spec else None
if column_type is None:
column_type = types.String()
logger.info(
"Did not recognize starrocks type %s of column %s",
str(column.Type),
str(column.Field),
)
column_info = cls._create_column_info(column.Field, column_type)
column_info["nullable"] = getattr(column, "Null", True)
column_info["default"] = None
result.append(column_info)
return result
@classmethod
def _show_columns(
cls, inspector: Inspector, table_name: str, schema: Optional[str]
) -> list[ResultRow]:
"""
Show starrocks column names
:param inspector: object that performs database schema inspection
:param table_name: table name
:param schema: schema name
:return: list of column objects
"""
quote = inspector.engine.dialect.identifier_preparer.quote_identifier
full_table = quote(table_name)
if schema:
full_table = f"{quote(schema)}.{full_table}"
return inspector.bind.execute(f"SHOW COLUMNS FROM {full_table}").fetchall()
@classmethod
def _create_column_info(
cls, name: str, data_type: types.TypeEngine
) -> dict[str, Any]:
"""
Create column info object
:param name: column name
:param data_type: column data type
:return: column info object
"""
return {"name": name, "type": f"{data_type}"}
@classmethod @classmethod
def get_schema_from_engine_params( def get_schema_from_engine_params(
cls, cls,

View File

@ -18,10 +18,20 @@
from typing import Any, Optional from typing import Any, Optional
import pytest import pytest
from sqlalchemy import types from sqlalchemy import JSON, types
from sqlalchemy.engine.url import make_url from sqlalchemy.engine.url import make_url
from superset.db_engine_specs.starrocks import ARRAY, DOUBLE, MAP, STRUCT, TINYINT from superset.db_engine_specs.starrocks import (
ARRAY,
BITMAP,
DOUBLE,
HLL,
LARGEINT,
MAP,
PERCENTILE,
STRUCT,
TINYINT,
)
from superset.utils.core import GenericDataType from superset.utils.core import GenericDataType
from tests.unit_tests.db_engine_specs.utils import assert_column_spec from tests.unit_tests.db_engine_specs.utils import assert_column_spec
@ -30,17 +40,22 @@ from tests.unit_tests.db_engine_specs.utils import assert_column_spec
"native_type,sqla_type,attrs,generic_type,is_dttm", "native_type,sqla_type,attrs,generic_type,is_dttm",
[ [
# Numeric # Numeric
("TINYINT", TINYINT, None, GenericDataType.NUMERIC, False), ("tinyint", TINYINT, None, GenericDataType.NUMERIC, False),
("DECIMAL", types.DECIMAL, None, GenericDataType.NUMERIC, False), ("largeint", LARGEINT, None, GenericDataType.NUMERIC, False),
("DOUBLE", DOUBLE, None, GenericDataType.NUMERIC, False), ("decimal(38,18)", types.DECIMAL, None, GenericDataType.NUMERIC, False),
("double", DOUBLE, None, GenericDataType.NUMERIC, False),
# String # String
("CHAR", types.CHAR, None, GenericDataType.STRING, False), ("char(10)", types.CHAR, None, GenericDataType.STRING, False),
("VARCHAR", types.VARCHAR, None, GenericDataType.STRING, False), ("varchar(65533)", types.VARCHAR, None, GenericDataType.STRING, False),
("BINARY", types.String, None, GenericDataType.STRING, False), ("binary", types.String, None, GenericDataType.STRING, False),
# Complex type # Complex type
("ARRAY", ARRAY, None, GenericDataType.STRING, False), ("array<varchar(65533)>", ARRAY, None, GenericDataType.STRING, False),
("MAP", MAP, None, GenericDataType.STRING, False), ("map<string,int>", MAP, None, GenericDataType.STRING, False),
("STRUCT", STRUCT, None, GenericDataType.STRING, False), ("struct<int,string>", STRUCT, None, GenericDataType.STRING, False),
("json", JSON, None, GenericDataType.STRING, False),
("bitmap", BITMAP, None, GenericDataType.STRING, False),
("hll", HLL, None, GenericDataType.STRING, False),
("percentile", PERCENTILE, None, GenericDataType.STRING, False),
], ],
) )
def test_get_column_spec( def test_get_column_spec(