remove autoflush for queries during dual write (#20460)

This commit is contained in:
Elizabeth Thompson 2022-06-23 14:50:30 -07:00 committed by GitHub
parent 661ab35bd0
commit 44f0b511dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 68 additions and 12 deletions

View File

@ -439,6 +439,7 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
self, known_columns: Optional[Dict[str, NewColumn]] = None
) -> NewColumn:
"""Convert a TableColumn to NewColumn"""
session: Session = inspect(self).session
column = known_columns.get(self.uuid) if known_columns else None
if not column:
column = NewColumn()
@ -452,6 +453,21 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
if value:
extra_json[attr] = value
if not column.id:
with session.no_autoflush:
saved_column = (
session.query(NewColumn).filter_by(uuid=self.uuid).one_or_none()
)
if saved_column:
logger.warning(
"sl_column already exists. Assigning existing id %s", self
)
# uuid isn't a primary key, so add the id of the existing column to
# ensure that the column is modified instead of created
# in order to avoid a uuid collision
column.id = saved_column.id
column.uuid = self.uuid
column.created_on = self.created_on
column.changed_on = self.changed_on
@ -555,6 +571,7 @@ class SqlMetric(Model, BaseMetric, CertificationMixin):
) -> NewColumn:
"""Convert a SqlMetric to NewColumn. Find and update existing or
create a new one."""
session: Session = inspect(self).session
column = known_columns.get(self.uuid) if known_columns else None
if not column:
column = NewColumn()
@ -568,6 +585,20 @@ class SqlMetric(Model, BaseMetric, CertificationMixin):
self.metric_type and self.metric_type.lower() in ADDITIVE_METRIC_TYPES_LOWER
)
if not column.id:
with session.no_autoflush:
saved_column = (
session.query(NewColumn).filter_by(uuid=self.uuid).one_or_none()
)
if saved_column:
logger.warning(
"sl_column already exists. Assigning existing id %s", self
)
# uuid isn't a primary key, so add the id of the existing column to
# ensure that the column is modified instead of created
# in order to avoid a uuid collision
column.id = saved_column.id
column.uuid = self.uuid
column.name = self.metric_name
column.created_on = self.created_on
@ -2149,10 +2180,11 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
uuids.remove(column.uuid)
if uuids:
# load those not found from db
existing_columns |= set(
session.query(NewColumn).filter(NewColumn.uuid.in_(uuids))
)
with session.no_autoflush:
# load those not found from db
existing_columns |= set(
session.query(NewColumn).filter(NewColumn.uuid.in_(uuids))
)
known_columns = {column.uuid: column for column in existing_columns}
return [
@ -2192,9 +2224,10 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
# update changed_on timestamp
session.execute(update(NewDataset).where(NewDataset.id == dataset.id))
try:
column = session.query(NewColumn).filter_by(uuid=target.uuid).one()
# update `Column` model as well
session.merge(target.to_sl_column({target.uuid: column}))
with session.no_autoflush:
column = session.query(NewColumn).filter_by(uuid=target.uuid).one()
# update `Column` model as well
session.merge(target.to_sl_column({target.uuid: column}))
except NoResultFound:
logger.warning("No column was found for %s", target)
# see if the column is in cache
@ -2204,14 +2237,15 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
),
None,
)
if column:
logger.warning("New column was found in cache: %s", column)
if not column:
else:
# to be safe, use a different uuid and create a new column
uuid = uuid4()
target.uuid = uuid
column = NewColumn(uuid=uuid)
session.add(target.to_sl_column({column.uuid: column}))
session.add(target.to_sl_column())
@staticmethod
def after_insert(

View File

@ -17,8 +17,10 @@
from unittest import mock
import pytest
from sqlalchemy import inspect
from sqlalchemy.orm.exc import NoResultFound
from superset.columns.models import Column
from superset.connectors.sqla.models import SqlaTable, TableColumn
from superset.extensions import db
from tests.integration_tests.base_tests import SupersetTestCase
@ -59,6 +61,10 @@ class SqlaTableModelTest(SupersetTestCase):
with mock.patch("sqlalchemy.orm.query.Query.one", side_effect=NoResultFound):
SqlaTable.update_column(None, None, target=column)
session = inspect(column).session
session.flush()
# refetch
dataset = db.session.query(SqlaTable).filter_by(id=dataset.id).one()
# it should create a new uuid
@ -67,3 +73,15 @@ class SqlaTableModelTest(SupersetTestCase):
# reset
column.uuid = column_uuid
SqlaTable.update_column(None, None, target=column)
@pytest.mark.usefixtures("load_dataset_with_columns")
def test_to_sl_column_no_known_columns(self) -> None:
"""
Test that the function returns a new column
"""
dataset = db.session.query(SqlaTable).filter_by(table_name="students").first()
column = dataset.columns[0]
new_column = column.to_sl_column()
# it should use the same uuid
assert column.uuid == new_column.uuid

View File

@ -718,6 +718,7 @@ def test_update_physical_sqlatable_columns(
metrics=[],
database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"),
)
session.add(sqla_table)
session.flush()
@ -735,8 +736,11 @@ def test_update_physical_sqlatable_columns(
assert session.query(Column).count() == 3
dataset = session.query(Dataset).one()
assert len(dataset.columns) == 2
for table_column, dataset_column in zip(sqla_table.columns, dataset.columns):
assert table_column.uuid == dataset_column.uuid
# check that both lists have the same uuids
assert [col.uuid for col in sqla_table.columns].sort() == [
col.uuid for col in dataset.columns
].sort()
# delete the column in the original instance
sqla_table.columns = sqla_table.columns[1:]