Revert interim test updates

This commit is contained in:
John Bodley 2024-04-19 15:16:18 -07:00
parent acfba3da29
commit 3e4325a176
58 changed files with 239 additions and 411 deletions

View File

@ -239,8 +239,8 @@ basepython = python3.10
ignore_basepython_conflict = true ignore_basepython_conflict = true
commands = commands =
superset db upgrade superset db upgrade
superset load_test_users
superset init superset init
superset load-test-users
# use -s to be able to use break pointers. # use -s to be able to use break pointers.
# no args or tests/* can be passed as an argument to run all tests # no args or tests/* can be passed as an argument to run all tests
pytest -s {posargs} pytest -s {posargs}

View File

@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=consider-using-transaction
from collections import defaultdict from collections import defaultdict
from superset import db, security_manager from superset import db, security_manager

View File

@ -28,9 +28,9 @@ export SUPERSET_TESTENV=true
echo "Superset config module: $SUPERSET_CONFIG" echo "Superset config module: $SUPERSET_CONFIG"
superset db upgrade superset db upgrade
superset load_test_users
superset init superset init
superset load-test-users
echo "Running tests" echo "Running tests"
pytest --durations-min=2 --cov-report= --cov=superset ./tests/integration_tests "$@" pytest --durations-min=2 --maxfail=1 --cov-report= --cov=superset ./tests/integration_tests "$@"

View File

@ -37,15 +37,7 @@ def load_test_users() -> None:
Syncs permissions for those users/roles Syncs permissions for those users/roles
""" """
print(Fore.GREEN + "Loading a set of users for unit tests") print(Fore.GREEN + "Loading a set of users for unit tests")
load_test_users_run()
def load_test_users_run() -> None:
"""
Loads admin, alpha, and gamma user for testing purposes
Syncs permissions for those users/roles
"""
if app.config["TESTING"]: if app.config["TESTING"]:
sm = security_manager sm = security_manager

View File

@ -77,7 +77,7 @@ def import_chart(
if chart.id is None: if chart.id is None:
db.session.flush() db.session.flush()
if user := get_user(): if (user := get_user()) and user not in chart.owners:
chart.owners.append(user) chart.owners.append(user)
return chart return chart

View File

@ -38,8 +38,8 @@ from superset.daos.chart import ChartDAO
from superset.daos.dashboard import DashboardDAO from superset.daos.dashboard import DashboardDAO
from superset.exceptions import SupersetSecurityException from superset.exceptions import SupersetSecurityException
from superset.models.slice import Slice from superset.models.slice import Slice
from superset.utils.decorators import on_error, transaction
from superset.tags.models import ObjectType from superset.tags.models import ObjectType
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -62,14 +62,13 @@ class UpdateChartCommand(UpdateMixin, BaseCommand):
assert self._model assert self._model
# Update tags # Update tags
tags = self._properties.pop("tags", None) if (tags := self._properties.pop("tags", None)) is not None:
if tags is not None:
update_tags(ObjectType.chart, self._model.id, self._model.tags, tags) update_tags(ObjectType.chart, self._model.id, self._model.tags, tags)
if self._properties.get("query_context_generation") is None: if self._properties.get("query_context_generation") is None:
self._properties["last_saved_at"] = datetime.now() self._properties["last_saved_at"] = datetime.now()
self._properties["last_saved_by"] = g.user self._properties["last_saved_by"] = g.user
return ChartDAO.update(self._model, self._properties) return ChartDAO.update(self._model, self._properties)
def validate(self) -> None: def validate(self) -> None:

View File

@ -188,7 +188,7 @@ def import_dashboard(
if dashboard.id is None: if dashboard.id is None:
db.session.flush() db.session.flush()
if user := get_user(): if (user := get_user()) and user not in dashboard.owners:
dashboard.owners.append(user) dashboard.owners.append(user)
return dashboard return dashboard

View File

@ -19,7 +19,6 @@ from functools import partial
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from superset import db
from superset.commands.dashboard.permalink.base import BaseDashboardPermalinkCommand from superset.commands.dashboard.permalink.base import BaseDashboardPermalinkCommand
from superset.commands.key_value.upsert import UpsertKeyValueCommand from superset.commands.key_value.upsert import UpsertKeyValueCommand
from superset.daos.dashboard import DashboardDAO from superset.daos.dashboard import DashboardDAO
@ -78,6 +77,7 @@ class CreateDashboardPermalinkCommand(BaseDashboardPermalinkCommand):
codec=self.codec, codec=self.codec,
).run() ).run()
assert key.id # for type checks assert key.id # for type checks
return encode_permalink_key(key=key.id, salt=self.salt)
def validate(self) -> None: def validate(self) -> None:
pass pass

View File

@ -53,19 +53,16 @@ class UpdateDashboardCommand(UpdateMixin, BaseCommand):
assert self._model assert self._model
# Update tags # Update tags
tags = self._properties.pop("tags", None) if (tags := self._properties.pop("tags", None)) is not None:
if tags is not None: update_tags(ObjectType.dashboard, self._model.id, self._model.tags, tags)
update_tags(
ObjectType.dashboard, self._model.id, self._model.tags, tags
)
dashboard = DashboardDAO.update(self._model, self._properties, commit=False) dashboard = DashboardDAO.update(self._model, self._properties)
if self._properties.get("json_metadata"): if self._properties.get("json_metadata"):
DashboardDAO.set_dash_metadata( DashboardDAO.set_dash_metadata(
dashboard, dashboard,
data=json.loads(self._properties.get("json_metadata", "{}")), data=json.loads(self._properties.get("json_metadata", "{}")),
) )
return dashboard return dashboard
def validate(self) -> None: def validate(self) -> None:

View File

@ -40,7 +40,7 @@ from superset.commands.database.ssh_tunnel.exceptions import (
) )
from superset.commands.database.test_connection import TestConnectionDatabaseCommand from superset.commands.database.test_connection import TestConnectionDatabaseCommand
from superset.daos.database import DatabaseDAO from superset.daos.database import DatabaseDAO
from superset.daos.exceptions import DAOCreateFailedError from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.exceptions import SupersetErrorsException from superset.exceptions import SupersetErrorsException
from superset.extensions import event_logger, security_manager from superset.extensions import event_logger, security_manager
from superset.models.core import Database from superset.models.core import Database

View File

@ -1,194 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from functools import partial
from typing import Any, Optional, TypedDict
import pandas as pd
from flask_babel import lazy_gettext as _
from superset import db
from superset.commands.base import BaseCommand
from superset.commands.database.exceptions import (
DatabaseNotFoundError,
DatabaseSchemaUploadNotAllowed,
DatabaseUploadFailed,
DatabaseUploadSaveMetadataFailed,
)
from superset.connectors.sqla.models import SqlaTable
from superset.daos.database import DatabaseDAO
from superset.models.core import Database
from superset.sql_parse import Table
from superset.utils.core import get_user
from superset.utils.decorators import on_error, transaction
from superset.views.database.validators import schema_allows_file_upload
logger = logging.getLogger(__name__)
READ_CSV_CHUNK_SIZE = 1000
class CSVImportOptions(TypedDict, total=False):
schema: str
delimiter: str
already_exists: str
column_data_types: dict[str, str]
column_dates: list[str]
column_labels: str
columns_read: list[str]
dataframe_index: str
day_first: bool
decimal_character: str
header_row: int
index_column: str
null_values: list[str]
overwrite_duplicates: bool
rows_to_read: int
skip_blank_lines: bool
skip_initial_space: bool
skip_rows: int
class CSVImportCommand(BaseCommand):
def __init__(
self,
model_id: int,
table_name: str,
file: Any,
options: CSVImportOptions,
) -> None:
self._model_id = model_id
self._model: Optional[Database] = None
self._table_name = table_name
self._schema = options.get("schema")
self._file = file
self._options = options
def _read_csv(self) -> pd.DataFrame:
"""
Read CSV file into a DataFrame
:return: pandas DataFrame
:throws DatabaseUploadFailed: if there is an error reading the CSV file
"""
try:
return pd.concat(
pd.read_csv(
chunksize=READ_CSV_CHUNK_SIZE,
encoding="utf-8",
filepath_or_buffer=self._file,
header=self._options.get("header_row", 0),
index_col=self._options.get("index_column"),
dayfirst=self._options.get("day_first", False),
iterator=True,
keep_default_na=not self._options.get("null_values"),
usecols=self._options.get("columns_read")
if self._options.get("columns_read") # None if an empty list
else None,
na_values=self._options.get("null_values")
if self._options.get("null_values") # None if an empty list
else None,
nrows=self._options.get("rows_to_read"),
parse_dates=self._options.get("column_dates"),
sep=self._options.get("delimiter", ","),
skip_blank_lines=self._options.get("skip_blank_lines", False),
skipinitialspace=self._options.get("skip_initial_space", False),
skiprows=self._options.get("skip_rows", 0),
dtype=self._options.get("column_data_types")
if self._options.get("column_data_types")
else None,
)
)
except (
pd.errors.ParserError,
pd.errors.EmptyDataError,
UnicodeDecodeError,
ValueError,
) as ex:
raise DatabaseUploadFailed(
message=_("Parsing error: %(error)s", error=str(ex))
) from ex
except Exception as ex:
raise DatabaseUploadFailed(_("Error reading CSV file")) from ex
def _dataframe_to_database(self, df: pd.DataFrame, database: Database) -> None:
"""
Upload DataFrame to database
:param df:
:throws DatabaseUploadFailed: if there is an error uploading the DataFrame
"""
try:
csv_table = Table(table=self._table_name, schema=self._schema)
database.db_engine_spec.df_to_sql(
database,
csv_table,
df,
to_sql_kwargs={
"chunksize": READ_CSV_CHUNK_SIZE,
"if_exists": self._options.get("already_exists", "fail"),
"index": self._options.get("index_column"),
"index_label": self._options.get("column_labels"),
},
)
except ValueError as ex:
raise DatabaseUploadFailed(
message=_(
"Table already exists. You can change your "
"'if table already exists' strategy to append or "
"replace or provide a different Table Name to use."
)
) from ex
except Exception as ex:
raise DatabaseUploadFailed(exception=ex) from ex
@transaction(on_error=partial(on_error, reraise=DatabaseUploadSaveMetadataFailed))
def run(self) -> None:
self.validate()
if not self._model:
return
df = self._read_csv()
self._dataframe_to_database(df, self._model)
sqla_table = (
db.session.query(SqlaTable)
.filter_by(
table_name=self._table_name,
schema=self._schema,
database_id=self._model_id,
)
.one_or_none()
)
if not sqla_table:
sqla_table = SqlaTable(
table_name=self._table_name,
database=self._model,
database_id=self._model_id,
owners=[get_user()],
schema=self._schema,
)
db.session.add(sqla_table)
sqla_table.fetch_metadata()
def validate(self) -> None:
self._model = DatabaseDAO.find_by_id(self._model_id)
if not self._model:
raise DatabaseNotFoundError()
if not schema_allows_file_upload(self._model, self._schema):
raise DatabaseSchemaUploadNotAllowed()

View File

@ -41,7 +41,6 @@ from superset.commands.database.ssh_tunnel.update import UpdateSSHTunnelCommand
from superset.daos.database import DatabaseDAO from superset.daos.database import DatabaseDAO
from superset.daos.dataset import DatasetDAO from superset.daos.dataset import DatasetDAO
from superset.databases.ssh_tunnel.models import SSHTunnel from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.extensions import db
from superset.models.core import Database from superset.models.core import Database
from superset.utils.decorators import on_error, transaction from superset.utils.decorators import on_error, transaction
@ -78,11 +77,7 @@ class UpdateDatabaseCommand(BaseCommand):
original_database_name = self._model.database_name original_database_name = self._model.database_name
try: try:
database = DatabaseDAO.update( database = DatabaseDAO.update(self._model, self._properties)
self._model,
self._properties,
commit=False,
)
database.set_sqlalchemy_uri(database.sqlalchemy_uri) database.set_sqlalchemy_uri(database.sqlalchemy_uri)
ssh_tunnel = self._handle_ssh_tunnel(database) ssh_tunnel = self._handle_ssh_tunnel(database)
self._refresh_catalogs(database, original_database_name, ssh_tunnel) self._refresh_catalogs(database, original_database_name, ssh_tunnel)
@ -100,7 +95,6 @@ class UpdateDatabaseCommand(BaseCommand):
return None return None
if not is_feature_enabled("SSH_TUNNELING"): if not is_feature_enabled("SSH_TUNNELING"):
db.session.rollback()
raise SSHTunnelingNotEnabledError() raise SSHTunnelingNotEnabledError()
current_ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id) current_ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
@ -130,14 +124,11 @@ class UpdateDatabaseCommand(BaseCommand):
This method captures a generic exception, since errors could potentially come This method captures a generic exception, since errors could potentially come
from any of the 50+ database drivers we support. from any of the 50+ database drivers we support.
""" """
try:
return database.get_all_catalog_names( return database.get_all_catalog_names(
force=True, force=True,
ssh_tunnel=ssh_tunnel, ssh_tunnel=ssh_tunnel,
) )
except Exception as ex:
db.session.rollback()
raise DatabaseConnectionFailedError() from ex
def _get_schema_names( def _get_schema_names(
self, self,
@ -151,15 +142,12 @@ class UpdateDatabaseCommand(BaseCommand):
This method captures a generic exception, since errors could potentially come This method captures a generic exception, since errors could potentially come
from any of the 50+ database drivers we support. from any of the 50+ database drivers we support.
""" """
try:
return database.get_all_schema_names( return database.get_all_schema_names(
force=True, force=True,
catalog=catalog, catalog=catalog,
ssh_tunnel=ssh_tunnel, ssh_tunnel=ssh_tunnel,
) )
except Exception as ex:
db.session.rollback()
raise DatabaseConnectionFailedError() from ex
def _refresh_catalogs( def _refresh_catalogs(
self, self,
@ -224,8 +212,6 @@ class UpdateDatabaseCommand(BaseCommand):
schemas, schemas,
) )
db.session.commit()
def _refresh_schemas( def _refresh_schemas(
self, self,
database: Database, database: Database,

View File

@ -43,7 +43,7 @@ class DeleteDatasetColumnCommand(BaseCommand):
def run(self) -> None: def run(self) -> None:
self.validate() self.validate()
assert self._model assert self._model
DatasetColumnDAO.delete(self._model) DatasetColumnDAO.delete([self._model])
def validate(self) -> None: def validate(self) -> None:
# Validate/populate model exists # Validate/populate model exists

View File

@ -32,7 +32,7 @@ from superset.commands.dataset.exceptions import (
) )
from superset.daos.dataset import DatasetDAO from superset.daos.dataset import DatasetDAO
from superset.exceptions import SupersetSecurityException from superset.exceptions import SupersetSecurityException
from superset.extensions import db, security_manager from superset.extensions import security_manager
from superset.sql_parse import Table from superset.sql_parse import Table
from superset.utils.decorators import on_error, transaction from superset.utils.decorators import on_error, transaction

View File

@ -178,7 +178,7 @@ def import_dataset(
if data_uri and (not table_exists or force_data): if data_uri and (not table_exists or force_data):
load_data(data_uri, dataset, dataset.database) load_data(data_uri, dataset, dataset.database)
if user := get_user(): if (user := get_user()) and user not in dataset.owners:
dataset.owners.append(user) dataset.owners.append(user)
return dataset return dataset

View File

@ -43,7 +43,7 @@ class DeleteDatasetMetricCommand(BaseCommand):
def run(self) -> None: def run(self) -> None:
self.validate() self.validate()
assert self._model assert self._model
DatasetMetricDAO.delete(self._model) DatasetMetricDAO.delete([self._model])
def validate(self) -> None: def validate(self) -> None:
# Validate/populate model exists # Validate/populate model exists

View File

@ -21,6 +21,7 @@ from typing import Any, Optional
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
from marshmallow import ValidationError from marshmallow import ValidationError
from sqlalchemy.exc import SQLAlchemyError
from superset import security_manager from superset import security_manager
from superset.commands.base import BaseCommand, UpdateMixin from superset.commands.base import BaseCommand, UpdateMixin
@ -60,7 +61,16 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand):
self.override_columns = override_columns self.override_columns = override_columns
self._properties["override_columns"] = override_columns self._properties["override_columns"] = override_columns
@transaction(on_error=partial(on_error, reraise=DatasetUpdateFailedError)) @transaction(
on_error=partial(
on_error,
catches=(
SQLAlchemyError,
ValueError,
),
reraise=DatasetUpdateFailedError,
)
)
def run(self) -> Model: def run(self) -> Model:
self.validate() self.validate()
assert self._model assert self._model

View File

@ -20,7 +20,6 @@ from typing import Any, Optional
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from superset import db
from superset.commands.explore.permalink.base import BaseExplorePermalinkCommand from superset.commands.explore.permalink.base import BaseExplorePermalinkCommand
from superset.commands.key_value.create import CreateKeyValueCommand from superset.commands.key_value.create import CreateKeyValueCommand
from superset.explore.permalink.exceptions import ExplorePermalinkCreateFailedError from superset.explore.permalink.exceptions import ExplorePermalinkCreateFailedError
@ -74,27 +73,7 @@ class CreateExplorePermalinkCommand(BaseExplorePermalinkCommand):
key = command.run() key = command.run()
if key.id is None: if key.id is None:
raise ExplorePermalinkCreateFailedError("Unexpected missing key id") raise ExplorePermalinkCreateFailedError("Unexpected missing key id")
db.session.commit()
return encode_permalink_key(key=key.id, salt=self.salt) return encode_permalink_key(key=key.id, salt=self.salt)
d_id, d_type = self.datasource.split("__")
datasource_id = int(d_id)
datasource_type = DatasourceType(d_type)
check_chart_access(datasource_id, self.chart_id, datasource_type)
value = {
"chartId": self.chart_id,
"datasourceId": datasource_id,
"datasourceType": datasource_type.value,
"datasource": self.datasource,
"state": self.state,
}
command = CreateKeyValueCommand(
resource=self.resource,
value=value,
codec=self.codec,
)
key = command.run()
return encode_permalink_key(key=key.id, salt=self.salt)
>>>>>>> c01dacb71a (chore(dao): Use nested session for operations)
def validate(self) -> None: def validate(self) -> None:
pass pass

View File

@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
from functools import partial
from typing import Any, Optional from typing import Any, Optional
from marshmallow import Schema from marshmallow import Schema
@ -44,7 +45,7 @@ from superset.datasets.schemas import ImportV1DatasetSchema
from superset.migrations.shared.native_filters import migrate_dashboard from superset.migrations.shared.native_filters import migrate_dashboard
from superset.models.dashboard import dashboard_slices from superset.models.dashboard import dashboard_slices
from superset.queries.saved_queries.schemas import ImportV1SavedQuerySchema from superset.queries.saved_queries.schemas import ImportV1SavedQuerySchema
from superset.utils.decorators import transaction from superset.utils.decorators import on_error, transaction
class ImportAssetsCommand(BaseCommand): class ImportAssetsCommand(BaseCommand):
@ -154,14 +155,16 @@ class ImportAssetsCommand(BaseCommand):
if chart.viz_type == "filter_box": if chart.viz_type == "filter_box":
db.session.delete(chart) db.session.delete(chart)
@transaction() @transaction(
on_error=partial(
on_error,
catches=(Exception,),
reraise=ImportFailedError,
)
)
def run(self) -> None: def run(self) -> None:
self.validate() self.validate()
self._import(self._configs)
try:
self._import(self._configs)
except Exception as ex:
raise ImportFailedError() from ex
def validate(self) -> None: def validate(self) -> None:
exceptions: list[ValidationError] = [] exceptions: list[ValidationError] = []

View File

@ -53,9 +53,11 @@ class DeleteKeyValueCommand(BaseCommand):
pass pass
def delete(self) -> bool: def delete(self) -> bool:
filter_ = get_filter(self.resource, self.key) if (
if entry := db.session.query(KeyValueEntry).filter_by(**filter_).first(): entry := db.session.query(KeyValueEntry)
.filter_by(**get_filter(self.resource, self.key))
.first()
):
db.session.delete(entry) db.session.delete(entry)
db.session.flush()
return True return True
return False return False

View File

@ -60,4 +60,3 @@ class DeleteExpiredKeyValueCommand(BaseCommand):
) )
.delete() .delete()
) )
db.session.flush()

View File

@ -21,6 +21,8 @@ from functools import partial
from typing import Any, Optional, Union from typing import Any, Optional, Union
from uuid import UUID from uuid import UUID
from sqlalchemy.exc import SQLAlchemyError
from superset import db from superset import db
from superset.commands.base import BaseCommand from superset.commands.base import BaseCommand
from superset.commands.key_value.create import CreateKeyValueCommand from superset.commands.key_value.create import CreateKeyValueCommand
@ -71,7 +73,7 @@ class UpsertKeyValueCommand(BaseCommand):
@transaction( @transaction(
on_error=partial( on_error=partial(
on_error, on_error,
catches=(KeyValueCreateFailedError,), catches=(KeyValueCreateFailedError, SQLAlchemyError),
reraise=KeyValueUpsertFailedError, reraise=KeyValueUpsertFailedError,
), ),
) )
@ -82,16 +84,15 @@ class UpsertKeyValueCommand(BaseCommand):
pass pass
def upsert(self) -> Key: def upsert(self) -> Key:
filter_ = get_filter(self.resource, self.key) if (
entry: KeyValueEntry = ( entry := db.session.query(KeyValueEntry)
db.session.query(KeyValueEntry).filter_by(**filter_).first() .filter_by(**get_filter(self.resource, self.key))
) .first()
if entry: ):
entry.value = self.codec.encode(self.value) entry.value = self.codec.encode(self.value)
entry.expires_on = self.expires_on entry.expires_on = self.expires_on
entry.changed_on = datetime.now() entry.changed_on = datetime.now()
entry.changed_by_fk = get_user_id() entry.changed_by_fk = get_user_id()
db.session.flush()
return Key(entry.id, entry.uuid) return Key(entry.id, entry.uuid)
return CreateKeyValueCommand( return CreateKeyValueCommand(

View File

@ -137,6 +137,7 @@ class BaseReportState:
uuid=self._execution_id, uuid=self._execution_id,
) )
db.session.add(log) db.session.add(log)
db.session.commit() # pylint: disable=consider-using-transaction
def _get_url( def _get_url(
self, self,

View File

@ -49,7 +49,6 @@ class AsyncPruneReportScheduleLogCommand(BaseCommand):
report_schedule, report_schedule,
from_date, from_date,
) )
db.session.commit()
logger.info( logger.info(
"Deleted %s logs for report schedule id: %s", "Deleted %s logs for report schedule id: %s",
str(row_count), str(row_count),

View File

@ -383,5 +383,4 @@ class TagDAO(BaseDAO[Tag]):
object_id, object_id,
tag.name, tag.name,
) )
db.session.add_all(tagged_objects) db.session.add_all(tagged_objects)

View File

@ -32,7 +32,7 @@ from marshmallow import ValidationError
from werkzeug.wrappers import Response as WerkzeugResponse from werkzeug.wrappers import Response as WerkzeugResponse
from werkzeug.wsgi import FileWrapper from werkzeug.wsgi import FileWrapper
from superset import is_feature_enabled, thumbnail_cache from superset import db, is_feature_enabled, thumbnail_cache
from superset.charts.schemas import ChartEntityResponseSchema from superset.charts.schemas import ChartEntityResponseSchema
from superset.commands.dashboard.create import CreateDashboardCommand from superset.commands.dashboard.create import CreateDashboardCommand
from superset.commands.dashboard.delete import DeleteDashboardCommand from superset.commands.dashboard.delete import DeleteDashboardCommand
@ -1314,7 +1314,13 @@ class DashboardRestApi(BaseSupersetModelRestApi):
""" """
try: try:
body = self.embedded_config_schema.load(request.json) body = self.embedded_config_schema.load(request.json)
embedded = EmbeddedDashboardDAO.upsert(dashboard, body["allowed_domains"])
with db.session.begin_nested():
embedded = EmbeddedDashboardDAO.upsert(
dashboard,
body["allowed_domains"],
)
result = self.embedded_response_schema.dump(embedded) result = self.embedded_response_schema.dump(embedded)
return self.response(200, result=result) return self.response(200, result=result)
except ValidationError as error: except ValidationError as error:

View File

@ -22,7 +22,6 @@ from uuid import UUID, uuid3
from flask import current_app, Flask, has_app_context from flask import current_app, Flask, has_app_context
from flask_caching import BaseCache from flask_caching import BaseCache
from superset import db
from superset.key_value.exceptions import KeyValueCreateFailedError from superset.key_value.exceptions import KeyValueCreateFailedError
from superset.key_value.types import ( from superset.key_value.types import (
KeyValueCodec, KeyValueCodec,
@ -95,7 +94,6 @@ class SupersetMetastoreCache(BaseCache):
codec=self.codec, codec=self.codec,
expires_on=self._get_expiry(timeout), expires_on=self._get_expiry(timeout),
).run() ).run()
db.session.commit()
return True return True
def add(self, key: str, value: Any, timeout: Optional[int] = None) -> bool: def add(self, key: str, value: Any, timeout: Optional[int] = None) -> bool:
@ -111,7 +109,6 @@ class SupersetMetastoreCache(BaseCache):
key=self.get_key(key), key=self.get_key(key),
expires_on=self._get_expiry(timeout), expires_on=self._get_expiry(timeout),
).run() ).run()
db.session.commit()
return True return True
except KeyValueCreateFailedError: except KeyValueCreateFailedError:
return False return False
@ -136,6 +133,4 @@ class SupersetMetastoreCache(BaseCache):
# pylint: disable=import-outside-toplevel # pylint: disable=import-outside-toplevel
from superset.commands.key_value.delete import DeleteKeyValueCommand from superset.commands.key_value.delete import DeleteKeyValueCommand
ret = DeleteKeyValueCommand(resource=RESOURCE, key=self.get_key(key)).run() return DeleteKeyValueCommand(resource=RESOURCE, key=self.get_key(key)).run()
db.session.commit()
return ret

View File

@ -18,7 +18,6 @@
from typing import Any, Optional from typing import Any, Optional
from uuid import uuid3 from uuid import uuid3
from superset import db
from superset.key_value.types import JsonKeyValueCodec, KeyValueResource, SharedKey from superset.key_value.types import JsonKeyValueCodec, KeyValueResource, SharedKey
from superset.key_value.utils import get_uuid_namespace, random_key from superset.key_value.utils import get_uuid_namespace, random_key
@ -46,7 +45,6 @@ def set_shared_value(key: SharedKey, value: Any) -> None:
key=uuid_key, key=uuid_key,
codec=CODEC, codec=CODEC,
).run() ).run()
db.session.commit()
def get_permalink_salt(key: SharedKey) -> str: def get_permalink_salt(key: SharedKey) -> str:

View File

@ -334,22 +334,26 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
pyjwt_for_guest_token = _jwt_global_obj pyjwt_for_guest_token = _jwt_global_obj
@property @property
def get_session(self) -> Session: def get_session2(self) -> Session:
""" """
Flask-AppBuilder (FAB) which has a tendency to explicitly commit, thus violating Flask-AppBuilder (FAB) which has a tendency to explicitly commit, thus violating
our definition of "unit of work". our definition of "unit of work".
By providing a monkey patched transaction for the FAB session ensures that any By providing a monkey patched transaction for the FAB session, within the
explicit commit merely flushes and any rollback is a no-op. confines of a nested session, ensures that any explicit commit merely flushes
and any rollback is a no-op.
""" """
# pylint: disable=import-outside-toplevel # pylint: disable=import-outside-toplevel
from superset import db from superset import db
with db.session.begin_nested() as transaction: if db.session._proxied._nested_transaction: # pylint: disable=protected-access
transaction.session.commit = transaction.session.flush with db.session.begin_nested() as transaction:
transaction.session.rollback = lambda: None transaction.session.commit = transaction.session.flush
return transaction.session transaction.session.rollback = lambda: None
return transaction.session
return db.session
def create_login_manager(self, app: Flask) -> LoginManager: def create_login_manager(self, app: Flask) -> LoginManager:
lm = super().create_login_manager(app) lm = super().create_login_manager(app)
@ -1035,7 +1039,6 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
== None, # noqa: E711 == None, # noqa: E711
) )
) )
self.get_session.flush()
if deleted_count := pvms.delete(): if deleted_count := pvms.delete():
logger.info("Deleted %i faulty permissions", deleted_count) logger.info("Deleted %i faulty permissions", deleted_count)
@ -1065,7 +1068,6 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
) )
self.create_missing_perms() self.create_missing_perms()
self.get_session.flush()
self.clean_perms() self.clean_perms()
def _get_all_pvms(self) -> list[PermissionView]: def _get_all_pvms(self) -> list[PermissionView]:
@ -1138,7 +1140,6 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
): ):
role_from_permissions.append(permission_view) role_from_permissions.append(permission_view)
role_to.permissions = role_from_permissions role_to.permissions = role_from_permissions
self.get_session.flush()
def set_role( def set_role(
self, self,
@ -1159,7 +1160,6 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
permission_view for permission_view in pvms if pvm_check(permission_view) permission_view for permission_view in pvms if pvm_check(permission_view)
] ]
role.permissions = role_pvms role.permissions = role_pvms
self.get_session.flush()
def _is_admin_only(self, pvm: PermissionView) -> bool: def _is_admin_only(self, pvm: PermissionView) -> bool:
""" """

View File

@ -140,6 +140,7 @@ def get_tag(
if tag is None: if tag is None:
tag = Tag(name=escape(tag_name), type=type_) tag = Tag(name=escape(tag_name), type=type_)
session.add(tag) session.add(tag)
session.commit()
return tag return tag

View File

@ -59,6 +59,7 @@ def get_or_create_db(
if database and database.sqlalchemy_uri_decrypted != sqlalchemy_uri: if database and database.sqlalchemy_uri_decrypted != sqlalchemy_uri:
database.set_sqlalchemy_uri(sqlalchemy_uri) database.set_sqlalchemy_uri(sqlalchemy_uri)
db.session.flush()
return database return database
@ -78,3 +79,4 @@ def remove_database(database: Database) -> None:
from superset import db from superset import db
db.session.delete(database) db.session.delete(database)
db.session.flush()

View File

@ -222,7 +222,7 @@ def on_error(
:param ex: The source exception :param ex: The source exception
:param catches: The exception types the handler catches :param catches: The exception types the handler catches
:param reraise: The exception type the handler reraises after catching :param reraise: The exception type the handler raises after catching
:raises Exception: If the exception is not swallowed :raises Exception: If the exception is not swallowed
""" """
@ -252,13 +252,16 @@ def transaction( # pylint: disable=redefined-outer-name
from superset import db # pylint: disable=import-outside-toplevel from superset import db # pylint: disable=import-outside-toplevel
try: try:
with db.session.begin_nested(): result = func(*args, **kwargs)
return func(*args, **kwargs) db.session.commit()
return result
except Exception as ex: except Exception as ex:
db.session.rollback()
if on_error: if on_error:
return on_error(ex) return on_error(ex)
raise ex raise
return wrapped return wrapped

View File

@ -24,7 +24,6 @@ from contextlib import contextmanager
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, cast, TypeVar, Union from typing import Any, cast, TypeVar, Union
from superset import db
from superset.exceptions import CreateKeyValueDistributedLockFailedException from superset.exceptions import CreateKeyValueDistributedLockFailedException
from superset.key_value.exceptions import KeyValueCreateFailedError from superset.key_value.exceptions import KeyValueCreateFailedError
from superset.key_value.types import JsonKeyValueCodec, KeyValueResource from superset.key_value.types import JsonKeyValueCodec, KeyValueResource
@ -72,7 +71,6 @@ def KeyValueDistributedLock( # pylint: disable=invalid-name
store. store.
:param namespace: The namespace for which the lock is to be acquired. :param namespace: The namespace for which the lock is to be acquired.
:type namespace: str
:param kwargs: Additional keyword arguments. :param kwargs: Additional keyword arguments.
:yields: A unique identifier (UUID) for the acquired lock (the KV key). :yields: A unique identifier (UUID) for the acquired lock (the KV key).
:raises CreateKeyValueDistributedLockFailedException: If the lock is taken. :raises CreateKeyValueDistributedLockFailedException: If the lock is taken.
@ -93,12 +91,10 @@ def KeyValueDistributedLock( # pylint: disable=invalid-name
value=True, value=True,
expires_on=datetime.now() + LOCK_EXPIRATION, expires_on=datetime.now() + LOCK_EXPIRATION,
).run() ).run()
db.session.commit()
yield key yield key
DeleteKeyValueCommand(resource=KeyValueResource.LOCK, key=key).run() DeleteKeyValueCommand(resource=KeyValueResource.LOCK, key=key).run()
db.session.commit()
logger.debug("Removed lock on namespace %s for key %s", namespace, key) logger.debug("Removed lock on namespace %s for key %s", namespace, key)
except KeyValueCreateFailedError as ex: except KeyValueCreateFailedError as ex:
raise CreateKeyValueDistributedLockFailedException( raise CreateKeyValueDistributedLockFailedException(

View File

@ -403,6 +403,7 @@ class DBEventLogger(AbstractEventLogger):
logs.append(log) logs.append(log)
try: try:
db.session.bulk_save_objects(logs) db.session.bulk_save_objects(logs)
db.session.commit() # pylint: disable=consider-using-transaction
except SQLAlchemyError as ex: except SQLAlchemyError as ex:
logging.error("DBEventLogger failed to log event(s)") logging.error("DBEventLogger failed to log event(s)")
logging.exception(ex) logging.exception(ex)

View File

@ -14,8 +14,7 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=consider-using-transaction from typing import TYPE_CHECKING
from typing import Any, TYPE_CHECKING
from flask_appbuilder import expose from flask_appbuilder import expose
from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_appbuilder.models.sqla.interface import SQLAInterface

View File

@ -349,15 +349,7 @@ class SupersetTestCase(TestCase):
self.grant_role_access_to_table(table, role_name) self.grant_role_access_to_table(table, role_name)
def grant_role_access_to_table(self, table, role_name): def grant_role_access_to_table(self, table, role_name):
print(">>> grant_role_access_to_table <<<")
print(role_name)
print(db.session.get_bind())
from flask_appbuilder.security.sqla.models import Role
print(list(db.session.query(Role).all()))
role = security_manager.find_role(role_name) role = security_manager.find_role(role_name)
print(role)
perms = db.session.query(ab_models.PermissionView).all() perms = db.session.query(ab_models.PermissionView).all()
for perm in perms: for perm in perms:
if ( if (

View File

@ -332,6 +332,9 @@ def test_run_async_cta_query_with_lower_limit(test_client, ctas_method):
tmp_table=tmp_table, tmp_table=tmp_table,
) )
query = wait_for_success(result) query = wait_for_success(result)
print(">>> test_run_async_cta_query_with_lower_limit <<<")
print(result)
print(query.to_dict())
assert QueryStatus.SUCCESS == query.status assert QueryStatus.SUCCESS == query.status
sqlite_select_sql = f"SELECT\n *\nFROM {tmp_table}\nLIMIT {query.limit}\nOFFSET 0" sqlite_select_sql = f"SELECT\n *\nFROM {tmp_table}\nLIMIT {query.limit}\nOFFSET 0"

View File

@ -91,7 +91,7 @@ class TestCore(SupersetTestCase):
self.original_unsafe_db_setting = app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] self.original_unsafe_db_setting = app.config["PREVENT_UNSAFE_DB_CONNECTIONS"]
def tearDown(self): def tearDown(self):
# db.session.query(Query).delete() db.session.query(Query).delete()
app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = self.original_unsafe_db_setting app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = self.original_unsafe_db_setting
super().tearDown() super().tearDown()
@ -235,7 +235,6 @@ class TestCore(SupersetTestCase):
) )
for slc in slices: for slc in slices:
db.session.delete(slc) db.session.delete(slc)
print(db.session.dirty)
db.session.commit() db.session.commit()
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@ -666,7 +665,6 @@ class TestCore(SupersetTestCase):
client_id="client_id_1", client_id="client_id_1",
username="admin", username="admin",
) )
print(resp)
count_ds = [] count_ds = []
count_name = [] count_name = []
for series in data["data"]: for series in data["data"]:
@ -814,7 +812,7 @@ class TestCore(SupersetTestCase):
mock_cache.return_value = MockCache() mock_cache.return_value = MockCache()
rv = self.client.get("/superset/explore_json/data/valid-cache-key") rv = self.client.get("/superset/explore_json/data/valid-cache-key")
self.assertEqual(rv.status_code, 401) self.assertEqual(rv.status_code, 403)
def test_explore_json_data_invalid_cache_key(self): def test_explore_json_data_invalid_cache_key(self):
self.login(ADMIN_USERNAME) self.login(ADMIN_USERNAME)

View File

@ -97,5 +97,6 @@ def create_dashboard(
if slices is not None: if slices is not None:
dash.slices = slices dash.slices = slices
db.session.add(dash) db.session.add(dash)
db.session.commit()
return dash return dash

View File

@ -592,7 +592,6 @@ class TestImportDashboardsCommand(SupersetTestCase):
} }
command = v1.ImportDashboardsCommand(contents, overwrite=True) command = v1.ImportDashboardsCommand(contents, overwrite=True)
command.run() command.run()
command.run()
new_num_dashboards = db.session.query(Dashboard).count() new_num_dashboards = db.session.query(Dashboard).count()
assert new_num_dashboards == num_dashboards + 1 assert new_num_dashboards == num_dashboards + 1

View File

@ -48,15 +48,16 @@ class TestDashboardDAO(SupersetTestCase):
assert changed_on == DashboardDAO.get_dashboard_changed_on("world_health") assert changed_on == DashboardDAO.get_dashboard_changed_on("world_health")
old_changed_on = dashboard.changed_on old_changed_on = dashboard.changed_on
# freezegun doesn't work for some reason, so we need to sleep here :( # freezegun doesn't work for some reason, so we need to sleep here :(
time.sleep(1) time.sleep(1)
data = dashboard.data data = dashboard.data
positions = data["position_json"] positions = data["position_json"]
data.update({"positions": positions}) data.update({"positions": positions})
original_data = copy.deepcopy(data) original_data = copy.deepcopy(data)
data.update({"foo": "bar"}) data.update({"foo": "bar"})
DashboardDAO.set_dash_metadata(dashboard, data) DashboardDAO.set_dash_metadata(dashboard, data)
db.session.flush()
db.session.commit() db.session.commit()
new_changed_on = DashboardDAO.get_dashboard_changed_on(dashboard) new_changed_on = DashboardDAO.get_dashboard_changed_on(dashboard)
assert old_changed_on.replace(microsecond=0) < new_changed_on assert old_changed_on.replace(microsecond=0) < new_changed_on
@ -68,7 +69,6 @@ class TestDashboardDAO(SupersetTestCase):
) )
DashboardDAO.set_dash_metadata(dashboard, original_data) DashboardDAO.set_dash_metadata(dashboard, original_data)
db.session.flush()
db.session.commit() db.session.commit()
@pytest.mark.usefixtures("load_world_bank_dashboard_with_slices") @pytest.mark.usefixtures("load_world_bank_dashboard_with_slices")

View File

@ -26,6 +26,7 @@ import prison
import pytest import pytest
import yaml import yaml
from sqlalchemy import inspect from sqlalchemy import inspect
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
from sqlalchemy.sql import func from sqlalchemy.sql import func

View File

@ -44,6 +44,7 @@ class TestEmbeddedDashboardApi(SupersetTestCase):
self.login(ADMIN_USERNAME) self.login(ADMIN_USERNAME)
self.dash = db.session.query(Dashboard).filter_by(slug="births").first() self.dash = db.session.query(Dashboard).filter_by(slug="births").first()
self.embedded = EmbeddedDashboardDAO.upsert(self.dash, []) self.embedded = EmbeddedDashboardDAO.upsert(self.dash, [])
db.session.flush()
uri = f"api/v1/{self.resource_name}/{self.embedded.uuid}" uri = f"api/v1/{self.resource_name}/{self.embedded.uuid}"
response = self.client.get(uri) response = self.client.get(uri)
self.assert200(response) self.assert200(response)

View File

@ -34,17 +34,21 @@ class TestEmbeddedDashboardDAO(SupersetTestCase):
dash = db.session.query(Dashboard).filter_by(slug="world_health").first() dash = db.session.query(Dashboard).filter_by(slug="world_health").first()
assert not dash.embedded assert not dash.embedded
EmbeddedDashboardDAO.upsert(dash, ["test.example.com"]) EmbeddedDashboardDAO.upsert(dash, ["test.example.com"])
db.session.flush()
assert dash.embedded assert dash.embedded
self.assertEqual(dash.embedded[0].allowed_domains, ["test.example.com"]) self.assertEqual(dash.embedded[0].allowed_domains, ["test.example.com"])
original_uuid = dash.embedded[0].uuid original_uuid = dash.embedded[0].uuid
self.assertIsNotNone(original_uuid) self.assertIsNotNone(original_uuid)
EmbeddedDashboardDAO.upsert(dash, []) EmbeddedDashboardDAO.upsert(dash, [])
db.session.flush()
self.assertEqual(dash.embedded[0].allowed_domains, []) self.assertEqual(dash.embedded[0].allowed_domains, [])
self.assertEqual(dash.embedded[0].uuid, original_uuid) self.assertEqual(dash.embedded[0].uuid, original_uuid)
@pytest.mark.usefixtures("load_world_bank_dashboard_with_slices") @pytest.mark.usefixtures("load_world_bank_dashboard_with_slices")
def test_get_by_uuid(self): def test_get_by_uuid(self):
dash = db.session.query(Dashboard).filter_by(slug="world_health").first() dash = db.session.query(Dashboard).filter_by(slug="world_health").first()
uuid = str(EmbeddedDashboardDAO.upsert(dash, ["test.example.com"]).uuid) EmbeddedDashboardDAO.upsert(dash, ["test.example.com"])
db.session.flush()
uuid = str(dash.embedded[0].uuid)
embedded = EmbeddedDashboardDAO.find_by_id(uuid) embedded = EmbeddedDashboardDAO.find_by_id(uuid)
self.assertIsNotNone(embedded) self.assertIsNotNone(embedded)

View File

@ -44,6 +44,7 @@ if TYPE_CHECKING:
def test_get_embedded_dashboard(client: FlaskClient[Any]): # noqa: F811 def test_get_embedded_dashboard(client: FlaskClient[Any]): # noqa: F811
dash = db.session.query(Dashboard).filter_by(slug="births").first() dash = db.session.query(Dashboard).filter_by(slug="births").first()
embedded = EmbeddedDashboardDAO.upsert(dash, []) embedded = EmbeddedDashboardDAO.upsert(dash, [])
db.session.flush()
uri = f"embedded/{embedded.uuid}" uri = f"embedded/{embedded.uuid}"
response = client.get(uri) response = client.get(uri)
assert response.status_code == 200 assert response.status_code == 200
@ -57,6 +58,7 @@ def test_get_embedded_dashboard(client: FlaskClient[Any]): # noqa: F811
def test_get_embedded_dashboard_referrer_not_allowed(client: FlaskClient[Any]): # noqa: F811 def test_get_embedded_dashboard_referrer_not_allowed(client: FlaskClient[Any]): # noqa: F811
dash = db.session.query(Dashboard).filter_by(slug="births").first() dash = db.session.query(Dashboard).filter_by(slug="births").first()
embedded = EmbeddedDashboardDAO.upsert(dash, ["test.example.com"]) embedded = EmbeddedDashboardDAO.upsert(dash, ["test.example.com"])
db.session.flush()
uri = f"embedded/{embedded.uuid}" uri = f"embedded/{embedded.uuid}"
response = client.get(uri) response = client.get(uri)
assert response.status_code == 403 assert response.status_code == 403

View File

@ -46,22 +46,25 @@ def load_birth_names_data(
@pytest.fixture() @pytest.fixture()
def load_birth_names_dashboard_with_slices(load_birth_names_data): def load_birth_names_dashboard_with_slices(load_birth_names_data):
with app.app_context(): with app.app_context():
_create_dashboards() dash_id_to_delete, slices_ids_to_delete = _create_dashboards()
yield yield
_cleanup(dash_id_to_delete, slices_ids_to_delete)
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def load_birth_names_dashboard_with_slices_module_scope(load_birth_names_data): def load_birth_names_dashboard_with_slices_module_scope(load_birth_names_data):
with app.app_context(): with app.app_context():
_create_dashboards() dash_id_to_delete, slices_ids_to_delete = _create_dashboards()
yield yield
_cleanup(dash_id_to_delete, slices_ids_to_delete)
@pytest.fixture(scope="class") @pytest.fixture(scope="class")
def load_birth_names_dashboard_with_slices_class_scope(load_birth_names_data): def load_birth_names_dashboard_with_slices_class_scope(load_birth_names_data):
with app.app_context(): with app.app_context():
_create_dashboards() dash_id_to_delete, slices_ids_to_delete = _create_dashboards()
yield yield
_cleanup(dash_id_to_delete, slices_ids_to_delete)
def _create_dashboards(): def _create_dashboards():
@ -74,7 +77,10 @@ def _create_dashboards():
from superset.examples.birth_names import create_dashboard, create_slices from superset.examples.birth_names import create_dashboard, create_slices
slices, _ = create_slices(table) slices, _ = create_slices(table)
create_dashboard(slices) dash = create_dashboard(slices)
slices_ids_to_delete = [slice.id for slice in slices]
dash_id_to_delete = dash.id
return dash_id_to_delete, slices_ids_to_delete
def _create_table( def _create_table(
@ -91,4 +97,20 @@ def _create_table(
_set_table_metadata(table, database) _set_table_metadata(table, database)
_add_table_metrics(table) _add_table_metrics(table)
db.session.commit()
return table return table
def _cleanup(dash_id: int, slice_ids: list[int]) -> None:
schema = get_example_default_schema()
for datasource in db.session.query(SqlaTable).filter_by(
table_name="birth_names", schema=schema
):
for col in datasource.columns + datasource.metrics:
db.session.delete(col)
for dash in db.session.query(Dashboard).filter_by(id=dash_id):
db.session.delete(dash)
for slc in db.session.query(Slice).filter(Slice.id.in_(slice_ids)):
db.session.delete(slc)
db.session.commit()

View File

@ -61,6 +61,7 @@ def load_energy_table_with_slice(load_energy_table_data):
with app.app_context(): with app.app_context():
slices = _create_energy_table() slices = _create_energy_table()
yield slices yield slices
_cleanup()
def _get_dataframe(): def _get_dataframe():
@ -109,6 +110,24 @@ def _create_and_commit_energy_slice(
return slice return slice
def _cleanup() -> None:
for slice_data in _get_energy_slices():
slice = (
db.session.query(Slice)
.filter_by(slice_name=slice_data["slice_title"])
.first()
)
db.session.delete(slice)
metric = (
db.session.query(SqlMetric).filter_by(metric_name="sum__value").one_or_none()
)
if metric:
db.session.delete(metric)
db.session.commit()
def _get_energy_data(): def _get_energy_data():
data = [] data = []
for i in range(85): for i in range(85):

View File

@ -135,4 +135,7 @@ def tabbed_dashboard(app_context):
slices=[], slices=[],
) )
db.session.add(dash) db.session.add(dash)
yield db.session.commit()
yield dash
db.session.query(Dashboard).filter_by(id=dash.id).delete()
db.session.commit()

View File

@ -61,6 +61,7 @@ def load_unicode_dashboard_with_slice(load_unicode_data):
with app.app_context(): with app.app_context():
dash = _create_unicode_dashboard(slice_name, None) dash = _create_unicode_dashboard(slice_name, None)
yield yield
_cleanup(dash, slice_name)
@pytest.fixture() @pytest.fixture()
@ -70,6 +71,7 @@ def load_unicode_dashboard_with_position(load_unicode_data):
with app.app_context(): with app.app_context():
dash = _create_unicode_dashboard(slice_name, position) dash = _create_unicode_dashboard(slice_name, position)
yield yield
_cleanup(dash, slice_name)
def _get_dataframe(): def _get_dataframe():
@ -95,18 +97,25 @@ def _create_unicode_dashboard(slice_title: str, position: str) -> Dashboard:
table.fetch_metadata() table.fetch_metadata()
if slice_title: if slice_title:
slice = _create_unicode_slice(table, slice_title) slice = _create_and_commit_unicode_slice(table, slice_title)
return create_dashboard("unicode-test", "Unicode Test", position, [slice]) return create_dashboard("unicode-test", "Unicode Test", position, [slice])
def _create_unicode_slice(table: SqlaTable, title: str): def _create_and_commit_unicode_slice(table: SqlaTable, title: str):
slc = create_slice(title, "word_cloud", table, {}) slice = create_slice(title, "word_cloud", table, {})
if ( o = db.session.query(Slice).filter_by(slice_name=slice.slice_name).one_or_none()
obj := db.session.query(Slice) if o:
.filter_by(slice_name=slc.slice_name) db.session.delete(o)
.one_or_none() db.session.add(slice)
db.session.commit()
return slice
def _cleanup(dash: Dashboard, slice_name: str) -> None:
db.session.delete(dash)
if slice_name and (
slice := db.session.query(Slice).filter_by(slice_name=slice_name).one_or_none()
): ):
db.session.delete(obj) db.session.delete(slice)
db.session.add(slc) db.session.commit()
return slc

View File

@ -64,7 +64,6 @@ def load_world_bank_data():
) )
yield yield
with app.app_context(): with app.app_context():
with get_example_database().get_sqla_engine() as engine: with get_example_database().get_sqla_engine() as engine:
engine.execute("DROP TABLE IF EXISTS wb_health_population") engine.execute("DROP TABLE IF EXISTS wb_health_population")
@ -73,14 +72,15 @@ def load_world_bank_data():
@pytest.fixture() @pytest.fixture()
def load_world_bank_dashboard_with_slices(load_world_bank_data): def load_world_bank_dashboard_with_slices(load_world_bank_data):
with app.app_context(): with app.app_context():
create_dashboard_for_loaded_data() dash_id_to_delete, slices_ids_to_delete = create_dashboard_for_loaded_data()
yield yield
_cleanup(dash_id_to_delete, slices_ids_to_delete)
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def load_world_bank_dashboard_with_slices_module_scope(load_world_bank_data): def load_world_bank_dashboard_with_slices_module_scope(load_world_bank_data):
with app.app_context(): with app.app_context():
create_dashboard_for_loaded_data() dash_id_to_delete, slices_ids_to_delete = create_dashboard_for_loaded_data()
yield yield
_cleanup_reports(dash_id_to_delete, slices_ids_to_delete) _cleanup_reports(dash_id_to_delete, slices_ids_to_delete)
_cleanup(dash_id_to_delete, slices_ids_to_delete) _cleanup(dash_id_to_delete, slices_ids_to_delete)
@ -89,8 +89,9 @@ def load_world_bank_dashboard_with_slices_module_scope(load_world_bank_data):
@pytest.fixture(scope="class") @pytest.fixture(scope="class")
def load_world_bank_dashboard_with_slices_class_scope(load_world_bank_data): def load_world_bank_dashboard_with_slices_class_scope(load_world_bank_data):
with app.app_context(): with app.app_context():
create_dashboard_for_loaded_data() dash_id_to_delete, slices_ids_to_delete = create_dashboard_for_loaded_data()
yield yield
_cleanup(dash_id_to_delete, slices_ids_to_delete)
def create_dashboard_for_loaded_data(): def create_dashboard_for_loaded_data():
@ -102,22 +103,24 @@ def create_dashboard_for_loaded_data():
return dash_id_to_delete, slices_ids_to_delete return dash_id_to_delete, slices_ids_to_delete
def _create_world_bank_slices(table: SqlaTable) -> None: def _create_world_bank_slices(table: SqlaTable) -> list[Slice]:
from superset.examples.world_bank import create_slices from superset.examples.world_bank import create_slices
slices = create_slices(table) slices = create_slices(table)
_commit_slices(slices)
for slc in slices: return slices
if (
obj := db.session.query(Slice)
.filter_by(slice_name=slc.slice_name)
.one_or_none()
):
db.session.delete(obj)
db.session.add(slc)
def _create_world_bank_dashboard(table: SqlaTable) -> None: def _commit_slices(slices: list[Slice]):
for slice in slices:
o = db.session.query(Slice).filter_by(slice_name=slice.slice_name).one_or_none()
if o:
db.session.delete(o)
db.session.add(slice)
db.session.commit()
def _create_world_bank_dashboard(table: SqlaTable) -> Dashboard:
from superset.examples.helpers import update_slice_ids from superset.examples.helpers import update_slice_ids
from superset.examples.world_bank import dashboard_positions from superset.examples.world_bank import dashboard_positions
@ -130,6 +133,16 @@ def _create_world_bank_dashboard(table: SqlaTable) -> None:
"world_health", "World Bank's Data", json.dumps(pos), slices "world_health", "World Bank's Data", json.dumps(pos), slices
) )
dash.json_metadata = '{"mock_key": "mock_value"}' dash.json_metadata = '{"mock_key": "mock_value"}'
db.session.commit()
return dash
def _cleanup(dash_id: int, slices_ids: list[int]) -> None:
dash = db.session.query(Dashboard).filter_by(id=dash_id).first()
db.session.delete(dash)
for slice_id in slices_ids:
db.session.query(Slice).filter_by(id=slice_id).delete()
db.session.commit()
def _cleanup_reports(dash_id: int, slices_ids: list[int]) -> None: def _cleanup_reports(dash_id: int, slices_ids: list[int]) -> None:

View File

@ -215,8 +215,6 @@ class TestRowLevelSecurity(SupersetTestCase):
}, },
) )
self.assertEqual(rv.status_code, 422) self.assertEqual(rv.status_code, 422)
data = json.loads(rv.data.decode("utf-8"))
assert "Create failed" in data["message"]
@pytest.mark.usefixtures("create_dataset") @pytest.mark.usefixtures("create_dataset")
def test_model_view_rls_add_tables_required(self): def test_model_view_rls_add_tables_required(self):

View File

@ -71,10 +71,8 @@ QUERY_3 = "SELECT * FROM birth_names LIMIT 10"
class TestSqlLab(SupersetTestCase): class TestSqlLab(SupersetTestCase):
"""Testings for Sql Lab""" """Testings for Sql Lab"""
@pytest.mark.usefixtures("load_birth_names_data")
def run_some_queries(self): def run_some_queries(self):
db.session.query(Query).delete() db.session.query(Query).delete()
db.session.commit()
self.run_sql(QUERY_1, client_id="client_id_1", username="admin") self.run_sql(QUERY_1, client_id="client_id_1", username="admin")
self.run_sql(QUERY_2, client_id="client_id_2", username="admin") self.run_sql(QUERY_2, client_id="client_id_2", username="admin")
self.run_sql(QUERY_3, client_id="client_id_3", username="gamma_sqllab") self.run_sql(QUERY_3, client_id="client_id_3", username="gamma_sqllab")
@ -419,6 +417,7 @@ class TestSqlLab(SupersetTestCase):
self.assertEqual(len(data["data"]), 1200) self.assertEqual(len(data["data"]), 1200)
self.assertEqual(data["query"]["limitingFactor"], LimitingFactor.NOT_LIMITED) self.assertEqual(data["query"]["limitingFactor"], LimitingFactor.NOT_LIMITED)
@pytest.mark.usefixtures("load_birth_names_data")
def test_query_api_filter(self) -> None: def test_query_api_filter(self) -> None:
""" """
Test query api without can_only_access_owned_queries perm added to Test query api without can_only_access_owned_queries perm added to
@ -438,6 +437,7 @@ class TestSqlLab(SupersetTestCase):
assert admin.first_name in user_queries assert admin.first_name in user_queries
assert gamma_sqllab.first_name in user_queries assert gamma_sqllab.first_name in user_queries
@pytest.mark.usefixtures("load_birth_names_data")
def test_query_api_can_access_all_queries(self) -> None: def test_query_api_can_access_all_queries(self) -> None:
""" """
Test query api with can_access_all_queries perm added to Test query api with can_access_all_queries perm added to
@ -522,6 +522,7 @@ class TestSqlLab(SupersetTestCase):
{r.get("sql") for r in self.get_json_resp(url)["result"]}, {r.get("sql") for r in self.get_json_resp(url)["result"]},
) )
@pytest.mark.usefixtures("load_birth_names_data")
def test_query_admin_can_access_all_queries(self) -> None: def test_query_admin_can_access_all_queries(self) -> None:
""" """
Test query api with all_query_access perm added to Test query api with all_query_access perm added to

View File

@ -187,6 +187,7 @@ class TestTagsDAO(SupersetTestCase):
TaggedObject.object_type == ObjectType.chart, TaggedObject.object_type == ObjectType.chart,
), ),
) )
.join(Tag, TaggedObject.tag_id == Tag.id)
.distinct(Slice.id) .distinct(Slice.id)
.count() .count()
) )
@ -199,6 +200,7 @@ class TestTagsDAO(SupersetTestCase):
TaggedObject.object_type == ObjectType.dashboard, TaggedObject.object_type == ObjectType.dashboard,
), ),
) )
.join(Tag, TaggedObject.tag_id == Tag.id)
.distinct(Dashboard.id) .distinct(Dashboard.id)
.count() .count()
+ num_charts + num_charts

View File

@ -29,7 +29,6 @@ def database_with_catalog(mocker: MockerFixture) -> MagicMock:
""" """
Mock a database with catalogs and schemas. Mock a database with catalogs and schemas.
""" """
mocker.patch("superset.commands.database.create.db")
mocker.patch("superset.commands.database.create.TestConnectionDatabaseCommand") mocker.patch("superset.commands.database.create.TestConnectionDatabaseCommand")
database = mocker.MagicMock() database = mocker.MagicMock()
@ -53,7 +52,6 @@ def database_without_catalog(mocker: MockerFixture) -> MagicMock:
""" """
Mock a database without catalogs. Mock a database without catalogs.
""" """
mocker.patch("superset.commands.database.create.db")
mocker.patch("superset.commands.database.create.TestConnectionDatabaseCommand") mocker.patch("superset.commands.database.create.TestConnectionDatabaseCommand")
database = mocker.MagicMock() database = mocker.MagicMock()

View File

@ -29,8 +29,6 @@ def database_with_catalog(mocker: MockerFixture) -> MagicMock:
""" """
Mock a database with catalogs and schemas. Mock a database with catalogs and schemas.
""" """
mocker.patch("superset.commands.database.update.db")
database = mocker.MagicMock() database = mocker.MagicMock()
database.database_name = "my_db" database.database_name = "my_db"
database.db_engine_spec.__name__ = "test_engine" database.db_engine_spec.__name__ = "test_engine"
@ -50,8 +48,6 @@ def database_without_catalog(mocker: MockerFixture) -> MagicMock:
""" """
Mock a database without catalogs. Mock a database without catalogs.
""" """
mocker.patch("superset.commands.database.update.db")
database = mocker.MagicMock() database = mocker.MagicMock()
database.database_name = "my_db" database.database_name = "my_db"
database.db_engine_spec.__name__ = "test_engine" database.db_engine_spec.__name__ = "test_engine"

View File

@ -22,7 +22,6 @@ def test_user_favorite_tag(mocker):
from superset.daos.tag import TagDAO from superset.daos.tag import TagDAO
# Mock the behavior of TagDAO and g # Mock the behavior of TagDAO and g
mock_session = mocker.patch("superset.daos.tag.db.session")
mock_TagDAO = mocker.patch( mock_TagDAO = mocker.patch(
"superset.daos.tag.TagDAO" "superset.daos.tag.TagDAO"
) # Replace with the actual path to TagDAO ) # Replace with the actual path to TagDAO
@ -45,7 +44,6 @@ def test_remove_user_favorite_tag(mocker):
from superset.daos.tag import TagDAO from superset.daos.tag import TagDAO
# Mock the behavior of TagDAO and g # Mock the behavior of TagDAO and g
mock_session = mocker.patch("superset.daos.tag.db.session")
mock_TagDAO = mocker.patch("superset.daos.tag.TagDAO") mock_TagDAO = mocker.patch("superset.daos.tag.TagDAO")
mock_tag = mocker.MagicMock(users_favorited=[]) mock_tag = mocker.MagicMock(users_favorited=[])
mock_TagDAO.find_by_id.return_value = mock_tag mock_TagDAO.find_by_id.return_value = mock_tag

View File

@ -116,7 +116,7 @@ def test_post_with_uuid(
assert payload["result"]["uuid"] == "7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb" assert payload["result"]["uuid"] == "7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb"
database = session.query(Database).one() database = session.query(Database).one()
assert database.uuid == "7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb" assert database.uuid == UUID("7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb")
def test_password_mask( def test_password_mask(

View File

@ -22,8 +22,8 @@ from uuid import UUID
import pytest import pytest
from freezegun import freeze_time from freezegun import freeze_time
from sqlalchemy.orm import Session, sessionmaker
from superset import db
from superset.exceptions import CreateKeyValueDistributedLockFailedException from superset.exceptions import CreateKeyValueDistributedLockFailedException
from superset.key_value.types import JsonKeyValueCodec from superset.key_value.types import JsonKeyValueCodec
from superset.utils.lock import get_key, KeyValueDistributedLock from superset.utils.lock import get_key, KeyValueDistributedLock
@ -32,56 +32,51 @@ MAIN_KEY = get_key("ns", a=1, b=2)
OTHER_KEY = get_key("ns2", a=1, b=2) OTHER_KEY = get_key("ns2", a=1, b=2)
def _get_lock(key: UUID, session: Session) -> Any: def _get_lock(key: UUID) -> Any:
from superset.key_value.models import KeyValueEntry from superset.key_value.models import KeyValueEntry
entry = session.query(KeyValueEntry).filter_by(uuid=key).first() entry = db.session.query(KeyValueEntry).filter_by(uuid=key).first()
if entry is None or entry.is_expired(): if entry is None or entry.is_expired():
return None return None
return JsonKeyValueCodec().decode(entry.value) return JsonKeyValueCodec().decode(entry.value)
def _get_other_session() -> Session:
# This session is used to simulate what another worker will find in the metastore
# during the locking process.
from superset import db
bind = db.session.get_bind()
SessionMaker = sessionmaker(bind=bind)
return SessionMaker()
def test_key_value_distributed_lock_happy_path() -> None: def test_key_value_distributed_lock_happy_path() -> None:
""" """
Test successfully acquiring and returning the distributed lock. Test successfully acquiring and returning the distributed lock.
Note we use a nested transaction to ensure that the cleanup from the outer context
manager is correctly invoked, otherwise a partial rollback would occur leaving the
database in a fractured state.
""" """
session = _get_other_session()
with freeze_time("2021-01-01"): with freeze_time("2021-01-01"):
assert _get_lock(MAIN_KEY, session) is None assert _get_lock(MAIN_KEY) is None
with KeyValueDistributedLock("ns", a=1, b=2) as key: with KeyValueDistributedLock("ns", a=1, b=2) as key:
assert key == MAIN_KEY assert key == MAIN_KEY
assert _get_lock(key, session) is True assert _get_lock(key) is True
assert _get_lock(OTHER_KEY, session) is None assert _get_lock(OTHER_KEY) is None
with pytest.raises(CreateKeyValueDistributedLockFailedException):
with KeyValueDistributedLock("ns", a=1, b=2):
pass
assert _get_lock(MAIN_KEY, session) is None with db.session.begin_nested():
with pytest.raises(CreateKeyValueDistributedLockFailedException):
with KeyValueDistributedLock("ns", a=1, b=2):
pass
assert _get_lock(MAIN_KEY) is None
def test_key_value_distributed_lock_expired() -> None: def test_key_value_distributed_lock_expired() -> None:
""" """
Test expiration of the distributed lock Test expiration of the distributed lock
""" """
session = _get_other_session()
with freeze_time("2021-01-01T"): with freeze_time("2021-01-01"):
assert _get_lock(MAIN_KEY, session) is None assert _get_lock(MAIN_KEY) is None
with KeyValueDistributedLock("ns", a=1, b=2): with KeyValueDistributedLock("ns", a=1, b=2):
assert _get_lock(MAIN_KEY, session) is True assert _get_lock(MAIN_KEY) is True
with freeze_time("2022-01-01T"): with freeze_time("2022-01-01"):
assert _get_lock(MAIN_KEY, session) is None assert _get_lock(MAIN_KEY) is None
assert _get_lock(MAIN_KEY, session) is None assert _get_lock(MAIN_KEY) is None