mirror of https://github.com/apache/superset.git
Revert interim test updates
This commit is contained in:
parent
acfba3da29
commit
3e4325a176
|
@ -239,8 +239,8 @@ basepython = python3.10
|
|||
ignore_basepython_conflict = true
|
||||
commands =
|
||||
superset db upgrade
|
||||
superset load_test_users
|
||||
superset init
|
||||
superset load-test-users
|
||||
# use -s to be able to use break pointers.
|
||||
# no args or tests/* can be passed as an argument to run all tests
|
||||
pytest -s {posargs}
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# pylint: disable=consider-using-transaction
|
||||
from collections import defaultdict
|
||||
|
||||
from superset import db, security_manager
|
||||
|
|
|
@ -28,9 +28,9 @@ export SUPERSET_TESTENV=true
|
|||
echo "Superset config module: $SUPERSET_CONFIG"
|
||||
|
||||
superset db upgrade
|
||||
superset load_test_users
|
||||
superset init
|
||||
superset load-test-users
|
||||
|
||||
echo "Running tests"
|
||||
|
||||
pytest --durations-min=2 --cov-report= --cov=superset ./tests/integration_tests "$@"
|
||||
pytest --durations-min=2 --maxfail=1 --cov-report= --cov=superset ./tests/integration_tests "$@"
|
||||
|
|
|
@ -37,15 +37,7 @@ def load_test_users() -> None:
|
|||
Syncs permissions for those users/roles
|
||||
"""
|
||||
print(Fore.GREEN + "Loading a set of users for unit tests")
|
||||
load_test_users_run()
|
||||
|
||||
|
||||
def load_test_users_run() -> None:
|
||||
"""
|
||||
Loads admin, alpha, and gamma user for testing purposes
|
||||
|
||||
Syncs permissions for those users/roles
|
||||
"""
|
||||
if app.config["TESTING"]:
|
||||
sm = security_manager
|
||||
|
||||
|
|
|
@ -77,7 +77,7 @@ def import_chart(
|
|||
if chart.id is None:
|
||||
db.session.flush()
|
||||
|
||||
if user := get_user():
|
||||
if (user := get_user()) and user not in chart.owners:
|
||||
chart.owners.append(user)
|
||||
|
||||
return chart
|
||||
|
|
|
@ -38,8 +38,8 @@ from superset.daos.chart import ChartDAO
|
|||
from superset.daos.dashboard import DashboardDAO
|
||||
from superset.exceptions import SupersetSecurityException
|
||||
from superset.models.slice import Slice
|
||||
from superset.utils.decorators import on_error, transaction
|
||||
from superset.tags.models import ObjectType
|
||||
from superset.utils.decorators import on_error, transaction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -62,8 +62,7 @@ class UpdateChartCommand(UpdateMixin, BaseCommand):
|
|||
assert self._model
|
||||
|
||||
# Update tags
|
||||
tags = self._properties.pop("tags", None)
|
||||
if tags is not None:
|
||||
if (tags := self._properties.pop("tags", None)) is not None:
|
||||
update_tags(ObjectType.chart, self._model.id, self._model.tags, tags)
|
||||
|
||||
if self._properties.get("query_context_generation") is None:
|
||||
|
|
|
@ -188,7 +188,7 @@ def import_dashboard(
|
|||
if dashboard.id is None:
|
||||
db.session.flush()
|
||||
|
||||
if user := get_user():
|
||||
if (user := get_user()) and user not in dashboard.owners:
|
||||
dashboard.owners.append(user)
|
||||
|
||||
return dashboard
|
||||
|
|
|
@ -19,7 +19,6 @@ from functools import partial
|
|||
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from superset import db
|
||||
from superset.commands.dashboard.permalink.base import BaseDashboardPermalinkCommand
|
||||
from superset.commands.key_value.upsert import UpsertKeyValueCommand
|
||||
from superset.daos.dashboard import DashboardDAO
|
||||
|
@ -78,6 +77,7 @@ class CreateDashboardPermalinkCommand(BaseDashboardPermalinkCommand):
|
|||
codec=self.codec,
|
||||
).run()
|
||||
assert key.id # for type checks
|
||||
return encode_permalink_key(key=key.id, salt=self.salt)
|
||||
|
||||
def validate(self) -> None:
|
||||
pass
|
||||
|
|
|
@ -53,13 +53,10 @@ class UpdateDashboardCommand(UpdateMixin, BaseCommand):
|
|||
assert self._model
|
||||
|
||||
# Update tags
|
||||
tags = self._properties.pop("tags", None)
|
||||
if tags is not None:
|
||||
update_tags(
|
||||
ObjectType.dashboard, self._model.id, self._model.tags, tags
|
||||
)
|
||||
if (tags := self._properties.pop("tags", None)) is not None:
|
||||
update_tags(ObjectType.dashboard, self._model.id, self._model.tags, tags)
|
||||
|
||||
dashboard = DashboardDAO.update(self._model, self._properties, commit=False)
|
||||
dashboard = DashboardDAO.update(self._model, self._properties)
|
||||
if self._properties.get("json_metadata"):
|
||||
DashboardDAO.set_dash_metadata(
|
||||
dashboard,
|
||||
|
|
|
@ -40,7 +40,7 @@ from superset.commands.database.ssh_tunnel.exceptions import (
|
|||
)
|
||||
from superset.commands.database.test_connection import TestConnectionDatabaseCommand
|
||||
from superset.daos.database import DatabaseDAO
|
||||
from superset.daos.exceptions import DAOCreateFailedError
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.exceptions import SupersetErrorsException
|
||||
from superset.extensions import event_logger, security_manager
|
||||
from superset.models.core import Database
|
||||
|
|
|
@ -1,194 +0,0 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import Any, Optional, TypedDict
|
||||
|
||||
import pandas as pd
|
||||
from flask_babel import lazy_gettext as _
|
||||
|
||||
from superset import db
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.commands.database.exceptions import (
|
||||
DatabaseNotFoundError,
|
||||
DatabaseSchemaUploadNotAllowed,
|
||||
DatabaseUploadFailed,
|
||||
DatabaseUploadSaveMetadataFailed,
|
||||
)
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.daos.database import DatabaseDAO
|
||||
from superset.models.core import Database
|
||||
from superset.sql_parse import Table
|
||||
from superset.utils.core import get_user
|
||||
from superset.utils.decorators import on_error, transaction
|
||||
from superset.views.database.validators import schema_allows_file_upload
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
READ_CSV_CHUNK_SIZE = 1000
|
||||
|
||||
|
||||
class CSVImportOptions(TypedDict, total=False):
|
||||
schema: str
|
||||
delimiter: str
|
||||
already_exists: str
|
||||
column_data_types: dict[str, str]
|
||||
column_dates: list[str]
|
||||
column_labels: str
|
||||
columns_read: list[str]
|
||||
dataframe_index: str
|
||||
day_first: bool
|
||||
decimal_character: str
|
||||
header_row: int
|
||||
index_column: str
|
||||
null_values: list[str]
|
||||
overwrite_duplicates: bool
|
||||
rows_to_read: int
|
||||
skip_blank_lines: bool
|
||||
skip_initial_space: bool
|
||||
skip_rows: int
|
||||
|
||||
|
||||
class CSVImportCommand(BaseCommand):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: int,
|
||||
table_name: str,
|
||||
file: Any,
|
||||
options: CSVImportOptions,
|
||||
) -> None:
|
||||
self._model_id = model_id
|
||||
self._model: Optional[Database] = None
|
||||
self._table_name = table_name
|
||||
self._schema = options.get("schema")
|
||||
self._file = file
|
||||
self._options = options
|
||||
|
||||
def _read_csv(self) -> pd.DataFrame:
|
||||
"""
|
||||
Read CSV file into a DataFrame
|
||||
|
||||
:return: pandas DataFrame
|
||||
:throws DatabaseUploadFailed: if there is an error reading the CSV file
|
||||
"""
|
||||
try:
|
||||
return pd.concat(
|
||||
pd.read_csv(
|
||||
chunksize=READ_CSV_CHUNK_SIZE,
|
||||
encoding="utf-8",
|
||||
filepath_or_buffer=self._file,
|
||||
header=self._options.get("header_row", 0),
|
||||
index_col=self._options.get("index_column"),
|
||||
dayfirst=self._options.get("day_first", False),
|
||||
iterator=True,
|
||||
keep_default_na=not self._options.get("null_values"),
|
||||
usecols=self._options.get("columns_read")
|
||||
if self._options.get("columns_read") # None if an empty list
|
||||
else None,
|
||||
na_values=self._options.get("null_values")
|
||||
if self._options.get("null_values") # None if an empty list
|
||||
else None,
|
||||
nrows=self._options.get("rows_to_read"),
|
||||
parse_dates=self._options.get("column_dates"),
|
||||
sep=self._options.get("delimiter", ","),
|
||||
skip_blank_lines=self._options.get("skip_blank_lines", False),
|
||||
skipinitialspace=self._options.get("skip_initial_space", False),
|
||||
skiprows=self._options.get("skip_rows", 0),
|
||||
dtype=self._options.get("column_data_types")
|
||||
if self._options.get("column_data_types")
|
||||
else None,
|
||||
)
|
||||
)
|
||||
except (
|
||||
pd.errors.ParserError,
|
||||
pd.errors.EmptyDataError,
|
||||
UnicodeDecodeError,
|
||||
ValueError,
|
||||
) as ex:
|
||||
raise DatabaseUploadFailed(
|
||||
message=_("Parsing error: %(error)s", error=str(ex))
|
||||
) from ex
|
||||
except Exception as ex:
|
||||
raise DatabaseUploadFailed(_("Error reading CSV file")) from ex
|
||||
|
||||
def _dataframe_to_database(self, df: pd.DataFrame, database: Database) -> None:
|
||||
"""
|
||||
Upload DataFrame to database
|
||||
|
||||
:param df:
|
||||
:throws DatabaseUploadFailed: if there is an error uploading the DataFrame
|
||||
"""
|
||||
try:
|
||||
csv_table = Table(table=self._table_name, schema=self._schema)
|
||||
database.db_engine_spec.df_to_sql(
|
||||
database,
|
||||
csv_table,
|
||||
df,
|
||||
to_sql_kwargs={
|
||||
"chunksize": READ_CSV_CHUNK_SIZE,
|
||||
"if_exists": self._options.get("already_exists", "fail"),
|
||||
"index": self._options.get("index_column"),
|
||||
"index_label": self._options.get("column_labels"),
|
||||
},
|
||||
)
|
||||
except ValueError as ex:
|
||||
raise DatabaseUploadFailed(
|
||||
message=_(
|
||||
"Table already exists. You can change your "
|
||||
"'if table already exists' strategy to append or "
|
||||
"replace or provide a different Table Name to use."
|
||||
)
|
||||
) from ex
|
||||
except Exception as ex:
|
||||
raise DatabaseUploadFailed(exception=ex) from ex
|
||||
|
||||
@transaction(on_error=partial(on_error, reraise=DatabaseUploadSaveMetadataFailed))
|
||||
def run(self) -> None:
|
||||
self.validate()
|
||||
if not self._model:
|
||||
return
|
||||
|
||||
df = self._read_csv()
|
||||
self._dataframe_to_database(df, self._model)
|
||||
|
||||
sqla_table = (
|
||||
db.session.query(SqlaTable)
|
||||
.filter_by(
|
||||
table_name=self._table_name,
|
||||
schema=self._schema,
|
||||
database_id=self._model_id,
|
||||
)
|
||||
.one_or_none()
|
||||
)
|
||||
if not sqla_table:
|
||||
sqla_table = SqlaTable(
|
||||
table_name=self._table_name,
|
||||
database=self._model,
|
||||
database_id=self._model_id,
|
||||
owners=[get_user()],
|
||||
schema=self._schema,
|
||||
)
|
||||
db.session.add(sqla_table)
|
||||
|
||||
sqla_table.fetch_metadata()
|
||||
|
||||
def validate(self) -> None:
|
||||
self._model = DatabaseDAO.find_by_id(self._model_id)
|
||||
if not self._model:
|
||||
raise DatabaseNotFoundError()
|
||||
if not schema_allows_file_upload(self._model, self._schema):
|
||||
raise DatabaseSchemaUploadNotAllowed()
|
|
@ -41,7 +41,6 @@ from superset.commands.database.ssh_tunnel.update import UpdateSSHTunnelCommand
|
|||
from superset.daos.database import DatabaseDAO
|
||||
from superset.daos.dataset import DatasetDAO
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.extensions import db
|
||||
from superset.models.core import Database
|
||||
from superset.utils.decorators import on_error, transaction
|
||||
|
||||
|
@ -78,11 +77,7 @@ class UpdateDatabaseCommand(BaseCommand):
|
|||
original_database_name = self._model.database_name
|
||||
|
||||
try:
|
||||
database = DatabaseDAO.update(
|
||||
self._model,
|
||||
self._properties,
|
||||
commit=False,
|
||||
)
|
||||
database = DatabaseDAO.update(self._model, self._properties)
|
||||
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
|
||||
ssh_tunnel = self._handle_ssh_tunnel(database)
|
||||
self._refresh_catalogs(database, original_database_name, ssh_tunnel)
|
||||
|
@ -100,7 +95,6 @@ class UpdateDatabaseCommand(BaseCommand):
|
|||
return None
|
||||
|
||||
if not is_feature_enabled("SSH_TUNNELING"):
|
||||
db.session.rollback()
|
||||
raise SSHTunnelingNotEnabledError()
|
||||
|
||||
current_ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
|
||||
|
@ -130,14 +124,11 @@ class UpdateDatabaseCommand(BaseCommand):
|
|||
This method captures a generic exception, since errors could potentially come
|
||||
from any of the 50+ database drivers we support.
|
||||
"""
|
||||
try:
|
||||
return database.get_all_catalog_names(
|
||||
force=True,
|
||||
ssh_tunnel=ssh_tunnel,
|
||||
)
|
||||
except Exception as ex:
|
||||
db.session.rollback()
|
||||
raise DatabaseConnectionFailedError() from ex
|
||||
|
||||
return database.get_all_catalog_names(
|
||||
force=True,
|
||||
ssh_tunnel=ssh_tunnel,
|
||||
)
|
||||
|
||||
def _get_schema_names(
|
||||
self,
|
||||
|
@ -151,15 +142,12 @@ class UpdateDatabaseCommand(BaseCommand):
|
|||
This method captures a generic exception, since errors could potentially come
|
||||
from any of the 50+ database drivers we support.
|
||||
"""
|
||||
try:
|
||||
return database.get_all_schema_names(
|
||||
force=True,
|
||||
catalog=catalog,
|
||||
ssh_tunnel=ssh_tunnel,
|
||||
)
|
||||
except Exception as ex:
|
||||
db.session.rollback()
|
||||
raise DatabaseConnectionFailedError() from ex
|
||||
|
||||
return database.get_all_schema_names(
|
||||
force=True,
|
||||
catalog=catalog,
|
||||
ssh_tunnel=ssh_tunnel,
|
||||
)
|
||||
|
||||
def _refresh_catalogs(
|
||||
self,
|
||||
|
@ -224,8 +212,6 @@ class UpdateDatabaseCommand(BaseCommand):
|
|||
schemas,
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
def _refresh_schemas(
|
||||
self,
|
||||
database: Database,
|
||||
|
|
|
@ -43,7 +43,7 @@ class DeleteDatasetColumnCommand(BaseCommand):
|
|||
def run(self) -> None:
|
||||
self.validate()
|
||||
assert self._model
|
||||
DatasetColumnDAO.delete(self._model)
|
||||
DatasetColumnDAO.delete([self._model])
|
||||
|
||||
def validate(self) -> None:
|
||||
# Validate/populate model exists
|
||||
|
|
|
@ -32,7 +32,7 @@ from superset.commands.dataset.exceptions import (
|
|||
)
|
||||
from superset.daos.dataset import DatasetDAO
|
||||
from superset.exceptions import SupersetSecurityException
|
||||
from superset.extensions import db, security_manager
|
||||
from superset.extensions import security_manager
|
||||
from superset.sql_parse import Table
|
||||
from superset.utils.decorators import on_error, transaction
|
||||
|
||||
|
|
|
@ -178,7 +178,7 @@ def import_dataset(
|
|||
if data_uri and (not table_exists or force_data):
|
||||
load_data(data_uri, dataset, dataset.database)
|
||||
|
||||
if user := get_user():
|
||||
if (user := get_user()) and user not in dataset.owners:
|
||||
dataset.owners.append(user)
|
||||
|
||||
return dataset
|
||||
|
|
|
@ -43,7 +43,7 @@ class DeleteDatasetMetricCommand(BaseCommand):
|
|||
def run(self) -> None:
|
||||
self.validate()
|
||||
assert self._model
|
||||
DatasetMetricDAO.delete(self._model)
|
||||
DatasetMetricDAO.delete([self._model])
|
||||
|
||||
def validate(self) -> None:
|
||||
# Validate/populate model exists
|
||||
|
|
|
@ -21,6 +21,7 @@ from typing import Any, Optional
|
|||
|
||||
from flask_appbuilder.models.sqla import Model
|
||||
from marshmallow import ValidationError
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from superset import security_manager
|
||||
from superset.commands.base import BaseCommand, UpdateMixin
|
||||
|
@ -60,7 +61,16 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand):
|
|||
self.override_columns = override_columns
|
||||
self._properties["override_columns"] = override_columns
|
||||
|
||||
@transaction(on_error=partial(on_error, reraise=DatasetUpdateFailedError))
|
||||
@transaction(
|
||||
on_error=partial(
|
||||
on_error,
|
||||
catches=(
|
||||
SQLAlchemyError,
|
||||
ValueError,
|
||||
),
|
||||
reraise=DatasetUpdateFailedError,
|
||||
)
|
||||
)
|
||||
def run(self) -> Model:
|
||||
self.validate()
|
||||
assert self._model
|
||||
|
|
|
@ -20,7 +20,6 @@ from typing import Any, Optional
|
|||
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from superset import db
|
||||
from superset.commands.explore.permalink.base import BaseExplorePermalinkCommand
|
||||
from superset.commands.key_value.create import CreateKeyValueCommand
|
||||
from superset.explore.permalink.exceptions import ExplorePermalinkCreateFailedError
|
||||
|
@ -74,27 +73,7 @@ class CreateExplorePermalinkCommand(BaseExplorePermalinkCommand):
|
|||
key = command.run()
|
||||
if key.id is None:
|
||||
raise ExplorePermalinkCreateFailedError("Unexpected missing key id")
|
||||
db.session.commit()
|
||||
return encode_permalink_key(key=key.id, salt=self.salt)
|
||||
d_id, d_type = self.datasource.split("__")
|
||||
datasource_id = int(d_id)
|
||||
datasource_type = DatasourceType(d_type)
|
||||
check_chart_access(datasource_id, self.chart_id, datasource_type)
|
||||
value = {
|
||||
"chartId": self.chart_id,
|
||||
"datasourceId": datasource_id,
|
||||
"datasourceType": datasource_type.value,
|
||||
"datasource": self.datasource,
|
||||
"state": self.state,
|
||||
}
|
||||
command = CreateKeyValueCommand(
|
||||
resource=self.resource,
|
||||
value=value,
|
||||
codec=self.codec,
|
||||
)
|
||||
key = command.run()
|
||||
return encode_permalink_key(key=key.id, salt=self.salt)
|
||||
>>>>>>> c01dacb71a (chore(dao): Use nested session for operations)
|
||||
|
||||
def validate(self) -> None:
|
||||
pass
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from functools import partial
|
||||
from typing import Any, Optional
|
||||
|
||||
from marshmallow import Schema
|
||||
|
@ -44,7 +45,7 @@ from superset.datasets.schemas import ImportV1DatasetSchema
|
|||
from superset.migrations.shared.native_filters import migrate_dashboard
|
||||
from superset.models.dashboard import dashboard_slices
|
||||
from superset.queries.saved_queries.schemas import ImportV1SavedQuerySchema
|
||||
from superset.utils.decorators import transaction
|
||||
from superset.utils.decorators import on_error, transaction
|
||||
|
||||
|
||||
class ImportAssetsCommand(BaseCommand):
|
||||
|
@ -154,14 +155,16 @@ class ImportAssetsCommand(BaseCommand):
|
|||
if chart.viz_type == "filter_box":
|
||||
db.session.delete(chart)
|
||||
|
||||
@transaction()
|
||||
@transaction(
|
||||
on_error=partial(
|
||||
on_error,
|
||||
catches=(Exception,),
|
||||
reraise=ImportFailedError,
|
||||
)
|
||||
)
|
||||
def run(self) -> None:
|
||||
self.validate()
|
||||
|
||||
try:
|
||||
self._import(self._configs)
|
||||
except Exception as ex:
|
||||
raise ImportFailedError() from ex
|
||||
self._import(self._configs)
|
||||
|
||||
def validate(self) -> None:
|
||||
exceptions: list[ValidationError] = []
|
||||
|
|
|
@ -53,9 +53,11 @@ class DeleteKeyValueCommand(BaseCommand):
|
|||
pass
|
||||
|
||||
def delete(self) -> bool:
|
||||
filter_ = get_filter(self.resource, self.key)
|
||||
if entry := db.session.query(KeyValueEntry).filter_by(**filter_).first():
|
||||
if (
|
||||
entry := db.session.query(KeyValueEntry)
|
||||
.filter_by(**get_filter(self.resource, self.key))
|
||||
.first()
|
||||
):
|
||||
db.session.delete(entry)
|
||||
db.session.flush()
|
||||
return True
|
||||
return False
|
||||
|
|
|
@ -60,4 +60,3 @@ class DeleteExpiredKeyValueCommand(BaseCommand):
|
|||
)
|
||||
.delete()
|
||||
)
|
||||
db.session.flush()
|
||||
|
|
|
@ -21,6 +21,8 @@ from functools import partial
|
|||
from typing import Any, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from superset import db
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.commands.key_value.create import CreateKeyValueCommand
|
||||
|
@ -71,7 +73,7 @@ class UpsertKeyValueCommand(BaseCommand):
|
|||
@transaction(
|
||||
on_error=partial(
|
||||
on_error,
|
||||
catches=(KeyValueCreateFailedError,),
|
||||
catches=(KeyValueCreateFailedError, SQLAlchemyError),
|
||||
reraise=KeyValueUpsertFailedError,
|
||||
),
|
||||
)
|
||||
|
@ -82,16 +84,15 @@ class UpsertKeyValueCommand(BaseCommand):
|
|||
pass
|
||||
|
||||
def upsert(self) -> Key:
|
||||
filter_ = get_filter(self.resource, self.key)
|
||||
entry: KeyValueEntry = (
|
||||
db.session.query(KeyValueEntry).filter_by(**filter_).first()
|
||||
)
|
||||
if entry:
|
||||
if (
|
||||
entry := db.session.query(KeyValueEntry)
|
||||
.filter_by(**get_filter(self.resource, self.key))
|
||||
.first()
|
||||
):
|
||||
entry.value = self.codec.encode(self.value)
|
||||
entry.expires_on = self.expires_on
|
||||
entry.changed_on = datetime.now()
|
||||
entry.changed_by_fk = get_user_id()
|
||||
db.session.flush()
|
||||
return Key(entry.id, entry.uuid)
|
||||
|
||||
return CreateKeyValueCommand(
|
||||
|
|
|
@ -137,6 +137,7 @@ class BaseReportState:
|
|||
uuid=self._execution_id,
|
||||
)
|
||||
db.session.add(log)
|
||||
db.session.commit() # pylint: disable=consider-using-transaction
|
||||
|
||||
def _get_url(
|
||||
self,
|
||||
|
|
|
@ -49,7 +49,6 @@ class AsyncPruneReportScheduleLogCommand(BaseCommand):
|
|||
report_schedule,
|
||||
from_date,
|
||||
)
|
||||
db.session.commit()
|
||||
logger.info(
|
||||
"Deleted %s logs for report schedule id: %s",
|
||||
str(row_count),
|
||||
|
|
|
@ -383,5 +383,4 @@ class TagDAO(BaseDAO[Tag]):
|
|||
object_id,
|
||||
tag.name,
|
||||
)
|
||||
|
||||
db.session.add_all(tagged_objects)
|
||||
|
|
|
@ -32,7 +32,7 @@ from marshmallow import ValidationError
|
|||
from werkzeug.wrappers import Response as WerkzeugResponse
|
||||
from werkzeug.wsgi import FileWrapper
|
||||
|
||||
from superset import is_feature_enabled, thumbnail_cache
|
||||
from superset import db, is_feature_enabled, thumbnail_cache
|
||||
from superset.charts.schemas import ChartEntityResponseSchema
|
||||
from superset.commands.dashboard.create import CreateDashboardCommand
|
||||
from superset.commands.dashboard.delete import DeleteDashboardCommand
|
||||
|
@ -1314,7 +1314,13 @@ class DashboardRestApi(BaseSupersetModelRestApi):
|
|||
"""
|
||||
try:
|
||||
body = self.embedded_config_schema.load(request.json)
|
||||
embedded = EmbeddedDashboardDAO.upsert(dashboard, body["allowed_domains"])
|
||||
|
||||
with db.session.begin_nested():
|
||||
embedded = EmbeddedDashboardDAO.upsert(
|
||||
dashboard,
|
||||
body["allowed_domains"],
|
||||
)
|
||||
|
||||
result = self.embedded_response_schema.dump(embedded)
|
||||
return self.response(200, result=result)
|
||||
except ValidationError as error:
|
||||
|
|
|
@ -22,7 +22,6 @@ from uuid import UUID, uuid3
|
|||
from flask import current_app, Flask, has_app_context
|
||||
from flask_caching import BaseCache
|
||||
|
||||
from superset import db
|
||||
from superset.key_value.exceptions import KeyValueCreateFailedError
|
||||
from superset.key_value.types import (
|
||||
KeyValueCodec,
|
||||
|
@ -95,7 +94,6 @@ class SupersetMetastoreCache(BaseCache):
|
|||
codec=self.codec,
|
||||
expires_on=self._get_expiry(timeout),
|
||||
).run()
|
||||
db.session.commit()
|
||||
return True
|
||||
|
||||
def add(self, key: str, value: Any, timeout: Optional[int] = None) -> bool:
|
||||
|
@ -111,7 +109,6 @@ class SupersetMetastoreCache(BaseCache):
|
|||
key=self.get_key(key),
|
||||
expires_on=self._get_expiry(timeout),
|
||||
).run()
|
||||
db.session.commit()
|
||||
return True
|
||||
except KeyValueCreateFailedError:
|
||||
return False
|
||||
|
@ -136,6 +133,4 @@ class SupersetMetastoreCache(BaseCache):
|
|||
# pylint: disable=import-outside-toplevel
|
||||
from superset.commands.key_value.delete import DeleteKeyValueCommand
|
||||
|
||||
ret = DeleteKeyValueCommand(resource=RESOURCE, key=self.get_key(key)).run()
|
||||
db.session.commit()
|
||||
return ret
|
||||
return DeleteKeyValueCommand(resource=RESOURCE, key=self.get_key(key)).run()
|
||||
|
|
|
@ -18,7 +18,6 @@
|
|||
from typing import Any, Optional
|
||||
from uuid import uuid3
|
||||
|
||||
from superset import db
|
||||
from superset.key_value.types import JsonKeyValueCodec, KeyValueResource, SharedKey
|
||||
from superset.key_value.utils import get_uuid_namespace, random_key
|
||||
|
||||
|
@ -46,7 +45,6 @@ def set_shared_value(key: SharedKey, value: Any) -> None:
|
|||
key=uuid_key,
|
||||
codec=CODEC,
|
||||
).run()
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def get_permalink_salt(key: SharedKey) -> str:
|
||||
|
|
|
@ -334,22 +334,26 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
pyjwt_for_guest_token = _jwt_global_obj
|
||||
|
||||
@property
|
||||
def get_session(self) -> Session:
|
||||
def get_session2(self) -> Session:
|
||||
"""
|
||||
Flask-AppBuilder (FAB) which has a tendency to explicitly commit, thus violating
|
||||
our definition of "unit of work".
|
||||
|
||||
By providing a monkey patched transaction for the FAB session ensures that any
|
||||
explicit commit merely flushes and any rollback is a no-op.
|
||||
By providing a monkey patched transaction for the FAB session, within the
|
||||
confines of a nested session, ensures that any explicit commit merely flushes
|
||||
and any rollback is a no-op.
|
||||
"""
|
||||
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from superset import db
|
||||
|
||||
with db.session.begin_nested() as transaction:
|
||||
transaction.session.commit = transaction.session.flush
|
||||
transaction.session.rollback = lambda: None
|
||||
return transaction.session
|
||||
if db.session._proxied._nested_transaction: # pylint: disable=protected-access
|
||||
with db.session.begin_nested() as transaction:
|
||||
transaction.session.commit = transaction.session.flush
|
||||
transaction.session.rollback = lambda: None
|
||||
return transaction.session
|
||||
|
||||
return db.session
|
||||
|
||||
def create_login_manager(self, app: Flask) -> LoginManager:
|
||||
lm = super().create_login_manager(app)
|
||||
|
@ -1035,7 +1039,6 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
== None, # noqa: E711
|
||||
)
|
||||
)
|
||||
self.get_session.flush()
|
||||
if deleted_count := pvms.delete():
|
||||
logger.info("Deleted %i faulty permissions", deleted_count)
|
||||
|
||||
|
@ -1065,7 +1068,6 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
)
|
||||
|
||||
self.create_missing_perms()
|
||||
self.get_session.flush()
|
||||
self.clean_perms()
|
||||
|
||||
def _get_all_pvms(self) -> list[PermissionView]:
|
||||
|
@ -1138,7 +1140,6 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
):
|
||||
role_from_permissions.append(permission_view)
|
||||
role_to.permissions = role_from_permissions
|
||||
self.get_session.flush()
|
||||
|
||||
def set_role(
|
||||
self,
|
||||
|
@ -1159,7 +1160,6 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
permission_view for permission_view in pvms if pvm_check(permission_view)
|
||||
]
|
||||
role.permissions = role_pvms
|
||||
self.get_session.flush()
|
||||
|
||||
def _is_admin_only(self, pvm: PermissionView) -> bool:
|
||||
"""
|
||||
|
|
|
@ -140,6 +140,7 @@ def get_tag(
|
|||
if tag is None:
|
||||
tag = Tag(name=escape(tag_name), type=type_)
|
||||
session.add(tag)
|
||||
session.commit()
|
||||
return tag
|
||||
|
||||
|
||||
|
|
|
@ -59,6 +59,7 @@ def get_or_create_db(
|
|||
if database and database.sqlalchemy_uri_decrypted != sqlalchemy_uri:
|
||||
database.set_sqlalchemy_uri(sqlalchemy_uri)
|
||||
|
||||
db.session.flush()
|
||||
return database
|
||||
|
||||
|
||||
|
@ -78,3 +79,4 @@ def remove_database(database: Database) -> None:
|
|||
from superset import db
|
||||
|
||||
db.session.delete(database)
|
||||
db.session.flush()
|
||||
|
|
|
@ -222,7 +222,7 @@ def on_error(
|
|||
|
||||
:param ex: The source exception
|
||||
:param catches: The exception types the handler catches
|
||||
:param reraise: The exception type the handler reraises after catching
|
||||
:param reraise: The exception type the handler raises after catching
|
||||
:raises Exception: If the exception is not swallowed
|
||||
"""
|
||||
|
||||
|
@ -252,13 +252,16 @@ def transaction( # pylint: disable=redefined-outer-name
|
|||
from superset import db # pylint: disable=import-outside-toplevel
|
||||
|
||||
try:
|
||||
with db.session.begin_nested():
|
||||
return func(*args, **kwargs)
|
||||
result = func(*args, **kwargs)
|
||||
db.session.commit()
|
||||
return result
|
||||
except Exception as ex:
|
||||
db.session.rollback()
|
||||
|
||||
if on_error:
|
||||
return on_error(ex)
|
||||
|
||||
raise ex
|
||||
raise
|
||||
|
||||
return wrapped
|
||||
|
||||
|
|
|
@ -24,7 +24,6 @@ from contextlib import contextmanager
|
|||
from datetime import datetime, timedelta
|
||||
from typing import Any, cast, TypeVar, Union
|
||||
|
||||
from superset import db
|
||||
from superset.exceptions import CreateKeyValueDistributedLockFailedException
|
||||
from superset.key_value.exceptions import KeyValueCreateFailedError
|
||||
from superset.key_value.types import JsonKeyValueCodec, KeyValueResource
|
||||
|
@ -72,7 +71,6 @@ def KeyValueDistributedLock( # pylint: disable=invalid-name
|
|||
store.
|
||||
|
||||
:param namespace: The namespace for which the lock is to be acquired.
|
||||
:type namespace: str
|
||||
:param kwargs: Additional keyword arguments.
|
||||
:yields: A unique identifier (UUID) for the acquired lock (the KV key).
|
||||
:raises CreateKeyValueDistributedLockFailedException: If the lock is taken.
|
||||
|
@ -93,12 +91,10 @@ def KeyValueDistributedLock( # pylint: disable=invalid-name
|
|||
value=True,
|
||||
expires_on=datetime.now() + LOCK_EXPIRATION,
|
||||
).run()
|
||||
db.session.commit()
|
||||
|
||||
yield key
|
||||
|
||||
DeleteKeyValueCommand(resource=KeyValueResource.LOCK, key=key).run()
|
||||
db.session.commit()
|
||||
logger.debug("Removed lock on namespace %s for key %s", namespace, key)
|
||||
except KeyValueCreateFailedError as ex:
|
||||
raise CreateKeyValueDistributedLockFailedException(
|
||||
|
|
|
@ -403,6 +403,7 @@ class DBEventLogger(AbstractEventLogger):
|
|||
logs.append(log)
|
||||
try:
|
||||
db.session.bulk_save_objects(logs)
|
||||
db.session.commit() # pylint: disable=consider-using-transaction
|
||||
except SQLAlchemyError as ex:
|
||||
logging.error("DBEventLogger failed to log event(s)")
|
||||
logging.exception(ex)
|
||||
|
|
|
@ -14,8 +14,7 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# pylint: disable=consider-using-transaction
|
||||
from typing import Any, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from flask_appbuilder import expose
|
||||
from flask_appbuilder.models.sqla.interface import SQLAInterface
|
||||
|
|
|
@ -349,15 +349,7 @@ class SupersetTestCase(TestCase):
|
|||
self.grant_role_access_to_table(table, role_name)
|
||||
|
||||
def grant_role_access_to_table(self, table, role_name):
|
||||
print(">>> grant_role_access_to_table <<<")
|
||||
print(role_name)
|
||||
print(db.session.get_bind())
|
||||
from flask_appbuilder.security.sqla.models import Role
|
||||
|
||||
print(list(db.session.query(Role).all()))
|
||||
role = security_manager.find_role(role_name)
|
||||
print(role)
|
||||
|
||||
perms = db.session.query(ab_models.PermissionView).all()
|
||||
for perm in perms:
|
||||
if (
|
||||
|
|
|
@ -332,6 +332,9 @@ def test_run_async_cta_query_with_lower_limit(test_client, ctas_method):
|
|||
tmp_table=tmp_table,
|
||||
)
|
||||
query = wait_for_success(result)
|
||||
print(">>> test_run_async_cta_query_with_lower_limit <<<")
|
||||
print(result)
|
||||
print(query.to_dict())
|
||||
assert QueryStatus.SUCCESS == query.status
|
||||
|
||||
sqlite_select_sql = f"SELECT\n *\nFROM {tmp_table}\nLIMIT {query.limit}\nOFFSET 0"
|
||||
|
|
|
@ -91,7 +91,7 @@ class TestCore(SupersetTestCase):
|
|||
self.original_unsafe_db_setting = app.config["PREVENT_UNSAFE_DB_CONNECTIONS"]
|
||||
|
||||
def tearDown(self):
|
||||
# db.session.query(Query).delete()
|
||||
db.session.query(Query).delete()
|
||||
app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = self.original_unsafe_db_setting
|
||||
super().tearDown()
|
||||
|
||||
|
@ -235,7 +235,6 @@ class TestCore(SupersetTestCase):
|
|||
)
|
||||
for slc in slices:
|
||||
db.session.delete(slc)
|
||||
print(db.session.dirty)
|
||||
db.session.commit()
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
|
@ -666,7 +665,6 @@ class TestCore(SupersetTestCase):
|
|||
client_id="client_id_1",
|
||||
username="admin",
|
||||
)
|
||||
print(resp)
|
||||
count_ds = []
|
||||
count_name = []
|
||||
for series in data["data"]:
|
||||
|
@ -814,7 +812,7 @@ class TestCore(SupersetTestCase):
|
|||
mock_cache.return_value = MockCache()
|
||||
|
||||
rv = self.client.get("/superset/explore_json/data/valid-cache-key")
|
||||
self.assertEqual(rv.status_code, 401)
|
||||
self.assertEqual(rv.status_code, 403)
|
||||
|
||||
def test_explore_json_data_invalid_cache_key(self):
|
||||
self.login(ADMIN_USERNAME)
|
||||
|
|
|
@ -97,5 +97,6 @@ def create_dashboard(
|
|||
if slices is not None:
|
||||
dash.slices = slices
|
||||
db.session.add(dash)
|
||||
db.session.commit()
|
||||
|
||||
return dash
|
||||
|
|
|
@ -592,7 +592,6 @@ class TestImportDashboardsCommand(SupersetTestCase):
|
|||
}
|
||||
command = v1.ImportDashboardsCommand(contents, overwrite=True)
|
||||
command.run()
|
||||
command.run()
|
||||
|
||||
new_num_dashboards = db.session.query(Dashboard).count()
|
||||
assert new_num_dashboards == num_dashboards + 1
|
||||
|
|
|
@ -48,15 +48,16 @@ class TestDashboardDAO(SupersetTestCase):
|
|||
assert changed_on == DashboardDAO.get_dashboard_changed_on("world_health")
|
||||
|
||||
old_changed_on = dashboard.changed_on
|
||||
|
||||
# freezegun doesn't work for some reason, so we need to sleep here :(
|
||||
time.sleep(1)
|
||||
data = dashboard.data
|
||||
positions = data["position_json"]
|
||||
data.update({"positions": positions})
|
||||
original_data = copy.deepcopy(data)
|
||||
|
||||
data.update({"foo": "bar"})
|
||||
DashboardDAO.set_dash_metadata(dashboard, data)
|
||||
db.session.flush()
|
||||
db.session.commit()
|
||||
new_changed_on = DashboardDAO.get_dashboard_changed_on(dashboard)
|
||||
assert old_changed_on.replace(microsecond=0) < new_changed_on
|
||||
|
@ -68,7 +69,6 @@ class TestDashboardDAO(SupersetTestCase):
|
|||
)
|
||||
|
||||
DashboardDAO.set_dash_metadata(dashboard, original_data)
|
||||
db.session.flush()
|
||||
db.session.commit()
|
||||
|
||||
@pytest.mark.usefixtures("load_world_bank_dashboard_with_slices")
|
||||
|
|
|
@ -26,6 +26,7 @@ import prison
|
|||
import pytest
|
||||
import yaml
|
||||
from sqlalchemy import inspect
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
|
|
|
@ -44,6 +44,7 @@ class TestEmbeddedDashboardApi(SupersetTestCase):
|
|||
self.login(ADMIN_USERNAME)
|
||||
self.dash = db.session.query(Dashboard).filter_by(slug="births").first()
|
||||
self.embedded = EmbeddedDashboardDAO.upsert(self.dash, [])
|
||||
db.session.flush()
|
||||
uri = f"api/v1/{self.resource_name}/{self.embedded.uuid}"
|
||||
response = self.client.get(uri)
|
||||
self.assert200(response)
|
||||
|
|
|
@ -34,17 +34,21 @@ class TestEmbeddedDashboardDAO(SupersetTestCase):
|
|||
dash = db.session.query(Dashboard).filter_by(slug="world_health").first()
|
||||
assert not dash.embedded
|
||||
EmbeddedDashboardDAO.upsert(dash, ["test.example.com"])
|
||||
db.session.flush()
|
||||
assert dash.embedded
|
||||
self.assertEqual(dash.embedded[0].allowed_domains, ["test.example.com"])
|
||||
original_uuid = dash.embedded[0].uuid
|
||||
self.assertIsNotNone(original_uuid)
|
||||
EmbeddedDashboardDAO.upsert(dash, [])
|
||||
db.session.flush()
|
||||
self.assertEqual(dash.embedded[0].allowed_domains, [])
|
||||
self.assertEqual(dash.embedded[0].uuid, original_uuid)
|
||||
|
||||
@pytest.mark.usefixtures("load_world_bank_dashboard_with_slices")
|
||||
def test_get_by_uuid(self):
|
||||
dash = db.session.query(Dashboard).filter_by(slug="world_health").first()
|
||||
uuid = str(EmbeddedDashboardDAO.upsert(dash, ["test.example.com"]).uuid)
|
||||
EmbeddedDashboardDAO.upsert(dash, ["test.example.com"])
|
||||
db.session.flush()
|
||||
uuid = str(dash.embedded[0].uuid)
|
||||
embedded = EmbeddedDashboardDAO.find_by_id(uuid)
|
||||
self.assertIsNotNone(embedded)
|
||||
|
|
|
@ -44,6 +44,7 @@ if TYPE_CHECKING:
|
|||
def test_get_embedded_dashboard(client: FlaskClient[Any]): # noqa: F811
|
||||
dash = db.session.query(Dashboard).filter_by(slug="births").first()
|
||||
embedded = EmbeddedDashboardDAO.upsert(dash, [])
|
||||
db.session.flush()
|
||||
uri = f"embedded/{embedded.uuid}"
|
||||
response = client.get(uri)
|
||||
assert response.status_code == 200
|
||||
|
@ -57,6 +58,7 @@ def test_get_embedded_dashboard(client: FlaskClient[Any]): # noqa: F811
|
|||
def test_get_embedded_dashboard_referrer_not_allowed(client: FlaskClient[Any]): # noqa: F811
|
||||
dash = db.session.query(Dashboard).filter_by(slug="births").first()
|
||||
embedded = EmbeddedDashboardDAO.upsert(dash, ["test.example.com"])
|
||||
db.session.flush()
|
||||
uri = f"embedded/{embedded.uuid}"
|
||||
response = client.get(uri)
|
||||
assert response.status_code == 403
|
||||
|
|
|
@ -46,22 +46,25 @@ def load_birth_names_data(
|
|||
@pytest.fixture()
|
||||
def load_birth_names_dashboard_with_slices(load_birth_names_data):
|
||||
with app.app_context():
|
||||
_create_dashboards()
|
||||
dash_id_to_delete, slices_ids_to_delete = _create_dashboards()
|
||||
yield
|
||||
_cleanup(dash_id_to_delete, slices_ids_to_delete)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def load_birth_names_dashboard_with_slices_module_scope(load_birth_names_data):
|
||||
with app.app_context():
|
||||
_create_dashboards()
|
||||
dash_id_to_delete, slices_ids_to_delete = _create_dashboards()
|
||||
yield
|
||||
_cleanup(dash_id_to_delete, slices_ids_to_delete)
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def load_birth_names_dashboard_with_slices_class_scope(load_birth_names_data):
|
||||
with app.app_context():
|
||||
_create_dashboards()
|
||||
dash_id_to_delete, slices_ids_to_delete = _create_dashboards()
|
||||
yield
|
||||
_cleanup(dash_id_to_delete, slices_ids_to_delete)
|
||||
|
||||
|
||||
def _create_dashboards():
|
||||
|
@ -74,7 +77,10 @@ def _create_dashboards():
|
|||
from superset.examples.birth_names import create_dashboard, create_slices
|
||||
|
||||
slices, _ = create_slices(table)
|
||||
create_dashboard(slices)
|
||||
dash = create_dashboard(slices)
|
||||
slices_ids_to_delete = [slice.id for slice in slices]
|
||||
dash_id_to_delete = dash.id
|
||||
return dash_id_to_delete, slices_ids_to_delete
|
||||
|
||||
|
||||
def _create_table(
|
||||
|
@ -91,4 +97,20 @@ def _create_table(
|
|||
|
||||
_set_table_metadata(table, database)
|
||||
_add_table_metrics(table)
|
||||
db.session.commit()
|
||||
return table
|
||||
|
||||
|
||||
def _cleanup(dash_id: int, slice_ids: list[int]) -> None:
|
||||
schema = get_example_default_schema()
|
||||
for datasource in db.session.query(SqlaTable).filter_by(
|
||||
table_name="birth_names", schema=schema
|
||||
):
|
||||
for col in datasource.columns + datasource.metrics:
|
||||
db.session.delete(col)
|
||||
|
||||
for dash in db.session.query(Dashboard).filter_by(id=dash_id):
|
||||
db.session.delete(dash)
|
||||
for slc in db.session.query(Slice).filter(Slice.id.in_(slice_ids)):
|
||||
db.session.delete(slc)
|
||||
db.session.commit()
|
||||
|
|
|
@ -61,6 +61,7 @@ def load_energy_table_with_slice(load_energy_table_data):
|
|||
with app.app_context():
|
||||
slices = _create_energy_table()
|
||||
yield slices
|
||||
_cleanup()
|
||||
|
||||
|
||||
def _get_dataframe():
|
||||
|
@ -109,6 +110,24 @@ def _create_and_commit_energy_slice(
|
|||
return slice
|
||||
|
||||
|
||||
def _cleanup() -> None:
|
||||
for slice_data in _get_energy_slices():
|
||||
slice = (
|
||||
db.session.query(Slice)
|
||||
.filter_by(slice_name=slice_data["slice_title"])
|
||||
.first()
|
||||
)
|
||||
db.session.delete(slice)
|
||||
|
||||
metric = (
|
||||
db.session.query(SqlMetric).filter_by(metric_name="sum__value").one_or_none()
|
||||
)
|
||||
if metric:
|
||||
db.session.delete(metric)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def _get_energy_data():
|
||||
data = []
|
||||
for i in range(85):
|
||||
|
|
|
@ -135,4 +135,7 @@ def tabbed_dashboard(app_context):
|
|||
slices=[],
|
||||
)
|
||||
db.session.add(dash)
|
||||
yield
|
||||
db.session.commit()
|
||||
yield dash
|
||||
db.session.query(Dashboard).filter_by(id=dash.id).delete()
|
||||
db.session.commit()
|
||||
|
|
|
@ -61,6 +61,7 @@ def load_unicode_dashboard_with_slice(load_unicode_data):
|
|||
with app.app_context():
|
||||
dash = _create_unicode_dashboard(slice_name, None)
|
||||
yield
|
||||
_cleanup(dash, slice_name)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
|
@ -70,6 +71,7 @@ def load_unicode_dashboard_with_position(load_unicode_data):
|
|||
with app.app_context():
|
||||
dash = _create_unicode_dashboard(slice_name, position)
|
||||
yield
|
||||
_cleanup(dash, slice_name)
|
||||
|
||||
|
||||
def _get_dataframe():
|
||||
|
@ -95,18 +97,25 @@ def _create_unicode_dashboard(slice_title: str, position: str) -> Dashboard:
|
|||
table.fetch_metadata()
|
||||
|
||||
if slice_title:
|
||||
slice = _create_unicode_slice(table, slice_title)
|
||||
slice = _create_and_commit_unicode_slice(table, slice_title)
|
||||
|
||||
return create_dashboard("unicode-test", "Unicode Test", position, [slice])
|
||||
|
||||
|
||||
def _create_unicode_slice(table: SqlaTable, title: str):
|
||||
slc = create_slice(title, "word_cloud", table, {})
|
||||
if (
|
||||
obj := db.session.query(Slice)
|
||||
.filter_by(slice_name=slc.slice_name)
|
||||
.one_or_none()
|
||||
def _create_and_commit_unicode_slice(table: SqlaTable, title: str):
|
||||
slice = create_slice(title, "word_cloud", table, {})
|
||||
o = db.session.query(Slice).filter_by(slice_name=slice.slice_name).one_or_none()
|
||||
if o:
|
||||
db.session.delete(o)
|
||||
db.session.add(slice)
|
||||
db.session.commit()
|
||||
return slice
|
||||
|
||||
|
||||
def _cleanup(dash: Dashboard, slice_name: str) -> None:
|
||||
db.session.delete(dash)
|
||||
if slice_name and (
|
||||
slice := db.session.query(Slice).filter_by(slice_name=slice_name).one_or_none()
|
||||
):
|
||||
db.session.delete(obj)
|
||||
db.session.add(slc)
|
||||
return slc
|
||||
db.session.delete(slice)
|
||||
db.session.commit()
|
||||
|
|
|
@ -64,7 +64,6 @@ def load_world_bank_data():
|
|||
)
|
||||
|
||||
yield
|
||||
|
||||
with app.app_context():
|
||||
with get_example_database().get_sqla_engine() as engine:
|
||||
engine.execute("DROP TABLE IF EXISTS wb_health_population")
|
||||
|
@ -73,14 +72,15 @@ def load_world_bank_data():
|
|||
@pytest.fixture()
|
||||
def load_world_bank_dashboard_with_slices(load_world_bank_data):
|
||||
with app.app_context():
|
||||
create_dashboard_for_loaded_data()
|
||||
dash_id_to_delete, slices_ids_to_delete = create_dashboard_for_loaded_data()
|
||||
yield
|
||||
_cleanup(dash_id_to_delete, slices_ids_to_delete)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def load_world_bank_dashboard_with_slices_module_scope(load_world_bank_data):
|
||||
with app.app_context():
|
||||
create_dashboard_for_loaded_data()
|
||||
dash_id_to_delete, slices_ids_to_delete = create_dashboard_for_loaded_data()
|
||||
yield
|
||||
_cleanup_reports(dash_id_to_delete, slices_ids_to_delete)
|
||||
_cleanup(dash_id_to_delete, slices_ids_to_delete)
|
||||
|
@ -89,8 +89,9 @@ def load_world_bank_dashboard_with_slices_module_scope(load_world_bank_data):
|
|||
@pytest.fixture(scope="class")
|
||||
def load_world_bank_dashboard_with_slices_class_scope(load_world_bank_data):
|
||||
with app.app_context():
|
||||
create_dashboard_for_loaded_data()
|
||||
dash_id_to_delete, slices_ids_to_delete = create_dashboard_for_loaded_data()
|
||||
yield
|
||||
_cleanup(dash_id_to_delete, slices_ids_to_delete)
|
||||
|
||||
|
||||
def create_dashboard_for_loaded_data():
|
||||
|
@ -102,22 +103,24 @@ def create_dashboard_for_loaded_data():
|
|||
return dash_id_to_delete, slices_ids_to_delete
|
||||
|
||||
|
||||
def _create_world_bank_slices(table: SqlaTable) -> None:
|
||||
def _create_world_bank_slices(table: SqlaTable) -> list[Slice]:
|
||||
from superset.examples.world_bank import create_slices
|
||||
|
||||
slices = create_slices(table)
|
||||
|
||||
for slc in slices:
|
||||
if (
|
||||
obj := db.session.query(Slice)
|
||||
.filter_by(slice_name=slc.slice_name)
|
||||
.one_or_none()
|
||||
):
|
||||
db.session.delete(obj)
|
||||
db.session.add(slc)
|
||||
_commit_slices(slices)
|
||||
return slices
|
||||
|
||||
|
||||
def _create_world_bank_dashboard(table: SqlaTable) -> None:
|
||||
def _commit_slices(slices: list[Slice]):
|
||||
for slice in slices:
|
||||
o = db.session.query(Slice).filter_by(slice_name=slice.slice_name).one_or_none()
|
||||
if o:
|
||||
db.session.delete(o)
|
||||
db.session.add(slice)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def _create_world_bank_dashboard(table: SqlaTable) -> Dashboard:
|
||||
from superset.examples.helpers import update_slice_ids
|
||||
from superset.examples.world_bank import dashboard_positions
|
||||
|
||||
|
@ -130,6 +133,16 @@ def _create_world_bank_dashboard(table: SqlaTable) -> None:
|
|||
"world_health", "World Bank's Data", json.dumps(pos), slices
|
||||
)
|
||||
dash.json_metadata = '{"mock_key": "mock_value"}'
|
||||
db.session.commit()
|
||||
return dash
|
||||
|
||||
|
||||
def _cleanup(dash_id: int, slices_ids: list[int]) -> None:
|
||||
dash = db.session.query(Dashboard).filter_by(id=dash_id).first()
|
||||
db.session.delete(dash)
|
||||
for slice_id in slices_ids:
|
||||
db.session.query(Slice).filter_by(id=slice_id).delete()
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def _cleanup_reports(dash_id: int, slices_ids: list[int]) -> None:
|
||||
|
|
|
@ -215,8 +215,6 @@ class TestRowLevelSecurity(SupersetTestCase):
|
|||
},
|
||||
)
|
||||
self.assertEqual(rv.status_code, 422)
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
assert "Create failed" in data["message"]
|
||||
|
||||
@pytest.mark.usefixtures("create_dataset")
|
||||
def test_model_view_rls_add_tables_required(self):
|
||||
|
|
|
@ -71,10 +71,8 @@ QUERY_3 = "SELECT * FROM birth_names LIMIT 10"
|
|||
class TestSqlLab(SupersetTestCase):
|
||||
"""Testings for Sql Lab"""
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_data")
|
||||
def run_some_queries(self):
|
||||
db.session.query(Query).delete()
|
||||
db.session.commit()
|
||||
self.run_sql(QUERY_1, client_id="client_id_1", username="admin")
|
||||
self.run_sql(QUERY_2, client_id="client_id_2", username="admin")
|
||||
self.run_sql(QUERY_3, client_id="client_id_3", username="gamma_sqllab")
|
||||
|
@ -419,6 +417,7 @@ class TestSqlLab(SupersetTestCase):
|
|||
self.assertEqual(len(data["data"]), 1200)
|
||||
self.assertEqual(data["query"]["limitingFactor"], LimitingFactor.NOT_LIMITED)
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_data")
|
||||
def test_query_api_filter(self) -> None:
|
||||
"""
|
||||
Test query api without can_only_access_owned_queries perm added to
|
||||
|
@ -438,6 +437,7 @@ class TestSqlLab(SupersetTestCase):
|
|||
assert admin.first_name in user_queries
|
||||
assert gamma_sqllab.first_name in user_queries
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_data")
|
||||
def test_query_api_can_access_all_queries(self) -> None:
|
||||
"""
|
||||
Test query api with can_access_all_queries perm added to
|
||||
|
@ -522,6 +522,7 @@ class TestSqlLab(SupersetTestCase):
|
|||
{r.get("sql") for r in self.get_json_resp(url)["result"]},
|
||||
)
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_data")
|
||||
def test_query_admin_can_access_all_queries(self) -> None:
|
||||
"""
|
||||
Test query api with all_query_access perm added to
|
||||
|
|
|
@ -187,6 +187,7 @@ class TestTagsDAO(SupersetTestCase):
|
|||
TaggedObject.object_type == ObjectType.chart,
|
||||
),
|
||||
)
|
||||
.join(Tag, TaggedObject.tag_id == Tag.id)
|
||||
.distinct(Slice.id)
|
||||
.count()
|
||||
)
|
||||
|
@ -199,6 +200,7 @@ class TestTagsDAO(SupersetTestCase):
|
|||
TaggedObject.object_type == ObjectType.dashboard,
|
||||
),
|
||||
)
|
||||
.join(Tag, TaggedObject.tag_id == Tag.id)
|
||||
.distinct(Dashboard.id)
|
||||
.count()
|
||||
+ num_charts
|
||||
|
|
|
@ -29,7 +29,6 @@ def database_with_catalog(mocker: MockerFixture) -> MagicMock:
|
|||
"""
|
||||
Mock a database with catalogs and schemas.
|
||||
"""
|
||||
mocker.patch("superset.commands.database.create.db")
|
||||
mocker.patch("superset.commands.database.create.TestConnectionDatabaseCommand")
|
||||
|
||||
database = mocker.MagicMock()
|
||||
|
@ -53,7 +52,6 @@ def database_without_catalog(mocker: MockerFixture) -> MagicMock:
|
|||
"""
|
||||
Mock a database without catalogs.
|
||||
"""
|
||||
mocker.patch("superset.commands.database.create.db")
|
||||
mocker.patch("superset.commands.database.create.TestConnectionDatabaseCommand")
|
||||
|
||||
database = mocker.MagicMock()
|
||||
|
|
|
@ -29,8 +29,6 @@ def database_with_catalog(mocker: MockerFixture) -> MagicMock:
|
|||
"""
|
||||
Mock a database with catalogs and schemas.
|
||||
"""
|
||||
mocker.patch("superset.commands.database.update.db")
|
||||
|
||||
database = mocker.MagicMock()
|
||||
database.database_name = "my_db"
|
||||
database.db_engine_spec.__name__ = "test_engine"
|
||||
|
@ -50,8 +48,6 @@ def database_without_catalog(mocker: MockerFixture) -> MagicMock:
|
|||
"""
|
||||
Mock a database without catalogs.
|
||||
"""
|
||||
mocker.patch("superset.commands.database.update.db")
|
||||
|
||||
database = mocker.MagicMock()
|
||||
database.database_name = "my_db"
|
||||
database.db_engine_spec.__name__ = "test_engine"
|
||||
|
|
|
@ -22,7 +22,6 @@ def test_user_favorite_tag(mocker):
|
|||
from superset.daos.tag import TagDAO
|
||||
|
||||
# Mock the behavior of TagDAO and g
|
||||
mock_session = mocker.patch("superset.daos.tag.db.session")
|
||||
mock_TagDAO = mocker.patch(
|
||||
"superset.daos.tag.TagDAO"
|
||||
) # Replace with the actual path to TagDAO
|
||||
|
@ -45,7 +44,6 @@ def test_remove_user_favorite_tag(mocker):
|
|||
from superset.daos.tag import TagDAO
|
||||
|
||||
# Mock the behavior of TagDAO and g
|
||||
mock_session = mocker.patch("superset.daos.tag.db.session")
|
||||
mock_TagDAO = mocker.patch("superset.daos.tag.TagDAO")
|
||||
mock_tag = mocker.MagicMock(users_favorited=[])
|
||||
mock_TagDAO.find_by_id.return_value = mock_tag
|
||||
|
|
|
@ -116,7 +116,7 @@ def test_post_with_uuid(
|
|||
assert payload["result"]["uuid"] == "7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb"
|
||||
|
||||
database = session.query(Database).one()
|
||||
assert database.uuid == "7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb"
|
||||
assert database.uuid == UUID("7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb")
|
||||
|
||||
|
||||
def test_password_mask(
|
||||
|
|
|
@ -22,8 +22,8 @@ from uuid import UUID
|
|||
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from superset import db
|
||||
from superset.exceptions import CreateKeyValueDistributedLockFailedException
|
||||
from superset.key_value.types import JsonKeyValueCodec
|
||||
from superset.utils.lock import get_key, KeyValueDistributedLock
|
||||
|
@ -32,56 +32,51 @@ MAIN_KEY = get_key("ns", a=1, b=2)
|
|||
OTHER_KEY = get_key("ns2", a=1, b=2)
|
||||
|
||||
|
||||
def _get_lock(key: UUID, session: Session) -> Any:
|
||||
def _get_lock(key: UUID) -> Any:
|
||||
from superset.key_value.models import KeyValueEntry
|
||||
|
||||
entry = session.query(KeyValueEntry).filter_by(uuid=key).first()
|
||||
entry = db.session.query(KeyValueEntry).filter_by(uuid=key).first()
|
||||
if entry is None or entry.is_expired():
|
||||
return None
|
||||
|
||||
return JsonKeyValueCodec().decode(entry.value)
|
||||
|
||||
|
||||
def _get_other_session() -> Session:
|
||||
# This session is used to simulate what another worker will find in the metastore
|
||||
# during the locking process.
|
||||
from superset import db
|
||||
|
||||
bind = db.session.get_bind()
|
||||
SessionMaker = sessionmaker(bind=bind)
|
||||
return SessionMaker()
|
||||
|
||||
|
||||
def test_key_value_distributed_lock_happy_path() -> None:
|
||||
"""
|
||||
Test successfully acquiring and returning the distributed lock.
|
||||
|
||||
Note we use a nested transaction to ensure that the cleanup from the outer context
|
||||
manager is correctly invoked, otherwise a partial rollback would occur leaving the
|
||||
database in a fractured state.
|
||||
"""
|
||||
session = _get_other_session()
|
||||
|
||||
with freeze_time("2021-01-01"):
|
||||
assert _get_lock(MAIN_KEY, session) is None
|
||||
assert _get_lock(MAIN_KEY) is None
|
||||
|
||||
with KeyValueDistributedLock("ns", a=1, b=2) as key:
|
||||
assert key == MAIN_KEY
|
||||
assert _get_lock(key, session) is True
|
||||
assert _get_lock(OTHER_KEY, session) is None
|
||||
with pytest.raises(CreateKeyValueDistributedLockFailedException):
|
||||
with KeyValueDistributedLock("ns", a=1, b=2):
|
||||
pass
|
||||
assert _get_lock(key) is True
|
||||
assert _get_lock(OTHER_KEY) is None
|
||||
|
||||
assert _get_lock(MAIN_KEY, session) is None
|
||||
with db.session.begin_nested():
|
||||
with pytest.raises(CreateKeyValueDistributedLockFailedException):
|
||||
with KeyValueDistributedLock("ns", a=1, b=2):
|
||||
pass
|
||||
|
||||
assert _get_lock(MAIN_KEY) is None
|
||||
|
||||
|
||||
def test_key_value_distributed_lock_expired() -> None:
|
||||
"""
|
||||
Test expiration of the distributed lock
|
||||
"""
|
||||
session = _get_other_session()
|
||||
|
||||
with freeze_time("2021-01-01T"):
|
||||
assert _get_lock(MAIN_KEY, session) is None
|
||||
with freeze_time("2021-01-01"):
|
||||
assert _get_lock(MAIN_KEY) is None
|
||||
with KeyValueDistributedLock("ns", a=1, b=2):
|
||||
assert _get_lock(MAIN_KEY, session) is True
|
||||
with freeze_time("2022-01-01T"):
|
||||
assert _get_lock(MAIN_KEY, session) is None
|
||||
assert _get_lock(MAIN_KEY) is True
|
||||
with freeze_time("2022-01-01"):
|
||||
assert _get_lock(MAIN_KEY) is None
|
||||
|
||||
assert _get_lock(MAIN_KEY, session) is None
|
||||
assert _get_lock(MAIN_KEY) is None
|
||||
|
|
Loading…
Reference in New Issue