mirror of https://github.com/apache/superset.git
perf: refactor SIP-68 db migrations with INSERT SELECT FROM (#19421)
This commit is contained in:
parent
1c5d3b73df
commit
231716cb50
|
@ -23,7 +23,6 @@ tables, metrics, and datasets were also introduced.
|
|||
|
||||
These models are not fully implemented, and shouldn't be used yet.
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask_appbuilder import Model
|
||||
|
||||
|
@ -33,6 +32,8 @@ from superset.models.helpers import (
|
|||
ImportExportMixin,
|
||||
)
|
||||
|
||||
UNKOWN_TYPE = "UNKNOWN"
|
||||
|
||||
|
||||
class Column(
|
||||
Model,
|
||||
|
@ -52,51 +53,58 @@ class Column(
|
|||
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
# We use ``sa.Text`` for these attributes because (1) in modern databases the
|
||||
# performance is the same as ``VARCHAR``[1] and (2) because some table names can be
|
||||
# **really** long (eg, Google Sheets URLs).
|
||||
#
|
||||
# [1] https://www.postgresql.org/docs/9.1/datatype-character.html
|
||||
name = sa.Column(sa.Text)
|
||||
type = sa.Column(sa.Text)
|
||||
|
||||
# Columns are defined by expressions. For tables, these are the actual columns names,
|
||||
# and should match the ``name`` attribute. For datasets, these can be any valid SQL
|
||||
# expression. If the SQL expression is an aggregation the column is a metric,
|
||||
# otherwise it's a computed column.
|
||||
expression = sa.Column(sa.Text)
|
||||
|
||||
# Does the expression point directly to a physical column?
|
||||
is_physical = sa.Column(sa.Boolean, default=True)
|
||||
|
||||
# Additional metadata describing the column.
|
||||
description = sa.Column(sa.Text)
|
||||
warning_text = sa.Column(sa.Text)
|
||||
unit = sa.Column(sa.Text)
|
||||
|
||||
# Is this a time column? Useful for plotting time series.
|
||||
is_temporal = sa.Column(sa.Boolean, default=False)
|
||||
|
||||
# Is this a spatial column? This could be leveraged in the future for spatial
|
||||
# visualizations.
|
||||
is_spatial = sa.Column(sa.Boolean, default=False)
|
||||
|
||||
# Is this column a partition? Useful for scheduling queries and previewing the latest
|
||||
# data.
|
||||
is_partition = sa.Column(sa.Boolean, default=False)
|
||||
|
||||
# Is this column an aggregation (metric)?
|
||||
is_aggregation = sa.Column(sa.Boolean, default=False)
|
||||
|
||||
# Assuming the column is an aggregation, is it additive? Useful for determining which
|
||||
# aggregations can be done on the metric. Eg, ``COUNT(DISTINCT user_id)`` is not
|
||||
# additive, so it shouldn't be used in a ``SUM``.
|
||||
is_additive = sa.Column(sa.Boolean, default=False)
|
||||
|
||||
# Is this column an aggregation (metric)?
|
||||
is_aggregation = sa.Column(sa.Boolean, default=False)
|
||||
|
||||
is_filterable = sa.Column(sa.Boolean, nullable=False, default=True)
|
||||
is_dimensional = sa.Column(sa.Boolean, nullable=False, default=False)
|
||||
|
||||
# Is an increase desired? Useful for displaying the results of A/B tests, or setting
|
||||
# up alerts. Eg, this is true for "revenue", but false for "latency".
|
||||
is_increase_desired = sa.Column(sa.Boolean, default=True)
|
||||
|
||||
# Column is managed externally and should be read-only inside Superset
|
||||
is_managed_externally = sa.Column(sa.Boolean, nullable=False, default=False)
|
||||
|
||||
# Is this column a partition? Useful for scheduling queries and previewing the latest
|
||||
# data.
|
||||
is_partition = sa.Column(sa.Boolean, default=False)
|
||||
|
||||
# Does the expression point directly to a physical column?
|
||||
is_physical = sa.Column(sa.Boolean, default=True)
|
||||
|
||||
# Is this a spatial column? This could be leveraged in the future for spatial
|
||||
# visualizations.
|
||||
is_spatial = sa.Column(sa.Boolean, default=False)
|
||||
|
||||
# Is this a time column? Useful for plotting time series.
|
||||
is_temporal = sa.Column(sa.Boolean, default=False)
|
||||
|
||||
# We use ``sa.Text`` for these attributes because (1) in modern databases the
|
||||
# performance is the same as ``VARCHAR``[1] and (2) because some table names can be
|
||||
# **really** long (eg, Google Sheets URLs).
|
||||
#
|
||||
# [1] https://www.postgresql.org/docs/9.1/datatype-character.html
|
||||
name = sa.Column(sa.Text)
|
||||
# Raw type as returned and used by db engine.
|
||||
type = sa.Column(sa.Text, default=UNKOWN_TYPE)
|
||||
|
||||
# Columns are defined by expressions. For tables, these are the actual columns names,
|
||||
# and should match the ``name`` attribute. For datasets, these can be any valid SQL
|
||||
# expression. If the SQL expression is an aggregation the column is a metric,
|
||||
# otherwise it's a computed column.
|
||||
expression = sa.Column(sa.Text)
|
||||
unit = sa.Column(sa.Text)
|
||||
|
||||
# Additional metadata describing the column.
|
||||
description = sa.Column(sa.Text)
|
||||
warning_text = sa.Column(sa.Text)
|
||||
external_url = sa.Column(sa.Text, nullable=True)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Column id={self.id}>"
|
||||
|
|
|
@ -31,7 +31,7 @@ from superset.models.helpers import AuditMixinNullable, ImportExportMixin, Query
|
|||
from superset.models.slice import Slice
|
||||
from superset.superset_typing import FilterValue, FilterValues, QueryObjectDict
|
||||
from superset.utils import core as utils
|
||||
from superset.utils.core import GenericDataType
|
||||
from superset.utils.core import GenericDataType, MediumText
|
||||
|
||||
METRIC_FORM_DATA_PARAMS = [
|
||||
"metric",
|
||||
|
@ -586,7 +586,7 @@ class BaseColumn(AuditMixinNullable, ImportExportMixin):
|
|||
type = Column(Text)
|
||||
groupby = Column(Boolean, default=True)
|
||||
filterable = Column(Boolean, default=True)
|
||||
description = Column(Text)
|
||||
description = Column(MediumText())
|
||||
is_dttm = None
|
||||
|
||||
# [optional] Set this to support import/export functionality
|
||||
|
@ -672,7 +672,7 @@ class BaseMetric(AuditMixinNullable, ImportExportMixin):
|
|||
metric_name = Column(String(255), nullable=False)
|
||||
verbose_name = Column(String(1024))
|
||||
metric_type = Column(String(32))
|
||||
description = Column(Text)
|
||||
description = Column(MediumText())
|
||||
d3format = Column(String(128))
|
||||
warning_text = Column(Text)
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@ from dataclasses import dataclass, field
|
|||
from datetime import datetime, timedelta
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
Dict,
|
||||
Hashable,
|
||||
|
@ -34,6 +35,7 @@ from typing import (
|
|||
Type,
|
||||
Union,
|
||||
)
|
||||
from uuid import uuid4
|
||||
|
||||
import dateutil.parser
|
||||
import numpy as np
|
||||
|
@ -72,13 +74,13 @@ from sqlalchemy.sql.expression import Label, Select, TextAsFrom
|
|||
from sqlalchemy.sql.selectable import Alias, TableClause
|
||||
|
||||
from superset import app, db, is_feature_enabled, security_manager
|
||||
from superset.columns.models import Column as NewColumn
|
||||
from superset.columns.models import Column as NewColumn, UNKOWN_TYPE
|
||||
from superset.common.db_query_status import QueryStatus
|
||||
from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric
|
||||
from superset.connectors.sqla.utils import (
|
||||
find_cached_objects_in_session,
|
||||
get_physical_table_metadata,
|
||||
get_virtual_table_metadata,
|
||||
load_or_create_tables,
|
||||
validate_adhoc_subquery,
|
||||
)
|
||||
from superset.datasets.models import Dataset as NewDataset
|
||||
|
@ -100,7 +102,12 @@ from superset.models.helpers import (
|
|||
clone_model,
|
||||
QueryResult,
|
||||
)
|
||||
from superset.sql_parse import ParsedQuery, sanitize_clause
|
||||
from superset.sql_parse import (
|
||||
extract_table_references,
|
||||
ParsedQuery,
|
||||
sanitize_clause,
|
||||
Table as TableName,
|
||||
)
|
||||
from superset.superset_typing import (
|
||||
AdhocColumn,
|
||||
AdhocMetric,
|
||||
|
@ -114,6 +121,7 @@ from superset.utils.core import (
|
|||
GenericDataType,
|
||||
get_column_name,
|
||||
is_adhoc_column,
|
||||
MediumText,
|
||||
QueryObjectFilterClause,
|
||||
remove_duplicates,
|
||||
)
|
||||
|
@ -130,6 +138,7 @@ ADDITIVE_METRIC_TYPES = {
|
|||
"sum",
|
||||
"doubleSum",
|
||||
}
|
||||
ADDITIVE_METRIC_TYPES_LOWER = {op.lower() for op in ADDITIVE_METRIC_TYPES}
|
||||
|
||||
|
||||
class SqlaQuery(NamedTuple):
|
||||
|
@ -215,13 +224,13 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
|
|||
__tablename__ = "table_columns"
|
||||
__table_args__ = (UniqueConstraint("table_id", "column_name"),)
|
||||
table_id = Column(Integer, ForeignKey("tables.id"))
|
||||
table = relationship(
|
||||
table: "SqlaTable" = relationship(
|
||||
"SqlaTable",
|
||||
backref=backref("columns", cascade="all, delete-orphan"),
|
||||
foreign_keys=[table_id],
|
||||
)
|
||||
is_dttm = Column(Boolean, default=False)
|
||||
expression = Column(Text)
|
||||
expression = Column(MediumText())
|
||||
python_date_format = Column(String(255))
|
||||
extra = Column(Text)
|
||||
|
||||
|
@ -417,6 +426,59 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
|
|||
|
||||
return attr_dict
|
||||
|
||||
def to_sl_column(
|
||||
self, known_columns: Optional[Dict[str, NewColumn]] = None
|
||||
) -> NewColumn:
|
||||
"""Convert a TableColumn to NewColumn"""
|
||||
column = known_columns.get(self.uuid) if known_columns else None
|
||||
if not column:
|
||||
column = NewColumn()
|
||||
|
||||
extra_json = self.get_extra_dict()
|
||||
for attr in {
|
||||
"verbose_name",
|
||||
"python_date_format",
|
||||
}:
|
||||
value = getattr(self, attr)
|
||||
if value:
|
||||
extra_json[attr] = value
|
||||
|
||||
column.uuid = self.uuid
|
||||
column.created_on = self.created_on
|
||||
column.changed_on = self.changed_on
|
||||
column.created_by = self.created_by
|
||||
column.changed_by = self.changed_by
|
||||
column.name = self.column_name
|
||||
column.type = self.type or UNKOWN_TYPE
|
||||
column.expression = self.expression or self.table.quote_identifier(
|
||||
self.column_name
|
||||
)
|
||||
column.description = self.description
|
||||
column.is_aggregation = False
|
||||
column.is_dimensional = self.groupby
|
||||
column.is_filterable = self.filterable
|
||||
column.is_increase_desired = True
|
||||
column.is_managed_externally = self.table.is_managed_externally
|
||||
column.is_partition = False
|
||||
column.is_physical = not self.expression
|
||||
column.is_spatial = False
|
||||
column.is_temporal = self.is_dttm
|
||||
column.extra_json = json.dumps(extra_json) if extra_json else None
|
||||
column.external_url = self.table.external_url
|
||||
|
||||
return column
|
||||
|
||||
@staticmethod
|
||||
def after_delete( # pylint: disable=unused-argument
|
||||
mapper: Mapper,
|
||||
connection: Connection,
|
||||
target: "TableColumn",
|
||||
) -> None:
|
||||
session = inspect(target).session
|
||||
column = session.query(NewColumn).filter_by(uuid=target.uuid).one_or_none()
|
||||
if column:
|
||||
session.delete(column)
|
||||
|
||||
|
||||
class SqlMetric(Model, BaseMetric, CertificationMixin):
|
||||
|
||||
|
@ -430,7 +492,7 @@ class SqlMetric(Model, BaseMetric, CertificationMixin):
|
|||
backref=backref("metrics", cascade="all, delete-orphan"),
|
||||
foreign_keys=[table_id],
|
||||
)
|
||||
expression = Column(Text, nullable=False)
|
||||
expression = Column(MediumText(), nullable=False)
|
||||
extra = Column(Text)
|
||||
|
||||
export_fields = [
|
||||
|
@ -479,6 +541,58 @@ class SqlMetric(Model, BaseMetric, CertificationMixin):
|
|||
attr_dict.update(super().data)
|
||||
return attr_dict
|
||||
|
||||
def to_sl_column(
|
||||
self, known_columns: Optional[Dict[str, NewColumn]] = None
|
||||
) -> NewColumn:
|
||||
"""Convert a SqlMetric to NewColumn. Find and update existing or
|
||||
create a new one."""
|
||||
column = known_columns.get(self.uuid) if known_columns else None
|
||||
if not column:
|
||||
column = NewColumn()
|
||||
|
||||
extra_json = self.get_extra_dict()
|
||||
for attr in {"verbose_name", "metric_type", "d3format"}:
|
||||
value = getattr(self, attr)
|
||||
if value is not None:
|
||||
extra_json[attr] = value
|
||||
is_additive = (
|
||||
self.metric_type and self.metric_type.lower() in ADDITIVE_METRIC_TYPES_LOWER
|
||||
)
|
||||
|
||||
column.uuid = self.uuid
|
||||
column.name = self.metric_name
|
||||
column.created_on = self.created_on
|
||||
column.changed_on = self.changed_on
|
||||
column.created_by = self.created_by
|
||||
column.changed_by = self.changed_by
|
||||
column.type = UNKOWN_TYPE
|
||||
column.expression = self.expression
|
||||
column.warning_text = self.warning_text
|
||||
column.description = self.description
|
||||
column.is_aggregation = True
|
||||
column.is_additive = is_additive
|
||||
column.is_filterable = False
|
||||
column.is_increase_desired = True
|
||||
column.is_managed_externally = self.table.is_managed_externally
|
||||
column.is_partition = False
|
||||
column.is_physical = False
|
||||
column.is_spatial = False
|
||||
column.extra_json = json.dumps(extra_json) if extra_json else None
|
||||
column.external_url = self.table.external_url
|
||||
|
||||
return column
|
||||
|
||||
@staticmethod
|
||||
def after_delete( # pylint: disable=unused-argument
|
||||
mapper: Mapper,
|
||||
connection: Connection,
|
||||
target: "SqlMetric",
|
||||
) -> None:
|
||||
session = inspect(target).session
|
||||
column = session.query(NewColumn).filter_by(uuid=target.uuid).one_or_none()
|
||||
if column:
|
||||
session.delete(column)
|
||||
|
||||
|
||||
sqlatable_user = Table(
|
||||
"sqlatable_user",
|
||||
|
@ -544,7 +658,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
foreign_keys=[database_id],
|
||||
)
|
||||
schema = Column(String(255))
|
||||
sql = Column(Text)
|
||||
sql = Column(MediumText())
|
||||
is_sqllab_view = Column(Boolean, default=False)
|
||||
template_params = Column(Text)
|
||||
extra = Column(Text)
|
||||
|
@ -1731,7 +1845,19 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
metrics = []
|
||||
any_date_col = None
|
||||
db_engine_spec = self.db_engine_spec
|
||||
old_columns = db.session.query(TableColumn).filter(TableColumn.table == self)
|
||||
|
||||
# If no `self.id`, then this is a new table, no need to fetch columns
|
||||
# from db. Passing in `self.id` to query will actually automatically
|
||||
# generate a new id, which can be tricky during certain transactions.
|
||||
old_columns = (
|
||||
(
|
||||
db.session.query(TableColumn)
|
||||
.filter(TableColumn.table_id == self.id)
|
||||
.all()
|
||||
)
|
||||
if self.id
|
||||
else self.columns
|
||||
)
|
||||
|
||||
old_columns_by_name: Dict[str, TableColumn] = {
|
||||
col.column_name: col for col in old_columns
|
||||
|
@ -1745,13 +1871,15 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
)
|
||||
|
||||
# clear old columns before adding modified columns back
|
||||
self.columns = []
|
||||
columns = []
|
||||
for col in new_columns:
|
||||
old_column = old_columns_by_name.pop(col["name"], None)
|
||||
if not old_column:
|
||||
results.added.append(col["name"])
|
||||
new_column = TableColumn(
|
||||
column_name=col["name"], type=col["type"], table=self
|
||||
column_name=col["name"],
|
||||
type=col["type"],
|
||||
table=self,
|
||||
)
|
||||
new_column.is_dttm = new_column.is_temporal
|
||||
db_engine_spec.alter_new_orm_column(new_column)
|
||||
|
@ -1763,12 +1891,14 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
new_column.expression = ""
|
||||
new_column.groupby = True
|
||||
new_column.filterable = True
|
||||
self.columns.append(new_column)
|
||||
columns.append(new_column)
|
||||
if not any_date_col and new_column.is_temporal:
|
||||
any_date_col = col["name"]
|
||||
self.columns.extend(
|
||||
[col for col in old_columns_by_name.values() if col.expression]
|
||||
)
|
||||
|
||||
# add back calculated (virtual) columns
|
||||
columns.extend([col for col in old_columns if col.expression])
|
||||
self.columns = columns
|
||||
|
||||
metrics.append(
|
||||
SqlMetric(
|
||||
metric_name="count",
|
||||
|
@ -1854,6 +1984,10 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
extra_cache_keys += sqla_query.extra_cache_keys
|
||||
return extra_cache_keys
|
||||
|
||||
@property
|
||||
def quote_identifier(self) -> Callable[[str], str]:
|
||||
return self.database.quote_identifier
|
||||
|
||||
@staticmethod
|
||||
def before_update(
|
||||
mapper: Mapper, # pylint: disable=unused-argument
|
||||
|
@ -1895,14 +2029,44 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
):
|
||||
raise Exception(get_dataset_exist_error_msg(target.full_name))
|
||||
|
||||
def get_sl_columns(self) -> List[NewColumn]:
|
||||
"""
|
||||
Convert `SqlaTable.columns` and `SqlaTable.metrics` to the new Column model
|
||||
"""
|
||||
session: Session = inspect(self).session
|
||||
|
||||
uuids = set()
|
||||
for column_or_metric in self.columns + self.metrics:
|
||||
# pre-assign uuid after new columns or metrics are inserted so
|
||||
# the related `NewColumn` can have a deterministic uuid, too
|
||||
if not column_or_metric.uuid:
|
||||
column_or_metric.uuid = uuid4()
|
||||
else:
|
||||
uuids.add(column_or_metric.uuid)
|
||||
|
||||
# load existing columns from cached session states first
|
||||
existing_columns = set(
|
||||
find_cached_objects_in_session(session, NewColumn, uuids=uuids)
|
||||
)
|
||||
for column in existing_columns:
|
||||
uuids.remove(column.uuid)
|
||||
|
||||
if uuids:
|
||||
# load those not found from db
|
||||
existing_columns |= set(
|
||||
session.query(NewColumn).filter(NewColumn.uuid.in_(uuids))
|
||||
)
|
||||
|
||||
known_columns = {column.uuid: column for column in existing_columns}
|
||||
return [
|
||||
item.to_sl_column(known_columns) for item in self.columns + self.metrics
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def update_table( # pylint: disable=unused-argument
|
||||
mapper: Mapper, connection: Connection, target: Union[SqlMetric, TableColumn]
|
||||
) -> None:
|
||||
"""
|
||||
Forces an update to the table's changed_on value when a metric or column on the
|
||||
table is updated. This busts the cache key for all charts that use the table.
|
||||
|
||||
:param mapper: Unused.
|
||||
:param connection: Unused.
|
||||
:param target: The metric or column that was updated.
|
||||
|
@ -1910,90 +2074,43 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
inspector = inspect(target)
|
||||
session = inspector.session
|
||||
|
||||
# get DB-specific conditional quoter for expressions that point to columns or
|
||||
# table names
|
||||
database = (
|
||||
target.table.database
|
||||
or session.query(Database).filter_by(id=target.database_id).one()
|
||||
)
|
||||
engine = database.get_sqla_engine(schema=target.table.schema)
|
||||
conditional_quote = engine.dialect.identifier_preparer.quote
|
||||
|
||||
# Forces an update to the table's changed_on value when a metric or column on the
|
||||
# table is updated. This busts the cache key for all charts that use the table.
|
||||
session.execute(update(SqlaTable).where(SqlaTable.id == target.table.id))
|
||||
|
||||
dataset = (
|
||||
session.query(NewDataset)
|
||||
.filter_by(sqlatable_id=target.table.id)
|
||||
.one_or_none()
|
||||
)
|
||||
|
||||
if not dataset:
|
||||
# if dataset is not found create a new copy
|
||||
# of the dataset instead of updating the existing
|
||||
|
||||
SqlaTable.write_shadow_dataset(target.table, database, session)
|
||||
return
|
||||
|
||||
# update ``Column`` model as well
|
||||
if isinstance(target, TableColumn):
|
||||
columns = [
|
||||
column
|
||||
for column in dataset.columns
|
||||
if column.name == target.column_name
|
||||
]
|
||||
if not columns:
|
||||
# if table itself has changed, shadow-writing will happen in `after_udpate` anyway
|
||||
if target.table not in session.dirty:
|
||||
dataset: NewDataset = (
|
||||
session.query(NewDataset)
|
||||
.filter_by(uuid=target.table.uuid)
|
||||
.one_or_none()
|
||||
)
|
||||
# Update shadow dataset and columns
|
||||
# did we find the dataset?
|
||||
if not dataset:
|
||||
# if dataset is not found create a new copy
|
||||
target.table.write_shadow_dataset()
|
||||
return
|
||||
|
||||
column = columns[0]
|
||||
extra_json = json.loads(target.extra or "{}")
|
||||
for attr in {"groupby", "filterable", "verbose_name", "python_date_format"}:
|
||||
value = getattr(target, attr)
|
||||
if value:
|
||||
extra_json[attr] = value
|
||||
# update changed_on timestamp
|
||||
session.execute(update(NewDataset).where(NewDataset.id == dataset.id))
|
||||
|
||||
column.name = target.column_name
|
||||
column.type = target.type or "Unknown"
|
||||
column.expression = target.expression or conditional_quote(
|
||||
target.column_name
|
||||
# update `Column` model as well
|
||||
session.add(
|
||||
target.to_sl_column(
|
||||
{
|
||||
target.uuid: session.query(NewColumn)
|
||||
.filter_by(uuid=target.uuid)
|
||||
.one_or_none()
|
||||
}
|
||||
)
|
||||
)
|
||||
column.description = target.description
|
||||
column.is_temporal = target.is_dttm
|
||||
column.is_physical = target.expression is None
|
||||
column.extra_json = json.dumps(extra_json) if extra_json else None
|
||||
|
||||
else: # SqlMetric
|
||||
columns = [
|
||||
column
|
||||
for column in dataset.columns
|
||||
if column.name == target.metric_name
|
||||
]
|
||||
if not columns:
|
||||
return
|
||||
|
||||
column = columns[0]
|
||||
extra_json = json.loads(target.extra or "{}")
|
||||
for attr in {"verbose_name", "metric_type", "d3format"}:
|
||||
value = getattr(target, attr)
|
||||
if value:
|
||||
extra_json[attr] = value
|
||||
|
||||
is_additive = (
|
||||
target.metric_type
|
||||
and target.metric_type.lower() in ADDITIVE_METRIC_TYPES
|
||||
)
|
||||
|
||||
column.name = target.metric_name
|
||||
column.expression = target.expression
|
||||
column.warning_text = target.warning_text
|
||||
column.description = target.description
|
||||
column.is_additive = is_additive
|
||||
column.extra_json = json.dumps(extra_json) if extra_json else None
|
||||
|
||||
@staticmethod
|
||||
def after_insert(
|
||||
mapper: Mapper,
|
||||
connection: Connection,
|
||||
target: "SqlaTable",
|
||||
sqla_table: "SqlaTable",
|
||||
) -> None:
|
||||
"""
|
||||
Shadow write the dataset to new models.
|
||||
|
@ -2007,24 +2124,14 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
|
||||
For more context: https://github.com/apache/superset/issues/14909
|
||||
"""
|
||||
session = inspect(target).session
|
||||
# set permissions
|
||||
security_manager.set_perm(mapper, connection, target)
|
||||
|
||||
# get DB-specific conditional quoter for expressions that point to columns or
|
||||
# table names
|
||||
database = (
|
||||
target.database
|
||||
or session.query(Database).filter_by(id=target.database_id).one()
|
||||
)
|
||||
|
||||
SqlaTable.write_shadow_dataset(target, database, session)
|
||||
security_manager.set_perm(mapper, connection, sqla_table)
|
||||
sqla_table.write_shadow_dataset()
|
||||
|
||||
@staticmethod
|
||||
def after_delete( # pylint: disable=unused-argument
|
||||
mapper: Mapper,
|
||||
connection: Connection,
|
||||
target: "SqlaTable",
|
||||
sqla_table: "SqlaTable",
|
||||
) -> None:
|
||||
"""
|
||||
Shadow write the dataset to new models.
|
||||
|
@ -2038,18 +2145,18 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
|
||||
For more context: https://github.com/apache/superset/issues/14909
|
||||
"""
|
||||
session = inspect(target).session
|
||||
session = inspect(sqla_table).session
|
||||
dataset = (
|
||||
session.query(NewDataset).filter_by(sqlatable_id=target.id).one_or_none()
|
||||
session.query(NewDataset).filter_by(uuid=sqla_table.uuid).one_or_none()
|
||||
)
|
||||
if dataset:
|
||||
session.delete(dataset)
|
||||
|
||||
@staticmethod
|
||||
def after_update( # pylint: disable=too-many-branches, too-many-locals, too-many-statements
|
||||
def after_update(
|
||||
mapper: Mapper,
|
||||
connection: Connection,
|
||||
target: "SqlaTable",
|
||||
sqla_table: "SqlaTable",
|
||||
) -> None:
|
||||
"""
|
||||
Shadow write the dataset to new models.
|
||||
|
@ -2063,172 +2170,76 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
|
||||
For more context: https://github.com/apache/superset/issues/14909
|
||||
"""
|
||||
inspector = inspect(target)
|
||||
# set permissions
|
||||
security_manager.set_perm(mapper, connection, sqla_table)
|
||||
|
||||
inspector = inspect(sqla_table)
|
||||
session = inspector.session
|
||||
|
||||
# double-check that ``UPDATE``s are actually pending (this method is called even
|
||||
# for instances that have no net changes to their column-based attributes)
|
||||
if not session.is_modified(target, include_collections=True):
|
||||
if not session.is_modified(sqla_table, include_collections=True):
|
||||
return
|
||||
|
||||
# set permissions
|
||||
security_manager.set_perm(mapper, connection, target)
|
||||
|
||||
dataset = (
|
||||
session.query(NewDataset).filter_by(sqlatable_id=target.id).one_or_none()
|
||||
# find the dataset from the known instance list first
|
||||
# (it could be either from a previous query or newly created)
|
||||
dataset = next(
|
||||
find_cached_objects_in_session(
|
||||
session, NewDataset, uuids=[sqla_table.uuid]
|
||||
),
|
||||
None,
|
||||
)
|
||||
# if not found, pull from database
|
||||
if not dataset:
|
||||
dataset = (
|
||||
session.query(NewDataset).filter_by(uuid=sqla_table.uuid).one_or_none()
|
||||
)
|
||||
if not dataset:
|
||||
sqla_table.write_shadow_dataset()
|
||||
return
|
||||
|
||||
# get DB-specific conditional quoter for expressions that point to columns or
|
||||
# table names
|
||||
database = (
|
||||
target.database
|
||||
or session.query(Database).filter_by(id=target.database_id).one()
|
||||
)
|
||||
engine = database.get_sqla_engine(schema=target.schema)
|
||||
conditional_quote = engine.dialect.identifier_preparer.quote
|
||||
|
||||
# update columns
|
||||
if inspector.attrs.columns.history.has_changes():
|
||||
# handle deleted columns
|
||||
if inspector.attrs.columns.history.deleted:
|
||||
column_names = {
|
||||
column.column_name
|
||||
for column in inspector.attrs.columns.history.deleted
|
||||
}
|
||||
dataset.columns = [
|
||||
column
|
||||
for column in dataset.columns
|
||||
if column.name not in column_names
|
||||
]
|
||||
|
||||
# handle inserted columns
|
||||
for column in inspector.attrs.columns.history.added:
|
||||
# ``is_active`` might be ``None``, but it defaults to ``True``.
|
||||
if column.is_active is False:
|
||||
continue
|
||||
|
||||
extra_json = json.loads(column.extra or "{}")
|
||||
for attr in {
|
||||
"groupby",
|
||||
"filterable",
|
||||
"verbose_name",
|
||||
"python_date_format",
|
||||
}:
|
||||
value = getattr(column, attr)
|
||||
if value:
|
||||
extra_json[attr] = value
|
||||
|
||||
dataset.columns.append(
|
||||
NewColumn(
|
||||
name=column.column_name,
|
||||
type=column.type or "Unknown",
|
||||
expression=column.expression
|
||||
or conditional_quote(column.column_name),
|
||||
description=column.description,
|
||||
is_temporal=column.is_dttm,
|
||||
is_aggregation=False,
|
||||
is_physical=column.expression is None,
|
||||
is_spatial=False,
|
||||
is_partition=False,
|
||||
is_increase_desired=True,
|
||||
extra_json=json.dumps(extra_json) if extra_json else None,
|
||||
is_managed_externally=target.is_managed_externally,
|
||||
external_url=target.external_url,
|
||||
)
|
||||
)
|
||||
|
||||
# update metrics
|
||||
if inspector.attrs.metrics.history.has_changes():
|
||||
# handle deleted metrics
|
||||
if inspector.attrs.metrics.history.deleted:
|
||||
column_names = {
|
||||
metric.metric_name
|
||||
for metric in inspector.attrs.metrics.history.deleted
|
||||
}
|
||||
dataset.columns = [
|
||||
column
|
||||
for column in dataset.columns
|
||||
if column.name not in column_names
|
||||
]
|
||||
|
||||
# handle inserted metrics
|
||||
for metric in inspector.attrs.metrics.history.added:
|
||||
extra_json = json.loads(metric.extra or "{}")
|
||||
for attr in {"verbose_name", "metric_type", "d3format"}:
|
||||
value = getattr(metric, attr)
|
||||
if value:
|
||||
extra_json[attr] = value
|
||||
|
||||
is_additive = (
|
||||
metric.metric_type
|
||||
and metric.metric_type.lower() in ADDITIVE_METRIC_TYPES
|
||||
)
|
||||
|
||||
dataset.columns.append(
|
||||
NewColumn(
|
||||
name=metric.metric_name,
|
||||
type="Unknown",
|
||||
expression=metric.expression,
|
||||
warning_text=metric.warning_text,
|
||||
description=metric.description,
|
||||
is_aggregation=True,
|
||||
is_additive=is_additive,
|
||||
is_physical=False,
|
||||
is_spatial=False,
|
||||
is_partition=False,
|
||||
is_increase_desired=True,
|
||||
extra_json=json.dumps(extra_json) if extra_json else None,
|
||||
is_managed_externally=target.is_managed_externally,
|
||||
external_url=target.external_url,
|
||||
)
|
||||
)
|
||||
# sync column list and delete removed columns
|
||||
if (
|
||||
inspector.attrs.columns.history.has_changes()
|
||||
or inspector.attrs.metrics.history.has_changes()
|
||||
):
|
||||
# add pending new columns to known columns list, too, so if calling
|
||||
# `after_update` twice before changes are persisted will not create
|
||||
# two duplicate columns with the same uuids.
|
||||
dataset.columns = sqla_table.get_sl_columns()
|
||||
|
||||
# physical dataset
|
||||
if target.sql is None:
|
||||
physical_columns = [
|
||||
column for column in dataset.columns if column.is_physical
|
||||
]
|
||||
|
||||
# if the table name changed we should create a new table instance, instead
|
||||
# of reusing the original one
|
||||
if not sqla_table.sql:
|
||||
# if the table name changed we should relink the dataset to another table
|
||||
# (and create one if necessary)
|
||||
if (
|
||||
inspector.attrs.table_name.history.has_changes()
|
||||
or inspector.attrs.schema.history.has_changes()
|
||||
or inspector.attrs.database_id.history.has_changes()
|
||||
or inspector.attrs.database.history.has_changes()
|
||||
):
|
||||
# does the dataset point to an existing table?
|
||||
table = (
|
||||
session.query(NewTable)
|
||||
.filter_by(
|
||||
database_id=target.database_id,
|
||||
schema=target.schema,
|
||||
name=target.table_name,
|
||||
)
|
||||
.first()
|
||||
tables = NewTable.bulk_load_or_create(
|
||||
sqla_table.database,
|
||||
[TableName(schema=sqla_table.schema, table=sqla_table.table_name)],
|
||||
sync_columns=False,
|
||||
default_props=dict(
|
||||
changed_by=sqla_table.changed_by,
|
||||
created_by=sqla_table.created_by,
|
||||
is_managed_externally=sqla_table.is_managed_externally,
|
||||
external_url=sqla_table.external_url,
|
||||
),
|
||||
)
|
||||
if not table:
|
||||
# create new columns
|
||||
if not tables[0].id:
|
||||
# dataset columns will only be assigned to newly created tables
|
||||
# existing tables should manage column syncing in another process
|
||||
physical_columns = [
|
||||
clone_model(column, ignore=["uuid"])
|
||||
for column in physical_columns
|
||||
clone_model(
|
||||
column, ignore=["uuid"], keep_relations=["changed_by"]
|
||||
)
|
||||
for column in dataset.columns
|
||||
if column.is_physical
|
||||
]
|
||||
|
||||
# create new table
|
||||
table = NewTable(
|
||||
name=target.table_name,
|
||||
schema=target.schema,
|
||||
catalog=None,
|
||||
database_id=target.database_id,
|
||||
columns=physical_columns,
|
||||
is_managed_externally=target.is_managed_externally,
|
||||
external_url=target.external_url,
|
||||
)
|
||||
dataset.tables = [table]
|
||||
elif dataset.tables:
|
||||
table = dataset.tables[0]
|
||||
table.columns = physical_columns
|
||||
tables[0].columns = physical_columns
|
||||
dataset.tables = tables
|
||||
|
||||
# virtual dataset
|
||||
else:
|
||||
|
@ -2237,29 +2248,34 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
column.is_physical = False
|
||||
|
||||
# update referenced tables if SQL changed
|
||||
if inspector.attrs.sql.history.has_changes():
|
||||
parsed = ParsedQuery(target.sql)
|
||||
referenced_tables = parsed.tables
|
||||
|
||||
predicate = or_(
|
||||
*[
|
||||
and_(
|
||||
NewTable.schema == (table.schema or target.schema),
|
||||
NewTable.name == table.table,
|
||||
)
|
||||
for table in referenced_tables
|
||||
]
|
||||
if sqla_table.sql and inspector.attrs.sql.history.has_changes():
|
||||
referenced_tables = extract_table_references(
|
||||
sqla_table.sql, sqla_table.database.get_dialect().name
|
||||
)
|
||||
dataset.tables = NewTable.bulk_load_or_create(
|
||||
sqla_table.database,
|
||||
referenced_tables,
|
||||
default_schema=sqla_table.schema,
|
||||
# sync metadata is expensive, we'll do it in another process
|
||||
# e.g. when users open a Table page
|
||||
sync_columns=False,
|
||||
default_props=dict(
|
||||
changed_by=sqla_table.changed_by,
|
||||
created_by=sqla_table.created_by,
|
||||
is_managed_externally=sqla_table.is_managed_externally,
|
||||
external_url=sqla_table.external_url,
|
||||
),
|
||||
)
|
||||
dataset.tables = session.query(NewTable).filter(predicate).all()
|
||||
|
||||
# update other attributes
|
||||
dataset.name = target.table_name
|
||||
dataset.expression = target.sql or conditional_quote(target.table_name)
|
||||
dataset.is_physical = target.sql is None
|
||||
dataset.name = sqla_table.table_name
|
||||
dataset.expression = sqla_table.sql or sqla_table.quote_identifier(
|
||||
sqla_table.table_name
|
||||
)
|
||||
dataset.is_physical = not sqla_table.sql
|
||||
|
||||
@staticmethod
|
||||
def write_shadow_dataset( # pylint: disable=too-many-locals
|
||||
dataset: "SqlaTable", database: Database, session: Session
|
||||
def write_shadow_dataset(
|
||||
self: "SqlaTable",
|
||||
) -> None:
|
||||
"""
|
||||
Shadow write the dataset to new models.
|
||||
|
@ -2273,95 +2289,57 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
|
||||
For more context: https://github.com/apache/superset/issues/14909
|
||||
"""
|
||||
|
||||
engine = database.get_sqla_engine(schema=dataset.schema)
|
||||
conditional_quote = engine.dialect.identifier_preparer.quote
|
||||
session = inspect(self).session
|
||||
# make sure database points to the right instance, in case only
|
||||
# `table.database_id` is updated and the changes haven't been
|
||||
# consolidated by SQLA
|
||||
if self.database_id and (
|
||||
not self.database or self.database.id != self.database_id
|
||||
):
|
||||
self.database = session.query(Database).filter_by(id=self.database_id).one()
|
||||
|
||||
# create columns
|
||||
columns = []
|
||||
for column in dataset.columns:
|
||||
# ``is_active`` might be ``None`` at this point, but it defaults to ``True``.
|
||||
if column.is_active is False:
|
||||
continue
|
||||
|
||||
try:
|
||||
extra_json = json.loads(column.extra or "{}")
|
||||
except json.decoder.JSONDecodeError:
|
||||
extra_json = {}
|
||||
for attr in {"groupby", "filterable", "verbose_name", "python_date_format"}:
|
||||
value = getattr(column, attr)
|
||||
if value:
|
||||
extra_json[attr] = value
|
||||
|
||||
columns.append(
|
||||
NewColumn(
|
||||
name=column.column_name,
|
||||
type=column.type or "Unknown",
|
||||
expression=column.expression
|
||||
or conditional_quote(column.column_name),
|
||||
description=column.description,
|
||||
is_temporal=column.is_dttm,
|
||||
is_aggregation=False,
|
||||
is_physical=column.expression is None,
|
||||
is_spatial=False,
|
||||
is_partition=False,
|
||||
is_increase_desired=True,
|
||||
extra_json=json.dumps(extra_json) if extra_json else None,
|
||||
is_managed_externally=dataset.is_managed_externally,
|
||||
external_url=dataset.external_url,
|
||||
),
|
||||
)
|
||||
|
||||
# create metrics
|
||||
for metric in dataset.metrics:
|
||||
try:
|
||||
extra_json = json.loads(metric.extra or "{}")
|
||||
except json.decoder.JSONDecodeError:
|
||||
extra_json = {}
|
||||
for attr in {"verbose_name", "metric_type", "d3format"}:
|
||||
value = getattr(metric, attr)
|
||||
if value:
|
||||
extra_json[attr] = value
|
||||
|
||||
is_additive = (
|
||||
metric.metric_type
|
||||
and metric.metric_type.lower() in ADDITIVE_METRIC_TYPES
|
||||
)
|
||||
|
||||
columns.append(
|
||||
NewColumn(
|
||||
name=metric.metric_name,
|
||||
type="Unknown", # figuring this out would require a type inferrer
|
||||
expression=metric.expression,
|
||||
warning_text=metric.warning_text,
|
||||
description=metric.description,
|
||||
is_aggregation=True,
|
||||
is_additive=is_additive,
|
||||
is_physical=False,
|
||||
is_spatial=False,
|
||||
is_partition=False,
|
||||
is_increase_desired=True,
|
||||
extra_json=json.dumps(extra_json) if extra_json else None,
|
||||
is_managed_externally=dataset.is_managed_externally,
|
||||
external_url=dataset.external_url,
|
||||
),
|
||||
)
|
||||
for item in self.columns + self.metrics:
|
||||
item.created_by = self.created_by
|
||||
item.changed_by = self.changed_by
|
||||
# on `SqlaTable.after_insert`` event, although the table itself
|
||||
# already has a `uuid`, the associated columns will not.
|
||||
# Here we pre-assign a uuid so they can still be matched to the new
|
||||
# Column after creation.
|
||||
if not item.uuid:
|
||||
item.uuid = uuid4()
|
||||
columns.append(item.to_sl_column())
|
||||
|
||||
# physical dataset
|
||||
if not dataset.sql:
|
||||
physical_columns = [column for column in columns if column.is_physical]
|
||||
|
||||
# create table
|
||||
table = NewTable(
|
||||
name=dataset.table_name,
|
||||
schema=dataset.schema,
|
||||
catalog=None, # currently not supported
|
||||
database_id=dataset.database_id,
|
||||
columns=physical_columns,
|
||||
is_managed_externally=dataset.is_managed_externally,
|
||||
external_url=dataset.external_url,
|
||||
if not self.sql:
|
||||
# always create separate column entries for Dataset and Table
|
||||
# so updating a dataset would not update columns in the related table
|
||||
physical_columns = [
|
||||
clone_model(
|
||||
column,
|
||||
ignore=["uuid"],
|
||||
# `created_by` will always be left empty because it'd always
|
||||
# be created via some sort of automated system.
|
||||
# But keep `changed_by` in case someone manually changes
|
||||
# column attributes such as `is_dttm`.
|
||||
keep_relations=["changed_by"],
|
||||
)
|
||||
for column in columns
|
||||
if column.is_physical
|
||||
]
|
||||
tables = NewTable.bulk_load_or_create(
|
||||
self.database,
|
||||
[TableName(schema=self.schema, table=self.table_name)],
|
||||
sync_columns=False,
|
||||
default_props=dict(
|
||||
created_by=self.created_by,
|
||||
changed_by=self.changed_by,
|
||||
is_managed_externally=self.is_managed_externally,
|
||||
external_url=self.external_url,
|
||||
),
|
||||
)
|
||||
tables = [table]
|
||||
tables[0].columns = physical_columns
|
||||
|
||||
# virtual dataset
|
||||
else:
|
||||
|
@ -2370,26 +2348,39 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
column.is_physical = False
|
||||
|
||||
# find referenced tables
|
||||
parsed = ParsedQuery(dataset.sql)
|
||||
referenced_tables = parsed.tables
|
||||
tables = load_or_create_tables(
|
||||
session,
|
||||
database,
|
||||
dataset.schema,
|
||||
referenced_tables = extract_table_references(
|
||||
self.sql, self.database.get_dialect().name
|
||||
)
|
||||
tables = NewTable.bulk_load_or_create(
|
||||
self.database,
|
||||
referenced_tables,
|
||||
conditional_quote,
|
||||
default_schema=self.schema,
|
||||
# syncing table columns can be slow so we are not doing it here
|
||||
sync_columns=False,
|
||||
default_props=dict(
|
||||
created_by=self.created_by,
|
||||
changed_by=self.changed_by,
|
||||
is_managed_externally=self.is_managed_externally,
|
||||
external_url=self.external_url,
|
||||
),
|
||||
)
|
||||
|
||||
# create the new dataset
|
||||
new_dataset = NewDataset(
|
||||
sqlatable_id=dataset.id,
|
||||
name=dataset.table_name,
|
||||
expression=dataset.sql or conditional_quote(dataset.table_name),
|
||||
uuid=self.uuid,
|
||||
database_id=self.database_id,
|
||||
created_on=self.created_on,
|
||||
created_by=self.created_by,
|
||||
changed_by=self.changed_by,
|
||||
changed_on=self.changed_on,
|
||||
owners=self.owners,
|
||||
name=self.table_name,
|
||||
expression=self.sql or self.quote_identifier(self.table_name),
|
||||
tables=tables,
|
||||
columns=columns,
|
||||
is_physical=not dataset.sql,
|
||||
is_managed_externally=dataset.is_managed_externally,
|
||||
external_url=dataset.external_url,
|
||||
is_physical=not self.sql,
|
||||
is_managed_externally=self.is_managed_externally,
|
||||
external_url=self.external_url,
|
||||
)
|
||||
session.add(new_dataset)
|
||||
|
||||
|
@ -2399,7 +2390,9 @@ sa.event.listen(SqlaTable, "after_insert", SqlaTable.after_insert)
|
|||
sa.event.listen(SqlaTable, "after_delete", SqlaTable.after_delete)
|
||||
sa.event.listen(SqlaTable, "after_update", SqlaTable.after_update)
|
||||
sa.event.listen(SqlMetric, "after_update", SqlaTable.update_table)
|
||||
sa.event.listen(SqlMetric, "after_delete", SqlMetric.after_delete)
|
||||
sa.event.listen(TableColumn, "after_update", SqlaTable.update_table)
|
||||
sa.event.listen(TableColumn, "after_delete", TableColumn.after_delete)
|
||||
|
||||
RLSFilterRoles = Table(
|
||||
"rls_filter_roles",
|
||||
|
|
|
@ -15,16 +15,28 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from contextlib import closing
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, TYPE_CHECKING
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Type,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
)
|
||||
from uuid import UUID
|
||||
|
||||
import sqlparse
|
||||
from flask_babel import lazy_gettext as _
|
||||
from sqlalchemy import and_, or_
|
||||
from sqlalchemy.engine.url import URL as SqlaURL
|
||||
from sqlalchemy.exc import NoSuchTableError
|
||||
from sqlalchemy.ext.declarative import DeclarativeMeta
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql.type_api import TypeEngine
|
||||
|
||||
from superset.columns.models import Column as NewColumn
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.exceptions import (
|
||||
SupersetGenericDBErrorException,
|
||||
|
@ -32,9 +44,9 @@ from superset.exceptions import (
|
|||
)
|
||||
from superset.models.core import Database
|
||||
from superset.result_set import SupersetResultSet
|
||||
from superset.sql_parse import has_table_query, insert_rls, ParsedQuery, Table
|
||||
from superset.sql_parse import has_table_query, insert_rls, ParsedQuery
|
||||
from superset.superset_typing import ResultSetColumnType
|
||||
from superset.tables.models import Table as NewTable
|
||||
from superset.utils.memoized import memoized
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
|
@ -168,75 +180,38 @@ def validate_adhoc_subquery(
|
|||
return ";\n".join(str(statement) for statement in statements)
|
||||
|
||||
|
||||
def load_or_create_tables( # pylint: disable=too-many-arguments
|
||||
@memoized
|
||||
def get_dialect_name(drivername: str) -> str:
|
||||
return SqlaURL(drivername).get_dialect().name
|
||||
|
||||
|
||||
@memoized
|
||||
def get_identifier_quoter(drivername: str) -> Dict[str, Callable[[str], str]]:
|
||||
return SqlaURL(drivername).get_dialect()().identifier_preparer.quote
|
||||
|
||||
|
||||
DeclarativeModel = TypeVar("DeclarativeModel", bound=DeclarativeMeta)
|
||||
|
||||
|
||||
def find_cached_objects_in_session(
|
||||
session: Session,
|
||||
database: Database,
|
||||
default_schema: Optional[str],
|
||||
tables: Set[Table],
|
||||
conditional_quote: Callable[[str], str],
|
||||
) -> List[NewTable]:
|
||||
"""
|
||||
Load or create new table model instances.
|
||||
"""
|
||||
if not tables:
|
||||
return []
|
||||
cls: Type[DeclarativeModel],
|
||||
ids: Optional[Iterable[int]] = None,
|
||||
uuids: Optional[Iterable[UUID]] = None,
|
||||
) -> Iterator[DeclarativeModel]:
|
||||
"""Find known ORM instances in cached SQLA session states.
|
||||
|
||||
# set the default schema in tables that don't have it
|
||||
if default_schema:
|
||||
fixed_tables = list(tables)
|
||||
for i, table in enumerate(fixed_tables):
|
||||
if table.schema is None:
|
||||
fixed_tables[i] = Table(table.table, default_schema, table.catalog)
|
||||
tables = set(fixed_tables)
|
||||
|
||||
# load existing tables
|
||||
predicate = or_(
|
||||
*[
|
||||
and_(
|
||||
NewTable.database_id == database.id,
|
||||
NewTable.schema == table.schema,
|
||||
NewTable.name == table.table,
|
||||
)
|
||||
for table in tables
|
||||
]
|
||||
:param session: a SQLA session
|
||||
:param cls: a SQLA DeclarativeModel
|
||||
:param ids: ids of the desired model instances (optional)
|
||||
:param uuids: uuids of the desired instances, will be ignored if `ids` are provides
|
||||
"""
|
||||
if not ids and not uuids:
|
||||
return iter([])
|
||||
uuids = uuids or []
|
||||
return (
|
||||
item
|
||||
# `session` is an iterator of all known items
|
||||
for item in set(session)
|
||||
if isinstance(item, cls) and (item.id in ids if ids else item.uuid in uuids)
|
||||
)
|
||||
new_tables = session.query(NewTable).filter(predicate).all()
|
||||
|
||||
# add missing tables
|
||||
existing = {(table.schema, table.name) for table in new_tables}
|
||||
for table in tables:
|
||||
if (table.schema, table.table) not in existing:
|
||||
try:
|
||||
column_metadata = get_physical_table_metadata(
|
||||
database=database,
|
||||
table_name=table.table,
|
||||
schema_name=table.schema,
|
||||
)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
continue
|
||||
columns = [
|
||||
NewColumn(
|
||||
name=column["name"],
|
||||
type=str(column["type"]),
|
||||
expression=conditional_quote(column["name"]),
|
||||
is_temporal=column["is_dttm"],
|
||||
is_aggregation=False,
|
||||
is_physical=True,
|
||||
is_spatial=False,
|
||||
is_partition=False,
|
||||
is_increase_desired=True,
|
||||
)
|
||||
for column in column_metadata
|
||||
]
|
||||
new_tables.append(
|
||||
NewTable(
|
||||
name=table.table,
|
||||
schema=table.schema,
|
||||
catalog=None,
|
||||
database_id=database.id,
|
||||
columns=columns,
|
||||
)
|
||||
)
|
||||
existing.add((table.schema, table.table))
|
||||
|
||||
return new_tables
|
||||
|
|
|
@ -28,9 +28,11 @@ from typing import List
|
|||
|
||||
import sqlalchemy as sa
|
||||
from flask_appbuilder import Model
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.orm import backref, relationship
|
||||
|
||||
from superset import security_manager
|
||||
from superset.columns.models import Column
|
||||
from superset.models.core import Database
|
||||
from superset.models.helpers import (
|
||||
AuditMixinNullable,
|
||||
ExtraJSONMixin,
|
||||
|
@ -38,18 +40,33 @@ from superset.models.helpers import (
|
|||
)
|
||||
from superset.tables.models import Table
|
||||
|
||||
column_association_table = sa.Table(
|
||||
dataset_column_association_table = sa.Table(
|
||||
"sl_dataset_columns",
|
||||
Model.metadata, # pylint: disable=no-member
|
||||
sa.Column("dataset_id", sa.ForeignKey("sl_datasets.id")),
|
||||
sa.Column("column_id", sa.ForeignKey("sl_columns.id")),
|
||||
sa.Column(
|
||||
"dataset_id",
|
||||
sa.ForeignKey("sl_datasets.id"),
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column(
|
||||
"column_id",
|
||||
sa.ForeignKey("sl_columns.id"),
|
||||
primary_key=True,
|
||||
),
|
||||
)
|
||||
|
||||
table_association_table = sa.Table(
|
||||
dataset_table_association_table = sa.Table(
|
||||
"sl_dataset_tables",
|
||||
Model.metadata, # pylint: disable=no-member
|
||||
sa.Column("dataset_id", sa.ForeignKey("sl_datasets.id")),
|
||||
sa.Column("table_id", sa.ForeignKey("sl_tables.id")),
|
||||
sa.Column("dataset_id", sa.ForeignKey("sl_datasets.id"), primary_key=True),
|
||||
sa.Column("table_id", sa.ForeignKey("sl_tables.id"), primary_key=True),
|
||||
)
|
||||
|
||||
dataset_user_association_table = sa.Table(
|
||||
"sl_dataset_users",
|
||||
Model.metadata, # pylint: disable=no-member
|
||||
sa.Column("dataset_id", sa.ForeignKey("sl_datasets.id"), primary_key=True),
|
||||
sa.Column("user_id", sa.ForeignKey("ab_user.id"), primary_key=True),
|
||||
)
|
||||
|
||||
|
||||
|
@ -61,27 +78,27 @@ class Dataset(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin):
|
|||
__tablename__ = "sl_datasets"
|
||||
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
# A temporary column, used for shadow writing to the new model. Once the ``SqlaTable``
|
||||
# model has been deleted this column can be removed.
|
||||
sqlatable_id = sa.Column(sa.Integer, nullable=True, unique=True)
|
||||
|
||||
# We use ``sa.Text`` for these attributes because (1) in modern databases the
|
||||
# performance is the same as ``VARCHAR``[1] and (2) because some table names can be
|
||||
# **really** long (eg, Google Sheets URLs).
|
||||
#
|
||||
# [1] https://www.postgresql.org/docs/9.1/datatype-character.html
|
||||
name = sa.Column(sa.Text)
|
||||
|
||||
expression = sa.Column(sa.Text)
|
||||
|
||||
# n:n relationship
|
||||
tables: List[Table] = relationship("Table", secondary=table_association_table)
|
||||
|
||||
# The relationship between datasets and columns is 1:n, but we use a many-to-many
|
||||
# association to differentiate between the relationship between tables and columns.
|
||||
database_id = sa.Column(sa.Integer, sa.ForeignKey("dbs.id"), nullable=False)
|
||||
database: Database = relationship(
|
||||
"Database",
|
||||
backref=backref("datasets", cascade="all, delete-orphan"),
|
||||
foreign_keys=[database_id],
|
||||
)
|
||||
# The relationship between datasets and columns is 1:n, but we use a
|
||||
# many-to-many association table to avoid adding two mutually exclusive
|
||||
# columns(dataset_id and table_id) to Column
|
||||
columns: List[Column] = relationship(
|
||||
"Column", secondary=column_association_table, cascade="all, delete"
|
||||
"Column",
|
||||
secondary=dataset_column_association_table,
|
||||
cascade="all, delete-orphan",
|
||||
single_parent=True,
|
||||
backref="datasets",
|
||||
)
|
||||
owners = relationship(
|
||||
security_manager.user_model, secondary=dataset_user_association_table
|
||||
)
|
||||
tables: List[Table] = relationship(
|
||||
"Table", secondary=dataset_table_association_table, backref="datasets"
|
||||
)
|
||||
|
||||
# Does the dataset point directly to a ``Table``?
|
||||
|
@ -89,4 +106,15 @@ class Dataset(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin):
|
|||
|
||||
# Column is managed externally and should be read-only inside Superset
|
||||
is_managed_externally = sa.Column(sa.Boolean, nullable=False, default=False)
|
||||
|
||||
# We use ``sa.Text`` for these attributes because (1) in modern databases the
|
||||
# performance is the same as ``VARCHAR``[1] and (2) because some table names can be
|
||||
# **really** long (eg, Google Sheets URLs).
|
||||
#
|
||||
# [1] https://www.postgresql.org/docs/9.1/datatype-character.html
|
||||
name = sa.Column(sa.Text)
|
||||
expression = sa.Column(sa.Text)
|
||||
external_url = sa.Column(sa.Text, nullable=True)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Dataset id={self.id} database_id={self.database_id} {self.name}>"
|
||||
|
|
|
@ -135,23 +135,26 @@ def _set_table_metadata(datasource: SqlaTable, database: "Database") -> None:
|
|||
|
||||
|
||||
def _add_table_metrics(datasource: SqlaTable) -> None:
|
||||
if not any(col.column_name == "num_california" for col in datasource.columns):
|
||||
# By accessing the attribute first, we make sure `datasource.columns` and
|
||||
# `datasource.metrics` are already loaded. Otherwise accessing them later
|
||||
# may trigger an unnecessary and unexpected `after_update` event.
|
||||
columns, metrics = datasource.columns, datasource.metrics
|
||||
|
||||
if not any(col.column_name == "num_california" for col in columns):
|
||||
col_state = str(column("state").compile(db.engine))
|
||||
col_num = str(column("num").compile(db.engine))
|
||||
datasource.columns.append(
|
||||
columns.append(
|
||||
TableColumn(
|
||||
column_name="num_california",
|
||||
expression=f"CASE WHEN {col_state} = 'CA' THEN {col_num} ELSE 0 END",
|
||||
)
|
||||
)
|
||||
|
||||
if not any(col.metric_name == "sum__num" for col in datasource.metrics):
|
||||
if not any(col.metric_name == "sum__num" for col in metrics):
|
||||
col = str(column("num").compile(db.engine))
|
||||
datasource.metrics.append(
|
||||
SqlMetric(metric_name="sum__num", expression=f"SUM({col})")
|
||||
)
|
||||
metrics.append(SqlMetric(metric_name="sum__num", expression=f"SUM({col})"))
|
||||
|
||||
for col in datasource.columns:
|
||||
for col in columns:
|
||||
if col.column_name == "ds":
|
||||
col.is_dttm = True
|
||||
break
|
||||
|
|
|
@ -15,42 +15,22 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import Any, Iterator, Optional, Set
|
||||
import os
|
||||
import time
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from alembic import op
|
||||
from sqlalchemy import engine_from_config
|
||||
from sqlalchemy.dialects.mysql.base import MySQLDialect
|
||||
from sqlalchemy.dialects.postgresql.base import PGDialect
|
||||
from sqlalchemy.engine import reflection
|
||||
from sqlalchemy.exc import NoSuchTableError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
try:
|
||||
from sqloxide import parse_sql
|
||||
except ImportError:
|
||||
parse_sql = None
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from superset.sql_parse import ParsedQuery, Table
|
||||
|
||||
logger = logging.getLogger("alembic")
|
||||
|
||||
|
||||
# mapping between sqloxide and SQLAlchemy dialects
|
||||
sqloxide_dialects = {
|
||||
"ansi": {"trino", "trinonative", "presto"},
|
||||
"hive": {"hive", "databricks"},
|
||||
"ms": {"mssql"},
|
||||
"mysql": {"mysql"},
|
||||
"postgres": {
|
||||
"cockroachdb",
|
||||
"hana",
|
||||
"netezza",
|
||||
"postgres",
|
||||
"postgresql",
|
||||
"redshift",
|
||||
"vertica",
|
||||
},
|
||||
"snowflake": {"snowflake"},
|
||||
"sqlite": {"sqlite", "gsheets", "shillelagh"},
|
||||
"clickhouse": {"clickhouse"},
|
||||
}
|
||||
DEFAULT_BATCH_SIZE = int(os.environ.get("BATCH_SIZE", 1000))
|
||||
|
||||
|
||||
def table_has_column(table: str, column: str) -> bool:
|
||||
|
@ -61,7 +41,6 @@ def table_has_column(table: str, column: str) -> bool:
|
|||
:param column: A column name
|
||||
:returns: True iff the column exists in the table
|
||||
"""
|
||||
|
||||
config = op.get_context().config
|
||||
engine = engine_from_config(
|
||||
config.get_section(config.config_ini_section), prefix="sqlalchemy."
|
||||
|
@ -73,42 +52,44 @@ def table_has_column(table: str, column: str) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def find_nodes_by_key(element: Any, target: str) -> Iterator[Any]:
|
||||
"""
|
||||
Find all nodes in a SQL tree matching a given key.
|
||||
"""
|
||||
if isinstance(element, list):
|
||||
for child in element:
|
||||
yield from find_nodes_by_key(child, target)
|
||||
elif isinstance(element, dict):
|
||||
for key, value in element.items():
|
||||
if key == target:
|
||||
yield value
|
||||
else:
|
||||
yield from find_nodes_by_key(value, target)
|
||||
uuid_by_dialect = {
|
||||
MySQLDialect: "UNHEX(REPLACE(CONVERT(UUID() using utf8mb4), '-', ''))",
|
||||
PGDialect: "uuid_in(md5(random()::text || clock_timestamp()::text)::cstring)",
|
||||
}
|
||||
|
||||
|
||||
def extract_table_references(sql_text: str, sqla_dialect: str) -> Set[Table]:
|
||||
"""
|
||||
Return all the dependencies from a SQL sql_text.
|
||||
"""
|
||||
if not parse_sql:
|
||||
parsed = ParsedQuery(sql_text)
|
||||
return parsed.tables
|
||||
def assign_uuids(
|
||||
model: Any, session: Session, batch_size: int = DEFAULT_BATCH_SIZE
|
||||
) -> None:
|
||||
"""Generate new UUIDs for all rows in a table"""
|
||||
bind = op.get_bind()
|
||||
table_name = model.__tablename__
|
||||
count = session.query(model).count()
|
||||
# silently skip if the table is empty (suitable for db initialization)
|
||||
if count == 0:
|
||||
return
|
||||
|
||||
dialect = "generic"
|
||||
for dialect, sqla_dialects in sqloxide_dialects.items():
|
||||
if sqla_dialect in sqla_dialects:
|
||||
break
|
||||
try:
|
||||
tree = parse_sql(sql_text, dialect=dialect)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
logger.warning("Unable to parse query with sqloxide: %s", sql_text)
|
||||
# fallback to sqlparse
|
||||
parsed = ParsedQuery(sql_text)
|
||||
return parsed.tables
|
||||
start_time = time.time()
|
||||
print(f"\nAdding uuids for `{table_name}`...")
|
||||
# Use dialect specific native SQL queries if possible
|
||||
for dialect, sql in uuid_by_dialect.items():
|
||||
if isinstance(bind.dialect, dialect):
|
||||
op.execute(
|
||||
f"UPDATE {dialect().identifier_preparer.quote(table_name)} SET uuid = {sql}"
|
||||
)
|
||||
print(f"Done. Assigned {count} uuids in {time.time() - start_time:.3f}s.\n")
|
||||
return
|
||||
|
||||
return {
|
||||
Table(*[part["value"] for part in table["name"][::-1]])
|
||||
for table in find_nodes_by_key(tree, "Table")
|
||||
}
|
||||
# Othwewise Use Python uuid function
|
||||
start = 0
|
||||
while start < count:
|
||||
end = min(start + batch_size, count)
|
||||
for obj in session.query(model)[start:end]:
|
||||
obj.uuid = uuid4()
|
||||
session.merge(obj)
|
||||
session.commit()
|
||||
if start + batch_size < count:
|
||||
print(f" uuid assigned to {end} out of {count}\r", end="")
|
||||
start += batch_size
|
||||
|
||||
print(f"Done. Assigned {count} uuids in {time.time() - start_time:.3f}s.\n")
|
||||
|
|
|
@ -32,9 +32,7 @@ from sqlalchemy.ext.declarative import declarative_base
|
|||
from sqlalchemy_utils import UUIDType
|
||||
|
||||
from superset import db
|
||||
from superset.migrations.versions.b56500de1855_add_uuid_column_to_import_mixin import (
|
||||
add_uuids,
|
||||
)
|
||||
from superset.migrations.shared.utils import assign_uuids
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "96e99fb176a0"
|
||||
|
@ -75,7 +73,7 @@ def upgrade():
|
|||
# Ignore column update errors so that we can run upgrade multiple times
|
||||
pass
|
||||
|
||||
add_uuids(SavedQuery, "saved_query", session)
|
||||
assign_uuids(SavedQuery, session)
|
||||
|
||||
try:
|
||||
# Add uniqueness constraint
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""empty message
|
||||
"""merge point
|
||||
|
||||
Revision ID: 9d8a8d575284
|
||||
Revises: ('8b841273bec3', 'b0d0249074e4')
|
||||
|
|
|
@ -0,0 +1,905 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""new_dataset_models_take_2
|
||||
|
||||
Revision ID: a9422eeaae74
|
||||
Revises: ad07e4fdbaba
|
||||
Create Date: 2022-04-01 14:38:09.499483
|
||||
|
||||
"""
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a9422eeaae74"
|
||||
down_revision = "ad07e4fdbaba"
|
||||
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Set, Type, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.declarative import declarative_base, declared_attr
|
||||
from sqlalchemy.orm import backref, relationship, Session
|
||||
from sqlalchemy.schema import UniqueConstraint
|
||||
from sqlalchemy.sql import functions as func
|
||||
from sqlalchemy.sql.expression import and_, or_
|
||||
from sqlalchemy_utils import UUIDType
|
||||
|
||||
from superset import app, db
|
||||
from superset.connectors.sqla.models import ADDITIVE_METRIC_TYPES_LOWER
|
||||
from superset.connectors.sqla.utils import get_dialect_name, get_identifier_quoter
|
||||
from superset.extensions import encrypted_field_factory
|
||||
from superset.migrations.shared.utils import assign_uuids
|
||||
from superset.sql_parse import extract_table_references, Table
|
||||
from superset.utils.core import MediumText
|
||||
|
||||
Base = declarative_base()
|
||||
custom_password_store = app.config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"]
|
||||
DB_CONNECTION_MUTATOR = app.config["DB_CONNECTION_MUTATOR"]
|
||||
SHOW_PROGRESS = os.environ.get("SHOW_PROGRESS") == "1"
|
||||
UNKNOWN_TYPE = "UNKNOWN"
|
||||
|
||||
|
||||
user_table = sa.Table(
|
||||
"ab_user", Base.metadata, sa.Column("id", sa.Integer(), primary_key=True)
|
||||
)
|
||||
|
||||
|
||||
class UUIDMixin:
|
||||
uuid = sa.Column(
|
||||
UUIDType(binary=True), primary_key=False, unique=True, default=uuid4
|
||||
)
|
||||
|
||||
|
||||
class AuxiliaryColumnsMixin(UUIDMixin):
|
||||
"""
|
||||
Auxiliary columns, a combination of columns added by
|
||||
AuditMixinNullable + ImportExportMixin
|
||||
"""
|
||||
|
||||
created_on = sa.Column(sa.DateTime, default=datetime.now, nullable=True)
|
||||
changed_on = sa.Column(
|
||||
sa.DateTime, default=datetime.now, onupdate=datetime.now, nullable=True
|
||||
)
|
||||
|
||||
@declared_attr
|
||||
def created_by_fk(cls):
|
||||
return sa.Column(sa.Integer, sa.ForeignKey("ab_user.id"), nullable=True)
|
||||
|
||||
@declared_attr
|
||||
def changed_by_fk(cls):
|
||||
return sa.Column(sa.Integer, sa.ForeignKey("ab_user.id"), nullable=True)
|
||||
|
||||
|
||||
def insert_from_select(
|
||||
target: Union[str, sa.Table, Type[Base]], source: sa.sql.expression.Select
|
||||
) -> None:
|
||||
"""
|
||||
Execute INSERT FROM SELECT to copy data from a SELECT query to the target table.
|
||||
"""
|
||||
if isinstance(target, sa.Table):
|
||||
target_table = target
|
||||
elif hasattr(target, "__tablename__"):
|
||||
target_table: sa.Table = Base.metadata.tables[target.__tablename__]
|
||||
else:
|
||||
target_table: sa.Table = Base.metadata.tables[target]
|
||||
cols = [col.name for col in source.columns if col.name in target_table.columns]
|
||||
query = target_table.insert().from_select(cols, source)
|
||||
return op.execute(query)
|
||||
|
||||
|
||||
class Database(Base):
|
||||
|
||||
__tablename__ = "dbs"
|
||||
__table_args__ = (UniqueConstraint("database_name"),)
|
||||
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
database_name = sa.Column(sa.String(250), unique=True, nullable=False)
|
||||
sqlalchemy_uri = sa.Column(sa.String(1024), nullable=False)
|
||||
password = sa.Column(encrypted_field_factory.create(sa.String(1024)))
|
||||
impersonate_user = sa.Column(sa.Boolean, default=False)
|
||||
encrypted_extra = sa.Column(encrypted_field_factory.create(sa.Text), nullable=True)
|
||||
extra = sa.Column(sa.Text)
|
||||
server_cert = sa.Column(encrypted_field_factory.create(sa.Text), nullable=True)
|
||||
|
||||
|
||||
class TableColumn(AuxiliaryColumnsMixin, Base):
|
||||
|
||||
__tablename__ = "table_columns"
|
||||
__table_args__ = (UniqueConstraint("table_id", "column_name"),)
|
||||
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
table_id = sa.Column(sa.Integer, sa.ForeignKey("tables.id"))
|
||||
is_active = sa.Column(sa.Boolean, default=True)
|
||||
extra = sa.Column(sa.Text)
|
||||
column_name = sa.Column(sa.String(255), nullable=False)
|
||||
type = sa.Column(sa.String(32))
|
||||
expression = sa.Column(MediumText())
|
||||
description = sa.Column(MediumText())
|
||||
is_dttm = sa.Column(sa.Boolean, default=False)
|
||||
filterable = sa.Column(sa.Boolean, default=True)
|
||||
groupby = sa.Column(sa.Boolean, default=True)
|
||||
verbose_name = sa.Column(sa.String(1024))
|
||||
python_date_format = sa.Column(sa.String(255))
|
||||
|
||||
|
||||
class SqlMetric(AuxiliaryColumnsMixin, Base):
|
||||
|
||||
__tablename__ = "sql_metrics"
|
||||
__table_args__ = (UniqueConstraint("table_id", "metric_name"),)
|
||||
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
table_id = sa.Column(sa.Integer, sa.ForeignKey("tables.id"))
|
||||
extra = sa.Column(sa.Text)
|
||||
metric_type = sa.Column(sa.String(32))
|
||||
metric_name = sa.Column(sa.String(255), nullable=False)
|
||||
expression = sa.Column(MediumText(), nullable=False)
|
||||
warning_text = sa.Column(MediumText())
|
||||
description = sa.Column(MediumText())
|
||||
d3format = sa.Column(sa.String(128))
|
||||
verbose_name = sa.Column(sa.String(1024))
|
||||
|
||||
|
||||
sqlatable_user_table = sa.Table(
|
||||
"sqlatable_user",
|
||||
Base.metadata,
|
||||
sa.Column("id", sa.Integer, primary_key=True),
|
||||
sa.Column("user_id", sa.Integer, sa.ForeignKey("ab_user.id")),
|
||||
sa.Column("table_id", sa.Integer, sa.ForeignKey("tables.id")),
|
||||
)
|
||||
|
||||
|
||||
class SqlaTable(AuxiliaryColumnsMixin, Base):
|
||||
|
||||
__tablename__ = "tables"
|
||||
__table_args__ = (UniqueConstraint("database_id", "schema", "table_name"),)
|
||||
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
extra = sa.Column(sa.Text)
|
||||
database_id = sa.Column(sa.Integer, sa.ForeignKey("dbs.id"), nullable=False)
|
||||
database: Database = relationship(
|
||||
"Database",
|
||||
backref=backref("tables", cascade="all, delete-orphan"),
|
||||
foreign_keys=[database_id],
|
||||
)
|
||||
schema = sa.Column(sa.String(255))
|
||||
table_name = sa.Column(sa.String(250), nullable=False)
|
||||
sql = sa.Column(MediumText())
|
||||
is_managed_externally = sa.Column(sa.Boolean, nullable=False, default=False)
|
||||
external_url = sa.Column(sa.Text, nullable=True)
|
||||
|
||||
|
||||
table_column_association_table = sa.Table(
|
||||
"sl_table_columns",
|
||||
Base.metadata,
|
||||
sa.Column("table_id", sa.ForeignKey("sl_tables.id"), primary_key=True),
|
||||
sa.Column("column_id", sa.ForeignKey("sl_columns.id"), primary_key=True),
|
||||
)
|
||||
|
||||
dataset_column_association_table = sa.Table(
|
||||
"sl_dataset_columns",
|
||||
Base.metadata,
|
||||
sa.Column("dataset_id", sa.ForeignKey("sl_datasets.id"), primary_key=True),
|
||||
sa.Column("column_id", sa.ForeignKey("sl_columns.id"), primary_key=True),
|
||||
)
|
||||
|
||||
dataset_table_association_table = sa.Table(
|
||||
"sl_dataset_tables",
|
||||
Base.metadata,
|
||||
sa.Column("dataset_id", sa.ForeignKey("sl_datasets.id"), primary_key=True),
|
||||
sa.Column("table_id", sa.ForeignKey("sl_tables.id"), primary_key=True),
|
||||
)
|
||||
|
||||
dataset_user_association_table = sa.Table(
|
||||
"sl_dataset_users",
|
||||
Base.metadata,
|
||||
sa.Column("dataset_id", sa.ForeignKey("sl_datasets.id"), primary_key=True),
|
||||
sa.Column("user_id", sa.ForeignKey("ab_user.id"), primary_key=True),
|
||||
)
|
||||
|
||||
|
||||
class NewColumn(AuxiliaryColumnsMixin, Base):
|
||||
|
||||
__tablename__ = "sl_columns"
|
||||
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
# A temporary column to link physical columns with tables so we don't
|
||||
# have to insert a record in the relationship table while creating new columns.
|
||||
table_id = sa.Column(sa.Integer, nullable=True)
|
||||
|
||||
is_aggregation = sa.Column(sa.Boolean, nullable=False, default=False)
|
||||
is_additive = sa.Column(sa.Boolean, nullable=False, default=False)
|
||||
is_dimensional = sa.Column(sa.Boolean, nullable=False, default=False)
|
||||
is_filterable = sa.Column(sa.Boolean, nullable=False, default=True)
|
||||
is_increase_desired = sa.Column(sa.Boolean, nullable=False, default=True)
|
||||
is_managed_externally = sa.Column(sa.Boolean, nullable=False, default=False)
|
||||
is_partition = sa.Column(sa.Boolean, nullable=False, default=False)
|
||||
is_physical = sa.Column(sa.Boolean, nullable=False, default=False)
|
||||
is_temporal = sa.Column(sa.Boolean, nullable=False, default=False)
|
||||
is_spatial = sa.Column(sa.Boolean, nullable=False, default=False)
|
||||
|
||||
name = sa.Column(sa.Text)
|
||||
type = sa.Column(sa.Text)
|
||||
unit = sa.Column(sa.Text)
|
||||
expression = sa.Column(MediumText())
|
||||
description = sa.Column(MediumText())
|
||||
warning_text = sa.Column(MediumText())
|
||||
external_url = sa.Column(sa.Text, nullable=True)
|
||||
extra_json = sa.Column(MediumText(), default="{}")
|
||||
|
||||
|
||||
class NewTable(AuxiliaryColumnsMixin, Base):
|
||||
|
||||
__tablename__ = "sl_tables"
|
||||
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
# A temporary column to keep the link between NewTable to SqlaTable
|
||||
sqlatable_id = sa.Column(sa.Integer, primary_key=False, nullable=True, unique=True)
|
||||
database_id = sa.Column(sa.Integer, sa.ForeignKey("dbs.id"), nullable=False)
|
||||
is_managed_externally = sa.Column(sa.Boolean, nullable=False, default=False)
|
||||
catalog = sa.Column(sa.Text)
|
||||
schema = sa.Column(sa.Text)
|
||||
name = sa.Column(sa.Text)
|
||||
external_url = sa.Column(sa.Text, nullable=True)
|
||||
extra_json = sa.Column(MediumText(), default="{}")
|
||||
database: Database = relationship(
|
||||
"Database",
|
||||
backref=backref("new_tables", cascade="all, delete-orphan"),
|
||||
foreign_keys=[database_id],
|
||||
)
|
||||
|
||||
|
||||
class NewDataset(Base, AuxiliaryColumnsMixin):
|
||||
|
||||
__tablename__ = "sl_datasets"
|
||||
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
database_id = sa.Column(sa.Integer, sa.ForeignKey("dbs.id"), nullable=False)
|
||||
is_physical = sa.Column(sa.Boolean, default=False)
|
||||
is_managed_externally = sa.Column(sa.Boolean, nullable=False, default=False)
|
||||
name = sa.Column(sa.Text)
|
||||
expression = sa.Column(MediumText())
|
||||
external_url = sa.Column(sa.Text, nullable=True)
|
||||
extra_json = sa.Column(MediumText(), default="{}")
|
||||
|
||||
|
||||
def find_tables(
|
||||
session: Session,
|
||||
database_id: int,
|
||||
default_schema: Optional[str],
|
||||
tables: Set[Table],
|
||||
) -> List[int]:
|
||||
"""
|
||||
Look for NewTable's of from a specific database
|
||||
"""
|
||||
if not tables:
|
||||
return []
|
||||
|
||||
predicate = or_(
|
||||
*[
|
||||
and_(
|
||||
NewTable.database_id == database_id,
|
||||
NewTable.schema == (table.schema or default_schema),
|
||||
NewTable.name == table.table,
|
||||
)
|
||||
for table in tables
|
||||
]
|
||||
)
|
||||
return session.query(NewTable.id).filter(predicate).all()
|
||||
|
||||
|
||||
# helper SQLA elements for easier querying
|
||||
is_physical_table = or_(SqlaTable.sql.is_(None), SqlaTable.sql == "")
|
||||
is_physical_column = or_(TableColumn.expression.is_(None), TableColumn.expression == "")
|
||||
|
||||
# filtering out table columns with valid associated SqlTable
|
||||
active_table_columns = sa.join(
|
||||
TableColumn,
|
||||
SqlaTable,
|
||||
TableColumn.table_id == SqlaTable.id,
|
||||
)
|
||||
active_metrics = sa.join(SqlMetric, SqlaTable, SqlMetric.table_id == SqlaTable.id)
|
||||
|
||||
|
||||
def copy_tables(session: Session) -> None:
|
||||
"""Copy Physical tables"""
|
||||
count = session.query(SqlaTable).filter(is_physical_table).count()
|
||||
if not count:
|
||||
return
|
||||
print(f">> Copy {count:,} physical tables to sl_tables...")
|
||||
insert_from_select(
|
||||
NewTable,
|
||||
select(
|
||||
[
|
||||
# Tables need different uuid than datasets, since they are different
|
||||
# entities. When INSERT FROM SELECT, we must provide a value for `uuid`,
|
||||
# otherwise it'd use the default generated on Python side, which
|
||||
# will cause duplicate values. They will be replaced by `assign_uuids` later.
|
||||
SqlaTable.uuid,
|
||||
SqlaTable.id.label("sqlatable_id"),
|
||||
SqlaTable.created_on,
|
||||
SqlaTable.changed_on,
|
||||
SqlaTable.created_by_fk,
|
||||
SqlaTable.changed_by_fk,
|
||||
SqlaTable.table_name.label("name"),
|
||||
SqlaTable.schema,
|
||||
SqlaTable.database_id,
|
||||
SqlaTable.is_managed_externally,
|
||||
SqlaTable.external_url,
|
||||
]
|
||||
)
|
||||
# use an inner join to filter out only tables with valid database ids
|
||||
.select_from(
|
||||
sa.join(SqlaTable, Database, SqlaTable.database_id == Database.id)
|
||||
).where(is_physical_table),
|
||||
)
|
||||
|
||||
|
||||
def copy_datasets(session: Session) -> None:
|
||||
"""Copy all datasets"""
|
||||
count = session.query(SqlaTable).count()
|
||||
if not count:
|
||||
return
|
||||
print(f">> Copy {count:,} SqlaTable to sl_datasets...")
|
||||
insert_from_select(
|
||||
NewDataset,
|
||||
select(
|
||||
[
|
||||
SqlaTable.uuid,
|
||||
SqlaTable.created_on,
|
||||
SqlaTable.changed_on,
|
||||
SqlaTable.created_by_fk,
|
||||
SqlaTable.changed_by_fk,
|
||||
SqlaTable.database_id,
|
||||
SqlaTable.table_name.label("name"),
|
||||
func.coalesce(SqlaTable.sql, SqlaTable.table_name).label("expression"),
|
||||
is_physical_table.label("is_physical"),
|
||||
SqlaTable.is_managed_externally,
|
||||
SqlaTable.external_url,
|
||||
SqlaTable.extra.label("extra_json"),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
print(" Copy dataset owners...")
|
||||
insert_from_select(
|
||||
dataset_user_association_table,
|
||||
select(
|
||||
[NewDataset.id.label("dataset_id"), sqlatable_user_table.c.user_id]
|
||||
).select_from(
|
||||
sqlatable_user_table.join(
|
||||
SqlaTable, SqlaTable.id == sqlatable_user_table.c.table_id
|
||||
).join(NewDataset, NewDataset.uuid == SqlaTable.uuid)
|
||||
),
|
||||
)
|
||||
|
||||
print(" Link physical datasets with tables...")
|
||||
insert_from_select(
|
||||
dataset_table_association_table,
|
||||
select(
|
||||
[
|
||||
NewDataset.id.label("dataset_id"),
|
||||
NewTable.id.label("table_id"),
|
||||
]
|
||||
).select_from(
|
||||
sa.join(SqlaTable, NewTable, NewTable.sqlatable_id == SqlaTable.id).join(
|
||||
NewDataset, NewDataset.uuid == SqlaTable.uuid
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def copy_columns(session: Session) -> None:
|
||||
"""Copy columns with active associated SqlTable"""
|
||||
count = session.query(TableColumn).select_from(active_table_columns).count()
|
||||
if not count:
|
||||
return
|
||||
print(f">> Copy {count:,} table columns to sl_columns...")
|
||||
insert_from_select(
|
||||
NewColumn,
|
||||
select(
|
||||
[
|
||||
TableColumn.uuid,
|
||||
TableColumn.created_on,
|
||||
TableColumn.changed_on,
|
||||
TableColumn.created_by_fk,
|
||||
TableColumn.changed_by_fk,
|
||||
TableColumn.groupby.label("is_dimensional"),
|
||||
TableColumn.filterable.label("is_filterable"),
|
||||
TableColumn.column_name.label("name"),
|
||||
TableColumn.description,
|
||||
func.coalesce(TableColumn.expression, TableColumn.column_name).label(
|
||||
"expression"
|
||||
),
|
||||
sa.literal(False).label("is_aggregation"),
|
||||
is_physical_column.label("is_physical"),
|
||||
TableColumn.is_dttm.label("is_temporal"),
|
||||
func.coalesce(TableColumn.type, UNKNOWN_TYPE).label("type"),
|
||||
TableColumn.extra.label("extra_json"),
|
||||
]
|
||||
).select_from(active_table_columns),
|
||||
)
|
||||
|
||||
joined_columns_table = active_table_columns.join(
|
||||
NewColumn, TableColumn.uuid == NewColumn.uuid
|
||||
)
|
||||
print(" Link all columns to sl_datasets...")
|
||||
insert_from_select(
|
||||
dataset_column_association_table,
|
||||
select(
|
||||
[
|
||||
NewDataset.id.label("dataset_id"),
|
||||
NewColumn.id.label("column_id"),
|
||||
],
|
||||
).select_from(
|
||||
joined_columns_table.join(NewDataset, NewDataset.uuid == SqlaTable.uuid)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def copy_metrics(session: Session) -> None:
|
||||
"""Copy metrics as virtual columns"""
|
||||
metrics_count = session.query(SqlMetric).select_from(active_metrics).count()
|
||||
if not metrics_count:
|
||||
return
|
||||
|
||||
print(f">> Copy {metrics_count:,} metrics to sl_columns...")
|
||||
insert_from_select(
|
||||
NewColumn,
|
||||
select(
|
||||
[
|
||||
SqlMetric.uuid,
|
||||
SqlMetric.created_on,
|
||||
SqlMetric.changed_on,
|
||||
SqlMetric.created_by_fk,
|
||||
SqlMetric.changed_by_fk,
|
||||
SqlMetric.metric_name.label("name"),
|
||||
SqlMetric.expression,
|
||||
SqlMetric.description,
|
||||
sa.literal(UNKNOWN_TYPE).label("type"),
|
||||
(
|
||||
func.coalesce(
|
||||
sa.func.lower(SqlMetric.metric_type).in_(
|
||||
ADDITIVE_METRIC_TYPES_LOWER
|
||||
),
|
||||
sa.literal(False),
|
||||
).label("is_additive")
|
||||
),
|
||||
sa.literal(True).label("is_aggregation"),
|
||||
# metrics are by default not filterable
|
||||
sa.literal(False).label("is_filterable"),
|
||||
sa.literal(False).label("is_dimensional"),
|
||||
sa.literal(False).label("is_physical"),
|
||||
sa.literal(False).label("is_temporal"),
|
||||
SqlMetric.extra.label("extra_json"),
|
||||
SqlMetric.warning_text,
|
||||
]
|
||||
).select_from(active_metrics),
|
||||
)
|
||||
|
||||
print(" Link metric columns to datasets...")
|
||||
insert_from_select(
|
||||
dataset_column_association_table,
|
||||
select(
|
||||
[
|
||||
NewDataset.id.label("dataset_id"),
|
||||
NewColumn.id.label("column_id"),
|
||||
],
|
||||
).select_from(
|
||||
active_metrics.join(NewDataset, NewDataset.uuid == SqlaTable.uuid).join(
|
||||
NewColumn, NewColumn.uuid == SqlMetric.uuid
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def postprocess_datasets(session: Session) -> None:
|
||||
"""
|
||||
Postprocess datasets after insertion to
|
||||
- Quote table names for physical datasets (if needed)
|
||||
- Link referenced tables to virtual datasets
|
||||
"""
|
||||
total = session.query(SqlaTable).count()
|
||||
if not total:
|
||||
return
|
||||
|
||||
offset = 0
|
||||
limit = 10000
|
||||
|
||||
joined_tables = sa.join(
|
||||
NewDataset,
|
||||
SqlaTable,
|
||||
NewDataset.uuid == SqlaTable.uuid,
|
||||
).join(
|
||||
Database,
|
||||
Database.id == SqlaTable.database_id,
|
||||
isouter=True,
|
||||
)
|
||||
assert session.query(func.count()).select_from(joined_tables).scalar() == total
|
||||
|
||||
print(f">> Run postprocessing on {total} datasets")
|
||||
|
||||
update_count = 0
|
||||
|
||||
def print_update_count():
|
||||
if SHOW_PROGRESS:
|
||||
print(
|
||||
f" Will update {update_count} datasets" + " " * 20,
|
||||
end="\r",
|
||||
)
|
||||
|
||||
while offset < total:
|
||||
print(
|
||||
f" Process dataset {offset + 1}~{min(total, offset + limit)}..."
|
||||
+ " " * 30
|
||||
)
|
||||
for (
|
||||
database_id,
|
||||
dataset_id,
|
||||
expression,
|
||||
extra,
|
||||
is_physical,
|
||||
schema,
|
||||
sqlalchemy_uri,
|
||||
) in session.execute(
|
||||
select(
|
||||
[
|
||||
NewDataset.database_id,
|
||||
NewDataset.id.label("dataset_id"),
|
||||
NewDataset.expression,
|
||||
SqlaTable.extra,
|
||||
NewDataset.is_physical,
|
||||
SqlaTable.schema,
|
||||
Database.sqlalchemy_uri,
|
||||
]
|
||||
)
|
||||
.select_from(joined_tables)
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
):
|
||||
drivername = (sqlalchemy_uri or "").split("://")[0]
|
||||
updates = {}
|
||||
updated = False
|
||||
if is_physical and drivername:
|
||||
quoted_expression = get_identifier_quoter(drivername)(expression)
|
||||
if quoted_expression != expression:
|
||||
updates["expression"] = quoted_expression
|
||||
|
||||
# add schema name to `dataset.extra_json` so we don't have to join
|
||||
# tables in order to use datasets
|
||||
if schema:
|
||||
try:
|
||||
extra_json = json.loads(extra) if extra else {}
|
||||
except json.decoder.JSONDecodeError:
|
||||
extra_json = {}
|
||||
extra_json["schema"] = schema
|
||||
updates["extra_json"] = json.dumps(extra_json)
|
||||
|
||||
if updates:
|
||||
session.execute(
|
||||
sa.update(NewDataset)
|
||||
.where(NewDataset.id == dataset_id)
|
||||
.values(**updates)
|
||||
)
|
||||
updated = True
|
||||
|
||||
if not is_physical and expression:
|
||||
table_refrences = extract_table_references(
|
||||
expression, get_dialect_name(drivername), show_warning=False
|
||||
)
|
||||
found_tables = find_tables(
|
||||
session, database_id, schema, table_refrences
|
||||
)
|
||||
if found_tables:
|
||||
op.bulk_insert(
|
||||
dataset_table_association_table,
|
||||
[
|
||||
{"dataset_id": dataset_id, "table_id": table.id}
|
||||
for table in found_tables
|
||||
],
|
||||
)
|
||||
updated = True
|
||||
|
||||
if updated:
|
||||
update_count += 1
|
||||
print_update_count()
|
||||
|
||||
session.flush()
|
||||
offset += limit
|
||||
|
||||
if SHOW_PROGRESS:
|
||||
print("")
|
||||
|
||||
|
||||
def postprocess_columns(session: Session) -> None:
|
||||
"""
|
||||
At this step, we will
|
||||
- Add engine specific quotes to `expression` of physical columns
|
||||
- Tuck some extra metadata to `extra_json`
|
||||
"""
|
||||
total = session.query(NewColumn).count()
|
||||
if not total:
|
||||
return
|
||||
|
||||
def get_joined_tables(offset, limit):
|
||||
return (
|
||||
sa.join(
|
||||
session.query(NewColumn)
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.subquery("sl_columns"),
|
||||
dataset_column_association_table,
|
||||
dataset_column_association_table.c.column_id == NewColumn.id,
|
||||
)
|
||||
.join(
|
||||
NewDataset,
|
||||
NewDataset.id == dataset_column_association_table.c.dataset_id,
|
||||
)
|
||||
.join(
|
||||
dataset_table_association_table,
|
||||
# Join tables with physical datasets
|
||||
and_(
|
||||
NewDataset.is_physical,
|
||||
dataset_table_association_table.c.dataset_id == NewDataset.id,
|
||||
),
|
||||
isouter=True,
|
||||
)
|
||||
.join(Database, Database.id == NewDataset.database_id)
|
||||
.join(
|
||||
TableColumn,
|
||||
TableColumn.uuid == NewColumn.uuid,
|
||||
isouter=True,
|
||||
)
|
||||
.join(
|
||||
SqlMetric,
|
||||
SqlMetric.uuid == NewColumn.uuid,
|
||||
isouter=True,
|
||||
)
|
||||
)
|
||||
|
||||
offset = 0
|
||||
limit = 100000
|
||||
|
||||
print(f">> Run postprocessing on {total:,} columns")
|
||||
|
||||
update_count = 0
|
||||
|
||||
def print_update_count():
|
||||
if SHOW_PROGRESS:
|
||||
print(
|
||||
f" Will update {update_count} columns" + " " * 20,
|
||||
end="\r",
|
||||
)
|
||||
|
||||
while offset < total:
|
||||
query = (
|
||||
select(
|
||||
# sorted alphabetically
|
||||
[
|
||||
NewColumn.id.label("column_id"),
|
||||
TableColumn.column_name,
|
||||
NewColumn.changed_by_fk,
|
||||
NewColumn.changed_on,
|
||||
NewColumn.created_on,
|
||||
NewColumn.description,
|
||||
SqlMetric.d3format,
|
||||
NewDataset.external_url,
|
||||
NewColumn.extra_json,
|
||||
NewColumn.is_dimensional,
|
||||
NewColumn.is_filterable,
|
||||
NewDataset.is_managed_externally,
|
||||
NewColumn.is_physical,
|
||||
SqlMetric.metric_type,
|
||||
TableColumn.python_date_format,
|
||||
Database.sqlalchemy_uri,
|
||||
dataset_table_association_table.c.table_id,
|
||||
func.coalesce(
|
||||
TableColumn.verbose_name, SqlMetric.verbose_name
|
||||
).label("verbose_name"),
|
||||
NewColumn.warning_text,
|
||||
]
|
||||
)
|
||||
.select_from(get_joined_tables(offset, limit))
|
||||
.where(
|
||||
# pre-filter to columns with potential updates
|
||||
or_(
|
||||
NewColumn.is_physical,
|
||||
TableColumn.verbose_name.isnot(None),
|
||||
TableColumn.verbose_name.isnot(None),
|
||||
SqlMetric.verbose_name.isnot(None),
|
||||
SqlMetric.d3format.isnot(None),
|
||||
SqlMetric.metric_type.isnot(None),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
start = offset + 1
|
||||
end = min(total, offset + limit)
|
||||
count = session.query(func.count()).select_from(query).scalar()
|
||||
print(f" [Column {start:,} to {end:,}] {count:,} may be updated")
|
||||
|
||||
physical_columns = []
|
||||
|
||||
for (
|
||||
# sorted alphabetically
|
||||
column_id,
|
||||
column_name,
|
||||
changed_by_fk,
|
||||
changed_on,
|
||||
created_on,
|
||||
description,
|
||||
d3format,
|
||||
external_url,
|
||||
extra_json,
|
||||
is_dimensional,
|
||||
is_filterable,
|
||||
is_managed_externally,
|
||||
is_physical,
|
||||
metric_type,
|
||||
python_date_format,
|
||||
sqlalchemy_uri,
|
||||
table_id,
|
||||
verbose_name,
|
||||
warning_text,
|
||||
) in session.execute(query):
|
||||
try:
|
||||
extra = json.loads(extra_json) if extra_json else {}
|
||||
except json.decoder.JSONDecodeError:
|
||||
extra = {}
|
||||
updated_extra = {**extra}
|
||||
updates = {}
|
||||
|
||||
if is_managed_externally:
|
||||
updates["is_managed_externally"] = True
|
||||
if external_url:
|
||||
updates["external_url"] = external_url
|
||||
|
||||
# update extra json
|
||||
for (key, val) in (
|
||||
{
|
||||
"verbose_name": verbose_name,
|
||||
"python_date_format": python_date_format,
|
||||
"d3format": d3format,
|
||||
"metric_type": metric_type,
|
||||
}
|
||||
).items():
|
||||
# save the original val, including if it's `false`
|
||||
if val is not None:
|
||||
updated_extra[key] = val
|
||||
|
||||
if updated_extra != extra:
|
||||
updates["extra_json"] = json.dumps(updated_extra)
|
||||
|
||||
# update expression for physical table columns
|
||||
if is_physical:
|
||||
if column_name and sqlalchemy_uri:
|
||||
drivername = sqlalchemy_uri.split("://")[0]
|
||||
if is_physical and drivername:
|
||||
quoted_expression = get_identifier_quoter(drivername)(
|
||||
column_name
|
||||
)
|
||||
if quoted_expression != column_name:
|
||||
updates["expression"] = quoted_expression
|
||||
# duplicate physical columns for tables
|
||||
physical_columns.append(
|
||||
dict(
|
||||
created_on=created_on,
|
||||
changed_on=changed_on,
|
||||
changed_by_fk=changed_by_fk,
|
||||
description=description,
|
||||
expression=updates.get("expression", column_name),
|
||||
external_url=external_url,
|
||||
extra_json=updates.get("extra_json", extra_json),
|
||||
is_aggregation=False,
|
||||
is_dimensional=is_dimensional,
|
||||
is_filterable=is_filterable,
|
||||
is_managed_externally=is_managed_externally,
|
||||
is_physical=True,
|
||||
name=column_name,
|
||||
table_id=table_id,
|
||||
warning_text=warning_text,
|
||||
)
|
||||
)
|
||||
|
||||
if updates:
|
||||
session.execute(
|
||||
sa.update(NewColumn)
|
||||
.where(NewColumn.id == column_id)
|
||||
.values(**updates)
|
||||
)
|
||||
update_count += 1
|
||||
print_update_count()
|
||||
|
||||
if physical_columns:
|
||||
op.bulk_insert(NewColumn.__table__, physical_columns)
|
||||
|
||||
session.flush()
|
||||
offset += limit
|
||||
|
||||
if SHOW_PROGRESS:
|
||||
print("")
|
||||
|
||||
print(" Assign table column relations...")
|
||||
insert_from_select(
|
||||
table_column_association_table,
|
||||
select([NewColumn.table_id, NewColumn.id.label("column_id")])
|
||||
.select_from(NewColumn)
|
||||
.where(and_(NewColumn.is_physical, NewColumn.table_id.isnot(None))),
|
||||
)
|
||||
|
||||
|
||||
new_tables: sa.Table = [
|
||||
NewTable.__table__,
|
||||
NewDataset.__table__,
|
||||
NewColumn.__table__,
|
||||
table_column_association_table,
|
||||
dataset_column_association_table,
|
||||
dataset_table_association_table,
|
||||
dataset_user_association_table,
|
||||
]
|
||||
|
||||
|
||||
def reset_postgres_id_sequence(table: str) -> None:
|
||||
op.execute(
|
||||
f"""
|
||||
SELECT setval(
|
||||
pg_get_serial_sequence('{table}', 'id'),
|
||||
COALESCE(max(id) + 1, 1),
|
||||
false
|
||||
)
|
||||
FROM {table};
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
bind = op.get_bind()
|
||||
session: Session = db.Session(bind=bind)
|
||||
Base.metadata.drop_all(bind=bind, tables=new_tables)
|
||||
Base.metadata.create_all(bind=bind, tables=new_tables)
|
||||
|
||||
copy_tables(session)
|
||||
copy_datasets(session)
|
||||
copy_columns(session)
|
||||
copy_metrics(session)
|
||||
session.commit()
|
||||
|
||||
postprocess_columns(session)
|
||||
session.commit()
|
||||
|
||||
postprocess_datasets(session)
|
||||
session.commit()
|
||||
|
||||
# Table were created with the same uuids are datasets. They should
|
||||
# have different uuids as they are different entities.
|
||||
print(">> Assign new UUIDs to tables...")
|
||||
assign_uuids(NewTable, session)
|
||||
|
||||
print(">> Drop intermediate columns...")
|
||||
# These columns are are used during migration, as datasets are independent of tables once created,
|
||||
# dataset columns also the same to table columns.
|
||||
with op.batch_alter_table(NewTable.__tablename__) as batch_op:
|
||||
batch_op.drop_column("sqlatable_id")
|
||||
with op.batch_alter_table(NewColumn.__tablename__) as batch_op:
|
||||
batch_op.drop_column("table_id")
|
||||
|
||||
|
||||
def downgrade():
|
||||
Base.metadata.drop_all(bind=op.get_bind(), tables=new_tables)
|
|
@ -23,19 +23,17 @@ Create Date: 2020-09-28 17:57:23.128142
|
|||
"""
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from json.decoder import JSONDecodeError
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects.mysql.base import MySQLDialect
|
||||
from sqlalchemy.dialects.postgresql.base import PGDialect
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import load_only
|
||||
from sqlalchemy_utils import UUIDType
|
||||
|
||||
from superset import db
|
||||
from superset.migrations.shared.utils import assign_uuids
|
||||
from superset.utils import core as utils
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
|
@ -78,47 +76,6 @@ models["dashboards"].position_json = sa.Column(utils.MediumText())
|
|||
|
||||
default_batch_size = int(os.environ.get("BATCH_SIZE", 200))
|
||||
|
||||
# Add uuids directly using built-in SQL uuid function
|
||||
add_uuids_by_dialect = {
|
||||
MySQLDialect: """UPDATE %s SET uuid = UNHEX(REPLACE(CONVERT(UUID() using utf8mb4), '-', ''));""",
|
||||
PGDialect: """UPDATE %s SET uuid = uuid_in(md5(random()::text || clock_timestamp()::text)::cstring);""",
|
||||
}
|
||||
|
||||
|
||||
def add_uuids(model, table_name, session, batch_size=default_batch_size):
|
||||
"""Populate columns with pre-computed uuids"""
|
||||
bind = op.get_bind()
|
||||
objects_query = session.query(model)
|
||||
count = objects_query.count()
|
||||
|
||||
# silently skip if the table is empty (suitable for db initialization)
|
||||
if count == 0:
|
||||
return
|
||||
|
||||
print(f"\nAdding uuids for `{table_name}`...")
|
||||
start_time = time.time()
|
||||
|
||||
# Use dialect specific native SQL queries if possible
|
||||
for dialect, sql in add_uuids_by_dialect.items():
|
||||
if isinstance(bind.dialect, dialect):
|
||||
op.execute(sql % table_name)
|
||||
print(f"Done. Assigned {count} uuids in {time.time() - start_time:.3f}s.")
|
||||
return
|
||||
|
||||
# Othwewise Use Python uuid function
|
||||
start = 0
|
||||
while start < count:
|
||||
end = min(start + batch_size, count)
|
||||
for obj, uuid in map(lambda obj: (obj, uuid4()), objects_query[start:end]):
|
||||
obj.uuid = uuid
|
||||
session.merge(obj)
|
||||
session.commit()
|
||||
if start + batch_size < count:
|
||||
print(f" uuid assigned to {end} out of {count}\r", end="")
|
||||
start += batch_size
|
||||
|
||||
print(f"Done. Assigned {count} uuids in {time.time() - start_time:.3f}s.")
|
||||
|
||||
|
||||
def update_position_json(dashboard, session, uuid_map):
|
||||
try:
|
||||
|
@ -178,7 +135,7 @@ def upgrade():
|
|||
),
|
||||
)
|
||||
|
||||
add_uuids(model, table_name, session)
|
||||
assign_uuids(model, session)
|
||||
|
||||
# add uniqueness constraint
|
||||
with op.batch_alter_table(table_name) as batch_op:
|
||||
|
@ -203,7 +160,7 @@ def downgrade():
|
|||
update_dashboards(session, {})
|
||||
|
||||
# remove uuid column
|
||||
for table_name, model in models.items():
|
||||
for table_name in models:
|
||||
with op.batch_alter_table(table_name) as batch_op:
|
||||
batch_op.drop_constraint(f"uq_{table_name}_uuid", type_="unique")
|
||||
batch_op.drop_column("uuid")
|
||||
|
|
|
@ -23,619 +23,23 @@ Revises: 5afbb1a5849b
|
|||
Create Date: 2021-11-11 16:41:53.266965
|
||||
|
||||
"""
|
||||
|
||||
import json
|
||||
from datetime import date, datetime, time, timedelta
|
||||
from typing import Callable, List, Optional, Set
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy import and_, inspect, or_
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import backref, relationship, Session
|
||||
from sqlalchemy.schema import UniqueConstraint
|
||||
from sqlalchemy.sql.type_api import TypeEngine
|
||||
from sqlalchemy_utils import UUIDType
|
||||
|
||||
from superset import app, db
|
||||
from superset.connectors.sqla.models import ADDITIVE_METRIC_TYPES
|
||||
from superset.databases.utils import make_url_safe
|
||||
from superset.extensions import encrypted_field_factory
|
||||
from superset.migrations.shared.utils import extract_table_references
|
||||
from superset.models.core import Database as OriginalDatabase
|
||||
from superset.sql_parse import Table
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b8d3a24d9131"
|
||||
down_revision = "5afbb1a5849b"
|
||||
|
||||
Base = declarative_base()
|
||||
custom_password_store = app.config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"]
|
||||
DB_CONNECTION_MUTATOR = app.config["DB_CONNECTION_MUTATOR"]
|
||||
|
||||
# ===================== Notice ========================
|
||||
#
|
||||
# Migrations made in this revision has been moved to `new_dataset_models_take_2`
|
||||
# to fix performance issues as well as a couple of shortcomings in the original
|
||||
# design.
|
||||
#
|
||||
# ======================================================
|
||||
|
||||
class Database(Base):
|
||||
|
||||
__tablename__ = "dbs"
|
||||
__table_args__ = (UniqueConstraint("database_name"),)
|
||||
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
database_name = sa.Column(sa.String(250), unique=True, nullable=False)
|
||||
sqlalchemy_uri = sa.Column(sa.String(1024), nullable=False)
|
||||
password = sa.Column(encrypted_field_factory.create(sa.String(1024)))
|
||||
impersonate_user = sa.Column(sa.Boolean, default=False)
|
||||
encrypted_extra = sa.Column(encrypted_field_factory.create(sa.Text), nullable=True)
|
||||
extra = sa.Column(
|
||||
sa.Text,
|
||||
default=json.dumps(
|
||||
dict(
|
||||
metadata_params={},
|
||||
engine_params={},
|
||||
metadata_cache_timeout={},
|
||||
schemas_allowed_for_file_upload=[],
|
||||
)
|
||||
),
|
||||
)
|
||||
server_cert = sa.Column(encrypted_field_factory.create(sa.Text), nullable=True)
|
||||
|
||||
|
||||
class TableColumn(Base):
|
||||
|
||||
__tablename__ = "table_columns"
|
||||
__table_args__ = (UniqueConstraint("table_id", "column_name"),)
|
||||
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
table_id = sa.Column(sa.Integer, sa.ForeignKey("tables.id"))
|
||||
is_active = sa.Column(sa.Boolean, default=True)
|
||||
extra = sa.Column(sa.Text)
|
||||
column_name = sa.Column(sa.String(255), nullable=False)
|
||||
type = sa.Column(sa.String(32))
|
||||
expression = sa.Column(sa.Text)
|
||||
description = sa.Column(sa.Text)
|
||||
is_dttm = sa.Column(sa.Boolean, default=False)
|
||||
filterable = sa.Column(sa.Boolean, default=True)
|
||||
groupby = sa.Column(sa.Boolean, default=True)
|
||||
verbose_name = sa.Column(sa.String(1024))
|
||||
python_date_format = sa.Column(sa.String(255))
|
||||
|
||||
|
||||
class SqlMetric(Base):
|
||||
|
||||
__tablename__ = "sql_metrics"
|
||||
__table_args__ = (UniqueConstraint("table_id", "metric_name"),)
|
||||
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
table_id = sa.Column(sa.Integer, sa.ForeignKey("tables.id"))
|
||||
extra = sa.Column(sa.Text)
|
||||
metric_type = sa.Column(sa.String(32))
|
||||
metric_name = sa.Column(sa.String(255), nullable=False)
|
||||
expression = sa.Column(sa.Text, nullable=False)
|
||||
warning_text = sa.Column(sa.Text)
|
||||
description = sa.Column(sa.Text)
|
||||
d3format = sa.Column(sa.String(128))
|
||||
verbose_name = sa.Column(sa.String(1024))
|
||||
|
||||
|
||||
class SqlaTable(Base):
|
||||
|
||||
__tablename__ = "tables"
|
||||
__table_args__ = (UniqueConstraint("database_id", "schema", "table_name"),)
|
||||
|
||||
def fetch_columns_and_metrics(self, session: Session) -> None:
|
||||
self.columns = session.query(TableColumn).filter(
|
||||
TableColumn.table_id == self.id
|
||||
)
|
||||
self.metrics = session.query(SqlMetric).filter(TableColumn.table_id == self.id)
|
||||
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
columns: List[TableColumn] = []
|
||||
column_class = TableColumn
|
||||
metrics: List[SqlMetric] = []
|
||||
metric_class = SqlMetric
|
||||
|
||||
database_id = sa.Column(sa.Integer, sa.ForeignKey("dbs.id"), nullable=False)
|
||||
database: Database = relationship(
|
||||
"Database",
|
||||
backref=backref("tables", cascade="all, delete-orphan"),
|
||||
foreign_keys=[database_id],
|
||||
)
|
||||
schema = sa.Column(sa.String(255))
|
||||
table_name = sa.Column(sa.String(250), nullable=False)
|
||||
sql = sa.Column(sa.Text)
|
||||
is_managed_externally = sa.Column(sa.Boolean, nullable=False, default=False)
|
||||
external_url = sa.Column(sa.Text, nullable=True)
|
||||
|
||||
|
||||
table_column_association_table = sa.Table(
|
||||
"sl_table_columns",
|
||||
Base.metadata,
|
||||
sa.Column("table_id", sa.ForeignKey("sl_tables.id")),
|
||||
sa.Column("column_id", sa.ForeignKey("sl_columns.id")),
|
||||
)
|
||||
|
||||
dataset_column_association_table = sa.Table(
|
||||
"sl_dataset_columns",
|
||||
Base.metadata,
|
||||
sa.Column("dataset_id", sa.ForeignKey("sl_datasets.id")),
|
||||
sa.Column("column_id", sa.ForeignKey("sl_columns.id")),
|
||||
)
|
||||
|
||||
dataset_table_association_table = sa.Table(
|
||||
"sl_dataset_tables",
|
||||
Base.metadata,
|
||||
sa.Column("dataset_id", sa.ForeignKey("sl_datasets.id")),
|
||||
sa.Column("table_id", sa.ForeignKey("sl_tables.id")),
|
||||
)
|
||||
|
||||
|
||||
class NewColumn(Base):
|
||||
|
||||
__tablename__ = "sl_columns"
|
||||
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Text)
|
||||
type = sa.Column(sa.Text)
|
||||
expression = sa.Column(sa.Text)
|
||||
is_physical = sa.Column(sa.Boolean, default=True)
|
||||
description = sa.Column(sa.Text)
|
||||
warning_text = sa.Column(sa.Text)
|
||||
is_temporal = sa.Column(sa.Boolean, default=False)
|
||||
is_aggregation = sa.Column(sa.Boolean, default=False)
|
||||
is_additive = sa.Column(sa.Boolean, default=False)
|
||||
is_spatial = sa.Column(sa.Boolean, default=False)
|
||||
is_partition = sa.Column(sa.Boolean, default=False)
|
||||
is_increase_desired = sa.Column(sa.Boolean, default=True)
|
||||
is_managed_externally = sa.Column(sa.Boolean, nullable=False, default=False)
|
||||
external_url = sa.Column(sa.Text, nullable=True)
|
||||
extra_json = sa.Column(sa.Text, default="{}")
|
||||
|
||||
|
||||
class NewTable(Base):
|
||||
|
||||
__tablename__ = "sl_tables"
|
||||
__table_args__ = (UniqueConstraint("database_id", "catalog", "schema", "name"),)
|
||||
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Text)
|
||||
schema = sa.Column(sa.Text)
|
||||
catalog = sa.Column(sa.Text)
|
||||
database_id = sa.Column(sa.Integer, sa.ForeignKey("dbs.id"), nullable=False)
|
||||
database: Database = relationship(
|
||||
"Database",
|
||||
backref=backref("new_tables", cascade="all, delete-orphan"),
|
||||
foreign_keys=[database_id],
|
||||
)
|
||||
columns: List[NewColumn] = relationship(
|
||||
"NewColumn", secondary=table_column_association_table, cascade="all, delete"
|
||||
)
|
||||
is_managed_externally = sa.Column(sa.Boolean, nullable=False, default=False)
|
||||
external_url = sa.Column(sa.Text, nullable=True)
|
||||
|
||||
|
||||
class NewDataset(Base):
|
||||
|
||||
__tablename__ = "sl_datasets"
|
||||
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
sqlatable_id = sa.Column(sa.Integer, nullable=True, unique=True)
|
||||
name = sa.Column(sa.Text)
|
||||
expression = sa.Column(sa.Text)
|
||||
tables: List[NewTable] = relationship(
|
||||
"NewTable", secondary=dataset_table_association_table
|
||||
)
|
||||
columns: List[NewColumn] = relationship(
|
||||
"NewColumn", secondary=dataset_column_association_table, cascade="all, delete"
|
||||
)
|
||||
is_physical = sa.Column(sa.Boolean, default=False)
|
||||
is_managed_externally = sa.Column(sa.Boolean, nullable=False, default=False)
|
||||
external_url = sa.Column(sa.Text, nullable=True)
|
||||
|
||||
|
||||
TEMPORAL_TYPES = {date, datetime, time, timedelta}
|
||||
|
||||
|
||||
def is_column_type_temporal(column_type: TypeEngine) -> bool:
|
||||
try:
|
||||
return column_type.python_type in TEMPORAL_TYPES
|
||||
except NotImplementedError:
|
||||
return False
|
||||
|
||||
|
||||
def load_or_create_tables(
|
||||
session: Session,
|
||||
database_id: int,
|
||||
default_schema: Optional[str],
|
||||
tables: Set[Table],
|
||||
conditional_quote: Callable[[str], str],
|
||||
) -> List[NewTable]:
|
||||
"""
|
||||
Load or create new table model instances.
|
||||
"""
|
||||
if not tables:
|
||||
return []
|
||||
|
||||
# set the default schema in tables that don't have it
|
||||
if default_schema:
|
||||
tables = list(tables)
|
||||
for i, table in enumerate(tables):
|
||||
if table.schema is None:
|
||||
tables[i] = Table(table.table, default_schema, table.catalog)
|
||||
|
||||
# load existing tables
|
||||
predicate = or_(
|
||||
*[
|
||||
and_(
|
||||
NewTable.database_id == database_id,
|
||||
NewTable.schema == table.schema,
|
||||
NewTable.name == table.table,
|
||||
)
|
||||
for table in tables
|
||||
]
|
||||
)
|
||||
new_tables = session.query(NewTable).filter(predicate).all()
|
||||
|
||||
# use original database model to get the engine
|
||||
engine = (
|
||||
session.query(OriginalDatabase)
|
||||
.filter_by(id=database_id)
|
||||
.one()
|
||||
.get_sqla_engine(default_schema)
|
||||
)
|
||||
inspector = inspect(engine)
|
||||
|
||||
# add missing tables
|
||||
existing = {(table.schema, table.name) for table in new_tables}
|
||||
for table in tables:
|
||||
if (table.schema, table.table) not in existing:
|
||||
column_metadata = inspector.get_columns(table.table, schema=table.schema)
|
||||
columns = [
|
||||
NewColumn(
|
||||
name=column["name"],
|
||||
type=str(column["type"]),
|
||||
expression=conditional_quote(column["name"]),
|
||||
is_temporal=is_column_type_temporal(column["type"]),
|
||||
is_aggregation=False,
|
||||
is_physical=True,
|
||||
is_spatial=False,
|
||||
is_partition=False,
|
||||
is_increase_desired=True,
|
||||
)
|
||||
for column in column_metadata
|
||||
]
|
||||
new_tables.append(
|
||||
NewTable(
|
||||
name=table.table,
|
||||
schema=table.schema,
|
||||
catalog=None,
|
||||
database_id=database_id,
|
||||
columns=columns,
|
||||
)
|
||||
)
|
||||
existing.add((table.schema, table.table))
|
||||
|
||||
return new_tables
|
||||
|
||||
|
||||
def after_insert(target: SqlaTable) -> None: # pylint: disable=too-many-locals
|
||||
"""
|
||||
Copy old datasets to the new models.
|
||||
"""
|
||||
session = inspect(target).session
|
||||
|
||||
# get DB-specific conditional quoter for expressions that point to columns or
|
||||
# table names
|
||||
database = (
|
||||
target.database
|
||||
or session.query(Database).filter_by(id=target.database_id).first()
|
||||
)
|
||||
if not database:
|
||||
return
|
||||
url = make_url_safe(database.sqlalchemy_uri)
|
||||
dialect_class = url.get_dialect()
|
||||
conditional_quote = dialect_class().identifier_preparer.quote
|
||||
|
||||
# create columns
|
||||
columns = []
|
||||
for column in target.columns:
|
||||
# ``is_active`` might be ``None`` at this point, but it defaults to ``True``.
|
||||
if column.is_active is False:
|
||||
continue
|
||||
|
||||
try:
|
||||
extra_json = json.loads(column.extra or "{}")
|
||||
except json.decoder.JSONDecodeError:
|
||||
extra_json = {}
|
||||
for attr in {"groupby", "filterable", "verbose_name", "python_date_format"}:
|
||||
value = getattr(column, attr)
|
||||
if value:
|
||||
extra_json[attr] = value
|
||||
|
||||
columns.append(
|
||||
NewColumn(
|
||||
name=column.column_name,
|
||||
type=column.type or "Unknown",
|
||||
expression=column.expression or conditional_quote(column.column_name),
|
||||
description=column.description,
|
||||
is_temporal=column.is_dttm,
|
||||
is_aggregation=False,
|
||||
is_physical=column.expression is None or column.expression == "",
|
||||
is_spatial=False,
|
||||
is_partition=False,
|
||||
is_increase_desired=True,
|
||||
extra_json=json.dumps(extra_json) if extra_json else None,
|
||||
is_managed_externally=target.is_managed_externally,
|
||||
external_url=target.external_url,
|
||||
),
|
||||
)
|
||||
|
||||
# create metrics
|
||||
for metric in target.metrics:
|
||||
try:
|
||||
extra_json = json.loads(metric.extra or "{}")
|
||||
except json.decoder.JSONDecodeError:
|
||||
extra_json = {}
|
||||
for attr in {"verbose_name", "metric_type", "d3format"}:
|
||||
value = getattr(metric, attr)
|
||||
if value:
|
||||
extra_json[attr] = value
|
||||
|
||||
is_additive = (
|
||||
metric.metric_type and metric.metric_type.lower() in ADDITIVE_METRIC_TYPES
|
||||
)
|
||||
|
||||
columns.append(
|
||||
NewColumn(
|
||||
name=metric.metric_name,
|
||||
type="Unknown", # figuring this out would require a type inferrer
|
||||
expression=metric.expression,
|
||||
warning_text=metric.warning_text,
|
||||
description=metric.description,
|
||||
is_aggregation=True,
|
||||
is_additive=is_additive,
|
||||
is_physical=False,
|
||||
is_spatial=False,
|
||||
is_partition=False,
|
||||
is_increase_desired=True,
|
||||
extra_json=json.dumps(extra_json) if extra_json else None,
|
||||
is_managed_externally=target.is_managed_externally,
|
||||
external_url=target.external_url,
|
||||
),
|
||||
)
|
||||
|
||||
# physical dataset
|
||||
if not target.sql:
|
||||
physical_columns = [column for column in columns if column.is_physical]
|
||||
|
||||
# create table
|
||||
table = NewTable(
|
||||
name=target.table_name,
|
||||
schema=target.schema,
|
||||
catalog=None, # currently not supported
|
||||
database_id=target.database_id,
|
||||
columns=physical_columns,
|
||||
is_managed_externally=target.is_managed_externally,
|
||||
external_url=target.external_url,
|
||||
)
|
||||
tables = [table]
|
||||
|
||||
# virtual dataset
|
||||
else:
|
||||
# mark all columns as virtual (not physical)
|
||||
for column in columns:
|
||||
column.is_physical = False
|
||||
|
||||
# find referenced tables
|
||||
referenced_tables = extract_table_references(target.sql, dialect_class.name)
|
||||
tables = load_or_create_tables(
|
||||
session,
|
||||
target.database_id,
|
||||
target.schema,
|
||||
referenced_tables,
|
||||
conditional_quote,
|
||||
)
|
||||
|
||||
# create the new dataset
|
||||
dataset = NewDataset(
|
||||
sqlatable_id=target.id,
|
||||
name=target.table_name,
|
||||
expression=target.sql or conditional_quote(target.table_name),
|
||||
tables=tables,
|
||||
columns=columns,
|
||||
is_physical=not target.sql,
|
||||
is_managed_externally=target.is_managed_externally,
|
||||
external_url=target.external_url,
|
||||
)
|
||||
session.add(dataset)
|
||||
|
||||
|
||||
def upgrade():
|
||||
# Create tables for the new models.
|
||||
op.create_table(
|
||||
"sl_columns",
|
||||
# AuditMixinNullable
|
||||
sa.Column("created_on", sa.DateTime(), nullable=True),
|
||||
sa.Column("changed_on", sa.DateTime(), nullable=True),
|
||||
sa.Column("created_by_fk", sa.Integer(), nullable=True),
|
||||
sa.Column("changed_by_fk", sa.Integer(), nullable=True),
|
||||
# ExtraJSONMixin
|
||||
sa.Column("extra_json", sa.Text(), nullable=True),
|
||||
# ImportExportMixin
|
||||
sa.Column("uuid", UUIDType(binary=True), primary_key=False, default=uuid4),
|
||||
# Column
|
||||
sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False),
|
||||
sa.Column("name", sa.TEXT(), nullable=False),
|
||||
sa.Column("type", sa.TEXT(), nullable=False),
|
||||
sa.Column("expression", sa.TEXT(), nullable=False),
|
||||
sa.Column(
|
||||
"is_physical",
|
||||
sa.BOOLEAN(),
|
||||
nullable=False,
|
||||
default=True,
|
||||
),
|
||||
sa.Column("description", sa.TEXT(), nullable=True),
|
||||
sa.Column("warning_text", sa.TEXT(), nullable=True),
|
||||
sa.Column("unit", sa.TEXT(), nullable=True),
|
||||
sa.Column("is_temporal", sa.BOOLEAN(), nullable=False),
|
||||
sa.Column(
|
||||
"is_spatial",
|
||||
sa.BOOLEAN(),
|
||||
nullable=False,
|
||||
default=False,
|
||||
),
|
||||
sa.Column(
|
||||
"is_partition",
|
||||
sa.BOOLEAN(),
|
||||
nullable=False,
|
||||
default=False,
|
||||
),
|
||||
sa.Column(
|
||||
"is_aggregation",
|
||||
sa.BOOLEAN(),
|
||||
nullable=False,
|
||||
default=False,
|
||||
),
|
||||
sa.Column(
|
||||
"is_additive",
|
||||
sa.BOOLEAN(),
|
||||
nullable=False,
|
||||
default=False,
|
||||
),
|
||||
sa.Column(
|
||||
"is_increase_desired",
|
||||
sa.BOOLEAN(),
|
||||
nullable=False,
|
||||
default=True,
|
||||
),
|
||||
sa.Column(
|
||||
"is_managed_externally",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
sa.Column("external_url", sa.Text(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
with op.batch_alter_table("sl_columns") as batch_op:
|
||||
batch_op.create_unique_constraint("uq_sl_columns_uuid", ["uuid"])
|
||||
|
||||
op.create_table(
|
||||
"sl_tables",
|
||||
# AuditMixinNullable
|
||||
sa.Column("created_on", sa.DateTime(), nullable=True),
|
||||
sa.Column("changed_on", sa.DateTime(), nullable=True),
|
||||
sa.Column("created_by_fk", sa.Integer(), nullable=True),
|
||||
sa.Column("changed_by_fk", sa.Integer(), nullable=True),
|
||||
# ExtraJSONMixin
|
||||
sa.Column("extra_json", sa.Text(), nullable=True),
|
||||
# ImportExportMixin
|
||||
sa.Column("uuid", UUIDType(binary=True), primary_key=False, default=uuid4),
|
||||
# Table
|
||||
sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False),
|
||||
sa.Column("database_id", sa.INTEGER(), autoincrement=False, nullable=False),
|
||||
sa.Column("catalog", sa.TEXT(), nullable=True),
|
||||
sa.Column("schema", sa.TEXT(), nullable=True),
|
||||
sa.Column("name", sa.TEXT(), nullable=False),
|
||||
sa.Column(
|
||||
"is_managed_externally",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
sa.Column("external_url", sa.Text(), nullable=True),
|
||||
sa.ForeignKeyConstraint(["database_id"], ["dbs.id"], name="sl_tables_ibfk_1"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
with op.batch_alter_table("sl_tables") as batch_op:
|
||||
batch_op.create_unique_constraint("uq_sl_tables_uuid", ["uuid"])
|
||||
|
||||
op.create_table(
|
||||
"sl_table_columns",
|
||||
sa.Column("table_id", sa.INTEGER(), autoincrement=False, nullable=False),
|
||||
sa.Column("column_id", sa.INTEGER(), autoincrement=False, nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["column_id"], ["sl_columns.id"], name="sl_table_columns_ibfk_2"
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["table_id"], ["sl_tables.id"], name="sl_table_columns_ibfk_1"
|
||||
),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"sl_datasets",
|
||||
# AuditMixinNullable
|
||||
sa.Column("created_on", sa.DateTime(), nullable=True),
|
||||
sa.Column("changed_on", sa.DateTime(), nullable=True),
|
||||
sa.Column("created_by_fk", sa.Integer(), nullable=True),
|
||||
sa.Column("changed_by_fk", sa.Integer(), nullable=True),
|
||||
# ExtraJSONMixin
|
||||
sa.Column("extra_json", sa.Text(), nullable=True),
|
||||
# ImportExportMixin
|
||||
sa.Column("uuid", UUIDType(binary=True), primary_key=False, default=uuid4),
|
||||
# Dataset
|
||||
sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False),
|
||||
sa.Column("sqlatable_id", sa.INTEGER(), nullable=True),
|
||||
sa.Column("name", sa.TEXT(), nullable=False),
|
||||
sa.Column("expression", sa.TEXT(), nullable=False),
|
||||
sa.Column(
|
||||
"is_physical",
|
||||
sa.BOOLEAN(),
|
||||
nullable=False,
|
||||
default=False,
|
||||
),
|
||||
sa.Column(
|
||||
"is_managed_externally",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
sa.Column("external_url", sa.Text(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
with op.batch_alter_table("sl_datasets") as batch_op:
|
||||
batch_op.create_unique_constraint("uq_sl_datasets_uuid", ["uuid"])
|
||||
batch_op.create_unique_constraint(
|
||||
"uq_sl_datasets_sqlatable_id", ["sqlatable_id"]
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"sl_dataset_columns",
|
||||
sa.Column("dataset_id", sa.INTEGER(), autoincrement=False, nullable=False),
|
||||
sa.Column("column_id", sa.INTEGER(), autoincrement=False, nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["column_id"], ["sl_columns.id"], name="sl_dataset_columns_ibfk_2"
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["dataset_id"], ["sl_datasets.id"], name="sl_dataset_columns_ibfk_1"
|
||||
),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"sl_dataset_tables",
|
||||
sa.Column("dataset_id", sa.INTEGER(), autoincrement=False, nullable=False),
|
||||
sa.Column("table_id", sa.INTEGER(), autoincrement=False, nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["dataset_id"], ["sl_datasets.id"], name="sl_dataset_tables_ibfk_1"
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["table_id"], ["sl_tables.id"], name="sl_dataset_tables_ibfk_2"
|
||||
),
|
||||
)
|
||||
|
||||
# migrate existing datasets to the new models
|
||||
bind = op.get_bind()
|
||||
session = db.Session(bind=bind) # pylint: disable=no-member
|
||||
|
||||
datasets = session.query(SqlaTable).all()
|
||||
for dataset in datasets:
|
||||
dataset.fetch_columns_and_metrics(session)
|
||||
after_insert(target=dataset)
|
||||
def upgrade() -> None:
|
||||
pass
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_table("sl_dataset_columns")
|
||||
op.drop_table("sl_dataset_tables")
|
||||
op.drop_table("sl_datasets")
|
||||
op.drop_table("sl_table_columns")
|
||||
op.drop_table("sl_tables")
|
||||
op.drop_table("sl_columns")
|
||||
pass
|
||||
|
|
|
@ -38,7 +38,7 @@ from sqlalchemy_utils import UUIDType
|
|||
|
||||
from superset import db
|
||||
from superset.migrations.versions.b56500de1855_add_uuid_column_to_import_mixin import (
|
||||
add_uuids,
|
||||
assign_uuids,
|
||||
models,
|
||||
update_dashboards,
|
||||
)
|
||||
|
@ -73,7 +73,7 @@ def upgrade():
|
|||
default=uuid4,
|
||||
),
|
||||
)
|
||||
add_uuids(model, table_name, session)
|
||||
assign_uuids(model, session)
|
||||
|
||||
# add uniqueness constraint
|
||||
with op.batch_alter_table(table_name) as batch_op:
|
||||
|
|
|
@ -71,7 +71,7 @@ def downgrade_filters(native_filters: Iterable[Dict[str, Any]]) -> int:
|
|||
filter_state = default_data_mask.get("filterState")
|
||||
if filter_state is not None:
|
||||
changed_filters += 1
|
||||
value = filter_state["value"]
|
||||
value = filter_state.get("value")
|
||||
native_filter["defaultValue"] = value
|
||||
return changed_filters
|
||||
|
||||
|
|
|
@ -408,12 +408,14 @@ class Database(
|
|||
except Exception as ex:
|
||||
raise self.db_engine_spec.get_dbapi_mapped_exception(ex)
|
||||
|
||||
@property
|
||||
def quote_identifier(self) -> Callable[[str], str]:
|
||||
"""Add quotes to potential identifiter expressions if needed"""
|
||||
return self.get_dialect().identifier_preparer.quote
|
||||
|
||||
def get_reserved_words(self) -> Set[str]:
|
||||
return self.get_dialect().preparer.reserved_words
|
||||
|
||||
def get_quoter(self) -> Callable[[str, Any], str]:
|
||||
return self.get_dialect().identifier_preparer.quote
|
||||
|
||||
def get_df( # pylint: disable=too-many-locals
|
||||
self,
|
||||
sql: str,
|
||||
|
|
|
@ -477,7 +477,7 @@ class ExtraJSONMixin:
|
|||
@property
|
||||
def extra(self) -> Dict[str, Any]:
|
||||
try:
|
||||
return json.loads(self.extra_json)
|
||||
return json.loads(self.extra_json) if self.extra_json else {}
|
||||
except (TypeError, JSONDecodeError) as exc:
|
||||
logger.error(
|
||||
"Unable to load an extra json: %r. Leaving empty.", exc, exc_info=True
|
||||
|
@ -522,18 +522,23 @@ class CertificationMixin:
|
|||
|
||||
|
||||
def clone_model(
|
||||
target: Model, ignore: Optional[List[str]] = None, **kwargs: Any
|
||||
target: Model,
|
||||
ignore: Optional[List[str]] = None,
|
||||
keep_relations: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Model:
|
||||
"""
|
||||
Clone a SQLAlchemy model.
|
||||
Clone a SQLAlchemy model. By default will only clone naive column attributes.
|
||||
To include relationship attributes, use `keep_relations`.
|
||||
"""
|
||||
ignore = ignore or []
|
||||
|
||||
table = target.__table__
|
||||
primary_keys = table.primary_key.columns.keys()
|
||||
data = {
|
||||
attr: getattr(target, attr)
|
||||
for attr in table.columns.keys()
|
||||
if attr not in table.primary_key.columns.keys() and attr not in ignore
|
||||
for attr in list(table.columns.keys()) + (keep_relations or [])
|
||||
if attr not in primary_keys and attr not in ignore
|
||||
}
|
||||
data.update(kwargs)
|
||||
|
||||
|
|
|
@ -186,7 +186,7 @@ def execute_sql_statement( # pylint: disable=too-many-arguments,too-many-locals
|
|||
apply_ctas: bool = False,
|
||||
) -> SupersetResultSet:
|
||||
"""Executes a single SQL statement"""
|
||||
database = query.database
|
||||
database: Database = query.database
|
||||
db_engine_spec = database.db_engine_spec
|
||||
parsed_query = ParsedQuery(sql_statement)
|
||||
sql = parsed_query.stripped()
|
||||
|
|
|
@ -18,7 +18,7 @@ import logging
|
|||
import re
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import cast, List, Optional, Set, Tuple
|
||||
from typing import Any, cast, Iterator, List, Optional, Set, Tuple
|
||||
from urllib import parse
|
||||
|
||||
import sqlparse
|
||||
|
@ -47,10 +47,16 @@ from sqlparse.utils import imt
|
|||
|
||||
from superset.exceptions import QueryClauseValidationException
|
||||
|
||||
try:
|
||||
from sqloxide import parse_sql as sqloxide_parse
|
||||
except: # pylint: disable=bare-except
|
||||
sqloxide_parse = None
|
||||
|
||||
RESULT_OPERATIONS = {"UNION", "INTERSECT", "EXCEPT", "SELECT"}
|
||||
ON_KEYWORD = "ON"
|
||||
PRECEDES_TABLE_NAME = {"FROM", "JOIN", "DESCRIBE", "WITH", "LEFT JOIN", "RIGHT JOIN"}
|
||||
CTE_PREFIX = "CTE__"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -176,6 +182,9 @@ class Table:
|
|||
if part
|
||||
)
|
||||
|
||||
def __eq__(self, __o: object) -> bool:
|
||||
return str(self) == str(__o)
|
||||
|
||||
|
||||
class ParsedQuery:
|
||||
def __init__(self, sql_statement: str, strip_comments: bool = False):
|
||||
|
@ -698,3 +707,75 @@ def insert_rls(
|
|||
)
|
||||
|
||||
return token_list
|
||||
|
||||
|
||||
# mapping between sqloxide and SQLAlchemy dialects
|
||||
SQLOXITE_DIALECTS = {
|
||||
"ansi": {"trino", "trinonative", "presto"},
|
||||
"hive": {"hive", "databricks"},
|
||||
"ms": {"mssql"},
|
||||
"mysql": {"mysql"},
|
||||
"postgres": {
|
||||
"cockroachdb",
|
||||
"hana",
|
||||
"netezza",
|
||||
"postgres",
|
||||
"postgresql",
|
||||
"redshift",
|
||||
"vertica",
|
||||
},
|
||||
"snowflake": {"snowflake"},
|
||||
"sqlite": {"sqlite", "gsheets", "shillelagh"},
|
||||
"clickhouse": {"clickhouse"},
|
||||
}
|
||||
|
||||
RE_JINJA_VAR = re.compile(r"\{\{[^\{\}]+\}\}")
|
||||
RE_JINJA_BLOCK = re.compile(r"\{[%#][^\{\}%#]+[%#]\}")
|
||||
|
||||
|
||||
def extract_table_references(
|
||||
sql_text: str, sqla_dialect: str, show_warning: bool = True
|
||||
) -> Set["Table"]:
|
||||
"""
|
||||
Return all the dependencies from a SQL sql_text.
|
||||
"""
|
||||
dialect = "generic"
|
||||
tree = None
|
||||
|
||||
if sqloxide_parse:
|
||||
for dialect, sqla_dialects in SQLOXITE_DIALECTS.items():
|
||||
if sqla_dialect in sqla_dialects:
|
||||
break
|
||||
sql_text = RE_JINJA_BLOCK.sub(" ", sql_text)
|
||||
sql_text = RE_JINJA_VAR.sub("abc", sql_text)
|
||||
try:
|
||||
tree = sqloxide_parse(sql_text, dialect=dialect)
|
||||
except Exception as ex: # pylint: disable=broad-except
|
||||
if show_warning:
|
||||
logger.warning(
|
||||
"\nUnable to parse query with sqloxide:\n%s\n%s", sql_text, ex
|
||||
)
|
||||
|
||||
# fallback to sqlparse
|
||||
if not tree:
|
||||
parsed = ParsedQuery(sql_text)
|
||||
return parsed.tables
|
||||
|
||||
def find_nodes_by_key(element: Any, target: str) -> Iterator[Any]:
|
||||
"""
|
||||
Find all nodes in a SQL tree matching a given key.
|
||||
"""
|
||||
if isinstance(element, list):
|
||||
for child in element:
|
||||
yield from find_nodes_by_key(child, target)
|
||||
elif isinstance(element, dict):
|
||||
for key, value in element.items():
|
||||
if key == target:
|
||||
yield value
|
||||
else:
|
||||
yield from find_nodes_by_key(value, target)
|
||||
|
||||
return {
|
||||
Table(*[part["value"] for part in table["name"][::-1]])
|
||||
for table in find_nodes_by_key(tree, "Table")
|
||||
}
|
||||
|
|
|
@ -24,26 +24,41 @@ addition to a table, new models for columns, metrics, and datasets were also int
|
|||
These models are not fully implemented, and shouldn't be used yet.
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
from typing import Any, Dict, Iterable, List, Optional, TYPE_CHECKING
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask_appbuilder import Model
|
||||
from sqlalchemy.orm import backref, relationship
|
||||
from sqlalchemy import inspect
|
||||
from sqlalchemy.orm import backref, relationship, Session
|
||||
from sqlalchemy.schema import UniqueConstraint
|
||||
from sqlalchemy.sql import and_, or_
|
||||
|
||||
from superset.columns.models import Column
|
||||
from superset.connectors.sqla.utils import get_physical_table_metadata
|
||||
from superset.models.core import Database
|
||||
from superset.models.helpers import (
|
||||
AuditMixinNullable,
|
||||
ExtraJSONMixin,
|
||||
ImportExportMixin,
|
||||
)
|
||||
from superset.sql_parse import Table as TableName
|
||||
|
||||
association_table = sa.Table(
|
||||
if TYPE_CHECKING:
|
||||
from superset.datasets.models import Dataset
|
||||
|
||||
table_column_association_table = sa.Table(
|
||||
"sl_table_columns",
|
||||
Model.metadata, # pylint: disable=no-member
|
||||
sa.Column("table_id", sa.ForeignKey("sl_tables.id")),
|
||||
sa.Column("column_id", sa.ForeignKey("sl_columns.id")),
|
||||
sa.Column(
|
||||
"table_id",
|
||||
sa.ForeignKey("sl_tables.id", ondelete="cascade"),
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column(
|
||||
"column_id",
|
||||
sa.ForeignKey("sl_columns.id", ondelete="cascade"),
|
||||
primary_key=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
@ -61,7 +76,6 @@ class Table(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin):
|
|||
__table_args__ = (UniqueConstraint("database_id", "catalog", "schema", "name"),)
|
||||
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
database_id = sa.Column(sa.Integer, sa.ForeignKey("dbs.id"), nullable=False)
|
||||
database: Database = relationship(
|
||||
"Database",
|
||||
|
@ -70,6 +84,19 @@ class Table(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin):
|
|||
backref=backref("new_tables", cascade="all, delete-orphan"),
|
||||
foreign_keys=[database_id],
|
||||
)
|
||||
# The relationship between datasets and columns is 1:n, but we use a
|
||||
# many-to-many association table to avoid adding two mutually exclusive
|
||||
# columns(dataset_id and table_id) to Column
|
||||
columns: List[Column] = relationship(
|
||||
"Column",
|
||||
secondary=table_column_association_table,
|
||||
cascade="all, delete-orphan",
|
||||
single_parent=True,
|
||||
# backref is needed for session to skip detaching `dataset` if only `column`
|
||||
# is loaded.
|
||||
backref="tables",
|
||||
)
|
||||
datasets: List["Dataset"] # will be populated by Dataset.tables backref
|
||||
|
||||
# We use ``sa.Text`` for these attributes because (1) in modern databases the
|
||||
# performance is the same as ``VARCHAR``[1] and (2) because some table names can be
|
||||
|
@ -80,13 +107,96 @@ class Table(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin):
|
|||
schema = sa.Column(sa.Text)
|
||||
name = sa.Column(sa.Text)
|
||||
|
||||
# The relationship between tables and columns is 1:n, but we use a many-to-many
|
||||
# association to differentiate between the relationship between datasets and
|
||||
# columns.
|
||||
columns: List[Column] = relationship(
|
||||
"Column", secondary=association_table, cascade="all, delete"
|
||||
)
|
||||
|
||||
# Column is managed externally and should be read-only inside Superset
|
||||
is_managed_externally = sa.Column(sa.Boolean, nullable=False, default=False)
|
||||
external_url = sa.Column(sa.Text, nullable=True)
|
||||
|
||||
@property
|
||||
def fullname(self) -> str:
|
||||
return str(TableName(table=self.name, schema=self.schema, catalog=self.catalog))
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Table id={self.id} database_id={self.database_id} {self.fullname}>"
|
||||
|
||||
def sync_columns(self) -> None:
|
||||
"""Sync table columns with the database. Keep metadata for existing columns"""
|
||||
try:
|
||||
column_metadata = get_physical_table_metadata(
|
||||
self.database, self.name, self.schema
|
||||
)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
column_metadata = []
|
||||
|
||||
existing_columns = {column.name: column for column in self.columns}
|
||||
quote_identifier = self.database.quote_identifier
|
||||
|
||||
def update_or_create_column(column_meta: Dict[str, Any]) -> Column:
|
||||
column_name: str = column_meta["name"]
|
||||
if column_name in existing_columns:
|
||||
column = existing_columns[column_name]
|
||||
else:
|
||||
column = Column(name=column_name)
|
||||
column.type = column_meta["type"]
|
||||
column.is_temporal = column_meta["is_dttm"]
|
||||
column.expression = quote_identifier(column_name)
|
||||
column.is_aggregation = False
|
||||
column.is_physical = True
|
||||
column.is_spatial = False
|
||||
column.is_partition = False # TODO: update with accurate is_partition
|
||||
return column
|
||||
|
||||
self.columns = [update_or_create_column(col) for col in column_metadata]
|
||||
|
||||
@staticmethod
|
||||
def bulk_load_or_create(
|
||||
database: Database,
|
||||
table_names: Iterable[TableName],
|
||||
default_schema: Optional[str] = None,
|
||||
sync_columns: Optional[bool] = False,
|
||||
default_props: Optional[Dict[str, Any]] = None,
|
||||
) -> List["Table"]:
|
||||
"""
|
||||
Load or create multiple Table instances.
|
||||
"""
|
||||
if not table_names:
|
||||
return []
|
||||
|
||||
if not database.id:
|
||||
raise Exception("Database must be already saved to metastore")
|
||||
|
||||
default_props = default_props or {}
|
||||
session: Session = inspect(database).session
|
||||
# load existing tables
|
||||
predicate = or_(
|
||||
*[
|
||||
and_(
|
||||
Table.database_id == database.id,
|
||||
Table.schema == (table.schema or default_schema),
|
||||
Table.name == table.table,
|
||||
)
|
||||
for table in table_names
|
||||
]
|
||||
)
|
||||
all_tables = session.query(Table).filter(predicate).order_by(Table.id).all()
|
||||
|
||||
# add missing tables and pull its columns
|
||||
existing = {(table.schema, table.name) for table in all_tables}
|
||||
for table in table_names:
|
||||
schema = table.schema or default_schema
|
||||
name = table.table
|
||||
if (schema, name) not in existing:
|
||||
new_table = Table(
|
||||
database=database,
|
||||
database_id=database.id,
|
||||
name=name,
|
||||
schema=schema,
|
||||
catalog=None,
|
||||
**default_props,
|
||||
)
|
||||
if sync_columns:
|
||||
new_table.sync_columns()
|
||||
all_tables.append(new_table)
|
||||
existing.add((schema, name))
|
||||
session.add(new_table)
|
||||
|
||||
return all_tables
|
||||
|
|
|
@ -16,11 +16,11 @@
|
|||
# under the License.
|
||||
import copy
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
import yaml
|
||||
from flask import g
|
||||
|
||||
from superset import db, security_manager
|
||||
from superset import db
|
||||
from superset.commands.exceptions import CommandInvalidError
|
||||
from superset.commands.importers.v1.assets import ImportAssetsCommand
|
||||
from superset.commands.importers.v1.utils import is_valid_config
|
||||
|
@ -58,10 +58,13 @@ class TestImportersV1Utils(SupersetTestCase):
|
|||
|
||||
|
||||
class TestImportAssetsCommand(SupersetTestCase):
|
||||
@patch("superset.dashboards.commands.importers.v1.utils.g")
|
||||
def test_import_assets(self, mock_g):
|
||||
def setUp(self):
|
||||
user = self.get_user("admin")
|
||||
self.user = user
|
||||
setattr(g, "user", user)
|
||||
|
||||
def test_import_assets(self):
|
||||
"""Test that we can import multiple assets"""
|
||||
mock_g.user = security_manager.find_user("admin")
|
||||
contents = {
|
||||
"metadata.yaml": yaml.safe_dump(metadata_config),
|
||||
"databases/imported_database.yaml": yaml.safe_dump(database_config),
|
||||
|
@ -141,7 +144,7 @@ class TestImportAssetsCommand(SupersetTestCase):
|
|||
database = dataset.database
|
||||
assert str(database.uuid) == database_config["uuid"]
|
||||
|
||||
assert dashboard.owners == [mock_g.user]
|
||||
assert dashboard.owners == [self.user]
|
||||
|
||||
dashboard.owners = []
|
||||
chart.owners = []
|
||||
|
@ -153,11 +156,8 @@ class TestImportAssetsCommand(SupersetTestCase):
|
|||
db.session.delete(database)
|
||||
db.session.commit()
|
||||
|
||||
@patch("superset.dashboards.commands.importers.v1.utils.g")
|
||||
def test_import_v1_dashboard_overwrite(self, mock_g):
|
||||
def test_import_v1_dashboard_overwrite(self):
|
||||
"""Test that assets can be overwritten"""
|
||||
mock_g.user = security_manager.find_user("admin")
|
||||
|
||||
contents = {
|
||||
"metadata.yaml": yaml.safe_dump(metadata_config),
|
||||
"databases/imported_database.yaml": yaml.safe_dump(database_config),
|
||||
|
|
|
@ -111,11 +111,10 @@ def _commit_slices(slices: List[Slice]):
|
|||
|
||||
|
||||
def _create_world_bank_dashboard(table: SqlaTable, slices: List[Slice]) -> Dashboard:
|
||||
from superset.examples.helpers import update_slice_ids
|
||||
from superset.examples.world_bank import dashboard_positions
|
||||
|
||||
pos = dashboard_positions
|
||||
from superset.examples.helpers import update_slice_ids
|
||||
|
||||
update_slice_ids(pos, slices)
|
||||
|
||||
table.fetch_metadata()
|
||||
|
|
|
@ -455,7 +455,8 @@ class TestDatabaseModel(SupersetTestCase):
|
|||
|
||||
# make sure the columns have been mapped properly
|
||||
assert len(table.columns) == 4
|
||||
table.fetch_metadata()
|
||||
table.fetch_metadata(commit=False)
|
||||
|
||||
# assert that the removed column has been dropped and
|
||||
# the physical and calculated columns are present
|
||||
assert {col.column_name for col in table.columns} == {
|
||||
|
@ -473,6 +474,8 @@ class TestDatabaseModel(SupersetTestCase):
|
|||
assert VIRTUAL_TABLE_STRING_TYPES[backend].match(cols["mycase"].type)
|
||||
assert cols["expr"].expression == "case when 1 then 1 else 0 end"
|
||||
|
||||
db.session.delete(table)
|
||||
|
||||
@patch("superset.models.core.Database.db_engine_spec", BigQueryEngineSpec)
|
||||
def test_labels_expected_on_mutated_query(self):
|
||||
query_obj = {
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# isort:skip_file
|
||||
import unittest
|
||||
import uuid
|
||||
from datetime import date, datetime, time, timedelta
|
||||
from decimal import Decimal
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
# pylint: disable=redefined-outer-name, import-outside-toplevel
|
||||
|
||||
import importlib
|
||||
from typing import Any, Iterator
|
||||
from typing import Any, Callable, Iterator
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockFixture
|
||||
|
@ -31,25 +31,33 @@ from superset.initialization import SupersetAppInitializer
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def session(mocker: MockFixture) -> Iterator[Session]:
|
||||
def get_session(mocker: MockFixture) -> Callable[[], Session]:
|
||||
"""
|
||||
Create an in-memory SQLite session to test models.
|
||||
"""
|
||||
engine = create_engine("sqlite://")
|
||||
Session_ = sessionmaker(bind=engine) # pylint: disable=invalid-name
|
||||
in_memory_session = Session_()
|
||||
|
||||
# flask calls session.remove()
|
||||
in_memory_session.remove = lambda: None
|
||||
def get_session():
|
||||
Session_ = sessionmaker(bind=engine) # pylint: disable=invalid-name
|
||||
in_memory_session = Session_()
|
||||
|
||||
# patch session
|
||||
mocker.patch(
|
||||
"superset.security.SupersetSecurityManager.get_session",
|
||||
return_value=in_memory_session,
|
||||
)
|
||||
mocker.patch("superset.db.session", in_memory_session)
|
||||
# flask calls session.remove()
|
||||
in_memory_session.remove = lambda: None
|
||||
|
||||
yield in_memory_session
|
||||
# patch session
|
||||
mocker.patch(
|
||||
"superset.security.SupersetSecurityManager.get_session",
|
||||
return_value=in_memory_session,
|
||||
)
|
||||
mocker.patch("superset.db.session", in_memory_session)
|
||||
return in_memory_session
|
||||
|
||||
return get_session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session(get_session) -> Iterator[Session]:
|
||||
yield get_session()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
|
|
|
@ -0,0 +1,118 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from typing import Any, Dict, TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.connectors.sqla.models import SqlMetric, TableColumn
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def columns_default() -> Dict[str, Any]:
|
||||
"""Default props for new columns"""
|
||||
return {
|
||||
"changed_by": 1,
|
||||
"created_by": 1,
|
||||
"datasets": [],
|
||||
"tables": [],
|
||||
"is_additive": False,
|
||||
"is_aggregation": False,
|
||||
"is_dimensional": False,
|
||||
"is_filterable": True,
|
||||
"is_increase_desired": True,
|
||||
"is_partition": False,
|
||||
"is_physical": True,
|
||||
"is_spatial": False,
|
||||
"is_temporal": False,
|
||||
"description": None,
|
||||
"extra_json": "{}",
|
||||
"unit": None,
|
||||
"warning_text": None,
|
||||
"is_managed_externally": False,
|
||||
"external_url": None,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_columns() -> Dict["TableColumn", Dict[str, Any]]:
|
||||
from superset.connectors.sqla.models import TableColumn
|
||||
|
||||
return {
|
||||
TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"): {
|
||||
"name": "ds",
|
||||
"expression": "ds",
|
||||
"type": "TIMESTAMP",
|
||||
"is_temporal": True,
|
||||
"is_physical": True,
|
||||
},
|
||||
TableColumn(column_name="num_boys", type="INTEGER", groupby=True): {
|
||||
"name": "num_boys",
|
||||
"expression": "num_boys",
|
||||
"type": "INTEGER",
|
||||
"is_dimensional": True,
|
||||
"is_physical": True,
|
||||
},
|
||||
TableColumn(column_name="region", type="VARCHAR", groupby=True): {
|
||||
"name": "region",
|
||||
"expression": "region",
|
||||
"type": "VARCHAR",
|
||||
"is_dimensional": True,
|
||||
"is_physical": True,
|
||||
},
|
||||
TableColumn(
|
||||
column_name="profit",
|
||||
type="INTEGER",
|
||||
groupby=False,
|
||||
expression="revenue-expenses",
|
||||
): {
|
||||
"name": "profit",
|
||||
"expression": "revenue-expenses",
|
||||
"type": "INTEGER",
|
||||
"is_physical": False,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_metrics() -> Dict["SqlMetric", Dict[str, Any]]:
|
||||
from superset.connectors.sqla.models import SqlMetric
|
||||
|
||||
return {
|
||||
SqlMetric(metric_name="cnt", expression="COUNT(*)", metric_type="COUNT"): {
|
||||
"name": "cnt",
|
||||
"expression": "COUNT(*)",
|
||||
"extra_json": '{"metric_type": "COUNT"}',
|
||||
"type": "UNKNOWN",
|
||||
"is_additive": True,
|
||||
"is_aggregation": True,
|
||||
"is_filterable": False,
|
||||
"is_physical": False,
|
||||
},
|
||||
SqlMetric(
|
||||
metric_name="avg revenue", expression="AVG(revenue)", metric_type="AVG"
|
||||
): {
|
||||
"name": "avg revenue",
|
||||
"expression": "AVG(revenue)",
|
||||
"extra_json": '{"metric_type": "AVG"}',
|
||||
"type": "UNKNOWN",
|
||||
"is_additive": False,
|
||||
"is_aggregation": True,
|
||||
"is_filterable": False,
|
||||
"is_physical": False,
|
||||
},
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -1,16 +0,0 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
|
@ -1,56 +0,0 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# pylint: disable=import-outside-toplevel, unused-argument
|
||||
|
||||
"""
|
||||
Test the SIP-68 migration.
|
||||
"""
|
||||
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from superset.sql_parse import Table
|
||||
|
||||
|
||||
def test_extract_table_references(mocker: MockerFixture, app_context: None) -> None:
|
||||
"""
|
||||
Test the ``extract_table_references`` helper function.
|
||||
"""
|
||||
from superset.migrations.shared.utils import extract_table_references
|
||||
|
||||
assert extract_table_references("SELECT 1", "trino") == set()
|
||||
assert extract_table_references("SELECT 1 FROM some_table", "trino") == {
|
||||
Table(table="some_table", schema=None, catalog=None)
|
||||
}
|
||||
assert extract_table_references(
|
||||
"SELECT 1 FROM some_catalog.some_schema.some_table", "trino"
|
||||
) == {Table(table="some_table", schema="some_schema", catalog="some_catalog")}
|
||||
assert extract_table_references(
|
||||
"SELECT * FROM some_table JOIN other_table ON some_table.id = other_table.id",
|
||||
"trino",
|
||||
) == {
|
||||
Table(table="some_table", schema=None, catalog=None),
|
||||
Table(table="other_table", schema=None, catalog=None),
|
||||
}
|
||||
|
||||
# test falling back to sqlparse
|
||||
logger = mocker.patch("superset.migrations.shared.utils.logger")
|
||||
sql = "SELECT * FROM table UNION ALL SELECT * FROM other_table"
|
||||
assert extract_table_references(
|
||||
sql,
|
||||
"trino",
|
||||
) == {Table(table="other_table", schema=None, catalog=None)}
|
||||
logger.warning.assert_called_with("Unable to parse query with sqloxide: %s", sql)
|
|
@ -29,6 +29,7 @@ from sqlparse.tokens import Name
|
|||
from superset.exceptions import QueryClauseValidationException
|
||||
from superset.sql_parse import (
|
||||
add_table_name,
|
||||
extract_table_references,
|
||||
get_rls_for_table,
|
||||
has_table_query,
|
||||
insert_rls,
|
||||
|
@ -1468,3 +1469,51 @@ def test_get_rls_for_table(mocker: MockerFixture, app_context: None) -> None:
|
|||
|
||||
dataset.get_sqla_row_level_filters.return_value = []
|
||||
assert get_rls_for_table(candidate, 1, "public") is None
|
||||
|
||||
|
||||
def test_extract_table_references(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test the ``extract_table_references`` helper function.
|
||||
"""
|
||||
assert extract_table_references("SELECT 1", "trino") == set()
|
||||
assert extract_table_references("SELECT 1 FROM some_table", "trino") == {
|
||||
Table(table="some_table", schema=None, catalog=None)
|
||||
}
|
||||
assert extract_table_references("SELECT {{ jinja }} FROM some_table", "trino") == {
|
||||
Table(table="some_table", schema=None, catalog=None)
|
||||
}
|
||||
assert extract_table_references(
|
||||
"SELECT 1 FROM some_catalog.some_schema.some_table", "trino"
|
||||
) == {Table(table="some_table", schema="some_schema", catalog="some_catalog")}
|
||||
|
||||
# with identifier quotes
|
||||
assert extract_table_references(
|
||||
"SELECT 1 FROM `some_catalog`.`some_schema`.`some_table`", "mysql"
|
||||
) == {Table(table="some_table", schema="some_schema", catalog="some_catalog")}
|
||||
assert extract_table_references(
|
||||
'SELECT 1 FROM "some_catalog".some_schema."some_table"', "trino"
|
||||
) == {Table(table="some_table", schema="some_schema", catalog="some_catalog")}
|
||||
|
||||
assert extract_table_references(
|
||||
"SELECT * FROM some_table JOIN other_table ON some_table.id = other_table.id",
|
||||
"trino",
|
||||
) == {
|
||||
Table(table="some_table", schema=None, catalog=None),
|
||||
Table(table="other_table", schema=None, catalog=None),
|
||||
}
|
||||
|
||||
# test falling back to sqlparse
|
||||
logger = mocker.patch("superset.sql_parse.logger")
|
||||
sql = "SELECT * FROM table UNION ALL SELECT * FROM other_table"
|
||||
assert extract_table_references(
|
||||
sql,
|
||||
"trino",
|
||||
) == {Table(table="other_table", schema=None, catalog=None)}
|
||||
logger.warning.assert_called_once()
|
||||
|
||||
logger = mocker.patch("superset.migrations.shared.utils.logger")
|
||||
sql = "SELECT * FROM table UNION ALL SELECT * FROM other_table"
|
||||
assert extract_table_references(sql, "trino", show_warning=False) == {
|
||||
Table(table="other_table", schema=None, catalog=None)
|
||||
}
|
||||
logger.warning.assert_not_called()
|
||||
|
|
|
@ -14,3 +14,17 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from typing import Any
|
||||
|
||||
from superset import security_manager
|
||||
|
||||
|
||||
def get_test_user(id_: int, username: str) -> Any:
|
||||
"""Create a sample test user"""
|
||||
return security_manager.user_model(
|
||||
id=id_,
|
||||
username=username,
|
||||
first_name=username,
|
||||
last_name=username,
|
||||
email=f"{username}@example.com",
|
||||
)
|
Loading…
Reference in New Issue