mirror of https://github.com/apache/superset.git
fix dataset update table (#19269)
This commit is contained in:
parent
c07a707eab
commit
88029e21b6
|
@ -1863,11 +1863,20 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
|
||||
session.execute(update(SqlaTable).where(SqlaTable.id == target.table.id))
|
||||
|
||||
# update ``Column`` model as well
|
||||
dataset = (
|
||||
session.query(NewDataset).filter_by(sqlatable_id=target.table.id).one()
|
||||
session.query(NewDataset)
|
||||
.filter_by(sqlatable_id=target.table.id)
|
||||
.one_or_none()
|
||||
)
|
||||
|
||||
if not dataset:
|
||||
# if dataset is not found create a new copy
|
||||
# of the dataset instead of updating the existing
|
||||
|
||||
SqlaTable.write_shadow_dataset(target.table, database, session)
|
||||
return
|
||||
|
||||
# update ``Column`` model as well
|
||||
if isinstance(target, TableColumn):
|
||||
columns = [
|
||||
column
|
||||
|
@ -1923,7 +1932,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
column.extra_json = json.dumps(extra_json) if extra_json else None
|
||||
|
||||
@staticmethod
|
||||
def after_insert( # pylint: disable=too-many-locals
|
||||
def after_insert(
|
||||
mapper: Mapper, connection: Connection, target: "SqlaTable",
|
||||
) -> None:
|
||||
"""
|
||||
|
@ -1938,135 +1947,18 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
|
||||
For more context: https://github.com/apache/superset/issues/14909
|
||||
"""
|
||||
session = inspect(target).session
|
||||
# set permissions
|
||||
security_manager.set_perm(mapper, connection, target)
|
||||
|
||||
session = inspect(target).session
|
||||
|
||||
# get DB-specific conditional quoter for expressions that point to columns or
|
||||
# table names
|
||||
database = (
|
||||
target.database
|
||||
or session.query(Database).filter_by(id=target.database_id).one()
|
||||
)
|
||||
engine = database.get_sqla_engine(schema=target.schema)
|
||||
conditional_quote = engine.dialect.identifier_preparer.quote
|
||||
|
||||
# create columns
|
||||
columns = []
|
||||
for column in target.columns:
|
||||
# ``is_active`` might be ``None`` at this point, but it defaults to ``True``.
|
||||
if column.is_active is False:
|
||||
continue
|
||||
|
||||
extra_json = json.loads(column.extra or "{}")
|
||||
for attr in {"groupby", "filterable", "verbose_name", "python_date_format"}:
|
||||
value = getattr(column, attr)
|
||||
if value:
|
||||
extra_json[attr] = value
|
||||
|
||||
columns.append(
|
||||
NewColumn(
|
||||
name=column.column_name,
|
||||
type=column.type or "Unknown",
|
||||
expression=column.expression
|
||||
or conditional_quote(column.column_name),
|
||||
description=column.description,
|
||||
is_temporal=column.is_dttm,
|
||||
is_aggregation=False,
|
||||
is_physical=column.expression is None,
|
||||
is_spatial=False,
|
||||
is_partition=False,
|
||||
is_increase_desired=True,
|
||||
extra_json=json.dumps(extra_json) if extra_json else None,
|
||||
is_managed_externally=target.is_managed_externally,
|
||||
external_url=target.external_url,
|
||||
),
|
||||
)
|
||||
|
||||
# create metrics
|
||||
for metric in target.metrics:
|
||||
extra_json = json.loads(metric.extra or "{}")
|
||||
for attr in {"verbose_name", "metric_type", "d3format"}:
|
||||
value = getattr(metric, attr)
|
||||
if value:
|
||||
extra_json[attr] = value
|
||||
|
||||
is_additive = (
|
||||
metric.metric_type
|
||||
and metric.metric_type.lower() in ADDITIVE_METRIC_TYPES
|
||||
)
|
||||
|
||||
columns.append(
|
||||
NewColumn(
|
||||
name=metric.metric_name,
|
||||
type="Unknown", # figuring this out would require a type inferrer
|
||||
expression=metric.expression,
|
||||
warning_text=metric.warning_text,
|
||||
description=metric.description,
|
||||
is_aggregation=True,
|
||||
is_additive=is_additive,
|
||||
is_physical=False,
|
||||
is_spatial=False,
|
||||
is_partition=False,
|
||||
is_increase_desired=True,
|
||||
extra_json=json.dumps(extra_json) if extra_json else None,
|
||||
is_managed_externally=target.is_managed_externally,
|
||||
external_url=target.external_url,
|
||||
),
|
||||
)
|
||||
|
||||
# physical dataset
|
||||
tables = []
|
||||
if target.sql is None:
|
||||
physical_columns = [column for column in columns if column.is_physical]
|
||||
|
||||
# create table
|
||||
table = NewTable(
|
||||
name=target.table_name,
|
||||
schema=target.schema,
|
||||
catalog=None, # currently not supported
|
||||
database_id=target.database_id,
|
||||
columns=physical_columns,
|
||||
is_managed_externally=target.is_managed_externally,
|
||||
external_url=target.external_url,
|
||||
)
|
||||
tables.append(table)
|
||||
|
||||
# virtual dataset
|
||||
else:
|
||||
# mark all columns as virtual (not physical)
|
||||
for column in columns:
|
||||
column.is_physical = False
|
||||
|
||||
# find referenced tables
|
||||
parsed = ParsedQuery(target.sql)
|
||||
referenced_tables = parsed.tables
|
||||
|
||||
# predicate for finding the referenced tables
|
||||
predicate = or_(
|
||||
*[
|
||||
and_(
|
||||
NewTable.schema == (table.schema or target.schema),
|
||||
NewTable.name == table.table,
|
||||
)
|
||||
for table in referenced_tables
|
||||
]
|
||||
)
|
||||
tables = session.query(NewTable).filter(predicate).all()
|
||||
|
||||
# create the new dataset
|
||||
dataset = NewDataset(
|
||||
sqlatable_id=target.id,
|
||||
name=target.table_name,
|
||||
expression=target.sql or conditional_quote(target.table_name),
|
||||
tables=tables,
|
||||
columns=columns,
|
||||
is_physical=target.sql is None,
|
||||
is_managed_externally=target.is_managed_externally,
|
||||
external_url=target.external_url,
|
||||
)
|
||||
session.add(dataset)
|
||||
SqlaTable.write_shadow_dataset(target, database, session)
|
||||
|
||||
@staticmethod
|
||||
def after_delete( # pylint: disable=unused-argument
|
||||
|
@ -2301,6 +2193,142 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
dataset.expression = target.sql or conditional_quote(target.table_name)
|
||||
dataset.is_physical = target.sql is None
|
||||
|
||||
@staticmethod
|
||||
def write_shadow_dataset( # pylint: disable=too-many-locals
|
||||
dataset: "SqlaTable", database: Database, session: Session
|
||||
) -> None:
|
||||
"""
|
||||
Shadow write the dataset to new models.
|
||||
|
||||
The ``SqlaTable`` model is currently being migrated to two new models, ``Table``
|
||||
and ``Dataset``. In the first phase of the migration the new models are populated
|
||||
whenever ``SqlaTable`` is modified (created, updated, or deleted).
|
||||
|
||||
In the second phase of the migration reads will be done from the new models.
|
||||
Finally, in the third phase of the migration the old models will be removed.
|
||||
|
||||
For more context: https://github.com/apache/superset/issues/14909
|
||||
"""
|
||||
|
||||
engine = database.get_sqla_engine(schema=dataset.schema)
|
||||
conditional_quote = engine.dialect.identifier_preparer.quote
|
||||
|
||||
# create columns
|
||||
columns = []
|
||||
for column in dataset.columns:
|
||||
# ``is_active`` might be ``None`` at this point, but it defaults to ``True``.
|
||||
if column.is_active is False:
|
||||
continue
|
||||
|
||||
extra_json = json.loads(column.extra or "{}")
|
||||
for attr in {"groupby", "filterable", "verbose_name", "python_date_format"}:
|
||||
value = getattr(column, attr)
|
||||
if value:
|
||||
extra_json[attr] = value
|
||||
|
||||
columns.append(
|
||||
NewColumn(
|
||||
name=column.column_name,
|
||||
type=column.type or "Unknown",
|
||||
expression=column.expression
|
||||
or conditional_quote(column.column_name),
|
||||
description=column.description,
|
||||
is_temporal=column.is_dttm,
|
||||
is_aggregation=False,
|
||||
is_physical=column.expression is None,
|
||||
is_spatial=False,
|
||||
is_partition=False,
|
||||
is_increase_desired=True,
|
||||
extra_json=json.dumps(extra_json) if extra_json else None,
|
||||
is_managed_externally=dataset.is_managed_externally,
|
||||
external_url=dataset.external_url,
|
||||
),
|
||||
)
|
||||
|
||||
# create metrics
|
||||
for metric in dataset.metrics:
|
||||
extra_json = json.loads(metric.extra or "{}")
|
||||
for attr in {"verbose_name", "metric_type", "d3format"}:
|
||||
value = getattr(metric, attr)
|
||||
if value:
|
||||
extra_json[attr] = value
|
||||
|
||||
is_additive = (
|
||||
metric.metric_type
|
||||
and metric.metric_type.lower() in ADDITIVE_METRIC_TYPES
|
||||
)
|
||||
|
||||
columns.append(
|
||||
NewColumn(
|
||||
name=metric.metric_name,
|
||||
type="Unknown", # figuring this out would require a type inferrer
|
||||
expression=metric.expression,
|
||||
warning_text=metric.warning_text,
|
||||
description=metric.description,
|
||||
is_aggregation=True,
|
||||
is_additive=is_additive,
|
||||
is_physical=False,
|
||||
is_spatial=False,
|
||||
is_partition=False,
|
||||
is_increase_desired=True,
|
||||
extra_json=json.dumps(extra_json) if extra_json else None,
|
||||
is_managed_externally=dataset.is_managed_externally,
|
||||
external_url=dataset.external_url,
|
||||
),
|
||||
)
|
||||
|
||||
# physical dataset
|
||||
tables = []
|
||||
if dataset.sql is None:
|
||||
physical_columns = [column for column in columns if column.is_physical]
|
||||
|
||||
# create table
|
||||
table = NewTable(
|
||||
name=dataset.table_name,
|
||||
schema=dataset.schema,
|
||||
catalog=None, # currently not supported
|
||||
database_id=dataset.database_id,
|
||||
columns=physical_columns,
|
||||
is_managed_externally=dataset.is_managed_externally,
|
||||
external_url=dataset.external_url,
|
||||
)
|
||||
tables.append(table)
|
||||
|
||||
# virtual dataset
|
||||
else:
|
||||
# mark all columns as virtual (not physical)
|
||||
for column in columns:
|
||||
column.is_physical = False
|
||||
|
||||
# find referenced tables
|
||||
parsed = ParsedQuery(dataset.sql)
|
||||
referenced_tables = parsed.tables
|
||||
|
||||
# predicate for finding the referenced tables
|
||||
predicate = or_(
|
||||
*[
|
||||
and_(
|
||||
NewTable.schema == (table.schema or dataset.schema),
|
||||
NewTable.name == table.table,
|
||||
)
|
||||
for table in referenced_tables
|
||||
]
|
||||
)
|
||||
tables = session.query(NewTable).filter(predicate).all()
|
||||
|
||||
# create the new dataset
|
||||
new_dataset = NewDataset(
|
||||
sqlatable_id=dataset.id,
|
||||
name=dataset.table_name,
|
||||
expression=dataset.sql or conditional_quote(dataset.table_name),
|
||||
tables=tables,
|
||||
columns=columns,
|
||||
is_physical=dataset.sql is None,
|
||||
is_managed_externally=dataset.is_managed_externally,
|
||||
external_url=dataset.external_url,
|
||||
)
|
||||
session.add(new_dataset)
|
||||
|
||||
|
||||
sa.event.listen(SqlaTable, "before_update", SqlaTable.before_update)
|
||||
sa.event.listen(SqlaTable, "after_insert", SqlaTable.after_insert)
|
||||
|
|
|
@ -980,9 +980,9 @@ def test_update_sqlatable_schema(
|
|||
sqla_table.schema = "new_schema"
|
||||
session.flush()
|
||||
|
||||
dataset = session.query(Dataset).one()
|
||||
assert dataset.tables[0].schema == "new_schema"
|
||||
assert dataset.tables[0].id == 2
|
||||
new_dataset = session.query(Dataset).one()
|
||||
assert new_dataset.tables[0].schema == "new_schema"
|
||||
assert new_dataset.tables[0].id == 2
|
||||
|
||||
|
||||
def test_update_sqlatable_metric(
|
||||
|
@ -1098,9 +1098,9 @@ def test_update_virtual_sqlatable_references(
|
|||
session.flush()
|
||||
|
||||
# check that new dataset has both tables
|
||||
dataset = session.query(Dataset).one()
|
||||
assert dataset.tables == [table1, table2]
|
||||
assert dataset.expression == "SELECT a, b FROM table_a JOIN table_b"
|
||||
new_dataset = session.query(Dataset).one()
|
||||
assert new_dataset.tables == [table1, table2]
|
||||
assert new_dataset.expression == "SELECT a, b FROM table_a JOIN table_b"
|
||||
|
||||
|
||||
def test_quote_expressions(app_context: None, session: Session) -> None:
|
||||
|
@ -1242,3 +1242,72 @@ def test_update_physical_sqlatable(
|
|||
|
||||
# check that dataset points to the original table
|
||||
assert dataset.tables[0].database_id == 1
|
||||
|
||||
|
||||
def test_update_physical_sqlatable_no_dataset(
|
||||
mocker: MockFixture, app_context: None, session: Session
|
||||
) -> None:
|
||||
"""
|
||||
Test updating the table on a physical dataset that it creates
|
||||
a new dataset if one didn't already exist.
|
||||
|
||||
When updating the table on a physical dataset by pointing it somewhere else (change
|
||||
in database ID, schema, or table name) we should point the ``Dataset`` to an
|
||||
existing ``Table`` if possible, and create a new one otherwise.
|
||||
"""
|
||||
# patch session
|
||||
mocker.patch(
|
||||
"superset.security.SupersetSecurityManager.get_session", return_value=session
|
||||
)
|
||||
mocker.patch("superset.datasets.dao.db.session", session)
|
||||
|
||||
from superset.columns.models import Column
|
||||
from superset.connectors.sqla.models import SqlaTable, TableColumn
|
||||
from superset.datasets.models import Dataset
|
||||
from superset.models.core import Database
|
||||
from superset.tables.models import Table
|
||||
from superset.tables.schemas import TableSchema
|
||||
|
||||
engine = session.get_bind()
|
||||
Dataset.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
columns = [
|
||||
TableColumn(column_name="a", type="INTEGER"),
|
||||
]
|
||||
|
||||
sqla_table = SqlaTable(
|
||||
table_name="old_dataset",
|
||||
columns=columns,
|
||||
metrics=[],
|
||||
database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"),
|
||||
)
|
||||
session.add(sqla_table)
|
||||
session.flush()
|
||||
|
||||
# check that the table was created
|
||||
table = session.query(Table).one()
|
||||
assert table.id == 1
|
||||
|
||||
dataset = session.query(Dataset).one()
|
||||
assert dataset.tables == [table]
|
||||
|
||||
# point ``SqlaTable`` to a different database
|
||||
new_database = Database(
|
||||
database_name="my_other_database", sqlalchemy_uri="sqlite://"
|
||||
)
|
||||
session.add(new_database)
|
||||
session.flush()
|
||||
sqla_table.database = new_database
|
||||
session.flush()
|
||||
|
||||
new_dataset = session.query(Dataset).one()
|
||||
|
||||
# check that dataset now points to the new table
|
||||
assert new_dataset.tables[0].database_id == 2
|
||||
|
||||
# point ``SqlaTable`` back
|
||||
sqla_table.database_id = 1
|
||||
session.flush()
|
||||
|
||||
# check that dataset points to the original table
|
||||
assert new_dataset.tables[0].database_id == 1
|
||||
|
|
Loading…
Reference in New Issue