fix dataset update table (#19269)

This commit is contained in:
Elizabeth Thompson 2022-03-21 09:43:51 -07:00 committed by GitHub
parent c07a707eab
commit 88029e21b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 225 additions and 128 deletions

View File

@ -1863,11 +1863,20 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
session.execute(update(SqlaTable).where(SqlaTable.id == target.table.id))
# update ``Column`` model as well
dataset = (
session.query(NewDataset).filter_by(sqlatable_id=target.table.id).one()
session.query(NewDataset)
.filter_by(sqlatable_id=target.table.id)
.one_or_none()
)
if not dataset:
# if dataset is not found create a new copy
# of the dataset instead of updating the existing
SqlaTable.write_shadow_dataset(target.table, database, session)
return
# update ``Column`` model as well
if isinstance(target, TableColumn):
columns = [
column
@ -1923,7 +1932,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
column.extra_json = json.dumps(extra_json) if extra_json else None
@staticmethod
def after_insert( # pylint: disable=too-many-locals
def after_insert(
mapper: Mapper, connection: Connection, target: "SqlaTable",
) -> None:
"""
@ -1938,135 +1947,18 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
For more context: https://github.com/apache/superset/issues/14909
"""
session = inspect(target).session
# set permissions
security_manager.set_perm(mapper, connection, target)
session = inspect(target).session
# get DB-specific conditional quoter for expressions that point to columns or
# table names
database = (
target.database
or session.query(Database).filter_by(id=target.database_id).one()
)
engine = database.get_sqla_engine(schema=target.schema)
conditional_quote = engine.dialect.identifier_preparer.quote
# create columns
columns = []
for column in target.columns:
# ``is_active`` might be ``None`` at this point, but it defaults to ``True``.
if column.is_active is False:
continue
extra_json = json.loads(column.extra or "{}")
for attr in {"groupby", "filterable", "verbose_name", "python_date_format"}:
value = getattr(column, attr)
if value:
extra_json[attr] = value
columns.append(
NewColumn(
name=column.column_name,
type=column.type or "Unknown",
expression=column.expression
or conditional_quote(column.column_name),
description=column.description,
is_temporal=column.is_dttm,
is_aggregation=False,
is_physical=column.expression is None,
is_spatial=False,
is_partition=False,
is_increase_desired=True,
extra_json=json.dumps(extra_json) if extra_json else None,
is_managed_externally=target.is_managed_externally,
external_url=target.external_url,
),
)
# create metrics
for metric in target.metrics:
extra_json = json.loads(metric.extra or "{}")
for attr in {"verbose_name", "metric_type", "d3format"}:
value = getattr(metric, attr)
if value:
extra_json[attr] = value
is_additive = (
metric.metric_type
and metric.metric_type.lower() in ADDITIVE_METRIC_TYPES
)
columns.append(
NewColumn(
name=metric.metric_name,
type="Unknown", # figuring this out would require a type inferrer
expression=metric.expression,
warning_text=metric.warning_text,
description=metric.description,
is_aggregation=True,
is_additive=is_additive,
is_physical=False,
is_spatial=False,
is_partition=False,
is_increase_desired=True,
extra_json=json.dumps(extra_json) if extra_json else None,
is_managed_externally=target.is_managed_externally,
external_url=target.external_url,
),
)
# physical dataset
tables = []
if target.sql is None:
physical_columns = [column for column in columns if column.is_physical]
# create table
table = NewTable(
name=target.table_name,
schema=target.schema,
catalog=None, # currently not supported
database_id=target.database_id,
columns=physical_columns,
is_managed_externally=target.is_managed_externally,
external_url=target.external_url,
)
tables.append(table)
# virtual dataset
else:
# mark all columns as virtual (not physical)
for column in columns:
column.is_physical = False
# find referenced tables
parsed = ParsedQuery(target.sql)
referenced_tables = parsed.tables
# predicate for finding the referenced tables
predicate = or_(
*[
and_(
NewTable.schema == (table.schema or target.schema),
NewTable.name == table.table,
)
for table in referenced_tables
]
)
tables = session.query(NewTable).filter(predicate).all()
# create the new dataset
dataset = NewDataset(
sqlatable_id=target.id,
name=target.table_name,
expression=target.sql or conditional_quote(target.table_name),
tables=tables,
columns=columns,
is_physical=target.sql is None,
is_managed_externally=target.is_managed_externally,
external_url=target.external_url,
)
session.add(dataset)
SqlaTable.write_shadow_dataset(target, database, session)
@staticmethod
def after_delete( # pylint: disable=unused-argument
@ -2301,6 +2193,142 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
dataset.expression = target.sql or conditional_quote(target.table_name)
dataset.is_physical = target.sql is None
@staticmethod
def write_shadow_dataset( # pylint: disable=too-many-locals
dataset: "SqlaTable", database: Database, session: Session
) -> None:
"""
Shadow write the dataset to new models.
The ``SqlaTable`` model is currently being migrated to two new models, ``Table``
and ``Dataset``. In the first phase of the migration the new models are populated
whenever ``SqlaTable`` is modified (created, updated, or deleted).
In the second phase of the migration reads will be done from the new models.
Finally, in the third phase of the migration the old models will be removed.
For more context: https://github.com/apache/superset/issues/14909
"""
engine = database.get_sqla_engine(schema=dataset.schema)
conditional_quote = engine.dialect.identifier_preparer.quote
# create columns
columns = []
for column in dataset.columns:
# ``is_active`` might be ``None`` at this point, but it defaults to ``True``.
if column.is_active is False:
continue
extra_json = json.loads(column.extra or "{}")
for attr in {"groupby", "filterable", "verbose_name", "python_date_format"}:
value = getattr(column, attr)
if value:
extra_json[attr] = value
columns.append(
NewColumn(
name=column.column_name,
type=column.type or "Unknown",
expression=column.expression
or conditional_quote(column.column_name),
description=column.description,
is_temporal=column.is_dttm,
is_aggregation=False,
is_physical=column.expression is None,
is_spatial=False,
is_partition=False,
is_increase_desired=True,
extra_json=json.dumps(extra_json) if extra_json else None,
is_managed_externally=dataset.is_managed_externally,
external_url=dataset.external_url,
),
)
# create metrics
for metric in dataset.metrics:
extra_json = json.loads(metric.extra or "{}")
for attr in {"verbose_name", "metric_type", "d3format"}:
value = getattr(metric, attr)
if value:
extra_json[attr] = value
is_additive = (
metric.metric_type
and metric.metric_type.lower() in ADDITIVE_METRIC_TYPES
)
columns.append(
NewColumn(
name=metric.metric_name,
type="Unknown", # figuring this out would require a type inferrer
expression=metric.expression,
warning_text=metric.warning_text,
description=metric.description,
is_aggregation=True,
is_additive=is_additive,
is_physical=False,
is_spatial=False,
is_partition=False,
is_increase_desired=True,
extra_json=json.dumps(extra_json) if extra_json else None,
is_managed_externally=dataset.is_managed_externally,
external_url=dataset.external_url,
),
)
# physical dataset
tables = []
if dataset.sql is None:
physical_columns = [column for column in columns if column.is_physical]
# create table
table = NewTable(
name=dataset.table_name,
schema=dataset.schema,
catalog=None, # currently not supported
database_id=dataset.database_id,
columns=physical_columns,
is_managed_externally=dataset.is_managed_externally,
external_url=dataset.external_url,
)
tables.append(table)
# virtual dataset
else:
# mark all columns as virtual (not physical)
for column in columns:
column.is_physical = False
# find referenced tables
parsed = ParsedQuery(dataset.sql)
referenced_tables = parsed.tables
# predicate for finding the referenced tables
predicate = or_(
*[
and_(
NewTable.schema == (table.schema or dataset.schema),
NewTable.name == table.table,
)
for table in referenced_tables
]
)
tables = session.query(NewTable).filter(predicate).all()
# create the new dataset
new_dataset = NewDataset(
sqlatable_id=dataset.id,
name=dataset.table_name,
expression=dataset.sql or conditional_quote(dataset.table_name),
tables=tables,
columns=columns,
is_physical=dataset.sql is None,
is_managed_externally=dataset.is_managed_externally,
external_url=dataset.external_url,
)
session.add(new_dataset)
sa.event.listen(SqlaTable, "before_update", SqlaTable.before_update)
sa.event.listen(SqlaTable, "after_insert", SqlaTable.after_insert)

View File

@ -980,9 +980,9 @@ def test_update_sqlatable_schema(
sqla_table.schema = "new_schema"
session.flush()
dataset = session.query(Dataset).one()
assert dataset.tables[0].schema == "new_schema"
assert dataset.tables[0].id == 2
new_dataset = session.query(Dataset).one()
assert new_dataset.tables[0].schema == "new_schema"
assert new_dataset.tables[0].id == 2
def test_update_sqlatable_metric(
@ -1098,9 +1098,9 @@ def test_update_virtual_sqlatable_references(
session.flush()
# check that new dataset has both tables
dataset = session.query(Dataset).one()
assert dataset.tables == [table1, table2]
assert dataset.expression == "SELECT a, b FROM table_a JOIN table_b"
new_dataset = session.query(Dataset).one()
assert new_dataset.tables == [table1, table2]
assert new_dataset.expression == "SELECT a, b FROM table_a JOIN table_b"
def test_quote_expressions(app_context: None, session: Session) -> None:
@ -1242,3 +1242,72 @@ def test_update_physical_sqlatable(
# check that dataset points to the original table
assert dataset.tables[0].database_id == 1
def test_update_physical_sqlatable_no_dataset(
mocker: MockFixture, app_context: None, session: Session
) -> None:
"""
Test updating the table on a physical dataset that it creates
a new dataset if one didn't already exist.
When updating the table on a physical dataset by pointing it somewhere else (change
in database ID, schema, or table name) we should point the ``Dataset`` to an
existing ``Table`` if possible, and create a new one otherwise.
"""
# patch session
mocker.patch(
"superset.security.SupersetSecurityManager.get_session", return_value=session
)
mocker.patch("superset.datasets.dao.db.session", session)
from superset.columns.models import Column
from superset.connectors.sqla.models import SqlaTable, TableColumn
from superset.datasets.models import Dataset
from superset.models.core import Database
from superset.tables.models import Table
from superset.tables.schemas import TableSchema
engine = session.get_bind()
Dataset.metadata.create_all(engine) # pylint: disable=no-member
columns = [
TableColumn(column_name="a", type="INTEGER"),
]
sqla_table = SqlaTable(
table_name="old_dataset",
columns=columns,
metrics=[],
database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"),
)
session.add(sqla_table)
session.flush()
# check that the table was created
table = session.query(Table).one()
assert table.id == 1
dataset = session.query(Dataset).one()
assert dataset.tables == [table]
# point ``SqlaTable`` to a different database
new_database = Database(
database_name="my_other_database", sqlalchemy_uri="sqlite://"
)
session.add(new_database)
session.flush()
sqla_table.database = new_database
session.flush()
new_dataset = session.query(Dataset).one()
# check that dataset now points to the new table
assert new_dataset.tables[0].database_id == 2
# point ``SqlaTable`` back
sqla_table.database_id = 1
session.flush()
# check that dataset points to the original table
assert new_dataset.tables[0].database_id == 1