diff --git a/pyproject.toml b/pyproject.toml index d9aeee440b..2a778f4ab9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -239,8 +239,8 @@ basepython = python3.10 ignore_basepython_conflict = true commands = superset db upgrade - superset load_test_users superset init + superset load-test-users # use -s to be able to use break pointers. # no args or tests/* can be passed as an argument to run all tests pytest -s {posargs} diff --git a/scripts/permissions_cleanup.py b/scripts/permissions_cleanup.py index c80ef231b3..0b1980b146 100644 --- a/scripts/permissions_cleanup.py +++ b/scripts/permissions_cleanup.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=consider-using-transaction from collections import defaultdict from superset import db, security_manager diff --git a/scripts/python_tests.sh b/scripts/python_tests.sh index 443b1d5d61..e127d0c020 100755 --- a/scripts/python_tests.sh +++ b/scripts/python_tests.sh @@ -28,9 +28,9 @@ export SUPERSET_TESTENV=true echo "Superset config module: $SUPERSET_CONFIG" superset db upgrade -superset load_test_users superset init +superset load-test-users echo "Running tests" -pytest --durations-min=2 --cov-report= --cov=superset ./tests/integration_tests "$@" +pytest --durations-min=2 --maxfail=1 --cov-report= --cov=superset ./tests/integration_tests "$@" diff --git a/superset/cli/test.py b/superset/cli/test.py index 33b777b1ef..60ea532cbd 100755 --- a/superset/cli/test.py +++ b/superset/cli/test.py @@ -37,15 +37,7 @@ def load_test_users() -> None: Syncs permissions for those users/roles """ print(Fore.GREEN + "Loading a set of users for unit tests") - load_test_users_run() - -def load_test_users_run() -> None: - """ - Loads admin, alpha, and gamma user for testing purposes - - Syncs permissions for those users/roles - """ if app.config["TESTING"]: sm = security_manager diff --git a/superset/commands/chart/importers/v1/utils.py b/superset/commands/chart/importers/v1/utils.py index 39ca49a5d5..35a7f6e270 100644 --- a/superset/commands/chart/importers/v1/utils.py +++ b/superset/commands/chart/importers/v1/utils.py @@ -77,7 +77,7 @@ def import_chart( if chart.id is None: db.session.flush() - if user := get_user(): + if (user := get_user()) and user not in chart.owners: chart.owners.append(user) return chart diff --git a/superset/commands/chart/update.py b/superset/commands/chart/update.py index 1ea698ba0d..d6b212d5ce 100644 --- a/superset/commands/chart/update.py +++ b/superset/commands/chart/update.py @@ -38,8 +38,8 @@ from superset.daos.chart import ChartDAO from superset.daos.dashboard import DashboardDAO from superset.exceptions import SupersetSecurityException from superset.models.slice import Slice -from superset.utils.decorators import on_error, transaction from superset.tags.models import ObjectType +from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) @@ -62,14 +62,13 @@ class UpdateChartCommand(UpdateMixin, BaseCommand): assert self._model # Update tags - tags = self._properties.pop("tags", None) - if tags is not None: + if (tags := self._properties.pop("tags", None)) is not None: update_tags(ObjectType.chart, self._model.id, self._model.tags, tags) if self._properties.get("query_context_generation") is None: self._properties["last_saved_at"] = datetime.now() self._properties["last_saved_by"] = g.user - + return ChartDAO.update(self._model, self._properties) def validate(self) -> None: diff --git a/superset/commands/dashboard/importers/v1/utils.py b/superset/commands/dashboard/importers/v1/utils.py index f10afd12bc..5e949093b8 100644 --- a/superset/commands/dashboard/importers/v1/utils.py +++ b/superset/commands/dashboard/importers/v1/utils.py @@ -188,7 +188,7 @@ def import_dashboard( if dashboard.id is None: db.session.flush() - if user := get_user(): + if (user := get_user()) and user not in dashboard.owners: dashboard.owners.append(user) return dashboard diff --git a/superset/commands/dashboard/permalink/create.py b/superset/commands/dashboard/permalink/create.py index f6bff344c8..7d08f78e9a 100644 --- a/superset/commands/dashboard/permalink/create.py +++ b/superset/commands/dashboard/permalink/create.py @@ -19,7 +19,6 @@ from functools import partial from sqlalchemy.exc import SQLAlchemyError -from superset import db from superset.commands.dashboard.permalink.base import BaseDashboardPermalinkCommand from superset.commands.key_value.upsert import UpsertKeyValueCommand from superset.daos.dashboard import DashboardDAO @@ -78,6 +77,7 @@ class CreateDashboardPermalinkCommand(BaseDashboardPermalinkCommand): codec=self.codec, ).run() assert key.id # for type checks + return encode_permalink_key(key=key.id, salt=self.salt) def validate(self) -> None: pass diff --git a/superset/commands/dashboard/update.py b/superset/commands/dashboard/update.py index 5294d049ec..2effd7bd2e 100644 --- a/superset/commands/dashboard/update.py +++ b/superset/commands/dashboard/update.py @@ -53,19 +53,16 @@ class UpdateDashboardCommand(UpdateMixin, BaseCommand): assert self._model # Update tags - tags = self._properties.pop("tags", None) - if tags is not None: - update_tags( - ObjectType.dashboard, self._model.id, self._model.tags, tags - ) + if (tags := self._properties.pop("tags", None)) is not None: + update_tags(ObjectType.dashboard, self._model.id, self._model.tags, tags) - dashboard = DashboardDAO.update(self._model, self._properties, commit=False) + dashboard = DashboardDAO.update(self._model, self._properties) if self._properties.get("json_metadata"): DashboardDAO.set_dash_metadata( dashboard, data=json.loads(self._properties.get("json_metadata", "{}")), ) - + return dashboard def validate(self) -> None: diff --git a/superset/commands/database/create.py b/superset/commands/database/create.py index 842f69a7ab..76dd6087be 100644 --- a/superset/commands/database/create.py +++ b/superset/commands/database/create.py @@ -40,7 +40,7 @@ from superset.commands.database.ssh_tunnel.exceptions import ( ) from superset.commands.database.test_connection import TestConnectionDatabaseCommand from superset.daos.database import DatabaseDAO -from superset.daos.exceptions import DAOCreateFailedError +from superset.databases.ssh_tunnel.models import SSHTunnel from superset.exceptions import SupersetErrorsException from superset.extensions import event_logger, security_manager from superset.models.core import Database diff --git a/superset/commands/database/csv_import.py b/superset/commands/database/csv_import.py deleted file mode 100644 index 3354a81a4d..0000000000 --- a/superset/commands/database/csv_import.py +++ /dev/null @@ -1,194 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import logging -from functools import partial -from typing import Any, Optional, TypedDict - -import pandas as pd -from flask_babel import lazy_gettext as _ - -from superset import db -from superset.commands.base import BaseCommand -from superset.commands.database.exceptions import ( - DatabaseNotFoundError, - DatabaseSchemaUploadNotAllowed, - DatabaseUploadFailed, - DatabaseUploadSaveMetadataFailed, -) -from superset.connectors.sqla.models import SqlaTable -from superset.daos.database import DatabaseDAO -from superset.models.core import Database -from superset.sql_parse import Table -from superset.utils.core import get_user -from superset.utils.decorators import on_error, transaction -from superset.views.database.validators import schema_allows_file_upload - -logger = logging.getLogger(__name__) - -READ_CSV_CHUNK_SIZE = 1000 - - -class CSVImportOptions(TypedDict, total=False): - schema: str - delimiter: str - already_exists: str - column_data_types: dict[str, str] - column_dates: list[str] - column_labels: str - columns_read: list[str] - dataframe_index: str - day_first: bool - decimal_character: str - header_row: int - index_column: str - null_values: list[str] - overwrite_duplicates: bool - rows_to_read: int - skip_blank_lines: bool - skip_initial_space: bool - skip_rows: int - - -class CSVImportCommand(BaseCommand): - def __init__( - self, - model_id: int, - table_name: str, - file: Any, - options: CSVImportOptions, - ) -> None: - self._model_id = model_id - self._model: Optional[Database] = None - self._table_name = table_name - self._schema = options.get("schema") - self._file = file - self._options = options - - def _read_csv(self) -> pd.DataFrame: - """ - Read CSV file into a DataFrame - - :return: pandas DataFrame - :throws DatabaseUploadFailed: if there is an error reading the CSV file - """ - try: - return pd.concat( - pd.read_csv( - chunksize=READ_CSV_CHUNK_SIZE, - encoding="utf-8", - filepath_or_buffer=self._file, - header=self._options.get("header_row", 0), - index_col=self._options.get("index_column"), - dayfirst=self._options.get("day_first", False), - iterator=True, - keep_default_na=not self._options.get("null_values"), - usecols=self._options.get("columns_read") - if self._options.get("columns_read") # None if an empty list - else None, - na_values=self._options.get("null_values") - if self._options.get("null_values") # None if an empty list - else None, - nrows=self._options.get("rows_to_read"), - parse_dates=self._options.get("column_dates"), - sep=self._options.get("delimiter", ","), - skip_blank_lines=self._options.get("skip_blank_lines", False), - skipinitialspace=self._options.get("skip_initial_space", False), - skiprows=self._options.get("skip_rows", 0), - dtype=self._options.get("column_data_types") - if self._options.get("column_data_types") - else None, - ) - ) - except ( - pd.errors.ParserError, - pd.errors.EmptyDataError, - UnicodeDecodeError, - ValueError, - ) as ex: - raise DatabaseUploadFailed( - message=_("Parsing error: %(error)s", error=str(ex)) - ) from ex - except Exception as ex: - raise DatabaseUploadFailed(_("Error reading CSV file")) from ex - - def _dataframe_to_database(self, df: pd.DataFrame, database: Database) -> None: - """ - Upload DataFrame to database - - :param df: - :throws DatabaseUploadFailed: if there is an error uploading the DataFrame - """ - try: - csv_table = Table(table=self._table_name, schema=self._schema) - database.db_engine_spec.df_to_sql( - database, - csv_table, - df, - to_sql_kwargs={ - "chunksize": READ_CSV_CHUNK_SIZE, - "if_exists": self._options.get("already_exists", "fail"), - "index": self._options.get("index_column"), - "index_label": self._options.get("column_labels"), - }, - ) - except ValueError as ex: - raise DatabaseUploadFailed( - message=_( - "Table already exists. You can change your " - "'if table already exists' strategy to append or " - "replace or provide a different Table Name to use." - ) - ) from ex - except Exception as ex: - raise DatabaseUploadFailed(exception=ex) from ex - - @transaction(on_error=partial(on_error, reraise=DatabaseUploadSaveMetadataFailed)) - def run(self) -> None: - self.validate() - if not self._model: - return - - df = self._read_csv() - self._dataframe_to_database(df, self._model) - - sqla_table = ( - db.session.query(SqlaTable) - .filter_by( - table_name=self._table_name, - schema=self._schema, - database_id=self._model_id, - ) - .one_or_none() - ) - if not sqla_table: - sqla_table = SqlaTable( - table_name=self._table_name, - database=self._model, - database_id=self._model_id, - owners=[get_user()], - schema=self._schema, - ) - db.session.add(sqla_table) - - sqla_table.fetch_metadata() - - def validate(self) -> None: - self._model = DatabaseDAO.find_by_id(self._model_id) - if not self._model: - raise DatabaseNotFoundError() - if not schema_allows_file_upload(self._model, self._schema): - raise DatabaseSchemaUploadNotAllowed() diff --git a/superset/commands/database/update.py b/superset/commands/database/update.py index 1fd7d786dc..cc8046b889 100644 --- a/superset/commands/database/update.py +++ b/superset/commands/database/update.py @@ -41,7 +41,6 @@ from superset.commands.database.ssh_tunnel.update import UpdateSSHTunnelCommand from superset.daos.database import DatabaseDAO from superset.daos.dataset import DatasetDAO from superset.databases.ssh_tunnel.models import SSHTunnel -from superset.extensions import db from superset.models.core import Database from superset.utils.decorators import on_error, transaction @@ -78,11 +77,7 @@ class UpdateDatabaseCommand(BaseCommand): original_database_name = self._model.database_name try: - database = DatabaseDAO.update( - self._model, - self._properties, - commit=False, - ) + database = DatabaseDAO.update(self._model, self._properties) database.set_sqlalchemy_uri(database.sqlalchemy_uri) ssh_tunnel = self._handle_ssh_tunnel(database) self._refresh_catalogs(database, original_database_name, ssh_tunnel) @@ -100,7 +95,6 @@ class UpdateDatabaseCommand(BaseCommand): return None if not is_feature_enabled("SSH_TUNNELING"): - db.session.rollback() raise SSHTunnelingNotEnabledError() current_ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id) @@ -130,14 +124,11 @@ class UpdateDatabaseCommand(BaseCommand): This method captures a generic exception, since errors could potentially come from any of the 50+ database drivers we support. """ - try: - return database.get_all_catalog_names( - force=True, - ssh_tunnel=ssh_tunnel, - ) - except Exception as ex: - db.session.rollback() - raise DatabaseConnectionFailedError() from ex + + return database.get_all_catalog_names( + force=True, + ssh_tunnel=ssh_tunnel, + ) def _get_schema_names( self, @@ -151,15 +142,12 @@ class UpdateDatabaseCommand(BaseCommand): This method captures a generic exception, since errors could potentially come from any of the 50+ database drivers we support. """ - try: - return database.get_all_schema_names( - force=True, - catalog=catalog, - ssh_tunnel=ssh_tunnel, - ) - except Exception as ex: - db.session.rollback() - raise DatabaseConnectionFailedError() from ex + + return database.get_all_schema_names( + force=True, + catalog=catalog, + ssh_tunnel=ssh_tunnel, + ) def _refresh_catalogs( self, @@ -224,8 +212,6 @@ class UpdateDatabaseCommand(BaseCommand): schemas, ) - db.session.commit() - def _refresh_schemas( self, database: Database, diff --git a/superset/commands/dataset/columns/delete.py b/superset/commands/dataset/columns/delete.py index 1fb2863b1b..821528de74 100644 --- a/superset/commands/dataset/columns/delete.py +++ b/superset/commands/dataset/columns/delete.py @@ -43,7 +43,7 @@ class DeleteDatasetColumnCommand(BaseCommand): def run(self) -> None: self.validate() assert self._model - DatasetColumnDAO.delete(self._model) + DatasetColumnDAO.delete([self._model]) def validate(self) -> None: # Validate/populate model exists diff --git a/superset/commands/dataset/create.py b/superset/commands/dataset/create.py index 28983e74f8..a2d81e548b 100644 --- a/superset/commands/dataset/create.py +++ b/superset/commands/dataset/create.py @@ -32,7 +32,7 @@ from superset.commands.dataset.exceptions import ( ) from superset.daos.dataset import DatasetDAO from superset.exceptions import SupersetSecurityException -from superset.extensions import db, security_manager +from superset.extensions import security_manager from superset.sql_parse import Table from superset.utils.decorators import on_error, transaction diff --git a/superset/commands/dataset/importers/v1/utils.py b/superset/commands/dataset/importers/v1/utils.py index da39be4721..1c508fe252 100644 --- a/superset/commands/dataset/importers/v1/utils.py +++ b/superset/commands/dataset/importers/v1/utils.py @@ -178,7 +178,7 @@ def import_dataset( if data_uri and (not table_exists or force_data): load_data(data_uri, dataset, dataset.database) - if user := get_user(): + if (user := get_user()) and user not in dataset.owners: dataset.owners.append(user) return dataset diff --git a/superset/commands/dataset/metrics/delete.py b/superset/commands/dataset/metrics/delete.py index e4d65236c3..0a749295dc 100644 --- a/superset/commands/dataset/metrics/delete.py +++ b/superset/commands/dataset/metrics/delete.py @@ -43,7 +43,7 @@ class DeleteDatasetMetricCommand(BaseCommand): def run(self) -> None: self.validate() assert self._model - DatasetMetricDAO.delete(self._model) + DatasetMetricDAO.delete([self._model]) def validate(self) -> None: # Validate/populate model exists diff --git a/superset/commands/dataset/update.py b/superset/commands/dataset/update.py index 40b1bf18ba..14d1c5ef44 100644 --- a/superset/commands/dataset/update.py +++ b/superset/commands/dataset/update.py @@ -21,6 +21,7 @@ from typing import Any, Optional from flask_appbuilder.models.sqla import Model from marshmallow import ValidationError +from sqlalchemy.exc import SQLAlchemyError from superset import security_manager from superset.commands.base import BaseCommand, UpdateMixin @@ -60,7 +61,16 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand): self.override_columns = override_columns self._properties["override_columns"] = override_columns - @transaction(on_error=partial(on_error, reraise=DatasetUpdateFailedError)) + @transaction( + on_error=partial( + on_error, + catches=( + SQLAlchemyError, + ValueError, + ), + reraise=DatasetUpdateFailedError, + ) + ) def run(self) -> Model: self.validate() assert self._model diff --git a/superset/commands/explore/permalink/create.py b/superset/commands/explore/permalink/create.py index 03efdc584a..2128fa4b8c 100644 --- a/superset/commands/explore/permalink/create.py +++ b/superset/commands/explore/permalink/create.py @@ -20,7 +20,6 @@ from typing import Any, Optional from sqlalchemy.exc import SQLAlchemyError -from superset import db from superset.commands.explore.permalink.base import BaseExplorePermalinkCommand from superset.commands.key_value.create import CreateKeyValueCommand from superset.explore.permalink.exceptions import ExplorePermalinkCreateFailedError @@ -74,27 +73,7 @@ class CreateExplorePermalinkCommand(BaseExplorePermalinkCommand): key = command.run() if key.id is None: raise ExplorePermalinkCreateFailedError("Unexpected missing key id") - db.session.commit() return encode_permalink_key(key=key.id, salt=self.salt) - d_id, d_type = self.datasource.split("__") - datasource_id = int(d_id) - datasource_type = DatasourceType(d_type) - check_chart_access(datasource_id, self.chart_id, datasource_type) - value = { - "chartId": self.chart_id, - "datasourceId": datasource_id, - "datasourceType": datasource_type.value, - "datasource": self.datasource, - "state": self.state, - } - command = CreateKeyValueCommand( - resource=self.resource, - value=value, - codec=self.codec, - ) - key = command.run() - return encode_permalink_key(key=key.id, salt=self.salt) ->>>>>>> c01dacb71a (chore(dao): Use nested session for operations) def validate(self) -> None: pass diff --git a/superset/commands/importers/v1/assets.py b/superset/commands/importers/v1/assets.py index 5f955db3bf..78a2251a29 100644 --- a/superset/commands/importers/v1/assets.py +++ b/superset/commands/importers/v1/assets.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from functools import partial from typing import Any, Optional from marshmallow import Schema @@ -44,7 +45,7 @@ from superset.datasets.schemas import ImportV1DatasetSchema from superset.migrations.shared.native_filters import migrate_dashboard from superset.models.dashboard import dashboard_slices from superset.queries.saved_queries.schemas import ImportV1SavedQuerySchema -from superset.utils.decorators import transaction +from superset.utils.decorators import on_error, transaction class ImportAssetsCommand(BaseCommand): @@ -154,14 +155,16 @@ class ImportAssetsCommand(BaseCommand): if chart.viz_type == "filter_box": db.session.delete(chart) - @transaction() + @transaction( + on_error=partial( + on_error, + catches=(Exception,), + reraise=ImportFailedError, + ) + ) def run(self) -> None: self.validate() - - try: - self._import(self._configs) - except Exception as ex: - raise ImportFailedError() from ex + self._import(self._configs) def validate(self) -> None: exceptions: list[ValidationError] = [] diff --git a/superset/commands/key_value/delete.py b/superset/commands/key_value/delete.py index ec386675b5..a3fdf079c7 100644 --- a/superset/commands/key_value/delete.py +++ b/superset/commands/key_value/delete.py @@ -53,9 +53,11 @@ class DeleteKeyValueCommand(BaseCommand): pass def delete(self) -> bool: - filter_ = get_filter(self.resource, self.key) - if entry := db.session.query(KeyValueEntry).filter_by(**filter_).first(): + if ( + entry := db.session.query(KeyValueEntry) + .filter_by(**get_filter(self.resource, self.key)) + .first() + ): db.session.delete(entry) - db.session.flush() return True return False diff --git a/superset/commands/key_value/delete_expired.py b/superset/commands/key_value/delete_expired.py index e3e75bf0f1..54991c7531 100644 --- a/superset/commands/key_value/delete_expired.py +++ b/superset/commands/key_value/delete_expired.py @@ -60,4 +60,3 @@ class DeleteExpiredKeyValueCommand(BaseCommand): ) .delete() ) - db.session.flush() diff --git a/superset/commands/key_value/upsert.py b/superset/commands/key_value/upsert.py index e5c6eb7425..32918d9b14 100644 --- a/superset/commands/key_value/upsert.py +++ b/superset/commands/key_value/upsert.py @@ -21,6 +21,8 @@ from functools import partial from typing import Any, Optional, Union from uuid import UUID +from sqlalchemy.exc import SQLAlchemyError + from superset import db from superset.commands.base import BaseCommand from superset.commands.key_value.create import CreateKeyValueCommand @@ -71,7 +73,7 @@ class UpsertKeyValueCommand(BaseCommand): @transaction( on_error=partial( on_error, - catches=(KeyValueCreateFailedError,), + catches=(KeyValueCreateFailedError, SQLAlchemyError), reraise=KeyValueUpsertFailedError, ), ) @@ -82,16 +84,15 @@ class UpsertKeyValueCommand(BaseCommand): pass def upsert(self) -> Key: - filter_ = get_filter(self.resource, self.key) - entry: KeyValueEntry = ( - db.session.query(KeyValueEntry).filter_by(**filter_).first() - ) - if entry: + if ( + entry := db.session.query(KeyValueEntry) + .filter_by(**get_filter(self.resource, self.key)) + .first() + ): entry.value = self.codec.encode(self.value) entry.expires_on = self.expires_on entry.changed_on = datetime.now() entry.changed_by_fk = get_user_id() - db.session.flush() return Key(entry.id, entry.uuid) return CreateKeyValueCommand( diff --git a/superset/commands/report/execute.py b/superset/commands/report/execute.py index 000c87f514..c57828eac4 100644 --- a/superset/commands/report/execute.py +++ b/superset/commands/report/execute.py @@ -137,6 +137,7 @@ class BaseReportState: uuid=self._execution_id, ) db.session.add(log) + db.session.commit() # pylint: disable=consider-using-transaction def _get_url( self, diff --git a/superset/commands/report/log_prune.py b/superset/commands/report/log_prune.py index 493c16ed77..a780bf51e0 100644 --- a/superset/commands/report/log_prune.py +++ b/superset/commands/report/log_prune.py @@ -49,7 +49,6 @@ class AsyncPruneReportScheduleLogCommand(BaseCommand): report_schedule, from_date, ) - db.session.commit() logger.info( "Deleted %s logs for report schedule id: %s", str(row_count), diff --git a/superset/daos/tag.py b/superset/daos/tag.py index 98c83dbe8e..b155cf15c1 100644 --- a/superset/daos/tag.py +++ b/superset/daos/tag.py @@ -383,5 +383,4 @@ class TagDAO(BaseDAO[Tag]): object_id, tag.name, ) - db.session.add_all(tagged_objects) diff --git a/superset/dashboards/api.py b/superset/dashboards/api.py index 3fe557a684..823bfdfa8c 100644 --- a/superset/dashboards/api.py +++ b/superset/dashboards/api.py @@ -32,7 +32,7 @@ from marshmallow import ValidationError from werkzeug.wrappers import Response as WerkzeugResponse from werkzeug.wsgi import FileWrapper -from superset import is_feature_enabled, thumbnail_cache +from superset import db, is_feature_enabled, thumbnail_cache from superset.charts.schemas import ChartEntityResponseSchema from superset.commands.dashboard.create import CreateDashboardCommand from superset.commands.dashboard.delete import DeleteDashboardCommand @@ -1314,7 +1314,13 @@ class DashboardRestApi(BaseSupersetModelRestApi): """ try: body = self.embedded_config_schema.load(request.json) - embedded = EmbeddedDashboardDAO.upsert(dashboard, body["allowed_domains"]) + + with db.session.begin_nested(): + embedded = EmbeddedDashboardDAO.upsert( + dashboard, + body["allowed_domains"], + ) + result = self.embedded_response_schema.dump(embedded) return self.response(200, result=result) except ValidationError as error: diff --git a/superset/extensions/metastore_cache.py b/superset/extensions/metastore_cache.py index 7b4e39677e..1c89e84597 100644 --- a/superset/extensions/metastore_cache.py +++ b/superset/extensions/metastore_cache.py @@ -22,7 +22,6 @@ from uuid import UUID, uuid3 from flask import current_app, Flask, has_app_context from flask_caching import BaseCache -from superset import db from superset.key_value.exceptions import KeyValueCreateFailedError from superset.key_value.types import ( KeyValueCodec, @@ -95,7 +94,6 @@ class SupersetMetastoreCache(BaseCache): codec=self.codec, expires_on=self._get_expiry(timeout), ).run() - db.session.commit() return True def add(self, key: str, value: Any, timeout: Optional[int] = None) -> bool: @@ -111,7 +109,6 @@ class SupersetMetastoreCache(BaseCache): key=self.get_key(key), expires_on=self._get_expiry(timeout), ).run() - db.session.commit() return True except KeyValueCreateFailedError: return False @@ -136,6 +133,4 @@ class SupersetMetastoreCache(BaseCache): # pylint: disable=import-outside-toplevel from superset.commands.key_value.delete import DeleteKeyValueCommand - ret = DeleteKeyValueCommand(resource=RESOURCE, key=self.get_key(key)).run() - db.session.commit() - return ret + return DeleteKeyValueCommand(resource=RESOURCE, key=self.get_key(key)).run() diff --git a/superset/key_value/shared_entries.py b/superset/key_value/shared_entries.py index f472838d2e..130313157a 100644 --- a/superset/key_value/shared_entries.py +++ b/superset/key_value/shared_entries.py @@ -18,7 +18,6 @@ from typing import Any, Optional from uuid import uuid3 -from superset import db from superset.key_value.types import JsonKeyValueCodec, KeyValueResource, SharedKey from superset.key_value.utils import get_uuid_namespace, random_key @@ -46,7 +45,6 @@ def set_shared_value(key: SharedKey, value: Any) -> None: key=uuid_key, codec=CODEC, ).run() - db.session.commit() def get_permalink_salt(key: SharedKey) -> str: diff --git a/superset/security/manager.py b/superset/security/manager.py index a807e33122..e79155354d 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -334,22 +334,26 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods pyjwt_for_guest_token = _jwt_global_obj @property - def get_session(self) -> Session: + def get_session2(self) -> Session: """ Flask-AppBuilder (FAB) which has a tendency to explicitly commit, thus violating our definition of "unit of work". - By providing a monkey patched transaction for the FAB session ensures that any - explicit commit merely flushes and any rollback is a no-op. + By providing a monkey patched transaction for the FAB session, within the + confines of a nested session, ensures that any explicit commit merely flushes + and any rollback is a no-op. """ # pylint: disable=import-outside-toplevel from superset import db - with db.session.begin_nested() as transaction: - transaction.session.commit = transaction.session.flush - transaction.session.rollback = lambda: None - return transaction.session + if db.session._proxied._nested_transaction: # pylint: disable=protected-access + with db.session.begin_nested() as transaction: + transaction.session.commit = transaction.session.flush + transaction.session.rollback = lambda: None + return transaction.session + + return db.session def create_login_manager(self, app: Flask) -> LoginManager: lm = super().create_login_manager(app) @@ -1035,7 +1039,6 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods == None, # noqa: E711 ) ) - self.get_session.flush() if deleted_count := pvms.delete(): logger.info("Deleted %i faulty permissions", deleted_count) @@ -1065,7 +1068,6 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods ) self.create_missing_perms() - self.get_session.flush() self.clean_perms() def _get_all_pvms(self) -> list[PermissionView]: @@ -1138,7 +1140,6 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods ): role_from_permissions.append(permission_view) role_to.permissions = role_from_permissions - self.get_session.flush() def set_role( self, @@ -1159,7 +1160,6 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods permission_view for permission_view in pvms if pvm_check(permission_view) ] role.permissions = role_pvms - self.get_session.flush() def _is_admin_only(self, pvm: PermissionView) -> bool: """ diff --git a/superset/tags/models.py b/superset/tags/models.py index 8c3e53b314..31975c3e8e 100644 --- a/superset/tags/models.py +++ b/superset/tags/models.py @@ -140,6 +140,7 @@ def get_tag( if tag is None: tag = Tag(name=escape(tag_name), type=type_) session.add(tag) + session.commit() return tag diff --git a/superset/utils/database.py b/superset/utils/database.py index 7ed3156502..719e7f2d77 100644 --- a/superset/utils/database.py +++ b/superset/utils/database.py @@ -59,6 +59,7 @@ def get_or_create_db( if database and database.sqlalchemy_uri_decrypted != sqlalchemy_uri: database.set_sqlalchemy_uri(sqlalchemy_uri) + db.session.flush() return database @@ -78,3 +79,4 @@ def remove_database(database: Database) -> None: from superset import db db.session.delete(database) + db.session.flush() diff --git a/superset/utils/decorators.py b/superset/utils/decorators.py index 30c668bab5..26b94ffaaa 100644 --- a/superset/utils/decorators.py +++ b/superset/utils/decorators.py @@ -222,7 +222,7 @@ def on_error( :param ex: The source exception :param catches: The exception types the handler catches - :param reraise: The exception type the handler reraises after catching + :param reraise: The exception type the handler raises after catching :raises Exception: If the exception is not swallowed """ @@ -252,13 +252,16 @@ def transaction( # pylint: disable=redefined-outer-name from superset import db # pylint: disable=import-outside-toplevel try: - with db.session.begin_nested(): - return func(*args, **kwargs) + result = func(*args, **kwargs) + db.session.commit() + return result except Exception as ex: + db.session.rollback() + if on_error: return on_error(ex) - raise ex + raise return wrapped diff --git a/superset/utils/lock.py b/superset/utils/lock.py index 3cd3c8ead5..4723b57fa1 100644 --- a/superset/utils/lock.py +++ b/superset/utils/lock.py @@ -24,7 +24,6 @@ from contextlib import contextmanager from datetime import datetime, timedelta from typing import Any, cast, TypeVar, Union -from superset import db from superset.exceptions import CreateKeyValueDistributedLockFailedException from superset.key_value.exceptions import KeyValueCreateFailedError from superset.key_value.types import JsonKeyValueCodec, KeyValueResource @@ -72,7 +71,6 @@ def KeyValueDistributedLock( # pylint: disable=invalid-name store. :param namespace: The namespace for which the lock is to be acquired. - :type namespace: str :param kwargs: Additional keyword arguments. :yields: A unique identifier (UUID) for the acquired lock (the KV key). :raises CreateKeyValueDistributedLockFailedException: If the lock is taken. @@ -93,12 +91,10 @@ def KeyValueDistributedLock( # pylint: disable=invalid-name value=True, expires_on=datetime.now() + LOCK_EXPIRATION, ).run() - db.session.commit() yield key DeleteKeyValueCommand(resource=KeyValueResource.LOCK, key=key).run() - db.session.commit() logger.debug("Removed lock on namespace %s for key %s", namespace, key) except KeyValueCreateFailedError as ex: raise CreateKeyValueDistributedLockFailedException( diff --git a/superset/utils/log.py b/superset/utils/log.py index 730bb7c43f..71c5528833 100644 --- a/superset/utils/log.py +++ b/superset/utils/log.py @@ -403,6 +403,7 @@ class DBEventLogger(AbstractEventLogger): logs.append(log) try: db.session.bulk_save_objects(logs) + db.session.commit() # pylint: disable=consider-using-transaction except SQLAlchemyError as ex: logging.error("DBEventLogger failed to log event(s)") logging.exception(ex) diff --git a/superset/views/database/views.py b/superset/views/database/views.py index d2ccd49ba5..019dc1138b 100644 --- a/superset/views/database/views.py +++ b/superset/views/database/views.py @@ -14,8 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=consider-using-transaction -from typing import Any, TYPE_CHECKING +from typing import TYPE_CHECKING from flask_appbuilder import expose from flask_appbuilder.models.sqla.interface import SQLAInterface diff --git a/tests/integration_tests/base_tests.py b/tests/integration_tests/base_tests.py index 3f4cab16ad..77633d6564 100644 --- a/tests/integration_tests/base_tests.py +++ b/tests/integration_tests/base_tests.py @@ -349,15 +349,7 @@ class SupersetTestCase(TestCase): self.grant_role_access_to_table(table, role_name) def grant_role_access_to_table(self, table, role_name): - print(">>> grant_role_access_to_table <<<") - print(role_name) - print(db.session.get_bind()) - from flask_appbuilder.security.sqla.models import Role - - print(list(db.session.query(Role).all())) role = security_manager.find_role(role_name) - print(role) - perms = db.session.query(ab_models.PermissionView).all() for perm in perms: if ( diff --git a/tests/integration_tests/celery_tests.py b/tests/integration_tests/celery_tests.py index 3bd82211e5..320497d7c5 100644 --- a/tests/integration_tests/celery_tests.py +++ b/tests/integration_tests/celery_tests.py @@ -332,6 +332,9 @@ def test_run_async_cta_query_with_lower_limit(test_client, ctas_method): tmp_table=tmp_table, ) query = wait_for_success(result) + print(">>> test_run_async_cta_query_with_lower_limit <<<") + print(result) + print(query.to_dict()) assert QueryStatus.SUCCESS == query.status sqlite_select_sql = f"SELECT\n *\nFROM {tmp_table}\nLIMIT {query.limit}\nOFFSET 0" diff --git a/tests/integration_tests/core_tests.py b/tests/integration_tests/core_tests.py index 480bd286ea..211aeb07f4 100644 --- a/tests/integration_tests/core_tests.py +++ b/tests/integration_tests/core_tests.py @@ -91,7 +91,7 @@ class TestCore(SupersetTestCase): self.original_unsafe_db_setting = app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] def tearDown(self): - # db.session.query(Query).delete() + db.session.query(Query).delete() app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = self.original_unsafe_db_setting super().tearDown() @@ -235,7 +235,6 @@ class TestCore(SupersetTestCase): ) for slc in slices: db.session.delete(slc) - print(db.session.dirty) db.session.commit() @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @@ -666,7 +665,6 @@ class TestCore(SupersetTestCase): client_id="client_id_1", username="admin", ) - print(resp) count_ds = [] count_name = [] for series in data["data"]: @@ -814,7 +812,7 @@ class TestCore(SupersetTestCase): mock_cache.return_value = MockCache() rv = self.client.get("/superset/explore_json/data/valid-cache-key") - self.assertEqual(rv.status_code, 401) + self.assertEqual(rv.status_code, 403) def test_explore_json_data_invalid_cache_key(self): self.login(ADMIN_USERNAME) diff --git a/tests/integration_tests/dashboard_utils.py b/tests/integration_tests/dashboard_utils.py index 3a0c3ef217..1b900ecbcc 100644 --- a/tests/integration_tests/dashboard_utils.py +++ b/tests/integration_tests/dashboard_utils.py @@ -97,5 +97,6 @@ def create_dashboard( if slices is not None: dash.slices = slices db.session.add(dash) + db.session.commit() return dash diff --git a/tests/integration_tests/dashboards/commands_tests.py b/tests/integration_tests/dashboards/commands_tests.py index 06edd6c6d0..334e0425cf 100644 --- a/tests/integration_tests/dashboards/commands_tests.py +++ b/tests/integration_tests/dashboards/commands_tests.py @@ -592,7 +592,6 @@ class TestImportDashboardsCommand(SupersetTestCase): } command = v1.ImportDashboardsCommand(contents, overwrite=True) command.run() - command.run() new_num_dashboards = db.session.query(Dashboard).count() assert new_num_dashboards == num_dashboards + 1 diff --git a/tests/integration_tests/dashboards/dao_tests.py b/tests/integration_tests/dashboards/dao_tests.py index eb9207423e..83ef02730b 100644 --- a/tests/integration_tests/dashboards/dao_tests.py +++ b/tests/integration_tests/dashboards/dao_tests.py @@ -48,15 +48,16 @@ class TestDashboardDAO(SupersetTestCase): assert changed_on == DashboardDAO.get_dashboard_changed_on("world_health") old_changed_on = dashboard.changed_on + # freezegun doesn't work for some reason, so we need to sleep here :( time.sleep(1) data = dashboard.data positions = data["position_json"] data.update({"positions": positions}) original_data = copy.deepcopy(data) + data.update({"foo": "bar"}) DashboardDAO.set_dash_metadata(dashboard, data) - db.session.flush() db.session.commit() new_changed_on = DashboardDAO.get_dashboard_changed_on(dashboard) assert old_changed_on.replace(microsecond=0) < new_changed_on @@ -68,7 +69,6 @@ class TestDashboardDAO(SupersetTestCase): ) DashboardDAO.set_dash_metadata(dashboard, original_data) - db.session.flush() db.session.commit() @pytest.mark.usefixtures("load_world_bank_dashboard_with_slices") diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py index 05942ec22a..84d8a44066 100644 --- a/tests/integration_tests/datasets/api_tests.py +++ b/tests/integration_tests/datasets/api_tests.py @@ -26,6 +26,7 @@ import prison import pytest import yaml from sqlalchemy import inspect +from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import joinedload from sqlalchemy.sql import func diff --git a/tests/integration_tests/embedded/api_tests.py b/tests/integration_tests/embedded/api_tests.py index 533f1311d3..64afaa1784 100644 --- a/tests/integration_tests/embedded/api_tests.py +++ b/tests/integration_tests/embedded/api_tests.py @@ -44,6 +44,7 @@ class TestEmbeddedDashboardApi(SupersetTestCase): self.login(ADMIN_USERNAME) self.dash = db.session.query(Dashboard).filter_by(slug="births").first() self.embedded = EmbeddedDashboardDAO.upsert(self.dash, []) + db.session.flush() uri = f"api/v1/{self.resource_name}/{self.embedded.uuid}" response = self.client.get(uri) self.assert200(response) diff --git a/tests/integration_tests/embedded/dao_tests.py b/tests/integration_tests/embedded/dao_tests.py index e1f72feb89..eed161581f 100644 --- a/tests/integration_tests/embedded/dao_tests.py +++ b/tests/integration_tests/embedded/dao_tests.py @@ -34,17 +34,21 @@ class TestEmbeddedDashboardDAO(SupersetTestCase): dash = db.session.query(Dashboard).filter_by(slug="world_health").first() assert not dash.embedded EmbeddedDashboardDAO.upsert(dash, ["test.example.com"]) + db.session.flush() assert dash.embedded self.assertEqual(dash.embedded[0].allowed_domains, ["test.example.com"]) original_uuid = dash.embedded[0].uuid self.assertIsNotNone(original_uuid) EmbeddedDashboardDAO.upsert(dash, []) + db.session.flush() self.assertEqual(dash.embedded[0].allowed_domains, []) self.assertEqual(dash.embedded[0].uuid, original_uuid) @pytest.mark.usefixtures("load_world_bank_dashboard_with_slices") def test_get_by_uuid(self): dash = db.session.query(Dashboard).filter_by(slug="world_health").first() - uuid = str(EmbeddedDashboardDAO.upsert(dash, ["test.example.com"]).uuid) + EmbeddedDashboardDAO.upsert(dash, ["test.example.com"]) + db.session.flush() + uuid = str(dash.embedded[0].uuid) embedded = EmbeddedDashboardDAO.find_by_id(uuid) self.assertIsNotNone(embedded) diff --git a/tests/integration_tests/embedded/test_view.py b/tests/integration_tests/embedded/test_view.py index 7fcfcdba9f..f4d5ae6925 100644 --- a/tests/integration_tests/embedded/test_view.py +++ b/tests/integration_tests/embedded/test_view.py @@ -44,6 +44,7 @@ if TYPE_CHECKING: def test_get_embedded_dashboard(client: FlaskClient[Any]): # noqa: F811 dash = db.session.query(Dashboard).filter_by(slug="births").first() embedded = EmbeddedDashboardDAO.upsert(dash, []) + db.session.flush() uri = f"embedded/{embedded.uuid}" response = client.get(uri) assert response.status_code == 200 @@ -57,6 +58,7 @@ def test_get_embedded_dashboard(client: FlaskClient[Any]): # noqa: F811 def test_get_embedded_dashboard_referrer_not_allowed(client: FlaskClient[Any]): # noqa: F811 dash = db.session.query(Dashboard).filter_by(slug="births").first() embedded = EmbeddedDashboardDAO.upsert(dash, ["test.example.com"]) + db.session.flush() uri = f"embedded/{embedded.uuid}" response = client.get(uri) assert response.status_code == 403 diff --git a/tests/integration_tests/fixtures/birth_names_dashboard.py b/tests/integration_tests/fixtures/birth_names_dashboard.py index 3fe2de5944..513a9f84a2 100644 --- a/tests/integration_tests/fixtures/birth_names_dashboard.py +++ b/tests/integration_tests/fixtures/birth_names_dashboard.py @@ -46,22 +46,25 @@ def load_birth_names_data( @pytest.fixture() def load_birth_names_dashboard_with_slices(load_birth_names_data): with app.app_context(): - _create_dashboards() + dash_id_to_delete, slices_ids_to_delete = _create_dashboards() yield + _cleanup(dash_id_to_delete, slices_ids_to_delete) @pytest.fixture(scope="module") def load_birth_names_dashboard_with_slices_module_scope(load_birth_names_data): with app.app_context(): - _create_dashboards() + dash_id_to_delete, slices_ids_to_delete = _create_dashboards() yield + _cleanup(dash_id_to_delete, slices_ids_to_delete) @pytest.fixture(scope="class") def load_birth_names_dashboard_with_slices_class_scope(load_birth_names_data): with app.app_context(): - _create_dashboards() + dash_id_to_delete, slices_ids_to_delete = _create_dashboards() yield + _cleanup(dash_id_to_delete, slices_ids_to_delete) def _create_dashboards(): @@ -74,7 +77,10 @@ def _create_dashboards(): from superset.examples.birth_names import create_dashboard, create_slices slices, _ = create_slices(table) - create_dashboard(slices) + dash = create_dashboard(slices) + slices_ids_to_delete = [slice.id for slice in slices] + dash_id_to_delete = dash.id + return dash_id_to_delete, slices_ids_to_delete def _create_table( @@ -91,4 +97,20 @@ def _create_table( _set_table_metadata(table, database) _add_table_metrics(table) + db.session.commit() return table + + +def _cleanup(dash_id: int, slice_ids: list[int]) -> None: + schema = get_example_default_schema() + for datasource in db.session.query(SqlaTable).filter_by( + table_name="birth_names", schema=schema + ): + for col in datasource.columns + datasource.metrics: + db.session.delete(col) + + for dash in db.session.query(Dashboard).filter_by(id=dash_id): + db.session.delete(dash) + for slc in db.session.query(Slice).filter(Slice.id.in_(slice_ids)): + db.session.delete(slc) + db.session.commit() diff --git a/tests/integration_tests/fixtures/energy_dashboard.py b/tests/integration_tests/fixtures/energy_dashboard.py index dfac9644ae..5d938e0541 100644 --- a/tests/integration_tests/fixtures/energy_dashboard.py +++ b/tests/integration_tests/fixtures/energy_dashboard.py @@ -61,6 +61,7 @@ def load_energy_table_with_slice(load_energy_table_data): with app.app_context(): slices = _create_energy_table() yield slices + _cleanup() def _get_dataframe(): @@ -109,6 +110,24 @@ def _create_and_commit_energy_slice( return slice +def _cleanup() -> None: + for slice_data in _get_energy_slices(): + slice = ( + db.session.query(Slice) + .filter_by(slice_name=slice_data["slice_title"]) + .first() + ) + db.session.delete(slice) + + metric = ( + db.session.query(SqlMetric).filter_by(metric_name="sum__value").one_or_none() + ) + if metric: + db.session.delete(metric) + + db.session.commit() + + def _get_energy_data(): data = [] for i in range(85): diff --git a/tests/integration_tests/fixtures/tabbed_dashboard.py b/tests/integration_tests/fixtures/tabbed_dashboard.py index cf5b9f109c..d4ddff5796 100644 --- a/tests/integration_tests/fixtures/tabbed_dashboard.py +++ b/tests/integration_tests/fixtures/tabbed_dashboard.py @@ -135,4 +135,7 @@ def tabbed_dashboard(app_context): slices=[], ) db.session.add(dash) - yield + db.session.commit() + yield dash + db.session.query(Dashboard).filter_by(id=dash.id).delete() + db.session.commit() diff --git a/tests/integration_tests/fixtures/unicode_dashboard.py b/tests/integration_tests/fixtures/unicode_dashboard.py index e123279e75..9708457830 100644 --- a/tests/integration_tests/fixtures/unicode_dashboard.py +++ b/tests/integration_tests/fixtures/unicode_dashboard.py @@ -61,6 +61,7 @@ def load_unicode_dashboard_with_slice(load_unicode_data): with app.app_context(): dash = _create_unicode_dashboard(slice_name, None) yield + _cleanup(dash, slice_name) @pytest.fixture() @@ -70,6 +71,7 @@ def load_unicode_dashboard_with_position(load_unicode_data): with app.app_context(): dash = _create_unicode_dashboard(slice_name, position) yield + _cleanup(dash, slice_name) def _get_dataframe(): @@ -95,18 +97,25 @@ def _create_unicode_dashboard(slice_title: str, position: str) -> Dashboard: table.fetch_metadata() if slice_title: - slice = _create_unicode_slice(table, slice_title) + slice = _create_and_commit_unicode_slice(table, slice_title) return create_dashboard("unicode-test", "Unicode Test", position, [slice]) -def _create_unicode_slice(table: SqlaTable, title: str): - slc = create_slice(title, "word_cloud", table, {}) - if ( - obj := db.session.query(Slice) - .filter_by(slice_name=slc.slice_name) - .one_or_none() +def _create_and_commit_unicode_slice(table: SqlaTable, title: str): + slice = create_slice(title, "word_cloud", table, {}) + o = db.session.query(Slice).filter_by(slice_name=slice.slice_name).one_or_none() + if o: + db.session.delete(o) + db.session.add(slice) + db.session.commit() + return slice + + +def _cleanup(dash: Dashboard, slice_name: str) -> None: + db.session.delete(dash) + if slice_name and ( + slice := db.session.query(Slice).filter_by(slice_name=slice_name).one_or_none() ): - db.session.delete(obj) - db.session.add(slc) - return slc + db.session.delete(slice) + db.session.commit() diff --git a/tests/integration_tests/fixtures/world_bank_dashboard.py b/tests/integration_tests/fixtures/world_bank_dashboard.py index 99f8b57375..6e2b408600 100644 --- a/tests/integration_tests/fixtures/world_bank_dashboard.py +++ b/tests/integration_tests/fixtures/world_bank_dashboard.py @@ -64,7 +64,6 @@ def load_world_bank_data(): ) yield - with app.app_context(): with get_example_database().get_sqla_engine() as engine: engine.execute("DROP TABLE IF EXISTS wb_health_population") @@ -73,14 +72,15 @@ def load_world_bank_data(): @pytest.fixture() def load_world_bank_dashboard_with_slices(load_world_bank_data): with app.app_context(): - create_dashboard_for_loaded_data() + dash_id_to_delete, slices_ids_to_delete = create_dashboard_for_loaded_data() yield + _cleanup(dash_id_to_delete, slices_ids_to_delete) @pytest.fixture(scope="module") def load_world_bank_dashboard_with_slices_module_scope(load_world_bank_data): with app.app_context(): - create_dashboard_for_loaded_data() + dash_id_to_delete, slices_ids_to_delete = create_dashboard_for_loaded_data() yield _cleanup_reports(dash_id_to_delete, slices_ids_to_delete) _cleanup(dash_id_to_delete, slices_ids_to_delete) @@ -89,8 +89,9 @@ def load_world_bank_dashboard_with_slices_module_scope(load_world_bank_data): @pytest.fixture(scope="class") def load_world_bank_dashboard_with_slices_class_scope(load_world_bank_data): with app.app_context(): - create_dashboard_for_loaded_data() + dash_id_to_delete, slices_ids_to_delete = create_dashboard_for_loaded_data() yield + _cleanup(dash_id_to_delete, slices_ids_to_delete) def create_dashboard_for_loaded_data(): @@ -102,22 +103,24 @@ def create_dashboard_for_loaded_data(): return dash_id_to_delete, slices_ids_to_delete -def _create_world_bank_slices(table: SqlaTable) -> None: +def _create_world_bank_slices(table: SqlaTable) -> list[Slice]: from superset.examples.world_bank import create_slices slices = create_slices(table) - - for slc in slices: - if ( - obj := db.session.query(Slice) - .filter_by(slice_name=slc.slice_name) - .one_or_none() - ): - db.session.delete(obj) - db.session.add(slc) + _commit_slices(slices) + return slices -def _create_world_bank_dashboard(table: SqlaTable) -> None: +def _commit_slices(slices: list[Slice]): + for slice in slices: + o = db.session.query(Slice).filter_by(slice_name=slice.slice_name).one_or_none() + if o: + db.session.delete(o) + db.session.add(slice) + db.session.commit() + + +def _create_world_bank_dashboard(table: SqlaTable) -> Dashboard: from superset.examples.helpers import update_slice_ids from superset.examples.world_bank import dashboard_positions @@ -130,6 +133,16 @@ def _create_world_bank_dashboard(table: SqlaTable) -> None: "world_health", "World Bank's Data", json.dumps(pos), slices ) dash.json_metadata = '{"mock_key": "mock_value"}' + db.session.commit() + return dash + + +def _cleanup(dash_id: int, slices_ids: list[int]) -> None: + dash = db.session.query(Dashboard).filter_by(id=dash_id).first() + db.session.delete(dash) + for slice_id in slices_ids: + db.session.query(Slice).filter_by(id=slice_id).delete() + db.session.commit() def _cleanup_reports(dash_id: int, slices_ids: list[int]) -> None: diff --git a/tests/integration_tests/security/row_level_security_tests.py b/tests/integration_tests/security/row_level_security_tests.py index 2c8a13a71f..71bb1484e0 100644 --- a/tests/integration_tests/security/row_level_security_tests.py +++ b/tests/integration_tests/security/row_level_security_tests.py @@ -215,8 +215,6 @@ class TestRowLevelSecurity(SupersetTestCase): }, ) self.assertEqual(rv.status_code, 422) - data = json.loads(rv.data.decode("utf-8")) - assert "Create failed" in data["message"] @pytest.mark.usefixtures("create_dataset") def test_model_view_rls_add_tables_required(self): diff --git a/tests/integration_tests/sqllab_tests.py b/tests/integration_tests/sqllab_tests.py index 3f5a211db3..829854d966 100644 --- a/tests/integration_tests/sqllab_tests.py +++ b/tests/integration_tests/sqllab_tests.py @@ -71,10 +71,8 @@ QUERY_3 = "SELECT * FROM birth_names LIMIT 10" class TestSqlLab(SupersetTestCase): """Testings for Sql Lab""" - @pytest.mark.usefixtures("load_birth_names_data") def run_some_queries(self): db.session.query(Query).delete() - db.session.commit() self.run_sql(QUERY_1, client_id="client_id_1", username="admin") self.run_sql(QUERY_2, client_id="client_id_2", username="admin") self.run_sql(QUERY_3, client_id="client_id_3", username="gamma_sqllab") @@ -419,6 +417,7 @@ class TestSqlLab(SupersetTestCase): self.assertEqual(len(data["data"]), 1200) self.assertEqual(data["query"]["limitingFactor"], LimitingFactor.NOT_LIMITED) + @pytest.mark.usefixtures("load_birth_names_data") def test_query_api_filter(self) -> None: """ Test query api without can_only_access_owned_queries perm added to @@ -438,6 +437,7 @@ class TestSqlLab(SupersetTestCase): assert admin.first_name in user_queries assert gamma_sqllab.first_name in user_queries + @pytest.mark.usefixtures("load_birth_names_data") def test_query_api_can_access_all_queries(self) -> None: """ Test query api with can_access_all_queries perm added to @@ -522,6 +522,7 @@ class TestSqlLab(SupersetTestCase): {r.get("sql") for r in self.get_json_resp(url)["result"]}, ) + @pytest.mark.usefixtures("load_birth_names_data") def test_query_admin_can_access_all_queries(self) -> None: """ Test query api with all_query_access perm added to diff --git a/tests/integration_tests/tags/dao_tests.py b/tests/integration_tests/tags/dao_tests.py index dbd0360aa7..8a6ba6e5f4 100644 --- a/tests/integration_tests/tags/dao_tests.py +++ b/tests/integration_tests/tags/dao_tests.py @@ -187,6 +187,7 @@ class TestTagsDAO(SupersetTestCase): TaggedObject.object_type == ObjectType.chart, ), ) + .join(Tag, TaggedObject.tag_id == Tag.id) .distinct(Slice.id) .count() ) @@ -199,6 +200,7 @@ class TestTagsDAO(SupersetTestCase): TaggedObject.object_type == ObjectType.dashboard, ), ) + .join(Tag, TaggedObject.tag_id == Tag.id) .distinct(Dashboard.id) .count() + num_charts diff --git a/tests/unit_tests/commands/databases/create_test.py b/tests/unit_tests/commands/databases/create_test.py index 405238827d..09d5744afd 100644 --- a/tests/unit_tests/commands/databases/create_test.py +++ b/tests/unit_tests/commands/databases/create_test.py @@ -29,7 +29,6 @@ def database_with_catalog(mocker: MockerFixture) -> MagicMock: """ Mock a database with catalogs and schemas. """ - mocker.patch("superset.commands.database.create.db") mocker.patch("superset.commands.database.create.TestConnectionDatabaseCommand") database = mocker.MagicMock() @@ -53,7 +52,6 @@ def database_without_catalog(mocker: MockerFixture) -> MagicMock: """ Mock a database without catalogs. """ - mocker.patch("superset.commands.database.create.db") mocker.patch("superset.commands.database.create.TestConnectionDatabaseCommand") database = mocker.MagicMock() diff --git a/tests/unit_tests/commands/databases/update_test.py b/tests/unit_tests/commands/databases/update_test.py index 300efb62e7..37500d5214 100644 --- a/tests/unit_tests/commands/databases/update_test.py +++ b/tests/unit_tests/commands/databases/update_test.py @@ -29,8 +29,6 @@ def database_with_catalog(mocker: MockerFixture) -> MagicMock: """ Mock a database with catalogs and schemas. """ - mocker.patch("superset.commands.database.update.db") - database = mocker.MagicMock() database.database_name = "my_db" database.db_engine_spec.__name__ = "test_engine" @@ -50,8 +48,6 @@ def database_without_catalog(mocker: MockerFixture) -> MagicMock: """ Mock a database without catalogs. """ - mocker.patch("superset.commands.database.update.db") - database = mocker.MagicMock() database.database_name = "my_db" database.db_engine_spec.__name__ = "test_engine" diff --git a/tests/unit_tests/dao/tag_test.py b/tests/unit_tests/dao/tag_test.py index d1907fb6cb..7662393d4f 100644 --- a/tests/unit_tests/dao/tag_test.py +++ b/tests/unit_tests/dao/tag_test.py @@ -22,7 +22,6 @@ def test_user_favorite_tag(mocker): from superset.daos.tag import TagDAO # Mock the behavior of TagDAO and g - mock_session = mocker.patch("superset.daos.tag.db.session") mock_TagDAO = mocker.patch( "superset.daos.tag.TagDAO" ) # Replace with the actual path to TagDAO @@ -45,7 +44,6 @@ def test_remove_user_favorite_tag(mocker): from superset.daos.tag import TagDAO # Mock the behavior of TagDAO and g - mock_session = mocker.patch("superset.daos.tag.db.session") mock_TagDAO = mocker.patch("superset.daos.tag.TagDAO") mock_tag = mocker.MagicMock(users_favorited=[]) mock_TagDAO.find_by_id.return_value = mock_tag diff --git a/tests/unit_tests/databases/api_test.py b/tests/unit_tests/databases/api_test.py index 6eeb7ff162..f4534d216b 100644 --- a/tests/unit_tests/databases/api_test.py +++ b/tests/unit_tests/databases/api_test.py @@ -116,7 +116,7 @@ def test_post_with_uuid( assert payload["result"]["uuid"] == "7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb" database = session.query(Database).one() - assert database.uuid == "7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb" + assert database.uuid == UUID("7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb") def test_password_mask( diff --git a/tests/unit_tests/utils/lock_tests.py b/tests/unit_tests/utils/lock_tests.py index aa231bb0cf..4c9121fe38 100644 --- a/tests/unit_tests/utils/lock_tests.py +++ b/tests/unit_tests/utils/lock_tests.py @@ -22,8 +22,8 @@ from uuid import UUID import pytest from freezegun import freeze_time -from sqlalchemy.orm import Session, sessionmaker +from superset import db from superset.exceptions import CreateKeyValueDistributedLockFailedException from superset.key_value.types import JsonKeyValueCodec from superset.utils.lock import get_key, KeyValueDistributedLock @@ -32,56 +32,51 @@ MAIN_KEY = get_key("ns", a=1, b=2) OTHER_KEY = get_key("ns2", a=1, b=2) -def _get_lock(key: UUID, session: Session) -> Any: +def _get_lock(key: UUID) -> Any: from superset.key_value.models import KeyValueEntry - entry = session.query(KeyValueEntry).filter_by(uuid=key).first() + entry = db.session.query(KeyValueEntry).filter_by(uuid=key).first() if entry is None or entry.is_expired(): return None return JsonKeyValueCodec().decode(entry.value) -def _get_other_session() -> Session: - # This session is used to simulate what another worker will find in the metastore - # during the locking process. - from superset import db - - bind = db.session.get_bind() - SessionMaker = sessionmaker(bind=bind) - return SessionMaker() - - def test_key_value_distributed_lock_happy_path() -> None: """ Test successfully acquiring and returning the distributed lock. + + Note we use a nested transaction to ensure that the cleanup from the outer context + manager is correctly invoked, otherwise a partial rollback would occur leaving the + database in a fractured state. """ - session = _get_other_session() with freeze_time("2021-01-01"): - assert _get_lock(MAIN_KEY, session) is None + assert _get_lock(MAIN_KEY) is None + with KeyValueDistributedLock("ns", a=1, b=2) as key: assert key == MAIN_KEY - assert _get_lock(key, session) is True - assert _get_lock(OTHER_KEY, session) is None - with pytest.raises(CreateKeyValueDistributedLockFailedException): - with KeyValueDistributedLock("ns", a=1, b=2): - pass + assert _get_lock(key) is True + assert _get_lock(OTHER_KEY) is None - assert _get_lock(MAIN_KEY, session) is None + with db.session.begin_nested(): + with pytest.raises(CreateKeyValueDistributedLockFailedException): + with KeyValueDistributedLock("ns", a=1, b=2): + pass + + assert _get_lock(MAIN_KEY) is None def test_key_value_distributed_lock_expired() -> None: """ Test expiration of the distributed lock """ - session = _get_other_session() - with freeze_time("2021-01-01T"): - assert _get_lock(MAIN_KEY, session) is None + with freeze_time("2021-01-01"): + assert _get_lock(MAIN_KEY) is None with KeyValueDistributedLock("ns", a=1, b=2): - assert _get_lock(MAIN_KEY, session) is True - with freeze_time("2022-01-01T"): - assert _get_lock(MAIN_KEY, session) is None + assert _get_lock(MAIN_KEY) is True + with freeze_time("2022-01-01"): + assert _get_lock(MAIN_KEY) is None - assert _get_lock(MAIN_KEY, session) is None + assert _get_lock(MAIN_KEY) is None