remove autoflush for queries during dual write (#20460)

This commit is contained in:
Elizabeth Thompson 2022-06-23 14:50:30 -07:00 committed by GitHub
parent 661ab35bd0
commit 44f0b511dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 68 additions and 12 deletions

View File

@ -439,6 +439,7 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
self, known_columns: Optional[Dict[str, NewColumn]] = None self, known_columns: Optional[Dict[str, NewColumn]] = None
) -> NewColumn: ) -> NewColumn:
"""Convert a TableColumn to NewColumn""" """Convert a TableColumn to NewColumn"""
session: Session = inspect(self).session
column = known_columns.get(self.uuid) if known_columns else None column = known_columns.get(self.uuid) if known_columns else None
if not column: if not column:
column = NewColumn() column = NewColumn()
@ -452,6 +453,21 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
if value: if value:
extra_json[attr] = value extra_json[attr] = value
if not column.id:
with session.no_autoflush:
saved_column = (
session.query(NewColumn).filter_by(uuid=self.uuid).one_or_none()
)
if saved_column:
logger.warning(
"sl_column already exists. Assigning existing id %s", self
)
# uuid isn't a primary key, so add the id of the existing column to
# ensure that the column is modified instead of created
# in order to avoid a uuid collision
column.id = saved_column.id
column.uuid = self.uuid column.uuid = self.uuid
column.created_on = self.created_on column.created_on = self.created_on
column.changed_on = self.changed_on column.changed_on = self.changed_on
@ -555,6 +571,7 @@ class SqlMetric(Model, BaseMetric, CertificationMixin):
) -> NewColumn: ) -> NewColumn:
"""Convert a SqlMetric to NewColumn. Find and update existing or """Convert a SqlMetric to NewColumn. Find and update existing or
create a new one.""" create a new one."""
session: Session = inspect(self).session
column = known_columns.get(self.uuid) if known_columns else None column = known_columns.get(self.uuid) if known_columns else None
if not column: if not column:
column = NewColumn() column = NewColumn()
@ -568,6 +585,20 @@ class SqlMetric(Model, BaseMetric, CertificationMixin):
self.metric_type and self.metric_type.lower() in ADDITIVE_METRIC_TYPES_LOWER self.metric_type and self.metric_type.lower() in ADDITIVE_METRIC_TYPES_LOWER
) )
if not column.id:
with session.no_autoflush:
saved_column = (
session.query(NewColumn).filter_by(uuid=self.uuid).one_or_none()
)
if saved_column:
logger.warning(
"sl_column already exists. Assigning existing id %s", self
)
# uuid isn't a primary key, so add the id of the existing column to
# ensure that the column is modified instead of created
# in order to avoid a uuid collision
column.id = saved_column.id
column.uuid = self.uuid column.uuid = self.uuid
column.name = self.metric_name column.name = self.metric_name
column.created_on = self.created_on column.created_on = self.created_on
@ -2149,10 +2180,11 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
uuids.remove(column.uuid) uuids.remove(column.uuid)
if uuids: if uuids:
# load those not found from db with session.no_autoflush:
existing_columns |= set( # load those not found from db
session.query(NewColumn).filter(NewColumn.uuid.in_(uuids)) existing_columns |= set(
) session.query(NewColumn).filter(NewColumn.uuid.in_(uuids))
)
known_columns = {column.uuid: column for column in existing_columns} known_columns = {column.uuid: column for column in existing_columns}
return [ return [
@ -2192,9 +2224,10 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
# update changed_on timestamp # update changed_on timestamp
session.execute(update(NewDataset).where(NewDataset.id == dataset.id)) session.execute(update(NewDataset).where(NewDataset.id == dataset.id))
try: try:
column = session.query(NewColumn).filter_by(uuid=target.uuid).one() with session.no_autoflush:
# update `Column` model as well column = session.query(NewColumn).filter_by(uuid=target.uuid).one()
session.merge(target.to_sl_column({target.uuid: column})) # update `Column` model as well
session.merge(target.to_sl_column({target.uuid: column}))
except NoResultFound: except NoResultFound:
logger.warning("No column was found for %s", target) logger.warning("No column was found for %s", target)
# see if the column is in cache # see if the column is in cache
@ -2204,14 +2237,15 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
), ),
None, None,
) )
if column:
logger.warning("New column was found in cache: %s", column)
if not column: else:
# to be safe, use a different uuid and create a new column # to be safe, use a different uuid and create a new column
uuid = uuid4() uuid = uuid4()
target.uuid = uuid target.uuid = uuid
column = NewColumn(uuid=uuid)
session.add(target.to_sl_column({column.uuid: column})) session.add(target.to_sl_column())
@staticmethod @staticmethod
def after_insert( def after_insert(

View File

@ -17,8 +17,10 @@
from unittest import mock from unittest import mock
import pytest import pytest
from sqlalchemy import inspect
from sqlalchemy.orm.exc import NoResultFound from sqlalchemy.orm.exc import NoResultFound
from superset.columns.models import Column
from superset.connectors.sqla.models import SqlaTable, TableColumn from superset.connectors.sqla.models import SqlaTable, TableColumn
from superset.extensions import db from superset.extensions import db
from tests.integration_tests.base_tests import SupersetTestCase from tests.integration_tests.base_tests import SupersetTestCase
@ -59,6 +61,10 @@ class SqlaTableModelTest(SupersetTestCase):
with mock.patch("sqlalchemy.orm.query.Query.one", side_effect=NoResultFound): with mock.patch("sqlalchemy.orm.query.Query.one", side_effect=NoResultFound):
SqlaTable.update_column(None, None, target=column) SqlaTable.update_column(None, None, target=column)
session = inspect(column).session
session.flush()
# refetch # refetch
dataset = db.session.query(SqlaTable).filter_by(id=dataset.id).one() dataset = db.session.query(SqlaTable).filter_by(id=dataset.id).one()
# it should create a new uuid # it should create a new uuid
@ -67,3 +73,15 @@ class SqlaTableModelTest(SupersetTestCase):
# reset # reset
column.uuid = column_uuid column.uuid = column_uuid
SqlaTable.update_column(None, None, target=column) SqlaTable.update_column(None, None, target=column)
@pytest.mark.usefixtures("load_dataset_with_columns")
def test_to_sl_column_no_known_columns(self) -> None:
"""
Test that the function returns a new column
"""
dataset = db.session.query(SqlaTable).filter_by(table_name="students").first()
column = dataset.columns[0]
new_column = column.to_sl_column()
# it should use the same uuid
assert column.uuid == new_column.uuid

View File

@ -718,6 +718,7 @@ def test_update_physical_sqlatable_columns(
metrics=[], metrics=[],
database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"),
) )
session.add(sqla_table) session.add(sqla_table)
session.flush() session.flush()
@ -735,8 +736,11 @@ def test_update_physical_sqlatable_columns(
assert session.query(Column).count() == 3 assert session.query(Column).count() == 3
dataset = session.query(Dataset).one() dataset = session.query(Dataset).one()
assert len(dataset.columns) == 2 assert len(dataset.columns) == 2
for table_column, dataset_column in zip(sqla_table.columns, dataset.columns):
assert table_column.uuid == dataset_column.uuid # check that both lists have the same uuids
assert [col.uuid for col in sqla_table.columns].sort() == [
col.uuid for col in dataset.columns
].sort()
# delete the column in the original instance # delete the column in the original instance
sqla_table.columns = sqla_table.columns[1:] sqla_table.columns = sqla_table.columns[1:]