perf: refactor SIP-68 db migrations with INSERT SELECT FROM (#19421)

This commit is contained in:
Jesse Yang 2022-04-19 18:58:18 -07:00 committed by GitHub
parent 1c5d3b73df
commit 231716cb50
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 2356 additions and 1988 deletions

View File

@ -23,7 +23,6 @@ tables, metrics, and datasets were also introduced.
These models are not fully implemented, and shouldn't be used yet.
"""
import sqlalchemy as sa
from flask_appbuilder import Model
@ -33,6 +32,8 @@ from superset.models.helpers import (
ImportExportMixin,
)
UNKOWN_TYPE = "UNKNOWN"
class Column(
Model,
@ -52,51 +53,58 @@ class Column(
id = sa.Column(sa.Integer, primary_key=True)
# We use ``sa.Text`` for these attributes because (1) in modern databases the
# performance is the same as ``VARCHAR``[1] and (2) because some table names can be
# **really** long (eg, Google Sheets URLs).
#
# [1] https://www.postgresql.org/docs/9.1/datatype-character.html
name = sa.Column(sa.Text)
type = sa.Column(sa.Text)
# Columns are defined by expressions. For tables, these are the actual columns names,
# and should match the ``name`` attribute. For datasets, these can be any valid SQL
# expression. If the SQL expression is an aggregation the column is a metric,
# otherwise it's a computed column.
expression = sa.Column(sa.Text)
# Does the expression point directly to a physical column?
is_physical = sa.Column(sa.Boolean, default=True)
# Additional metadata describing the column.
description = sa.Column(sa.Text)
warning_text = sa.Column(sa.Text)
unit = sa.Column(sa.Text)
# Is this a time column? Useful for plotting time series.
is_temporal = sa.Column(sa.Boolean, default=False)
# Is this a spatial column? This could be leveraged in the future for spatial
# visualizations.
is_spatial = sa.Column(sa.Boolean, default=False)
# Is this column a partition? Useful for scheduling queries and previewing the latest
# data.
is_partition = sa.Column(sa.Boolean, default=False)
# Is this column an aggregation (metric)?
is_aggregation = sa.Column(sa.Boolean, default=False)
# Assuming the column is an aggregation, is it additive? Useful for determining which
# aggregations can be done on the metric. Eg, ``COUNT(DISTINCT user_id)`` is not
# additive, so it shouldn't be used in a ``SUM``.
is_additive = sa.Column(sa.Boolean, default=False)
# Is this column an aggregation (metric)?
is_aggregation = sa.Column(sa.Boolean, default=False)
is_filterable = sa.Column(sa.Boolean, nullable=False, default=True)
is_dimensional = sa.Column(sa.Boolean, nullable=False, default=False)
# Is an increase desired? Useful for displaying the results of A/B tests, or setting
# up alerts. Eg, this is true for "revenue", but false for "latency".
is_increase_desired = sa.Column(sa.Boolean, default=True)
# Column is managed externally and should be read-only inside Superset
is_managed_externally = sa.Column(sa.Boolean, nullable=False, default=False)
# Is this column a partition? Useful for scheduling queries and previewing the latest
# data.
is_partition = sa.Column(sa.Boolean, default=False)
# Does the expression point directly to a physical column?
is_physical = sa.Column(sa.Boolean, default=True)
# Is this a spatial column? This could be leveraged in the future for spatial
# visualizations.
is_spatial = sa.Column(sa.Boolean, default=False)
# Is this a time column? Useful for plotting time series.
is_temporal = sa.Column(sa.Boolean, default=False)
# We use ``sa.Text`` for these attributes because (1) in modern databases the
# performance is the same as ``VARCHAR``[1] and (2) because some table names can be
# **really** long (eg, Google Sheets URLs).
#
# [1] https://www.postgresql.org/docs/9.1/datatype-character.html
name = sa.Column(sa.Text)
# Raw type as returned and used by db engine.
type = sa.Column(sa.Text, default=UNKOWN_TYPE)
# Columns are defined by expressions. For tables, these are the actual columns names,
# and should match the ``name`` attribute. For datasets, these can be any valid SQL
# expression. If the SQL expression is an aggregation the column is a metric,
# otherwise it's a computed column.
expression = sa.Column(sa.Text)
unit = sa.Column(sa.Text)
# Additional metadata describing the column.
description = sa.Column(sa.Text)
warning_text = sa.Column(sa.Text)
external_url = sa.Column(sa.Text, nullable=True)
def __repr__(self) -> str:
return f"<Column id={self.id}>"

View File

@ -31,7 +31,7 @@ from superset.models.helpers import AuditMixinNullable, ImportExportMixin, Query
from superset.models.slice import Slice
from superset.superset_typing import FilterValue, FilterValues, QueryObjectDict
from superset.utils import core as utils
from superset.utils.core import GenericDataType
from superset.utils.core import GenericDataType, MediumText
METRIC_FORM_DATA_PARAMS = [
"metric",
@ -586,7 +586,7 @@ class BaseColumn(AuditMixinNullable, ImportExportMixin):
type = Column(Text)
groupby = Column(Boolean, default=True)
filterable = Column(Boolean, default=True)
description = Column(Text)
description = Column(MediumText())
is_dttm = None
# [optional] Set this to support import/export functionality
@ -672,7 +672,7 @@ class BaseMetric(AuditMixinNullable, ImportExportMixin):
metric_name = Column(String(255), nullable=False)
verbose_name = Column(String(1024))
metric_type = Column(String(32))
description = Column(Text)
description = Column(MediumText())
d3format = Column(String(128))
warning_text = Column(Text)

View File

@ -24,6 +24,7 @@ from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import (
Any,
Callable,
cast,
Dict,
Hashable,
@ -34,6 +35,7 @@ from typing import (
Type,
Union,
)
from uuid import uuid4
import dateutil.parser
import numpy as np
@ -72,13 +74,13 @@ from sqlalchemy.sql.expression import Label, Select, TextAsFrom
from sqlalchemy.sql.selectable import Alias, TableClause
from superset import app, db, is_feature_enabled, security_manager
from superset.columns.models import Column as NewColumn
from superset.columns.models import Column as NewColumn, UNKOWN_TYPE
from superset.common.db_query_status import QueryStatus
from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric
from superset.connectors.sqla.utils import (
find_cached_objects_in_session,
get_physical_table_metadata,
get_virtual_table_metadata,
load_or_create_tables,
validate_adhoc_subquery,
)
from superset.datasets.models import Dataset as NewDataset
@ -100,7 +102,12 @@ from superset.models.helpers import (
clone_model,
QueryResult,
)
from superset.sql_parse import ParsedQuery, sanitize_clause
from superset.sql_parse import (
extract_table_references,
ParsedQuery,
sanitize_clause,
Table as TableName,
)
from superset.superset_typing import (
AdhocColumn,
AdhocMetric,
@ -114,6 +121,7 @@ from superset.utils.core import (
GenericDataType,
get_column_name,
is_adhoc_column,
MediumText,
QueryObjectFilterClause,
remove_duplicates,
)
@ -130,6 +138,7 @@ ADDITIVE_METRIC_TYPES = {
"sum",
"doubleSum",
}
ADDITIVE_METRIC_TYPES_LOWER = {op.lower() for op in ADDITIVE_METRIC_TYPES}
class SqlaQuery(NamedTuple):
@ -215,13 +224,13 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
__tablename__ = "table_columns"
__table_args__ = (UniqueConstraint("table_id", "column_name"),)
table_id = Column(Integer, ForeignKey("tables.id"))
table = relationship(
table: "SqlaTable" = relationship(
"SqlaTable",
backref=backref("columns", cascade="all, delete-orphan"),
foreign_keys=[table_id],
)
is_dttm = Column(Boolean, default=False)
expression = Column(Text)
expression = Column(MediumText())
python_date_format = Column(String(255))
extra = Column(Text)
@ -417,6 +426,59 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
return attr_dict
def to_sl_column(
self, known_columns: Optional[Dict[str, NewColumn]] = None
) -> NewColumn:
"""Convert a TableColumn to NewColumn"""
column = known_columns.get(self.uuid) if known_columns else None
if not column:
column = NewColumn()
extra_json = self.get_extra_dict()
for attr in {
"verbose_name",
"python_date_format",
}:
value = getattr(self, attr)
if value:
extra_json[attr] = value
column.uuid = self.uuid
column.created_on = self.created_on
column.changed_on = self.changed_on
column.created_by = self.created_by
column.changed_by = self.changed_by
column.name = self.column_name
column.type = self.type or UNKOWN_TYPE
column.expression = self.expression or self.table.quote_identifier(
self.column_name
)
column.description = self.description
column.is_aggregation = False
column.is_dimensional = self.groupby
column.is_filterable = self.filterable
column.is_increase_desired = True
column.is_managed_externally = self.table.is_managed_externally
column.is_partition = False
column.is_physical = not self.expression
column.is_spatial = False
column.is_temporal = self.is_dttm
column.extra_json = json.dumps(extra_json) if extra_json else None
column.external_url = self.table.external_url
return column
@staticmethod
def after_delete( # pylint: disable=unused-argument
mapper: Mapper,
connection: Connection,
target: "TableColumn",
) -> None:
session = inspect(target).session
column = session.query(NewColumn).filter_by(uuid=target.uuid).one_or_none()
if column:
session.delete(column)
class SqlMetric(Model, BaseMetric, CertificationMixin):
@ -430,7 +492,7 @@ class SqlMetric(Model, BaseMetric, CertificationMixin):
backref=backref("metrics", cascade="all, delete-orphan"),
foreign_keys=[table_id],
)
expression = Column(Text, nullable=False)
expression = Column(MediumText(), nullable=False)
extra = Column(Text)
export_fields = [
@ -479,6 +541,58 @@ class SqlMetric(Model, BaseMetric, CertificationMixin):
attr_dict.update(super().data)
return attr_dict
def to_sl_column(
self, known_columns: Optional[Dict[str, NewColumn]] = None
) -> NewColumn:
"""Convert a SqlMetric to NewColumn. Find and update existing or
create a new one."""
column = known_columns.get(self.uuid) if known_columns else None
if not column:
column = NewColumn()
extra_json = self.get_extra_dict()
for attr in {"verbose_name", "metric_type", "d3format"}:
value = getattr(self, attr)
if value is not None:
extra_json[attr] = value
is_additive = (
self.metric_type and self.metric_type.lower() in ADDITIVE_METRIC_TYPES_LOWER
)
column.uuid = self.uuid
column.name = self.metric_name
column.created_on = self.created_on
column.changed_on = self.changed_on
column.created_by = self.created_by
column.changed_by = self.changed_by
column.type = UNKOWN_TYPE
column.expression = self.expression
column.warning_text = self.warning_text
column.description = self.description
column.is_aggregation = True
column.is_additive = is_additive
column.is_filterable = False
column.is_increase_desired = True
column.is_managed_externally = self.table.is_managed_externally
column.is_partition = False
column.is_physical = False
column.is_spatial = False
column.extra_json = json.dumps(extra_json) if extra_json else None
column.external_url = self.table.external_url
return column
@staticmethod
def after_delete( # pylint: disable=unused-argument
mapper: Mapper,
connection: Connection,
target: "SqlMetric",
) -> None:
session = inspect(target).session
column = session.query(NewColumn).filter_by(uuid=target.uuid).one_or_none()
if column:
session.delete(column)
sqlatable_user = Table(
"sqlatable_user",
@ -544,7 +658,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
foreign_keys=[database_id],
)
schema = Column(String(255))
sql = Column(Text)
sql = Column(MediumText())
is_sqllab_view = Column(Boolean, default=False)
template_params = Column(Text)
extra = Column(Text)
@ -1731,7 +1845,19 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
metrics = []
any_date_col = None
db_engine_spec = self.db_engine_spec
old_columns = db.session.query(TableColumn).filter(TableColumn.table == self)
# If no `self.id`, then this is a new table, no need to fetch columns
# from db. Passing in `self.id` to query will actually automatically
# generate a new id, which can be tricky during certain transactions.
old_columns = (
(
db.session.query(TableColumn)
.filter(TableColumn.table_id == self.id)
.all()
)
if self.id
else self.columns
)
old_columns_by_name: Dict[str, TableColumn] = {
col.column_name: col for col in old_columns
@ -1745,13 +1871,15 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
)
# clear old columns before adding modified columns back
self.columns = []
columns = []
for col in new_columns:
old_column = old_columns_by_name.pop(col["name"], None)
if not old_column:
results.added.append(col["name"])
new_column = TableColumn(
column_name=col["name"], type=col["type"], table=self
column_name=col["name"],
type=col["type"],
table=self,
)
new_column.is_dttm = new_column.is_temporal
db_engine_spec.alter_new_orm_column(new_column)
@ -1763,12 +1891,14 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
new_column.expression = ""
new_column.groupby = True
new_column.filterable = True
self.columns.append(new_column)
columns.append(new_column)
if not any_date_col and new_column.is_temporal:
any_date_col = col["name"]
self.columns.extend(
[col for col in old_columns_by_name.values() if col.expression]
)
# add back calculated (virtual) columns
columns.extend([col for col in old_columns if col.expression])
self.columns = columns
metrics.append(
SqlMetric(
metric_name="count",
@ -1854,6 +1984,10 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
extra_cache_keys += sqla_query.extra_cache_keys
return extra_cache_keys
@property
def quote_identifier(self) -> Callable[[str], str]:
return self.database.quote_identifier
@staticmethod
def before_update(
mapper: Mapper, # pylint: disable=unused-argument
@ -1895,14 +2029,44 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
):
raise Exception(get_dataset_exist_error_msg(target.full_name))
def get_sl_columns(self) -> List[NewColumn]:
"""
Convert `SqlaTable.columns` and `SqlaTable.metrics` to the new Column model
"""
session: Session = inspect(self).session
uuids = set()
for column_or_metric in self.columns + self.metrics:
# pre-assign uuid after new columns or metrics are inserted so
# the related `NewColumn` can have a deterministic uuid, too
if not column_or_metric.uuid:
column_or_metric.uuid = uuid4()
else:
uuids.add(column_or_metric.uuid)
# load existing columns from cached session states first
existing_columns = set(
find_cached_objects_in_session(session, NewColumn, uuids=uuids)
)
for column in existing_columns:
uuids.remove(column.uuid)
if uuids:
# load those not found from db
existing_columns |= set(
session.query(NewColumn).filter(NewColumn.uuid.in_(uuids))
)
known_columns = {column.uuid: column for column in existing_columns}
return [
item.to_sl_column(known_columns) for item in self.columns + self.metrics
]
@staticmethod
def update_table( # pylint: disable=unused-argument
mapper: Mapper, connection: Connection, target: Union[SqlMetric, TableColumn]
) -> None:
"""
Forces an update to the table's changed_on value when a metric or column on the
table is updated. This busts the cache key for all charts that use the table.
:param mapper: Unused.
:param connection: Unused.
:param target: The metric or column that was updated.
@ -1910,90 +2074,43 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
inspector = inspect(target)
session = inspector.session
# get DB-specific conditional quoter for expressions that point to columns or
# table names
database = (
target.table.database
or session.query(Database).filter_by(id=target.database_id).one()
)
engine = database.get_sqla_engine(schema=target.table.schema)
conditional_quote = engine.dialect.identifier_preparer.quote
# Forces an update to the table's changed_on value when a metric or column on the
# table is updated. This busts the cache key for all charts that use the table.
session.execute(update(SqlaTable).where(SqlaTable.id == target.table.id))
dataset = (
session.query(NewDataset)
.filter_by(sqlatable_id=target.table.id)
.one_or_none()
)
if not dataset:
# if dataset is not found create a new copy
# of the dataset instead of updating the existing
SqlaTable.write_shadow_dataset(target.table, database, session)
return
# update ``Column`` model as well
if isinstance(target, TableColumn):
columns = [
column
for column in dataset.columns
if column.name == target.column_name
]
if not columns:
# if table itself has changed, shadow-writing will happen in `after_udpate` anyway
if target.table not in session.dirty:
dataset: NewDataset = (
session.query(NewDataset)
.filter_by(uuid=target.table.uuid)
.one_or_none()
)
# Update shadow dataset and columns
# did we find the dataset?
if not dataset:
# if dataset is not found create a new copy
target.table.write_shadow_dataset()
return
column = columns[0]
extra_json = json.loads(target.extra or "{}")
for attr in {"groupby", "filterable", "verbose_name", "python_date_format"}:
value = getattr(target, attr)
if value:
extra_json[attr] = value
# update changed_on timestamp
session.execute(update(NewDataset).where(NewDataset.id == dataset.id))
column.name = target.column_name
column.type = target.type or "Unknown"
column.expression = target.expression or conditional_quote(
target.column_name
# update `Column` model as well
session.add(
target.to_sl_column(
{
target.uuid: session.query(NewColumn)
.filter_by(uuid=target.uuid)
.one_or_none()
}
)
)
column.description = target.description
column.is_temporal = target.is_dttm
column.is_physical = target.expression is None
column.extra_json = json.dumps(extra_json) if extra_json else None
else: # SqlMetric
columns = [
column
for column in dataset.columns
if column.name == target.metric_name
]
if not columns:
return
column = columns[0]
extra_json = json.loads(target.extra or "{}")
for attr in {"verbose_name", "metric_type", "d3format"}:
value = getattr(target, attr)
if value:
extra_json[attr] = value
is_additive = (
target.metric_type
and target.metric_type.lower() in ADDITIVE_METRIC_TYPES
)
column.name = target.metric_name
column.expression = target.expression
column.warning_text = target.warning_text
column.description = target.description
column.is_additive = is_additive
column.extra_json = json.dumps(extra_json) if extra_json else None
@staticmethod
def after_insert(
mapper: Mapper,
connection: Connection,
target: "SqlaTable",
sqla_table: "SqlaTable",
) -> None:
"""
Shadow write the dataset to new models.
@ -2007,24 +2124,14 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
For more context: https://github.com/apache/superset/issues/14909
"""
session = inspect(target).session
# set permissions
security_manager.set_perm(mapper, connection, target)
# get DB-specific conditional quoter for expressions that point to columns or
# table names
database = (
target.database
or session.query(Database).filter_by(id=target.database_id).one()
)
SqlaTable.write_shadow_dataset(target, database, session)
security_manager.set_perm(mapper, connection, sqla_table)
sqla_table.write_shadow_dataset()
@staticmethod
def after_delete( # pylint: disable=unused-argument
mapper: Mapper,
connection: Connection,
target: "SqlaTable",
sqla_table: "SqlaTable",
) -> None:
"""
Shadow write the dataset to new models.
@ -2038,18 +2145,18 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
For more context: https://github.com/apache/superset/issues/14909
"""
session = inspect(target).session
session = inspect(sqla_table).session
dataset = (
session.query(NewDataset).filter_by(sqlatable_id=target.id).one_or_none()
session.query(NewDataset).filter_by(uuid=sqla_table.uuid).one_or_none()
)
if dataset:
session.delete(dataset)
@staticmethod
def after_update( # pylint: disable=too-many-branches, too-many-locals, too-many-statements
def after_update(
mapper: Mapper,
connection: Connection,
target: "SqlaTable",
sqla_table: "SqlaTable",
) -> None:
"""
Shadow write the dataset to new models.
@ -2063,172 +2170,76 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
For more context: https://github.com/apache/superset/issues/14909
"""
inspector = inspect(target)
# set permissions
security_manager.set_perm(mapper, connection, sqla_table)
inspector = inspect(sqla_table)
session = inspector.session
# double-check that ``UPDATE``s are actually pending (this method is called even
# for instances that have no net changes to their column-based attributes)
if not session.is_modified(target, include_collections=True):
if not session.is_modified(sqla_table, include_collections=True):
return
# set permissions
security_manager.set_perm(mapper, connection, target)
dataset = (
session.query(NewDataset).filter_by(sqlatable_id=target.id).one_or_none()
# find the dataset from the known instance list first
# (it could be either from a previous query or newly created)
dataset = next(
find_cached_objects_in_session(
session, NewDataset, uuids=[sqla_table.uuid]
),
None,
)
# if not found, pull from database
if not dataset:
dataset = (
session.query(NewDataset).filter_by(uuid=sqla_table.uuid).one_or_none()
)
if not dataset:
sqla_table.write_shadow_dataset()
return
# get DB-specific conditional quoter for expressions that point to columns or
# table names
database = (
target.database
or session.query(Database).filter_by(id=target.database_id).one()
)
engine = database.get_sqla_engine(schema=target.schema)
conditional_quote = engine.dialect.identifier_preparer.quote
# update columns
if inspector.attrs.columns.history.has_changes():
# handle deleted columns
if inspector.attrs.columns.history.deleted:
column_names = {
column.column_name
for column in inspector.attrs.columns.history.deleted
}
dataset.columns = [
column
for column in dataset.columns
if column.name not in column_names
]
# handle inserted columns
for column in inspector.attrs.columns.history.added:
# ``is_active`` might be ``None``, but it defaults to ``True``.
if column.is_active is False:
continue
extra_json = json.loads(column.extra or "{}")
for attr in {
"groupby",
"filterable",
"verbose_name",
"python_date_format",
}:
value = getattr(column, attr)
if value:
extra_json[attr] = value
dataset.columns.append(
NewColumn(
name=column.column_name,
type=column.type or "Unknown",
expression=column.expression
or conditional_quote(column.column_name),
description=column.description,
is_temporal=column.is_dttm,
is_aggregation=False,
is_physical=column.expression is None,
is_spatial=False,
is_partition=False,
is_increase_desired=True,
extra_json=json.dumps(extra_json) if extra_json else None,
is_managed_externally=target.is_managed_externally,
external_url=target.external_url,
)
)
# update metrics
if inspector.attrs.metrics.history.has_changes():
# handle deleted metrics
if inspector.attrs.metrics.history.deleted:
column_names = {
metric.metric_name
for metric in inspector.attrs.metrics.history.deleted
}
dataset.columns = [
column
for column in dataset.columns
if column.name not in column_names
]
# handle inserted metrics
for metric in inspector.attrs.metrics.history.added:
extra_json = json.loads(metric.extra or "{}")
for attr in {"verbose_name", "metric_type", "d3format"}:
value = getattr(metric, attr)
if value:
extra_json[attr] = value
is_additive = (
metric.metric_type
and metric.metric_type.lower() in ADDITIVE_METRIC_TYPES
)
dataset.columns.append(
NewColumn(
name=metric.metric_name,
type="Unknown",
expression=metric.expression,
warning_text=metric.warning_text,
description=metric.description,
is_aggregation=True,
is_additive=is_additive,
is_physical=False,
is_spatial=False,
is_partition=False,
is_increase_desired=True,
extra_json=json.dumps(extra_json) if extra_json else None,
is_managed_externally=target.is_managed_externally,
external_url=target.external_url,
)
)
# sync column list and delete removed columns
if (
inspector.attrs.columns.history.has_changes()
or inspector.attrs.metrics.history.has_changes()
):
# add pending new columns to known columns list, too, so if calling
# `after_update` twice before changes are persisted will not create
# two duplicate columns with the same uuids.
dataset.columns = sqla_table.get_sl_columns()
# physical dataset
if target.sql is None:
physical_columns = [
column for column in dataset.columns if column.is_physical
]
# if the table name changed we should create a new table instance, instead
# of reusing the original one
if not sqla_table.sql:
# if the table name changed we should relink the dataset to another table
# (and create one if necessary)
if (
inspector.attrs.table_name.history.has_changes()
or inspector.attrs.schema.history.has_changes()
or inspector.attrs.database_id.history.has_changes()
or inspector.attrs.database.history.has_changes()
):
# does the dataset point to an existing table?
table = (
session.query(NewTable)
.filter_by(
database_id=target.database_id,
schema=target.schema,
name=target.table_name,
)
.first()
tables = NewTable.bulk_load_or_create(
sqla_table.database,
[TableName(schema=sqla_table.schema, table=sqla_table.table_name)],
sync_columns=False,
default_props=dict(
changed_by=sqla_table.changed_by,
created_by=sqla_table.created_by,
is_managed_externally=sqla_table.is_managed_externally,
external_url=sqla_table.external_url,
),
)
if not table:
# create new columns
if not tables[0].id:
# dataset columns will only be assigned to newly created tables
# existing tables should manage column syncing in another process
physical_columns = [
clone_model(column, ignore=["uuid"])
for column in physical_columns
clone_model(
column, ignore=["uuid"], keep_relations=["changed_by"]
)
for column in dataset.columns
if column.is_physical
]
# create new table
table = NewTable(
name=target.table_name,
schema=target.schema,
catalog=None,
database_id=target.database_id,
columns=physical_columns,
is_managed_externally=target.is_managed_externally,
external_url=target.external_url,
)
dataset.tables = [table]
elif dataset.tables:
table = dataset.tables[0]
table.columns = physical_columns
tables[0].columns = physical_columns
dataset.tables = tables
# virtual dataset
else:
@ -2237,29 +2248,34 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
column.is_physical = False
# update referenced tables if SQL changed
if inspector.attrs.sql.history.has_changes():
parsed = ParsedQuery(target.sql)
referenced_tables = parsed.tables
predicate = or_(
*[
and_(
NewTable.schema == (table.schema or target.schema),
NewTable.name == table.table,
)
for table in referenced_tables
]
if sqla_table.sql and inspector.attrs.sql.history.has_changes():
referenced_tables = extract_table_references(
sqla_table.sql, sqla_table.database.get_dialect().name
)
dataset.tables = NewTable.bulk_load_or_create(
sqla_table.database,
referenced_tables,
default_schema=sqla_table.schema,
# sync metadata is expensive, we'll do it in another process
# e.g. when users open a Table page
sync_columns=False,
default_props=dict(
changed_by=sqla_table.changed_by,
created_by=sqla_table.created_by,
is_managed_externally=sqla_table.is_managed_externally,
external_url=sqla_table.external_url,
),
)
dataset.tables = session.query(NewTable).filter(predicate).all()
# update other attributes
dataset.name = target.table_name
dataset.expression = target.sql or conditional_quote(target.table_name)
dataset.is_physical = target.sql is None
dataset.name = sqla_table.table_name
dataset.expression = sqla_table.sql or sqla_table.quote_identifier(
sqla_table.table_name
)
dataset.is_physical = not sqla_table.sql
@staticmethod
def write_shadow_dataset( # pylint: disable=too-many-locals
dataset: "SqlaTable", database: Database, session: Session
def write_shadow_dataset(
self: "SqlaTable",
) -> None:
"""
Shadow write the dataset to new models.
@ -2273,95 +2289,57 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
For more context: https://github.com/apache/superset/issues/14909
"""
engine = database.get_sqla_engine(schema=dataset.schema)
conditional_quote = engine.dialect.identifier_preparer.quote
session = inspect(self).session
# make sure database points to the right instance, in case only
# `table.database_id` is updated and the changes haven't been
# consolidated by SQLA
if self.database_id and (
not self.database or self.database.id != self.database_id
):
self.database = session.query(Database).filter_by(id=self.database_id).one()
# create columns
columns = []
for column in dataset.columns:
# ``is_active`` might be ``None`` at this point, but it defaults to ``True``.
if column.is_active is False:
continue
try:
extra_json = json.loads(column.extra or "{}")
except json.decoder.JSONDecodeError:
extra_json = {}
for attr in {"groupby", "filterable", "verbose_name", "python_date_format"}:
value = getattr(column, attr)
if value:
extra_json[attr] = value
columns.append(
NewColumn(
name=column.column_name,
type=column.type or "Unknown",
expression=column.expression
or conditional_quote(column.column_name),
description=column.description,
is_temporal=column.is_dttm,
is_aggregation=False,
is_physical=column.expression is None,
is_spatial=False,
is_partition=False,
is_increase_desired=True,
extra_json=json.dumps(extra_json) if extra_json else None,
is_managed_externally=dataset.is_managed_externally,
external_url=dataset.external_url,
),
)
# create metrics
for metric in dataset.metrics:
try:
extra_json = json.loads(metric.extra or "{}")
except json.decoder.JSONDecodeError:
extra_json = {}
for attr in {"verbose_name", "metric_type", "d3format"}:
value = getattr(metric, attr)
if value:
extra_json[attr] = value
is_additive = (
metric.metric_type
and metric.metric_type.lower() in ADDITIVE_METRIC_TYPES
)
columns.append(
NewColumn(
name=metric.metric_name,
type="Unknown", # figuring this out would require a type inferrer
expression=metric.expression,
warning_text=metric.warning_text,
description=metric.description,
is_aggregation=True,
is_additive=is_additive,
is_physical=False,
is_spatial=False,
is_partition=False,
is_increase_desired=True,
extra_json=json.dumps(extra_json) if extra_json else None,
is_managed_externally=dataset.is_managed_externally,
external_url=dataset.external_url,
),
)
for item in self.columns + self.metrics:
item.created_by = self.created_by
item.changed_by = self.changed_by
# on `SqlaTable.after_insert`` event, although the table itself
# already has a `uuid`, the associated columns will not.
# Here we pre-assign a uuid so they can still be matched to the new
# Column after creation.
if not item.uuid:
item.uuid = uuid4()
columns.append(item.to_sl_column())
# physical dataset
if not dataset.sql:
physical_columns = [column for column in columns if column.is_physical]
# create table
table = NewTable(
name=dataset.table_name,
schema=dataset.schema,
catalog=None, # currently not supported
database_id=dataset.database_id,
columns=physical_columns,
is_managed_externally=dataset.is_managed_externally,
external_url=dataset.external_url,
if not self.sql:
# always create separate column entries for Dataset and Table
# so updating a dataset would not update columns in the related table
physical_columns = [
clone_model(
column,
ignore=["uuid"],
# `created_by` will always be left empty because it'd always
# be created via some sort of automated system.
# But keep `changed_by` in case someone manually changes
# column attributes such as `is_dttm`.
keep_relations=["changed_by"],
)
for column in columns
if column.is_physical
]
tables = NewTable.bulk_load_or_create(
self.database,
[TableName(schema=self.schema, table=self.table_name)],
sync_columns=False,
default_props=dict(
created_by=self.created_by,
changed_by=self.changed_by,
is_managed_externally=self.is_managed_externally,
external_url=self.external_url,
),
)
tables = [table]
tables[0].columns = physical_columns
# virtual dataset
else:
@ -2370,26 +2348,39 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
column.is_physical = False
# find referenced tables
parsed = ParsedQuery(dataset.sql)
referenced_tables = parsed.tables
tables = load_or_create_tables(
session,
database,
dataset.schema,
referenced_tables = extract_table_references(
self.sql, self.database.get_dialect().name
)
tables = NewTable.bulk_load_or_create(
self.database,
referenced_tables,
conditional_quote,
default_schema=self.schema,
# syncing table columns can be slow so we are not doing it here
sync_columns=False,
default_props=dict(
created_by=self.created_by,
changed_by=self.changed_by,
is_managed_externally=self.is_managed_externally,
external_url=self.external_url,
),
)
# create the new dataset
new_dataset = NewDataset(
sqlatable_id=dataset.id,
name=dataset.table_name,
expression=dataset.sql or conditional_quote(dataset.table_name),
uuid=self.uuid,
database_id=self.database_id,
created_on=self.created_on,
created_by=self.created_by,
changed_by=self.changed_by,
changed_on=self.changed_on,
owners=self.owners,
name=self.table_name,
expression=self.sql or self.quote_identifier(self.table_name),
tables=tables,
columns=columns,
is_physical=not dataset.sql,
is_managed_externally=dataset.is_managed_externally,
external_url=dataset.external_url,
is_physical=not self.sql,
is_managed_externally=self.is_managed_externally,
external_url=self.external_url,
)
session.add(new_dataset)
@ -2399,7 +2390,9 @@ sa.event.listen(SqlaTable, "after_insert", SqlaTable.after_insert)
sa.event.listen(SqlaTable, "after_delete", SqlaTable.after_delete)
sa.event.listen(SqlaTable, "after_update", SqlaTable.after_update)
sa.event.listen(SqlMetric, "after_update", SqlaTable.update_table)
sa.event.listen(SqlMetric, "after_delete", SqlMetric.after_delete)
sa.event.listen(TableColumn, "after_update", SqlaTable.update_table)
sa.event.listen(TableColumn, "after_delete", TableColumn.after_delete)
RLSFilterRoles = Table(
"rls_filter_roles",

View File

@ -15,16 +15,28 @@
# specific language governing permissions and limitations
# under the License.
from contextlib import closing
from typing import Any, Callable, Dict, List, Optional, Set, TYPE_CHECKING
from typing import (
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
Type,
TYPE_CHECKING,
TypeVar,
)
from uuid import UUID
import sqlparse
from flask_babel import lazy_gettext as _
from sqlalchemy import and_, or_
from sqlalchemy.engine.url import URL as SqlaURL
from sqlalchemy.exc import NoSuchTableError
from sqlalchemy.ext.declarative import DeclarativeMeta
from sqlalchemy.orm import Session
from sqlalchemy.sql.type_api import TypeEngine
from superset.columns.models import Column as NewColumn
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import (
SupersetGenericDBErrorException,
@ -32,9 +44,9 @@ from superset.exceptions import (
)
from superset.models.core import Database
from superset.result_set import SupersetResultSet
from superset.sql_parse import has_table_query, insert_rls, ParsedQuery, Table
from superset.sql_parse import has_table_query, insert_rls, ParsedQuery
from superset.superset_typing import ResultSetColumnType
from superset.tables.models import Table as NewTable
from superset.utils.memoized import memoized
if TYPE_CHECKING:
from superset.connectors.sqla.models import SqlaTable
@ -168,75 +180,38 @@ def validate_adhoc_subquery(
return ";\n".join(str(statement) for statement in statements)
def load_or_create_tables( # pylint: disable=too-many-arguments
@memoized
def get_dialect_name(drivername: str) -> str:
return SqlaURL(drivername).get_dialect().name
@memoized
def get_identifier_quoter(drivername: str) -> Dict[str, Callable[[str], str]]:
return SqlaURL(drivername).get_dialect()().identifier_preparer.quote
DeclarativeModel = TypeVar("DeclarativeModel", bound=DeclarativeMeta)
def find_cached_objects_in_session(
session: Session,
database: Database,
default_schema: Optional[str],
tables: Set[Table],
conditional_quote: Callable[[str], str],
) -> List[NewTable]:
"""
Load or create new table model instances.
"""
if not tables:
return []
cls: Type[DeclarativeModel],
ids: Optional[Iterable[int]] = None,
uuids: Optional[Iterable[UUID]] = None,
) -> Iterator[DeclarativeModel]:
"""Find known ORM instances in cached SQLA session states.
# set the default schema in tables that don't have it
if default_schema:
fixed_tables = list(tables)
for i, table in enumerate(fixed_tables):
if table.schema is None:
fixed_tables[i] = Table(table.table, default_schema, table.catalog)
tables = set(fixed_tables)
# load existing tables
predicate = or_(
*[
and_(
NewTable.database_id == database.id,
NewTable.schema == table.schema,
NewTable.name == table.table,
)
for table in tables
]
:param session: a SQLA session
:param cls: a SQLA DeclarativeModel
:param ids: ids of the desired model instances (optional)
:param uuids: uuids of the desired instances, will be ignored if `ids` are provides
"""
if not ids and not uuids:
return iter([])
uuids = uuids or []
return (
item
# `session` is an iterator of all known items
for item in set(session)
if isinstance(item, cls) and (item.id in ids if ids else item.uuid in uuids)
)
new_tables = session.query(NewTable).filter(predicate).all()
# add missing tables
existing = {(table.schema, table.name) for table in new_tables}
for table in tables:
if (table.schema, table.table) not in existing:
try:
column_metadata = get_physical_table_metadata(
database=database,
table_name=table.table,
schema_name=table.schema,
)
except Exception: # pylint: disable=broad-except
continue
columns = [
NewColumn(
name=column["name"],
type=str(column["type"]),
expression=conditional_quote(column["name"]),
is_temporal=column["is_dttm"],
is_aggregation=False,
is_physical=True,
is_spatial=False,
is_partition=False,
is_increase_desired=True,
)
for column in column_metadata
]
new_tables.append(
NewTable(
name=table.table,
schema=table.schema,
catalog=None,
database_id=database.id,
columns=columns,
)
)
existing.add((table.schema, table.table))
return new_tables

View File

@ -28,9 +28,11 @@ from typing import List
import sqlalchemy as sa
from flask_appbuilder import Model
from sqlalchemy.orm import relationship
from sqlalchemy.orm import backref, relationship
from superset import security_manager
from superset.columns.models import Column
from superset.models.core import Database
from superset.models.helpers import (
AuditMixinNullable,
ExtraJSONMixin,
@ -38,18 +40,33 @@ from superset.models.helpers import (
)
from superset.tables.models import Table
column_association_table = sa.Table(
dataset_column_association_table = sa.Table(
"sl_dataset_columns",
Model.metadata, # pylint: disable=no-member
sa.Column("dataset_id", sa.ForeignKey("sl_datasets.id")),
sa.Column("column_id", sa.ForeignKey("sl_columns.id")),
sa.Column(
"dataset_id",
sa.ForeignKey("sl_datasets.id"),
primary_key=True,
),
sa.Column(
"column_id",
sa.ForeignKey("sl_columns.id"),
primary_key=True,
),
)
table_association_table = sa.Table(
dataset_table_association_table = sa.Table(
"sl_dataset_tables",
Model.metadata, # pylint: disable=no-member
sa.Column("dataset_id", sa.ForeignKey("sl_datasets.id")),
sa.Column("table_id", sa.ForeignKey("sl_tables.id")),
sa.Column("dataset_id", sa.ForeignKey("sl_datasets.id"), primary_key=True),
sa.Column("table_id", sa.ForeignKey("sl_tables.id"), primary_key=True),
)
dataset_user_association_table = sa.Table(
"sl_dataset_users",
Model.metadata, # pylint: disable=no-member
sa.Column("dataset_id", sa.ForeignKey("sl_datasets.id"), primary_key=True),
sa.Column("user_id", sa.ForeignKey("ab_user.id"), primary_key=True),
)
@ -61,27 +78,27 @@ class Dataset(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin):
__tablename__ = "sl_datasets"
id = sa.Column(sa.Integer, primary_key=True)
# A temporary column, used for shadow writing to the new model. Once the ``SqlaTable``
# model has been deleted this column can be removed.
sqlatable_id = sa.Column(sa.Integer, nullable=True, unique=True)
# We use ``sa.Text`` for these attributes because (1) in modern databases the
# performance is the same as ``VARCHAR``[1] and (2) because some table names can be
# **really** long (eg, Google Sheets URLs).
#
# [1] https://www.postgresql.org/docs/9.1/datatype-character.html
name = sa.Column(sa.Text)
expression = sa.Column(sa.Text)
# n:n relationship
tables: List[Table] = relationship("Table", secondary=table_association_table)
# The relationship between datasets and columns is 1:n, but we use a many-to-many
# association to differentiate between the relationship between tables and columns.
database_id = sa.Column(sa.Integer, sa.ForeignKey("dbs.id"), nullable=False)
database: Database = relationship(
"Database",
backref=backref("datasets", cascade="all, delete-orphan"),
foreign_keys=[database_id],
)
# The relationship between datasets and columns is 1:n, but we use a
# many-to-many association table to avoid adding two mutually exclusive
# columns(dataset_id and table_id) to Column
columns: List[Column] = relationship(
"Column", secondary=column_association_table, cascade="all, delete"
"Column",
secondary=dataset_column_association_table,
cascade="all, delete-orphan",
single_parent=True,
backref="datasets",
)
owners = relationship(
security_manager.user_model, secondary=dataset_user_association_table
)
tables: List[Table] = relationship(
"Table", secondary=dataset_table_association_table, backref="datasets"
)
# Does the dataset point directly to a ``Table``?
@ -89,4 +106,15 @@ class Dataset(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin):
# Column is managed externally and should be read-only inside Superset
is_managed_externally = sa.Column(sa.Boolean, nullable=False, default=False)
# We use ``sa.Text`` for these attributes because (1) in modern databases the
# performance is the same as ``VARCHAR``[1] and (2) because some table names can be
# **really** long (eg, Google Sheets URLs).
#
# [1] https://www.postgresql.org/docs/9.1/datatype-character.html
name = sa.Column(sa.Text)
expression = sa.Column(sa.Text)
external_url = sa.Column(sa.Text, nullable=True)
def __repr__(self) -> str:
return f"<Dataset id={self.id} database_id={self.database_id} {self.name}>"

View File

@ -135,23 +135,26 @@ def _set_table_metadata(datasource: SqlaTable, database: "Database") -> None:
def _add_table_metrics(datasource: SqlaTable) -> None:
if not any(col.column_name == "num_california" for col in datasource.columns):
# By accessing the attribute first, we make sure `datasource.columns` and
# `datasource.metrics` are already loaded. Otherwise accessing them later
# may trigger an unnecessary and unexpected `after_update` event.
columns, metrics = datasource.columns, datasource.metrics
if not any(col.column_name == "num_california" for col in columns):
col_state = str(column("state").compile(db.engine))
col_num = str(column("num").compile(db.engine))
datasource.columns.append(
columns.append(
TableColumn(
column_name="num_california",
expression=f"CASE WHEN {col_state} = 'CA' THEN {col_num} ELSE 0 END",
)
)
if not any(col.metric_name == "sum__num" for col in datasource.metrics):
if not any(col.metric_name == "sum__num" for col in metrics):
col = str(column("num").compile(db.engine))
datasource.metrics.append(
SqlMetric(metric_name="sum__num", expression=f"SUM({col})")
)
metrics.append(SqlMetric(metric_name="sum__num", expression=f"SUM({col})"))
for col in datasource.columns:
for col in columns:
if col.column_name == "ds":
col.is_dttm = True
break

View File

@ -15,42 +15,22 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Iterator, Optional, Set
import os
import time
from typing import Any
from uuid import uuid4
from alembic import op
from sqlalchemy import engine_from_config
from sqlalchemy.dialects.mysql.base import MySQLDialect
from sqlalchemy.dialects.postgresql.base import PGDialect
from sqlalchemy.engine import reflection
from sqlalchemy.exc import NoSuchTableError
from sqlalchemy.orm import Session
try:
from sqloxide import parse_sql
except ImportError:
parse_sql = None
logger = logging.getLogger(__name__)
from superset.sql_parse import ParsedQuery, Table
logger = logging.getLogger("alembic")
# mapping between sqloxide and SQLAlchemy dialects
sqloxide_dialects = {
"ansi": {"trino", "trinonative", "presto"},
"hive": {"hive", "databricks"},
"ms": {"mssql"},
"mysql": {"mysql"},
"postgres": {
"cockroachdb",
"hana",
"netezza",
"postgres",
"postgresql",
"redshift",
"vertica",
},
"snowflake": {"snowflake"},
"sqlite": {"sqlite", "gsheets", "shillelagh"},
"clickhouse": {"clickhouse"},
}
DEFAULT_BATCH_SIZE = int(os.environ.get("BATCH_SIZE", 1000))
def table_has_column(table: str, column: str) -> bool:
@ -61,7 +41,6 @@ def table_has_column(table: str, column: str) -> bool:
:param column: A column name
:returns: True iff the column exists in the table
"""
config = op.get_context().config
engine = engine_from_config(
config.get_section(config.config_ini_section), prefix="sqlalchemy."
@ -73,42 +52,44 @@ def table_has_column(table: str, column: str) -> bool:
return False
def find_nodes_by_key(element: Any, target: str) -> Iterator[Any]:
"""
Find all nodes in a SQL tree matching a given key.
"""
if isinstance(element, list):
for child in element:
yield from find_nodes_by_key(child, target)
elif isinstance(element, dict):
for key, value in element.items():
if key == target:
yield value
else:
yield from find_nodes_by_key(value, target)
uuid_by_dialect = {
MySQLDialect: "UNHEX(REPLACE(CONVERT(UUID() using utf8mb4), '-', ''))",
PGDialect: "uuid_in(md5(random()::text || clock_timestamp()::text)::cstring)",
}
def extract_table_references(sql_text: str, sqla_dialect: str) -> Set[Table]:
"""
Return all the dependencies from a SQL sql_text.
"""
if not parse_sql:
parsed = ParsedQuery(sql_text)
return parsed.tables
def assign_uuids(
model: Any, session: Session, batch_size: int = DEFAULT_BATCH_SIZE
) -> None:
"""Generate new UUIDs for all rows in a table"""
bind = op.get_bind()
table_name = model.__tablename__
count = session.query(model).count()
# silently skip if the table is empty (suitable for db initialization)
if count == 0:
return
dialect = "generic"
for dialect, sqla_dialects in sqloxide_dialects.items():
if sqla_dialect in sqla_dialects:
break
try:
tree = parse_sql(sql_text, dialect=dialect)
except Exception: # pylint: disable=broad-except
logger.warning("Unable to parse query with sqloxide: %s", sql_text)
# fallback to sqlparse
parsed = ParsedQuery(sql_text)
return parsed.tables
start_time = time.time()
print(f"\nAdding uuids for `{table_name}`...")
# Use dialect specific native SQL queries if possible
for dialect, sql in uuid_by_dialect.items():
if isinstance(bind.dialect, dialect):
op.execute(
f"UPDATE {dialect().identifier_preparer.quote(table_name)} SET uuid = {sql}"
)
print(f"Done. Assigned {count} uuids in {time.time() - start_time:.3f}s.\n")
return
return {
Table(*[part["value"] for part in table["name"][::-1]])
for table in find_nodes_by_key(tree, "Table")
}
# Othwewise Use Python uuid function
start = 0
while start < count:
end = min(start + batch_size, count)
for obj in session.query(model)[start:end]:
obj.uuid = uuid4()
session.merge(obj)
session.commit()
if start + batch_size < count:
print(f" uuid assigned to {end} out of {count}\r", end="")
start += batch_size
print(f"Done. Assigned {count} uuids in {time.time() - start_time:.3f}s.\n")

View File

@ -32,9 +32,7 @@ from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy_utils import UUIDType
from superset import db
from superset.migrations.versions.b56500de1855_add_uuid_column_to_import_mixin import (
add_uuids,
)
from superset.migrations.shared.utils import assign_uuids
# revision identifiers, used by Alembic.
revision = "96e99fb176a0"
@ -75,7 +73,7 @@ def upgrade():
# Ignore column update errors so that we can run upgrade multiple times
pass
add_uuids(SavedQuery, "saved_query", session)
assign_uuids(SavedQuery, session)
try:
# Add uniqueness constraint

View File

@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""empty message
"""merge point
Revision ID: 9d8a8d575284
Revises: ('8b841273bec3', 'b0d0249074e4')

View File

@ -0,0 +1,905 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""new_dataset_models_take_2
Revision ID: a9422eeaae74
Revises: ad07e4fdbaba
Create Date: 2022-04-01 14:38:09.499483
"""
# revision identifiers, used by Alembic.
revision = "a9422eeaae74"
down_revision = "ad07e4fdbaba"
import json
import os
from datetime import datetime
from typing import List, Optional, Set, Type, Union
from uuid import uuid4
import sqlalchemy as sa
from alembic import op
from sqlalchemy import select
from sqlalchemy.ext.declarative import declarative_base, declared_attr
from sqlalchemy.orm import backref, relationship, Session
from sqlalchemy.schema import UniqueConstraint
from sqlalchemy.sql import functions as func
from sqlalchemy.sql.expression import and_, or_
from sqlalchemy_utils import UUIDType
from superset import app, db
from superset.connectors.sqla.models import ADDITIVE_METRIC_TYPES_LOWER
from superset.connectors.sqla.utils import get_dialect_name, get_identifier_quoter
from superset.extensions import encrypted_field_factory
from superset.migrations.shared.utils import assign_uuids
from superset.sql_parse import extract_table_references, Table
from superset.utils.core import MediumText
Base = declarative_base()
custom_password_store = app.config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"]
DB_CONNECTION_MUTATOR = app.config["DB_CONNECTION_MUTATOR"]
SHOW_PROGRESS = os.environ.get("SHOW_PROGRESS") == "1"
UNKNOWN_TYPE = "UNKNOWN"
user_table = sa.Table(
"ab_user", Base.metadata, sa.Column("id", sa.Integer(), primary_key=True)
)
class UUIDMixin:
uuid = sa.Column(
UUIDType(binary=True), primary_key=False, unique=True, default=uuid4
)
class AuxiliaryColumnsMixin(UUIDMixin):
"""
Auxiliary columns, a combination of columns added by
AuditMixinNullable + ImportExportMixin
"""
created_on = sa.Column(sa.DateTime, default=datetime.now, nullable=True)
changed_on = sa.Column(
sa.DateTime, default=datetime.now, onupdate=datetime.now, nullable=True
)
@declared_attr
def created_by_fk(cls):
return sa.Column(sa.Integer, sa.ForeignKey("ab_user.id"), nullable=True)
@declared_attr
def changed_by_fk(cls):
return sa.Column(sa.Integer, sa.ForeignKey("ab_user.id"), nullable=True)
def insert_from_select(
target: Union[str, sa.Table, Type[Base]], source: sa.sql.expression.Select
) -> None:
"""
Execute INSERT FROM SELECT to copy data from a SELECT query to the target table.
"""
if isinstance(target, sa.Table):
target_table = target
elif hasattr(target, "__tablename__"):
target_table: sa.Table = Base.metadata.tables[target.__tablename__]
else:
target_table: sa.Table = Base.metadata.tables[target]
cols = [col.name for col in source.columns if col.name in target_table.columns]
query = target_table.insert().from_select(cols, source)
return op.execute(query)
class Database(Base):
__tablename__ = "dbs"
__table_args__ = (UniqueConstraint("database_name"),)
id = sa.Column(sa.Integer, primary_key=True)
database_name = sa.Column(sa.String(250), unique=True, nullable=False)
sqlalchemy_uri = sa.Column(sa.String(1024), nullable=False)
password = sa.Column(encrypted_field_factory.create(sa.String(1024)))
impersonate_user = sa.Column(sa.Boolean, default=False)
encrypted_extra = sa.Column(encrypted_field_factory.create(sa.Text), nullable=True)
extra = sa.Column(sa.Text)
server_cert = sa.Column(encrypted_field_factory.create(sa.Text), nullable=True)
class TableColumn(AuxiliaryColumnsMixin, Base):
__tablename__ = "table_columns"
__table_args__ = (UniqueConstraint("table_id", "column_name"),)
id = sa.Column(sa.Integer, primary_key=True)
table_id = sa.Column(sa.Integer, sa.ForeignKey("tables.id"))
is_active = sa.Column(sa.Boolean, default=True)
extra = sa.Column(sa.Text)
column_name = sa.Column(sa.String(255), nullable=False)
type = sa.Column(sa.String(32))
expression = sa.Column(MediumText())
description = sa.Column(MediumText())
is_dttm = sa.Column(sa.Boolean, default=False)
filterable = sa.Column(sa.Boolean, default=True)
groupby = sa.Column(sa.Boolean, default=True)
verbose_name = sa.Column(sa.String(1024))
python_date_format = sa.Column(sa.String(255))
class SqlMetric(AuxiliaryColumnsMixin, Base):
__tablename__ = "sql_metrics"
__table_args__ = (UniqueConstraint("table_id", "metric_name"),)
id = sa.Column(sa.Integer, primary_key=True)
table_id = sa.Column(sa.Integer, sa.ForeignKey("tables.id"))
extra = sa.Column(sa.Text)
metric_type = sa.Column(sa.String(32))
metric_name = sa.Column(sa.String(255), nullable=False)
expression = sa.Column(MediumText(), nullable=False)
warning_text = sa.Column(MediumText())
description = sa.Column(MediumText())
d3format = sa.Column(sa.String(128))
verbose_name = sa.Column(sa.String(1024))
sqlatable_user_table = sa.Table(
"sqlatable_user",
Base.metadata,
sa.Column("id", sa.Integer, primary_key=True),
sa.Column("user_id", sa.Integer, sa.ForeignKey("ab_user.id")),
sa.Column("table_id", sa.Integer, sa.ForeignKey("tables.id")),
)
class SqlaTable(AuxiliaryColumnsMixin, Base):
__tablename__ = "tables"
__table_args__ = (UniqueConstraint("database_id", "schema", "table_name"),)
id = sa.Column(sa.Integer, primary_key=True)
extra = sa.Column(sa.Text)
database_id = sa.Column(sa.Integer, sa.ForeignKey("dbs.id"), nullable=False)
database: Database = relationship(
"Database",
backref=backref("tables", cascade="all, delete-orphan"),
foreign_keys=[database_id],
)
schema = sa.Column(sa.String(255))
table_name = sa.Column(sa.String(250), nullable=False)
sql = sa.Column(MediumText())
is_managed_externally = sa.Column(sa.Boolean, nullable=False, default=False)
external_url = sa.Column(sa.Text, nullable=True)
table_column_association_table = sa.Table(
"sl_table_columns",
Base.metadata,
sa.Column("table_id", sa.ForeignKey("sl_tables.id"), primary_key=True),
sa.Column("column_id", sa.ForeignKey("sl_columns.id"), primary_key=True),
)
dataset_column_association_table = sa.Table(
"sl_dataset_columns",
Base.metadata,
sa.Column("dataset_id", sa.ForeignKey("sl_datasets.id"), primary_key=True),
sa.Column("column_id", sa.ForeignKey("sl_columns.id"), primary_key=True),
)
dataset_table_association_table = sa.Table(
"sl_dataset_tables",
Base.metadata,
sa.Column("dataset_id", sa.ForeignKey("sl_datasets.id"), primary_key=True),
sa.Column("table_id", sa.ForeignKey("sl_tables.id"), primary_key=True),
)
dataset_user_association_table = sa.Table(
"sl_dataset_users",
Base.metadata,
sa.Column("dataset_id", sa.ForeignKey("sl_datasets.id"), primary_key=True),
sa.Column("user_id", sa.ForeignKey("ab_user.id"), primary_key=True),
)
class NewColumn(AuxiliaryColumnsMixin, Base):
__tablename__ = "sl_columns"
id = sa.Column(sa.Integer, primary_key=True)
# A temporary column to link physical columns with tables so we don't
# have to insert a record in the relationship table while creating new columns.
table_id = sa.Column(sa.Integer, nullable=True)
is_aggregation = sa.Column(sa.Boolean, nullable=False, default=False)
is_additive = sa.Column(sa.Boolean, nullable=False, default=False)
is_dimensional = sa.Column(sa.Boolean, nullable=False, default=False)
is_filterable = sa.Column(sa.Boolean, nullable=False, default=True)
is_increase_desired = sa.Column(sa.Boolean, nullable=False, default=True)
is_managed_externally = sa.Column(sa.Boolean, nullable=False, default=False)
is_partition = sa.Column(sa.Boolean, nullable=False, default=False)
is_physical = sa.Column(sa.Boolean, nullable=False, default=False)
is_temporal = sa.Column(sa.Boolean, nullable=False, default=False)
is_spatial = sa.Column(sa.Boolean, nullable=False, default=False)
name = sa.Column(sa.Text)
type = sa.Column(sa.Text)
unit = sa.Column(sa.Text)
expression = sa.Column(MediumText())
description = sa.Column(MediumText())
warning_text = sa.Column(MediumText())
external_url = sa.Column(sa.Text, nullable=True)
extra_json = sa.Column(MediumText(), default="{}")
class NewTable(AuxiliaryColumnsMixin, Base):
__tablename__ = "sl_tables"
id = sa.Column(sa.Integer, primary_key=True)
# A temporary column to keep the link between NewTable to SqlaTable
sqlatable_id = sa.Column(sa.Integer, primary_key=False, nullable=True, unique=True)
database_id = sa.Column(sa.Integer, sa.ForeignKey("dbs.id"), nullable=False)
is_managed_externally = sa.Column(sa.Boolean, nullable=False, default=False)
catalog = sa.Column(sa.Text)
schema = sa.Column(sa.Text)
name = sa.Column(sa.Text)
external_url = sa.Column(sa.Text, nullable=True)
extra_json = sa.Column(MediumText(), default="{}")
database: Database = relationship(
"Database",
backref=backref("new_tables", cascade="all, delete-orphan"),
foreign_keys=[database_id],
)
class NewDataset(Base, AuxiliaryColumnsMixin):
__tablename__ = "sl_datasets"
id = sa.Column(sa.Integer, primary_key=True)
database_id = sa.Column(sa.Integer, sa.ForeignKey("dbs.id"), nullable=False)
is_physical = sa.Column(sa.Boolean, default=False)
is_managed_externally = sa.Column(sa.Boolean, nullable=False, default=False)
name = sa.Column(sa.Text)
expression = sa.Column(MediumText())
external_url = sa.Column(sa.Text, nullable=True)
extra_json = sa.Column(MediumText(), default="{}")
def find_tables(
session: Session,
database_id: int,
default_schema: Optional[str],
tables: Set[Table],
) -> List[int]:
"""
Look for NewTable's of from a specific database
"""
if not tables:
return []
predicate = or_(
*[
and_(
NewTable.database_id == database_id,
NewTable.schema == (table.schema or default_schema),
NewTable.name == table.table,
)
for table in tables
]
)
return session.query(NewTable.id).filter(predicate).all()
# helper SQLA elements for easier querying
is_physical_table = or_(SqlaTable.sql.is_(None), SqlaTable.sql == "")
is_physical_column = or_(TableColumn.expression.is_(None), TableColumn.expression == "")
# filtering out table columns with valid associated SqlTable
active_table_columns = sa.join(
TableColumn,
SqlaTable,
TableColumn.table_id == SqlaTable.id,
)
active_metrics = sa.join(SqlMetric, SqlaTable, SqlMetric.table_id == SqlaTable.id)
def copy_tables(session: Session) -> None:
"""Copy Physical tables"""
count = session.query(SqlaTable).filter(is_physical_table).count()
if not count:
return
print(f">> Copy {count:,} physical tables to sl_tables...")
insert_from_select(
NewTable,
select(
[
# Tables need different uuid than datasets, since they are different
# entities. When INSERT FROM SELECT, we must provide a value for `uuid`,
# otherwise it'd use the default generated on Python side, which
# will cause duplicate values. They will be replaced by `assign_uuids` later.
SqlaTable.uuid,
SqlaTable.id.label("sqlatable_id"),
SqlaTable.created_on,
SqlaTable.changed_on,
SqlaTable.created_by_fk,
SqlaTable.changed_by_fk,
SqlaTable.table_name.label("name"),
SqlaTable.schema,
SqlaTable.database_id,
SqlaTable.is_managed_externally,
SqlaTable.external_url,
]
)
# use an inner join to filter out only tables with valid database ids
.select_from(
sa.join(SqlaTable, Database, SqlaTable.database_id == Database.id)
).where(is_physical_table),
)
def copy_datasets(session: Session) -> None:
"""Copy all datasets"""
count = session.query(SqlaTable).count()
if not count:
return
print(f">> Copy {count:,} SqlaTable to sl_datasets...")
insert_from_select(
NewDataset,
select(
[
SqlaTable.uuid,
SqlaTable.created_on,
SqlaTable.changed_on,
SqlaTable.created_by_fk,
SqlaTable.changed_by_fk,
SqlaTable.database_id,
SqlaTable.table_name.label("name"),
func.coalesce(SqlaTable.sql, SqlaTable.table_name).label("expression"),
is_physical_table.label("is_physical"),
SqlaTable.is_managed_externally,
SqlaTable.external_url,
SqlaTable.extra.label("extra_json"),
]
),
)
print(" Copy dataset owners...")
insert_from_select(
dataset_user_association_table,
select(
[NewDataset.id.label("dataset_id"), sqlatable_user_table.c.user_id]
).select_from(
sqlatable_user_table.join(
SqlaTable, SqlaTable.id == sqlatable_user_table.c.table_id
).join(NewDataset, NewDataset.uuid == SqlaTable.uuid)
),
)
print(" Link physical datasets with tables...")
insert_from_select(
dataset_table_association_table,
select(
[
NewDataset.id.label("dataset_id"),
NewTable.id.label("table_id"),
]
).select_from(
sa.join(SqlaTable, NewTable, NewTable.sqlatable_id == SqlaTable.id).join(
NewDataset, NewDataset.uuid == SqlaTable.uuid
)
),
)
def copy_columns(session: Session) -> None:
"""Copy columns with active associated SqlTable"""
count = session.query(TableColumn).select_from(active_table_columns).count()
if not count:
return
print(f">> Copy {count:,} table columns to sl_columns...")
insert_from_select(
NewColumn,
select(
[
TableColumn.uuid,
TableColumn.created_on,
TableColumn.changed_on,
TableColumn.created_by_fk,
TableColumn.changed_by_fk,
TableColumn.groupby.label("is_dimensional"),
TableColumn.filterable.label("is_filterable"),
TableColumn.column_name.label("name"),
TableColumn.description,
func.coalesce(TableColumn.expression, TableColumn.column_name).label(
"expression"
),
sa.literal(False).label("is_aggregation"),
is_physical_column.label("is_physical"),
TableColumn.is_dttm.label("is_temporal"),
func.coalesce(TableColumn.type, UNKNOWN_TYPE).label("type"),
TableColumn.extra.label("extra_json"),
]
).select_from(active_table_columns),
)
joined_columns_table = active_table_columns.join(
NewColumn, TableColumn.uuid == NewColumn.uuid
)
print(" Link all columns to sl_datasets...")
insert_from_select(
dataset_column_association_table,
select(
[
NewDataset.id.label("dataset_id"),
NewColumn.id.label("column_id"),
],
).select_from(
joined_columns_table.join(NewDataset, NewDataset.uuid == SqlaTable.uuid)
),
)
def copy_metrics(session: Session) -> None:
"""Copy metrics as virtual columns"""
metrics_count = session.query(SqlMetric).select_from(active_metrics).count()
if not metrics_count:
return
print(f">> Copy {metrics_count:,} metrics to sl_columns...")
insert_from_select(
NewColumn,
select(
[
SqlMetric.uuid,
SqlMetric.created_on,
SqlMetric.changed_on,
SqlMetric.created_by_fk,
SqlMetric.changed_by_fk,
SqlMetric.metric_name.label("name"),
SqlMetric.expression,
SqlMetric.description,
sa.literal(UNKNOWN_TYPE).label("type"),
(
func.coalesce(
sa.func.lower(SqlMetric.metric_type).in_(
ADDITIVE_METRIC_TYPES_LOWER
),
sa.literal(False),
).label("is_additive")
),
sa.literal(True).label("is_aggregation"),
# metrics are by default not filterable
sa.literal(False).label("is_filterable"),
sa.literal(False).label("is_dimensional"),
sa.literal(False).label("is_physical"),
sa.literal(False).label("is_temporal"),
SqlMetric.extra.label("extra_json"),
SqlMetric.warning_text,
]
).select_from(active_metrics),
)
print(" Link metric columns to datasets...")
insert_from_select(
dataset_column_association_table,
select(
[
NewDataset.id.label("dataset_id"),
NewColumn.id.label("column_id"),
],
).select_from(
active_metrics.join(NewDataset, NewDataset.uuid == SqlaTable.uuid).join(
NewColumn, NewColumn.uuid == SqlMetric.uuid
)
),
)
def postprocess_datasets(session: Session) -> None:
"""
Postprocess datasets after insertion to
- Quote table names for physical datasets (if needed)
- Link referenced tables to virtual datasets
"""
total = session.query(SqlaTable).count()
if not total:
return
offset = 0
limit = 10000
joined_tables = sa.join(
NewDataset,
SqlaTable,
NewDataset.uuid == SqlaTable.uuid,
).join(
Database,
Database.id == SqlaTable.database_id,
isouter=True,
)
assert session.query(func.count()).select_from(joined_tables).scalar() == total
print(f">> Run postprocessing on {total} datasets")
update_count = 0
def print_update_count():
if SHOW_PROGRESS:
print(
f" Will update {update_count} datasets" + " " * 20,
end="\r",
)
while offset < total:
print(
f" Process dataset {offset + 1}~{min(total, offset + limit)}..."
+ " " * 30
)
for (
database_id,
dataset_id,
expression,
extra,
is_physical,
schema,
sqlalchemy_uri,
) in session.execute(
select(
[
NewDataset.database_id,
NewDataset.id.label("dataset_id"),
NewDataset.expression,
SqlaTable.extra,
NewDataset.is_physical,
SqlaTable.schema,
Database.sqlalchemy_uri,
]
)
.select_from(joined_tables)
.offset(offset)
.limit(limit)
):
drivername = (sqlalchemy_uri or "").split("://")[0]
updates = {}
updated = False
if is_physical and drivername:
quoted_expression = get_identifier_quoter(drivername)(expression)
if quoted_expression != expression:
updates["expression"] = quoted_expression
# add schema name to `dataset.extra_json` so we don't have to join
# tables in order to use datasets
if schema:
try:
extra_json = json.loads(extra) if extra else {}
except json.decoder.JSONDecodeError:
extra_json = {}
extra_json["schema"] = schema
updates["extra_json"] = json.dumps(extra_json)
if updates:
session.execute(
sa.update(NewDataset)
.where(NewDataset.id == dataset_id)
.values(**updates)
)
updated = True
if not is_physical and expression:
table_refrences = extract_table_references(
expression, get_dialect_name(drivername), show_warning=False
)
found_tables = find_tables(
session, database_id, schema, table_refrences
)
if found_tables:
op.bulk_insert(
dataset_table_association_table,
[
{"dataset_id": dataset_id, "table_id": table.id}
for table in found_tables
],
)
updated = True
if updated:
update_count += 1
print_update_count()
session.flush()
offset += limit
if SHOW_PROGRESS:
print("")
def postprocess_columns(session: Session) -> None:
"""
At this step, we will
- Add engine specific quotes to `expression` of physical columns
- Tuck some extra metadata to `extra_json`
"""
total = session.query(NewColumn).count()
if not total:
return
def get_joined_tables(offset, limit):
return (
sa.join(
session.query(NewColumn)
.offset(offset)
.limit(limit)
.subquery("sl_columns"),
dataset_column_association_table,
dataset_column_association_table.c.column_id == NewColumn.id,
)
.join(
NewDataset,
NewDataset.id == dataset_column_association_table.c.dataset_id,
)
.join(
dataset_table_association_table,
# Join tables with physical datasets
and_(
NewDataset.is_physical,
dataset_table_association_table.c.dataset_id == NewDataset.id,
),
isouter=True,
)
.join(Database, Database.id == NewDataset.database_id)
.join(
TableColumn,
TableColumn.uuid == NewColumn.uuid,
isouter=True,
)
.join(
SqlMetric,
SqlMetric.uuid == NewColumn.uuid,
isouter=True,
)
)
offset = 0
limit = 100000
print(f">> Run postprocessing on {total:,} columns")
update_count = 0
def print_update_count():
if SHOW_PROGRESS:
print(
f" Will update {update_count} columns" + " " * 20,
end="\r",
)
while offset < total:
query = (
select(
# sorted alphabetically
[
NewColumn.id.label("column_id"),
TableColumn.column_name,
NewColumn.changed_by_fk,
NewColumn.changed_on,
NewColumn.created_on,
NewColumn.description,
SqlMetric.d3format,
NewDataset.external_url,
NewColumn.extra_json,
NewColumn.is_dimensional,
NewColumn.is_filterable,
NewDataset.is_managed_externally,
NewColumn.is_physical,
SqlMetric.metric_type,
TableColumn.python_date_format,
Database.sqlalchemy_uri,
dataset_table_association_table.c.table_id,
func.coalesce(
TableColumn.verbose_name, SqlMetric.verbose_name
).label("verbose_name"),
NewColumn.warning_text,
]
)
.select_from(get_joined_tables(offset, limit))
.where(
# pre-filter to columns with potential updates
or_(
NewColumn.is_physical,
TableColumn.verbose_name.isnot(None),
TableColumn.verbose_name.isnot(None),
SqlMetric.verbose_name.isnot(None),
SqlMetric.d3format.isnot(None),
SqlMetric.metric_type.isnot(None),
)
)
)
start = offset + 1
end = min(total, offset + limit)
count = session.query(func.count()).select_from(query).scalar()
print(f" [Column {start:,} to {end:,}] {count:,} may be updated")
physical_columns = []
for (
# sorted alphabetically
column_id,
column_name,
changed_by_fk,
changed_on,
created_on,
description,
d3format,
external_url,
extra_json,
is_dimensional,
is_filterable,
is_managed_externally,
is_physical,
metric_type,
python_date_format,
sqlalchemy_uri,
table_id,
verbose_name,
warning_text,
) in session.execute(query):
try:
extra = json.loads(extra_json) if extra_json else {}
except json.decoder.JSONDecodeError:
extra = {}
updated_extra = {**extra}
updates = {}
if is_managed_externally:
updates["is_managed_externally"] = True
if external_url:
updates["external_url"] = external_url
# update extra json
for (key, val) in (
{
"verbose_name": verbose_name,
"python_date_format": python_date_format,
"d3format": d3format,
"metric_type": metric_type,
}
).items():
# save the original val, including if it's `false`
if val is not None:
updated_extra[key] = val
if updated_extra != extra:
updates["extra_json"] = json.dumps(updated_extra)
# update expression for physical table columns
if is_physical:
if column_name and sqlalchemy_uri:
drivername = sqlalchemy_uri.split("://")[0]
if is_physical and drivername:
quoted_expression = get_identifier_quoter(drivername)(
column_name
)
if quoted_expression != column_name:
updates["expression"] = quoted_expression
# duplicate physical columns for tables
physical_columns.append(
dict(
created_on=created_on,
changed_on=changed_on,
changed_by_fk=changed_by_fk,
description=description,
expression=updates.get("expression", column_name),
external_url=external_url,
extra_json=updates.get("extra_json", extra_json),
is_aggregation=False,
is_dimensional=is_dimensional,
is_filterable=is_filterable,
is_managed_externally=is_managed_externally,
is_physical=True,
name=column_name,
table_id=table_id,
warning_text=warning_text,
)
)
if updates:
session.execute(
sa.update(NewColumn)
.where(NewColumn.id == column_id)
.values(**updates)
)
update_count += 1
print_update_count()
if physical_columns:
op.bulk_insert(NewColumn.__table__, physical_columns)
session.flush()
offset += limit
if SHOW_PROGRESS:
print("")
print(" Assign table column relations...")
insert_from_select(
table_column_association_table,
select([NewColumn.table_id, NewColumn.id.label("column_id")])
.select_from(NewColumn)
.where(and_(NewColumn.is_physical, NewColumn.table_id.isnot(None))),
)
new_tables: sa.Table = [
NewTable.__table__,
NewDataset.__table__,
NewColumn.__table__,
table_column_association_table,
dataset_column_association_table,
dataset_table_association_table,
dataset_user_association_table,
]
def reset_postgres_id_sequence(table: str) -> None:
op.execute(
f"""
SELECT setval(
pg_get_serial_sequence('{table}', 'id'),
COALESCE(max(id) + 1, 1),
false
)
FROM {table};
"""
)
def upgrade() -> None:
bind = op.get_bind()
session: Session = db.Session(bind=bind)
Base.metadata.drop_all(bind=bind, tables=new_tables)
Base.metadata.create_all(bind=bind, tables=new_tables)
copy_tables(session)
copy_datasets(session)
copy_columns(session)
copy_metrics(session)
session.commit()
postprocess_columns(session)
session.commit()
postprocess_datasets(session)
session.commit()
# Table were created with the same uuids are datasets. They should
# have different uuids as they are different entities.
print(">> Assign new UUIDs to tables...")
assign_uuids(NewTable, session)
print(">> Drop intermediate columns...")
# These columns are are used during migration, as datasets are independent of tables once created,
# dataset columns also the same to table columns.
with op.batch_alter_table(NewTable.__tablename__) as batch_op:
batch_op.drop_column("sqlatable_id")
with op.batch_alter_table(NewColumn.__tablename__) as batch_op:
batch_op.drop_column("table_id")
def downgrade():
Base.metadata.drop_all(bind=op.get_bind(), tables=new_tables)

View File

@ -23,19 +23,17 @@ Create Date: 2020-09-28 17:57:23.128142
"""
import json
import os
import time
from json.decoder import JSONDecodeError
from uuid import uuid4
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects.mysql.base import MySQLDialect
from sqlalchemy.dialects.postgresql.base import PGDialect
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import load_only
from sqlalchemy_utils import UUIDType
from superset import db
from superset.migrations.shared.utils import assign_uuids
from superset.utils import core as utils
# revision identifiers, used by Alembic.
@ -78,47 +76,6 @@ models["dashboards"].position_json = sa.Column(utils.MediumText())
default_batch_size = int(os.environ.get("BATCH_SIZE", 200))
# Add uuids directly using built-in SQL uuid function
add_uuids_by_dialect = {
MySQLDialect: """UPDATE %s SET uuid = UNHEX(REPLACE(CONVERT(UUID() using utf8mb4), '-', ''));""",
PGDialect: """UPDATE %s SET uuid = uuid_in(md5(random()::text || clock_timestamp()::text)::cstring);""",
}
def add_uuids(model, table_name, session, batch_size=default_batch_size):
"""Populate columns with pre-computed uuids"""
bind = op.get_bind()
objects_query = session.query(model)
count = objects_query.count()
# silently skip if the table is empty (suitable for db initialization)
if count == 0:
return
print(f"\nAdding uuids for `{table_name}`...")
start_time = time.time()
# Use dialect specific native SQL queries if possible
for dialect, sql in add_uuids_by_dialect.items():
if isinstance(bind.dialect, dialect):
op.execute(sql % table_name)
print(f"Done. Assigned {count} uuids in {time.time() - start_time:.3f}s.")
return
# Othwewise Use Python uuid function
start = 0
while start < count:
end = min(start + batch_size, count)
for obj, uuid in map(lambda obj: (obj, uuid4()), objects_query[start:end]):
obj.uuid = uuid
session.merge(obj)
session.commit()
if start + batch_size < count:
print(f" uuid assigned to {end} out of {count}\r", end="")
start += batch_size
print(f"Done. Assigned {count} uuids in {time.time() - start_time:.3f}s.")
def update_position_json(dashboard, session, uuid_map):
try:
@ -178,7 +135,7 @@ def upgrade():
),
)
add_uuids(model, table_name, session)
assign_uuids(model, session)
# add uniqueness constraint
with op.batch_alter_table(table_name) as batch_op:
@ -203,7 +160,7 @@ def downgrade():
update_dashboards(session, {})
# remove uuid column
for table_name, model in models.items():
for table_name in models:
with op.batch_alter_table(table_name) as batch_op:
batch_op.drop_constraint(f"uq_{table_name}_uuid", type_="unique")
batch_op.drop_column("uuid")

View File

@ -23,619 +23,23 @@ Revises: 5afbb1a5849b
Create Date: 2021-11-11 16:41:53.266965
"""
import json
from datetime import date, datetime, time, timedelta
from typing import Callable, List, Optional, Set
from uuid import uuid4
import sqlalchemy as sa
from alembic import op
from sqlalchemy import and_, inspect, or_
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import backref, relationship, Session
from sqlalchemy.schema import UniqueConstraint
from sqlalchemy.sql.type_api import TypeEngine
from sqlalchemy_utils import UUIDType
from superset import app, db
from superset.connectors.sqla.models import ADDITIVE_METRIC_TYPES
from superset.databases.utils import make_url_safe
from superset.extensions import encrypted_field_factory
from superset.migrations.shared.utils import extract_table_references
from superset.models.core import Database as OriginalDatabase
from superset.sql_parse import Table
# revision identifiers, used by Alembic.
revision = "b8d3a24d9131"
down_revision = "5afbb1a5849b"
Base = declarative_base()
custom_password_store = app.config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"]
DB_CONNECTION_MUTATOR = app.config["DB_CONNECTION_MUTATOR"]
# ===================== Notice ========================
#
# Migrations made in this revision has been moved to `new_dataset_models_take_2`
# to fix performance issues as well as a couple of shortcomings in the original
# design.
#
# ======================================================
class Database(Base):
__tablename__ = "dbs"
__table_args__ = (UniqueConstraint("database_name"),)
id = sa.Column(sa.Integer, primary_key=True)
database_name = sa.Column(sa.String(250), unique=True, nullable=False)
sqlalchemy_uri = sa.Column(sa.String(1024), nullable=False)
password = sa.Column(encrypted_field_factory.create(sa.String(1024)))
impersonate_user = sa.Column(sa.Boolean, default=False)
encrypted_extra = sa.Column(encrypted_field_factory.create(sa.Text), nullable=True)
extra = sa.Column(
sa.Text,
default=json.dumps(
dict(
metadata_params={},
engine_params={},
metadata_cache_timeout={},
schemas_allowed_for_file_upload=[],
)
),
)
server_cert = sa.Column(encrypted_field_factory.create(sa.Text), nullable=True)
class TableColumn(Base):
__tablename__ = "table_columns"
__table_args__ = (UniqueConstraint("table_id", "column_name"),)
id = sa.Column(sa.Integer, primary_key=True)
table_id = sa.Column(sa.Integer, sa.ForeignKey("tables.id"))
is_active = sa.Column(sa.Boolean, default=True)
extra = sa.Column(sa.Text)
column_name = sa.Column(sa.String(255), nullable=False)
type = sa.Column(sa.String(32))
expression = sa.Column(sa.Text)
description = sa.Column(sa.Text)
is_dttm = sa.Column(sa.Boolean, default=False)
filterable = sa.Column(sa.Boolean, default=True)
groupby = sa.Column(sa.Boolean, default=True)
verbose_name = sa.Column(sa.String(1024))
python_date_format = sa.Column(sa.String(255))
class SqlMetric(Base):
__tablename__ = "sql_metrics"
__table_args__ = (UniqueConstraint("table_id", "metric_name"),)
id = sa.Column(sa.Integer, primary_key=True)
table_id = sa.Column(sa.Integer, sa.ForeignKey("tables.id"))
extra = sa.Column(sa.Text)
metric_type = sa.Column(sa.String(32))
metric_name = sa.Column(sa.String(255), nullable=False)
expression = sa.Column(sa.Text, nullable=False)
warning_text = sa.Column(sa.Text)
description = sa.Column(sa.Text)
d3format = sa.Column(sa.String(128))
verbose_name = sa.Column(sa.String(1024))
class SqlaTable(Base):
__tablename__ = "tables"
__table_args__ = (UniqueConstraint("database_id", "schema", "table_name"),)
def fetch_columns_and_metrics(self, session: Session) -> None:
self.columns = session.query(TableColumn).filter(
TableColumn.table_id == self.id
)
self.metrics = session.query(SqlMetric).filter(TableColumn.table_id == self.id)
id = sa.Column(sa.Integer, primary_key=True)
columns: List[TableColumn] = []
column_class = TableColumn
metrics: List[SqlMetric] = []
metric_class = SqlMetric
database_id = sa.Column(sa.Integer, sa.ForeignKey("dbs.id"), nullable=False)
database: Database = relationship(
"Database",
backref=backref("tables", cascade="all, delete-orphan"),
foreign_keys=[database_id],
)
schema = sa.Column(sa.String(255))
table_name = sa.Column(sa.String(250), nullable=False)
sql = sa.Column(sa.Text)
is_managed_externally = sa.Column(sa.Boolean, nullable=False, default=False)
external_url = sa.Column(sa.Text, nullable=True)
table_column_association_table = sa.Table(
"sl_table_columns",
Base.metadata,
sa.Column("table_id", sa.ForeignKey("sl_tables.id")),
sa.Column("column_id", sa.ForeignKey("sl_columns.id")),
)
dataset_column_association_table = sa.Table(
"sl_dataset_columns",
Base.metadata,
sa.Column("dataset_id", sa.ForeignKey("sl_datasets.id")),
sa.Column("column_id", sa.ForeignKey("sl_columns.id")),
)
dataset_table_association_table = sa.Table(
"sl_dataset_tables",
Base.metadata,
sa.Column("dataset_id", sa.ForeignKey("sl_datasets.id")),
sa.Column("table_id", sa.ForeignKey("sl_tables.id")),
)
class NewColumn(Base):
__tablename__ = "sl_columns"
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Text)
type = sa.Column(sa.Text)
expression = sa.Column(sa.Text)
is_physical = sa.Column(sa.Boolean, default=True)
description = sa.Column(sa.Text)
warning_text = sa.Column(sa.Text)
is_temporal = sa.Column(sa.Boolean, default=False)
is_aggregation = sa.Column(sa.Boolean, default=False)
is_additive = sa.Column(sa.Boolean, default=False)
is_spatial = sa.Column(sa.Boolean, default=False)
is_partition = sa.Column(sa.Boolean, default=False)
is_increase_desired = sa.Column(sa.Boolean, default=True)
is_managed_externally = sa.Column(sa.Boolean, nullable=False, default=False)
external_url = sa.Column(sa.Text, nullable=True)
extra_json = sa.Column(sa.Text, default="{}")
class NewTable(Base):
__tablename__ = "sl_tables"
__table_args__ = (UniqueConstraint("database_id", "catalog", "schema", "name"),)
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Text)
schema = sa.Column(sa.Text)
catalog = sa.Column(sa.Text)
database_id = sa.Column(sa.Integer, sa.ForeignKey("dbs.id"), nullable=False)
database: Database = relationship(
"Database",
backref=backref("new_tables", cascade="all, delete-orphan"),
foreign_keys=[database_id],
)
columns: List[NewColumn] = relationship(
"NewColumn", secondary=table_column_association_table, cascade="all, delete"
)
is_managed_externally = sa.Column(sa.Boolean, nullable=False, default=False)
external_url = sa.Column(sa.Text, nullable=True)
class NewDataset(Base):
__tablename__ = "sl_datasets"
id = sa.Column(sa.Integer, primary_key=True)
sqlatable_id = sa.Column(sa.Integer, nullable=True, unique=True)
name = sa.Column(sa.Text)
expression = sa.Column(sa.Text)
tables: List[NewTable] = relationship(
"NewTable", secondary=dataset_table_association_table
)
columns: List[NewColumn] = relationship(
"NewColumn", secondary=dataset_column_association_table, cascade="all, delete"
)
is_physical = sa.Column(sa.Boolean, default=False)
is_managed_externally = sa.Column(sa.Boolean, nullable=False, default=False)
external_url = sa.Column(sa.Text, nullable=True)
TEMPORAL_TYPES = {date, datetime, time, timedelta}
def is_column_type_temporal(column_type: TypeEngine) -> bool:
try:
return column_type.python_type in TEMPORAL_TYPES
except NotImplementedError:
return False
def load_or_create_tables(
session: Session,
database_id: int,
default_schema: Optional[str],
tables: Set[Table],
conditional_quote: Callable[[str], str],
) -> List[NewTable]:
"""
Load or create new table model instances.
"""
if not tables:
return []
# set the default schema in tables that don't have it
if default_schema:
tables = list(tables)
for i, table in enumerate(tables):
if table.schema is None:
tables[i] = Table(table.table, default_schema, table.catalog)
# load existing tables
predicate = or_(
*[
and_(
NewTable.database_id == database_id,
NewTable.schema == table.schema,
NewTable.name == table.table,
)
for table in tables
]
)
new_tables = session.query(NewTable).filter(predicate).all()
# use original database model to get the engine
engine = (
session.query(OriginalDatabase)
.filter_by(id=database_id)
.one()
.get_sqla_engine(default_schema)
)
inspector = inspect(engine)
# add missing tables
existing = {(table.schema, table.name) for table in new_tables}
for table in tables:
if (table.schema, table.table) not in existing:
column_metadata = inspector.get_columns(table.table, schema=table.schema)
columns = [
NewColumn(
name=column["name"],
type=str(column["type"]),
expression=conditional_quote(column["name"]),
is_temporal=is_column_type_temporal(column["type"]),
is_aggregation=False,
is_physical=True,
is_spatial=False,
is_partition=False,
is_increase_desired=True,
)
for column in column_metadata
]
new_tables.append(
NewTable(
name=table.table,
schema=table.schema,
catalog=None,
database_id=database_id,
columns=columns,
)
)
existing.add((table.schema, table.table))
return new_tables
def after_insert(target: SqlaTable) -> None: # pylint: disable=too-many-locals
"""
Copy old datasets to the new models.
"""
session = inspect(target).session
# get DB-specific conditional quoter for expressions that point to columns or
# table names
database = (
target.database
or session.query(Database).filter_by(id=target.database_id).first()
)
if not database:
return
url = make_url_safe(database.sqlalchemy_uri)
dialect_class = url.get_dialect()
conditional_quote = dialect_class().identifier_preparer.quote
# create columns
columns = []
for column in target.columns:
# ``is_active`` might be ``None`` at this point, but it defaults to ``True``.
if column.is_active is False:
continue
try:
extra_json = json.loads(column.extra or "{}")
except json.decoder.JSONDecodeError:
extra_json = {}
for attr in {"groupby", "filterable", "verbose_name", "python_date_format"}:
value = getattr(column, attr)
if value:
extra_json[attr] = value
columns.append(
NewColumn(
name=column.column_name,
type=column.type or "Unknown",
expression=column.expression or conditional_quote(column.column_name),
description=column.description,
is_temporal=column.is_dttm,
is_aggregation=False,
is_physical=column.expression is None or column.expression == "",
is_spatial=False,
is_partition=False,
is_increase_desired=True,
extra_json=json.dumps(extra_json) if extra_json else None,
is_managed_externally=target.is_managed_externally,
external_url=target.external_url,
),
)
# create metrics
for metric in target.metrics:
try:
extra_json = json.loads(metric.extra or "{}")
except json.decoder.JSONDecodeError:
extra_json = {}
for attr in {"verbose_name", "metric_type", "d3format"}:
value = getattr(metric, attr)
if value:
extra_json[attr] = value
is_additive = (
metric.metric_type and metric.metric_type.lower() in ADDITIVE_METRIC_TYPES
)
columns.append(
NewColumn(
name=metric.metric_name,
type="Unknown", # figuring this out would require a type inferrer
expression=metric.expression,
warning_text=metric.warning_text,
description=metric.description,
is_aggregation=True,
is_additive=is_additive,
is_physical=False,
is_spatial=False,
is_partition=False,
is_increase_desired=True,
extra_json=json.dumps(extra_json) if extra_json else None,
is_managed_externally=target.is_managed_externally,
external_url=target.external_url,
),
)
# physical dataset
if not target.sql:
physical_columns = [column for column in columns if column.is_physical]
# create table
table = NewTable(
name=target.table_name,
schema=target.schema,
catalog=None, # currently not supported
database_id=target.database_id,
columns=physical_columns,
is_managed_externally=target.is_managed_externally,
external_url=target.external_url,
)
tables = [table]
# virtual dataset
else:
# mark all columns as virtual (not physical)
for column in columns:
column.is_physical = False
# find referenced tables
referenced_tables = extract_table_references(target.sql, dialect_class.name)
tables = load_or_create_tables(
session,
target.database_id,
target.schema,
referenced_tables,
conditional_quote,
)
# create the new dataset
dataset = NewDataset(
sqlatable_id=target.id,
name=target.table_name,
expression=target.sql or conditional_quote(target.table_name),
tables=tables,
columns=columns,
is_physical=not target.sql,
is_managed_externally=target.is_managed_externally,
external_url=target.external_url,
)
session.add(dataset)
def upgrade():
# Create tables for the new models.
op.create_table(
"sl_columns",
# AuditMixinNullable
sa.Column("created_on", sa.DateTime(), nullable=True),
sa.Column("changed_on", sa.DateTime(), nullable=True),
sa.Column("created_by_fk", sa.Integer(), nullable=True),
sa.Column("changed_by_fk", sa.Integer(), nullable=True),
# ExtraJSONMixin
sa.Column("extra_json", sa.Text(), nullable=True),
# ImportExportMixin
sa.Column("uuid", UUIDType(binary=True), primary_key=False, default=uuid4),
# Column
sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False),
sa.Column("name", sa.TEXT(), nullable=False),
sa.Column("type", sa.TEXT(), nullable=False),
sa.Column("expression", sa.TEXT(), nullable=False),
sa.Column(
"is_physical",
sa.BOOLEAN(),
nullable=False,
default=True,
),
sa.Column("description", sa.TEXT(), nullable=True),
sa.Column("warning_text", sa.TEXT(), nullable=True),
sa.Column("unit", sa.TEXT(), nullable=True),
sa.Column("is_temporal", sa.BOOLEAN(), nullable=False),
sa.Column(
"is_spatial",
sa.BOOLEAN(),
nullable=False,
default=False,
),
sa.Column(
"is_partition",
sa.BOOLEAN(),
nullable=False,
default=False,
),
sa.Column(
"is_aggregation",
sa.BOOLEAN(),
nullable=False,
default=False,
),
sa.Column(
"is_additive",
sa.BOOLEAN(),
nullable=False,
default=False,
),
sa.Column(
"is_increase_desired",
sa.BOOLEAN(),
nullable=False,
default=True,
),
sa.Column(
"is_managed_externally",
sa.Boolean(),
nullable=False,
server_default=sa.false(),
),
sa.Column("external_url", sa.Text(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
with op.batch_alter_table("sl_columns") as batch_op:
batch_op.create_unique_constraint("uq_sl_columns_uuid", ["uuid"])
op.create_table(
"sl_tables",
# AuditMixinNullable
sa.Column("created_on", sa.DateTime(), nullable=True),
sa.Column("changed_on", sa.DateTime(), nullable=True),
sa.Column("created_by_fk", sa.Integer(), nullable=True),
sa.Column("changed_by_fk", sa.Integer(), nullable=True),
# ExtraJSONMixin
sa.Column("extra_json", sa.Text(), nullable=True),
# ImportExportMixin
sa.Column("uuid", UUIDType(binary=True), primary_key=False, default=uuid4),
# Table
sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False),
sa.Column("database_id", sa.INTEGER(), autoincrement=False, nullable=False),
sa.Column("catalog", sa.TEXT(), nullable=True),
sa.Column("schema", sa.TEXT(), nullable=True),
sa.Column("name", sa.TEXT(), nullable=False),
sa.Column(
"is_managed_externally",
sa.Boolean(),
nullable=False,
server_default=sa.false(),
),
sa.Column("external_url", sa.Text(), nullable=True),
sa.ForeignKeyConstraint(["database_id"], ["dbs.id"], name="sl_tables_ibfk_1"),
sa.PrimaryKeyConstraint("id"),
)
with op.batch_alter_table("sl_tables") as batch_op:
batch_op.create_unique_constraint("uq_sl_tables_uuid", ["uuid"])
op.create_table(
"sl_table_columns",
sa.Column("table_id", sa.INTEGER(), autoincrement=False, nullable=False),
sa.Column("column_id", sa.INTEGER(), autoincrement=False, nullable=False),
sa.ForeignKeyConstraint(
["column_id"], ["sl_columns.id"], name="sl_table_columns_ibfk_2"
),
sa.ForeignKeyConstraint(
["table_id"], ["sl_tables.id"], name="sl_table_columns_ibfk_1"
),
)
op.create_table(
"sl_datasets",
# AuditMixinNullable
sa.Column("created_on", sa.DateTime(), nullable=True),
sa.Column("changed_on", sa.DateTime(), nullable=True),
sa.Column("created_by_fk", sa.Integer(), nullable=True),
sa.Column("changed_by_fk", sa.Integer(), nullable=True),
# ExtraJSONMixin
sa.Column("extra_json", sa.Text(), nullable=True),
# ImportExportMixin
sa.Column("uuid", UUIDType(binary=True), primary_key=False, default=uuid4),
# Dataset
sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False),
sa.Column("sqlatable_id", sa.INTEGER(), nullable=True),
sa.Column("name", sa.TEXT(), nullable=False),
sa.Column("expression", sa.TEXT(), nullable=False),
sa.Column(
"is_physical",
sa.BOOLEAN(),
nullable=False,
default=False,
),
sa.Column(
"is_managed_externally",
sa.Boolean(),
nullable=False,
server_default=sa.false(),
),
sa.Column("external_url", sa.Text(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
with op.batch_alter_table("sl_datasets") as batch_op:
batch_op.create_unique_constraint("uq_sl_datasets_uuid", ["uuid"])
batch_op.create_unique_constraint(
"uq_sl_datasets_sqlatable_id", ["sqlatable_id"]
)
op.create_table(
"sl_dataset_columns",
sa.Column("dataset_id", sa.INTEGER(), autoincrement=False, nullable=False),
sa.Column("column_id", sa.INTEGER(), autoincrement=False, nullable=False),
sa.ForeignKeyConstraint(
["column_id"], ["sl_columns.id"], name="sl_dataset_columns_ibfk_2"
),
sa.ForeignKeyConstraint(
["dataset_id"], ["sl_datasets.id"], name="sl_dataset_columns_ibfk_1"
),
)
op.create_table(
"sl_dataset_tables",
sa.Column("dataset_id", sa.INTEGER(), autoincrement=False, nullable=False),
sa.Column("table_id", sa.INTEGER(), autoincrement=False, nullable=False),
sa.ForeignKeyConstraint(
["dataset_id"], ["sl_datasets.id"], name="sl_dataset_tables_ibfk_1"
),
sa.ForeignKeyConstraint(
["table_id"], ["sl_tables.id"], name="sl_dataset_tables_ibfk_2"
),
)
# migrate existing datasets to the new models
bind = op.get_bind()
session = db.Session(bind=bind) # pylint: disable=no-member
datasets = session.query(SqlaTable).all()
for dataset in datasets:
dataset.fetch_columns_and_metrics(session)
after_insert(target=dataset)
def upgrade() -> None:
pass
def downgrade():
op.drop_table("sl_dataset_columns")
op.drop_table("sl_dataset_tables")
op.drop_table("sl_datasets")
op.drop_table("sl_table_columns")
op.drop_table("sl_tables")
op.drop_table("sl_columns")
pass

View File

@ -38,7 +38,7 @@ from sqlalchemy_utils import UUIDType
from superset import db
from superset.migrations.versions.b56500de1855_add_uuid_column_to_import_mixin import (
add_uuids,
assign_uuids,
models,
update_dashboards,
)
@ -73,7 +73,7 @@ def upgrade():
default=uuid4,
),
)
add_uuids(model, table_name, session)
assign_uuids(model, session)
# add uniqueness constraint
with op.batch_alter_table(table_name) as batch_op:

View File

@ -71,7 +71,7 @@ def downgrade_filters(native_filters: Iterable[Dict[str, Any]]) -> int:
filter_state = default_data_mask.get("filterState")
if filter_state is not None:
changed_filters += 1
value = filter_state["value"]
value = filter_state.get("value")
native_filter["defaultValue"] = value
return changed_filters

View File

@ -408,12 +408,14 @@ class Database(
except Exception as ex:
raise self.db_engine_spec.get_dbapi_mapped_exception(ex)
@property
def quote_identifier(self) -> Callable[[str], str]:
"""Add quotes to potential identifiter expressions if needed"""
return self.get_dialect().identifier_preparer.quote
def get_reserved_words(self) -> Set[str]:
return self.get_dialect().preparer.reserved_words
def get_quoter(self) -> Callable[[str, Any], str]:
return self.get_dialect().identifier_preparer.quote
def get_df( # pylint: disable=too-many-locals
self,
sql: str,

View File

@ -477,7 +477,7 @@ class ExtraJSONMixin:
@property
def extra(self) -> Dict[str, Any]:
try:
return json.loads(self.extra_json)
return json.loads(self.extra_json) if self.extra_json else {}
except (TypeError, JSONDecodeError) as exc:
logger.error(
"Unable to load an extra json: %r. Leaving empty.", exc, exc_info=True
@ -522,18 +522,23 @@ class CertificationMixin:
def clone_model(
target: Model, ignore: Optional[List[str]] = None, **kwargs: Any
target: Model,
ignore: Optional[List[str]] = None,
keep_relations: Optional[List[str]] = None,
**kwargs: Any,
) -> Model:
"""
Clone a SQLAlchemy model.
Clone a SQLAlchemy model. By default will only clone naive column attributes.
To include relationship attributes, use `keep_relations`.
"""
ignore = ignore or []
table = target.__table__
primary_keys = table.primary_key.columns.keys()
data = {
attr: getattr(target, attr)
for attr in table.columns.keys()
if attr not in table.primary_key.columns.keys() and attr not in ignore
for attr in list(table.columns.keys()) + (keep_relations or [])
if attr not in primary_keys and attr not in ignore
}
data.update(kwargs)

View File

@ -186,7 +186,7 @@ def execute_sql_statement( # pylint: disable=too-many-arguments,too-many-locals
apply_ctas: bool = False,
) -> SupersetResultSet:
"""Executes a single SQL statement"""
database = query.database
database: Database = query.database
db_engine_spec = database.db_engine_spec
parsed_query = ParsedQuery(sql_statement)
sql = parsed_query.stripped()

View File

@ -18,7 +18,7 @@ import logging
import re
from dataclasses import dataclass
from enum import Enum
from typing import cast, List, Optional, Set, Tuple
from typing import Any, cast, Iterator, List, Optional, Set, Tuple
from urllib import parse
import sqlparse
@ -47,10 +47,16 @@ from sqlparse.utils import imt
from superset.exceptions import QueryClauseValidationException
try:
from sqloxide import parse_sql as sqloxide_parse
except: # pylint: disable=bare-except
sqloxide_parse = None
RESULT_OPERATIONS = {"UNION", "INTERSECT", "EXCEPT", "SELECT"}
ON_KEYWORD = "ON"
PRECEDES_TABLE_NAME = {"FROM", "JOIN", "DESCRIBE", "WITH", "LEFT JOIN", "RIGHT JOIN"}
CTE_PREFIX = "CTE__"
logger = logging.getLogger(__name__)
@ -176,6 +182,9 @@ class Table:
if part
)
def __eq__(self, __o: object) -> bool:
return str(self) == str(__o)
class ParsedQuery:
def __init__(self, sql_statement: str, strip_comments: bool = False):
@ -698,3 +707,75 @@ def insert_rls(
)
return token_list
# mapping between sqloxide and SQLAlchemy dialects
SQLOXITE_DIALECTS = {
"ansi": {"trino", "trinonative", "presto"},
"hive": {"hive", "databricks"},
"ms": {"mssql"},
"mysql": {"mysql"},
"postgres": {
"cockroachdb",
"hana",
"netezza",
"postgres",
"postgresql",
"redshift",
"vertica",
},
"snowflake": {"snowflake"},
"sqlite": {"sqlite", "gsheets", "shillelagh"},
"clickhouse": {"clickhouse"},
}
RE_JINJA_VAR = re.compile(r"\{\{[^\{\}]+\}\}")
RE_JINJA_BLOCK = re.compile(r"\{[%#][^\{\}%#]+[%#]\}")
def extract_table_references(
sql_text: str, sqla_dialect: str, show_warning: bool = True
) -> Set["Table"]:
"""
Return all the dependencies from a SQL sql_text.
"""
dialect = "generic"
tree = None
if sqloxide_parse:
for dialect, sqla_dialects in SQLOXITE_DIALECTS.items():
if sqla_dialect in sqla_dialects:
break
sql_text = RE_JINJA_BLOCK.sub(" ", sql_text)
sql_text = RE_JINJA_VAR.sub("abc", sql_text)
try:
tree = sqloxide_parse(sql_text, dialect=dialect)
except Exception as ex: # pylint: disable=broad-except
if show_warning:
logger.warning(
"\nUnable to parse query with sqloxide:\n%s\n%s", sql_text, ex
)
# fallback to sqlparse
if not tree:
parsed = ParsedQuery(sql_text)
return parsed.tables
def find_nodes_by_key(element: Any, target: str) -> Iterator[Any]:
"""
Find all nodes in a SQL tree matching a given key.
"""
if isinstance(element, list):
for child in element:
yield from find_nodes_by_key(child, target)
elif isinstance(element, dict):
for key, value in element.items():
if key == target:
yield value
else:
yield from find_nodes_by_key(value, target)
return {
Table(*[part["value"] for part in table["name"][::-1]])
for table in find_nodes_by_key(tree, "Table")
}

View File

@ -24,26 +24,41 @@ addition to a table, new models for columns, metrics, and datasets were also int
These models are not fully implemented, and shouldn't be used yet.
"""
from typing import List
from typing import Any, Dict, Iterable, List, Optional, TYPE_CHECKING
import sqlalchemy as sa
from flask_appbuilder import Model
from sqlalchemy.orm import backref, relationship
from sqlalchemy import inspect
from sqlalchemy.orm import backref, relationship, Session
from sqlalchemy.schema import UniqueConstraint
from sqlalchemy.sql import and_, or_
from superset.columns.models import Column
from superset.connectors.sqla.utils import get_physical_table_metadata
from superset.models.core import Database
from superset.models.helpers import (
AuditMixinNullable,
ExtraJSONMixin,
ImportExportMixin,
)
from superset.sql_parse import Table as TableName
association_table = sa.Table(
if TYPE_CHECKING:
from superset.datasets.models import Dataset
table_column_association_table = sa.Table(
"sl_table_columns",
Model.metadata, # pylint: disable=no-member
sa.Column("table_id", sa.ForeignKey("sl_tables.id")),
sa.Column("column_id", sa.ForeignKey("sl_columns.id")),
sa.Column(
"table_id",
sa.ForeignKey("sl_tables.id", ondelete="cascade"),
primary_key=True,
),
sa.Column(
"column_id",
sa.ForeignKey("sl_columns.id", ondelete="cascade"),
primary_key=True,
),
)
@ -61,7 +76,6 @@ class Table(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin):
__table_args__ = (UniqueConstraint("database_id", "catalog", "schema", "name"),)
id = sa.Column(sa.Integer, primary_key=True)
database_id = sa.Column(sa.Integer, sa.ForeignKey("dbs.id"), nullable=False)
database: Database = relationship(
"Database",
@ -70,6 +84,19 @@ class Table(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin):
backref=backref("new_tables", cascade="all, delete-orphan"),
foreign_keys=[database_id],
)
# The relationship between datasets and columns is 1:n, but we use a
# many-to-many association table to avoid adding two mutually exclusive
# columns(dataset_id and table_id) to Column
columns: List[Column] = relationship(
"Column",
secondary=table_column_association_table,
cascade="all, delete-orphan",
single_parent=True,
# backref is needed for session to skip detaching `dataset` if only `column`
# is loaded.
backref="tables",
)
datasets: List["Dataset"] # will be populated by Dataset.tables backref
# We use ``sa.Text`` for these attributes because (1) in modern databases the
# performance is the same as ``VARCHAR``[1] and (2) because some table names can be
@ -80,13 +107,96 @@ class Table(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin):
schema = sa.Column(sa.Text)
name = sa.Column(sa.Text)
# The relationship between tables and columns is 1:n, but we use a many-to-many
# association to differentiate between the relationship between datasets and
# columns.
columns: List[Column] = relationship(
"Column", secondary=association_table, cascade="all, delete"
)
# Column is managed externally and should be read-only inside Superset
is_managed_externally = sa.Column(sa.Boolean, nullable=False, default=False)
external_url = sa.Column(sa.Text, nullable=True)
@property
def fullname(self) -> str:
return str(TableName(table=self.name, schema=self.schema, catalog=self.catalog))
def __repr__(self) -> str:
return f"<Table id={self.id} database_id={self.database_id} {self.fullname}>"
def sync_columns(self) -> None:
"""Sync table columns with the database. Keep metadata for existing columns"""
try:
column_metadata = get_physical_table_metadata(
self.database, self.name, self.schema
)
except Exception: # pylint: disable=broad-except
column_metadata = []
existing_columns = {column.name: column for column in self.columns}
quote_identifier = self.database.quote_identifier
def update_or_create_column(column_meta: Dict[str, Any]) -> Column:
column_name: str = column_meta["name"]
if column_name in existing_columns:
column = existing_columns[column_name]
else:
column = Column(name=column_name)
column.type = column_meta["type"]
column.is_temporal = column_meta["is_dttm"]
column.expression = quote_identifier(column_name)
column.is_aggregation = False
column.is_physical = True
column.is_spatial = False
column.is_partition = False # TODO: update with accurate is_partition
return column
self.columns = [update_or_create_column(col) for col in column_metadata]
@staticmethod
def bulk_load_or_create(
database: Database,
table_names: Iterable[TableName],
default_schema: Optional[str] = None,
sync_columns: Optional[bool] = False,
default_props: Optional[Dict[str, Any]] = None,
) -> List["Table"]:
"""
Load or create multiple Table instances.
"""
if not table_names:
return []
if not database.id:
raise Exception("Database must be already saved to metastore")
default_props = default_props or {}
session: Session = inspect(database).session
# load existing tables
predicate = or_(
*[
and_(
Table.database_id == database.id,
Table.schema == (table.schema or default_schema),
Table.name == table.table,
)
for table in table_names
]
)
all_tables = session.query(Table).filter(predicate).order_by(Table.id).all()
# add missing tables and pull its columns
existing = {(table.schema, table.name) for table in all_tables}
for table in table_names:
schema = table.schema or default_schema
name = table.table
if (schema, name) not in existing:
new_table = Table(
database=database,
database_id=database.id,
name=name,
schema=schema,
catalog=None,
**default_props,
)
if sync_columns:
new_table.sync_columns()
all_tables.append(new_table)
existing.add((schema, name))
session.add(new_table)
return all_tables

View File

@ -16,11 +16,11 @@
# under the License.
import copy
import json
from unittest.mock import patch
import yaml
from flask import g
from superset import db, security_manager
from superset import db
from superset.commands.exceptions import CommandInvalidError
from superset.commands.importers.v1.assets import ImportAssetsCommand
from superset.commands.importers.v1.utils import is_valid_config
@ -58,10 +58,13 @@ class TestImportersV1Utils(SupersetTestCase):
class TestImportAssetsCommand(SupersetTestCase):
@patch("superset.dashboards.commands.importers.v1.utils.g")
def test_import_assets(self, mock_g):
def setUp(self):
user = self.get_user("admin")
self.user = user
setattr(g, "user", user)
def test_import_assets(self):
"""Test that we can import multiple assets"""
mock_g.user = security_manager.find_user("admin")
contents = {
"metadata.yaml": yaml.safe_dump(metadata_config),
"databases/imported_database.yaml": yaml.safe_dump(database_config),
@ -141,7 +144,7 @@ class TestImportAssetsCommand(SupersetTestCase):
database = dataset.database
assert str(database.uuid) == database_config["uuid"]
assert dashboard.owners == [mock_g.user]
assert dashboard.owners == [self.user]
dashboard.owners = []
chart.owners = []
@ -153,11 +156,8 @@ class TestImportAssetsCommand(SupersetTestCase):
db.session.delete(database)
db.session.commit()
@patch("superset.dashboards.commands.importers.v1.utils.g")
def test_import_v1_dashboard_overwrite(self, mock_g):
def test_import_v1_dashboard_overwrite(self):
"""Test that assets can be overwritten"""
mock_g.user = security_manager.find_user("admin")
contents = {
"metadata.yaml": yaml.safe_dump(metadata_config),
"databases/imported_database.yaml": yaml.safe_dump(database_config),

View File

@ -111,11 +111,10 @@ def _commit_slices(slices: List[Slice]):
def _create_world_bank_dashboard(table: SqlaTable, slices: List[Slice]) -> Dashboard:
from superset.examples.helpers import update_slice_ids
from superset.examples.world_bank import dashboard_positions
pos = dashboard_positions
from superset.examples.helpers import update_slice_ids
update_slice_ids(pos, slices)
table.fetch_metadata()

View File

@ -455,7 +455,8 @@ class TestDatabaseModel(SupersetTestCase):
# make sure the columns have been mapped properly
assert len(table.columns) == 4
table.fetch_metadata()
table.fetch_metadata(commit=False)
# assert that the removed column has been dropped and
# the physical and calculated columns are present
assert {col.column_name for col in table.columns} == {
@ -473,6 +474,8 @@ class TestDatabaseModel(SupersetTestCase):
assert VIRTUAL_TABLE_STRING_TYPES[backend].match(cols["mycase"].type)
assert cols["expr"].expression == "case when 1 then 1 else 0 end"
db.session.delete(table)
@patch("superset.models.core.Database.db_engine_spec", BigQueryEngineSpec)
def test_labels_expected_on_mutated_query(self):
query_obj = {

View File

@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
import unittest
import uuid
from datetime import date, datetime, time, timedelta
from decimal import Decimal

View File

@ -17,7 +17,7 @@
# pylint: disable=redefined-outer-name, import-outside-toplevel
import importlib
from typing import Any, Iterator
from typing import Any, Callable, Iterator
import pytest
from pytest_mock import MockFixture
@ -31,25 +31,33 @@ from superset.initialization import SupersetAppInitializer
@pytest.fixture
def session(mocker: MockFixture) -> Iterator[Session]:
def get_session(mocker: MockFixture) -> Callable[[], Session]:
"""
Create an in-memory SQLite session to test models.
"""
engine = create_engine("sqlite://")
Session_ = sessionmaker(bind=engine) # pylint: disable=invalid-name
in_memory_session = Session_()
# flask calls session.remove()
in_memory_session.remove = lambda: None
def get_session():
Session_ = sessionmaker(bind=engine) # pylint: disable=invalid-name
in_memory_session = Session_()
# patch session
mocker.patch(
"superset.security.SupersetSecurityManager.get_session",
return_value=in_memory_session,
)
mocker.patch("superset.db.session", in_memory_session)
# flask calls session.remove()
in_memory_session.remove = lambda: None
yield in_memory_session
# patch session
mocker.patch(
"superset.security.SupersetSecurityManager.get_session",
return_value=in_memory_session,
)
mocker.patch("superset.db.session", in_memory_session)
return in_memory_session
return get_session
@pytest.fixture
def session(get_session) -> Iterator[Session]:
yield get_session()
@pytest.fixture(scope="module")

View File

@ -0,0 +1,118 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict, TYPE_CHECKING
import pytest
if TYPE_CHECKING:
from superset.connectors.sqla.models import SqlMetric, TableColumn
@pytest.fixture
def columns_default() -> Dict[str, Any]:
"""Default props for new columns"""
return {
"changed_by": 1,
"created_by": 1,
"datasets": [],
"tables": [],
"is_additive": False,
"is_aggregation": False,
"is_dimensional": False,
"is_filterable": True,
"is_increase_desired": True,
"is_partition": False,
"is_physical": True,
"is_spatial": False,
"is_temporal": False,
"description": None,
"extra_json": "{}",
"unit": None,
"warning_text": None,
"is_managed_externally": False,
"external_url": None,
}
@pytest.fixture
def sample_columns() -> Dict["TableColumn", Dict[str, Any]]:
from superset.connectors.sqla.models import TableColumn
return {
TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"): {
"name": "ds",
"expression": "ds",
"type": "TIMESTAMP",
"is_temporal": True,
"is_physical": True,
},
TableColumn(column_name="num_boys", type="INTEGER", groupby=True): {
"name": "num_boys",
"expression": "num_boys",
"type": "INTEGER",
"is_dimensional": True,
"is_physical": True,
},
TableColumn(column_name="region", type="VARCHAR", groupby=True): {
"name": "region",
"expression": "region",
"type": "VARCHAR",
"is_dimensional": True,
"is_physical": True,
},
TableColumn(
column_name="profit",
type="INTEGER",
groupby=False,
expression="revenue-expenses",
): {
"name": "profit",
"expression": "revenue-expenses",
"type": "INTEGER",
"is_physical": False,
},
}
@pytest.fixture
def sample_metrics() -> Dict["SqlMetric", Dict[str, Any]]:
from superset.connectors.sqla.models import SqlMetric
return {
SqlMetric(metric_name="cnt", expression="COUNT(*)", metric_type="COUNT"): {
"name": "cnt",
"expression": "COUNT(*)",
"extra_json": '{"metric_type": "COUNT"}',
"type": "UNKNOWN",
"is_additive": True,
"is_aggregation": True,
"is_filterable": False,
"is_physical": False,
},
SqlMetric(
metric_name="avg revenue", expression="AVG(revenue)", metric_type="AVG"
): {
"name": "avg revenue",
"expression": "AVG(revenue)",
"extra_json": '{"metric_type": "AVG"}',
"type": "UNKNOWN",
"is_additive": False,
"is_aggregation": True,
"is_filterable": False,
"is_physical": False,
},
}

File diff suppressed because it is too large Load Diff

View File

@ -1,16 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@ -1,56 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=import-outside-toplevel, unused-argument
"""
Test the SIP-68 migration.
"""
from pytest_mock import MockerFixture
from superset.sql_parse import Table
def test_extract_table_references(mocker: MockerFixture, app_context: None) -> None:
"""
Test the ``extract_table_references`` helper function.
"""
from superset.migrations.shared.utils import extract_table_references
assert extract_table_references("SELECT 1", "trino") == set()
assert extract_table_references("SELECT 1 FROM some_table", "trino") == {
Table(table="some_table", schema=None, catalog=None)
}
assert extract_table_references(
"SELECT 1 FROM some_catalog.some_schema.some_table", "trino"
) == {Table(table="some_table", schema="some_schema", catalog="some_catalog")}
assert extract_table_references(
"SELECT * FROM some_table JOIN other_table ON some_table.id = other_table.id",
"trino",
) == {
Table(table="some_table", schema=None, catalog=None),
Table(table="other_table", schema=None, catalog=None),
}
# test falling back to sqlparse
logger = mocker.patch("superset.migrations.shared.utils.logger")
sql = "SELECT * FROM table UNION ALL SELECT * FROM other_table"
assert extract_table_references(
sql,
"trino",
) == {Table(table="other_table", schema=None, catalog=None)}
logger.warning.assert_called_with("Unable to parse query with sqloxide: %s", sql)

View File

@ -29,6 +29,7 @@ from sqlparse.tokens import Name
from superset.exceptions import QueryClauseValidationException
from superset.sql_parse import (
add_table_name,
extract_table_references,
get_rls_for_table,
has_table_query,
insert_rls,
@ -1468,3 +1469,51 @@ def test_get_rls_for_table(mocker: MockerFixture, app_context: None) -> None:
dataset.get_sqla_row_level_filters.return_value = []
assert get_rls_for_table(candidate, 1, "public") is None
def test_extract_table_references(mocker: MockerFixture) -> None:
"""
Test the ``extract_table_references`` helper function.
"""
assert extract_table_references("SELECT 1", "trino") == set()
assert extract_table_references("SELECT 1 FROM some_table", "trino") == {
Table(table="some_table", schema=None, catalog=None)
}
assert extract_table_references("SELECT {{ jinja }} FROM some_table", "trino") == {
Table(table="some_table", schema=None, catalog=None)
}
assert extract_table_references(
"SELECT 1 FROM some_catalog.some_schema.some_table", "trino"
) == {Table(table="some_table", schema="some_schema", catalog="some_catalog")}
# with identifier quotes
assert extract_table_references(
"SELECT 1 FROM `some_catalog`.`some_schema`.`some_table`", "mysql"
) == {Table(table="some_table", schema="some_schema", catalog="some_catalog")}
assert extract_table_references(
'SELECT 1 FROM "some_catalog".some_schema."some_table"', "trino"
) == {Table(table="some_table", schema="some_schema", catalog="some_catalog")}
assert extract_table_references(
"SELECT * FROM some_table JOIN other_table ON some_table.id = other_table.id",
"trino",
) == {
Table(table="some_table", schema=None, catalog=None),
Table(table="other_table", schema=None, catalog=None),
}
# test falling back to sqlparse
logger = mocker.patch("superset.sql_parse.logger")
sql = "SELECT * FROM table UNION ALL SELECT * FROM other_table"
assert extract_table_references(
sql,
"trino",
) == {Table(table="other_table", schema=None, catalog=None)}
logger.warning.assert_called_once()
logger = mocker.patch("superset.migrations.shared.utils.logger")
sql = "SELECT * FROM table UNION ALL SELECT * FROM other_table"
assert extract_table_references(sql, "trino", show_warning=False) == {
Table(table="other_table", schema=None, catalog=None)
}
logger.warning.assert_not_called()

View File

@ -14,3 +14,17 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any
from superset import security_manager
def get_test_user(id_: int, username: str) -> Any:
"""Create a sample test user"""
return security_manager.user_model(
id=id_,
username=username,
first_name=username,
last_name=username,
email=f"{username}@example.com",
)