diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 57730cc711..61d708f021 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -439,6 +439,7 @@ class TableColumn(Model, BaseColumn, CertificationMixin): self, known_columns: Optional[Dict[str, NewColumn]] = None ) -> NewColumn: """Convert a TableColumn to NewColumn""" + session: Session = inspect(self).session column = known_columns.get(self.uuid) if known_columns else None if not column: column = NewColumn() @@ -452,6 +453,21 @@ class TableColumn(Model, BaseColumn, CertificationMixin): if value: extra_json[attr] = value + if not column.id: + with session.no_autoflush: + saved_column = ( + session.query(NewColumn).filter_by(uuid=self.uuid).one_or_none() + ) + if saved_column: + logger.warning( + "sl_column already exists. Assigning existing id %s", self + ) + + # uuid isn't a primary key, so add the id of the existing column to + # ensure that the column is modified instead of created + # in order to avoid a uuid collision + column.id = saved_column.id + column.uuid = self.uuid column.created_on = self.created_on column.changed_on = self.changed_on @@ -555,6 +571,7 @@ class SqlMetric(Model, BaseMetric, CertificationMixin): ) -> NewColumn: """Convert a SqlMetric to NewColumn. Find and update existing or create a new one.""" + session: Session = inspect(self).session column = known_columns.get(self.uuid) if known_columns else None if not column: column = NewColumn() @@ -568,6 +585,20 @@ class SqlMetric(Model, BaseMetric, CertificationMixin): self.metric_type and self.metric_type.lower() in ADDITIVE_METRIC_TYPES_LOWER ) + if not column.id: + with session.no_autoflush: + saved_column = ( + session.query(NewColumn).filter_by(uuid=self.uuid).one_or_none() + ) + if saved_column: + logger.warning( + "sl_column already exists. Assigning existing id %s", self + ) + # uuid isn't a primary key, so add the id of the existing column to + # ensure that the column is modified instead of created + # in order to avoid a uuid collision + column.id = saved_column.id + column.uuid = self.uuid column.name = self.metric_name column.created_on = self.created_on @@ -2149,10 +2180,11 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho uuids.remove(column.uuid) if uuids: - # load those not found from db - existing_columns |= set( - session.query(NewColumn).filter(NewColumn.uuid.in_(uuids)) - ) + with session.no_autoflush: + # load those not found from db + existing_columns |= set( + session.query(NewColumn).filter(NewColumn.uuid.in_(uuids)) + ) known_columns = {column.uuid: column for column in existing_columns} return [ @@ -2192,9 +2224,10 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho # update changed_on timestamp session.execute(update(NewDataset).where(NewDataset.id == dataset.id)) try: - column = session.query(NewColumn).filter_by(uuid=target.uuid).one() - # update `Column` model as well - session.merge(target.to_sl_column({target.uuid: column})) + with session.no_autoflush: + column = session.query(NewColumn).filter_by(uuid=target.uuid).one() + # update `Column` model as well + session.merge(target.to_sl_column({target.uuid: column})) except NoResultFound: logger.warning("No column was found for %s", target) # see if the column is in cache @@ -2204,14 +2237,15 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho ), None, ) + if column: + logger.warning("New column was found in cache: %s", column) - if not column: + else: # to be safe, use a different uuid and create a new column uuid = uuid4() target.uuid = uuid - column = NewColumn(uuid=uuid) - session.add(target.to_sl_column({column.uuid: column})) + session.add(target.to_sl_column()) @staticmethod def after_insert( diff --git a/tests/integration_tests/datasets/model_tests.py b/tests/integration_tests/datasets/model_tests.py index 31abce5494..3bcc4c0793 100644 --- a/tests/integration_tests/datasets/model_tests.py +++ b/tests/integration_tests/datasets/model_tests.py @@ -17,8 +17,10 @@ from unittest import mock import pytest +from sqlalchemy import inspect from sqlalchemy.orm.exc import NoResultFound +from superset.columns.models import Column from superset.connectors.sqla.models import SqlaTable, TableColumn from superset.extensions import db from tests.integration_tests.base_tests import SupersetTestCase @@ -59,6 +61,10 @@ class SqlaTableModelTest(SupersetTestCase): with mock.patch("sqlalchemy.orm.query.Query.one", side_effect=NoResultFound): SqlaTable.update_column(None, None, target=column) + session = inspect(column).session + + session.flush() + # refetch dataset = db.session.query(SqlaTable).filter_by(id=dataset.id).one() # it should create a new uuid @@ -67,3 +73,15 @@ class SqlaTableModelTest(SupersetTestCase): # reset column.uuid = column_uuid SqlaTable.update_column(None, None, target=column) + + @pytest.mark.usefixtures("load_dataset_with_columns") + def test_to_sl_column_no_known_columns(self) -> None: + """ + Test that the function returns a new column + """ + dataset = db.session.query(SqlaTable).filter_by(table_name="students").first() + column = dataset.columns[0] + new_column = column.to_sl_column() + + # it should use the same uuid + assert column.uuid == new_column.uuid diff --git a/tests/unit_tests/datasets/test_models.py b/tests/unit_tests/datasets/test_models.py index cf3dccefe6..cacaef5ef5 100644 --- a/tests/unit_tests/datasets/test_models.py +++ b/tests/unit_tests/datasets/test_models.py @@ -718,6 +718,7 @@ def test_update_physical_sqlatable_columns( metrics=[], database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), ) + session.add(sqla_table) session.flush() @@ -735,8 +736,11 @@ def test_update_physical_sqlatable_columns( assert session.query(Column).count() == 3 dataset = session.query(Dataset).one() assert len(dataset.columns) == 2 - for table_column, dataset_column in zip(sqla_table.columns, dataset.columns): - assert table_column.uuid == dataset_column.uuid + + # check that both lists have the same uuids + assert [col.uuid for col in sqla_table.columns].sort() == [ + col.uuid for col in dataset.columns + ].sort() # delete the column in the original instance sqla_table.columns = sqla_table.columns[1:]