feat: support complex types and use get_columns implementation of starrrocks python client (#24237)

This commit is contained in:
miomiocat 2023-06-09 01:08:37 +08:00 committed by GitHub
parent 69c2cd5f40
commit fd3effe712
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 62 additions and 67 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 14 KiB

After

Width:  |  Height:  |  Size: 6.5 KiB

View File

@ -22,9 +22,7 @@ from typing import Any, Optional
from urllib import parse
from flask_babel import gettext as __
from sqlalchemy import Integer, Numeric, types
from sqlalchemy.engine import Inspector
from sqlalchemy.engine.result import Row as ResultRow
from sqlalchemy import Float, Integer, Numeric, types
from sqlalchemy.engine.url import URL
from sqlalchemy.sql.type_api import TypeEngine
@ -45,10 +43,26 @@ class TINYINT(Integer):
__visit_name__ = "TINYINT"
class DOUBLE(Numeric):
class LARGEINT(Integer):
__visit_name__ = "LARGEINT"
class DOUBLE(Float):
__visit_name__ = "DOUBLE"
class HLL(Numeric):
__visit_name__ = "HLL"
class BITMAP(Numeric):
__visit_name__ = "BITMAP"
class PERCENTILE(Numeric):
__visit_name__ = "PERCENTILE"
class ARRAY(TypeEngine): # pylint: disable=no-init
__visit_name__ = "ARRAY"
@ -88,6 +102,11 @@ class StarRocksEngineSpec(MySQLEngineSpec):
TINYINT(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^largeint", re.IGNORECASE),
LARGEINT(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^decimal.*", re.IGNORECASE),
types.DECIMAL(),
@ -108,11 +127,23 @@ class StarRocksEngineSpec(MySQLEngineSpec):
types.CHAR(),
GenericDataType.STRING,
),
(
re.compile(r"^json", re.IGNORECASE),
types.JSON(),
GenericDataType.STRING,
),
(
re.compile(r"^binary.*", re.IGNORECASE),
types.String(),
GenericDataType.STRING,
),
(
re.compile(r"^percentile", re.IGNORECASE),
PERCENTILE(),
GenericDataType.STRING,
),
(re.compile(r"^hll", re.IGNORECASE), HLL(), GenericDataType.STRING),
(re.compile(r"^bitmap", re.IGNORECASE), BITMAP(), GenericDataType.STRING),
(re.compile(r"^array.*", re.IGNORECASE), ARRAY(), GenericDataType.STRING),
(re.compile(r"^map.*", re.IGNORECASE), MAP(), GenericDataType.STRING),
(re.compile(r"^struct.*", re.IGNORECASE), STRUCT(), GenericDataType.STRING),
@ -145,62 +176,11 @@ class StarRocksEngineSpec(MySQLEngineSpec):
if "." in database:
database = database.split(".")[0] + "." + schema
else:
database += "." + schema
database = "default_catalog." + schema
uri = uri.set(database=database)
return uri, connect_args
@classmethod
def get_columns(
cls, inspector: Inspector, table_name: str, schema: Optional[str]
) -> list[dict[str, Any]]:
columns = cls._show_columns(inspector, table_name, schema)
result: list[dict[str, Any]] = []
for column in columns:
column_spec = cls.get_column_spec(column.Type)
column_type = column_spec.sqla_type if column_spec else None
if column_type is None:
column_type = types.String()
logger.info(
"Did not recognize starrocks type %s of column %s",
str(column.Type),
str(column.Field),
)
column_info = cls._create_column_info(column.Field, column_type)
column_info["nullable"] = getattr(column, "Null", True)
column_info["default"] = None
result.append(column_info)
return result
@classmethod
def _show_columns(
cls, inspector: Inspector, table_name: str, schema: Optional[str]
) -> list[ResultRow]:
"""
Show starrocks column names
:param inspector: object that performs database schema inspection
:param table_name: table name
:param schema: schema name
:return: list of column objects
"""
quote = inspector.engine.dialect.identifier_preparer.quote_identifier
full_table = quote(table_name)
if schema:
full_table = f"{quote(schema)}.{full_table}"
return inspector.bind.execute(f"SHOW COLUMNS FROM {full_table}").fetchall()
@classmethod
def _create_column_info(
cls, name: str, data_type: types.TypeEngine
) -> dict[str, Any]:
"""
Create column info object
:param name: column name
:param data_type: column data type
:return: column info object
"""
return {"name": name, "type": f"{data_type}"}
@classmethod
def get_schema_from_engine_params(
cls,

View File

@ -18,10 +18,20 @@
from typing import Any, Optional
import pytest
from sqlalchemy import types
from sqlalchemy import JSON, types
from sqlalchemy.engine.url import make_url
from superset.db_engine_specs.starrocks import ARRAY, DOUBLE, MAP, STRUCT, TINYINT
from superset.db_engine_specs.starrocks import (
ARRAY,
BITMAP,
DOUBLE,
HLL,
LARGEINT,
MAP,
PERCENTILE,
STRUCT,
TINYINT,
)
from superset.utils.core import GenericDataType
from tests.unit_tests.db_engine_specs.utils import assert_column_spec
@ -30,17 +40,22 @@ from tests.unit_tests.db_engine_specs.utils import assert_column_spec
"native_type,sqla_type,attrs,generic_type,is_dttm",
[
# Numeric
("TINYINT", TINYINT, None, GenericDataType.NUMERIC, False),
("DECIMAL", types.DECIMAL, None, GenericDataType.NUMERIC, False),
("DOUBLE", DOUBLE, None, GenericDataType.NUMERIC, False),
("tinyint", TINYINT, None, GenericDataType.NUMERIC, False),
("largeint", LARGEINT, None, GenericDataType.NUMERIC, False),
("decimal(38,18)", types.DECIMAL, None, GenericDataType.NUMERIC, False),
("double", DOUBLE, None, GenericDataType.NUMERIC, False),
# String
("CHAR", types.CHAR, None, GenericDataType.STRING, False),
("VARCHAR", types.VARCHAR, None, GenericDataType.STRING, False),
("BINARY", types.String, None, GenericDataType.STRING, False),
("char(10)", types.CHAR, None, GenericDataType.STRING, False),
("varchar(65533)", types.VARCHAR, None, GenericDataType.STRING, False),
("binary", types.String, None, GenericDataType.STRING, False),
# Complex type
("ARRAY", ARRAY, None, GenericDataType.STRING, False),
("MAP", MAP, None, GenericDataType.STRING, False),
("STRUCT", STRUCT, None, GenericDataType.STRING, False),
("array<varchar(65533)>", ARRAY, None, GenericDataType.STRING, False),
("map<string,int>", MAP, None, GenericDataType.STRING, False),
("struct<int,string>", STRUCT, None, GenericDataType.STRING, False),
("json", JSON, None, GenericDataType.STRING, False),
("bitmap", BITMAP, None, GenericDataType.STRING, False),
("hll", HLL, None, GenericDataType.STRING, False),
("percentile", PERCENTILE, None, GenericDataType.STRING, False),
],
)
def test_get_column_spec(