diff --git a/scripts/benchmark_migration.py b/scripts/benchmark_migration.py index 27670b5d4d..baae8befec 100644 --- a/scripts/benchmark_migration.py +++ b/scripts/benchmark_migration.py @@ -102,7 +102,10 @@ def find_models(module: ModuleType) -> List[Type[Model]]: while tables: table = tables.pop() seen.add(table) - model = getattr(Base.classes, table) + try: + model = getattr(Base.classes, table) + except AttributeError: + continue model.__tablename__ = table models.append(model) diff --git a/setup.py b/setup.py index 81068e1f86..d2835abb7d 100644 --- a/setup.py +++ b/setup.py @@ -111,6 +111,7 @@ setup( "slackclient==2.5.0", # PINNED! slack changes file upload api in the future versions "sqlalchemy>=1.3.16, <1.4, !=1.3.21", "sqlalchemy-utils>=0.37.8, <0.38", + "sqloxide==0.1.15", "sqlparse==0.3.0", # PINNED! see https://github.com/andialbrecht/sqlparse/issues/562 "tabulate==0.8.9", # needed to support Literal (3.8) and TypeGuard (3.10) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 3f6fee0043..5b89919d0e 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -78,6 +78,7 @@ from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetr from superset.connectors.sqla.utils import ( get_physical_table_metadata, get_virtual_table_metadata, + load_or_create_tables, validate_adhoc_subquery, ) from superset.datasets.models import Dataset as NewDataset @@ -2242,7 +2243,10 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho if column.is_active is False: continue - extra_json = json.loads(column.extra or "{}") + try: + extra_json = json.loads(column.extra or "{}") + except json.decoder.JSONDecodeError: + extra_json = {} for attr in {"groupby", "filterable", "verbose_name", "python_date_format"}: value = getattr(column, attr) if value: @@ -2269,7 +2273,10 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho # create metrics for metric in dataset.metrics: - extra_json = json.loads(metric.extra or "{}") + try: + extra_json = json.loads(metric.extra or "{}") + except json.decoder.JSONDecodeError: + extra_json = {} for attr in {"verbose_name", "metric_type", "d3format"}: value = getattr(metric, attr) if value: @@ -2300,8 +2307,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho ) # physical dataset - tables = [] - if dataset.sql is None: + if not dataset.sql: physical_columns = [column for column in columns if column.is_physical] # create table @@ -2314,7 +2320,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho is_managed_externally=dataset.is_managed_externally, external_url=dataset.external_url, ) - tables.append(table) + tables = [table] # virtual dataset else: @@ -2325,18 +2331,14 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho # find referenced tables parsed = ParsedQuery(dataset.sql) referenced_tables = parsed.tables - - # predicate for finding the referenced tables - predicate = or_( - *[ - and_( - NewTable.schema == (table.schema or dataset.schema), - NewTable.name == table.table, - ) - for table in referenced_tables - ] + tables = load_or_create_tables( + session, + dataset.database_id, + dataset.schema, + referenced_tables, + conditional_quote, + engine, ) - tables = session.query(NewTable).filter(predicate).all() # create the new dataset new_dataset = NewDataset( @@ -2345,7 +2347,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho expression=dataset.sql or conditional_quote(dataset.table_name), tables=tables, columns=columns, - is_physical=dataset.sql is None, + is_physical=not dataset.sql, is_managed_externally=dataset.is_managed_externally, external_url=dataset.external_url, ) diff --git a/superset/connectors/sqla/utils.py b/superset/connectors/sqla/utils.py index 389c5b9012..4fc11a4d1d 100644 --- a/superset/connectors/sqla/utils.py +++ b/superset/connectors/sqla/utils.py @@ -15,13 +15,17 @@ # specific language governing permissions and limitations # under the License. from contextlib import closing -from typing import Dict, List, Optional, TYPE_CHECKING +from typing import Callable, Dict, List, Optional, Set, TYPE_CHECKING import sqlparse from flask_babel import lazy_gettext as _ +from sqlalchemy import and_, inspect, or_ +from sqlalchemy.engine import Engine from sqlalchemy.exc import NoSuchTableError +from sqlalchemy.orm import Session from sqlalchemy.sql.type_api import TypeEngine +from superset.columns.models import Column as NewColumn from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import ( SupersetGenericDBErrorException, @@ -29,12 +33,16 @@ from superset.exceptions import ( ) from superset.models.core import Database from superset.result_set import SupersetResultSet -from superset.sql_parse import has_table_query, ParsedQuery +from superset.sql_parse import has_table_query, ParsedQuery, Table +from superset.tables.models import Table as NewTable if TYPE_CHECKING: from superset.connectors.sqla.models import SqlaTable +TEMPORAL_TYPES = {"DATETIME", "DATE", "TIME", "TIMEDELTA"} + + def get_physical_table_metadata( database: Database, table_name: str, @@ -151,3 +159,78 @@ def validate_adhoc_subquery(raw_sql: str) -> None: ) ) return + + +def load_or_create_tables( # pylint: disable=too-many-arguments + session: Session, + database_id: int, + default_schema: Optional[str], + tables: Set[Table], + conditional_quote: Callable[[str], str], + engine: Engine, +) -> List[NewTable]: + """ + Load or create new table model instances. + """ + if not tables: + return [] + + # set the default schema in tables that don't have it + if default_schema: + fixed_tables = list(tables) + for i, table in enumerate(fixed_tables): + if table.schema is None: + fixed_tables[i] = Table(table.table, default_schema, table.catalog) + tables = set(fixed_tables) + + # load existing tables + predicate = or_( + *[ + and_( + NewTable.database_id == database_id, + NewTable.schema == table.schema, + NewTable.name == table.table, + ) + for table in tables + ] + ) + new_tables = session.query(NewTable).filter(predicate).all() + + # add missing tables + existing = {(table.schema, table.name) for table in new_tables} + for table in tables: + if (table.schema, table.table) not in existing: + try: + inspector = inspect(engine) + column_metadata = inspector.get_columns( + table.table, schema=table.schema + ) + except Exception: # pylint: disable=broad-except + continue + columns = [ + NewColumn( + name=column["name"], + type=str(column["type"]), + expression=conditional_quote(column["name"]), + is_temporal=column["type"].python_type.__name__.upper() + in TEMPORAL_TYPES, + is_aggregation=False, + is_physical=True, + is_spatial=False, + is_partition=False, + is_increase_desired=True, + ) + for column in column_metadata + ] + new_tables.append( + NewTable( + name=table.table, + schema=table.schema, + catalog=None, + database_id=database_id, + columns=columns, + ) + ) + existing.add((table.schema, table.table)) + + return new_tables diff --git a/superset/migrations/shared/utils.py b/superset/migrations/shared/utils.py index 0331718117..bff25e05d1 100644 --- a/superset/migrations/shared/utils.py +++ b/superset/migrations/shared/utils.py @@ -14,10 +14,39 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import logging +from typing import Any, Iterator, Optional, Set + from alembic import op from sqlalchemy import engine_from_config from sqlalchemy.engine import reflection from sqlalchemy.exc import NoSuchTableError +from sqloxide import parse_sql + +from superset.sql_parse import ParsedQuery, Table + +logger = logging.getLogger("alembic") + + +# mapping between sqloxide and SQLAlchemy dialects +sqloxide_dialects = { + "ansi": {"trino", "trinonative", "presto"}, + "hive": {"hive", "databricks"}, + "ms": {"mssql"}, + "mysql": {"mysql"}, + "postgres": { + "cockroachdb", + "hana", + "netezza", + "postgres", + "postgresql", + "redshift", + "vertica", + }, + "snowflake": {"snowflake"}, + "sqlite": {"sqlite", "gsheets", "shillelagh"}, + "clickhouse": {"clickhouse"}, +} def table_has_column(table: str, column: str) -> bool: @@ -38,3 +67,40 @@ def table_has_column(table: str, column: str) -> bool: return any(col["name"] == column for col in insp.get_columns(table)) except NoSuchTableError: return False + + +def find_nodes_by_key(element: Any, target: str) -> Iterator[Any]: + """ + Find all nodes in a SQL tree matching a given key. + """ + if isinstance(element, list): + for child in element: + yield from find_nodes_by_key(child, target) + elif isinstance(element, dict): + for key, value in element.items(): + if key == target: + yield value + else: + yield from find_nodes_by_key(value, target) + + +def extract_table_references(sql_text: str, sqla_dialect: str) -> Set[Table]: + """ + Return all the dependencies from a SQL sql_text. + """ + dialect = "generic" + for dialect, sqla_dialects in sqloxide_dialects.items(): + if sqla_dialect in sqla_dialects: + break + try: + tree = parse_sql(sql_text, dialect=dialect) + except Exception: # pylint: disable=broad-except + logger.warning("Unable to parse query with sqloxide: %s", sql_text) + # fallback to sqlparse + parsed = ParsedQuery(sql_text) + return parsed.tables + + return { + Table(*[part["value"] for part in table["name"][::-1]]) + for table in find_nodes_by_key(tree, "Table") + } diff --git a/superset/migrations/versions/b8d3a24d9131_new_dataset_models.py b/superset/migrations/versions/b8d3a24d9131_new_dataset_models.py index 35419e0066..75f5293034 100644 --- a/superset/migrations/versions/b8d3a24d9131_new_dataset_models.py +++ b/superset/migrations/versions/b8d3a24d9131_new_dataset_models.py @@ -25,7 +25,7 @@ Create Date: 2021-11-11 16:41:53.266965 """ import json -from typing import List +from typing import Callable, List, Optional, Set from uuid import uuid4 import sqlalchemy as sa @@ -40,7 +40,9 @@ from sqlalchemy_utils import UUIDType from superset import app, db from superset.connectors.sqla.models import ADDITIVE_METRIC_TYPES from superset.extensions import encrypted_field_factory -from superset.sql_parse import ParsedQuery +from superset.migrations.shared.utils import extract_table_references +from superset.models.core import Database as OriginalDatabase +from superset.sql_parse import Table # revision identifiers, used by Alembic. revision = "b8d3a24d9131" @@ -228,6 +230,85 @@ class NewDataset(Base): external_url = sa.Column(sa.Text, nullable=True) +TEMPORAL_TYPES = {"DATETIME", "DATE", "TIME", "TIMEDELTA"} + + +def load_or_create_tables( + session: Session, + database_id: int, + default_schema: Optional[str], + tables: Set[Table], + conditional_quote: Callable[[str], str], +) -> List[NewTable]: + """ + Load or create new table model instances. + """ + if not tables: + return [] + + # set the default schema in tables that don't have it + if default_schema: + tables = list(tables) + for i, table in enumerate(tables): + if table.schema is None: + tables[i] = Table(table.table, default_schema, table.catalog) + + # load existing tables + predicate = or_( + *[ + and_( + NewTable.database_id == database_id, + NewTable.schema == table.schema, + NewTable.name == table.table, + ) + for table in tables + ] + ) + new_tables = session.query(NewTable).filter(predicate).all() + + # use original database model to get the engine + engine = ( + session.query(OriginalDatabase) + .filter_by(id=database_id) + .one() + .get_sqla_engine(default_schema) + ) + inspector = inspect(engine) + + # add missing tables + existing = {(table.schema, table.name) for table in new_tables} + for table in tables: + if (table.schema, table.table) not in existing: + column_metadata = inspector.get_columns(table.table, schema=table.schema) + columns = [ + NewColumn( + name=column["name"], + type=str(column["type"]), + expression=conditional_quote(column["name"]), + is_temporal=column["type"].python_type.__name__.upper() + in TEMPORAL_TYPES, + is_aggregation=False, + is_physical=True, + is_spatial=False, + is_partition=False, + is_increase_desired=True, + ) + for column in column_metadata + ] + new_tables.append( + NewTable( + name=table.table, + schema=table.schema, + catalog=None, + database_id=database_id, + columns=columns, + ) + ) + existing.add((table.schema, table.table)) + + return new_tables + + def after_insert(target: SqlaTable) -> None: # pylint: disable=too-many-locals """ Copy old datasets to the new models. @@ -253,7 +334,10 @@ def after_insert(target: SqlaTable) -> None: # pylint: disable=too-many-locals if column.is_active is False: continue - extra_json = json.loads(column.extra or "{}") + try: + extra_json = json.loads(column.extra or "{}") + except json.decoder.JSONDecodeError: + extra_json = {} for attr in {"groupby", "filterable", "verbose_name", "python_date_format"}: value = getattr(column, attr) if value: @@ -279,7 +363,10 @@ def after_insert(target: SqlaTable) -> None: # pylint: disable=too-many-locals # create metrics for metric in target.metrics: - extra_json = json.loads(metric.extra or "{}") + try: + extra_json = json.loads(metric.extra or "{}") + except json.decoder.JSONDecodeError: + extra_json = {} for attr in {"verbose_name", "metric_type", "d3format"}: value = getattr(metric, attr) if value: @@ -309,8 +396,7 @@ def after_insert(target: SqlaTable) -> None: # pylint: disable=too-many-locals ) # physical dataset - tables = [] - if target.sql is None: + if not target.sql: physical_columns = [column for column in columns if column.is_physical] # create table @@ -323,7 +409,7 @@ def after_insert(target: SqlaTable) -> None: # pylint: disable=too-many-locals is_managed_externally=target.is_managed_externally, external_url=target.external_url, ) - tables.append(table) + tables = [table] # virtual dataset else: @@ -332,20 +418,14 @@ def after_insert(target: SqlaTable) -> None: # pylint: disable=too-many-locals column.is_physical = False # find referenced tables - parsed = ParsedQuery(target.sql) - referenced_tables = parsed.tables - - # predicate for finding the referenced tables - predicate = or_( - *[ - and_( - NewTable.schema == (table.schema or target.schema), - NewTable.name == table.table, - ) - for table in referenced_tables - ] + referenced_tables = extract_table_references(target.sql, dialect_class.name) + tables = load_or_create_tables( + session, + target.database_id, + target.schema, + referenced_tables, + conditional_quote, ) - tables = session.query(NewTable).filter(predicate).all() # create the new dataset dataset = NewDataset( @@ -354,7 +434,7 @@ def after_insert(target: SqlaTable) -> None: # pylint: disable=too-many-locals expression=target.sql or conditional_quote(target.table_name), tables=tables, columns=columns, - is_physical=target.sql is None, + is_physical=not target.sql, is_managed_externally=target.is_managed_externally, external_url=target.external_url, ) diff --git a/tests/unit_tests/migrations/__init__.py b/tests/unit_tests/migrations/__init__.py new file mode 100644 index 0000000000..13a83393a9 --- /dev/null +++ b/tests/unit_tests/migrations/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/migrations/shared/__init__.py b/tests/unit_tests/migrations/shared/__init__.py new file mode 100644 index 0000000000..13a83393a9 --- /dev/null +++ b/tests/unit_tests/migrations/shared/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/migrations/shared/utils_test.py b/tests/unit_tests/migrations/shared/utils_test.py new file mode 100644 index 0000000000..cb5b2cbd0e --- /dev/null +++ b/tests/unit_tests/migrations/shared/utils_test.py @@ -0,0 +1,56 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=import-outside-toplevel, unused-argument + +""" +Test the SIP-68 migration. +""" + +from pytest_mock import MockerFixture + +from superset.sql_parse import Table + + +def test_extract_table_references(mocker: MockerFixture, app_context: None) -> None: + """ + Test the ``extract_table_references`` helper function. + """ + from superset.migrations.shared.utils import extract_table_references + + assert extract_table_references("SELECT 1", "trino") == set() + assert extract_table_references("SELECT 1 FROM some_table", "trino") == { + Table(table="some_table", schema=None, catalog=None) + } + assert extract_table_references( + "SELECT 1 FROM some_catalog.some_schema.some_table", "trino" + ) == {Table(table="some_table", schema="some_schema", catalog="some_catalog")} + assert extract_table_references( + "SELECT * FROM some_table JOIN other_table ON some_table.id = other_table.id", + "trino", + ) == { + Table(table="some_table", schema=None, catalog=None), + Table(table="other_table", schema=None, catalog=None), + } + + # test falling back to sqlparse + logger = mocker.patch("superset.migrations.shared.utils.logger") + sql = "SELECT * FROM table UNION ALL SELECT * FROM other_table" + assert extract_table_references( + sql, + "trino", + ) == {Table(table="other_table", schema=None, catalog=None)} + logger.warning.assert_called_with("Unable to parse query with sqloxide: %s", sql)