From 231716cb50983b04178602b86c846b7673f9d8c3 Mon Sep 17 00:00:00 2001 From: Jesse Yang Date: Tue, 19 Apr 2022 18:58:18 -0700 Subject: [PATCH] perf: refactor SIP-68 db migrations with INSERT SELECT FROM (#19421) --- superset/columns/models.py | 82 +- superset/connectors/base/models.py | 6 +- superset/connectors/sqla/models.py | 731 ++++++------ superset/connectors/sqla/utils.py | 123 +- superset/datasets/models.py | 82 +- superset/examples/birth_names.py | 17 +- superset/migrations/shared/utils.py | 111 +- ...b176a0_add_import_mixing_to_saved_query.py | 6 +- superset/migrations/versions/9d8a8d575284_.py | 2 +- .../a9422eeaae74_new_dataset_models_take_2.py | 905 ++++++++++++++ ...0de1855_add_uuid_column_to_import_mixin.py | 49 +- .../b8d3a24d9131_new_dataset_models.py | 616 +--------- .../c501b7c653a3_add_missing_uuid_column.py | 4 +- ...95_migrate_native_filters_to_new_schema.py | 2 +- superset/models/core.py | 8 +- superset/models/helpers.py | 15 +- superset/sql_lab.py | 2 +- superset/sql_parse.py | 83 +- superset/tables/models.py | 136 ++- tests/integration_tests/commands_test.py | 20 +- .../fixtures/world_bank_dashboard.py | 3 +- tests/integration_tests/sqla_models_tests.py | 5 +- tests/integration_tests/utils_tests.py | 1 - tests/unit_tests/conftest.py | 34 +- tests/unit_tests/datasets/conftest.py | 118 ++ tests/unit_tests/datasets/test_models.py | 1048 +++++++---------- .../unit_tests/migrations/shared/__init__.py | 16 - .../migrations/shared/utils_test.py | 56 - tests/unit_tests/sql_parse_tests.py | 49 + .../{migrations/__init__.py => utils/db.py} | 14 + 30 files changed, 2356 insertions(+), 1988 deletions(-) create mode 100644 superset/migrations/versions/a9422eeaae74_new_dataset_models_take_2.py create mode 100644 tests/unit_tests/datasets/conftest.py delete mode 100644 tests/unit_tests/migrations/shared/__init__.py delete mode 100644 tests/unit_tests/migrations/shared/utils_test.py rename tests/unit_tests/{migrations/__init__.py => utils/db.py} (69%) diff --git a/superset/columns/models.py b/superset/columns/models.py index fbe045e3d3..bfee3de859 100644 --- a/superset/columns/models.py +++ b/superset/columns/models.py @@ -23,7 +23,6 @@ tables, metrics, and datasets were also introduced. These models are not fully implemented, and shouldn't be used yet. """ - import sqlalchemy as sa from flask_appbuilder import Model @@ -33,6 +32,8 @@ from superset.models.helpers import ( ImportExportMixin, ) +UNKOWN_TYPE = "UNKNOWN" + class Column( Model, @@ -52,51 +53,58 @@ class Column( id = sa.Column(sa.Integer, primary_key=True) - # We use ``sa.Text`` for these attributes because (1) in modern databases the - # performance is the same as ``VARCHAR``[1] and (2) because some table names can be - # **really** long (eg, Google Sheets URLs). - # - # [1] https://www.postgresql.org/docs/9.1/datatype-character.html - name = sa.Column(sa.Text) - type = sa.Column(sa.Text) - - # Columns are defined by expressions. For tables, these are the actual columns names, - # and should match the ``name`` attribute. For datasets, these can be any valid SQL - # expression. If the SQL expression is an aggregation the column is a metric, - # otherwise it's a computed column. - expression = sa.Column(sa.Text) - - # Does the expression point directly to a physical column? - is_physical = sa.Column(sa.Boolean, default=True) - - # Additional metadata describing the column. - description = sa.Column(sa.Text) - warning_text = sa.Column(sa.Text) - unit = sa.Column(sa.Text) - - # Is this a time column? Useful for plotting time series. - is_temporal = sa.Column(sa.Boolean, default=False) - - # Is this a spatial column? This could be leveraged in the future for spatial - # visualizations. - is_spatial = sa.Column(sa.Boolean, default=False) - - # Is this column a partition? Useful for scheduling queries and previewing the latest - # data. - is_partition = sa.Column(sa.Boolean, default=False) - - # Is this column an aggregation (metric)? - is_aggregation = sa.Column(sa.Boolean, default=False) - # Assuming the column is an aggregation, is it additive? Useful for determining which # aggregations can be done on the metric. Eg, ``COUNT(DISTINCT user_id)`` is not # additive, so it shouldn't be used in a ``SUM``. is_additive = sa.Column(sa.Boolean, default=False) + # Is this column an aggregation (metric)? + is_aggregation = sa.Column(sa.Boolean, default=False) + + is_filterable = sa.Column(sa.Boolean, nullable=False, default=True) + is_dimensional = sa.Column(sa.Boolean, nullable=False, default=False) + # Is an increase desired? Useful for displaying the results of A/B tests, or setting # up alerts. Eg, this is true for "revenue", but false for "latency". is_increase_desired = sa.Column(sa.Boolean, default=True) # Column is managed externally and should be read-only inside Superset is_managed_externally = sa.Column(sa.Boolean, nullable=False, default=False) + + # Is this column a partition? Useful for scheduling queries and previewing the latest + # data. + is_partition = sa.Column(sa.Boolean, default=False) + + # Does the expression point directly to a physical column? + is_physical = sa.Column(sa.Boolean, default=True) + + # Is this a spatial column? This could be leveraged in the future for spatial + # visualizations. + is_spatial = sa.Column(sa.Boolean, default=False) + + # Is this a time column? Useful for plotting time series. + is_temporal = sa.Column(sa.Boolean, default=False) + + # We use ``sa.Text`` for these attributes because (1) in modern databases the + # performance is the same as ``VARCHAR``[1] and (2) because some table names can be + # **really** long (eg, Google Sheets URLs). + # + # [1] https://www.postgresql.org/docs/9.1/datatype-character.html + name = sa.Column(sa.Text) + # Raw type as returned and used by db engine. + type = sa.Column(sa.Text, default=UNKOWN_TYPE) + + # Columns are defined by expressions. For tables, these are the actual columns names, + # and should match the ``name`` attribute. For datasets, these can be any valid SQL + # expression. If the SQL expression is an aggregation the column is a metric, + # otherwise it's a computed column. + expression = sa.Column(sa.Text) + unit = sa.Column(sa.Text) + + # Additional metadata describing the column. + description = sa.Column(sa.Text) + warning_text = sa.Column(sa.Text) external_url = sa.Column(sa.Text, nullable=True) + + def __repr__(self) -> str: + return f"" diff --git a/superset/connectors/base/models.py b/superset/connectors/base/models.py index 9aacb0dc8c..3d22857912 100644 --- a/superset/connectors/base/models.py +++ b/superset/connectors/base/models.py @@ -31,7 +31,7 @@ from superset.models.helpers import AuditMixinNullable, ImportExportMixin, Query from superset.models.slice import Slice from superset.superset_typing import FilterValue, FilterValues, QueryObjectDict from superset.utils import core as utils -from superset.utils.core import GenericDataType +from superset.utils.core import GenericDataType, MediumText METRIC_FORM_DATA_PARAMS = [ "metric", @@ -586,7 +586,7 @@ class BaseColumn(AuditMixinNullable, ImportExportMixin): type = Column(Text) groupby = Column(Boolean, default=True) filterable = Column(Boolean, default=True) - description = Column(Text) + description = Column(MediumText()) is_dttm = None # [optional] Set this to support import/export functionality @@ -672,7 +672,7 @@ class BaseMetric(AuditMixinNullable, ImportExportMixin): metric_name = Column(String(255), nullable=False) verbose_name = Column(String(1024)) metric_type = Column(String(32)) - description = Column(Text) + description = Column(MediumText()) d3format = Column(String(128)) warning_text = Column(Text) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index d7d62db2a7..e0382c6595 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -24,6 +24,7 @@ from dataclasses import dataclass, field from datetime import datetime, timedelta from typing import ( Any, + Callable, cast, Dict, Hashable, @@ -34,6 +35,7 @@ from typing import ( Type, Union, ) +from uuid import uuid4 import dateutil.parser import numpy as np @@ -72,13 +74,13 @@ from sqlalchemy.sql.expression import Label, Select, TextAsFrom from sqlalchemy.sql.selectable import Alias, TableClause from superset import app, db, is_feature_enabled, security_manager -from superset.columns.models import Column as NewColumn +from superset.columns.models import Column as NewColumn, UNKOWN_TYPE from superset.common.db_query_status import QueryStatus from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric from superset.connectors.sqla.utils import ( + find_cached_objects_in_session, get_physical_table_metadata, get_virtual_table_metadata, - load_or_create_tables, validate_adhoc_subquery, ) from superset.datasets.models import Dataset as NewDataset @@ -100,7 +102,12 @@ from superset.models.helpers import ( clone_model, QueryResult, ) -from superset.sql_parse import ParsedQuery, sanitize_clause +from superset.sql_parse import ( + extract_table_references, + ParsedQuery, + sanitize_clause, + Table as TableName, +) from superset.superset_typing import ( AdhocColumn, AdhocMetric, @@ -114,6 +121,7 @@ from superset.utils.core import ( GenericDataType, get_column_name, is_adhoc_column, + MediumText, QueryObjectFilterClause, remove_duplicates, ) @@ -130,6 +138,7 @@ ADDITIVE_METRIC_TYPES = { "sum", "doubleSum", } +ADDITIVE_METRIC_TYPES_LOWER = {op.lower() for op in ADDITIVE_METRIC_TYPES} class SqlaQuery(NamedTuple): @@ -215,13 +224,13 @@ class TableColumn(Model, BaseColumn, CertificationMixin): __tablename__ = "table_columns" __table_args__ = (UniqueConstraint("table_id", "column_name"),) table_id = Column(Integer, ForeignKey("tables.id")) - table = relationship( + table: "SqlaTable" = relationship( "SqlaTable", backref=backref("columns", cascade="all, delete-orphan"), foreign_keys=[table_id], ) is_dttm = Column(Boolean, default=False) - expression = Column(Text) + expression = Column(MediumText()) python_date_format = Column(String(255)) extra = Column(Text) @@ -417,6 +426,59 @@ class TableColumn(Model, BaseColumn, CertificationMixin): return attr_dict + def to_sl_column( + self, known_columns: Optional[Dict[str, NewColumn]] = None + ) -> NewColumn: + """Convert a TableColumn to NewColumn""" + column = known_columns.get(self.uuid) if known_columns else None + if not column: + column = NewColumn() + + extra_json = self.get_extra_dict() + for attr in { + "verbose_name", + "python_date_format", + }: + value = getattr(self, attr) + if value: + extra_json[attr] = value + + column.uuid = self.uuid + column.created_on = self.created_on + column.changed_on = self.changed_on + column.created_by = self.created_by + column.changed_by = self.changed_by + column.name = self.column_name + column.type = self.type or UNKOWN_TYPE + column.expression = self.expression or self.table.quote_identifier( + self.column_name + ) + column.description = self.description + column.is_aggregation = False + column.is_dimensional = self.groupby + column.is_filterable = self.filterable + column.is_increase_desired = True + column.is_managed_externally = self.table.is_managed_externally + column.is_partition = False + column.is_physical = not self.expression + column.is_spatial = False + column.is_temporal = self.is_dttm + column.extra_json = json.dumps(extra_json) if extra_json else None + column.external_url = self.table.external_url + + return column + + @staticmethod + def after_delete( # pylint: disable=unused-argument + mapper: Mapper, + connection: Connection, + target: "TableColumn", + ) -> None: + session = inspect(target).session + column = session.query(NewColumn).filter_by(uuid=target.uuid).one_or_none() + if column: + session.delete(column) + class SqlMetric(Model, BaseMetric, CertificationMixin): @@ -430,7 +492,7 @@ class SqlMetric(Model, BaseMetric, CertificationMixin): backref=backref("metrics", cascade="all, delete-orphan"), foreign_keys=[table_id], ) - expression = Column(Text, nullable=False) + expression = Column(MediumText(), nullable=False) extra = Column(Text) export_fields = [ @@ -479,6 +541,58 @@ class SqlMetric(Model, BaseMetric, CertificationMixin): attr_dict.update(super().data) return attr_dict + def to_sl_column( + self, known_columns: Optional[Dict[str, NewColumn]] = None + ) -> NewColumn: + """Convert a SqlMetric to NewColumn. Find and update existing or + create a new one.""" + column = known_columns.get(self.uuid) if known_columns else None + if not column: + column = NewColumn() + + extra_json = self.get_extra_dict() + for attr in {"verbose_name", "metric_type", "d3format"}: + value = getattr(self, attr) + if value is not None: + extra_json[attr] = value + is_additive = ( + self.metric_type and self.metric_type.lower() in ADDITIVE_METRIC_TYPES_LOWER + ) + + column.uuid = self.uuid + column.name = self.metric_name + column.created_on = self.created_on + column.changed_on = self.changed_on + column.created_by = self.created_by + column.changed_by = self.changed_by + column.type = UNKOWN_TYPE + column.expression = self.expression + column.warning_text = self.warning_text + column.description = self.description + column.is_aggregation = True + column.is_additive = is_additive + column.is_filterable = False + column.is_increase_desired = True + column.is_managed_externally = self.table.is_managed_externally + column.is_partition = False + column.is_physical = False + column.is_spatial = False + column.extra_json = json.dumps(extra_json) if extra_json else None + column.external_url = self.table.external_url + + return column + + @staticmethod + def after_delete( # pylint: disable=unused-argument + mapper: Mapper, + connection: Connection, + target: "SqlMetric", + ) -> None: + session = inspect(target).session + column = session.query(NewColumn).filter_by(uuid=target.uuid).one_or_none() + if column: + session.delete(column) + sqlatable_user = Table( "sqlatable_user", @@ -544,7 +658,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho foreign_keys=[database_id], ) schema = Column(String(255)) - sql = Column(Text) + sql = Column(MediumText()) is_sqllab_view = Column(Boolean, default=False) template_params = Column(Text) extra = Column(Text) @@ -1731,7 +1845,19 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho metrics = [] any_date_col = None db_engine_spec = self.db_engine_spec - old_columns = db.session.query(TableColumn).filter(TableColumn.table == self) + + # If no `self.id`, then this is a new table, no need to fetch columns + # from db. Passing in `self.id` to query will actually automatically + # generate a new id, which can be tricky during certain transactions. + old_columns = ( + ( + db.session.query(TableColumn) + .filter(TableColumn.table_id == self.id) + .all() + ) + if self.id + else self.columns + ) old_columns_by_name: Dict[str, TableColumn] = { col.column_name: col for col in old_columns @@ -1745,13 +1871,15 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho ) # clear old columns before adding modified columns back - self.columns = [] + columns = [] for col in new_columns: old_column = old_columns_by_name.pop(col["name"], None) if not old_column: results.added.append(col["name"]) new_column = TableColumn( - column_name=col["name"], type=col["type"], table=self + column_name=col["name"], + type=col["type"], + table=self, ) new_column.is_dttm = new_column.is_temporal db_engine_spec.alter_new_orm_column(new_column) @@ -1763,12 +1891,14 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho new_column.expression = "" new_column.groupby = True new_column.filterable = True - self.columns.append(new_column) + columns.append(new_column) if not any_date_col and new_column.is_temporal: any_date_col = col["name"] - self.columns.extend( - [col for col in old_columns_by_name.values() if col.expression] - ) + + # add back calculated (virtual) columns + columns.extend([col for col in old_columns if col.expression]) + self.columns = columns + metrics.append( SqlMetric( metric_name="count", @@ -1854,6 +1984,10 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho extra_cache_keys += sqla_query.extra_cache_keys return extra_cache_keys + @property + def quote_identifier(self) -> Callable[[str], str]: + return self.database.quote_identifier + @staticmethod def before_update( mapper: Mapper, # pylint: disable=unused-argument @@ -1895,14 +2029,44 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho ): raise Exception(get_dataset_exist_error_msg(target.full_name)) + def get_sl_columns(self) -> List[NewColumn]: + """ + Convert `SqlaTable.columns` and `SqlaTable.metrics` to the new Column model + """ + session: Session = inspect(self).session + + uuids = set() + for column_or_metric in self.columns + self.metrics: + # pre-assign uuid after new columns or metrics are inserted so + # the related `NewColumn` can have a deterministic uuid, too + if not column_or_metric.uuid: + column_or_metric.uuid = uuid4() + else: + uuids.add(column_or_metric.uuid) + + # load existing columns from cached session states first + existing_columns = set( + find_cached_objects_in_session(session, NewColumn, uuids=uuids) + ) + for column in existing_columns: + uuids.remove(column.uuid) + + if uuids: + # load those not found from db + existing_columns |= set( + session.query(NewColumn).filter(NewColumn.uuid.in_(uuids)) + ) + + known_columns = {column.uuid: column for column in existing_columns} + return [ + item.to_sl_column(known_columns) for item in self.columns + self.metrics + ] + @staticmethod def update_table( # pylint: disable=unused-argument mapper: Mapper, connection: Connection, target: Union[SqlMetric, TableColumn] ) -> None: """ - Forces an update to the table's changed_on value when a metric or column on the - table is updated. This busts the cache key for all charts that use the table. - :param mapper: Unused. :param connection: Unused. :param target: The metric or column that was updated. @@ -1910,90 +2074,43 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho inspector = inspect(target) session = inspector.session - # get DB-specific conditional quoter for expressions that point to columns or - # table names - database = ( - target.table.database - or session.query(Database).filter_by(id=target.database_id).one() - ) - engine = database.get_sqla_engine(schema=target.table.schema) - conditional_quote = engine.dialect.identifier_preparer.quote - + # Forces an update to the table's changed_on value when a metric or column on the + # table is updated. This busts the cache key for all charts that use the table. session.execute(update(SqlaTable).where(SqlaTable.id == target.table.id)) - dataset = ( - session.query(NewDataset) - .filter_by(sqlatable_id=target.table.id) - .one_or_none() - ) - - if not dataset: - # if dataset is not found create a new copy - # of the dataset instead of updating the existing - - SqlaTable.write_shadow_dataset(target.table, database, session) - return - - # update ``Column`` model as well - if isinstance(target, TableColumn): - columns = [ - column - for column in dataset.columns - if column.name == target.column_name - ] - if not columns: + # if table itself has changed, shadow-writing will happen in `after_udpate` anyway + if target.table not in session.dirty: + dataset: NewDataset = ( + session.query(NewDataset) + .filter_by(uuid=target.table.uuid) + .one_or_none() + ) + # Update shadow dataset and columns + # did we find the dataset? + if not dataset: + # if dataset is not found create a new copy + target.table.write_shadow_dataset() return - column = columns[0] - extra_json = json.loads(target.extra or "{}") - for attr in {"groupby", "filterable", "verbose_name", "python_date_format"}: - value = getattr(target, attr) - if value: - extra_json[attr] = value + # update changed_on timestamp + session.execute(update(NewDataset).where(NewDataset.id == dataset.id)) - column.name = target.column_name - column.type = target.type or "Unknown" - column.expression = target.expression or conditional_quote( - target.column_name + # update `Column` model as well + session.add( + target.to_sl_column( + { + target.uuid: session.query(NewColumn) + .filter_by(uuid=target.uuid) + .one_or_none() + } + ) ) - column.description = target.description - column.is_temporal = target.is_dttm - column.is_physical = target.expression is None - column.extra_json = json.dumps(extra_json) if extra_json else None - - else: # SqlMetric - columns = [ - column - for column in dataset.columns - if column.name == target.metric_name - ] - if not columns: - return - - column = columns[0] - extra_json = json.loads(target.extra or "{}") - for attr in {"verbose_name", "metric_type", "d3format"}: - value = getattr(target, attr) - if value: - extra_json[attr] = value - - is_additive = ( - target.metric_type - and target.metric_type.lower() in ADDITIVE_METRIC_TYPES - ) - - column.name = target.metric_name - column.expression = target.expression - column.warning_text = target.warning_text - column.description = target.description - column.is_additive = is_additive - column.extra_json = json.dumps(extra_json) if extra_json else None @staticmethod def after_insert( mapper: Mapper, connection: Connection, - target: "SqlaTable", + sqla_table: "SqlaTable", ) -> None: """ Shadow write the dataset to new models. @@ -2007,24 +2124,14 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho For more context: https://github.com/apache/superset/issues/14909 """ - session = inspect(target).session - # set permissions - security_manager.set_perm(mapper, connection, target) - - # get DB-specific conditional quoter for expressions that point to columns or - # table names - database = ( - target.database - or session.query(Database).filter_by(id=target.database_id).one() - ) - - SqlaTable.write_shadow_dataset(target, database, session) + security_manager.set_perm(mapper, connection, sqla_table) + sqla_table.write_shadow_dataset() @staticmethod def after_delete( # pylint: disable=unused-argument mapper: Mapper, connection: Connection, - target: "SqlaTable", + sqla_table: "SqlaTable", ) -> None: """ Shadow write the dataset to new models. @@ -2038,18 +2145,18 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho For more context: https://github.com/apache/superset/issues/14909 """ - session = inspect(target).session + session = inspect(sqla_table).session dataset = ( - session.query(NewDataset).filter_by(sqlatable_id=target.id).one_or_none() + session.query(NewDataset).filter_by(uuid=sqla_table.uuid).one_or_none() ) if dataset: session.delete(dataset) @staticmethod - def after_update( # pylint: disable=too-many-branches, too-many-locals, too-many-statements + def after_update( mapper: Mapper, connection: Connection, - target: "SqlaTable", + sqla_table: "SqlaTable", ) -> None: """ Shadow write the dataset to new models. @@ -2063,172 +2170,76 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho For more context: https://github.com/apache/superset/issues/14909 """ - inspector = inspect(target) + # set permissions + security_manager.set_perm(mapper, connection, sqla_table) + + inspector = inspect(sqla_table) session = inspector.session # double-check that ``UPDATE``s are actually pending (this method is called even # for instances that have no net changes to their column-based attributes) - if not session.is_modified(target, include_collections=True): + if not session.is_modified(sqla_table, include_collections=True): return - # set permissions - security_manager.set_perm(mapper, connection, target) - - dataset = ( - session.query(NewDataset).filter_by(sqlatable_id=target.id).one_or_none() + # find the dataset from the known instance list first + # (it could be either from a previous query or newly created) + dataset = next( + find_cached_objects_in_session( + session, NewDataset, uuids=[sqla_table.uuid] + ), + None, ) + # if not found, pull from database if not dataset: + dataset = ( + session.query(NewDataset).filter_by(uuid=sqla_table.uuid).one_or_none() + ) + if not dataset: + sqla_table.write_shadow_dataset() return - # get DB-specific conditional quoter for expressions that point to columns or - # table names - database = ( - target.database - or session.query(Database).filter_by(id=target.database_id).one() - ) - engine = database.get_sqla_engine(schema=target.schema) - conditional_quote = engine.dialect.identifier_preparer.quote - - # update columns - if inspector.attrs.columns.history.has_changes(): - # handle deleted columns - if inspector.attrs.columns.history.deleted: - column_names = { - column.column_name - for column in inspector.attrs.columns.history.deleted - } - dataset.columns = [ - column - for column in dataset.columns - if column.name not in column_names - ] - - # handle inserted columns - for column in inspector.attrs.columns.history.added: - # ``is_active`` might be ``None``, but it defaults to ``True``. - if column.is_active is False: - continue - - extra_json = json.loads(column.extra or "{}") - for attr in { - "groupby", - "filterable", - "verbose_name", - "python_date_format", - }: - value = getattr(column, attr) - if value: - extra_json[attr] = value - - dataset.columns.append( - NewColumn( - name=column.column_name, - type=column.type or "Unknown", - expression=column.expression - or conditional_quote(column.column_name), - description=column.description, - is_temporal=column.is_dttm, - is_aggregation=False, - is_physical=column.expression is None, - is_spatial=False, - is_partition=False, - is_increase_desired=True, - extra_json=json.dumps(extra_json) if extra_json else None, - is_managed_externally=target.is_managed_externally, - external_url=target.external_url, - ) - ) - - # update metrics - if inspector.attrs.metrics.history.has_changes(): - # handle deleted metrics - if inspector.attrs.metrics.history.deleted: - column_names = { - metric.metric_name - for metric in inspector.attrs.metrics.history.deleted - } - dataset.columns = [ - column - for column in dataset.columns - if column.name not in column_names - ] - - # handle inserted metrics - for metric in inspector.attrs.metrics.history.added: - extra_json = json.loads(metric.extra or "{}") - for attr in {"verbose_name", "metric_type", "d3format"}: - value = getattr(metric, attr) - if value: - extra_json[attr] = value - - is_additive = ( - metric.metric_type - and metric.metric_type.lower() in ADDITIVE_METRIC_TYPES - ) - - dataset.columns.append( - NewColumn( - name=metric.metric_name, - type="Unknown", - expression=metric.expression, - warning_text=metric.warning_text, - description=metric.description, - is_aggregation=True, - is_additive=is_additive, - is_physical=False, - is_spatial=False, - is_partition=False, - is_increase_desired=True, - extra_json=json.dumps(extra_json) if extra_json else None, - is_managed_externally=target.is_managed_externally, - external_url=target.external_url, - ) - ) + # sync column list and delete removed columns + if ( + inspector.attrs.columns.history.has_changes() + or inspector.attrs.metrics.history.has_changes() + ): + # add pending new columns to known columns list, too, so if calling + # `after_update` twice before changes are persisted will not create + # two duplicate columns with the same uuids. + dataset.columns = sqla_table.get_sl_columns() # physical dataset - if target.sql is None: - physical_columns = [ - column for column in dataset.columns if column.is_physical - ] - - # if the table name changed we should create a new table instance, instead - # of reusing the original one + if not sqla_table.sql: + # if the table name changed we should relink the dataset to another table + # (and create one if necessary) if ( inspector.attrs.table_name.history.has_changes() or inspector.attrs.schema.history.has_changes() - or inspector.attrs.database_id.history.has_changes() + or inspector.attrs.database.history.has_changes() ): - # does the dataset point to an existing table? - table = ( - session.query(NewTable) - .filter_by( - database_id=target.database_id, - schema=target.schema, - name=target.table_name, - ) - .first() + tables = NewTable.bulk_load_or_create( + sqla_table.database, + [TableName(schema=sqla_table.schema, table=sqla_table.table_name)], + sync_columns=False, + default_props=dict( + changed_by=sqla_table.changed_by, + created_by=sqla_table.created_by, + is_managed_externally=sqla_table.is_managed_externally, + external_url=sqla_table.external_url, + ), ) - if not table: - # create new columns + if not tables[0].id: + # dataset columns will only be assigned to newly created tables + # existing tables should manage column syncing in another process physical_columns = [ - clone_model(column, ignore=["uuid"]) - for column in physical_columns + clone_model( + column, ignore=["uuid"], keep_relations=["changed_by"] + ) + for column in dataset.columns + if column.is_physical ] - - # create new table - table = NewTable( - name=target.table_name, - schema=target.schema, - catalog=None, - database_id=target.database_id, - columns=physical_columns, - is_managed_externally=target.is_managed_externally, - external_url=target.external_url, - ) - dataset.tables = [table] - elif dataset.tables: - table = dataset.tables[0] - table.columns = physical_columns + tables[0].columns = physical_columns + dataset.tables = tables # virtual dataset else: @@ -2237,29 +2248,34 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho column.is_physical = False # update referenced tables if SQL changed - if inspector.attrs.sql.history.has_changes(): - parsed = ParsedQuery(target.sql) - referenced_tables = parsed.tables - - predicate = or_( - *[ - and_( - NewTable.schema == (table.schema or target.schema), - NewTable.name == table.table, - ) - for table in referenced_tables - ] + if sqla_table.sql and inspector.attrs.sql.history.has_changes(): + referenced_tables = extract_table_references( + sqla_table.sql, sqla_table.database.get_dialect().name + ) + dataset.tables = NewTable.bulk_load_or_create( + sqla_table.database, + referenced_tables, + default_schema=sqla_table.schema, + # sync metadata is expensive, we'll do it in another process + # e.g. when users open a Table page + sync_columns=False, + default_props=dict( + changed_by=sqla_table.changed_by, + created_by=sqla_table.created_by, + is_managed_externally=sqla_table.is_managed_externally, + external_url=sqla_table.external_url, + ), ) - dataset.tables = session.query(NewTable).filter(predicate).all() # update other attributes - dataset.name = target.table_name - dataset.expression = target.sql or conditional_quote(target.table_name) - dataset.is_physical = target.sql is None + dataset.name = sqla_table.table_name + dataset.expression = sqla_table.sql or sqla_table.quote_identifier( + sqla_table.table_name + ) + dataset.is_physical = not sqla_table.sql - @staticmethod - def write_shadow_dataset( # pylint: disable=too-many-locals - dataset: "SqlaTable", database: Database, session: Session + def write_shadow_dataset( + self: "SqlaTable", ) -> None: """ Shadow write the dataset to new models. @@ -2273,95 +2289,57 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho For more context: https://github.com/apache/superset/issues/14909 """ - - engine = database.get_sqla_engine(schema=dataset.schema) - conditional_quote = engine.dialect.identifier_preparer.quote + session = inspect(self).session + # make sure database points to the right instance, in case only + # `table.database_id` is updated and the changes haven't been + # consolidated by SQLA + if self.database_id and ( + not self.database or self.database.id != self.database_id + ): + self.database = session.query(Database).filter_by(id=self.database_id).one() # create columns columns = [] - for column in dataset.columns: - # ``is_active`` might be ``None`` at this point, but it defaults to ``True``. - if column.is_active is False: - continue - - try: - extra_json = json.loads(column.extra or "{}") - except json.decoder.JSONDecodeError: - extra_json = {} - for attr in {"groupby", "filterable", "verbose_name", "python_date_format"}: - value = getattr(column, attr) - if value: - extra_json[attr] = value - - columns.append( - NewColumn( - name=column.column_name, - type=column.type or "Unknown", - expression=column.expression - or conditional_quote(column.column_name), - description=column.description, - is_temporal=column.is_dttm, - is_aggregation=False, - is_physical=column.expression is None, - is_spatial=False, - is_partition=False, - is_increase_desired=True, - extra_json=json.dumps(extra_json) if extra_json else None, - is_managed_externally=dataset.is_managed_externally, - external_url=dataset.external_url, - ), - ) - - # create metrics - for metric in dataset.metrics: - try: - extra_json = json.loads(metric.extra or "{}") - except json.decoder.JSONDecodeError: - extra_json = {} - for attr in {"verbose_name", "metric_type", "d3format"}: - value = getattr(metric, attr) - if value: - extra_json[attr] = value - - is_additive = ( - metric.metric_type - and metric.metric_type.lower() in ADDITIVE_METRIC_TYPES - ) - - columns.append( - NewColumn( - name=metric.metric_name, - type="Unknown", # figuring this out would require a type inferrer - expression=metric.expression, - warning_text=metric.warning_text, - description=metric.description, - is_aggregation=True, - is_additive=is_additive, - is_physical=False, - is_spatial=False, - is_partition=False, - is_increase_desired=True, - extra_json=json.dumps(extra_json) if extra_json else None, - is_managed_externally=dataset.is_managed_externally, - external_url=dataset.external_url, - ), - ) + for item in self.columns + self.metrics: + item.created_by = self.created_by + item.changed_by = self.changed_by + # on `SqlaTable.after_insert`` event, although the table itself + # already has a `uuid`, the associated columns will not. + # Here we pre-assign a uuid so they can still be matched to the new + # Column after creation. + if not item.uuid: + item.uuid = uuid4() + columns.append(item.to_sl_column()) # physical dataset - if not dataset.sql: - physical_columns = [column for column in columns if column.is_physical] - - # create table - table = NewTable( - name=dataset.table_name, - schema=dataset.schema, - catalog=None, # currently not supported - database_id=dataset.database_id, - columns=physical_columns, - is_managed_externally=dataset.is_managed_externally, - external_url=dataset.external_url, + if not self.sql: + # always create separate column entries for Dataset and Table + # so updating a dataset would not update columns in the related table + physical_columns = [ + clone_model( + column, + ignore=["uuid"], + # `created_by` will always be left empty because it'd always + # be created via some sort of automated system. + # But keep `changed_by` in case someone manually changes + # column attributes such as `is_dttm`. + keep_relations=["changed_by"], + ) + for column in columns + if column.is_physical + ] + tables = NewTable.bulk_load_or_create( + self.database, + [TableName(schema=self.schema, table=self.table_name)], + sync_columns=False, + default_props=dict( + created_by=self.created_by, + changed_by=self.changed_by, + is_managed_externally=self.is_managed_externally, + external_url=self.external_url, + ), ) - tables = [table] + tables[0].columns = physical_columns # virtual dataset else: @@ -2370,26 +2348,39 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho column.is_physical = False # find referenced tables - parsed = ParsedQuery(dataset.sql) - referenced_tables = parsed.tables - tables = load_or_create_tables( - session, - database, - dataset.schema, + referenced_tables = extract_table_references( + self.sql, self.database.get_dialect().name + ) + tables = NewTable.bulk_load_or_create( + self.database, referenced_tables, - conditional_quote, + default_schema=self.schema, + # syncing table columns can be slow so we are not doing it here + sync_columns=False, + default_props=dict( + created_by=self.created_by, + changed_by=self.changed_by, + is_managed_externally=self.is_managed_externally, + external_url=self.external_url, + ), ) # create the new dataset new_dataset = NewDataset( - sqlatable_id=dataset.id, - name=dataset.table_name, - expression=dataset.sql or conditional_quote(dataset.table_name), + uuid=self.uuid, + database_id=self.database_id, + created_on=self.created_on, + created_by=self.created_by, + changed_by=self.changed_by, + changed_on=self.changed_on, + owners=self.owners, + name=self.table_name, + expression=self.sql or self.quote_identifier(self.table_name), tables=tables, columns=columns, - is_physical=not dataset.sql, - is_managed_externally=dataset.is_managed_externally, - external_url=dataset.external_url, + is_physical=not self.sql, + is_managed_externally=self.is_managed_externally, + external_url=self.external_url, ) session.add(new_dataset) @@ -2399,7 +2390,9 @@ sa.event.listen(SqlaTable, "after_insert", SqlaTable.after_insert) sa.event.listen(SqlaTable, "after_delete", SqlaTable.after_delete) sa.event.listen(SqlaTable, "after_update", SqlaTable.after_update) sa.event.listen(SqlMetric, "after_update", SqlaTable.update_table) +sa.event.listen(SqlMetric, "after_delete", SqlMetric.after_delete) sa.event.listen(TableColumn, "after_update", SqlaTable.update_table) +sa.event.listen(TableColumn, "after_delete", TableColumn.after_delete) RLSFilterRoles = Table( "rls_filter_roles", diff --git a/superset/connectors/sqla/utils.py b/superset/connectors/sqla/utils.py index f8ed7a9567..1786c5bf17 100644 --- a/superset/connectors/sqla/utils.py +++ b/superset/connectors/sqla/utils.py @@ -15,16 +15,28 @@ # specific language governing permissions and limitations # under the License. from contextlib import closing -from typing import Any, Callable, Dict, List, Optional, Set, TYPE_CHECKING +from typing import ( + Any, + Callable, + Dict, + Iterable, + Iterator, + List, + Optional, + Type, + TYPE_CHECKING, + TypeVar, +) +from uuid import UUID import sqlparse from flask_babel import lazy_gettext as _ -from sqlalchemy import and_, or_ +from sqlalchemy.engine.url import URL as SqlaURL from sqlalchemy.exc import NoSuchTableError +from sqlalchemy.ext.declarative import DeclarativeMeta from sqlalchemy.orm import Session from sqlalchemy.sql.type_api import TypeEngine -from superset.columns.models import Column as NewColumn from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import ( SupersetGenericDBErrorException, @@ -32,9 +44,9 @@ from superset.exceptions import ( ) from superset.models.core import Database from superset.result_set import SupersetResultSet -from superset.sql_parse import has_table_query, insert_rls, ParsedQuery, Table +from superset.sql_parse import has_table_query, insert_rls, ParsedQuery from superset.superset_typing import ResultSetColumnType -from superset.tables.models import Table as NewTable +from superset.utils.memoized import memoized if TYPE_CHECKING: from superset.connectors.sqla.models import SqlaTable @@ -168,75 +180,38 @@ def validate_adhoc_subquery( return ";\n".join(str(statement) for statement in statements) -def load_or_create_tables( # pylint: disable=too-many-arguments +@memoized +def get_dialect_name(drivername: str) -> str: + return SqlaURL(drivername).get_dialect().name + + +@memoized +def get_identifier_quoter(drivername: str) -> Dict[str, Callable[[str], str]]: + return SqlaURL(drivername).get_dialect()().identifier_preparer.quote + + +DeclarativeModel = TypeVar("DeclarativeModel", bound=DeclarativeMeta) + + +def find_cached_objects_in_session( session: Session, - database: Database, - default_schema: Optional[str], - tables: Set[Table], - conditional_quote: Callable[[str], str], -) -> List[NewTable]: - """ - Load or create new table model instances. - """ - if not tables: - return [] + cls: Type[DeclarativeModel], + ids: Optional[Iterable[int]] = None, + uuids: Optional[Iterable[UUID]] = None, +) -> Iterator[DeclarativeModel]: + """Find known ORM instances in cached SQLA session states. - # set the default schema in tables that don't have it - if default_schema: - fixed_tables = list(tables) - for i, table in enumerate(fixed_tables): - if table.schema is None: - fixed_tables[i] = Table(table.table, default_schema, table.catalog) - tables = set(fixed_tables) - - # load existing tables - predicate = or_( - *[ - and_( - NewTable.database_id == database.id, - NewTable.schema == table.schema, - NewTable.name == table.table, - ) - for table in tables - ] + :param session: a SQLA session + :param cls: a SQLA DeclarativeModel + :param ids: ids of the desired model instances (optional) + :param uuids: uuids of the desired instances, will be ignored if `ids` are provides + """ + if not ids and not uuids: + return iter([]) + uuids = uuids or [] + return ( + item + # `session` is an iterator of all known items + for item in set(session) + if isinstance(item, cls) and (item.id in ids if ids else item.uuid in uuids) ) - new_tables = session.query(NewTable).filter(predicate).all() - - # add missing tables - existing = {(table.schema, table.name) for table in new_tables} - for table in tables: - if (table.schema, table.table) not in existing: - try: - column_metadata = get_physical_table_metadata( - database=database, - table_name=table.table, - schema_name=table.schema, - ) - except Exception: # pylint: disable=broad-except - continue - columns = [ - NewColumn( - name=column["name"], - type=str(column["type"]), - expression=conditional_quote(column["name"]), - is_temporal=column["is_dttm"], - is_aggregation=False, - is_physical=True, - is_spatial=False, - is_partition=False, - is_increase_desired=True, - ) - for column in column_metadata - ] - new_tables.append( - NewTable( - name=table.table, - schema=table.schema, - catalog=None, - database_id=database.id, - columns=columns, - ) - ) - existing.add((table.schema, table.table)) - - return new_tables diff --git a/superset/datasets/models.py b/superset/datasets/models.py index 56a6fbf400..b433709f2c 100644 --- a/superset/datasets/models.py +++ b/superset/datasets/models.py @@ -28,9 +28,11 @@ from typing import List import sqlalchemy as sa from flask_appbuilder import Model -from sqlalchemy.orm import relationship +from sqlalchemy.orm import backref, relationship +from superset import security_manager from superset.columns.models import Column +from superset.models.core import Database from superset.models.helpers import ( AuditMixinNullable, ExtraJSONMixin, @@ -38,18 +40,33 @@ from superset.models.helpers import ( ) from superset.tables.models import Table -column_association_table = sa.Table( +dataset_column_association_table = sa.Table( "sl_dataset_columns", Model.metadata, # pylint: disable=no-member - sa.Column("dataset_id", sa.ForeignKey("sl_datasets.id")), - sa.Column("column_id", sa.ForeignKey("sl_columns.id")), + sa.Column( + "dataset_id", + sa.ForeignKey("sl_datasets.id"), + primary_key=True, + ), + sa.Column( + "column_id", + sa.ForeignKey("sl_columns.id"), + primary_key=True, + ), ) -table_association_table = sa.Table( +dataset_table_association_table = sa.Table( "sl_dataset_tables", Model.metadata, # pylint: disable=no-member - sa.Column("dataset_id", sa.ForeignKey("sl_datasets.id")), - sa.Column("table_id", sa.ForeignKey("sl_tables.id")), + sa.Column("dataset_id", sa.ForeignKey("sl_datasets.id"), primary_key=True), + sa.Column("table_id", sa.ForeignKey("sl_tables.id"), primary_key=True), +) + +dataset_user_association_table = sa.Table( + "sl_dataset_users", + Model.metadata, # pylint: disable=no-member + sa.Column("dataset_id", sa.ForeignKey("sl_datasets.id"), primary_key=True), + sa.Column("user_id", sa.ForeignKey("ab_user.id"), primary_key=True), ) @@ -61,27 +78,27 @@ class Dataset(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin): __tablename__ = "sl_datasets" id = sa.Column(sa.Integer, primary_key=True) - - # A temporary column, used for shadow writing to the new model. Once the ``SqlaTable`` - # model has been deleted this column can be removed. - sqlatable_id = sa.Column(sa.Integer, nullable=True, unique=True) - - # We use ``sa.Text`` for these attributes because (1) in modern databases the - # performance is the same as ``VARCHAR``[1] and (2) because some table names can be - # **really** long (eg, Google Sheets URLs). - # - # [1] https://www.postgresql.org/docs/9.1/datatype-character.html - name = sa.Column(sa.Text) - - expression = sa.Column(sa.Text) - - # n:n relationship - tables: List[Table] = relationship("Table", secondary=table_association_table) - - # The relationship between datasets and columns is 1:n, but we use a many-to-many - # association to differentiate between the relationship between tables and columns. + database_id = sa.Column(sa.Integer, sa.ForeignKey("dbs.id"), nullable=False) + database: Database = relationship( + "Database", + backref=backref("datasets", cascade="all, delete-orphan"), + foreign_keys=[database_id], + ) + # The relationship between datasets and columns is 1:n, but we use a + # many-to-many association table to avoid adding two mutually exclusive + # columns(dataset_id and table_id) to Column columns: List[Column] = relationship( - "Column", secondary=column_association_table, cascade="all, delete" + "Column", + secondary=dataset_column_association_table, + cascade="all, delete-orphan", + single_parent=True, + backref="datasets", + ) + owners = relationship( + security_manager.user_model, secondary=dataset_user_association_table + ) + tables: List[Table] = relationship( + "Table", secondary=dataset_table_association_table, backref="datasets" ) # Does the dataset point directly to a ``Table``? @@ -89,4 +106,15 @@ class Dataset(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin): # Column is managed externally and should be read-only inside Superset is_managed_externally = sa.Column(sa.Boolean, nullable=False, default=False) + + # We use ``sa.Text`` for these attributes because (1) in modern databases the + # performance is the same as ``VARCHAR``[1] and (2) because some table names can be + # **really** long (eg, Google Sheets URLs). + # + # [1] https://www.postgresql.org/docs/9.1/datatype-character.html + name = sa.Column(sa.Text) + expression = sa.Column(sa.Text) external_url = sa.Column(sa.Text, nullable=True) + + def __repr__(self) -> str: + return f"" diff --git a/superset/examples/birth_names.py b/superset/examples/birth_names.py index 1380958b2a..8d7c02799d 100644 --- a/superset/examples/birth_names.py +++ b/superset/examples/birth_names.py @@ -135,23 +135,26 @@ def _set_table_metadata(datasource: SqlaTable, database: "Database") -> None: def _add_table_metrics(datasource: SqlaTable) -> None: - if not any(col.column_name == "num_california" for col in datasource.columns): + # By accessing the attribute first, we make sure `datasource.columns` and + # `datasource.metrics` are already loaded. Otherwise accessing them later + # may trigger an unnecessary and unexpected `after_update` event. + columns, metrics = datasource.columns, datasource.metrics + + if not any(col.column_name == "num_california" for col in columns): col_state = str(column("state").compile(db.engine)) col_num = str(column("num").compile(db.engine)) - datasource.columns.append( + columns.append( TableColumn( column_name="num_california", expression=f"CASE WHEN {col_state} = 'CA' THEN {col_num} ELSE 0 END", ) ) - if not any(col.metric_name == "sum__num" for col in datasource.metrics): + if not any(col.metric_name == "sum__num" for col in metrics): col = str(column("num").compile(db.engine)) - datasource.metrics.append( - SqlMetric(metric_name="sum__num", expression=f"SUM({col})") - ) + metrics.append(SqlMetric(metric_name="sum__num", expression=f"SUM({col})")) - for col in datasource.columns: + for col in columns: if col.column_name == "ds": col.is_dttm = True break diff --git a/superset/migrations/shared/utils.py b/superset/migrations/shared/utils.py index c54de83c42..4b0c4e1440 100644 --- a/superset/migrations/shared/utils.py +++ b/superset/migrations/shared/utils.py @@ -15,42 +15,22 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Iterator, Optional, Set +import os +import time +from typing import Any +from uuid import uuid4 from alembic import op from sqlalchemy import engine_from_config +from sqlalchemy.dialects.mysql.base import MySQLDialect +from sqlalchemy.dialects.postgresql.base import PGDialect from sqlalchemy.engine import reflection from sqlalchemy.exc import NoSuchTableError +from sqlalchemy.orm import Session -try: - from sqloxide import parse_sql -except ImportError: - parse_sql = None +logger = logging.getLogger(__name__) -from superset.sql_parse import ParsedQuery, Table - -logger = logging.getLogger("alembic") - - -# mapping between sqloxide and SQLAlchemy dialects -sqloxide_dialects = { - "ansi": {"trino", "trinonative", "presto"}, - "hive": {"hive", "databricks"}, - "ms": {"mssql"}, - "mysql": {"mysql"}, - "postgres": { - "cockroachdb", - "hana", - "netezza", - "postgres", - "postgresql", - "redshift", - "vertica", - }, - "snowflake": {"snowflake"}, - "sqlite": {"sqlite", "gsheets", "shillelagh"}, - "clickhouse": {"clickhouse"}, -} +DEFAULT_BATCH_SIZE = int(os.environ.get("BATCH_SIZE", 1000)) def table_has_column(table: str, column: str) -> bool: @@ -61,7 +41,6 @@ def table_has_column(table: str, column: str) -> bool: :param column: A column name :returns: True iff the column exists in the table """ - config = op.get_context().config engine = engine_from_config( config.get_section(config.config_ini_section), prefix="sqlalchemy." @@ -73,42 +52,44 @@ def table_has_column(table: str, column: str) -> bool: return False -def find_nodes_by_key(element: Any, target: str) -> Iterator[Any]: - """ - Find all nodes in a SQL tree matching a given key. - """ - if isinstance(element, list): - for child in element: - yield from find_nodes_by_key(child, target) - elif isinstance(element, dict): - for key, value in element.items(): - if key == target: - yield value - else: - yield from find_nodes_by_key(value, target) +uuid_by_dialect = { + MySQLDialect: "UNHEX(REPLACE(CONVERT(UUID() using utf8mb4), '-', ''))", + PGDialect: "uuid_in(md5(random()::text || clock_timestamp()::text)::cstring)", +} -def extract_table_references(sql_text: str, sqla_dialect: str) -> Set[Table]: - """ - Return all the dependencies from a SQL sql_text. - """ - if not parse_sql: - parsed = ParsedQuery(sql_text) - return parsed.tables +def assign_uuids( + model: Any, session: Session, batch_size: int = DEFAULT_BATCH_SIZE +) -> None: + """Generate new UUIDs for all rows in a table""" + bind = op.get_bind() + table_name = model.__tablename__ + count = session.query(model).count() + # silently skip if the table is empty (suitable for db initialization) + if count == 0: + return - dialect = "generic" - for dialect, sqla_dialects in sqloxide_dialects.items(): - if sqla_dialect in sqla_dialects: - break - try: - tree = parse_sql(sql_text, dialect=dialect) - except Exception: # pylint: disable=broad-except - logger.warning("Unable to parse query with sqloxide: %s", sql_text) - # fallback to sqlparse - parsed = ParsedQuery(sql_text) - return parsed.tables + start_time = time.time() + print(f"\nAdding uuids for `{table_name}`...") + # Use dialect specific native SQL queries if possible + for dialect, sql in uuid_by_dialect.items(): + if isinstance(bind.dialect, dialect): + op.execute( + f"UPDATE {dialect().identifier_preparer.quote(table_name)} SET uuid = {sql}" + ) + print(f"Done. Assigned {count} uuids in {time.time() - start_time:.3f}s.\n") + return - return { - Table(*[part["value"] for part in table["name"][::-1]]) - for table in find_nodes_by_key(tree, "Table") - } + # Othwewise Use Python uuid function + start = 0 + while start < count: + end = min(start + batch_size, count) + for obj in session.query(model)[start:end]: + obj.uuid = uuid4() + session.merge(obj) + session.commit() + if start + batch_size < count: + print(f" uuid assigned to {end} out of {count}\r", end="") + start += batch_size + + print(f"Done. Assigned {count} uuids in {time.time() - start_time:.3f}s.\n") diff --git a/superset/migrations/versions/96e99fb176a0_add_import_mixing_to_saved_query.py b/superset/migrations/versions/96e99fb176a0_add_import_mixing_to_saved_query.py index 57d22aa089..f93deb1d0c 100644 --- a/superset/migrations/versions/96e99fb176a0_add_import_mixing_to_saved_query.py +++ b/superset/migrations/versions/96e99fb176a0_add_import_mixing_to_saved_query.py @@ -32,9 +32,7 @@ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy_utils import UUIDType from superset import db -from superset.migrations.versions.b56500de1855_add_uuid_column_to_import_mixin import ( - add_uuids, -) +from superset.migrations.shared.utils import assign_uuids # revision identifiers, used by Alembic. revision = "96e99fb176a0" @@ -75,7 +73,7 @@ def upgrade(): # Ignore column update errors so that we can run upgrade multiple times pass - add_uuids(SavedQuery, "saved_query", session) + assign_uuids(SavedQuery, session) try: # Add uniqueness constraint diff --git a/superset/migrations/versions/9d8a8d575284_.py b/superset/migrations/versions/9d8a8d575284_.py index daa84a2ad0..fbbfac231b 100644 --- a/superset/migrations/versions/9d8a8d575284_.py +++ b/superset/migrations/versions/9d8a8d575284_.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""empty message +"""merge point Revision ID: 9d8a8d575284 Revises: ('8b841273bec3', 'b0d0249074e4') diff --git a/superset/migrations/versions/a9422eeaae74_new_dataset_models_take_2.py b/superset/migrations/versions/a9422eeaae74_new_dataset_models_take_2.py new file mode 100644 index 0000000000..efb7d1a01b --- /dev/null +++ b/superset/migrations/versions/a9422eeaae74_new_dataset_models_take_2.py @@ -0,0 +1,905 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""new_dataset_models_take_2 + +Revision ID: a9422eeaae74 +Revises: ad07e4fdbaba +Create Date: 2022-04-01 14:38:09.499483 + +""" + +# revision identifiers, used by Alembic. +revision = "a9422eeaae74" +down_revision = "ad07e4fdbaba" + +import json +import os +from datetime import datetime +from typing import List, Optional, Set, Type, Union +from uuid import uuid4 + +import sqlalchemy as sa +from alembic import op +from sqlalchemy import select +from sqlalchemy.ext.declarative import declarative_base, declared_attr +from sqlalchemy.orm import backref, relationship, Session +from sqlalchemy.schema import UniqueConstraint +from sqlalchemy.sql import functions as func +from sqlalchemy.sql.expression import and_, or_ +from sqlalchemy_utils import UUIDType + +from superset import app, db +from superset.connectors.sqla.models import ADDITIVE_METRIC_TYPES_LOWER +from superset.connectors.sqla.utils import get_dialect_name, get_identifier_quoter +from superset.extensions import encrypted_field_factory +from superset.migrations.shared.utils import assign_uuids +from superset.sql_parse import extract_table_references, Table +from superset.utils.core import MediumText + +Base = declarative_base() +custom_password_store = app.config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"] +DB_CONNECTION_MUTATOR = app.config["DB_CONNECTION_MUTATOR"] +SHOW_PROGRESS = os.environ.get("SHOW_PROGRESS") == "1" +UNKNOWN_TYPE = "UNKNOWN" + + +user_table = sa.Table( + "ab_user", Base.metadata, sa.Column("id", sa.Integer(), primary_key=True) +) + + +class UUIDMixin: + uuid = sa.Column( + UUIDType(binary=True), primary_key=False, unique=True, default=uuid4 + ) + + +class AuxiliaryColumnsMixin(UUIDMixin): + """ + Auxiliary columns, a combination of columns added by + AuditMixinNullable + ImportExportMixin + """ + + created_on = sa.Column(sa.DateTime, default=datetime.now, nullable=True) + changed_on = sa.Column( + sa.DateTime, default=datetime.now, onupdate=datetime.now, nullable=True + ) + + @declared_attr + def created_by_fk(cls): + return sa.Column(sa.Integer, sa.ForeignKey("ab_user.id"), nullable=True) + + @declared_attr + def changed_by_fk(cls): + return sa.Column(sa.Integer, sa.ForeignKey("ab_user.id"), nullable=True) + + +def insert_from_select( + target: Union[str, sa.Table, Type[Base]], source: sa.sql.expression.Select +) -> None: + """ + Execute INSERT FROM SELECT to copy data from a SELECT query to the target table. + """ + if isinstance(target, sa.Table): + target_table = target + elif hasattr(target, "__tablename__"): + target_table: sa.Table = Base.metadata.tables[target.__tablename__] + else: + target_table: sa.Table = Base.metadata.tables[target] + cols = [col.name for col in source.columns if col.name in target_table.columns] + query = target_table.insert().from_select(cols, source) + return op.execute(query) + + +class Database(Base): + + __tablename__ = "dbs" + __table_args__ = (UniqueConstraint("database_name"),) + + id = sa.Column(sa.Integer, primary_key=True) + database_name = sa.Column(sa.String(250), unique=True, nullable=False) + sqlalchemy_uri = sa.Column(sa.String(1024), nullable=False) + password = sa.Column(encrypted_field_factory.create(sa.String(1024))) + impersonate_user = sa.Column(sa.Boolean, default=False) + encrypted_extra = sa.Column(encrypted_field_factory.create(sa.Text), nullable=True) + extra = sa.Column(sa.Text) + server_cert = sa.Column(encrypted_field_factory.create(sa.Text), nullable=True) + + +class TableColumn(AuxiliaryColumnsMixin, Base): + + __tablename__ = "table_columns" + __table_args__ = (UniqueConstraint("table_id", "column_name"),) + + id = sa.Column(sa.Integer, primary_key=True) + table_id = sa.Column(sa.Integer, sa.ForeignKey("tables.id")) + is_active = sa.Column(sa.Boolean, default=True) + extra = sa.Column(sa.Text) + column_name = sa.Column(sa.String(255), nullable=False) + type = sa.Column(sa.String(32)) + expression = sa.Column(MediumText()) + description = sa.Column(MediumText()) + is_dttm = sa.Column(sa.Boolean, default=False) + filterable = sa.Column(sa.Boolean, default=True) + groupby = sa.Column(sa.Boolean, default=True) + verbose_name = sa.Column(sa.String(1024)) + python_date_format = sa.Column(sa.String(255)) + + +class SqlMetric(AuxiliaryColumnsMixin, Base): + + __tablename__ = "sql_metrics" + __table_args__ = (UniqueConstraint("table_id", "metric_name"),) + + id = sa.Column(sa.Integer, primary_key=True) + table_id = sa.Column(sa.Integer, sa.ForeignKey("tables.id")) + extra = sa.Column(sa.Text) + metric_type = sa.Column(sa.String(32)) + metric_name = sa.Column(sa.String(255), nullable=False) + expression = sa.Column(MediumText(), nullable=False) + warning_text = sa.Column(MediumText()) + description = sa.Column(MediumText()) + d3format = sa.Column(sa.String(128)) + verbose_name = sa.Column(sa.String(1024)) + + +sqlatable_user_table = sa.Table( + "sqlatable_user", + Base.metadata, + sa.Column("id", sa.Integer, primary_key=True), + sa.Column("user_id", sa.Integer, sa.ForeignKey("ab_user.id")), + sa.Column("table_id", sa.Integer, sa.ForeignKey("tables.id")), +) + + +class SqlaTable(AuxiliaryColumnsMixin, Base): + + __tablename__ = "tables" + __table_args__ = (UniqueConstraint("database_id", "schema", "table_name"),) + + id = sa.Column(sa.Integer, primary_key=True) + extra = sa.Column(sa.Text) + database_id = sa.Column(sa.Integer, sa.ForeignKey("dbs.id"), nullable=False) + database: Database = relationship( + "Database", + backref=backref("tables", cascade="all, delete-orphan"), + foreign_keys=[database_id], + ) + schema = sa.Column(sa.String(255)) + table_name = sa.Column(sa.String(250), nullable=False) + sql = sa.Column(MediumText()) + is_managed_externally = sa.Column(sa.Boolean, nullable=False, default=False) + external_url = sa.Column(sa.Text, nullable=True) + + +table_column_association_table = sa.Table( + "sl_table_columns", + Base.metadata, + sa.Column("table_id", sa.ForeignKey("sl_tables.id"), primary_key=True), + sa.Column("column_id", sa.ForeignKey("sl_columns.id"), primary_key=True), +) + +dataset_column_association_table = sa.Table( + "sl_dataset_columns", + Base.metadata, + sa.Column("dataset_id", sa.ForeignKey("sl_datasets.id"), primary_key=True), + sa.Column("column_id", sa.ForeignKey("sl_columns.id"), primary_key=True), +) + +dataset_table_association_table = sa.Table( + "sl_dataset_tables", + Base.metadata, + sa.Column("dataset_id", sa.ForeignKey("sl_datasets.id"), primary_key=True), + sa.Column("table_id", sa.ForeignKey("sl_tables.id"), primary_key=True), +) + +dataset_user_association_table = sa.Table( + "sl_dataset_users", + Base.metadata, + sa.Column("dataset_id", sa.ForeignKey("sl_datasets.id"), primary_key=True), + sa.Column("user_id", sa.ForeignKey("ab_user.id"), primary_key=True), +) + + +class NewColumn(AuxiliaryColumnsMixin, Base): + + __tablename__ = "sl_columns" + + id = sa.Column(sa.Integer, primary_key=True) + # A temporary column to link physical columns with tables so we don't + # have to insert a record in the relationship table while creating new columns. + table_id = sa.Column(sa.Integer, nullable=True) + + is_aggregation = sa.Column(sa.Boolean, nullable=False, default=False) + is_additive = sa.Column(sa.Boolean, nullable=False, default=False) + is_dimensional = sa.Column(sa.Boolean, nullable=False, default=False) + is_filterable = sa.Column(sa.Boolean, nullable=False, default=True) + is_increase_desired = sa.Column(sa.Boolean, nullable=False, default=True) + is_managed_externally = sa.Column(sa.Boolean, nullable=False, default=False) + is_partition = sa.Column(sa.Boolean, nullable=False, default=False) + is_physical = sa.Column(sa.Boolean, nullable=False, default=False) + is_temporal = sa.Column(sa.Boolean, nullable=False, default=False) + is_spatial = sa.Column(sa.Boolean, nullable=False, default=False) + + name = sa.Column(sa.Text) + type = sa.Column(sa.Text) + unit = sa.Column(sa.Text) + expression = sa.Column(MediumText()) + description = sa.Column(MediumText()) + warning_text = sa.Column(MediumText()) + external_url = sa.Column(sa.Text, nullable=True) + extra_json = sa.Column(MediumText(), default="{}") + + +class NewTable(AuxiliaryColumnsMixin, Base): + + __tablename__ = "sl_tables" + + id = sa.Column(sa.Integer, primary_key=True) + # A temporary column to keep the link between NewTable to SqlaTable + sqlatable_id = sa.Column(sa.Integer, primary_key=False, nullable=True, unique=True) + database_id = sa.Column(sa.Integer, sa.ForeignKey("dbs.id"), nullable=False) + is_managed_externally = sa.Column(sa.Boolean, nullable=False, default=False) + catalog = sa.Column(sa.Text) + schema = sa.Column(sa.Text) + name = sa.Column(sa.Text) + external_url = sa.Column(sa.Text, nullable=True) + extra_json = sa.Column(MediumText(), default="{}") + database: Database = relationship( + "Database", + backref=backref("new_tables", cascade="all, delete-orphan"), + foreign_keys=[database_id], + ) + + +class NewDataset(Base, AuxiliaryColumnsMixin): + + __tablename__ = "sl_datasets" + + id = sa.Column(sa.Integer, primary_key=True) + database_id = sa.Column(sa.Integer, sa.ForeignKey("dbs.id"), nullable=False) + is_physical = sa.Column(sa.Boolean, default=False) + is_managed_externally = sa.Column(sa.Boolean, nullable=False, default=False) + name = sa.Column(sa.Text) + expression = sa.Column(MediumText()) + external_url = sa.Column(sa.Text, nullable=True) + extra_json = sa.Column(MediumText(), default="{}") + + +def find_tables( + session: Session, + database_id: int, + default_schema: Optional[str], + tables: Set[Table], +) -> List[int]: + """ + Look for NewTable's of from a specific database + """ + if not tables: + return [] + + predicate = or_( + *[ + and_( + NewTable.database_id == database_id, + NewTable.schema == (table.schema or default_schema), + NewTable.name == table.table, + ) + for table in tables + ] + ) + return session.query(NewTable.id).filter(predicate).all() + + +# helper SQLA elements for easier querying +is_physical_table = or_(SqlaTable.sql.is_(None), SqlaTable.sql == "") +is_physical_column = or_(TableColumn.expression.is_(None), TableColumn.expression == "") + +# filtering out table columns with valid associated SqlTable +active_table_columns = sa.join( + TableColumn, + SqlaTable, + TableColumn.table_id == SqlaTable.id, +) +active_metrics = sa.join(SqlMetric, SqlaTable, SqlMetric.table_id == SqlaTable.id) + + +def copy_tables(session: Session) -> None: + """Copy Physical tables""" + count = session.query(SqlaTable).filter(is_physical_table).count() + if not count: + return + print(f">> Copy {count:,} physical tables to sl_tables...") + insert_from_select( + NewTable, + select( + [ + # Tables need different uuid than datasets, since they are different + # entities. When INSERT FROM SELECT, we must provide a value for `uuid`, + # otherwise it'd use the default generated on Python side, which + # will cause duplicate values. They will be replaced by `assign_uuids` later. + SqlaTable.uuid, + SqlaTable.id.label("sqlatable_id"), + SqlaTable.created_on, + SqlaTable.changed_on, + SqlaTable.created_by_fk, + SqlaTable.changed_by_fk, + SqlaTable.table_name.label("name"), + SqlaTable.schema, + SqlaTable.database_id, + SqlaTable.is_managed_externally, + SqlaTable.external_url, + ] + ) + # use an inner join to filter out only tables with valid database ids + .select_from( + sa.join(SqlaTable, Database, SqlaTable.database_id == Database.id) + ).where(is_physical_table), + ) + + +def copy_datasets(session: Session) -> None: + """Copy all datasets""" + count = session.query(SqlaTable).count() + if not count: + return + print(f">> Copy {count:,} SqlaTable to sl_datasets...") + insert_from_select( + NewDataset, + select( + [ + SqlaTable.uuid, + SqlaTable.created_on, + SqlaTable.changed_on, + SqlaTable.created_by_fk, + SqlaTable.changed_by_fk, + SqlaTable.database_id, + SqlaTable.table_name.label("name"), + func.coalesce(SqlaTable.sql, SqlaTable.table_name).label("expression"), + is_physical_table.label("is_physical"), + SqlaTable.is_managed_externally, + SqlaTable.external_url, + SqlaTable.extra.label("extra_json"), + ] + ), + ) + + print(" Copy dataset owners...") + insert_from_select( + dataset_user_association_table, + select( + [NewDataset.id.label("dataset_id"), sqlatable_user_table.c.user_id] + ).select_from( + sqlatable_user_table.join( + SqlaTable, SqlaTable.id == sqlatable_user_table.c.table_id + ).join(NewDataset, NewDataset.uuid == SqlaTable.uuid) + ), + ) + + print(" Link physical datasets with tables...") + insert_from_select( + dataset_table_association_table, + select( + [ + NewDataset.id.label("dataset_id"), + NewTable.id.label("table_id"), + ] + ).select_from( + sa.join(SqlaTable, NewTable, NewTable.sqlatable_id == SqlaTable.id).join( + NewDataset, NewDataset.uuid == SqlaTable.uuid + ) + ), + ) + + +def copy_columns(session: Session) -> None: + """Copy columns with active associated SqlTable""" + count = session.query(TableColumn).select_from(active_table_columns).count() + if not count: + return + print(f">> Copy {count:,} table columns to sl_columns...") + insert_from_select( + NewColumn, + select( + [ + TableColumn.uuid, + TableColumn.created_on, + TableColumn.changed_on, + TableColumn.created_by_fk, + TableColumn.changed_by_fk, + TableColumn.groupby.label("is_dimensional"), + TableColumn.filterable.label("is_filterable"), + TableColumn.column_name.label("name"), + TableColumn.description, + func.coalesce(TableColumn.expression, TableColumn.column_name).label( + "expression" + ), + sa.literal(False).label("is_aggregation"), + is_physical_column.label("is_physical"), + TableColumn.is_dttm.label("is_temporal"), + func.coalesce(TableColumn.type, UNKNOWN_TYPE).label("type"), + TableColumn.extra.label("extra_json"), + ] + ).select_from(active_table_columns), + ) + + joined_columns_table = active_table_columns.join( + NewColumn, TableColumn.uuid == NewColumn.uuid + ) + print(" Link all columns to sl_datasets...") + insert_from_select( + dataset_column_association_table, + select( + [ + NewDataset.id.label("dataset_id"), + NewColumn.id.label("column_id"), + ], + ).select_from( + joined_columns_table.join(NewDataset, NewDataset.uuid == SqlaTable.uuid) + ), + ) + + +def copy_metrics(session: Session) -> None: + """Copy metrics as virtual columns""" + metrics_count = session.query(SqlMetric).select_from(active_metrics).count() + if not metrics_count: + return + + print(f">> Copy {metrics_count:,} metrics to sl_columns...") + insert_from_select( + NewColumn, + select( + [ + SqlMetric.uuid, + SqlMetric.created_on, + SqlMetric.changed_on, + SqlMetric.created_by_fk, + SqlMetric.changed_by_fk, + SqlMetric.metric_name.label("name"), + SqlMetric.expression, + SqlMetric.description, + sa.literal(UNKNOWN_TYPE).label("type"), + ( + func.coalesce( + sa.func.lower(SqlMetric.metric_type).in_( + ADDITIVE_METRIC_TYPES_LOWER + ), + sa.literal(False), + ).label("is_additive") + ), + sa.literal(True).label("is_aggregation"), + # metrics are by default not filterable + sa.literal(False).label("is_filterable"), + sa.literal(False).label("is_dimensional"), + sa.literal(False).label("is_physical"), + sa.literal(False).label("is_temporal"), + SqlMetric.extra.label("extra_json"), + SqlMetric.warning_text, + ] + ).select_from(active_metrics), + ) + + print(" Link metric columns to datasets...") + insert_from_select( + dataset_column_association_table, + select( + [ + NewDataset.id.label("dataset_id"), + NewColumn.id.label("column_id"), + ], + ).select_from( + active_metrics.join(NewDataset, NewDataset.uuid == SqlaTable.uuid).join( + NewColumn, NewColumn.uuid == SqlMetric.uuid + ) + ), + ) + + +def postprocess_datasets(session: Session) -> None: + """ + Postprocess datasets after insertion to + - Quote table names for physical datasets (if needed) + - Link referenced tables to virtual datasets + """ + total = session.query(SqlaTable).count() + if not total: + return + + offset = 0 + limit = 10000 + + joined_tables = sa.join( + NewDataset, + SqlaTable, + NewDataset.uuid == SqlaTable.uuid, + ).join( + Database, + Database.id == SqlaTable.database_id, + isouter=True, + ) + assert session.query(func.count()).select_from(joined_tables).scalar() == total + + print(f">> Run postprocessing on {total} datasets") + + update_count = 0 + + def print_update_count(): + if SHOW_PROGRESS: + print( + f" Will update {update_count} datasets" + " " * 20, + end="\r", + ) + + while offset < total: + print( + f" Process dataset {offset + 1}~{min(total, offset + limit)}..." + + " " * 30 + ) + for ( + database_id, + dataset_id, + expression, + extra, + is_physical, + schema, + sqlalchemy_uri, + ) in session.execute( + select( + [ + NewDataset.database_id, + NewDataset.id.label("dataset_id"), + NewDataset.expression, + SqlaTable.extra, + NewDataset.is_physical, + SqlaTable.schema, + Database.sqlalchemy_uri, + ] + ) + .select_from(joined_tables) + .offset(offset) + .limit(limit) + ): + drivername = (sqlalchemy_uri or "").split("://")[0] + updates = {} + updated = False + if is_physical and drivername: + quoted_expression = get_identifier_quoter(drivername)(expression) + if quoted_expression != expression: + updates["expression"] = quoted_expression + + # add schema name to `dataset.extra_json` so we don't have to join + # tables in order to use datasets + if schema: + try: + extra_json = json.loads(extra) if extra else {} + except json.decoder.JSONDecodeError: + extra_json = {} + extra_json["schema"] = schema + updates["extra_json"] = json.dumps(extra_json) + + if updates: + session.execute( + sa.update(NewDataset) + .where(NewDataset.id == dataset_id) + .values(**updates) + ) + updated = True + + if not is_physical and expression: + table_refrences = extract_table_references( + expression, get_dialect_name(drivername), show_warning=False + ) + found_tables = find_tables( + session, database_id, schema, table_refrences + ) + if found_tables: + op.bulk_insert( + dataset_table_association_table, + [ + {"dataset_id": dataset_id, "table_id": table.id} + for table in found_tables + ], + ) + updated = True + + if updated: + update_count += 1 + print_update_count() + + session.flush() + offset += limit + + if SHOW_PROGRESS: + print("") + + +def postprocess_columns(session: Session) -> None: + """ + At this step, we will + - Add engine specific quotes to `expression` of physical columns + - Tuck some extra metadata to `extra_json` + """ + total = session.query(NewColumn).count() + if not total: + return + + def get_joined_tables(offset, limit): + return ( + sa.join( + session.query(NewColumn) + .offset(offset) + .limit(limit) + .subquery("sl_columns"), + dataset_column_association_table, + dataset_column_association_table.c.column_id == NewColumn.id, + ) + .join( + NewDataset, + NewDataset.id == dataset_column_association_table.c.dataset_id, + ) + .join( + dataset_table_association_table, + # Join tables with physical datasets + and_( + NewDataset.is_physical, + dataset_table_association_table.c.dataset_id == NewDataset.id, + ), + isouter=True, + ) + .join(Database, Database.id == NewDataset.database_id) + .join( + TableColumn, + TableColumn.uuid == NewColumn.uuid, + isouter=True, + ) + .join( + SqlMetric, + SqlMetric.uuid == NewColumn.uuid, + isouter=True, + ) + ) + + offset = 0 + limit = 100000 + + print(f">> Run postprocessing on {total:,} columns") + + update_count = 0 + + def print_update_count(): + if SHOW_PROGRESS: + print( + f" Will update {update_count} columns" + " " * 20, + end="\r", + ) + + while offset < total: + query = ( + select( + # sorted alphabetically + [ + NewColumn.id.label("column_id"), + TableColumn.column_name, + NewColumn.changed_by_fk, + NewColumn.changed_on, + NewColumn.created_on, + NewColumn.description, + SqlMetric.d3format, + NewDataset.external_url, + NewColumn.extra_json, + NewColumn.is_dimensional, + NewColumn.is_filterable, + NewDataset.is_managed_externally, + NewColumn.is_physical, + SqlMetric.metric_type, + TableColumn.python_date_format, + Database.sqlalchemy_uri, + dataset_table_association_table.c.table_id, + func.coalesce( + TableColumn.verbose_name, SqlMetric.verbose_name + ).label("verbose_name"), + NewColumn.warning_text, + ] + ) + .select_from(get_joined_tables(offset, limit)) + .where( + # pre-filter to columns with potential updates + or_( + NewColumn.is_physical, + TableColumn.verbose_name.isnot(None), + TableColumn.verbose_name.isnot(None), + SqlMetric.verbose_name.isnot(None), + SqlMetric.d3format.isnot(None), + SqlMetric.metric_type.isnot(None), + ) + ) + ) + + start = offset + 1 + end = min(total, offset + limit) + count = session.query(func.count()).select_from(query).scalar() + print(f" [Column {start:,} to {end:,}] {count:,} may be updated") + + physical_columns = [] + + for ( + # sorted alphabetically + column_id, + column_name, + changed_by_fk, + changed_on, + created_on, + description, + d3format, + external_url, + extra_json, + is_dimensional, + is_filterable, + is_managed_externally, + is_physical, + metric_type, + python_date_format, + sqlalchemy_uri, + table_id, + verbose_name, + warning_text, + ) in session.execute(query): + try: + extra = json.loads(extra_json) if extra_json else {} + except json.decoder.JSONDecodeError: + extra = {} + updated_extra = {**extra} + updates = {} + + if is_managed_externally: + updates["is_managed_externally"] = True + if external_url: + updates["external_url"] = external_url + + # update extra json + for (key, val) in ( + { + "verbose_name": verbose_name, + "python_date_format": python_date_format, + "d3format": d3format, + "metric_type": metric_type, + } + ).items(): + # save the original val, including if it's `false` + if val is not None: + updated_extra[key] = val + + if updated_extra != extra: + updates["extra_json"] = json.dumps(updated_extra) + + # update expression for physical table columns + if is_physical: + if column_name and sqlalchemy_uri: + drivername = sqlalchemy_uri.split("://")[0] + if is_physical and drivername: + quoted_expression = get_identifier_quoter(drivername)( + column_name + ) + if quoted_expression != column_name: + updates["expression"] = quoted_expression + # duplicate physical columns for tables + physical_columns.append( + dict( + created_on=created_on, + changed_on=changed_on, + changed_by_fk=changed_by_fk, + description=description, + expression=updates.get("expression", column_name), + external_url=external_url, + extra_json=updates.get("extra_json", extra_json), + is_aggregation=False, + is_dimensional=is_dimensional, + is_filterable=is_filterable, + is_managed_externally=is_managed_externally, + is_physical=True, + name=column_name, + table_id=table_id, + warning_text=warning_text, + ) + ) + + if updates: + session.execute( + sa.update(NewColumn) + .where(NewColumn.id == column_id) + .values(**updates) + ) + update_count += 1 + print_update_count() + + if physical_columns: + op.bulk_insert(NewColumn.__table__, physical_columns) + + session.flush() + offset += limit + + if SHOW_PROGRESS: + print("") + + print(" Assign table column relations...") + insert_from_select( + table_column_association_table, + select([NewColumn.table_id, NewColumn.id.label("column_id")]) + .select_from(NewColumn) + .where(and_(NewColumn.is_physical, NewColumn.table_id.isnot(None))), + ) + + +new_tables: sa.Table = [ + NewTable.__table__, + NewDataset.__table__, + NewColumn.__table__, + table_column_association_table, + dataset_column_association_table, + dataset_table_association_table, + dataset_user_association_table, +] + + +def reset_postgres_id_sequence(table: str) -> None: + op.execute( + f""" + SELECT setval( + pg_get_serial_sequence('{table}', 'id'), + COALESCE(max(id) + 1, 1), + false + ) + FROM {table}; + """ + ) + + +def upgrade() -> None: + bind = op.get_bind() + session: Session = db.Session(bind=bind) + Base.metadata.drop_all(bind=bind, tables=new_tables) + Base.metadata.create_all(bind=bind, tables=new_tables) + + copy_tables(session) + copy_datasets(session) + copy_columns(session) + copy_metrics(session) + session.commit() + + postprocess_columns(session) + session.commit() + + postprocess_datasets(session) + session.commit() + + # Table were created with the same uuids are datasets. They should + # have different uuids as they are different entities. + print(">> Assign new UUIDs to tables...") + assign_uuids(NewTable, session) + + print(">> Drop intermediate columns...") + # These columns are are used during migration, as datasets are independent of tables once created, + # dataset columns also the same to table columns. + with op.batch_alter_table(NewTable.__tablename__) as batch_op: + batch_op.drop_column("sqlatable_id") + with op.batch_alter_table(NewColumn.__tablename__) as batch_op: + batch_op.drop_column("table_id") + + +def downgrade(): + Base.metadata.drop_all(bind=op.get_bind(), tables=new_tables) diff --git a/superset/migrations/versions/b56500de1855_add_uuid_column_to_import_mixin.py b/superset/migrations/versions/b56500de1855_add_uuid_column_to_import_mixin.py index 747ec9fb4f..0872cf5b3b 100644 --- a/superset/migrations/versions/b56500de1855_add_uuid_column_to_import_mixin.py +++ b/superset/migrations/versions/b56500de1855_add_uuid_column_to_import_mixin.py @@ -23,19 +23,17 @@ Create Date: 2020-09-28 17:57:23.128142 """ import json import os -import time from json.decoder import JSONDecodeError from uuid import uuid4 import sqlalchemy as sa from alembic import op -from sqlalchemy.dialects.mysql.base import MySQLDialect -from sqlalchemy.dialects.postgresql.base import PGDialect from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import load_only from sqlalchemy_utils import UUIDType from superset import db +from superset.migrations.shared.utils import assign_uuids from superset.utils import core as utils # revision identifiers, used by Alembic. @@ -78,47 +76,6 @@ models["dashboards"].position_json = sa.Column(utils.MediumText()) default_batch_size = int(os.environ.get("BATCH_SIZE", 200)) -# Add uuids directly using built-in SQL uuid function -add_uuids_by_dialect = { - MySQLDialect: """UPDATE %s SET uuid = UNHEX(REPLACE(CONVERT(UUID() using utf8mb4), '-', ''));""", - PGDialect: """UPDATE %s SET uuid = uuid_in(md5(random()::text || clock_timestamp()::text)::cstring);""", -} - - -def add_uuids(model, table_name, session, batch_size=default_batch_size): - """Populate columns with pre-computed uuids""" - bind = op.get_bind() - objects_query = session.query(model) - count = objects_query.count() - - # silently skip if the table is empty (suitable for db initialization) - if count == 0: - return - - print(f"\nAdding uuids for `{table_name}`...") - start_time = time.time() - - # Use dialect specific native SQL queries if possible - for dialect, sql in add_uuids_by_dialect.items(): - if isinstance(bind.dialect, dialect): - op.execute(sql % table_name) - print(f"Done. Assigned {count} uuids in {time.time() - start_time:.3f}s.") - return - - # Othwewise Use Python uuid function - start = 0 - while start < count: - end = min(start + batch_size, count) - for obj, uuid in map(lambda obj: (obj, uuid4()), objects_query[start:end]): - obj.uuid = uuid - session.merge(obj) - session.commit() - if start + batch_size < count: - print(f" uuid assigned to {end} out of {count}\r", end="") - start += batch_size - - print(f"Done. Assigned {count} uuids in {time.time() - start_time:.3f}s.") - def update_position_json(dashboard, session, uuid_map): try: @@ -178,7 +135,7 @@ def upgrade(): ), ) - add_uuids(model, table_name, session) + assign_uuids(model, session) # add uniqueness constraint with op.batch_alter_table(table_name) as batch_op: @@ -203,7 +160,7 @@ def downgrade(): update_dashboards(session, {}) # remove uuid column - for table_name, model in models.items(): + for table_name in models: with op.batch_alter_table(table_name) as batch_op: batch_op.drop_constraint(f"uq_{table_name}_uuid", type_="unique") batch_op.drop_column("uuid") diff --git a/superset/migrations/versions/b8d3a24d9131_new_dataset_models.py b/superset/migrations/versions/b8d3a24d9131_new_dataset_models.py index 8728e9adb7..e69d1606e3 100644 --- a/superset/migrations/versions/b8d3a24d9131_new_dataset_models.py +++ b/superset/migrations/versions/b8d3a24d9131_new_dataset_models.py @@ -23,619 +23,23 @@ Revises: 5afbb1a5849b Create Date: 2021-11-11 16:41:53.266965 """ - -import json -from datetime import date, datetime, time, timedelta -from typing import Callable, List, Optional, Set -from uuid import uuid4 - -import sqlalchemy as sa -from alembic import op -from sqlalchemy import and_, inspect, or_ -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import backref, relationship, Session -from sqlalchemy.schema import UniqueConstraint -from sqlalchemy.sql.type_api import TypeEngine -from sqlalchemy_utils import UUIDType - -from superset import app, db -from superset.connectors.sqla.models import ADDITIVE_METRIC_TYPES -from superset.databases.utils import make_url_safe -from superset.extensions import encrypted_field_factory -from superset.migrations.shared.utils import extract_table_references -from superset.models.core import Database as OriginalDatabase -from superset.sql_parse import Table - # revision identifiers, used by Alembic. revision = "b8d3a24d9131" down_revision = "5afbb1a5849b" -Base = declarative_base() -custom_password_store = app.config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"] -DB_CONNECTION_MUTATOR = app.config["DB_CONNECTION_MUTATOR"] +# ===================== Notice ======================== +# +# Migrations made in this revision has been moved to `new_dataset_models_take_2` +# to fix performance issues as well as a couple of shortcomings in the original +# design. +# +# ====================================================== -class Database(Base): - __tablename__ = "dbs" - __table_args__ = (UniqueConstraint("database_name"),) - - id = sa.Column(sa.Integer, primary_key=True) - database_name = sa.Column(sa.String(250), unique=True, nullable=False) - sqlalchemy_uri = sa.Column(sa.String(1024), nullable=False) - password = sa.Column(encrypted_field_factory.create(sa.String(1024))) - impersonate_user = sa.Column(sa.Boolean, default=False) - encrypted_extra = sa.Column(encrypted_field_factory.create(sa.Text), nullable=True) - extra = sa.Column( - sa.Text, - default=json.dumps( - dict( - metadata_params={}, - engine_params={}, - metadata_cache_timeout={}, - schemas_allowed_for_file_upload=[], - ) - ), - ) - server_cert = sa.Column(encrypted_field_factory.create(sa.Text), nullable=True) - - -class TableColumn(Base): - - __tablename__ = "table_columns" - __table_args__ = (UniqueConstraint("table_id", "column_name"),) - - id = sa.Column(sa.Integer, primary_key=True) - table_id = sa.Column(sa.Integer, sa.ForeignKey("tables.id")) - is_active = sa.Column(sa.Boolean, default=True) - extra = sa.Column(sa.Text) - column_name = sa.Column(sa.String(255), nullable=False) - type = sa.Column(sa.String(32)) - expression = sa.Column(sa.Text) - description = sa.Column(sa.Text) - is_dttm = sa.Column(sa.Boolean, default=False) - filterable = sa.Column(sa.Boolean, default=True) - groupby = sa.Column(sa.Boolean, default=True) - verbose_name = sa.Column(sa.String(1024)) - python_date_format = sa.Column(sa.String(255)) - - -class SqlMetric(Base): - - __tablename__ = "sql_metrics" - __table_args__ = (UniqueConstraint("table_id", "metric_name"),) - - id = sa.Column(sa.Integer, primary_key=True) - table_id = sa.Column(sa.Integer, sa.ForeignKey("tables.id")) - extra = sa.Column(sa.Text) - metric_type = sa.Column(sa.String(32)) - metric_name = sa.Column(sa.String(255), nullable=False) - expression = sa.Column(sa.Text, nullable=False) - warning_text = sa.Column(sa.Text) - description = sa.Column(sa.Text) - d3format = sa.Column(sa.String(128)) - verbose_name = sa.Column(sa.String(1024)) - - -class SqlaTable(Base): - - __tablename__ = "tables" - __table_args__ = (UniqueConstraint("database_id", "schema", "table_name"),) - - def fetch_columns_and_metrics(self, session: Session) -> None: - self.columns = session.query(TableColumn).filter( - TableColumn.table_id == self.id - ) - self.metrics = session.query(SqlMetric).filter(TableColumn.table_id == self.id) - - id = sa.Column(sa.Integer, primary_key=True) - columns: List[TableColumn] = [] - column_class = TableColumn - metrics: List[SqlMetric] = [] - metric_class = SqlMetric - - database_id = sa.Column(sa.Integer, sa.ForeignKey("dbs.id"), nullable=False) - database: Database = relationship( - "Database", - backref=backref("tables", cascade="all, delete-orphan"), - foreign_keys=[database_id], - ) - schema = sa.Column(sa.String(255)) - table_name = sa.Column(sa.String(250), nullable=False) - sql = sa.Column(sa.Text) - is_managed_externally = sa.Column(sa.Boolean, nullable=False, default=False) - external_url = sa.Column(sa.Text, nullable=True) - - -table_column_association_table = sa.Table( - "sl_table_columns", - Base.metadata, - sa.Column("table_id", sa.ForeignKey("sl_tables.id")), - sa.Column("column_id", sa.ForeignKey("sl_columns.id")), -) - -dataset_column_association_table = sa.Table( - "sl_dataset_columns", - Base.metadata, - sa.Column("dataset_id", sa.ForeignKey("sl_datasets.id")), - sa.Column("column_id", sa.ForeignKey("sl_columns.id")), -) - -dataset_table_association_table = sa.Table( - "sl_dataset_tables", - Base.metadata, - sa.Column("dataset_id", sa.ForeignKey("sl_datasets.id")), - sa.Column("table_id", sa.ForeignKey("sl_tables.id")), -) - - -class NewColumn(Base): - - __tablename__ = "sl_columns" - - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Text) - type = sa.Column(sa.Text) - expression = sa.Column(sa.Text) - is_physical = sa.Column(sa.Boolean, default=True) - description = sa.Column(sa.Text) - warning_text = sa.Column(sa.Text) - is_temporal = sa.Column(sa.Boolean, default=False) - is_aggregation = sa.Column(sa.Boolean, default=False) - is_additive = sa.Column(sa.Boolean, default=False) - is_spatial = sa.Column(sa.Boolean, default=False) - is_partition = sa.Column(sa.Boolean, default=False) - is_increase_desired = sa.Column(sa.Boolean, default=True) - is_managed_externally = sa.Column(sa.Boolean, nullable=False, default=False) - external_url = sa.Column(sa.Text, nullable=True) - extra_json = sa.Column(sa.Text, default="{}") - - -class NewTable(Base): - - __tablename__ = "sl_tables" - __table_args__ = (UniqueConstraint("database_id", "catalog", "schema", "name"),) - - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Text) - schema = sa.Column(sa.Text) - catalog = sa.Column(sa.Text) - database_id = sa.Column(sa.Integer, sa.ForeignKey("dbs.id"), nullable=False) - database: Database = relationship( - "Database", - backref=backref("new_tables", cascade="all, delete-orphan"), - foreign_keys=[database_id], - ) - columns: List[NewColumn] = relationship( - "NewColumn", secondary=table_column_association_table, cascade="all, delete" - ) - is_managed_externally = sa.Column(sa.Boolean, nullable=False, default=False) - external_url = sa.Column(sa.Text, nullable=True) - - -class NewDataset(Base): - - __tablename__ = "sl_datasets" - - id = sa.Column(sa.Integer, primary_key=True) - sqlatable_id = sa.Column(sa.Integer, nullable=True, unique=True) - name = sa.Column(sa.Text) - expression = sa.Column(sa.Text) - tables: List[NewTable] = relationship( - "NewTable", secondary=dataset_table_association_table - ) - columns: List[NewColumn] = relationship( - "NewColumn", secondary=dataset_column_association_table, cascade="all, delete" - ) - is_physical = sa.Column(sa.Boolean, default=False) - is_managed_externally = sa.Column(sa.Boolean, nullable=False, default=False) - external_url = sa.Column(sa.Text, nullable=True) - - -TEMPORAL_TYPES = {date, datetime, time, timedelta} - - -def is_column_type_temporal(column_type: TypeEngine) -> bool: - try: - return column_type.python_type in TEMPORAL_TYPES - except NotImplementedError: - return False - - -def load_or_create_tables( - session: Session, - database_id: int, - default_schema: Optional[str], - tables: Set[Table], - conditional_quote: Callable[[str], str], -) -> List[NewTable]: - """ - Load or create new table model instances. - """ - if not tables: - return [] - - # set the default schema in tables that don't have it - if default_schema: - tables = list(tables) - for i, table in enumerate(tables): - if table.schema is None: - tables[i] = Table(table.table, default_schema, table.catalog) - - # load existing tables - predicate = or_( - *[ - and_( - NewTable.database_id == database_id, - NewTable.schema == table.schema, - NewTable.name == table.table, - ) - for table in tables - ] - ) - new_tables = session.query(NewTable).filter(predicate).all() - - # use original database model to get the engine - engine = ( - session.query(OriginalDatabase) - .filter_by(id=database_id) - .one() - .get_sqla_engine(default_schema) - ) - inspector = inspect(engine) - - # add missing tables - existing = {(table.schema, table.name) for table in new_tables} - for table in tables: - if (table.schema, table.table) not in existing: - column_metadata = inspector.get_columns(table.table, schema=table.schema) - columns = [ - NewColumn( - name=column["name"], - type=str(column["type"]), - expression=conditional_quote(column["name"]), - is_temporal=is_column_type_temporal(column["type"]), - is_aggregation=False, - is_physical=True, - is_spatial=False, - is_partition=False, - is_increase_desired=True, - ) - for column in column_metadata - ] - new_tables.append( - NewTable( - name=table.table, - schema=table.schema, - catalog=None, - database_id=database_id, - columns=columns, - ) - ) - existing.add((table.schema, table.table)) - - return new_tables - - -def after_insert(target: SqlaTable) -> None: # pylint: disable=too-many-locals - """ - Copy old datasets to the new models. - """ - session = inspect(target).session - - # get DB-specific conditional quoter for expressions that point to columns or - # table names - database = ( - target.database - or session.query(Database).filter_by(id=target.database_id).first() - ) - if not database: - return - url = make_url_safe(database.sqlalchemy_uri) - dialect_class = url.get_dialect() - conditional_quote = dialect_class().identifier_preparer.quote - - # create columns - columns = [] - for column in target.columns: - # ``is_active`` might be ``None`` at this point, but it defaults to ``True``. - if column.is_active is False: - continue - - try: - extra_json = json.loads(column.extra or "{}") - except json.decoder.JSONDecodeError: - extra_json = {} - for attr in {"groupby", "filterable", "verbose_name", "python_date_format"}: - value = getattr(column, attr) - if value: - extra_json[attr] = value - - columns.append( - NewColumn( - name=column.column_name, - type=column.type or "Unknown", - expression=column.expression or conditional_quote(column.column_name), - description=column.description, - is_temporal=column.is_dttm, - is_aggregation=False, - is_physical=column.expression is None or column.expression == "", - is_spatial=False, - is_partition=False, - is_increase_desired=True, - extra_json=json.dumps(extra_json) if extra_json else None, - is_managed_externally=target.is_managed_externally, - external_url=target.external_url, - ), - ) - - # create metrics - for metric in target.metrics: - try: - extra_json = json.loads(metric.extra or "{}") - except json.decoder.JSONDecodeError: - extra_json = {} - for attr in {"verbose_name", "metric_type", "d3format"}: - value = getattr(metric, attr) - if value: - extra_json[attr] = value - - is_additive = ( - metric.metric_type and metric.metric_type.lower() in ADDITIVE_METRIC_TYPES - ) - - columns.append( - NewColumn( - name=metric.metric_name, - type="Unknown", # figuring this out would require a type inferrer - expression=metric.expression, - warning_text=metric.warning_text, - description=metric.description, - is_aggregation=True, - is_additive=is_additive, - is_physical=False, - is_spatial=False, - is_partition=False, - is_increase_desired=True, - extra_json=json.dumps(extra_json) if extra_json else None, - is_managed_externally=target.is_managed_externally, - external_url=target.external_url, - ), - ) - - # physical dataset - if not target.sql: - physical_columns = [column for column in columns if column.is_physical] - - # create table - table = NewTable( - name=target.table_name, - schema=target.schema, - catalog=None, # currently not supported - database_id=target.database_id, - columns=physical_columns, - is_managed_externally=target.is_managed_externally, - external_url=target.external_url, - ) - tables = [table] - - # virtual dataset - else: - # mark all columns as virtual (not physical) - for column in columns: - column.is_physical = False - - # find referenced tables - referenced_tables = extract_table_references(target.sql, dialect_class.name) - tables = load_or_create_tables( - session, - target.database_id, - target.schema, - referenced_tables, - conditional_quote, - ) - - # create the new dataset - dataset = NewDataset( - sqlatable_id=target.id, - name=target.table_name, - expression=target.sql or conditional_quote(target.table_name), - tables=tables, - columns=columns, - is_physical=not target.sql, - is_managed_externally=target.is_managed_externally, - external_url=target.external_url, - ) - session.add(dataset) - - -def upgrade(): - # Create tables for the new models. - op.create_table( - "sl_columns", - # AuditMixinNullable - sa.Column("created_on", sa.DateTime(), nullable=True), - sa.Column("changed_on", sa.DateTime(), nullable=True), - sa.Column("created_by_fk", sa.Integer(), nullable=True), - sa.Column("changed_by_fk", sa.Integer(), nullable=True), - # ExtraJSONMixin - sa.Column("extra_json", sa.Text(), nullable=True), - # ImportExportMixin - sa.Column("uuid", UUIDType(binary=True), primary_key=False, default=uuid4), - # Column - sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False), - sa.Column("name", sa.TEXT(), nullable=False), - sa.Column("type", sa.TEXT(), nullable=False), - sa.Column("expression", sa.TEXT(), nullable=False), - sa.Column( - "is_physical", - sa.BOOLEAN(), - nullable=False, - default=True, - ), - sa.Column("description", sa.TEXT(), nullable=True), - sa.Column("warning_text", sa.TEXT(), nullable=True), - sa.Column("unit", sa.TEXT(), nullable=True), - sa.Column("is_temporal", sa.BOOLEAN(), nullable=False), - sa.Column( - "is_spatial", - sa.BOOLEAN(), - nullable=False, - default=False, - ), - sa.Column( - "is_partition", - sa.BOOLEAN(), - nullable=False, - default=False, - ), - sa.Column( - "is_aggregation", - sa.BOOLEAN(), - nullable=False, - default=False, - ), - sa.Column( - "is_additive", - sa.BOOLEAN(), - nullable=False, - default=False, - ), - sa.Column( - "is_increase_desired", - sa.BOOLEAN(), - nullable=False, - default=True, - ), - sa.Column( - "is_managed_externally", - sa.Boolean(), - nullable=False, - server_default=sa.false(), - ), - sa.Column("external_url", sa.Text(), nullable=True), - sa.PrimaryKeyConstraint("id"), - ) - with op.batch_alter_table("sl_columns") as batch_op: - batch_op.create_unique_constraint("uq_sl_columns_uuid", ["uuid"]) - - op.create_table( - "sl_tables", - # AuditMixinNullable - sa.Column("created_on", sa.DateTime(), nullable=True), - sa.Column("changed_on", sa.DateTime(), nullable=True), - sa.Column("created_by_fk", sa.Integer(), nullable=True), - sa.Column("changed_by_fk", sa.Integer(), nullable=True), - # ExtraJSONMixin - sa.Column("extra_json", sa.Text(), nullable=True), - # ImportExportMixin - sa.Column("uuid", UUIDType(binary=True), primary_key=False, default=uuid4), - # Table - sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False), - sa.Column("database_id", sa.INTEGER(), autoincrement=False, nullable=False), - sa.Column("catalog", sa.TEXT(), nullable=True), - sa.Column("schema", sa.TEXT(), nullable=True), - sa.Column("name", sa.TEXT(), nullable=False), - sa.Column( - "is_managed_externally", - sa.Boolean(), - nullable=False, - server_default=sa.false(), - ), - sa.Column("external_url", sa.Text(), nullable=True), - sa.ForeignKeyConstraint(["database_id"], ["dbs.id"], name="sl_tables_ibfk_1"), - sa.PrimaryKeyConstraint("id"), - ) - with op.batch_alter_table("sl_tables") as batch_op: - batch_op.create_unique_constraint("uq_sl_tables_uuid", ["uuid"]) - - op.create_table( - "sl_table_columns", - sa.Column("table_id", sa.INTEGER(), autoincrement=False, nullable=False), - sa.Column("column_id", sa.INTEGER(), autoincrement=False, nullable=False), - sa.ForeignKeyConstraint( - ["column_id"], ["sl_columns.id"], name="sl_table_columns_ibfk_2" - ), - sa.ForeignKeyConstraint( - ["table_id"], ["sl_tables.id"], name="sl_table_columns_ibfk_1" - ), - ) - - op.create_table( - "sl_datasets", - # AuditMixinNullable - sa.Column("created_on", sa.DateTime(), nullable=True), - sa.Column("changed_on", sa.DateTime(), nullable=True), - sa.Column("created_by_fk", sa.Integer(), nullable=True), - sa.Column("changed_by_fk", sa.Integer(), nullable=True), - # ExtraJSONMixin - sa.Column("extra_json", sa.Text(), nullable=True), - # ImportExportMixin - sa.Column("uuid", UUIDType(binary=True), primary_key=False, default=uuid4), - # Dataset - sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False), - sa.Column("sqlatable_id", sa.INTEGER(), nullable=True), - sa.Column("name", sa.TEXT(), nullable=False), - sa.Column("expression", sa.TEXT(), nullable=False), - sa.Column( - "is_physical", - sa.BOOLEAN(), - nullable=False, - default=False, - ), - sa.Column( - "is_managed_externally", - sa.Boolean(), - nullable=False, - server_default=sa.false(), - ), - sa.Column("external_url", sa.Text(), nullable=True), - sa.PrimaryKeyConstraint("id"), - ) - with op.batch_alter_table("sl_datasets") as batch_op: - batch_op.create_unique_constraint("uq_sl_datasets_uuid", ["uuid"]) - batch_op.create_unique_constraint( - "uq_sl_datasets_sqlatable_id", ["sqlatable_id"] - ) - - op.create_table( - "sl_dataset_columns", - sa.Column("dataset_id", sa.INTEGER(), autoincrement=False, nullable=False), - sa.Column("column_id", sa.INTEGER(), autoincrement=False, nullable=False), - sa.ForeignKeyConstraint( - ["column_id"], ["sl_columns.id"], name="sl_dataset_columns_ibfk_2" - ), - sa.ForeignKeyConstraint( - ["dataset_id"], ["sl_datasets.id"], name="sl_dataset_columns_ibfk_1" - ), - ) - - op.create_table( - "sl_dataset_tables", - sa.Column("dataset_id", sa.INTEGER(), autoincrement=False, nullable=False), - sa.Column("table_id", sa.INTEGER(), autoincrement=False, nullable=False), - sa.ForeignKeyConstraint( - ["dataset_id"], ["sl_datasets.id"], name="sl_dataset_tables_ibfk_1" - ), - sa.ForeignKeyConstraint( - ["table_id"], ["sl_tables.id"], name="sl_dataset_tables_ibfk_2" - ), - ) - - # migrate existing datasets to the new models - bind = op.get_bind() - session = db.Session(bind=bind) # pylint: disable=no-member - - datasets = session.query(SqlaTable).all() - for dataset in datasets: - dataset.fetch_columns_and_metrics(session) - after_insert(target=dataset) +def upgrade() -> None: + pass def downgrade(): - op.drop_table("sl_dataset_columns") - op.drop_table("sl_dataset_tables") - op.drop_table("sl_datasets") - op.drop_table("sl_table_columns") - op.drop_table("sl_tables") - op.drop_table("sl_columns") + pass diff --git a/superset/migrations/versions/c501b7c653a3_add_missing_uuid_column.py b/superset/migrations/versions/c501b7c653a3_add_missing_uuid_column.py index 4cfbc104c0..786b41a1c7 100644 --- a/superset/migrations/versions/c501b7c653a3_add_missing_uuid_column.py +++ b/superset/migrations/versions/c501b7c653a3_add_missing_uuid_column.py @@ -38,7 +38,7 @@ from sqlalchemy_utils import UUIDType from superset import db from superset.migrations.versions.b56500de1855_add_uuid_column_to_import_mixin import ( - add_uuids, + assign_uuids, models, update_dashboards, ) @@ -73,7 +73,7 @@ def upgrade(): default=uuid4, ), ) - add_uuids(model, table_name, session) + assign_uuids(model, session) # add uniqueness constraint with op.batch_alter_table(table_name) as batch_op: diff --git a/superset/migrations/versions/f1410ed7ec95_migrate_native_filters_to_new_schema.py b/superset/migrations/versions/f1410ed7ec95_migrate_native_filters_to_new_schema.py index 630a7b1062..46b8e5f958 100644 --- a/superset/migrations/versions/f1410ed7ec95_migrate_native_filters_to_new_schema.py +++ b/superset/migrations/versions/f1410ed7ec95_migrate_native_filters_to_new_schema.py @@ -71,7 +71,7 @@ def downgrade_filters(native_filters: Iterable[Dict[str, Any]]) -> int: filter_state = default_data_mask.get("filterState") if filter_state is not None: changed_filters += 1 - value = filter_state["value"] + value = filter_state.get("value") native_filter["defaultValue"] = value return changed_filters diff --git a/superset/models/core.py b/superset/models/core.py index daa0fb9a7d..c2052749ad 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -408,12 +408,14 @@ class Database( except Exception as ex: raise self.db_engine_spec.get_dbapi_mapped_exception(ex) + @property + def quote_identifier(self) -> Callable[[str], str]: + """Add quotes to potential identifiter expressions if needed""" + return self.get_dialect().identifier_preparer.quote + def get_reserved_words(self) -> Set[str]: return self.get_dialect().preparer.reserved_words - def get_quoter(self) -> Callable[[str, Any], str]: - return self.get_dialect().identifier_preparer.quote - def get_df( # pylint: disable=too-many-locals self, sql: str, diff --git a/superset/models/helpers.py b/superset/models/helpers.py index baa0566c01..3b4e99159f 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -477,7 +477,7 @@ class ExtraJSONMixin: @property def extra(self) -> Dict[str, Any]: try: - return json.loads(self.extra_json) + return json.loads(self.extra_json) if self.extra_json else {} except (TypeError, JSONDecodeError) as exc: logger.error( "Unable to load an extra json: %r. Leaving empty.", exc, exc_info=True @@ -522,18 +522,23 @@ class CertificationMixin: def clone_model( - target: Model, ignore: Optional[List[str]] = None, **kwargs: Any + target: Model, + ignore: Optional[List[str]] = None, + keep_relations: Optional[List[str]] = None, + **kwargs: Any, ) -> Model: """ - Clone a SQLAlchemy model. + Clone a SQLAlchemy model. By default will only clone naive column attributes. + To include relationship attributes, use `keep_relations`. """ ignore = ignore or [] table = target.__table__ + primary_keys = table.primary_key.columns.keys() data = { attr: getattr(target, attr) - for attr in table.columns.keys() - if attr not in table.primary_key.columns.keys() and attr not in ignore + for attr in list(table.columns.keys()) + (keep_relations or []) + if attr not in primary_keys and attr not in ignore } data.update(kwargs) diff --git a/superset/sql_lab.py b/superset/sql_lab.py index d3e08de92a..567ff0d13d 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -186,7 +186,7 @@ def execute_sql_statement( # pylint: disable=too-many-arguments,too-many-locals apply_ctas: bool = False, ) -> SupersetResultSet: """Executes a single SQL statement""" - database = query.database + database: Database = query.database db_engine_spec = database.db_engine_spec parsed_query = ParsedQuery(sql_statement) sql = parsed_query.stripped() diff --git a/superset/sql_parse.py b/superset/sql_parse.py index e3b2e7c196..d377986f56 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -18,7 +18,7 @@ import logging import re from dataclasses import dataclass from enum import Enum -from typing import cast, List, Optional, Set, Tuple +from typing import Any, cast, Iterator, List, Optional, Set, Tuple from urllib import parse import sqlparse @@ -47,10 +47,16 @@ from sqlparse.utils import imt from superset.exceptions import QueryClauseValidationException +try: + from sqloxide import parse_sql as sqloxide_parse +except: # pylint: disable=bare-except + sqloxide_parse = None + RESULT_OPERATIONS = {"UNION", "INTERSECT", "EXCEPT", "SELECT"} ON_KEYWORD = "ON" PRECEDES_TABLE_NAME = {"FROM", "JOIN", "DESCRIBE", "WITH", "LEFT JOIN", "RIGHT JOIN"} CTE_PREFIX = "CTE__" + logger = logging.getLogger(__name__) @@ -176,6 +182,9 @@ class Table: if part ) + def __eq__(self, __o: object) -> bool: + return str(self) == str(__o) + class ParsedQuery: def __init__(self, sql_statement: str, strip_comments: bool = False): @@ -698,3 +707,75 @@ def insert_rls( ) return token_list + + +# mapping between sqloxide and SQLAlchemy dialects +SQLOXITE_DIALECTS = { + "ansi": {"trino", "trinonative", "presto"}, + "hive": {"hive", "databricks"}, + "ms": {"mssql"}, + "mysql": {"mysql"}, + "postgres": { + "cockroachdb", + "hana", + "netezza", + "postgres", + "postgresql", + "redshift", + "vertica", + }, + "snowflake": {"snowflake"}, + "sqlite": {"sqlite", "gsheets", "shillelagh"}, + "clickhouse": {"clickhouse"}, +} + +RE_JINJA_VAR = re.compile(r"\{\{[^\{\}]+\}\}") +RE_JINJA_BLOCK = re.compile(r"\{[%#][^\{\}%#]+[%#]\}") + + +def extract_table_references( + sql_text: str, sqla_dialect: str, show_warning: bool = True +) -> Set["Table"]: + """ + Return all the dependencies from a SQL sql_text. + """ + dialect = "generic" + tree = None + + if sqloxide_parse: + for dialect, sqla_dialects in SQLOXITE_DIALECTS.items(): + if sqla_dialect in sqla_dialects: + break + sql_text = RE_JINJA_BLOCK.sub(" ", sql_text) + sql_text = RE_JINJA_VAR.sub("abc", sql_text) + try: + tree = sqloxide_parse(sql_text, dialect=dialect) + except Exception as ex: # pylint: disable=broad-except + if show_warning: + logger.warning( + "\nUnable to parse query with sqloxide:\n%s\n%s", sql_text, ex + ) + + # fallback to sqlparse + if not tree: + parsed = ParsedQuery(sql_text) + return parsed.tables + + def find_nodes_by_key(element: Any, target: str) -> Iterator[Any]: + """ + Find all nodes in a SQL tree matching a given key. + """ + if isinstance(element, list): + for child in element: + yield from find_nodes_by_key(child, target) + elif isinstance(element, dict): + for key, value in element.items(): + if key == target: + yield value + else: + yield from find_nodes_by_key(value, target) + + return { + Table(*[part["value"] for part in table["name"][::-1]]) + for table in find_nodes_by_key(tree, "Table") + } diff --git a/superset/tables/models.py b/superset/tables/models.py index e2489445c6..9a0c07fdcf 100644 --- a/superset/tables/models.py +++ b/superset/tables/models.py @@ -24,26 +24,41 @@ addition to a table, new models for columns, metrics, and datasets were also int These models are not fully implemented, and shouldn't be used yet. """ -from typing import List +from typing import Any, Dict, Iterable, List, Optional, TYPE_CHECKING import sqlalchemy as sa from flask_appbuilder import Model -from sqlalchemy.orm import backref, relationship +from sqlalchemy import inspect +from sqlalchemy.orm import backref, relationship, Session from sqlalchemy.schema import UniqueConstraint +from sqlalchemy.sql import and_, or_ from superset.columns.models import Column +from superset.connectors.sqla.utils import get_physical_table_metadata from superset.models.core import Database from superset.models.helpers import ( AuditMixinNullable, ExtraJSONMixin, ImportExportMixin, ) +from superset.sql_parse import Table as TableName -association_table = sa.Table( +if TYPE_CHECKING: + from superset.datasets.models import Dataset + +table_column_association_table = sa.Table( "sl_table_columns", Model.metadata, # pylint: disable=no-member - sa.Column("table_id", sa.ForeignKey("sl_tables.id")), - sa.Column("column_id", sa.ForeignKey("sl_columns.id")), + sa.Column( + "table_id", + sa.ForeignKey("sl_tables.id", ondelete="cascade"), + primary_key=True, + ), + sa.Column( + "column_id", + sa.ForeignKey("sl_columns.id", ondelete="cascade"), + primary_key=True, + ), ) @@ -61,7 +76,6 @@ class Table(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin): __table_args__ = (UniqueConstraint("database_id", "catalog", "schema", "name"),) id = sa.Column(sa.Integer, primary_key=True) - database_id = sa.Column(sa.Integer, sa.ForeignKey("dbs.id"), nullable=False) database: Database = relationship( "Database", @@ -70,6 +84,19 @@ class Table(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin): backref=backref("new_tables", cascade="all, delete-orphan"), foreign_keys=[database_id], ) + # The relationship between datasets and columns is 1:n, but we use a + # many-to-many association table to avoid adding two mutually exclusive + # columns(dataset_id and table_id) to Column + columns: List[Column] = relationship( + "Column", + secondary=table_column_association_table, + cascade="all, delete-orphan", + single_parent=True, + # backref is needed for session to skip detaching `dataset` if only `column` + # is loaded. + backref="tables", + ) + datasets: List["Dataset"] # will be populated by Dataset.tables backref # We use ``sa.Text`` for these attributes because (1) in modern databases the # performance is the same as ``VARCHAR``[1] and (2) because some table names can be @@ -80,13 +107,96 @@ class Table(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin): schema = sa.Column(sa.Text) name = sa.Column(sa.Text) - # The relationship between tables and columns is 1:n, but we use a many-to-many - # association to differentiate between the relationship between datasets and - # columns. - columns: List[Column] = relationship( - "Column", secondary=association_table, cascade="all, delete" - ) - # Column is managed externally and should be read-only inside Superset is_managed_externally = sa.Column(sa.Boolean, nullable=False, default=False) external_url = sa.Column(sa.Text, nullable=True) + + @property + def fullname(self) -> str: + return str(TableName(table=self.name, schema=self.schema, catalog=self.catalog)) + + def __repr__(self) -> str: + return f"" + + def sync_columns(self) -> None: + """Sync table columns with the database. Keep metadata for existing columns""" + try: + column_metadata = get_physical_table_metadata( + self.database, self.name, self.schema + ) + except Exception: # pylint: disable=broad-except + column_metadata = [] + + existing_columns = {column.name: column for column in self.columns} + quote_identifier = self.database.quote_identifier + + def update_or_create_column(column_meta: Dict[str, Any]) -> Column: + column_name: str = column_meta["name"] + if column_name in existing_columns: + column = existing_columns[column_name] + else: + column = Column(name=column_name) + column.type = column_meta["type"] + column.is_temporal = column_meta["is_dttm"] + column.expression = quote_identifier(column_name) + column.is_aggregation = False + column.is_physical = True + column.is_spatial = False + column.is_partition = False # TODO: update with accurate is_partition + return column + + self.columns = [update_or_create_column(col) for col in column_metadata] + + @staticmethod + def bulk_load_or_create( + database: Database, + table_names: Iterable[TableName], + default_schema: Optional[str] = None, + sync_columns: Optional[bool] = False, + default_props: Optional[Dict[str, Any]] = None, + ) -> List["Table"]: + """ + Load or create multiple Table instances. + """ + if not table_names: + return [] + + if not database.id: + raise Exception("Database must be already saved to metastore") + + default_props = default_props or {} + session: Session = inspect(database).session + # load existing tables + predicate = or_( + *[ + and_( + Table.database_id == database.id, + Table.schema == (table.schema or default_schema), + Table.name == table.table, + ) + for table in table_names + ] + ) + all_tables = session.query(Table).filter(predicate).order_by(Table.id).all() + + # add missing tables and pull its columns + existing = {(table.schema, table.name) for table in all_tables} + for table in table_names: + schema = table.schema or default_schema + name = table.table + if (schema, name) not in existing: + new_table = Table( + database=database, + database_id=database.id, + name=name, + schema=schema, + catalog=None, + **default_props, + ) + if sync_columns: + new_table.sync_columns() + all_tables.append(new_table) + existing.add((schema, name)) + session.add(new_table) + + return all_tables diff --git a/tests/integration_tests/commands_test.py b/tests/integration_tests/commands_test.py index 5ff18b02a9..77fbad05f3 100644 --- a/tests/integration_tests/commands_test.py +++ b/tests/integration_tests/commands_test.py @@ -16,11 +16,11 @@ # under the License. import copy import json -from unittest.mock import patch import yaml +from flask import g -from superset import db, security_manager +from superset import db from superset.commands.exceptions import CommandInvalidError from superset.commands.importers.v1.assets import ImportAssetsCommand from superset.commands.importers.v1.utils import is_valid_config @@ -58,10 +58,13 @@ class TestImportersV1Utils(SupersetTestCase): class TestImportAssetsCommand(SupersetTestCase): - @patch("superset.dashboards.commands.importers.v1.utils.g") - def test_import_assets(self, mock_g): + def setUp(self): + user = self.get_user("admin") + self.user = user + setattr(g, "user", user) + + def test_import_assets(self): """Test that we can import multiple assets""" - mock_g.user = security_manager.find_user("admin") contents = { "metadata.yaml": yaml.safe_dump(metadata_config), "databases/imported_database.yaml": yaml.safe_dump(database_config), @@ -141,7 +144,7 @@ class TestImportAssetsCommand(SupersetTestCase): database = dataset.database assert str(database.uuid) == database_config["uuid"] - assert dashboard.owners == [mock_g.user] + assert dashboard.owners == [self.user] dashboard.owners = [] chart.owners = [] @@ -153,11 +156,8 @@ class TestImportAssetsCommand(SupersetTestCase): db.session.delete(database) db.session.commit() - @patch("superset.dashboards.commands.importers.v1.utils.g") - def test_import_v1_dashboard_overwrite(self, mock_g): + def test_import_v1_dashboard_overwrite(self): """Test that assets can be overwritten""" - mock_g.user = security_manager.find_user("admin") - contents = { "metadata.yaml": yaml.safe_dump(metadata_config), "databases/imported_database.yaml": yaml.safe_dump(database_config), diff --git a/tests/integration_tests/fixtures/world_bank_dashboard.py b/tests/integration_tests/fixtures/world_bank_dashboard.py index 1ac1706a9d..e767036b7d 100644 --- a/tests/integration_tests/fixtures/world_bank_dashboard.py +++ b/tests/integration_tests/fixtures/world_bank_dashboard.py @@ -111,11 +111,10 @@ def _commit_slices(slices: List[Slice]): def _create_world_bank_dashboard(table: SqlaTable, slices: List[Slice]) -> Dashboard: + from superset.examples.helpers import update_slice_ids from superset.examples.world_bank import dashboard_positions pos = dashboard_positions - from superset.examples.helpers import update_slice_ids - update_slice_ids(pos, slices) table.fetch_metadata() diff --git a/tests/integration_tests/sqla_models_tests.py b/tests/integration_tests/sqla_models_tests.py index bbe062e509..d23b95f53c 100644 --- a/tests/integration_tests/sqla_models_tests.py +++ b/tests/integration_tests/sqla_models_tests.py @@ -455,7 +455,8 @@ class TestDatabaseModel(SupersetTestCase): # make sure the columns have been mapped properly assert len(table.columns) == 4 - table.fetch_metadata() + table.fetch_metadata(commit=False) + # assert that the removed column has been dropped and # the physical and calculated columns are present assert {col.column_name for col in table.columns} == { @@ -473,6 +474,8 @@ class TestDatabaseModel(SupersetTestCase): assert VIRTUAL_TABLE_STRING_TYPES[backend].match(cols["mycase"].type) assert cols["expr"].expression == "case when 1 then 1 else 0 end" + db.session.delete(table) + @patch("superset.models.core.Database.db_engine_spec", BigQueryEngineSpec) def test_labels_expected_on_mutated_query(self): query_obj = { diff --git a/tests/integration_tests/utils_tests.py b/tests/integration_tests/utils_tests.py index 5add2c5f6e..7e8aede6a3 100644 --- a/tests/integration_tests/utils_tests.py +++ b/tests/integration_tests/utils_tests.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. # isort:skip_file -import unittest import uuid from datetime import date, datetime, time, timedelta from decimal import Decimal diff --git a/tests/unit_tests/conftest.py b/tests/unit_tests/conftest.py index 4987aaf0e0..86fb0127b8 100644 --- a/tests/unit_tests/conftest.py +++ b/tests/unit_tests/conftest.py @@ -17,7 +17,7 @@ # pylint: disable=redefined-outer-name, import-outside-toplevel import importlib -from typing import Any, Iterator +from typing import Any, Callable, Iterator import pytest from pytest_mock import MockFixture @@ -31,25 +31,33 @@ from superset.initialization import SupersetAppInitializer @pytest.fixture -def session(mocker: MockFixture) -> Iterator[Session]: +def get_session(mocker: MockFixture) -> Callable[[], Session]: """ Create an in-memory SQLite session to test models. """ engine = create_engine("sqlite://") - Session_ = sessionmaker(bind=engine) # pylint: disable=invalid-name - in_memory_session = Session_() - # flask calls session.remove() - in_memory_session.remove = lambda: None + def get_session(): + Session_ = sessionmaker(bind=engine) # pylint: disable=invalid-name + in_memory_session = Session_() - # patch session - mocker.patch( - "superset.security.SupersetSecurityManager.get_session", - return_value=in_memory_session, - ) - mocker.patch("superset.db.session", in_memory_session) + # flask calls session.remove() + in_memory_session.remove = lambda: None - yield in_memory_session + # patch session + mocker.patch( + "superset.security.SupersetSecurityManager.get_session", + return_value=in_memory_session, + ) + mocker.patch("superset.db.session", in_memory_session) + return in_memory_session + + return get_session + + +@pytest.fixture +def session(get_session) -> Iterator[Session]: + yield get_session() @pytest.fixture(scope="module") diff --git a/tests/unit_tests/datasets/conftest.py b/tests/unit_tests/datasets/conftest.py new file mode 100644 index 0000000000..9d9403934d --- /dev/null +++ b/tests/unit_tests/datasets/conftest.py @@ -0,0 +1,118 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Any, Dict, TYPE_CHECKING + +import pytest + +if TYPE_CHECKING: + from superset.connectors.sqla.models import SqlMetric, TableColumn + + +@pytest.fixture +def columns_default() -> Dict[str, Any]: + """Default props for new columns""" + return { + "changed_by": 1, + "created_by": 1, + "datasets": [], + "tables": [], + "is_additive": False, + "is_aggregation": False, + "is_dimensional": False, + "is_filterable": True, + "is_increase_desired": True, + "is_partition": False, + "is_physical": True, + "is_spatial": False, + "is_temporal": False, + "description": None, + "extra_json": "{}", + "unit": None, + "warning_text": None, + "is_managed_externally": False, + "external_url": None, + } + + +@pytest.fixture +def sample_columns() -> Dict["TableColumn", Dict[str, Any]]: + from superset.connectors.sqla.models import TableColumn + + return { + TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"): { + "name": "ds", + "expression": "ds", + "type": "TIMESTAMP", + "is_temporal": True, + "is_physical": True, + }, + TableColumn(column_name="num_boys", type="INTEGER", groupby=True): { + "name": "num_boys", + "expression": "num_boys", + "type": "INTEGER", + "is_dimensional": True, + "is_physical": True, + }, + TableColumn(column_name="region", type="VARCHAR", groupby=True): { + "name": "region", + "expression": "region", + "type": "VARCHAR", + "is_dimensional": True, + "is_physical": True, + }, + TableColumn( + column_name="profit", + type="INTEGER", + groupby=False, + expression="revenue-expenses", + ): { + "name": "profit", + "expression": "revenue-expenses", + "type": "INTEGER", + "is_physical": False, + }, + } + + +@pytest.fixture +def sample_metrics() -> Dict["SqlMetric", Dict[str, Any]]: + from superset.connectors.sqla.models import SqlMetric + + return { + SqlMetric(metric_name="cnt", expression="COUNT(*)", metric_type="COUNT"): { + "name": "cnt", + "expression": "COUNT(*)", + "extra_json": '{"metric_type": "COUNT"}', + "type": "UNKNOWN", + "is_additive": True, + "is_aggregation": True, + "is_filterable": False, + "is_physical": False, + }, + SqlMetric( + metric_name="avg revenue", expression="AVG(revenue)", metric_type="AVG" + ): { + "name": "avg revenue", + "expression": "AVG(revenue)", + "extra_json": '{"metric_type": "AVG"}', + "type": "UNKNOWN", + "is_additive": False, + "is_aggregation": True, + "is_filterable": False, + "is_physical": False, + }, + } diff --git a/tests/unit_tests/datasets/test_models.py b/tests/unit_tests/datasets/test_models.py index d21ef8ea60..08e0f11e0d 100644 --- a/tests/unit_tests/datasets/test_models.py +++ b/tests/unit_tests/datasets/test_models.py @@ -15,14 +15,17 @@ # specific language governing permissions and limitations # under the License. -# pylint: disable=import-outside-toplevel, unused-argument, unused-import, too-many-locals, invalid-name, too-many-lines - import json -from datetime import datetime, timezone +from typing import Any, Callable, Dict, List, TYPE_CHECKING from pytest_mock import MockFixture from sqlalchemy.orm.session import Session +from tests.unit_tests.utils.db import get_test_user + +if TYPE_CHECKING: + from superset.connectors.sqla.models import SqlMetric, TableColumn + def test_dataset_model(app_context: None, session: Session) -> None: """ @@ -50,6 +53,7 @@ def test_dataset_model(app_context: None, session: Session) -> None: session.flush() dataset = Dataset( + database=table.database, name="positions", expression=""" SELECT array_agg(array[longitude,latitude]) AS position @@ -148,6 +152,7 @@ def test_cascade_delete_dataset(app_context: None, session: Session) -> None: SELECT array_agg(array[longitude,latitude]) AS position FROM my_catalog.my_schema.my_table """, + database=table.database, tables=[table], columns=[ Column( @@ -185,7 +190,7 @@ def test_dataset_attributes(app_context: None, session: Session) -> None: columns = [ TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"), - TableColumn(column_name="user_id", type="INTEGER"), + TableColumn(column_name="num_boys", type="INTEGER"), TableColumn(column_name="revenue", type="INTEGER"), TableColumn(column_name="expenses", type="INTEGER"), TableColumn( @@ -254,6 +259,7 @@ def test_dataset_attributes(app_context: None, session: Session) -> None: "main_dttm_col", "metrics", "offset", + "owners", "params", "perm", "schema", @@ -265,7 +271,13 @@ def test_dataset_attributes(app_context: None, session: Session) -> None: ] -def test_create_physical_sqlatable(app_context: None, session: Session) -> None: +def test_create_physical_sqlatable( + app_context: None, + session: Session, + sample_columns: Dict["TableColumn", Dict[str, Any]], + sample_metrics: Dict["SqlMetric", Dict[str, Any]], + columns_default: Dict[str, Any], +) -> None: """ Test shadow write when creating a new ``SqlaTable``. @@ -274,7 +286,7 @@ def test_create_physical_sqlatable(app_context: None, session: Session) -> None: """ from superset.columns.models import Column from superset.columns.schemas import ColumnSchema - from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn + from superset.connectors.sqla.models import SqlaTable from superset.datasets.models import Dataset from superset.datasets.schemas import DatasetSchema from superset.models.core import Database @@ -283,19 +295,11 @@ def test_create_physical_sqlatable(app_context: None, session: Session) -> None: engine = session.get_bind() Dataset.metadata.create_all(engine) # pylint: disable=no-member - - columns = [ - TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"), - TableColumn(column_name="user_id", type="INTEGER"), - TableColumn(column_name="revenue", type="INTEGER"), - TableColumn(column_name="expenses", type="INTEGER"), - TableColumn( - column_name="profit", type="INTEGER", expression="revenue-expenses" - ), - ] - metrics = [ - SqlMetric(metric_name="cnt", expression="COUNT(*)"), - ] + user1 = get_test_user(1, "abc") + columns = list(sample_columns.keys()) + metrics = list(sample_metrics.keys()) + expected_table_columns = list(sample_columns.values()) + expected_metric_columns = list(sample_metrics.values()) sqla_table = SqlaTable( table_name="old_dataset", @@ -317,6 +321,9 @@ def test_create_physical_sqlatable(app_context: None, session: Session) -> None: "import_time": 1606677834, } ), + created_by=user1, + changed_by=user1, + owners=[user1], perm=None, filter_select_enabled=1, fetch_values_predicate="foo IN (1, 2)", @@ -329,164 +336,85 @@ def test_create_physical_sqlatable(app_context: None, session: Session) -> None: session.flush() # ignore these keys when comparing results - ignored_keys = {"created_on", "changed_on", "uuid"} + ignored_keys = {"created_on", "changed_on"} # check that columns were created column_schema = ColumnSchema() - column_schemas = [ + actual_columns = [ {k: v for k, v in column_schema.dump(column).items() if k not in ignored_keys} for column in session.query(Column).all() ] - assert column_schemas == [ - { - "changed_by": None, - "created_by": None, - "description": None, - "expression": "ds", - "extra_json": "{}", - "id": 1, - "is_increase_desired": True, - "is_additive": False, - "is_aggregation": False, - "is_partition": False, + num_physical_columns = len( + [col for col in expected_table_columns if col.get("is_physical") == True] + ) + num_dataset_table_columns = len(columns) + num_dataset_metric_columns = len(metrics) + assert ( + len(actual_columns) + == num_physical_columns + num_dataset_table_columns + num_dataset_metric_columns + ) + + # table columns are created before dataset columns are created + offset = 0 + for i in range(num_physical_columns): + assert actual_columns[i + offset] == { + **columns_default, + **expected_table_columns[i], + "id": i + offset + 1, + # physical columns for table have its own uuid + "uuid": actual_columns[i + offset]["uuid"], "is_physical": True, - "is_spatial": False, - "is_temporal": True, - "name": "ds", - "type": "TIMESTAMP", - "unit": None, - "warning_text": None, - "is_managed_externally": False, - "external_url": None, - }, - { - "changed_by": None, + # table columns do not have creators "created_by": None, - "description": None, - "expression": "user_id", - "extra_json": "{}", - "id": 2, - "is_increase_desired": True, - "is_additive": False, - "is_aggregation": False, - "is_partition": False, - "is_physical": True, - "is_spatial": False, - "is_temporal": False, - "name": "user_id", - "type": "INTEGER", - "unit": None, - "warning_text": None, - "is_managed_externally": False, - "external_url": None, - }, - { - "changed_by": None, - "created_by": None, - "description": None, - "expression": "revenue", - "extra_json": "{}", - "id": 3, - "is_increase_desired": True, - "is_additive": False, - "is_aggregation": False, - "is_partition": False, - "is_physical": True, - "is_spatial": False, - "is_temporal": False, - "name": "revenue", - "type": "INTEGER", - "unit": None, - "warning_text": None, - "is_managed_externally": False, - "external_url": None, - }, - { - "changed_by": None, - "created_by": None, - "description": None, - "expression": "expenses", - "extra_json": "{}", - "id": 4, - "is_increase_desired": True, - "is_additive": False, - "is_aggregation": False, - "is_partition": False, - "is_physical": True, - "is_spatial": False, - "is_temporal": False, - "name": "expenses", - "type": "INTEGER", - "unit": None, - "warning_text": None, - "is_managed_externally": False, - "external_url": None, - }, - { - "changed_by": None, - "created_by": None, - "description": None, - "expression": "revenue-expenses", - "extra_json": "{}", - "id": 5, - "is_increase_desired": True, - "is_additive": False, - "is_aggregation": False, - "is_partition": False, - "is_physical": False, - "is_spatial": False, - "is_temporal": False, - "name": "profit", - "type": "INTEGER", - "unit": None, - "warning_text": None, - "is_managed_externally": False, - "external_url": None, - }, - { - "changed_by": None, - "created_by": None, - "description": None, - "expression": "COUNT(*)", - "extra_json": "{}", - "id": 6, - "is_increase_desired": True, - "is_additive": False, - "is_aggregation": True, - "is_partition": False, - "is_physical": False, - "is_spatial": False, - "is_temporal": False, - "name": "cnt", - "type": "Unknown", - "unit": None, - "warning_text": None, - "is_managed_externally": False, - "external_url": None, - }, - ] + "tables": [1], + } + + offset += num_physical_columns + for i, column in enumerate(sqla_table.columns): + assert actual_columns[i + offset] == { + **columns_default, + **expected_table_columns[i], + "id": i + offset + 1, + # columns for dataset reuses the same uuid of TableColumn + "uuid": str(column.uuid), + "datasets": [1], + } + + offset += num_dataset_table_columns + for i, metric in enumerate(sqla_table.metrics): + assert actual_columns[i + offset] == { + **columns_default, + **expected_metric_columns[i], + "id": i + offset + 1, + "uuid": str(metric.uuid), + "datasets": [1], + } # check that table was created table_schema = TableSchema() tables = [ - {k: v for k, v in table_schema.dump(table).items() if k not in ignored_keys} + { + k: v + for k, v in table_schema.dump(table).items() + if k not in (ignored_keys | {"uuid"}) + } for table in session.query(Table).all() ] - assert tables == [ - { - "extra_json": "{}", - "catalog": None, - "schema": "my_schema", - "name": "old_dataset", - "id": 1, - "database": 1, - "columns": [1, 2, 3, 4], - "created_by": None, - "changed_by": None, - "is_managed_externally": False, - "external_url": None, - } - ] + assert len(tables) == 1 + assert tables[0] == { + "id": 1, + "database": 1, + "created_by": 1, + "changed_by": 1, + "datasets": [1], + "columns": [1, 2, 3], + "extra_json": "{}", + "catalog": None, + "schema": "my_schema", + "name": "old_dataset", + "is_managed_externally": False, + "external_url": None, + } # check that dataset was created dataset_schema = DatasetSchema() @@ -494,26 +422,32 @@ def test_create_physical_sqlatable(app_context: None, session: Session) -> None: {k: v for k, v in dataset_schema.dump(dataset).items() if k not in ignored_keys} for dataset in session.query(Dataset).all() ] - assert datasets == [ - { - "id": 1, - "sqlatable_id": 1, - "name": "old_dataset", - "changed_by": None, - "created_by": None, - "columns": [1, 2, 3, 4, 5, 6], - "is_physical": True, - "tables": [1], - "extra_json": "{}", - "expression": "old_dataset", - "is_managed_externally": False, - "external_url": None, - } - ] + assert len(datasets) == 1 + assert datasets[0] == { + "id": 1, + "uuid": str(sqla_table.uuid), + "created_by": 1, + "changed_by": 1, + "owners": [1], + "name": "old_dataset", + "columns": [4, 5, 6, 7, 8, 9], + "is_physical": True, + "database": 1, + "tables": [1], + "extra_json": "{}", + "expression": "old_dataset", + "is_managed_externally": False, + "external_url": None, + } def test_create_virtual_sqlatable( - mocker: MockFixture, app_context: None, session: Session + app_context: None, + mocker: MockFixture, + session: Session, + sample_columns: Dict["TableColumn", Dict[str, Any]], + sample_metrics: Dict["SqlMetric", Dict[str, Any]], + columns_default: Dict[str, Any], ) -> None: """ Test shadow write when creating a new ``SqlaTable``. @@ -528,7 +462,7 @@ def test_create_virtual_sqlatable( from superset.columns.models import Column from superset.columns.schemas import ColumnSchema - from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn + from superset.connectors.sqla.models import SqlaTable from superset.datasets.models import Dataset from superset.datasets.schemas import DatasetSchema from superset.models.core import Database @@ -536,8 +470,20 @@ def test_create_virtual_sqlatable( engine = session.get_bind() Dataset.metadata.create_all(engine) # pylint: disable=no-member - - # create the ``Table`` that the virtual dataset points to + user1 = get_test_user(1, "abc") + physical_table_columns: List[Dict[str, Any]] = [ + dict( + name="ds", + is_temporal=True, + type="TIMESTAMP", + expression="ds", + is_physical=True, + ), + dict(name="num_boys", type="INTEGER", expression="num_boys", is_physical=True), + dict(name="revenue", type="INTEGER", expression="revenue", is_physical=True), + dict(name="expenses", type="INTEGER", expression="expenses", is_physical=True), + ] + # create a physical ``Table`` that the virtual dataset points to database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") table = Table( name="some_table", @@ -545,30 +491,26 @@ def test_create_virtual_sqlatable( catalog=None, database=database, columns=[ - Column(name="ds", is_temporal=True, type="TIMESTAMP"), - Column(name="user_id", type="INTEGER"), - Column(name="revenue", type="INTEGER"), - Column(name="expenses", type="INTEGER"), + Column(**props, created_by=user1, changed_by=user1) + for props in physical_table_columns ], ) session.add(table) session.commit() + assert session.query(Table).count() == 1 + assert session.query(Dataset).count() == 0 + # create virtual dataset - columns = [ - TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"), - TableColumn(column_name="user_id", type="INTEGER"), - TableColumn(column_name="revenue", type="INTEGER"), - TableColumn(column_name="expenses", type="INTEGER"), - TableColumn( - column_name="profit", type="INTEGER", expression="revenue-expenses" - ), - ] - metrics = [ - SqlMetric(metric_name="cnt", expression="COUNT(*)"), - ] + columns = list(sample_columns.keys()) + metrics = list(sample_metrics.keys()) + expected_table_columns = list(sample_columns.values()) + expected_metric_columns = list(sample_metrics.values()) sqla_table = SqlaTable( + created_by=user1, + changed_by=user1, + owners=[user1], table_name="old_dataset", columns=columns, metrics=metrics, @@ -583,7 +525,7 @@ def test_create_virtual_sqlatable( sql=""" SELECT ds, - user_id, + num_boys, revenue, expenses, revenue - expenses AS profit @@ -607,227 +549,54 @@ FROM session.add(sqla_table) session.flush() - # ignore these keys when comparing results - ignored_keys = {"created_on", "changed_on", "uuid"} + # should not add a new table + assert session.query(Table).count() == 1 + assert session.query(Dataset).count() == 1 - # check that columns were created + # ignore these keys when comparing results + ignored_keys = {"created_on", "changed_on"} column_schema = ColumnSchema() - column_schemas = [ + actual_columns = [ {k: v for k, v in column_schema.dump(column).items() if k not in ignored_keys} for column in session.query(Column).all() ] - assert column_schemas == [ - { - "type": "TIMESTAMP", - "is_additive": False, - "extra_json": "{}", - "is_partition": False, - "expression": None, - "unit": None, - "warning_text": None, - "created_by": None, - "is_increase_desired": True, - "description": None, - "is_spatial": False, - "name": "ds", - "is_physical": True, - "changed_by": None, - "is_temporal": True, - "id": 1, - "is_aggregation": False, - "external_url": None, - "is_managed_externally": False, - }, - { - "type": "INTEGER", - "is_additive": False, - "extra_json": "{}", - "is_partition": False, - "expression": None, - "unit": None, - "warning_text": None, - "created_by": None, - "is_increase_desired": True, - "description": None, - "is_spatial": False, - "name": "user_id", - "is_physical": True, - "changed_by": None, - "is_temporal": False, - "id": 2, - "is_aggregation": False, - "external_url": None, - "is_managed_externally": False, - }, - { - "type": "INTEGER", - "is_additive": False, - "extra_json": "{}", - "is_partition": False, - "expression": None, - "unit": None, - "warning_text": None, - "created_by": None, - "is_increase_desired": True, - "description": None, - "is_spatial": False, - "name": "revenue", - "is_physical": True, - "changed_by": None, - "is_temporal": False, - "id": 3, - "is_aggregation": False, - "external_url": None, - "is_managed_externally": False, - }, - { - "type": "INTEGER", - "is_additive": False, - "extra_json": "{}", - "is_partition": False, - "expression": None, - "unit": None, - "warning_text": None, - "created_by": None, - "is_increase_desired": True, - "description": None, - "is_spatial": False, - "name": "expenses", - "is_physical": True, - "changed_by": None, - "is_temporal": False, - "id": 4, - "is_aggregation": False, - "external_url": None, - "is_managed_externally": False, - }, - { - "type": "TIMESTAMP", - "is_additive": False, - "extra_json": "{}", - "is_partition": False, - "expression": "ds", - "unit": None, - "warning_text": None, - "created_by": None, - "is_increase_desired": True, - "description": None, - "is_spatial": False, - "name": "ds", + num_physical_columns = len(physical_table_columns) + num_dataset_table_columns = len(columns) + num_dataset_metric_columns = len(metrics) + assert ( + len(actual_columns) + == num_physical_columns + num_dataset_table_columns + num_dataset_metric_columns + ) + + for i, column in enumerate(table.columns): + assert actual_columns[i] == { + **columns_default, + **physical_table_columns[i], + "id": i + 1, + "uuid": str(column.uuid), + "tables": [1], + } + + offset = num_physical_columns + for i, column in enumerate(sqla_table.columns): + assert actual_columns[i + offset] == { + **columns_default, + **expected_table_columns[i], + "id": i + offset + 1, + "uuid": str(column.uuid), "is_physical": False, - "changed_by": None, - "is_temporal": True, - "id": 5, - "is_aggregation": False, - "external_url": None, - "is_managed_externally": False, - }, - { - "type": "INTEGER", - "is_additive": False, - "extra_json": "{}", - "is_partition": False, - "expression": "user_id", - "unit": None, - "warning_text": None, - "created_by": None, - "is_increase_desired": True, - "description": None, - "is_spatial": False, - "name": "user_id", - "is_physical": False, - "changed_by": None, - "is_temporal": False, - "id": 6, - "is_aggregation": False, - "external_url": None, - "is_managed_externally": False, - }, - { - "type": "INTEGER", - "is_additive": False, - "extra_json": "{}", - "is_partition": False, - "expression": "revenue", - "unit": None, - "warning_text": None, - "created_by": None, - "is_increase_desired": True, - "description": None, - "is_spatial": False, - "name": "revenue", - "is_physical": False, - "changed_by": None, - "is_temporal": False, - "id": 7, - "is_aggregation": False, - "external_url": None, - "is_managed_externally": False, - }, - { - "type": "INTEGER", - "is_additive": False, - "extra_json": "{}", - "is_partition": False, - "expression": "expenses", - "unit": None, - "warning_text": None, - "created_by": None, - "is_increase_desired": True, - "description": None, - "is_spatial": False, - "name": "expenses", - "is_physical": False, - "changed_by": None, - "is_temporal": False, - "id": 8, - "is_aggregation": False, - "external_url": None, - "is_managed_externally": False, - }, - { - "type": "INTEGER", - "is_additive": False, - "extra_json": "{}", - "is_partition": False, - "expression": "revenue-expenses", - "unit": None, - "warning_text": None, - "created_by": None, - "is_increase_desired": True, - "description": None, - "is_spatial": False, - "name": "profit", - "is_physical": False, - "changed_by": None, - "is_temporal": False, - "id": 9, - "is_aggregation": False, - "external_url": None, - "is_managed_externally": False, - }, - { - "type": "Unknown", - "is_additive": False, - "extra_json": "{}", - "is_partition": False, - "expression": "COUNT(*)", - "unit": None, - "warning_text": None, - "created_by": None, - "is_increase_desired": True, - "description": None, - "is_spatial": False, - "name": "cnt", - "is_physical": False, - "changed_by": None, - "is_temporal": False, - "id": 10, - "is_aggregation": True, - "external_url": None, - "is_managed_externally": False, - }, - ] + "datasets": [1], + } + + offset = num_physical_columns + num_dataset_table_columns + for i, metric in enumerate(sqla_table.metrics): + assert actual_columns[i + offset] == { + **columns_default, + **expected_metric_columns[i], + "id": i + offset + 1, + "uuid": str(metric.uuid), + "datasets": [1], + } # check that dataset was created, and has a reference to the table dataset_schema = DatasetSchema() @@ -835,30 +604,31 @@ FROM {k: v for k, v in dataset_schema.dump(dataset).items() if k not in ignored_keys} for dataset in session.query(Dataset).all() ] - assert datasets == [ - { - "id": 1, - "sqlatable_id": 1, - "name": "old_dataset", - "changed_by": None, - "created_by": None, - "columns": [5, 6, 7, 8, 9, 10], - "is_physical": False, - "tables": [1], - "extra_json": "{}", - "external_url": None, - "is_managed_externally": False, - "expression": """ + assert len(datasets) == 1 + assert datasets[0] == { + "id": 1, + "database": 1, + "uuid": str(sqla_table.uuid), + "name": "old_dataset", + "changed_by": 1, + "created_by": 1, + "owners": [1], + "columns": [5, 6, 7, 8, 9, 10], + "is_physical": False, + "tables": [1], + "extra_json": "{}", + "external_url": None, + "is_managed_externally": False, + "expression": """ SELECT ds, - user_id, + num_boys, revenue, expenses, revenue - expenses AS profit FROM some_table""", - } - ] + } def test_delete_sqlatable(app_context: None, session: Session) -> None: @@ -886,18 +656,21 @@ def test_delete_sqlatable(app_context: None, session: Session) -> None: session.add(sqla_table) session.flush() - datasets = session.query(Dataset).all() - assert len(datasets) == 1 + assert session.query(Dataset).count() == 1 + assert session.query(Table).count() == 1 + assert session.query(Column).count() == 2 session.delete(sqla_table) session.flush() - # test that dataset was also deleted - datasets = session.query(Dataset).all() - assert len(datasets) == 0 + # test that dataset and dataset columns are also deleted + # but the physical table and table columns are kept + assert session.query(Dataset).count() == 0 + assert session.query(Table).count() == 1 + assert session.query(Column).count() == 1 -def test_update_sqlatable( +def test_update_physical_sqlatable_columns( mocker: MockFixture, app_context: None, session: Session ) -> None: """ @@ -929,21 +702,33 @@ def test_update_sqlatable( session.add(sqla_table) session.flush() + assert session.query(Table).count() == 1 + assert session.query(Dataset).count() == 1 + assert session.query(Column).count() == 2 # 1 for table, 1 for dataset + dataset = session.query(Dataset).one() assert len(dataset.columns) == 1 # add a column to the original ``SqlaTable`` instance - sqla_table.columns.append(TableColumn(column_name="user_id", type="INTEGER")) + sqla_table.columns.append(TableColumn(column_name="num_boys", type="INTEGER")) session.flush() - # check that the column was added to the dataset + assert session.query(Column).count() == 3 dataset = session.query(Dataset).one() assert len(dataset.columns) == 2 + for table_column, dataset_column in zip(sqla_table.columns, dataset.columns): + assert table_column.uuid == dataset_column.uuid # delete the column in the original instance sqla_table.columns = sqla_table.columns[1:] session.flush() + # check that the column was added to the dataset and the added columns have + # the correct uuid. + assert session.query(TableColumn).count() == 1 + # the extra Dataset.column is deleted, but Table.column is kept + assert session.query(Column).count() == 2 + # check that the column was also removed from the dataset dataset = session.query(Dataset).one() assert len(dataset.columns) == 1 @@ -957,7 +742,7 @@ def test_update_sqlatable( assert dataset.columns[0].is_temporal is True -def test_update_sqlatable_schema( +def test_update_physical_sqlatable_schema( mocker: MockFixture, app_context: None, session: Session ) -> None: """ @@ -1003,8 +788,11 @@ def test_update_sqlatable_schema( assert new_dataset.tables[0].id == 2 -def test_update_sqlatable_metric( - mocker: MockFixture, app_context: None, session: Session +def test_update_physical_sqlatable_metrics( + mocker: MockFixture, + app_context: None, + session: Session, + get_session: Callable[[], Session], ) -> None: """ Test that updating a ``SqlaTable`` also updates the corresponding ``Dataset``. @@ -1042,6 +830,9 @@ def test_update_sqlatable_metric( session.flush() # check that the metric was created + # 1 physical column for table + (1 column + 1 metric for datasets) + assert session.query(Column).count() == 3 + column = session.query(Column).filter_by(is_physical=False).one() assert column.expression == "COUNT(*)" @@ -1051,6 +842,186 @@ def test_update_sqlatable_metric( assert column.expression == "MAX(ds)" + # in a new session, update new columns and metrics at the same time + # reload the sqla_table so we can test the case that accessing an not already + # loaded attribute (`sqla_table.metrics`) while there are updates on the instance + # may trigger `after_update` before the attribute is loaded + session = get_session() + sqla_table = session.query(SqlaTable).filter(SqlaTable.id == sqla_table.id).one() + sqla_table.columns.append( + TableColumn( + column_name="another_column", + is_dttm=0, + type="TIMESTAMP", + expression="concat('a', 'b')", + ) + ) + # Here `SqlaTable.after_update` is triggered + # before `sqla_table.metrics` is loaded + sqla_table.metrics.append( + SqlMetric(metric_name="another_metric", expression="COUNT(*)") + ) + # `SqlaTable.after_update` will trigger again at flushing + session.flush() + assert session.query(Column).count() == 5 + + +def test_update_physical_sqlatable_database( + mocker: MockFixture, + app_context: None, + session: Session, + get_session: Callable[[], Session], +) -> None: + """ + Test updating the table on a physical dataset. + + When updating the table on a physical dataset by pointing it somewhere else (change + in database ID, schema, or table name) we should point the ``Dataset`` to an + existing ``Table`` if possible, and create a new one otherwise. + """ + # patch session + mocker.patch( + "superset.security.SupersetSecurityManager.get_session", return_value=session + ) + mocker.patch("superset.datasets.dao.db.session", session) + + from superset.columns.models import Column + from superset.connectors.sqla.models import SqlaTable, TableColumn + from superset.datasets.models import Dataset, dataset_column_association_table + from superset.models.core import Database + from superset.tables.models import Table, table_column_association_table + from superset.tables.schemas import TableSchema + + engine = session.get_bind() + Dataset.metadata.create_all(engine) # pylint: disable=no-member + + columns = [ + TableColumn(column_name="a", type="INTEGER"), + ] + + original_database = Database( + database_name="my_database", sqlalchemy_uri="sqlite://" + ) + sqla_table = SqlaTable( + table_name="original_table", + columns=columns, + metrics=[], + database=original_database, + ) + session.add(sqla_table) + session.flush() + + assert session.query(Table).count() == 1 + assert session.query(Dataset).count() == 1 + assert session.query(Column).count() == 2 # 1 for table, 1 for dataset + + # check that the table was created, and that the created dataset points to it + table = session.query(Table).one() + assert table.id == 1 + assert table.name == "original_table" + assert table.schema is None + assert table.database_id == 1 + + dataset = session.query(Dataset).one() + assert dataset.tables == [table] + + # point ``SqlaTable`` to a different database + new_database = Database( + database_name="my_other_database", sqlalchemy_uri="sqlite://" + ) + session.add(new_database) + session.flush() + sqla_table.database = new_database + sqla_table.table_name = "new_table" + session.flush() + + assert session.query(Dataset).count() == 1 + assert session.query(Table).count() == 2 + # is kept for the old table + # is kept for the updated dataset + # is created for the new table + assert session.query(Column).count() == 3 + + # ignore these keys when comparing results + ignored_keys = {"created_on", "changed_on", "uuid"} + + # check that the old table still exists, and that the dataset points to the newly + # created table, column and dataset + table_schema = TableSchema() + tables = [ + {k: v for k, v in table_schema.dump(table).items() if k not in ignored_keys} + for table in session.query(Table).all() + ] + assert tables[0] == { + "id": 1, + "database": 1, + "columns": [1], + "datasets": [], + "created_by": None, + "changed_by": None, + "extra_json": "{}", + "catalog": None, + "schema": None, + "name": "original_table", + "external_url": None, + "is_managed_externally": False, + } + assert tables[1] == { + "id": 2, + "database": 2, + "datasets": [1], + "columns": [3], + "created_by": None, + "changed_by": None, + "catalog": None, + "schema": None, + "name": "new_table", + "is_managed_externally": False, + "extra_json": "{}", + "external_url": None, + } + + # check that dataset now points to the new table + assert dataset.tables[0].database_id == 2 + # and a new column is created + assert len(dataset.columns) == 1 + assert dataset.columns[0].id == 2 + + # point ``SqlaTable`` back + sqla_table.database = original_database + sqla_table.table_name = "original_table" + session.flush() + + # should not create more table and datasets + assert session.query(Dataset).count() == 1 + assert session.query(Table).count() == 2 + # is deleted for the old table + # is kept for the updated dataset + # is kept for the new table + assert session.query(Column.id).order_by(Column.id).all() == [ + (1,), + (2,), + (3,), + ] + assert session.query(dataset_column_association_table).all() == [(1, 2)] + assert session.query(table_column_association_table).all() == [(1, 1), (2, 3)] + assert session.query(Dataset).filter_by(id=1).one().columns[0].id == 2 + assert session.query(Table).filter_by(id=2).one().columns[0].id == 3 + assert session.query(Table).filter_by(id=1).one().columns[0].id == 1 + + # the dataset points back to the original table + assert dataset.tables[0].database_id == 1 + assert dataset.tables[0].name == "original_table" + + # kept the original column + assert dataset.columns[0].id == 2 + session.commit() + session.close() + + # querying in a new session should still return the same result + session = get_session() + assert session.query(table_column_association_table).all() == [(1, 1), (2, 3)] + def test_update_virtual_sqlatable_references( mocker: MockFixture, app_context: None, session: Session @@ -1108,7 +1079,7 @@ def test_update_virtual_sqlatable_references( session.flush() # check that new dataset has table1 - dataset = session.query(Dataset).one() + dataset: Dataset = session.query(Dataset).one() assert dataset.tables == [table1] # change SQL @@ -1116,20 +1087,26 @@ def test_update_virtual_sqlatable_references( session.flush() # check that new dataset has both tables - new_dataset = session.query(Dataset).one() + new_dataset: Dataset = session.query(Dataset).one() assert new_dataset.tables == [table1, table2] assert new_dataset.expression == "SELECT a, b FROM table_a JOIN table_b" + # automatically add new referenced table + sqla_table.sql = "SELECT a, b, c FROM table_a JOIN table_b JOIN table_c" + session.flush() + + new_dataset = session.query(Dataset).one() + assert len(new_dataset.tables) == 3 + assert new_dataset.tables[2].name == "table_c" + def test_quote_expressions(app_context: None, session: Session) -> None: """ Test that expressions are quoted appropriately in columns and datasets. """ - from superset.columns.models import Column from superset.connectors.sqla.models import SqlaTable, TableColumn from superset.datasets.models import Dataset from superset.models.core import Database - from superset.tables.models import Table engine = session.get_bind() Dataset.metadata.create_all(engine) # pylint: disable=no-member @@ -1152,180 +1129,3 @@ def test_quote_expressions(app_context: None, session: Session) -> None: assert dataset.expression == '"old dataset"' assert dataset.columns[0].expression == '"has space"' assert dataset.columns[1].expression == "no_need" - - -def test_update_physical_sqlatable( - mocker: MockFixture, app_context: None, session: Session -) -> None: - """ - Test updating the table on a physical dataset. - - When updating the table on a physical dataset by pointing it somewhere else (change - in database ID, schema, or table name) we should point the ``Dataset`` to an - existing ``Table`` if possible, and create a new one otherwise. - """ - # patch session - mocker.patch( - "superset.security.SupersetSecurityManager.get_session", return_value=session - ) - mocker.patch("superset.datasets.dao.db.session", session) - - from superset.columns.models import Column - from superset.connectors.sqla.models import SqlaTable, TableColumn - from superset.datasets.models import Dataset - from superset.models.core import Database - from superset.tables.models import Table - from superset.tables.schemas import TableSchema - - engine = session.get_bind() - Dataset.metadata.create_all(engine) # pylint: disable=no-member - - columns = [ - TableColumn(column_name="a", type="INTEGER"), - ] - - sqla_table = SqlaTable( - table_name="old_dataset", - columns=columns, - metrics=[], - database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), - ) - session.add(sqla_table) - session.flush() - - # check that the table was created, and that the created dataset points to it - table = session.query(Table).one() - assert table.id == 1 - assert table.name == "old_dataset" - assert table.schema is None - assert table.database_id == 1 - - dataset = session.query(Dataset).one() - assert dataset.tables == [table] - - # point ``SqlaTable`` to a different database - new_database = Database( - database_name="my_other_database", sqlalchemy_uri="sqlite://" - ) - session.add(new_database) - session.flush() - sqla_table.database = new_database - session.flush() - - # ignore these keys when comparing results - ignored_keys = {"created_on", "changed_on", "uuid"} - - # check that the old table still exists, and that the dataset points to the newly - # created table (id=2) and column (id=2), on the new database (also id=2) - table_schema = TableSchema() - tables = [ - {k: v for k, v in table_schema.dump(table).items() if k not in ignored_keys} - for table in session.query(Table).all() - ] - assert tables == [ - { - "created_by": None, - "extra_json": "{}", - "name": "old_dataset", - "changed_by": None, - "catalog": None, - "columns": [1], - "database": 1, - "external_url": None, - "schema": None, - "id": 1, - "is_managed_externally": False, - }, - { - "created_by": None, - "extra_json": "{}", - "name": "old_dataset", - "changed_by": None, - "catalog": None, - "columns": [2], - "database": 2, - "external_url": None, - "schema": None, - "id": 2, - "is_managed_externally": False, - }, - ] - - # check that dataset now points to the new table - assert dataset.tables[0].database_id == 2 - - # point ``SqlaTable`` back - sqla_table.database_id = 1 - session.flush() - - # check that dataset points to the original table - assert dataset.tables[0].database_id == 1 - - -def test_update_physical_sqlatable_no_dataset( - mocker: MockFixture, app_context: None, session: Session -) -> None: - """ - Test updating the table on a physical dataset that it creates - a new dataset if one didn't already exist. - - When updating the table on a physical dataset by pointing it somewhere else (change - in database ID, schema, or table name) we should point the ``Dataset`` to an - existing ``Table`` if possible, and create a new one otherwise. - """ - # patch session - mocker.patch( - "superset.security.SupersetSecurityManager.get_session", return_value=session - ) - mocker.patch("superset.datasets.dao.db.session", session) - - from superset.columns.models import Column - from superset.connectors.sqla.models import SqlaTable, TableColumn - from superset.datasets.models import Dataset - from superset.models.core import Database - from superset.tables.models import Table - from superset.tables.schemas import TableSchema - - engine = session.get_bind() - Dataset.metadata.create_all(engine) # pylint: disable=no-member - - columns = [ - TableColumn(column_name="a", type="INTEGER"), - ] - - sqla_table = SqlaTable( - table_name="old_dataset", - columns=columns, - metrics=[], - database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), - ) - session.add(sqla_table) - session.flush() - - # check that the table was created - table = session.query(Table).one() - assert table.id == 1 - - dataset = session.query(Dataset).one() - assert dataset.tables == [table] - - # point ``SqlaTable`` to a different database - new_database = Database( - database_name="my_other_database", sqlalchemy_uri="sqlite://" - ) - session.add(new_database) - session.flush() - sqla_table.database = new_database - session.flush() - - new_dataset = session.query(Dataset).one() - - # check that dataset now points to the new table - assert new_dataset.tables[0].database_id == 2 - - # point ``SqlaTable`` back - sqla_table.database_id = 1 - session.flush() - - # check that dataset points to the original table - assert new_dataset.tables[0].database_id == 1 diff --git a/tests/unit_tests/migrations/shared/__init__.py b/tests/unit_tests/migrations/shared/__init__.py deleted file mode 100644 index 13a83393a9..0000000000 --- a/tests/unit_tests/migrations/shared/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. diff --git a/tests/unit_tests/migrations/shared/utils_test.py b/tests/unit_tests/migrations/shared/utils_test.py deleted file mode 100644 index cb5b2cbd0e..0000000000 --- a/tests/unit_tests/migrations/shared/utils_test.py +++ /dev/null @@ -1,56 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=import-outside-toplevel, unused-argument - -""" -Test the SIP-68 migration. -""" - -from pytest_mock import MockerFixture - -from superset.sql_parse import Table - - -def test_extract_table_references(mocker: MockerFixture, app_context: None) -> None: - """ - Test the ``extract_table_references`` helper function. - """ - from superset.migrations.shared.utils import extract_table_references - - assert extract_table_references("SELECT 1", "trino") == set() - assert extract_table_references("SELECT 1 FROM some_table", "trino") == { - Table(table="some_table", schema=None, catalog=None) - } - assert extract_table_references( - "SELECT 1 FROM some_catalog.some_schema.some_table", "trino" - ) == {Table(table="some_table", schema="some_schema", catalog="some_catalog")} - assert extract_table_references( - "SELECT * FROM some_table JOIN other_table ON some_table.id = other_table.id", - "trino", - ) == { - Table(table="some_table", schema=None, catalog=None), - Table(table="other_table", schema=None, catalog=None), - } - - # test falling back to sqlparse - logger = mocker.patch("superset.migrations.shared.utils.logger") - sql = "SELECT * FROM table UNION ALL SELECT * FROM other_table" - assert extract_table_references( - sql, - "trino", - ) == {Table(table="other_table", schema=None, catalog=None)} - logger.warning.assert_called_with("Unable to parse query with sqloxide: %s", sql) diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py index 4a1ff89d74..d9c5d64c59 100644 --- a/tests/unit_tests/sql_parse_tests.py +++ b/tests/unit_tests/sql_parse_tests.py @@ -29,6 +29,7 @@ from sqlparse.tokens import Name from superset.exceptions import QueryClauseValidationException from superset.sql_parse import ( add_table_name, + extract_table_references, get_rls_for_table, has_table_query, insert_rls, @@ -1468,3 +1469,51 @@ def test_get_rls_for_table(mocker: MockerFixture, app_context: None) -> None: dataset.get_sqla_row_level_filters.return_value = [] assert get_rls_for_table(candidate, 1, "public") is None + + +def test_extract_table_references(mocker: MockerFixture) -> None: + """ + Test the ``extract_table_references`` helper function. + """ + assert extract_table_references("SELECT 1", "trino") == set() + assert extract_table_references("SELECT 1 FROM some_table", "trino") == { + Table(table="some_table", schema=None, catalog=None) + } + assert extract_table_references("SELECT {{ jinja }} FROM some_table", "trino") == { + Table(table="some_table", schema=None, catalog=None) + } + assert extract_table_references( + "SELECT 1 FROM some_catalog.some_schema.some_table", "trino" + ) == {Table(table="some_table", schema="some_schema", catalog="some_catalog")} + + # with identifier quotes + assert extract_table_references( + "SELECT 1 FROM `some_catalog`.`some_schema`.`some_table`", "mysql" + ) == {Table(table="some_table", schema="some_schema", catalog="some_catalog")} + assert extract_table_references( + 'SELECT 1 FROM "some_catalog".some_schema."some_table"', "trino" + ) == {Table(table="some_table", schema="some_schema", catalog="some_catalog")} + + assert extract_table_references( + "SELECT * FROM some_table JOIN other_table ON some_table.id = other_table.id", + "trino", + ) == { + Table(table="some_table", schema=None, catalog=None), + Table(table="other_table", schema=None, catalog=None), + } + + # test falling back to sqlparse + logger = mocker.patch("superset.sql_parse.logger") + sql = "SELECT * FROM table UNION ALL SELECT * FROM other_table" + assert extract_table_references( + sql, + "trino", + ) == {Table(table="other_table", schema=None, catalog=None)} + logger.warning.assert_called_once() + + logger = mocker.patch("superset.migrations.shared.utils.logger") + sql = "SELECT * FROM table UNION ALL SELECT * FROM other_table" + assert extract_table_references(sql, "trino", show_warning=False) == { + Table(table="other_table", schema=None, catalog=None) + } + logger.warning.assert_not_called() diff --git a/tests/unit_tests/migrations/__init__.py b/tests/unit_tests/utils/db.py similarity index 69% rename from tests/unit_tests/migrations/__init__.py rename to tests/unit_tests/utils/db.py index 13a83393a9..554c95bd43 100644 --- a/tests/unit_tests/migrations/__init__.py +++ b/tests/unit_tests/utils/db.py @@ -14,3 +14,17 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import Any + +from superset import security_manager + + +def get_test_user(id_: int, username: str) -> Any: + """Create a sample test user""" + return security_manager.user_model( + id=id_, + username=username, + first_name=username, + last_name=username, + email=f"{username}@example.com", + )