Revert interim test updates

This commit is contained in:
John Bodley 2024-04-19 15:16:18 -07:00
parent acfba3da29
commit 3e4325a176
58 changed files with 239 additions and 411 deletions

View File

@ -239,8 +239,8 @@ basepython = python3.10
ignore_basepython_conflict = true
commands =
superset db upgrade
superset load_test_users
superset init
superset load-test-users
# use -s to be able to use break pointers.
# no args or tests/* can be passed as an argument to run all tests
pytest -s {posargs}

View File

@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=consider-using-transaction
from collections import defaultdict
from superset import db, security_manager

View File

@ -28,9 +28,9 @@ export SUPERSET_TESTENV=true
echo "Superset config module: $SUPERSET_CONFIG"
superset db upgrade
superset load_test_users
superset init
superset load-test-users
echo "Running tests"
pytest --durations-min=2 --cov-report= --cov=superset ./tests/integration_tests "$@"
pytest --durations-min=2 --maxfail=1 --cov-report= --cov=superset ./tests/integration_tests "$@"

View File

@ -37,15 +37,7 @@ def load_test_users() -> None:
Syncs permissions for those users/roles
"""
print(Fore.GREEN + "Loading a set of users for unit tests")
load_test_users_run()
def load_test_users_run() -> None:
"""
Loads admin, alpha, and gamma user for testing purposes
Syncs permissions for those users/roles
"""
if app.config["TESTING"]:
sm = security_manager

View File

@ -77,7 +77,7 @@ def import_chart(
if chart.id is None:
db.session.flush()
if user := get_user():
if (user := get_user()) and user not in chart.owners:
chart.owners.append(user)
return chart

View File

@ -38,8 +38,8 @@ from superset.daos.chart import ChartDAO
from superset.daos.dashboard import DashboardDAO
from superset.exceptions import SupersetSecurityException
from superset.models.slice import Slice
from superset.utils.decorators import on_error, transaction
from superset.tags.models import ObjectType
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@ -62,14 +62,13 @@ class UpdateChartCommand(UpdateMixin, BaseCommand):
assert self._model
# Update tags
tags = self._properties.pop("tags", None)
if tags is not None:
if (tags := self._properties.pop("tags", None)) is not None:
update_tags(ObjectType.chart, self._model.id, self._model.tags, tags)
if self._properties.get("query_context_generation") is None:
self._properties["last_saved_at"] = datetime.now()
self._properties["last_saved_by"] = g.user
return ChartDAO.update(self._model, self._properties)
def validate(self) -> None:

View File

@ -188,7 +188,7 @@ def import_dashboard(
if dashboard.id is None:
db.session.flush()
if user := get_user():
if (user := get_user()) and user not in dashboard.owners:
dashboard.owners.append(user)
return dashboard

View File

@ -19,7 +19,6 @@ from functools import partial
from sqlalchemy.exc import SQLAlchemyError
from superset import db
from superset.commands.dashboard.permalink.base import BaseDashboardPermalinkCommand
from superset.commands.key_value.upsert import UpsertKeyValueCommand
from superset.daos.dashboard import DashboardDAO
@ -78,6 +77,7 @@ class CreateDashboardPermalinkCommand(BaseDashboardPermalinkCommand):
codec=self.codec,
).run()
assert key.id # for type checks
return encode_permalink_key(key=key.id, salt=self.salt)
def validate(self) -> None:
pass

View File

@ -53,19 +53,16 @@ class UpdateDashboardCommand(UpdateMixin, BaseCommand):
assert self._model
# Update tags
tags = self._properties.pop("tags", None)
if tags is not None:
update_tags(
ObjectType.dashboard, self._model.id, self._model.tags, tags
)
if (tags := self._properties.pop("tags", None)) is not None:
update_tags(ObjectType.dashboard, self._model.id, self._model.tags, tags)
dashboard = DashboardDAO.update(self._model, self._properties, commit=False)
dashboard = DashboardDAO.update(self._model, self._properties)
if self._properties.get("json_metadata"):
DashboardDAO.set_dash_metadata(
dashboard,
data=json.loads(self._properties.get("json_metadata", "{}")),
)
return dashboard
def validate(self) -> None:

View File

@ -40,7 +40,7 @@ from superset.commands.database.ssh_tunnel.exceptions import (
)
from superset.commands.database.test_connection import TestConnectionDatabaseCommand
from superset.daos.database import DatabaseDAO
from superset.daos.exceptions import DAOCreateFailedError
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.exceptions import SupersetErrorsException
from superset.extensions import event_logger, security_manager
from superset.models.core import Database

View File

@ -1,194 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from functools import partial
from typing import Any, Optional, TypedDict
import pandas as pd
from flask_babel import lazy_gettext as _
from superset import db
from superset.commands.base import BaseCommand
from superset.commands.database.exceptions import (
DatabaseNotFoundError,
DatabaseSchemaUploadNotAllowed,
DatabaseUploadFailed,
DatabaseUploadSaveMetadataFailed,
)
from superset.connectors.sqla.models import SqlaTable
from superset.daos.database import DatabaseDAO
from superset.models.core import Database
from superset.sql_parse import Table
from superset.utils.core import get_user
from superset.utils.decorators import on_error, transaction
from superset.views.database.validators import schema_allows_file_upload
logger = logging.getLogger(__name__)
READ_CSV_CHUNK_SIZE = 1000
class CSVImportOptions(TypedDict, total=False):
schema: str
delimiter: str
already_exists: str
column_data_types: dict[str, str]
column_dates: list[str]
column_labels: str
columns_read: list[str]
dataframe_index: str
day_first: bool
decimal_character: str
header_row: int
index_column: str
null_values: list[str]
overwrite_duplicates: bool
rows_to_read: int
skip_blank_lines: bool
skip_initial_space: bool
skip_rows: int
class CSVImportCommand(BaseCommand):
def __init__(
self,
model_id: int,
table_name: str,
file: Any,
options: CSVImportOptions,
) -> None:
self._model_id = model_id
self._model: Optional[Database] = None
self._table_name = table_name
self._schema = options.get("schema")
self._file = file
self._options = options
def _read_csv(self) -> pd.DataFrame:
"""
Read CSV file into a DataFrame
:return: pandas DataFrame
:throws DatabaseUploadFailed: if there is an error reading the CSV file
"""
try:
return pd.concat(
pd.read_csv(
chunksize=READ_CSV_CHUNK_SIZE,
encoding="utf-8",
filepath_or_buffer=self._file,
header=self._options.get("header_row", 0),
index_col=self._options.get("index_column"),
dayfirst=self._options.get("day_first", False),
iterator=True,
keep_default_na=not self._options.get("null_values"),
usecols=self._options.get("columns_read")
if self._options.get("columns_read") # None if an empty list
else None,
na_values=self._options.get("null_values")
if self._options.get("null_values") # None if an empty list
else None,
nrows=self._options.get("rows_to_read"),
parse_dates=self._options.get("column_dates"),
sep=self._options.get("delimiter", ","),
skip_blank_lines=self._options.get("skip_blank_lines", False),
skipinitialspace=self._options.get("skip_initial_space", False),
skiprows=self._options.get("skip_rows", 0),
dtype=self._options.get("column_data_types")
if self._options.get("column_data_types")
else None,
)
)
except (
pd.errors.ParserError,
pd.errors.EmptyDataError,
UnicodeDecodeError,
ValueError,
) as ex:
raise DatabaseUploadFailed(
message=_("Parsing error: %(error)s", error=str(ex))
) from ex
except Exception as ex:
raise DatabaseUploadFailed(_("Error reading CSV file")) from ex
def _dataframe_to_database(self, df: pd.DataFrame, database: Database) -> None:
"""
Upload DataFrame to database
:param df:
:throws DatabaseUploadFailed: if there is an error uploading the DataFrame
"""
try:
csv_table = Table(table=self._table_name, schema=self._schema)
database.db_engine_spec.df_to_sql(
database,
csv_table,
df,
to_sql_kwargs={
"chunksize": READ_CSV_CHUNK_SIZE,
"if_exists": self._options.get("already_exists", "fail"),
"index": self._options.get("index_column"),
"index_label": self._options.get("column_labels"),
},
)
except ValueError as ex:
raise DatabaseUploadFailed(
message=_(
"Table already exists. You can change your "
"'if table already exists' strategy to append or "
"replace or provide a different Table Name to use."
)
) from ex
except Exception as ex:
raise DatabaseUploadFailed(exception=ex) from ex
@transaction(on_error=partial(on_error, reraise=DatabaseUploadSaveMetadataFailed))
def run(self) -> None:
self.validate()
if not self._model:
return
df = self._read_csv()
self._dataframe_to_database(df, self._model)
sqla_table = (
db.session.query(SqlaTable)
.filter_by(
table_name=self._table_name,
schema=self._schema,
database_id=self._model_id,
)
.one_or_none()
)
if not sqla_table:
sqla_table = SqlaTable(
table_name=self._table_name,
database=self._model,
database_id=self._model_id,
owners=[get_user()],
schema=self._schema,
)
db.session.add(sqla_table)
sqla_table.fetch_metadata()
def validate(self) -> None:
self._model = DatabaseDAO.find_by_id(self._model_id)
if not self._model:
raise DatabaseNotFoundError()
if not schema_allows_file_upload(self._model, self._schema):
raise DatabaseSchemaUploadNotAllowed()

View File

@ -41,7 +41,6 @@ from superset.commands.database.ssh_tunnel.update import UpdateSSHTunnelCommand
from superset.daos.database import DatabaseDAO
from superset.daos.dataset import DatasetDAO
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.extensions import db
from superset.models.core import Database
from superset.utils.decorators import on_error, transaction
@ -78,11 +77,7 @@ class UpdateDatabaseCommand(BaseCommand):
original_database_name = self._model.database_name
try:
database = DatabaseDAO.update(
self._model,
self._properties,
commit=False,
)
database = DatabaseDAO.update(self._model, self._properties)
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
ssh_tunnel = self._handle_ssh_tunnel(database)
self._refresh_catalogs(database, original_database_name, ssh_tunnel)
@ -100,7 +95,6 @@ class UpdateDatabaseCommand(BaseCommand):
return None
if not is_feature_enabled("SSH_TUNNELING"):
db.session.rollback()
raise SSHTunnelingNotEnabledError()
current_ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
@ -130,14 +124,11 @@ class UpdateDatabaseCommand(BaseCommand):
This method captures a generic exception, since errors could potentially come
from any of the 50+ database drivers we support.
"""
try:
return database.get_all_catalog_names(
force=True,
ssh_tunnel=ssh_tunnel,
)
except Exception as ex:
db.session.rollback()
raise DatabaseConnectionFailedError() from ex
return database.get_all_catalog_names(
force=True,
ssh_tunnel=ssh_tunnel,
)
def _get_schema_names(
self,
@ -151,15 +142,12 @@ class UpdateDatabaseCommand(BaseCommand):
This method captures a generic exception, since errors could potentially come
from any of the 50+ database drivers we support.
"""
try:
return database.get_all_schema_names(
force=True,
catalog=catalog,
ssh_tunnel=ssh_tunnel,
)
except Exception as ex:
db.session.rollback()
raise DatabaseConnectionFailedError() from ex
return database.get_all_schema_names(
force=True,
catalog=catalog,
ssh_tunnel=ssh_tunnel,
)
def _refresh_catalogs(
self,
@ -224,8 +212,6 @@ class UpdateDatabaseCommand(BaseCommand):
schemas,
)
db.session.commit()
def _refresh_schemas(
self,
database: Database,

View File

@ -43,7 +43,7 @@ class DeleteDatasetColumnCommand(BaseCommand):
def run(self) -> None:
self.validate()
assert self._model
DatasetColumnDAO.delete(self._model)
DatasetColumnDAO.delete([self._model])
def validate(self) -> None:
# Validate/populate model exists

View File

@ -32,7 +32,7 @@ from superset.commands.dataset.exceptions import (
)
from superset.daos.dataset import DatasetDAO
from superset.exceptions import SupersetSecurityException
from superset.extensions import db, security_manager
from superset.extensions import security_manager
from superset.sql_parse import Table
from superset.utils.decorators import on_error, transaction

View File

@ -178,7 +178,7 @@ def import_dataset(
if data_uri and (not table_exists or force_data):
load_data(data_uri, dataset, dataset.database)
if user := get_user():
if (user := get_user()) and user not in dataset.owners:
dataset.owners.append(user)
return dataset

View File

@ -43,7 +43,7 @@ class DeleteDatasetMetricCommand(BaseCommand):
def run(self) -> None:
self.validate()
assert self._model
DatasetMetricDAO.delete(self._model)
DatasetMetricDAO.delete([self._model])
def validate(self) -> None:
# Validate/populate model exists

View File

@ -21,6 +21,7 @@ from typing import Any, Optional
from flask_appbuilder.models.sqla import Model
from marshmallow import ValidationError
from sqlalchemy.exc import SQLAlchemyError
from superset import security_manager
from superset.commands.base import BaseCommand, UpdateMixin
@ -60,7 +61,16 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand):
self.override_columns = override_columns
self._properties["override_columns"] = override_columns
@transaction(on_error=partial(on_error, reraise=DatasetUpdateFailedError))
@transaction(
on_error=partial(
on_error,
catches=(
SQLAlchemyError,
ValueError,
),
reraise=DatasetUpdateFailedError,
)
)
def run(self) -> Model:
self.validate()
assert self._model

View File

@ -20,7 +20,6 @@ from typing import Any, Optional
from sqlalchemy.exc import SQLAlchemyError
from superset import db
from superset.commands.explore.permalink.base import BaseExplorePermalinkCommand
from superset.commands.key_value.create import CreateKeyValueCommand
from superset.explore.permalink.exceptions import ExplorePermalinkCreateFailedError
@ -74,27 +73,7 @@ class CreateExplorePermalinkCommand(BaseExplorePermalinkCommand):
key = command.run()
if key.id is None:
raise ExplorePermalinkCreateFailedError("Unexpected missing key id")
db.session.commit()
return encode_permalink_key(key=key.id, salt=self.salt)
d_id, d_type = self.datasource.split("__")
datasource_id = int(d_id)
datasource_type = DatasourceType(d_type)
check_chart_access(datasource_id, self.chart_id, datasource_type)
value = {
"chartId": self.chart_id,
"datasourceId": datasource_id,
"datasourceType": datasource_type.value,
"datasource": self.datasource,
"state": self.state,
}
command = CreateKeyValueCommand(
resource=self.resource,
value=value,
codec=self.codec,
)
key = command.run()
return encode_permalink_key(key=key.id, salt=self.salt)
>>>>>>> c01dacb71a (chore(dao): Use nested session for operations)
def validate(self) -> None:
pass

View File

@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from functools import partial
from typing import Any, Optional
from marshmallow import Schema
@ -44,7 +45,7 @@ from superset.datasets.schemas import ImportV1DatasetSchema
from superset.migrations.shared.native_filters import migrate_dashboard
from superset.models.dashboard import dashboard_slices
from superset.queries.saved_queries.schemas import ImportV1SavedQuerySchema
from superset.utils.decorators import transaction
from superset.utils.decorators import on_error, transaction
class ImportAssetsCommand(BaseCommand):
@ -154,14 +155,16 @@ class ImportAssetsCommand(BaseCommand):
if chart.viz_type == "filter_box":
db.session.delete(chart)
@transaction()
@transaction(
on_error=partial(
on_error,
catches=(Exception,),
reraise=ImportFailedError,
)
)
def run(self) -> None:
self.validate()
try:
self._import(self._configs)
except Exception as ex:
raise ImportFailedError() from ex
self._import(self._configs)
def validate(self) -> None:
exceptions: list[ValidationError] = []

View File

@ -53,9 +53,11 @@ class DeleteKeyValueCommand(BaseCommand):
pass
def delete(self) -> bool:
filter_ = get_filter(self.resource, self.key)
if entry := db.session.query(KeyValueEntry).filter_by(**filter_).first():
if (
entry := db.session.query(KeyValueEntry)
.filter_by(**get_filter(self.resource, self.key))
.first()
):
db.session.delete(entry)
db.session.flush()
return True
return False

View File

@ -60,4 +60,3 @@ class DeleteExpiredKeyValueCommand(BaseCommand):
)
.delete()
)
db.session.flush()

View File

@ -21,6 +21,8 @@ from functools import partial
from typing import Any, Optional, Union
from uuid import UUID
from sqlalchemy.exc import SQLAlchemyError
from superset import db
from superset.commands.base import BaseCommand
from superset.commands.key_value.create import CreateKeyValueCommand
@ -71,7 +73,7 @@ class UpsertKeyValueCommand(BaseCommand):
@transaction(
on_error=partial(
on_error,
catches=(KeyValueCreateFailedError,),
catches=(KeyValueCreateFailedError, SQLAlchemyError),
reraise=KeyValueUpsertFailedError,
),
)
@ -82,16 +84,15 @@ class UpsertKeyValueCommand(BaseCommand):
pass
def upsert(self) -> Key:
filter_ = get_filter(self.resource, self.key)
entry: KeyValueEntry = (
db.session.query(KeyValueEntry).filter_by(**filter_).first()
)
if entry:
if (
entry := db.session.query(KeyValueEntry)
.filter_by(**get_filter(self.resource, self.key))
.first()
):
entry.value = self.codec.encode(self.value)
entry.expires_on = self.expires_on
entry.changed_on = datetime.now()
entry.changed_by_fk = get_user_id()
db.session.flush()
return Key(entry.id, entry.uuid)
return CreateKeyValueCommand(

View File

@ -137,6 +137,7 @@ class BaseReportState:
uuid=self._execution_id,
)
db.session.add(log)
db.session.commit() # pylint: disable=consider-using-transaction
def _get_url(
self,

View File

@ -49,7 +49,6 @@ class AsyncPruneReportScheduleLogCommand(BaseCommand):
report_schedule,
from_date,
)
db.session.commit()
logger.info(
"Deleted %s logs for report schedule id: %s",
str(row_count),

View File

@ -383,5 +383,4 @@ class TagDAO(BaseDAO[Tag]):
object_id,
tag.name,
)
db.session.add_all(tagged_objects)

View File

@ -32,7 +32,7 @@ from marshmallow import ValidationError
from werkzeug.wrappers import Response as WerkzeugResponse
from werkzeug.wsgi import FileWrapper
from superset import is_feature_enabled, thumbnail_cache
from superset import db, is_feature_enabled, thumbnail_cache
from superset.charts.schemas import ChartEntityResponseSchema
from superset.commands.dashboard.create import CreateDashboardCommand
from superset.commands.dashboard.delete import DeleteDashboardCommand
@ -1314,7 +1314,13 @@ class DashboardRestApi(BaseSupersetModelRestApi):
"""
try:
body = self.embedded_config_schema.load(request.json)
embedded = EmbeddedDashboardDAO.upsert(dashboard, body["allowed_domains"])
with db.session.begin_nested():
embedded = EmbeddedDashboardDAO.upsert(
dashboard,
body["allowed_domains"],
)
result = self.embedded_response_schema.dump(embedded)
return self.response(200, result=result)
except ValidationError as error:

View File

@ -22,7 +22,6 @@ from uuid import UUID, uuid3
from flask import current_app, Flask, has_app_context
from flask_caching import BaseCache
from superset import db
from superset.key_value.exceptions import KeyValueCreateFailedError
from superset.key_value.types import (
KeyValueCodec,
@ -95,7 +94,6 @@ class SupersetMetastoreCache(BaseCache):
codec=self.codec,
expires_on=self._get_expiry(timeout),
).run()
db.session.commit()
return True
def add(self, key: str, value: Any, timeout: Optional[int] = None) -> bool:
@ -111,7 +109,6 @@ class SupersetMetastoreCache(BaseCache):
key=self.get_key(key),
expires_on=self._get_expiry(timeout),
).run()
db.session.commit()
return True
except KeyValueCreateFailedError:
return False
@ -136,6 +133,4 @@ class SupersetMetastoreCache(BaseCache):
# pylint: disable=import-outside-toplevel
from superset.commands.key_value.delete import DeleteKeyValueCommand
ret = DeleteKeyValueCommand(resource=RESOURCE, key=self.get_key(key)).run()
db.session.commit()
return ret
return DeleteKeyValueCommand(resource=RESOURCE, key=self.get_key(key)).run()

View File

@ -18,7 +18,6 @@
from typing import Any, Optional
from uuid import uuid3
from superset import db
from superset.key_value.types import JsonKeyValueCodec, KeyValueResource, SharedKey
from superset.key_value.utils import get_uuid_namespace, random_key
@ -46,7 +45,6 @@ def set_shared_value(key: SharedKey, value: Any) -> None:
key=uuid_key,
codec=CODEC,
).run()
db.session.commit()
def get_permalink_salt(key: SharedKey) -> str:

View File

@ -334,22 +334,26 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
pyjwt_for_guest_token = _jwt_global_obj
@property
def get_session(self) -> Session:
def get_session2(self) -> Session:
"""
Flask-AppBuilder (FAB) which has a tendency to explicitly commit, thus violating
our definition of "unit of work".
By providing a monkey patched transaction for the FAB session ensures that any
explicit commit merely flushes and any rollback is a no-op.
By providing a monkey patched transaction for the FAB session, within the
confines of a nested session, ensures that any explicit commit merely flushes
and any rollback is a no-op.
"""
# pylint: disable=import-outside-toplevel
from superset import db
with db.session.begin_nested() as transaction:
transaction.session.commit = transaction.session.flush
transaction.session.rollback = lambda: None
return transaction.session
if db.session._proxied._nested_transaction: # pylint: disable=protected-access
with db.session.begin_nested() as transaction:
transaction.session.commit = transaction.session.flush
transaction.session.rollback = lambda: None
return transaction.session
return db.session
def create_login_manager(self, app: Flask) -> LoginManager:
lm = super().create_login_manager(app)
@ -1035,7 +1039,6 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
== None, # noqa: E711
)
)
self.get_session.flush()
if deleted_count := pvms.delete():
logger.info("Deleted %i faulty permissions", deleted_count)
@ -1065,7 +1068,6 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
)
self.create_missing_perms()
self.get_session.flush()
self.clean_perms()
def _get_all_pvms(self) -> list[PermissionView]:
@ -1138,7 +1140,6 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
):
role_from_permissions.append(permission_view)
role_to.permissions = role_from_permissions
self.get_session.flush()
def set_role(
self,
@ -1159,7 +1160,6 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
permission_view for permission_view in pvms if pvm_check(permission_view)
]
role.permissions = role_pvms
self.get_session.flush()
def _is_admin_only(self, pvm: PermissionView) -> bool:
"""

View File

@ -140,6 +140,7 @@ def get_tag(
if tag is None:
tag = Tag(name=escape(tag_name), type=type_)
session.add(tag)
session.commit()
return tag

View File

@ -59,6 +59,7 @@ def get_or_create_db(
if database and database.sqlalchemy_uri_decrypted != sqlalchemy_uri:
database.set_sqlalchemy_uri(sqlalchemy_uri)
db.session.flush()
return database
@ -78,3 +79,4 @@ def remove_database(database: Database) -> None:
from superset import db
db.session.delete(database)
db.session.flush()

View File

@ -222,7 +222,7 @@ def on_error(
:param ex: The source exception
:param catches: The exception types the handler catches
:param reraise: The exception type the handler reraises after catching
:param reraise: The exception type the handler raises after catching
:raises Exception: If the exception is not swallowed
"""
@ -252,13 +252,16 @@ def transaction( # pylint: disable=redefined-outer-name
from superset import db # pylint: disable=import-outside-toplevel
try:
with db.session.begin_nested():
return func(*args, **kwargs)
result = func(*args, **kwargs)
db.session.commit()
return result
except Exception as ex:
db.session.rollback()
if on_error:
return on_error(ex)
raise ex
raise
return wrapped

View File

@ -24,7 +24,6 @@ from contextlib import contextmanager
from datetime import datetime, timedelta
from typing import Any, cast, TypeVar, Union
from superset import db
from superset.exceptions import CreateKeyValueDistributedLockFailedException
from superset.key_value.exceptions import KeyValueCreateFailedError
from superset.key_value.types import JsonKeyValueCodec, KeyValueResource
@ -72,7 +71,6 @@ def KeyValueDistributedLock( # pylint: disable=invalid-name
store.
:param namespace: The namespace for which the lock is to be acquired.
:type namespace: str
:param kwargs: Additional keyword arguments.
:yields: A unique identifier (UUID) for the acquired lock (the KV key).
:raises CreateKeyValueDistributedLockFailedException: If the lock is taken.
@ -93,12 +91,10 @@ def KeyValueDistributedLock( # pylint: disable=invalid-name
value=True,
expires_on=datetime.now() + LOCK_EXPIRATION,
).run()
db.session.commit()
yield key
DeleteKeyValueCommand(resource=KeyValueResource.LOCK, key=key).run()
db.session.commit()
logger.debug("Removed lock on namespace %s for key %s", namespace, key)
except KeyValueCreateFailedError as ex:
raise CreateKeyValueDistributedLockFailedException(

View File

@ -403,6 +403,7 @@ class DBEventLogger(AbstractEventLogger):
logs.append(log)
try:
db.session.bulk_save_objects(logs)
db.session.commit() # pylint: disable=consider-using-transaction
except SQLAlchemyError as ex:
logging.error("DBEventLogger failed to log event(s)")
logging.exception(ex)

View File

@ -14,8 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=consider-using-transaction
from typing import Any, TYPE_CHECKING
from typing import TYPE_CHECKING
from flask_appbuilder import expose
from flask_appbuilder.models.sqla.interface import SQLAInterface

View File

@ -349,15 +349,7 @@ class SupersetTestCase(TestCase):
self.grant_role_access_to_table(table, role_name)
def grant_role_access_to_table(self, table, role_name):
print(">>> grant_role_access_to_table <<<")
print(role_name)
print(db.session.get_bind())
from flask_appbuilder.security.sqla.models import Role
print(list(db.session.query(Role).all()))
role = security_manager.find_role(role_name)
print(role)
perms = db.session.query(ab_models.PermissionView).all()
for perm in perms:
if (

View File

@ -332,6 +332,9 @@ def test_run_async_cta_query_with_lower_limit(test_client, ctas_method):
tmp_table=tmp_table,
)
query = wait_for_success(result)
print(">>> test_run_async_cta_query_with_lower_limit <<<")
print(result)
print(query.to_dict())
assert QueryStatus.SUCCESS == query.status
sqlite_select_sql = f"SELECT\n *\nFROM {tmp_table}\nLIMIT {query.limit}\nOFFSET 0"

View File

@ -91,7 +91,7 @@ class TestCore(SupersetTestCase):
self.original_unsafe_db_setting = app.config["PREVENT_UNSAFE_DB_CONNECTIONS"]
def tearDown(self):
# db.session.query(Query).delete()
db.session.query(Query).delete()
app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = self.original_unsafe_db_setting
super().tearDown()
@ -235,7 +235,6 @@ class TestCore(SupersetTestCase):
)
for slc in slices:
db.session.delete(slc)
print(db.session.dirty)
db.session.commit()
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@ -666,7 +665,6 @@ class TestCore(SupersetTestCase):
client_id="client_id_1",
username="admin",
)
print(resp)
count_ds = []
count_name = []
for series in data["data"]:
@ -814,7 +812,7 @@ class TestCore(SupersetTestCase):
mock_cache.return_value = MockCache()
rv = self.client.get("/superset/explore_json/data/valid-cache-key")
self.assertEqual(rv.status_code, 401)
self.assertEqual(rv.status_code, 403)
def test_explore_json_data_invalid_cache_key(self):
self.login(ADMIN_USERNAME)

View File

@ -97,5 +97,6 @@ def create_dashboard(
if slices is not None:
dash.slices = slices
db.session.add(dash)
db.session.commit()
return dash

View File

@ -592,7 +592,6 @@ class TestImportDashboardsCommand(SupersetTestCase):
}
command = v1.ImportDashboardsCommand(contents, overwrite=True)
command.run()
command.run()
new_num_dashboards = db.session.query(Dashboard).count()
assert new_num_dashboards == num_dashboards + 1

View File

@ -48,15 +48,16 @@ class TestDashboardDAO(SupersetTestCase):
assert changed_on == DashboardDAO.get_dashboard_changed_on("world_health")
old_changed_on = dashboard.changed_on
# freezegun doesn't work for some reason, so we need to sleep here :(
time.sleep(1)
data = dashboard.data
positions = data["position_json"]
data.update({"positions": positions})
original_data = copy.deepcopy(data)
data.update({"foo": "bar"})
DashboardDAO.set_dash_metadata(dashboard, data)
db.session.flush()
db.session.commit()
new_changed_on = DashboardDAO.get_dashboard_changed_on(dashboard)
assert old_changed_on.replace(microsecond=0) < new_changed_on
@ -68,7 +69,6 @@ class TestDashboardDAO(SupersetTestCase):
)
DashboardDAO.set_dash_metadata(dashboard, original_data)
db.session.flush()
db.session.commit()
@pytest.mark.usefixtures("load_world_bank_dashboard_with_slices")

View File

@ -26,6 +26,7 @@ import prison
import pytest
import yaml
from sqlalchemy import inspect
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import joinedload
from sqlalchemy.sql import func

View File

@ -44,6 +44,7 @@ class TestEmbeddedDashboardApi(SupersetTestCase):
self.login(ADMIN_USERNAME)
self.dash = db.session.query(Dashboard).filter_by(slug="births").first()
self.embedded = EmbeddedDashboardDAO.upsert(self.dash, [])
db.session.flush()
uri = f"api/v1/{self.resource_name}/{self.embedded.uuid}"
response = self.client.get(uri)
self.assert200(response)

View File

@ -34,17 +34,21 @@ class TestEmbeddedDashboardDAO(SupersetTestCase):
dash = db.session.query(Dashboard).filter_by(slug="world_health").first()
assert not dash.embedded
EmbeddedDashboardDAO.upsert(dash, ["test.example.com"])
db.session.flush()
assert dash.embedded
self.assertEqual(dash.embedded[0].allowed_domains, ["test.example.com"])
original_uuid = dash.embedded[0].uuid
self.assertIsNotNone(original_uuid)
EmbeddedDashboardDAO.upsert(dash, [])
db.session.flush()
self.assertEqual(dash.embedded[0].allowed_domains, [])
self.assertEqual(dash.embedded[0].uuid, original_uuid)
@pytest.mark.usefixtures("load_world_bank_dashboard_with_slices")
def test_get_by_uuid(self):
dash = db.session.query(Dashboard).filter_by(slug="world_health").first()
uuid = str(EmbeddedDashboardDAO.upsert(dash, ["test.example.com"]).uuid)
EmbeddedDashboardDAO.upsert(dash, ["test.example.com"])
db.session.flush()
uuid = str(dash.embedded[0].uuid)
embedded = EmbeddedDashboardDAO.find_by_id(uuid)
self.assertIsNotNone(embedded)

View File

@ -44,6 +44,7 @@ if TYPE_CHECKING:
def test_get_embedded_dashboard(client: FlaskClient[Any]): # noqa: F811
dash = db.session.query(Dashboard).filter_by(slug="births").first()
embedded = EmbeddedDashboardDAO.upsert(dash, [])
db.session.flush()
uri = f"embedded/{embedded.uuid}"
response = client.get(uri)
assert response.status_code == 200
@ -57,6 +58,7 @@ def test_get_embedded_dashboard(client: FlaskClient[Any]): # noqa: F811
def test_get_embedded_dashboard_referrer_not_allowed(client: FlaskClient[Any]): # noqa: F811
dash = db.session.query(Dashboard).filter_by(slug="births").first()
embedded = EmbeddedDashboardDAO.upsert(dash, ["test.example.com"])
db.session.flush()
uri = f"embedded/{embedded.uuid}"
response = client.get(uri)
assert response.status_code == 403

View File

@ -46,22 +46,25 @@ def load_birth_names_data(
@pytest.fixture()
def load_birth_names_dashboard_with_slices(load_birth_names_data):
with app.app_context():
_create_dashboards()
dash_id_to_delete, slices_ids_to_delete = _create_dashboards()
yield
_cleanup(dash_id_to_delete, slices_ids_to_delete)
@pytest.fixture(scope="module")
def load_birth_names_dashboard_with_slices_module_scope(load_birth_names_data):
with app.app_context():
_create_dashboards()
dash_id_to_delete, slices_ids_to_delete = _create_dashboards()
yield
_cleanup(dash_id_to_delete, slices_ids_to_delete)
@pytest.fixture(scope="class")
def load_birth_names_dashboard_with_slices_class_scope(load_birth_names_data):
with app.app_context():
_create_dashboards()
dash_id_to_delete, slices_ids_to_delete = _create_dashboards()
yield
_cleanup(dash_id_to_delete, slices_ids_to_delete)
def _create_dashboards():
@ -74,7 +77,10 @@ def _create_dashboards():
from superset.examples.birth_names import create_dashboard, create_slices
slices, _ = create_slices(table)
create_dashboard(slices)
dash = create_dashboard(slices)
slices_ids_to_delete = [slice.id for slice in slices]
dash_id_to_delete = dash.id
return dash_id_to_delete, slices_ids_to_delete
def _create_table(
@ -91,4 +97,20 @@ def _create_table(
_set_table_metadata(table, database)
_add_table_metrics(table)
db.session.commit()
return table
def _cleanup(dash_id: int, slice_ids: list[int]) -> None:
schema = get_example_default_schema()
for datasource in db.session.query(SqlaTable).filter_by(
table_name="birth_names", schema=schema
):
for col in datasource.columns + datasource.metrics:
db.session.delete(col)
for dash in db.session.query(Dashboard).filter_by(id=dash_id):
db.session.delete(dash)
for slc in db.session.query(Slice).filter(Slice.id.in_(slice_ids)):
db.session.delete(slc)
db.session.commit()

View File

@ -61,6 +61,7 @@ def load_energy_table_with_slice(load_energy_table_data):
with app.app_context():
slices = _create_energy_table()
yield slices
_cleanup()
def _get_dataframe():
@ -109,6 +110,24 @@ def _create_and_commit_energy_slice(
return slice
def _cleanup() -> None:
for slice_data in _get_energy_slices():
slice = (
db.session.query(Slice)
.filter_by(slice_name=slice_data["slice_title"])
.first()
)
db.session.delete(slice)
metric = (
db.session.query(SqlMetric).filter_by(metric_name="sum__value").one_or_none()
)
if metric:
db.session.delete(metric)
db.session.commit()
def _get_energy_data():
data = []
for i in range(85):

View File

@ -135,4 +135,7 @@ def tabbed_dashboard(app_context):
slices=[],
)
db.session.add(dash)
yield
db.session.commit()
yield dash
db.session.query(Dashboard).filter_by(id=dash.id).delete()
db.session.commit()

View File

@ -61,6 +61,7 @@ def load_unicode_dashboard_with_slice(load_unicode_data):
with app.app_context():
dash = _create_unicode_dashboard(slice_name, None)
yield
_cleanup(dash, slice_name)
@pytest.fixture()
@ -70,6 +71,7 @@ def load_unicode_dashboard_with_position(load_unicode_data):
with app.app_context():
dash = _create_unicode_dashboard(slice_name, position)
yield
_cleanup(dash, slice_name)
def _get_dataframe():
@ -95,18 +97,25 @@ def _create_unicode_dashboard(slice_title: str, position: str) -> Dashboard:
table.fetch_metadata()
if slice_title:
slice = _create_unicode_slice(table, slice_title)
slice = _create_and_commit_unicode_slice(table, slice_title)
return create_dashboard("unicode-test", "Unicode Test", position, [slice])
def _create_unicode_slice(table: SqlaTable, title: str):
slc = create_slice(title, "word_cloud", table, {})
if (
obj := db.session.query(Slice)
.filter_by(slice_name=slc.slice_name)
.one_or_none()
def _create_and_commit_unicode_slice(table: SqlaTable, title: str):
slice = create_slice(title, "word_cloud", table, {})
o = db.session.query(Slice).filter_by(slice_name=slice.slice_name).one_or_none()
if o:
db.session.delete(o)
db.session.add(slice)
db.session.commit()
return slice
def _cleanup(dash: Dashboard, slice_name: str) -> None:
db.session.delete(dash)
if slice_name and (
slice := db.session.query(Slice).filter_by(slice_name=slice_name).one_or_none()
):
db.session.delete(obj)
db.session.add(slc)
return slc
db.session.delete(slice)
db.session.commit()

View File

@ -64,7 +64,6 @@ def load_world_bank_data():
)
yield
with app.app_context():
with get_example_database().get_sqla_engine() as engine:
engine.execute("DROP TABLE IF EXISTS wb_health_population")
@ -73,14 +72,15 @@ def load_world_bank_data():
@pytest.fixture()
def load_world_bank_dashboard_with_slices(load_world_bank_data):
with app.app_context():
create_dashboard_for_loaded_data()
dash_id_to_delete, slices_ids_to_delete = create_dashboard_for_loaded_data()
yield
_cleanup(dash_id_to_delete, slices_ids_to_delete)
@pytest.fixture(scope="module")
def load_world_bank_dashboard_with_slices_module_scope(load_world_bank_data):
with app.app_context():
create_dashboard_for_loaded_data()
dash_id_to_delete, slices_ids_to_delete = create_dashboard_for_loaded_data()
yield
_cleanup_reports(dash_id_to_delete, slices_ids_to_delete)
_cleanup(dash_id_to_delete, slices_ids_to_delete)
@ -89,8 +89,9 @@ def load_world_bank_dashboard_with_slices_module_scope(load_world_bank_data):
@pytest.fixture(scope="class")
def load_world_bank_dashboard_with_slices_class_scope(load_world_bank_data):
with app.app_context():
create_dashboard_for_loaded_data()
dash_id_to_delete, slices_ids_to_delete = create_dashboard_for_loaded_data()
yield
_cleanup(dash_id_to_delete, slices_ids_to_delete)
def create_dashboard_for_loaded_data():
@ -102,22 +103,24 @@ def create_dashboard_for_loaded_data():
return dash_id_to_delete, slices_ids_to_delete
def _create_world_bank_slices(table: SqlaTable) -> None:
def _create_world_bank_slices(table: SqlaTable) -> list[Slice]:
from superset.examples.world_bank import create_slices
slices = create_slices(table)
for slc in slices:
if (
obj := db.session.query(Slice)
.filter_by(slice_name=slc.slice_name)
.one_or_none()
):
db.session.delete(obj)
db.session.add(slc)
_commit_slices(slices)
return slices
def _create_world_bank_dashboard(table: SqlaTable) -> None:
def _commit_slices(slices: list[Slice]):
for slice in slices:
o = db.session.query(Slice).filter_by(slice_name=slice.slice_name).one_or_none()
if o:
db.session.delete(o)
db.session.add(slice)
db.session.commit()
def _create_world_bank_dashboard(table: SqlaTable) -> Dashboard:
from superset.examples.helpers import update_slice_ids
from superset.examples.world_bank import dashboard_positions
@ -130,6 +133,16 @@ def _create_world_bank_dashboard(table: SqlaTable) -> None:
"world_health", "World Bank's Data", json.dumps(pos), slices
)
dash.json_metadata = '{"mock_key": "mock_value"}'
db.session.commit()
return dash
def _cleanup(dash_id: int, slices_ids: list[int]) -> None:
dash = db.session.query(Dashboard).filter_by(id=dash_id).first()
db.session.delete(dash)
for slice_id in slices_ids:
db.session.query(Slice).filter_by(id=slice_id).delete()
db.session.commit()
def _cleanup_reports(dash_id: int, slices_ids: list[int]) -> None:

View File

@ -215,8 +215,6 @@ class TestRowLevelSecurity(SupersetTestCase):
},
)
self.assertEqual(rv.status_code, 422)
data = json.loads(rv.data.decode("utf-8"))
assert "Create failed" in data["message"]
@pytest.mark.usefixtures("create_dataset")
def test_model_view_rls_add_tables_required(self):

View File

@ -71,10 +71,8 @@ QUERY_3 = "SELECT * FROM birth_names LIMIT 10"
class TestSqlLab(SupersetTestCase):
"""Testings for Sql Lab"""
@pytest.mark.usefixtures("load_birth_names_data")
def run_some_queries(self):
db.session.query(Query).delete()
db.session.commit()
self.run_sql(QUERY_1, client_id="client_id_1", username="admin")
self.run_sql(QUERY_2, client_id="client_id_2", username="admin")
self.run_sql(QUERY_3, client_id="client_id_3", username="gamma_sqllab")
@ -419,6 +417,7 @@ class TestSqlLab(SupersetTestCase):
self.assertEqual(len(data["data"]), 1200)
self.assertEqual(data["query"]["limitingFactor"], LimitingFactor.NOT_LIMITED)
@pytest.mark.usefixtures("load_birth_names_data")
def test_query_api_filter(self) -> None:
"""
Test query api without can_only_access_owned_queries perm added to
@ -438,6 +437,7 @@ class TestSqlLab(SupersetTestCase):
assert admin.first_name in user_queries
assert gamma_sqllab.first_name in user_queries
@pytest.mark.usefixtures("load_birth_names_data")
def test_query_api_can_access_all_queries(self) -> None:
"""
Test query api with can_access_all_queries perm added to
@ -522,6 +522,7 @@ class TestSqlLab(SupersetTestCase):
{r.get("sql") for r in self.get_json_resp(url)["result"]},
)
@pytest.mark.usefixtures("load_birth_names_data")
def test_query_admin_can_access_all_queries(self) -> None:
"""
Test query api with all_query_access perm added to

View File

@ -187,6 +187,7 @@ class TestTagsDAO(SupersetTestCase):
TaggedObject.object_type == ObjectType.chart,
),
)
.join(Tag, TaggedObject.tag_id == Tag.id)
.distinct(Slice.id)
.count()
)
@ -199,6 +200,7 @@ class TestTagsDAO(SupersetTestCase):
TaggedObject.object_type == ObjectType.dashboard,
),
)
.join(Tag, TaggedObject.tag_id == Tag.id)
.distinct(Dashboard.id)
.count()
+ num_charts

View File

@ -29,7 +29,6 @@ def database_with_catalog(mocker: MockerFixture) -> MagicMock:
"""
Mock a database with catalogs and schemas.
"""
mocker.patch("superset.commands.database.create.db")
mocker.patch("superset.commands.database.create.TestConnectionDatabaseCommand")
database = mocker.MagicMock()
@ -53,7 +52,6 @@ def database_without_catalog(mocker: MockerFixture) -> MagicMock:
"""
Mock a database without catalogs.
"""
mocker.patch("superset.commands.database.create.db")
mocker.patch("superset.commands.database.create.TestConnectionDatabaseCommand")
database = mocker.MagicMock()

View File

@ -29,8 +29,6 @@ def database_with_catalog(mocker: MockerFixture) -> MagicMock:
"""
Mock a database with catalogs and schemas.
"""
mocker.patch("superset.commands.database.update.db")
database = mocker.MagicMock()
database.database_name = "my_db"
database.db_engine_spec.__name__ = "test_engine"
@ -50,8 +48,6 @@ def database_without_catalog(mocker: MockerFixture) -> MagicMock:
"""
Mock a database without catalogs.
"""
mocker.patch("superset.commands.database.update.db")
database = mocker.MagicMock()
database.database_name = "my_db"
database.db_engine_spec.__name__ = "test_engine"

View File

@ -22,7 +22,6 @@ def test_user_favorite_tag(mocker):
from superset.daos.tag import TagDAO
# Mock the behavior of TagDAO and g
mock_session = mocker.patch("superset.daos.tag.db.session")
mock_TagDAO = mocker.patch(
"superset.daos.tag.TagDAO"
) # Replace with the actual path to TagDAO
@ -45,7 +44,6 @@ def test_remove_user_favorite_tag(mocker):
from superset.daos.tag import TagDAO
# Mock the behavior of TagDAO and g
mock_session = mocker.patch("superset.daos.tag.db.session")
mock_TagDAO = mocker.patch("superset.daos.tag.TagDAO")
mock_tag = mocker.MagicMock(users_favorited=[])
mock_TagDAO.find_by_id.return_value = mock_tag

View File

@ -116,7 +116,7 @@ def test_post_with_uuid(
assert payload["result"]["uuid"] == "7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb"
database = session.query(Database).one()
assert database.uuid == "7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb"
assert database.uuid == UUID("7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb")
def test_password_mask(

View File

@ -22,8 +22,8 @@ from uuid import UUID
import pytest
from freezegun import freeze_time
from sqlalchemy.orm import Session, sessionmaker
from superset import db
from superset.exceptions import CreateKeyValueDistributedLockFailedException
from superset.key_value.types import JsonKeyValueCodec
from superset.utils.lock import get_key, KeyValueDistributedLock
@ -32,56 +32,51 @@ MAIN_KEY = get_key("ns", a=1, b=2)
OTHER_KEY = get_key("ns2", a=1, b=2)
def _get_lock(key: UUID, session: Session) -> Any:
def _get_lock(key: UUID) -> Any:
from superset.key_value.models import KeyValueEntry
entry = session.query(KeyValueEntry).filter_by(uuid=key).first()
entry = db.session.query(KeyValueEntry).filter_by(uuid=key).first()
if entry is None or entry.is_expired():
return None
return JsonKeyValueCodec().decode(entry.value)
def _get_other_session() -> Session:
# This session is used to simulate what another worker will find in the metastore
# during the locking process.
from superset import db
bind = db.session.get_bind()
SessionMaker = sessionmaker(bind=bind)
return SessionMaker()
def test_key_value_distributed_lock_happy_path() -> None:
"""
Test successfully acquiring and returning the distributed lock.
Note we use a nested transaction to ensure that the cleanup from the outer context
manager is correctly invoked, otherwise a partial rollback would occur leaving the
database in a fractured state.
"""
session = _get_other_session()
with freeze_time("2021-01-01"):
assert _get_lock(MAIN_KEY, session) is None
assert _get_lock(MAIN_KEY) is None
with KeyValueDistributedLock("ns", a=1, b=2) as key:
assert key == MAIN_KEY
assert _get_lock(key, session) is True
assert _get_lock(OTHER_KEY, session) is None
with pytest.raises(CreateKeyValueDistributedLockFailedException):
with KeyValueDistributedLock("ns", a=1, b=2):
pass
assert _get_lock(key) is True
assert _get_lock(OTHER_KEY) is None
assert _get_lock(MAIN_KEY, session) is None
with db.session.begin_nested():
with pytest.raises(CreateKeyValueDistributedLockFailedException):
with KeyValueDistributedLock("ns", a=1, b=2):
pass
assert _get_lock(MAIN_KEY) is None
def test_key_value_distributed_lock_expired() -> None:
"""
Test expiration of the distributed lock
"""
session = _get_other_session()
with freeze_time("2021-01-01T"):
assert _get_lock(MAIN_KEY, session) is None
with freeze_time("2021-01-01"):
assert _get_lock(MAIN_KEY) is None
with KeyValueDistributedLock("ns", a=1, b=2):
assert _get_lock(MAIN_KEY, session) is True
with freeze_time("2022-01-01T"):
assert _get_lock(MAIN_KEY, session) is None
assert _get_lock(MAIN_KEY) is True
with freeze_time("2022-01-01"):
assert _get_lock(MAIN_KEY) is None
assert _get_lock(MAIN_KEY, session) is None
assert _get_lock(MAIN_KEY) is None