diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index de9f4d1fd3..138b0e5d5c 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -900,7 +900,7 @@ class SqlaTable(Model, BaseDatasource): for col in table.columns: try: - datatype = col.type.compile(dialect=db_dialect).upper() + datatype = db_engine_spec.column_datatype_to_string(col.type, db_dialect) except Exception as e: datatype = 'UNKNOWN' logging.error( diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py index 5136bcac81..922e7cc796 100644 --- a/superset/db_engine_specs.py +++ b/superset/db_engine_specs.py @@ -529,6 +529,10 @@ class BaseEngineSpec(object): label = label[:cls.max_column_name_length] return label + @classmethod + def column_datatype_to_string(cls, sqla_column_type, dialect): + return sqla_column_type.compile(dialect=dialect).upper() + class PostgresBaseEngineSpec(BaseEngineSpec): """ Abstract class for Postgres 'like' databases """ @@ -864,6 +868,17 @@ class MySQLEngineSpec(BaseEngineSpec): pass return message + @classmethod + def column_datatype_to_string(cls, sqla_column_type, dialect): + datatype = super().column_datatype_to_string(sqla_column_type, dialect) + # MySQL dialect started returning long overflowing datatype + # as in 'VARCHAR(255) COLLATE UTF8MB4_GENERAL_CI' + # and we don't need the verbose collation type + str_cutoff = ' COLLATE ' + if str_cutoff in datatype: + datatype = datatype.split(str_cutoff)[0] + return datatype + class PrestoEngineSpec(BaseEngineSpec): engine = 'presto' diff --git a/tests/base_tests.py b/tests/base_tests.py index 6de082afc0..b24193c1cb 100644 --- a/tests/base_tests.py +++ b/tests/base_tests.py @@ -174,12 +174,15 @@ class SupersetTestCase(unittest.TestCase): perm.view_menu and table.perm in perm.view_menu.name): security_manager.del_permission_role(public_role, perm) + def get_main_database(self): + return get_main_database(db.session) + def run_sql(self, sql, client_id=None, user_name=None, raise_on_error=False, query_limit=None): if user_name: self.logout() self.login(username=(user_name if user_name else 'admin')) - dbid = get_main_database(db.session).id + dbid = self.get_main_database().id resp = self.get_json_resp( '/superset/sql_json/', raise_on_error=False, @@ -195,7 +198,7 @@ class SupersetTestCase(unittest.TestCase): if user_name: self.logout() self.login(username=(user_name if user_name else 'admin')) - dbid = get_main_database(db.session).id + dbid = self.get_main_database().id resp = self.get_json_resp( '/superset/validate_sql_json/', raise_on_error=False, diff --git a/tests/db_engine_specs_test.py b/tests/db_engine_specs_test.py index 44919143d8..aae73807e9 100644 --- a/tests/db_engine_specs_test.py +++ b/tests/db_engine_specs_test.py @@ -814,3 +814,17 @@ class DbEngineSpecsTestCase(SupersetTestCase): expr = PinotEngineSpec.get_timestamp_expr(col, 'epoch_s', 'P1M') result = str(expr.compile()) self.assertEqual(result, 'DATETIMECONVERT(tstamp, "1:SECONDS:EPOCH", "1:SECONDS:EPOCH", "1:MONTHS")') # noqa + + def test_column_datatype_to_string(self): + main_db = self.get_main_database() + sqla_table = main_db.get_table('energy_usage') + dialect = main_db.get_dialect() + col_names = [ + main_db.db_engine_spec.column_datatype_to_string(c.type, dialect) + for c in sqla_table.columns + ] + if main_db.backend == 'postgresql': + expected = ['VARCHAR(255)', 'VARCHAR(255)', 'DOUBLE PRECISION'] + else: + expected = ['VARCHAR(255)', 'VARCHAR(255)', 'FLOAT'] + self.assertEquals(col_names, expected)