mirror of https://github.com/apache/superset.git
Revert interim test updates
This commit is contained in:
parent
acfba3da29
commit
3e4325a176
|
@ -239,8 +239,8 @@ basepython = python3.10
|
||||||
ignore_basepython_conflict = true
|
ignore_basepython_conflict = true
|
||||||
commands =
|
commands =
|
||||||
superset db upgrade
|
superset db upgrade
|
||||||
superset load_test_users
|
|
||||||
superset init
|
superset init
|
||||||
|
superset load-test-users
|
||||||
# use -s to be able to use break pointers.
|
# use -s to be able to use break pointers.
|
||||||
# no args or tests/* can be passed as an argument to run all tests
|
# no args or tests/* can be passed as an argument to run all tests
|
||||||
pytest -s {posargs}
|
pytest -s {posargs}
|
||||||
|
|
|
@ -14,7 +14,6 @@
|
||||||
# KIND, either express or implied. See the License for the
|
# KIND, either express or implied. See the License for the
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
# pylint: disable=consider-using-transaction
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
from superset import db, security_manager
|
from superset import db, security_manager
|
||||||
|
|
|
@ -28,9 +28,9 @@ export SUPERSET_TESTENV=true
|
||||||
echo "Superset config module: $SUPERSET_CONFIG"
|
echo "Superset config module: $SUPERSET_CONFIG"
|
||||||
|
|
||||||
superset db upgrade
|
superset db upgrade
|
||||||
superset load_test_users
|
|
||||||
superset init
|
superset init
|
||||||
|
superset load-test-users
|
||||||
|
|
||||||
echo "Running tests"
|
echo "Running tests"
|
||||||
|
|
||||||
pytest --durations-min=2 --cov-report= --cov=superset ./tests/integration_tests "$@"
|
pytest --durations-min=2 --maxfail=1 --cov-report= --cov=superset ./tests/integration_tests "$@"
|
||||||
|
|
|
@ -37,15 +37,7 @@ def load_test_users() -> None:
|
||||||
Syncs permissions for those users/roles
|
Syncs permissions for those users/roles
|
||||||
"""
|
"""
|
||||||
print(Fore.GREEN + "Loading a set of users for unit tests")
|
print(Fore.GREEN + "Loading a set of users for unit tests")
|
||||||
load_test_users_run()
|
|
||||||
|
|
||||||
|
|
||||||
def load_test_users_run() -> None:
|
|
||||||
"""
|
|
||||||
Loads admin, alpha, and gamma user for testing purposes
|
|
||||||
|
|
||||||
Syncs permissions for those users/roles
|
|
||||||
"""
|
|
||||||
if app.config["TESTING"]:
|
if app.config["TESTING"]:
|
||||||
sm = security_manager
|
sm = security_manager
|
||||||
|
|
||||||
|
|
|
@ -77,7 +77,7 @@ def import_chart(
|
||||||
if chart.id is None:
|
if chart.id is None:
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
|
|
||||||
if user := get_user():
|
if (user := get_user()) and user not in chart.owners:
|
||||||
chart.owners.append(user)
|
chart.owners.append(user)
|
||||||
|
|
||||||
return chart
|
return chart
|
||||||
|
|
|
@ -38,8 +38,8 @@ from superset.daos.chart import ChartDAO
|
||||||
from superset.daos.dashboard import DashboardDAO
|
from superset.daos.dashboard import DashboardDAO
|
||||||
from superset.exceptions import SupersetSecurityException
|
from superset.exceptions import SupersetSecurityException
|
||||||
from superset.models.slice import Slice
|
from superset.models.slice import Slice
|
||||||
from superset.utils.decorators import on_error, transaction
|
|
||||||
from superset.tags.models import ObjectType
|
from superset.tags.models import ObjectType
|
||||||
|
from superset.utils.decorators import on_error, transaction
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -62,14 +62,13 @@ class UpdateChartCommand(UpdateMixin, BaseCommand):
|
||||||
assert self._model
|
assert self._model
|
||||||
|
|
||||||
# Update tags
|
# Update tags
|
||||||
tags = self._properties.pop("tags", None)
|
if (tags := self._properties.pop("tags", None)) is not None:
|
||||||
if tags is not None:
|
|
||||||
update_tags(ObjectType.chart, self._model.id, self._model.tags, tags)
|
update_tags(ObjectType.chart, self._model.id, self._model.tags, tags)
|
||||||
|
|
||||||
if self._properties.get("query_context_generation") is None:
|
if self._properties.get("query_context_generation") is None:
|
||||||
self._properties["last_saved_at"] = datetime.now()
|
self._properties["last_saved_at"] = datetime.now()
|
||||||
self._properties["last_saved_by"] = g.user
|
self._properties["last_saved_by"] = g.user
|
||||||
|
|
||||||
return ChartDAO.update(self._model, self._properties)
|
return ChartDAO.update(self._model, self._properties)
|
||||||
|
|
||||||
def validate(self) -> None:
|
def validate(self) -> None:
|
||||||
|
|
|
@ -188,7 +188,7 @@ def import_dashboard(
|
||||||
if dashboard.id is None:
|
if dashboard.id is None:
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
|
|
||||||
if user := get_user():
|
if (user := get_user()) and user not in dashboard.owners:
|
||||||
dashboard.owners.append(user)
|
dashboard.owners.append(user)
|
||||||
|
|
||||||
return dashboard
|
return dashboard
|
||||||
|
|
|
@ -19,7 +19,6 @@ from functools import partial
|
||||||
|
|
||||||
from sqlalchemy.exc import SQLAlchemyError
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
|
||||||
from superset import db
|
|
||||||
from superset.commands.dashboard.permalink.base import BaseDashboardPermalinkCommand
|
from superset.commands.dashboard.permalink.base import BaseDashboardPermalinkCommand
|
||||||
from superset.commands.key_value.upsert import UpsertKeyValueCommand
|
from superset.commands.key_value.upsert import UpsertKeyValueCommand
|
||||||
from superset.daos.dashboard import DashboardDAO
|
from superset.daos.dashboard import DashboardDAO
|
||||||
|
@ -78,6 +77,7 @@ class CreateDashboardPermalinkCommand(BaseDashboardPermalinkCommand):
|
||||||
codec=self.codec,
|
codec=self.codec,
|
||||||
).run()
|
).run()
|
||||||
assert key.id # for type checks
|
assert key.id # for type checks
|
||||||
|
return encode_permalink_key(key=key.id, salt=self.salt)
|
||||||
|
|
||||||
def validate(self) -> None:
|
def validate(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -53,19 +53,16 @@ class UpdateDashboardCommand(UpdateMixin, BaseCommand):
|
||||||
assert self._model
|
assert self._model
|
||||||
|
|
||||||
# Update tags
|
# Update tags
|
||||||
tags = self._properties.pop("tags", None)
|
if (tags := self._properties.pop("tags", None)) is not None:
|
||||||
if tags is not None:
|
update_tags(ObjectType.dashboard, self._model.id, self._model.tags, tags)
|
||||||
update_tags(
|
|
||||||
ObjectType.dashboard, self._model.id, self._model.tags, tags
|
|
||||||
)
|
|
||||||
|
|
||||||
dashboard = DashboardDAO.update(self._model, self._properties, commit=False)
|
dashboard = DashboardDAO.update(self._model, self._properties)
|
||||||
if self._properties.get("json_metadata"):
|
if self._properties.get("json_metadata"):
|
||||||
DashboardDAO.set_dash_metadata(
|
DashboardDAO.set_dash_metadata(
|
||||||
dashboard,
|
dashboard,
|
||||||
data=json.loads(self._properties.get("json_metadata", "{}")),
|
data=json.loads(self._properties.get("json_metadata", "{}")),
|
||||||
)
|
)
|
||||||
|
|
||||||
return dashboard
|
return dashboard
|
||||||
|
|
||||||
def validate(self) -> None:
|
def validate(self) -> None:
|
||||||
|
|
|
@ -40,7 +40,7 @@ from superset.commands.database.ssh_tunnel.exceptions import (
|
||||||
)
|
)
|
||||||
from superset.commands.database.test_connection import TestConnectionDatabaseCommand
|
from superset.commands.database.test_connection import TestConnectionDatabaseCommand
|
||||||
from superset.daos.database import DatabaseDAO
|
from superset.daos.database import DatabaseDAO
|
||||||
from superset.daos.exceptions import DAOCreateFailedError
|
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||||
from superset.exceptions import SupersetErrorsException
|
from superset.exceptions import SupersetErrorsException
|
||||||
from superset.extensions import event_logger, security_manager
|
from superset.extensions import event_logger, security_manager
|
||||||
from superset.models.core import Database
|
from superset.models.core import Database
|
||||||
|
|
|
@ -1,194 +0,0 @@
|
||||||
# Licensed to the Apache Software Foundation (ASF) under one
|
|
||||||
# or more contributor license agreements. See the NOTICE file
|
|
||||||
# distributed with this work for additional information
|
|
||||||
# regarding copyright ownership. The ASF licenses this file
|
|
||||||
# to you under the Apache License, Version 2.0 (the
|
|
||||||
# "License"); you may not use this file except in compliance
|
|
||||||
# with the License. You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing,
|
|
||||||
# software distributed under the License is distributed on an
|
|
||||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
||||||
# KIND, either express or implied. See the License for the
|
|
||||||
# specific language governing permissions and limitations
|
|
||||||
# under the License.
|
|
||||||
import logging
|
|
||||||
from functools import partial
|
|
||||||
from typing import Any, Optional, TypedDict
|
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
from flask_babel import lazy_gettext as _
|
|
||||||
|
|
||||||
from superset import db
|
|
||||||
from superset.commands.base import BaseCommand
|
|
||||||
from superset.commands.database.exceptions import (
|
|
||||||
DatabaseNotFoundError,
|
|
||||||
DatabaseSchemaUploadNotAllowed,
|
|
||||||
DatabaseUploadFailed,
|
|
||||||
DatabaseUploadSaveMetadataFailed,
|
|
||||||
)
|
|
||||||
from superset.connectors.sqla.models import SqlaTable
|
|
||||||
from superset.daos.database import DatabaseDAO
|
|
||||||
from superset.models.core import Database
|
|
||||||
from superset.sql_parse import Table
|
|
||||||
from superset.utils.core import get_user
|
|
||||||
from superset.utils.decorators import on_error, transaction
|
|
||||||
from superset.views.database.validators import schema_allows_file_upload
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
READ_CSV_CHUNK_SIZE = 1000
|
|
||||||
|
|
||||||
|
|
||||||
class CSVImportOptions(TypedDict, total=False):
|
|
||||||
schema: str
|
|
||||||
delimiter: str
|
|
||||||
already_exists: str
|
|
||||||
column_data_types: dict[str, str]
|
|
||||||
column_dates: list[str]
|
|
||||||
column_labels: str
|
|
||||||
columns_read: list[str]
|
|
||||||
dataframe_index: str
|
|
||||||
day_first: bool
|
|
||||||
decimal_character: str
|
|
||||||
header_row: int
|
|
||||||
index_column: str
|
|
||||||
null_values: list[str]
|
|
||||||
overwrite_duplicates: bool
|
|
||||||
rows_to_read: int
|
|
||||||
skip_blank_lines: bool
|
|
||||||
skip_initial_space: bool
|
|
||||||
skip_rows: int
|
|
||||||
|
|
||||||
|
|
||||||
class CSVImportCommand(BaseCommand):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: int,
|
|
||||||
table_name: str,
|
|
||||||
file: Any,
|
|
||||||
options: CSVImportOptions,
|
|
||||||
) -> None:
|
|
||||||
self._model_id = model_id
|
|
||||||
self._model: Optional[Database] = None
|
|
||||||
self._table_name = table_name
|
|
||||||
self._schema = options.get("schema")
|
|
||||||
self._file = file
|
|
||||||
self._options = options
|
|
||||||
|
|
||||||
def _read_csv(self) -> pd.DataFrame:
|
|
||||||
"""
|
|
||||||
Read CSV file into a DataFrame
|
|
||||||
|
|
||||||
:return: pandas DataFrame
|
|
||||||
:throws DatabaseUploadFailed: if there is an error reading the CSV file
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return pd.concat(
|
|
||||||
pd.read_csv(
|
|
||||||
chunksize=READ_CSV_CHUNK_SIZE,
|
|
||||||
encoding="utf-8",
|
|
||||||
filepath_or_buffer=self._file,
|
|
||||||
header=self._options.get("header_row", 0),
|
|
||||||
index_col=self._options.get("index_column"),
|
|
||||||
dayfirst=self._options.get("day_first", False),
|
|
||||||
iterator=True,
|
|
||||||
keep_default_na=not self._options.get("null_values"),
|
|
||||||
usecols=self._options.get("columns_read")
|
|
||||||
if self._options.get("columns_read") # None if an empty list
|
|
||||||
else None,
|
|
||||||
na_values=self._options.get("null_values")
|
|
||||||
if self._options.get("null_values") # None if an empty list
|
|
||||||
else None,
|
|
||||||
nrows=self._options.get("rows_to_read"),
|
|
||||||
parse_dates=self._options.get("column_dates"),
|
|
||||||
sep=self._options.get("delimiter", ","),
|
|
||||||
skip_blank_lines=self._options.get("skip_blank_lines", False),
|
|
||||||
skipinitialspace=self._options.get("skip_initial_space", False),
|
|
||||||
skiprows=self._options.get("skip_rows", 0),
|
|
||||||
dtype=self._options.get("column_data_types")
|
|
||||||
if self._options.get("column_data_types")
|
|
||||||
else None,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except (
|
|
||||||
pd.errors.ParserError,
|
|
||||||
pd.errors.EmptyDataError,
|
|
||||||
UnicodeDecodeError,
|
|
||||||
ValueError,
|
|
||||||
) as ex:
|
|
||||||
raise DatabaseUploadFailed(
|
|
||||||
message=_("Parsing error: %(error)s", error=str(ex))
|
|
||||||
) from ex
|
|
||||||
except Exception as ex:
|
|
||||||
raise DatabaseUploadFailed(_("Error reading CSV file")) from ex
|
|
||||||
|
|
||||||
def _dataframe_to_database(self, df: pd.DataFrame, database: Database) -> None:
|
|
||||||
"""
|
|
||||||
Upload DataFrame to database
|
|
||||||
|
|
||||||
:param df:
|
|
||||||
:throws DatabaseUploadFailed: if there is an error uploading the DataFrame
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
csv_table = Table(table=self._table_name, schema=self._schema)
|
|
||||||
database.db_engine_spec.df_to_sql(
|
|
||||||
database,
|
|
||||||
csv_table,
|
|
||||||
df,
|
|
||||||
to_sql_kwargs={
|
|
||||||
"chunksize": READ_CSV_CHUNK_SIZE,
|
|
||||||
"if_exists": self._options.get("already_exists", "fail"),
|
|
||||||
"index": self._options.get("index_column"),
|
|
||||||
"index_label": self._options.get("column_labels"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
except ValueError as ex:
|
|
||||||
raise DatabaseUploadFailed(
|
|
||||||
message=_(
|
|
||||||
"Table already exists. You can change your "
|
|
||||||
"'if table already exists' strategy to append or "
|
|
||||||
"replace or provide a different Table Name to use."
|
|
||||||
)
|
|
||||||
) from ex
|
|
||||||
except Exception as ex:
|
|
||||||
raise DatabaseUploadFailed(exception=ex) from ex
|
|
||||||
|
|
||||||
@transaction(on_error=partial(on_error, reraise=DatabaseUploadSaveMetadataFailed))
|
|
||||||
def run(self) -> None:
|
|
||||||
self.validate()
|
|
||||||
if not self._model:
|
|
||||||
return
|
|
||||||
|
|
||||||
df = self._read_csv()
|
|
||||||
self._dataframe_to_database(df, self._model)
|
|
||||||
|
|
||||||
sqla_table = (
|
|
||||||
db.session.query(SqlaTable)
|
|
||||||
.filter_by(
|
|
||||||
table_name=self._table_name,
|
|
||||||
schema=self._schema,
|
|
||||||
database_id=self._model_id,
|
|
||||||
)
|
|
||||||
.one_or_none()
|
|
||||||
)
|
|
||||||
if not sqla_table:
|
|
||||||
sqla_table = SqlaTable(
|
|
||||||
table_name=self._table_name,
|
|
||||||
database=self._model,
|
|
||||||
database_id=self._model_id,
|
|
||||||
owners=[get_user()],
|
|
||||||
schema=self._schema,
|
|
||||||
)
|
|
||||||
db.session.add(sqla_table)
|
|
||||||
|
|
||||||
sqla_table.fetch_metadata()
|
|
||||||
|
|
||||||
def validate(self) -> None:
|
|
||||||
self._model = DatabaseDAO.find_by_id(self._model_id)
|
|
||||||
if not self._model:
|
|
||||||
raise DatabaseNotFoundError()
|
|
||||||
if not schema_allows_file_upload(self._model, self._schema):
|
|
||||||
raise DatabaseSchemaUploadNotAllowed()
|
|
|
@ -41,7 +41,6 @@ from superset.commands.database.ssh_tunnel.update import UpdateSSHTunnelCommand
|
||||||
from superset.daos.database import DatabaseDAO
|
from superset.daos.database import DatabaseDAO
|
||||||
from superset.daos.dataset import DatasetDAO
|
from superset.daos.dataset import DatasetDAO
|
||||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||||
from superset.extensions import db
|
|
||||||
from superset.models.core import Database
|
from superset.models.core import Database
|
||||||
from superset.utils.decorators import on_error, transaction
|
from superset.utils.decorators import on_error, transaction
|
||||||
|
|
||||||
|
@ -78,11 +77,7 @@ class UpdateDatabaseCommand(BaseCommand):
|
||||||
original_database_name = self._model.database_name
|
original_database_name = self._model.database_name
|
||||||
|
|
||||||
try:
|
try:
|
||||||
database = DatabaseDAO.update(
|
database = DatabaseDAO.update(self._model, self._properties)
|
||||||
self._model,
|
|
||||||
self._properties,
|
|
||||||
commit=False,
|
|
||||||
)
|
|
||||||
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
|
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
|
||||||
ssh_tunnel = self._handle_ssh_tunnel(database)
|
ssh_tunnel = self._handle_ssh_tunnel(database)
|
||||||
self._refresh_catalogs(database, original_database_name, ssh_tunnel)
|
self._refresh_catalogs(database, original_database_name, ssh_tunnel)
|
||||||
|
@ -100,7 +95,6 @@ class UpdateDatabaseCommand(BaseCommand):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if not is_feature_enabled("SSH_TUNNELING"):
|
if not is_feature_enabled("SSH_TUNNELING"):
|
||||||
db.session.rollback()
|
|
||||||
raise SSHTunnelingNotEnabledError()
|
raise SSHTunnelingNotEnabledError()
|
||||||
|
|
||||||
current_ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
|
current_ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
|
||||||
|
@ -130,14 +124,11 @@ class UpdateDatabaseCommand(BaseCommand):
|
||||||
This method captures a generic exception, since errors could potentially come
|
This method captures a generic exception, since errors could potentially come
|
||||||
from any of the 50+ database drivers we support.
|
from any of the 50+ database drivers we support.
|
||||||
"""
|
"""
|
||||||
try:
|
|
||||||
return database.get_all_catalog_names(
|
return database.get_all_catalog_names(
|
||||||
force=True,
|
force=True,
|
||||||
ssh_tunnel=ssh_tunnel,
|
ssh_tunnel=ssh_tunnel,
|
||||||
)
|
)
|
||||||
except Exception as ex:
|
|
||||||
db.session.rollback()
|
|
||||||
raise DatabaseConnectionFailedError() from ex
|
|
||||||
|
|
||||||
def _get_schema_names(
|
def _get_schema_names(
|
||||||
self,
|
self,
|
||||||
|
@ -151,15 +142,12 @@ class UpdateDatabaseCommand(BaseCommand):
|
||||||
This method captures a generic exception, since errors could potentially come
|
This method captures a generic exception, since errors could potentially come
|
||||||
from any of the 50+ database drivers we support.
|
from any of the 50+ database drivers we support.
|
||||||
"""
|
"""
|
||||||
try:
|
|
||||||
return database.get_all_schema_names(
|
return database.get_all_schema_names(
|
||||||
force=True,
|
force=True,
|
||||||
catalog=catalog,
|
catalog=catalog,
|
||||||
ssh_tunnel=ssh_tunnel,
|
ssh_tunnel=ssh_tunnel,
|
||||||
)
|
)
|
||||||
except Exception as ex:
|
|
||||||
db.session.rollback()
|
|
||||||
raise DatabaseConnectionFailedError() from ex
|
|
||||||
|
|
||||||
def _refresh_catalogs(
|
def _refresh_catalogs(
|
||||||
self,
|
self,
|
||||||
|
@ -224,8 +212,6 @@ class UpdateDatabaseCommand(BaseCommand):
|
||||||
schemas,
|
schemas,
|
||||||
)
|
)
|
||||||
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
def _refresh_schemas(
|
def _refresh_schemas(
|
||||||
self,
|
self,
|
||||||
database: Database,
|
database: Database,
|
||||||
|
|
|
@ -43,7 +43,7 @@ class DeleteDatasetColumnCommand(BaseCommand):
|
||||||
def run(self) -> None:
|
def run(self) -> None:
|
||||||
self.validate()
|
self.validate()
|
||||||
assert self._model
|
assert self._model
|
||||||
DatasetColumnDAO.delete(self._model)
|
DatasetColumnDAO.delete([self._model])
|
||||||
|
|
||||||
def validate(self) -> None:
|
def validate(self) -> None:
|
||||||
# Validate/populate model exists
|
# Validate/populate model exists
|
||||||
|
|
|
@ -32,7 +32,7 @@ from superset.commands.dataset.exceptions import (
|
||||||
)
|
)
|
||||||
from superset.daos.dataset import DatasetDAO
|
from superset.daos.dataset import DatasetDAO
|
||||||
from superset.exceptions import SupersetSecurityException
|
from superset.exceptions import SupersetSecurityException
|
||||||
from superset.extensions import db, security_manager
|
from superset.extensions import security_manager
|
||||||
from superset.sql_parse import Table
|
from superset.sql_parse import Table
|
||||||
from superset.utils.decorators import on_error, transaction
|
from superset.utils.decorators import on_error, transaction
|
||||||
|
|
||||||
|
|
|
@ -178,7 +178,7 @@ def import_dataset(
|
||||||
if data_uri and (not table_exists or force_data):
|
if data_uri and (not table_exists or force_data):
|
||||||
load_data(data_uri, dataset, dataset.database)
|
load_data(data_uri, dataset, dataset.database)
|
||||||
|
|
||||||
if user := get_user():
|
if (user := get_user()) and user not in dataset.owners:
|
||||||
dataset.owners.append(user)
|
dataset.owners.append(user)
|
||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
|
|
|
@ -43,7 +43,7 @@ class DeleteDatasetMetricCommand(BaseCommand):
|
||||||
def run(self) -> None:
|
def run(self) -> None:
|
||||||
self.validate()
|
self.validate()
|
||||||
assert self._model
|
assert self._model
|
||||||
DatasetMetricDAO.delete(self._model)
|
DatasetMetricDAO.delete([self._model])
|
||||||
|
|
||||||
def validate(self) -> None:
|
def validate(self) -> None:
|
||||||
# Validate/populate model exists
|
# Validate/populate model exists
|
||||||
|
|
|
@ -21,6 +21,7 @@ from typing import Any, Optional
|
||||||
|
|
||||||
from flask_appbuilder.models.sqla import Model
|
from flask_appbuilder.models.sqla import Model
|
||||||
from marshmallow import ValidationError
|
from marshmallow import ValidationError
|
||||||
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
|
||||||
from superset import security_manager
|
from superset import security_manager
|
||||||
from superset.commands.base import BaseCommand, UpdateMixin
|
from superset.commands.base import BaseCommand, UpdateMixin
|
||||||
|
@ -60,7 +61,16 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand):
|
||||||
self.override_columns = override_columns
|
self.override_columns = override_columns
|
||||||
self._properties["override_columns"] = override_columns
|
self._properties["override_columns"] = override_columns
|
||||||
|
|
||||||
@transaction(on_error=partial(on_error, reraise=DatasetUpdateFailedError))
|
@transaction(
|
||||||
|
on_error=partial(
|
||||||
|
on_error,
|
||||||
|
catches=(
|
||||||
|
SQLAlchemyError,
|
||||||
|
ValueError,
|
||||||
|
),
|
||||||
|
reraise=DatasetUpdateFailedError,
|
||||||
|
)
|
||||||
|
)
|
||||||
def run(self) -> Model:
|
def run(self) -> Model:
|
||||||
self.validate()
|
self.validate()
|
||||||
assert self._model
|
assert self._model
|
||||||
|
|
|
@ -20,7 +20,6 @@ from typing import Any, Optional
|
||||||
|
|
||||||
from sqlalchemy.exc import SQLAlchemyError
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
|
||||||
from superset import db
|
|
||||||
from superset.commands.explore.permalink.base import BaseExplorePermalinkCommand
|
from superset.commands.explore.permalink.base import BaseExplorePermalinkCommand
|
||||||
from superset.commands.key_value.create import CreateKeyValueCommand
|
from superset.commands.key_value.create import CreateKeyValueCommand
|
||||||
from superset.explore.permalink.exceptions import ExplorePermalinkCreateFailedError
|
from superset.explore.permalink.exceptions import ExplorePermalinkCreateFailedError
|
||||||
|
@ -74,27 +73,7 @@ class CreateExplorePermalinkCommand(BaseExplorePermalinkCommand):
|
||||||
key = command.run()
|
key = command.run()
|
||||||
if key.id is None:
|
if key.id is None:
|
||||||
raise ExplorePermalinkCreateFailedError("Unexpected missing key id")
|
raise ExplorePermalinkCreateFailedError("Unexpected missing key id")
|
||||||
db.session.commit()
|
|
||||||
return encode_permalink_key(key=key.id, salt=self.salt)
|
return encode_permalink_key(key=key.id, salt=self.salt)
|
||||||
d_id, d_type = self.datasource.split("__")
|
|
||||||
datasource_id = int(d_id)
|
|
||||||
datasource_type = DatasourceType(d_type)
|
|
||||||
check_chart_access(datasource_id, self.chart_id, datasource_type)
|
|
||||||
value = {
|
|
||||||
"chartId": self.chart_id,
|
|
||||||
"datasourceId": datasource_id,
|
|
||||||
"datasourceType": datasource_type.value,
|
|
||||||
"datasource": self.datasource,
|
|
||||||
"state": self.state,
|
|
||||||
}
|
|
||||||
command = CreateKeyValueCommand(
|
|
||||||
resource=self.resource,
|
|
||||||
value=value,
|
|
||||||
codec=self.codec,
|
|
||||||
)
|
|
||||||
key = command.run()
|
|
||||||
return encode_permalink_key(key=key.id, salt=self.salt)
|
|
||||||
>>>>>>> c01dacb71a (chore(dao): Use nested session for operations)
|
|
||||||
|
|
||||||
def validate(self) -> None:
|
def validate(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
# KIND, either express or implied. See the License for the
|
# KIND, either express or implied. See the License for the
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
|
from functools import partial
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from marshmallow import Schema
|
from marshmallow import Schema
|
||||||
|
@ -44,7 +45,7 @@ from superset.datasets.schemas import ImportV1DatasetSchema
|
||||||
from superset.migrations.shared.native_filters import migrate_dashboard
|
from superset.migrations.shared.native_filters import migrate_dashboard
|
||||||
from superset.models.dashboard import dashboard_slices
|
from superset.models.dashboard import dashboard_slices
|
||||||
from superset.queries.saved_queries.schemas import ImportV1SavedQuerySchema
|
from superset.queries.saved_queries.schemas import ImportV1SavedQuerySchema
|
||||||
from superset.utils.decorators import transaction
|
from superset.utils.decorators import on_error, transaction
|
||||||
|
|
||||||
|
|
||||||
class ImportAssetsCommand(BaseCommand):
|
class ImportAssetsCommand(BaseCommand):
|
||||||
|
@ -154,14 +155,16 @@ class ImportAssetsCommand(BaseCommand):
|
||||||
if chart.viz_type == "filter_box":
|
if chart.viz_type == "filter_box":
|
||||||
db.session.delete(chart)
|
db.session.delete(chart)
|
||||||
|
|
||||||
@transaction()
|
@transaction(
|
||||||
|
on_error=partial(
|
||||||
|
on_error,
|
||||||
|
catches=(Exception,),
|
||||||
|
reraise=ImportFailedError,
|
||||||
|
)
|
||||||
|
)
|
||||||
def run(self) -> None:
|
def run(self) -> None:
|
||||||
self.validate()
|
self.validate()
|
||||||
|
self._import(self._configs)
|
||||||
try:
|
|
||||||
self._import(self._configs)
|
|
||||||
except Exception as ex:
|
|
||||||
raise ImportFailedError() from ex
|
|
||||||
|
|
||||||
def validate(self) -> None:
|
def validate(self) -> None:
|
||||||
exceptions: list[ValidationError] = []
|
exceptions: list[ValidationError] = []
|
||||||
|
|
|
@ -53,9 +53,11 @@ class DeleteKeyValueCommand(BaseCommand):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def delete(self) -> bool:
|
def delete(self) -> bool:
|
||||||
filter_ = get_filter(self.resource, self.key)
|
if (
|
||||||
if entry := db.session.query(KeyValueEntry).filter_by(**filter_).first():
|
entry := db.session.query(KeyValueEntry)
|
||||||
|
.filter_by(**get_filter(self.resource, self.key))
|
||||||
|
.first()
|
||||||
|
):
|
||||||
db.session.delete(entry)
|
db.session.delete(entry)
|
||||||
db.session.flush()
|
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
|
@ -60,4 +60,3 @@ class DeleteExpiredKeyValueCommand(BaseCommand):
|
||||||
)
|
)
|
||||||
.delete()
|
.delete()
|
||||||
)
|
)
|
||||||
db.session.flush()
|
|
||||||
|
|
|
@ -21,6 +21,8 @@ from functools import partial
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
|
||||||
from superset import db
|
from superset import db
|
||||||
from superset.commands.base import BaseCommand
|
from superset.commands.base import BaseCommand
|
||||||
from superset.commands.key_value.create import CreateKeyValueCommand
|
from superset.commands.key_value.create import CreateKeyValueCommand
|
||||||
|
@ -71,7 +73,7 @@ class UpsertKeyValueCommand(BaseCommand):
|
||||||
@transaction(
|
@transaction(
|
||||||
on_error=partial(
|
on_error=partial(
|
||||||
on_error,
|
on_error,
|
||||||
catches=(KeyValueCreateFailedError,),
|
catches=(KeyValueCreateFailedError, SQLAlchemyError),
|
||||||
reraise=KeyValueUpsertFailedError,
|
reraise=KeyValueUpsertFailedError,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@ -82,16 +84,15 @@ class UpsertKeyValueCommand(BaseCommand):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def upsert(self) -> Key:
|
def upsert(self) -> Key:
|
||||||
filter_ = get_filter(self.resource, self.key)
|
if (
|
||||||
entry: KeyValueEntry = (
|
entry := db.session.query(KeyValueEntry)
|
||||||
db.session.query(KeyValueEntry).filter_by(**filter_).first()
|
.filter_by(**get_filter(self.resource, self.key))
|
||||||
)
|
.first()
|
||||||
if entry:
|
):
|
||||||
entry.value = self.codec.encode(self.value)
|
entry.value = self.codec.encode(self.value)
|
||||||
entry.expires_on = self.expires_on
|
entry.expires_on = self.expires_on
|
||||||
entry.changed_on = datetime.now()
|
entry.changed_on = datetime.now()
|
||||||
entry.changed_by_fk = get_user_id()
|
entry.changed_by_fk = get_user_id()
|
||||||
db.session.flush()
|
|
||||||
return Key(entry.id, entry.uuid)
|
return Key(entry.id, entry.uuid)
|
||||||
|
|
||||||
return CreateKeyValueCommand(
|
return CreateKeyValueCommand(
|
||||||
|
|
|
@ -137,6 +137,7 @@ class BaseReportState:
|
||||||
uuid=self._execution_id,
|
uuid=self._execution_id,
|
||||||
)
|
)
|
||||||
db.session.add(log)
|
db.session.add(log)
|
||||||
|
db.session.commit() # pylint: disable=consider-using-transaction
|
||||||
|
|
||||||
def _get_url(
|
def _get_url(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -49,7 +49,6 @@ class AsyncPruneReportScheduleLogCommand(BaseCommand):
|
||||||
report_schedule,
|
report_schedule,
|
||||||
from_date,
|
from_date,
|
||||||
)
|
)
|
||||||
db.session.commit()
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Deleted %s logs for report schedule id: %s",
|
"Deleted %s logs for report schedule id: %s",
|
||||||
str(row_count),
|
str(row_count),
|
||||||
|
|
|
@ -383,5 +383,4 @@ class TagDAO(BaseDAO[Tag]):
|
||||||
object_id,
|
object_id,
|
||||||
tag.name,
|
tag.name,
|
||||||
)
|
)
|
||||||
|
|
||||||
db.session.add_all(tagged_objects)
|
db.session.add_all(tagged_objects)
|
||||||
|
|
|
@ -32,7 +32,7 @@ from marshmallow import ValidationError
|
||||||
from werkzeug.wrappers import Response as WerkzeugResponse
|
from werkzeug.wrappers import Response as WerkzeugResponse
|
||||||
from werkzeug.wsgi import FileWrapper
|
from werkzeug.wsgi import FileWrapper
|
||||||
|
|
||||||
from superset import is_feature_enabled, thumbnail_cache
|
from superset import db, is_feature_enabled, thumbnail_cache
|
||||||
from superset.charts.schemas import ChartEntityResponseSchema
|
from superset.charts.schemas import ChartEntityResponseSchema
|
||||||
from superset.commands.dashboard.create import CreateDashboardCommand
|
from superset.commands.dashboard.create import CreateDashboardCommand
|
||||||
from superset.commands.dashboard.delete import DeleteDashboardCommand
|
from superset.commands.dashboard.delete import DeleteDashboardCommand
|
||||||
|
@ -1314,7 +1314,13 @@ class DashboardRestApi(BaseSupersetModelRestApi):
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
body = self.embedded_config_schema.load(request.json)
|
body = self.embedded_config_schema.load(request.json)
|
||||||
embedded = EmbeddedDashboardDAO.upsert(dashboard, body["allowed_domains"])
|
|
||||||
|
with db.session.begin_nested():
|
||||||
|
embedded = EmbeddedDashboardDAO.upsert(
|
||||||
|
dashboard,
|
||||||
|
body["allowed_domains"],
|
||||||
|
)
|
||||||
|
|
||||||
result = self.embedded_response_schema.dump(embedded)
|
result = self.embedded_response_schema.dump(embedded)
|
||||||
return self.response(200, result=result)
|
return self.response(200, result=result)
|
||||||
except ValidationError as error:
|
except ValidationError as error:
|
||||||
|
|
|
@ -22,7 +22,6 @@ from uuid import UUID, uuid3
|
||||||
from flask import current_app, Flask, has_app_context
|
from flask import current_app, Flask, has_app_context
|
||||||
from flask_caching import BaseCache
|
from flask_caching import BaseCache
|
||||||
|
|
||||||
from superset import db
|
|
||||||
from superset.key_value.exceptions import KeyValueCreateFailedError
|
from superset.key_value.exceptions import KeyValueCreateFailedError
|
||||||
from superset.key_value.types import (
|
from superset.key_value.types import (
|
||||||
KeyValueCodec,
|
KeyValueCodec,
|
||||||
|
@ -95,7 +94,6 @@ class SupersetMetastoreCache(BaseCache):
|
||||||
codec=self.codec,
|
codec=self.codec,
|
||||||
expires_on=self._get_expiry(timeout),
|
expires_on=self._get_expiry(timeout),
|
||||||
).run()
|
).run()
|
||||||
db.session.commit()
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def add(self, key: str, value: Any, timeout: Optional[int] = None) -> bool:
|
def add(self, key: str, value: Any, timeout: Optional[int] = None) -> bool:
|
||||||
|
@ -111,7 +109,6 @@ class SupersetMetastoreCache(BaseCache):
|
||||||
key=self.get_key(key),
|
key=self.get_key(key),
|
||||||
expires_on=self._get_expiry(timeout),
|
expires_on=self._get_expiry(timeout),
|
||||||
).run()
|
).run()
|
||||||
db.session.commit()
|
|
||||||
return True
|
return True
|
||||||
except KeyValueCreateFailedError:
|
except KeyValueCreateFailedError:
|
||||||
return False
|
return False
|
||||||
|
@ -136,6 +133,4 @@ class SupersetMetastoreCache(BaseCache):
|
||||||
# pylint: disable=import-outside-toplevel
|
# pylint: disable=import-outside-toplevel
|
||||||
from superset.commands.key_value.delete import DeleteKeyValueCommand
|
from superset.commands.key_value.delete import DeleteKeyValueCommand
|
||||||
|
|
||||||
ret = DeleteKeyValueCommand(resource=RESOURCE, key=self.get_key(key)).run()
|
return DeleteKeyValueCommand(resource=RESOURCE, key=self.get_key(key)).run()
|
||||||
db.session.commit()
|
|
||||||
return ret
|
|
||||||
|
|
|
@ -18,7 +18,6 @@
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
from uuid import uuid3
|
from uuid import uuid3
|
||||||
|
|
||||||
from superset import db
|
|
||||||
from superset.key_value.types import JsonKeyValueCodec, KeyValueResource, SharedKey
|
from superset.key_value.types import JsonKeyValueCodec, KeyValueResource, SharedKey
|
||||||
from superset.key_value.utils import get_uuid_namespace, random_key
|
from superset.key_value.utils import get_uuid_namespace, random_key
|
||||||
|
|
||||||
|
@ -46,7 +45,6 @@ def set_shared_value(key: SharedKey, value: Any) -> None:
|
||||||
key=uuid_key,
|
key=uuid_key,
|
||||||
codec=CODEC,
|
codec=CODEC,
|
||||||
).run()
|
).run()
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
|
|
||||||
def get_permalink_salt(key: SharedKey) -> str:
|
def get_permalink_salt(key: SharedKey) -> str:
|
||||||
|
|
|
@ -334,22 +334,26 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
||||||
pyjwt_for_guest_token = _jwt_global_obj
|
pyjwt_for_guest_token = _jwt_global_obj
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def get_session(self) -> Session:
|
def get_session2(self) -> Session:
|
||||||
"""
|
"""
|
||||||
Flask-AppBuilder (FAB) which has a tendency to explicitly commit, thus violating
|
Flask-AppBuilder (FAB) which has a tendency to explicitly commit, thus violating
|
||||||
our definition of "unit of work".
|
our definition of "unit of work".
|
||||||
|
|
||||||
By providing a monkey patched transaction for the FAB session ensures that any
|
By providing a monkey patched transaction for the FAB session, within the
|
||||||
explicit commit merely flushes and any rollback is a no-op.
|
confines of a nested session, ensures that any explicit commit merely flushes
|
||||||
|
and any rollback is a no-op.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pylint: disable=import-outside-toplevel
|
# pylint: disable=import-outside-toplevel
|
||||||
from superset import db
|
from superset import db
|
||||||
|
|
||||||
with db.session.begin_nested() as transaction:
|
if db.session._proxied._nested_transaction: # pylint: disable=protected-access
|
||||||
transaction.session.commit = transaction.session.flush
|
with db.session.begin_nested() as transaction:
|
||||||
transaction.session.rollback = lambda: None
|
transaction.session.commit = transaction.session.flush
|
||||||
return transaction.session
|
transaction.session.rollback = lambda: None
|
||||||
|
return transaction.session
|
||||||
|
|
||||||
|
return db.session
|
||||||
|
|
||||||
def create_login_manager(self, app: Flask) -> LoginManager:
|
def create_login_manager(self, app: Flask) -> LoginManager:
|
||||||
lm = super().create_login_manager(app)
|
lm = super().create_login_manager(app)
|
||||||
|
@ -1035,7 +1039,6 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
||||||
== None, # noqa: E711
|
== None, # noqa: E711
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.get_session.flush()
|
|
||||||
if deleted_count := pvms.delete():
|
if deleted_count := pvms.delete():
|
||||||
logger.info("Deleted %i faulty permissions", deleted_count)
|
logger.info("Deleted %i faulty permissions", deleted_count)
|
||||||
|
|
||||||
|
@ -1065,7 +1068,6 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
||||||
)
|
)
|
||||||
|
|
||||||
self.create_missing_perms()
|
self.create_missing_perms()
|
||||||
self.get_session.flush()
|
|
||||||
self.clean_perms()
|
self.clean_perms()
|
||||||
|
|
||||||
def _get_all_pvms(self) -> list[PermissionView]:
|
def _get_all_pvms(self) -> list[PermissionView]:
|
||||||
|
@ -1138,7 +1140,6 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
||||||
):
|
):
|
||||||
role_from_permissions.append(permission_view)
|
role_from_permissions.append(permission_view)
|
||||||
role_to.permissions = role_from_permissions
|
role_to.permissions = role_from_permissions
|
||||||
self.get_session.flush()
|
|
||||||
|
|
||||||
def set_role(
|
def set_role(
|
||||||
self,
|
self,
|
||||||
|
@ -1159,7 +1160,6 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
||||||
permission_view for permission_view in pvms if pvm_check(permission_view)
|
permission_view for permission_view in pvms if pvm_check(permission_view)
|
||||||
]
|
]
|
||||||
role.permissions = role_pvms
|
role.permissions = role_pvms
|
||||||
self.get_session.flush()
|
|
||||||
|
|
||||||
def _is_admin_only(self, pvm: PermissionView) -> bool:
|
def _is_admin_only(self, pvm: PermissionView) -> bool:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -140,6 +140,7 @@ def get_tag(
|
||||||
if tag is None:
|
if tag is None:
|
||||||
tag = Tag(name=escape(tag_name), type=type_)
|
tag = Tag(name=escape(tag_name), type=type_)
|
||||||
session.add(tag)
|
session.add(tag)
|
||||||
|
session.commit()
|
||||||
return tag
|
return tag
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -59,6 +59,7 @@ def get_or_create_db(
|
||||||
if database and database.sqlalchemy_uri_decrypted != sqlalchemy_uri:
|
if database and database.sqlalchemy_uri_decrypted != sqlalchemy_uri:
|
||||||
database.set_sqlalchemy_uri(sqlalchemy_uri)
|
database.set_sqlalchemy_uri(sqlalchemy_uri)
|
||||||
|
|
||||||
|
db.session.flush()
|
||||||
return database
|
return database
|
||||||
|
|
||||||
|
|
||||||
|
@ -78,3 +79,4 @@ def remove_database(database: Database) -> None:
|
||||||
from superset import db
|
from superset import db
|
||||||
|
|
||||||
db.session.delete(database)
|
db.session.delete(database)
|
||||||
|
db.session.flush()
|
||||||
|
|
|
@ -222,7 +222,7 @@ def on_error(
|
||||||
|
|
||||||
:param ex: The source exception
|
:param ex: The source exception
|
||||||
:param catches: The exception types the handler catches
|
:param catches: The exception types the handler catches
|
||||||
:param reraise: The exception type the handler reraises after catching
|
:param reraise: The exception type the handler raises after catching
|
||||||
:raises Exception: If the exception is not swallowed
|
:raises Exception: If the exception is not swallowed
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -252,13 +252,16 @@ def transaction( # pylint: disable=redefined-outer-name
|
||||||
from superset import db # pylint: disable=import-outside-toplevel
|
from superset import db # pylint: disable=import-outside-toplevel
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with db.session.begin_nested():
|
result = func(*args, **kwargs)
|
||||||
return func(*args, **kwargs)
|
db.session.commit()
|
||||||
|
return result
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
|
db.session.rollback()
|
||||||
|
|
||||||
if on_error:
|
if on_error:
|
||||||
return on_error(ex)
|
return on_error(ex)
|
||||||
|
|
||||||
raise ex
|
raise
|
||||||
|
|
||||||
return wrapped
|
return wrapped
|
||||||
|
|
||||||
|
|
|
@ -24,7 +24,6 @@ from contextlib import contextmanager
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Any, cast, TypeVar, Union
|
from typing import Any, cast, TypeVar, Union
|
||||||
|
|
||||||
from superset import db
|
|
||||||
from superset.exceptions import CreateKeyValueDistributedLockFailedException
|
from superset.exceptions import CreateKeyValueDistributedLockFailedException
|
||||||
from superset.key_value.exceptions import KeyValueCreateFailedError
|
from superset.key_value.exceptions import KeyValueCreateFailedError
|
||||||
from superset.key_value.types import JsonKeyValueCodec, KeyValueResource
|
from superset.key_value.types import JsonKeyValueCodec, KeyValueResource
|
||||||
|
@ -72,7 +71,6 @@ def KeyValueDistributedLock( # pylint: disable=invalid-name
|
||||||
store.
|
store.
|
||||||
|
|
||||||
:param namespace: The namespace for which the lock is to be acquired.
|
:param namespace: The namespace for which the lock is to be acquired.
|
||||||
:type namespace: str
|
|
||||||
:param kwargs: Additional keyword arguments.
|
:param kwargs: Additional keyword arguments.
|
||||||
:yields: A unique identifier (UUID) for the acquired lock (the KV key).
|
:yields: A unique identifier (UUID) for the acquired lock (the KV key).
|
||||||
:raises CreateKeyValueDistributedLockFailedException: If the lock is taken.
|
:raises CreateKeyValueDistributedLockFailedException: If the lock is taken.
|
||||||
|
@ -93,12 +91,10 @@ def KeyValueDistributedLock( # pylint: disable=invalid-name
|
||||||
value=True,
|
value=True,
|
||||||
expires_on=datetime.now() + LOCK_EXPIRATION,
|
expires_on=datetime.now() + LOCK_EXPIRATION,
|
||||||
).run()
|
).run()
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
yield key
|
yield key
|
||||||
|
|
||||||
DeleteKeyValueCommand(resource=KeyValueResource.LOCK, key=key).run()
|
DeleteKeyValueCommand(resource=KeyValueResource.LOCK, key=key).run()
|
||||||
db.session.commit()
|
|
||||||
logger.debug("Removed lock on namespace %s for key %s", namespace, key)
|
logger.debug("Removed lock on namespace %s for key %s", namespace, key)
|
||||||
except KeyValueCreateFailedError as ex:
|
except KeyValueCreateFailedError as ex:
|
||||||
raise CreateKeyValueDistributedLockFailedException(
|
raise CreateKeyValueDistributedLockFailedException(
|
||||||
|
|
|
@ -403,6 +403,7 @@ class DBEventLogger(AbstractEventLogger):
|
||||||
logs.append(log)
|
logs.append(log)
|
||||||
try:
|
try:
|
||||||
db.session.bulk_save_objects(logs)
|
db.session.bulk_save_objects(logs)
|
||||||
|
db.session.commit() # pylint: disable=consider-using-transaction
|
||||||
except SQLAlchemyError as ex:
|
except SQLAlchemyError as ex:
|
||||||
logging.error("DBEventLogger failed to log event(s)")
|
logging.error("DBEventLogger failed to log event(s)")
|
||||||
logging.exception(ex)
|
logging.exception(ex)
|
||||||
|
|
|
@ -14,8 +14,7 @@
|
||||||
# KIND, either express or implied. See the License for the
|
# KIND, either express or implied. See the License for the
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
# pylint: disable=consider-using-transaction
|
from typing import TYPE_CHECKING
|
||||||
from typing import Any, TYPE_CHECKING
|
|
||||||
|
|
||||||
from flask_appbuilder import expose
|
from flask_appbuilder import expose
|
||||||
from flask_appbuilder.models.sqla.interface import SQLAInterface
|
from flask_appbuilder.models.sqla.interface import SQLAInterface
|
||||||
|
|
|
@ -349,15 +349,7 @@ class SupersetTestCase(TestCase):
|
||||||
self.grant_role_access_to_table(table, role_name)
|
self.grant_role_access_to_table(table, role_name)
|
||||||
|
|
||||||
def grant_role_access_to_table(self, table, role_name):
|
def grant_role_access_to_table(self, table, role_name):
|
||||||
print(">>> grant_role_access_to_table <<<")
|
|
||||||
print(role_name)
|
|
||||||
print(db.session.get_bind())
|
|
||||||
from flask_appbuilder.security.sqla.models import Role
|
|
||||||
|
|
||||||
print(list(db.session.query(Role).all()))
|
|
||||||
role = security_manager.find_role(role_name)
|
role = security_manager.find_role(role_name)
|
||||||
print(role)
|
|
||||||
|
|
||||||
perms = db.session.query(ab_models.PermissionView).all()
|
perms = db.session.query(ab_models.PermissionView).all()
|
||||||
for perm in perms:
|
for perm in perms:
|
||||||
if (
|
if (
|
||||||
|
|
|
@ -332,6 +332,9 @@ def test_run_async_cta_query_with_lower_limit(test_client, ctas_method):
|
||||||
tmp_table=tmp_table,
|
tmp_table=tmp_table,
|
||||||
)
|
)
|
||||||
query = wait_for_success(result)
|
query = wait_for_success(result)
|
||||||
|
print(">>> test_run_async_cta_query_with_lower_limit <<<")
|
||||||
|
print(result)
|
||||||
|
print(query.to_dict())
|
||||||
assert QueryStatus.SUCCESS == query.status
|
assert QueryStatus.SUCCESS == query.status
|
||||||
|
|
||||||
sqlite_select_sql = f"SELECT\n *\nFROM {tmp_table}\nLIMIT {query.limit}\nOFFSET 0"
|
sqlite_select_sql = f"SELECT\n *\nFROM {tmp_table}\nLIMIT {query.limit}\nOFFSET 0"
|
||||||
|
|
|
@ -91,7 +91,7 @@ class TestCore(SupersetTestCase):
|
||||||
self.original_unsafe_db_setting = app.config["PREVENT_UNSAFE_DB_CONNECTIONS"]
|
self.original_unsafe_db_setting = app.config["PREVENT_UNSAFE_DB_CONNECTIONS"]
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
# db.session.query(Query).delete()
|
db.session.query(Query).delete()
|
||||||
app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = self.original_unsafe_db_setting
|
app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = self.original_unsafe_db_setting
|
||||||
super().tearDown()
|
super().tearDown()
|
||||||
|
|
||||||
|
@ -235,7 +235,6 @@ class TestCore(SupersetTestCase):
|
||||||
)
|
)
|
||||||
for slc in slices:
|
for slc in slices:
|
||||||
db.session.delete(slc)
|
db.session.delete(slc)
|
||||||
print(db.session.dirty)
|
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||||
|
@ -666,7 +665,6 @@ class TestCore(SupersetTestCase):
|
||||||
client_id="client_id_1",
|
client_id="client_id_1",
|
||||||
username="admin",
|
username="admin",
|
||||||
)
|
)
|
||||||
print(resp)
|
|
||||||
count_ds = []
|
count_ds = []
|
||||||
count_name = []
|
count_name = []
|
||||||
for series in data["data"]:
|
for series in data["data"]:
|
||||||
|
@ -814,7 +812,7 @@ class TestCore(SupersetTestCase):
|
||||||
mock_cache.return_value = MockCache()
|
mock_cache.return_value = MockCache()
|
||||||
|
|
||||||
rv = self.client.get("/superset/explore_json/data/valid-cache-key")
|
rv = self.client.get("/superset/explore_json/data/valid-cache-key")
|
||||||
self.assertEqual(rv.status_code, 401)
|
self.assertEqual(rv.status_code, 403)
|
||||||
|
|
||||||
def test_explore_json_data_invalid_cache_key(self):
|
def test_explore_json_data_invalid_cache_key(self):
|
||||||
self.login(ADMIN_USERNAME)
|
self.login(ADMIN_USERNAME)
|
||||||
|
|
|
@ -97,5 +97,6 @@ def create_dashboard(
|
||||||
if slices is not None:
|
if slices is not None:
|
||||||
dash.slices = slices
|
dash.slices = slices
|
||||||
db.session.add(dash)
|
db.session.add(dash)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
return dash
|
return dash
|
||||||
|
|
|
@ -592,7 +592,6 @@ class TestImportDashboardsCommand(SupersetTestCase):
|
||||||
}
|
}
|
||||||
command = v1.ImportDashboardsCommand(contents, overwrite=True)
|
command = v1.ImportDashboardsCommand(contents, overwrite=True)
|
||||||
command.run()
|
command.run()
|
||||||
command.run()
|
|
||||||
|
|
||||||
new_num_dashboards = db.session.query(Dashboard).count()
|
new_num_dashboards = db.session.query(Dashboard).count()
|
||||||
assert new_num_dashboards == num_dashboards + 1
|
assert new_num_dashboards == num_dashboards + 1
|
||||||
|
|
|
@ -48,15 +48,16 @@ class TestDashboardDAO(SupersetTestCase):
|
||||||
assert changed_on == DashboardDAO.get_dashboard_changed_on("world_health")
|
assert changed_on == DashboardDAO.get_dashboard_changed_on("world_health")
|
||||||
|
|
||||||
old_changed_on = dashboard.changed_on
|
old_changed_on = dashboard.changed_on
|
||||||
|
|
||||||
# freezegun doesn't work for some reason, so we need to sleep here :(
|
# freezegun doesn't work for some reason, so we need to sleep here :(
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
data = dashboard.data
|
data = dashboard.data
|
||||||
positions = data["position_json"]
|
positions = data["position_json"]
|
||||||
data.update({"positions": positions})
|
data.update({"positions": positions})
|
||||||
original_data = copy.deepcopy(data)
|
original_data = copy.deepcopy(data)
|
||||||
|
|
||||||
data.update({"foo": "bar"})
|
data.update({"foo": "bar"})
|
||||||
DashboardDAO.set_dash_metadata(dashboard, data)
|
DashboardDAO.set_dash_metadata(dashboard, data)
|
||||||
db.session.flush()
|
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
new_changed_on = DashboardDAO.get_dashboard_changed_on(dashboard)
|
new_changed_on = DashboardDAO.get_dashboard_changed_on(dashboard)
|
||||||
assert old_changed_on.replace(microsecond=0) < new_changed_on
|
assert old_changed_on.replace(microsecond=0) < new_changed_on
|
||||||
|
@ -68,7 +69,6 @@ class TestDashboardDAO(SupersetTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
DashboardDAO.set_dash_metadata(dashboard, original_data)
|
DashboardDAO.set_dash_metadata(dashboard, original_data)
|
||||||
db.session.flush()
|
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
@pytest.mark.usefixtures("load_world_bank_dashboard_with_slices")
|
@pytest.mark.usefixtures("load_world_bank_dashboard_with_slices")
|
||||||
|
|
|
@ -26,6 +26,7 @@ import prison
|
||||||
import pytest
|
import pytest
|
||||||
import yaml
|
import yaml
|
||||||
from sqlalchemy import inspect
|
from sqlalchemy import inspect
|
||||||
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
from sqlalchemy.orm import joinedload
|
from sqlalchemy.orm import joinedload
|
||||||
from sqlalchemy.sql import func
|
from sqlalchemy.sql import func
|
||||||
|
|
||||||
|
|
|
@ -44,6 +44,7 @@ class TestEmbeddedDashboardApi(SupersetTestCase):
|
||||||
self.login(ADMIN_USERNAME)
|
self.login(ADMIN_USERNAME)
|
||||||
self.dash = db.session.query(Dashboard).filter_by(slug="births").first()
|
self.dash = db.session.query(Dashboard).filter_by(slug="births").first()
|
||||||
self.embedded = EmbeddedDashboardDAO.upsert(self.dash, [])
|
self.embedded = EmbeddedDashboardDAO.upsert(self.dash, [])
|
||||||
|
db.session.flush()
|
||||||
uri = f"api/v1/{self.resource_name}/{self.embedded.uuid}"
|
uri = f"api/v1/{self.resource_name}/{self.embedded.uuid}"
|
||||||
response = self.client.get(uri)
|
response = self.client.get(uri)
|
||||||
self.assert200(response)
|
self.assert200(response)
|
||||||
|
|
|
@ -34,17 +34,21 @@ class TestEmbeddedDashboardDAO(SupersetTestCase):
|
||||||
dash = db.session.query(Dashboard).filter_by(slug="world_health").first()
|
dash = db.session.query(Dashboard).filter_by(slug="world_health").first()
|
||||||
assert not dash.embedded
|
assert not dash.embedded
|
||||||
EmbeddedDashboardDAO.upsert(dash, ["test.example.com"])
|
EmbeddedDashboardDAO.upsert(dash, ["test.example.com"])
|
||||||
|
db.session.flush()
|
||||||
assert dash.embedded
|
assert dash.embedded
|
||||||
self.assertEqual(dash.embedded[0].allowed_domains, ["test.example.com"])
|
self.assertEqual(dash.embedded[0].allowed_domains, ["test.example.com"])
|
||||||
original_uuid = dash.embedded[0].uuid
|
original_uuid = dash.embedded[0].uuid
|
||||||
self.assertIsNotNone(original_uuid)
|
self.assertIsNotNone(original_uuid)
|
||||||
EmbeddedDashboardDAO.upsert(dash, [])
|
EmbeddedDashboardDAO.upsert(dash, [])
|
||||||
|
db.session.flush()
|
||||||
self.assertEqual(dash.embedded[0].allowed_domains, [])
|
self.assertEqual(dash.embedded[0].allowed_domains, [])
|
||||||
self.assertEqual(dash.embedded[0].uuid, original_uuid)
|
self.assertEqual(dash.embedded[0].uuid, original_uuid)
|
||||||
|
|
||||||
@pytest.mark.usefixtures("load_world_bank_dashboard_with_slices")
|
@pytest.mark.usefixtures("load_world_bank_dashboard_with_slices")
|
||||||
def test_get_by_uuid(self):
|
def test_get_by_uuid(self):
|
||||||
dash = db.session.query(Dashboard).filter_by(slug="world_health").first()
|
dash = db.session.query(Dashboard).filter_by(slug="world_health").first()
|
||||||
uuid = str(EmbeddedDashboardDAO.upsert(dash, ["test.example.com"]).uuid)
|
EmbeddedDashboardDAO.upsert(dash, ["test.example.com"])
|
||||||
|
db.session.flush()
|
||||||
|
uuid = str(dash.embedded[0].uuid)
|
||||||
embedded = EmbeddedDashboardDAO.find_by_id(uuid)
|
embedded = EmbeddedDashboardDAO.find_by_id(uuid)
|
||||||
self.assertIsNotNone(embedded)
|
self.assertIsNotNone(embedded)
|
||||||
|
|
|
@ -44,6 +44,7 @@ if TYPE_CHECKING:
|
||||||
def test_get_embedded_dashboard(client: FlaskClient[Any]): # noqa: F811
|
def test_get_embedded_dashboard(client: FlaskClient[Any]): # noqa: F811
|
||||||
dash = db.session.query(Dashboard).filter_by(slug="births").first()
|
dash = db.session.query(Dashboard).filter_by(slug="births").first()
|
||||||
embedded = EmbeddedDashboardDAO.upsert(dash, [])
|
embedded = EmbeddedDashboardDAO.upsert(dash, [])
|
||||||
|
db.session.flush()
|
||||||
uri = f"embedded/{embedded.uuid}"
|
uri = f"embedded/{embedded.uuid}"
|
||||||
response = client.get(uri)
|
response = client.get(uri)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
@ -57,6 +58,7 @@ def test_get_embedded_dashboard(client: FlaskClient[Any]): # noqa: F811
|
||||||
def test_get_embedded_dashboard_referrer_not_allowed(client: FlaskClient[Any]): # noqa: F811
|
def test_get_embedded_dashboard_referrer_not_allowed(client: FlaskClient[Any]): # noqa: F811
|
||||||
dash = db.session.query(Dashboard).filter_by(slug="births").first()
|
dash = db.session.query(Dashboard).filter_by(slug="births").first()
|
||||||
embedded = EmbeddedDashboardDAO.upsert(dash, ["test.example.com"])
|
embedded = EmbeddedDashboardDAO.upsert(dash, ["test.example.com"])
|
||||||
|
db.session.flush()
|
||||||
uri = f"embedded/{embedded.uuid}"
|
uri = f"embedded/{embedded.uuid}"
|
||||||
response = client.get(uri)
|
response = client.get(uri)
|
||||||
assert response.status_code == 403
|
assert response.status_code == 403
|
||||||
|
|
|
@ -46,22 +46,25 @@ def load_birth_names_data(
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def load_birth_names_dashboard_with_slices(load_birth_names_data):
|
def load_birth_names_dashboard_with_slices(load_birth_names_data):
|
||||||
with app.app_context():
|
with app.app_context():
|
||||||
_create_dashboards()
|
dash_id_to_delete, slices_ids_to_delete = _create_dashboards()
|
||||||
yield
|
yield
|
||||||
|
_cleanup(dash_id_to_delete, slices_ids_to_delete)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def load_birth_names_dashboard_with_slices_module_scope(load_birth_names_data):
|
def load_birth_names_dashboard_with_slices_module_scope(load_birth_names_data):
|
||||||
with app.app_context():
|
with app.app_context():
|
||||||
_create_dashboards()
|
dash_id_to_delete, slices_ids_to_delete = _create_dashboards()
|
||||||
yield
|
yield
|
||||||
|
_cleanup(dash_id_to_delete, slices_ids_to_delete)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="class")
|
@pytest.fixture(scope="class")
|
||||||
def load_birth_names_dashboard_with_slices_class_scope(load_birth_names_data):
|
def load_birth_names_dashboard_with_slices_class_scope(load_birth_names_data):
|
||||||
with app.app_context():
|
with app.app_context():
|
||||||
_create_dashboards()
|
dash_id_to_delete, slices_ids_to_delete = _create_dashboards()
|
||||||
yield
|
yield
|
||||||
|
_cleanup(dash_id_to_delete, slices_ids_to_delete)
|
||||||
|
|
||||||
|
|
||||||
def _create_dashboards():
|
def _create_dashboards():
|
||||||
|
@ -74,7 +77,10 @@ def _create_dashboards():
|
||||||
from superset.examples.birth_names import create_dashboard, create_slices
|
from superset.examples.birth_names import create_dashboard, create_slices
|
||||||
|
|
||||||
slices, _ = create_slices(table)
|
slices, _ = create_slices(table)
|
||||||
create_dashboard(slices)
|
dash = create_dashboard(slices)
|
||||||
|
slices_ids_to_delete = [slice.id for slice in slices]
|
||||||
|
dash_id_to_delete = dash.id
|
||||||
|
return dash_id_to_delete, slices_ids_to_delete
|
||||||
|
|
||||||
|
|
||||||
def _create_table(
|
def _create_table(
|
||||||
|
@ -91,4 +97,20 @@ def _create_table(
|
||||||
|
|
||||||
_set_table_metadata(table, database)
|
_set_table_metadata(table, database)
|
||||||
_add_table_metrics(table)
|
_add_table_metrics(table)
|
||||||
|
db.session.commit()
|
||||||
return table
|
return table
|
||||||
|
|
||||||
|
|
||||||
|
def _cleanup(dash_id: int, slice_ids: list[int]) -> None:
|
||||||
|
schema = get_example_default_schema()
|
||||||
|
for datasource in db.session.query(SqlaTable).filter_by(
|
||||||
|
table_name="birth_names", schema=schema
|
||||||
|
):
|
||||||
|
for col in datasource.columns + datasource.metrics:
|
||||||
|
db.session.delete(col)
|
||||||
|
|
||||||
|
for dash in db.session.query(Dashboard).filter_by(id=dash_id):
|
||||||
|
db.session.delete(dash)
|
||||||
|
for slc in db.session.query(Slice).filter(Slice.id.in_(slice_ids)):
|
||||||
|
db.session.delete(slc)
|
||||||
|
db.session.commit()
|
||||||
|
|
|
@ -61,6 +61,7 @@ def load_energy_table_with_slice(load_energy_table_data):
|
||||||
with app.app_context():
|
with app.app_context():
|
||||||
slices = _create_energy_table()
|
slices = _create_energy_table()
|
||||||
yield slices
|
yield slices
|
||||||
|
_cleanup()
|
||||||
|
|
||||||
|
|
||||||
def _get_dataframe():
|
def _get_dataframe():
|
||||||
|
@ -109,6 +110,24 @@ def _create_and_commit_energy_slice(
|
||||||
return slice
|
return slice
|
||||||
|
|
||||||
|
|
||||||
|
def _cleanup() -> None:
|
||||||
|
for slice_data in _get_energy_slices():
|
||||||
|
slice = (
|
||||||
|
db.session.query(Slice)
|
||||||
|
.filter_by(slice_name=slice_data["slice_title"])
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
db.session.delete(slice)
|
||||||
|
|
||||||
|
metric = (
|
||||||
|
db.session.query(SqlMetric).filter_by(metric_name="sum__value").one_or_none()
|
||||||
|
)
|
||||||
|
if metric:
|
||||||
|
db.session.delete(metric)
|
||||||
|
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
|
||||||
def _get_energy_data():
|
def _get_energy_data():
|
||||||
data = []
|
data = []
|
||||||
for i in range(85):
|
for i in range(85):
|
||||||
|
|
|
@ -135,4 +135,7 @@ def tabbed_dashboard(app_context):
|
||||||
slices=[],
|
slices=[],
|
||||||
)
|
)
|
||||||
db.session.add(dash)
|
db.session.add(dash)
|
||||||
yield
|
db.session.commit()
|
||||||
|
yield dash
|
||||||
|
db.session.query(Dashboard).filter_by(id=dash.id).delete()
|
||||||
|
db.session.commit()
|
||||||
|
|
|
@ -61,6 +61,7 @@ def load_unicode_dashboard_with_slice(load_unicode_data):
|
||||||
with app.app_context():
|
with app.app_context():
|
||||||
dash = _create_unicode_dashboard(slice_name, None)
|
dash = _create_unicode_dashboard(slice_name, None)
|
||||||
yield
|
yield
|
||||||
|
_cleanup(dash, slice_name)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
|
@ -70,6 +71,7 @@ def load_unicode_dashboard_with_position(load_unicode_data):
|
||||||
with app.app_context():
|
with app.app_context():
|
||||||
dash = _create_unicode_dashboard(slice_name, position)
|
dash = _create_unicode_dashboard(slice_name, position)
|
||||||
yield
|
yield
|
||||||
|
_cleanup(dash, slice_name)
|
||||||
|
|
||||||
|
|
||||||
def _get_dataframe():
|
def _get_dataframe():
|
||||||
|
@ -95,18 +97,25 @@ def _create_unicode_dashboard(slice_title: str, position: str) -> Dashboard:
|
||||||
table.fetch_metadata()
|
table.fetch_metadata()
|
||||||
|
|
||||||
if slice_title:
|
if slice_title:
|
||||||
slice = _create_unicode_slice(table, slice_title)
|
slice = _create_and_commit_unicode_slice(table, slice_title)
|
||||||
|
|
||||||
return create_dashboard("unicode-test", "Unicode Test", position, [slice])
|
return create_dashboard("unicode-test", "Unicode Test", position, [slice])
|
||||||
|
|
||||||
|
|
||||||
def _create_unicode_slice(table: SqlaTable, title: str):
|
def _create_and_commit_unicode_slice(table: SqlaTable, title: str):
|
||||||
slc = create_slice(title, "word_cloud", table, {})
|
slice = create_slice(title, "word_cloud", table, {})
|
||||||
if (
|
o = db.session.query(Slice).filter_by(slice_name=slice.slice_name).one_or_none()
|
||||||
obj := db.session.query(Slice)
|
if o:
|
||||||
.filter_by(slice_name=slc.slice_name)
|
db.session.delete(o)
|
||||||
.one_or_none()
|
db.session.add(slice)
|
||||||
|
db.session.commit()
|
||||||
|
return slice
|
||||||
|
|
||||||
|
|
||||||
|
def _cleanup(dash: Dashboard, slice_name: str) -> None:
|
||||||
|
db.session.delete(dash)
|
||||||
|
if slice_name and (
|
||||||
|
slice := db.session.query(Slice).filter_by(slice_name=slice_name).one_or_none()
|
||||||
):
|
):
|
||||||
db.session.delete(obj)
|
db.session.delete(slice)
|
||||||
db.session.add(slc)
|
db.session.commit()
|
||||||
return slc
|
|
||||||
|
|
|
@ -64,7 +64,6 @@ def load_world_bank_data():
|
||||||
)
|
)
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
with app.app_context():
|
with app.app_context():
|
||||||
with get_example_database().get_sqla_engine() as engine:
|
with get_example_database().get_sqla_engine() as engine:
|
||||||
engine.execute("DROP TABLE IF EXISTS wb_health_population")
|
engine.execute("DROP TABLE IF EXISTS wb_health_population")
|
||||||
|
@ -73,14 +72,15 @@ def load_world_bank_data():
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def load_world_bank_dashboard_with_slices(load_world_bank_data):
|
def load_world_bank_dashboard_with_slices(load_world_bank_data):
|
||||||
with app.app_context():
|
with app.app_context():
|
||||||
create_dashboard_for_loaded_data()
|
dash_id_to_delete, slices_ids_to_delete = create_dashboard_for_loaded_data()
|
||||||
yield
|
yield
|
||||||
|
_cleanup(dash_id_to_delete, slices_ids_to_delete)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def load_world_bank_dashboard_with_slices_module_scope(load_world_bank_data):
|
def load_world_bank_dashboard_with_slices_module_scope(load_world_bank_data):
|
||||||
with app.app_context():
|
with app.app_context():
|
||||||
create_dashboard_for_loaded_data()
|
dash_id_to_delete, slices_ids_to_delete = create_dashboard_for_loaded_data()
|
||||||
yield
|
yield
|
||||||
_cleanup_reports(dash_id_to_delete, slices_ids_to_delete)
|
_cleanup_reports(dash_id_to_delete, slices_ids_to_delete)
|
||||||
_cleanup(dash_id_to_delete, slices_ids_to_delete)
|
_cleanup(dash_id_to_delete, slices_ids_to_delete)
|
||||||
|
@ -89,8 +89,9 @@ def load_world_bank_dashboard_with_slices_module_scope(load_world_bank_data):
|
||||||
@pytest.fixture(scope="class")
|
@pytest.fixture(scope="class")
|
||||||
def load_world_bank_dashboard_with_slices_class_scope(load_world_bank_data):
|
def load_world_bank_dashboard_with_slices_class_scope(load_world_bank_data):
|
||||||
with app.app_context():
|
with app.app_context():
|
||||||
create_dashboard_for_loaded_data()
|
dash_id_to_delete, slices_ids_to_delete = create_dashboard_for_loaded_data()
|
||||||
yield
|
yield
|
||||||
|
_cleanup(dash_id_to_delete, slices_ids_to_delete)
|
||||||
|
|
||||||
|
|
||||||
def create_dashboard_for_loaded_data():
|
def create_dashboard_for_loaded_data():
|
||||||
|
@ -102,22 +103,24 @@ def create_dashboard_for_loaded_data():
|
||||||
return dash_id_to_delete, slices_ids_to_delete
|
return dash_id_to_delete, slices_ids_to_delete
|
||||||
|
|
||||||
|
|
||||||
def _create_world_bank_slices(table: SqlaTable) -> None:
|
def _create_world_bank_slices(table: SqlaTable) -> list[Slice]:
|
||||||
from superset.examples.world_bank import create_slices
|
from superset.examples.world_bank import create_slices
|
||||||
|
|
||||||
slices = create_slices(table)
|
slices = create_slices(table)
|
||||||
|
_commit_slices(slices)
|
||||||
for slc in slices:
|
return slices
|
||||||
if (
|
|
||||||
obj := db.session.query(Slice)
|
|
||||||
.filter_by(slice_name=slc.slice_name)
|
|
||||||
.one_or_none()
|
|
||||||
):
|
|
||||||
db.session.delete(obj)
|
|
||||||
db.session.add(slc)
|
|
||||||
|
|
||||||
|
|
||||||
def _create_world_bank_dashboard(table: SqlaTable) -> None:
|
def _commit_slices(slices: list[Slice]):
|
||||||
|
for slice in slices:
|
||||||
|
o = db.session.query(Slice).filter_by(slice_name=slice.slice_name).one_or_none()
|
||||||
|
if o:
|
||||||
|
db.session.delete(o)
|
||||||
|
db.session.add(slice)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def _create_world_bank_dashboard(table: SqlaTable) -> Dashboard:
|
||||||
from superset.examples.helpers import update_slice_ids
|
from superset.examples.helpers import update_slice_ids
|
||||||
from superset.examples.world_bank import dashboard_positions
|
from superset.examples.world_bank import dashboard_positions
|
||||||
|
|
||||||
|
@ -130,6 +133,16 @@ def _create_world_bank_dashboard(table: SqlaTable) -> None:
|
||||||
"world_health", "World Bank's Data", json.dumps(pos), slices
|
"world_health", "World Bank's Data", json.dumps(pos), slices
|
||||||
)
|
)
|
||||||
dash.json_metadata = '{"mock_key": "mock_value"}'
|
dash.json_metadata = '{"mock_key": "mock_value"}'
|
||||||
|
db.session.commit()
|
||||||
|
return dash
|
||||||
|
|
||||||
|
|
||||||
|
def _cleanup(dash_id: int, slices_ids: list[int]) -> None:
|
||||||
|
dash = db.session.query(Dashboard).filter_by(id=dash_id).first()
|
||||||
|
db.session.delete(dash)
|
||||||
|
for slice_id in slices_ids:
|
||||||
|
db.session.query(Slice).filter_by(id=slice_id).delete()
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
|
||||||
def _cleanup_reports(dash_id: int, slices_ids: list[int]) -> None:
|
def _cleanup_reports(dash_id: int, slices_ids: list[int]) -> None:
|
||||||
|
|
|
@ -215,8 +215,6 @@ class TestRowLevelSecurity(SupersetTestCase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.assertEqual(rv.status_code, 422)
|
self.assertEqual(rv.status_code, 422)
|
||||||
data = json.loads(rv.data.decode("utf-8"))
|
|
||||||
assert "Create failed" in data["message"]
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("create_dataset")
|
@pytest.mark.usefixtures("create_dataset")
|
||||||
def test_model_view_rls_add_tables_required(self):
|
def test_model_view_rls_add_tables_required(self):
|
||||||
|
|
|
@ -71,10 +71,8 @@ QUERY_3 = "SELECT * FROM birth_names LIMIT 10"
|
||||||
class TestSqlLab(SupersetTestCase):
|
class TestSqlLab(SupersetTestCase):
|
||||||
"""Testings for Sql Lab"""
|
"""Testings for Sql Lab"""
|
||||||
|
|
||||||
@pytest.mark.usefixtures("load_birth_names_data")
|
|
||||||
def run_some_queries(self):
|
def run_some_queries(self):
|
||||||
db.session.query(Query).delete()
|
db.session.query(Query).delete()
|
||||||
db.session.commit()
|
|
||||||
self.run_sql(QUERY_1, client_id="client_id_1", username="admin")
|
self.run_sql(QUERY_1, client_id="client_id_1", username="admin")
|
||||||
self.run_sql(QUERY_2, client_id="client_id_2", username="admin")
|
self.run_sql(QUERY_2, client_id="client_id_2", username="admin")
|
||||||
self.run_sql(QUERY_3, client_id="client_id_3", username="gamma_sqllab")
|
self.run_sql(QUERY_3, client_id="client_id_3", username="gamma_sqllab")
|
||||||
|
@ -419,6 +417,7 @@ class TestSqlLab(SupersetTestCase):
|
||||||
self.assertEqual(len(data["data"]), 1200)
|
self.assertEqual(len(data["data"]), 1200)
|
||||||
self.assertEqual(data["query"]["limitingFactor"], LimitingFactor.NOT_LIMITED)
|
self.assertEqual(data["query"]["limitingFactor"], LimitingFactor.NOT_LIMITED)
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("load_birth_names_data")
|
||||||
def test_query_api_filter(self) -> None:
|
def test_query_api_filter(self) -> None:
|
||||||
"""
|
"""
|
||||||
Test query api without can_only_access_owned_queries perm added to
|
Test query api without can_only_access_owned_queries perm added to
|
||||||
|
@ -438,6 +437,7 @@ class TestSqlLab(SupersetTestCase):
|
||||||
assert admin.first_name in user_queries
|
assert admin.first_name in user_queries
|
||||||
assert gamma_sqllab.first_name in user_queries
|
assert gamma_sqllab.first_name in user_queries
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("load_birth_names_data")
|
||||||
def test_query_api_can_access_all_queries(self) -> None:
|
def test_query_api_can_access_all_queries(self) -> None:
|
||||||
"""
|
"""
|
||||||
Test query api with can_access_all_queries perm added to
|
Test query api with can_access_all_queries perm added to
|
||||||
|
@ -522,6 +522,7 @@ class TestSqlLab(SupersetTestCase):
|
||||||
{r.get("sql") for r in self.get_json_resp(url)["result"]},
|
{r.get("sql") for r in self.get_json_resp(url)["result"]},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("load_birth_names_data")
|
||||||
def test_query_admin_can_access_all_queries(self) -> None:
|
def test_query_admin_can_access_all_queries(self) -> None:
|
||||||
"""
|
"""
|
||||||
Test query api with all_query_access perm added to
|
Test query api with all_query_access perm added to
|
||||||
|
|
|
@ -187,6 +187,7 @@ class TestTagsDAO(SupersetTestCase):
|
||||||
TaggedObject.object_type == ObjectType.chart,
|
TaggedObject.object_type == ObjectType.chart,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
.join(Tag, TaggedObject.tag_id == Tag.id)
|
||||||
.distinct(Slice.id)
|
.distinct(Slice.id)
|
||||||
.count()
|
.count()
|
||||||
)
|
)
|
||||||
|
@ -199,6 +200,7 @@ class TestTagsDAO(SupersetTestCase):
|
||||||
TaggedObject.object_type == ObjectType.dashboard,
|
TaggedObject.object_type == ObjectType.dashboard,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
.join(Tag, TaggedObject.tag_id == Tag.id)
|
||||||
.distinct(Dashboard.id)
|
.distinct(Dashboard.id)
|
||||||
.count()
|
.count()
|
||||||
+ num_charts
|
+ num_charts
|
||||||
|
|
|
@ -29,7 +29,6 @@ def database_with_catalog(mocker: MockerFixture) -> MagicMock:
|
||||||
"""
|
"""
|
||||||
Mock a database with catalogs and schemas.
|
Mock a database with catalogs and schemas.
|
||||||
"""
|
"""
|
||||||
mocker.patch("superset.commands.database.create.db")
|
|
||||||
mocker.patch("superset.commands.database.create.TestConnectionDatabaseCommand")
|
mocker.patch("superset.commands.database.create.TestConnectionDatabaseCommand")
|
||||||
|
|
||||||
database = mocker.MagicMock()
|
database = mocker.MagicMock()
|
||||||
|
@ -53,7 +52,6 @@ def database_without_catalog(mocker: MockerFixture) -> MagicMock:
|
||||||
"""
|
"""
|
||||||
Mock a database without catalogs.
|
Mock a database without catalogs.
|
||||||
"""
|
"""
|
||||||
mocker.patch("superset.commands.database.create.db")
|
|
||||||
mocker.patch("superset.commands.database.create.TestConnectionDatabaseCommand")
|
mocker.patch("superset.commands.database.create.TestConnectionDatabaseCommand")
|
||||||
|
|
||||||
database = mocker.MagicMock()
|
database = mocker.MagicMock()
|
||||||
|
|
|
@ -29,8 +29,6 @@ def database_with_catalog(mocker: MockerFixture) -> MagicMock:
|
||||||
"""
|
"""
|
||||||
Mock a database with catalogs and schemas.
|
Mock a database with catalogs and schemas.
|
||||||
"""
|
"""
|
||||||
mocker.patch("superset.commands.database.update.db")
|
|
||||||
|
|
||||||
database = mocker.MagicMock()
|
database = mocker.MagicMock()
|
||||||
database.database_name = "my_db"
|
database.database_name = "my_db"
|
||||||
database.db_engine_spec.__name__ = "test_engine"
|
database.db_engine_spec.__name__ = "test_engine"
|
||||||
|
@ -50,8 +48,6 @@ def database_without_catalog(mocker: MockerFixture) -> MagicMock:
|
||||||
"""
|
"""
|
||||||
Mock a database without catalogs.
|
Mock a database without catalogs.
|
||||||
"""
|
"""
|
||||||
mocker.patch("superset.commands.database.update.db")
|
|
||||||
|
|
||||||
database = mocker.MagicMock()
|
database = mocker.MagicMock()
|
||||||
database.database_name = "my_db"
|
database.database_name = "my_db"
|
||||||
database.db_engine_spec.__name__ = "test_engine"
|
database.db_engine_spec.__name__ = "test_engine"
|
||||||
|
|
|
@ -22,7 +22,6 @@ def test_user_favorite_tag(mocker):
|
||||||
from superset.daos.tag import TagDAO
|
from superset.daos.tag import TagDAO
|
||||||
|
|
||||||
# Mock the behavior of TagDAO and g
|
# Mock the behavior of TagDAO and g
|
||||||
mock_session = mocker.patch("superset.daos.tag.db.session")
|
|
||||||
mock_TagDAO = mocker.patch(
|
mock_TagDAO = mocker.patch(
|
||||||
"superset.daos.tag.TagDAO"
|
"superset.daos.tag.TagDAO"
|
||||||
) # Replace with the actual path to TagDAO
|
) # Replace with the actual path to TagDAO
|
||||||
|
@ -45,7 +44,6 @@ def test_remove_user_favorite_tag(mocker):
|
||||||
from superset.daos.tag import TagDAO
|
from superset.daos.tag import TagDAO
|
||||||
|
|
||||||
# Mock the behavior of TagDAO and g
|
# Mock the behavior of TagDAO and g
|
||||||
mock_session = mocker.patch("superset.daos.tag.db.session")
|
|
||||||
mock_TagDAO = mocker.patch("superset.daos.tag.TagDAO")
|
mock_TagDAO = mocker.patch("superset.daos.tag.TagDAO")
|
||||||
mock_tag = mocker.MagicMock(users_favorited=[])
|
mock_tag = mocker.MagicMock(users_favorited=[])
|
||||||
mock_TagDAO.find_by_id.return_value = mock_tag
|
mock_TagDAO.find_by_id.return_value = mock_tag
|
||||||
|
|
|
@ -116,7 +116,7 @@ def test_post_with_uuid(
|
||||||
assert payload["result"]["uuid"] == "7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb"
|
assert payload["result"]["uuid"] == "7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb"
|
||||||
|
|
||||||
database = session.query(Database).one()
|
database = session.query(Database).one()
|
||||||
assert database.uuid == "7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb"
|
assert database.uuid == UUID("7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb")
|
||||||
|
|
||||||
|
|
||||||
def test_password_mask(
|
def test_password_mask(
|
||||||
|
|
|
@ -22,8 +22,8 @@ from uuid import UUID
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from freezegun import freeze_time
|
from freezegun import freeze_time
|
||||||
from sqlalchemy.orm import Session, sessionmaker
|
|
||||||
|
|
||||||
|
from superset import db
|
||||||
from superset.exceptions import CreateKeyValueDistributedLockFailedException
|
from superset.exceptions import CreateKeyValueDistributedLockFailedException
|
||||||
from superset.key_value.types import JsonKeyValueCodec
|
from superset.key_value.types import JsonKeyValueCodec
|
||||||
from superset.utils.lock import get_key, KeyValueDistributedLock
|
from superset.utils.lock import get_key, KeyValueDistributedLock
|
||||||
|
@ -32,56 +32,51 @@ MAIN_KEY = get_key("ns", a=1, b=2)
|
||||||
OTHER_KEY = get_key("ns2", a=1, b=2)
|
OTHER_KEY = get_key("ns2", a=1, b=2)
|
||||||
|
|
||||||
|
|
||||||
def _get_lock(key: UUID, session: Session) -> Any:
|
def _get_lock(key: UUID) -> Any:
|
||||||
from superset.key_value.models import KeyValueEntry
|
from superset.key_value.models import KeyValueEntry
|
||||||
|
|
||||||
entry = session.query(KeyValueEntry).filter_by(uuid=key).first()
|
entry = db.session.query(KeyValueEntry).filter_by(uuid=key).first()
|
||||||
if entry is None or entry.is_expired():
|
if entry is None or entry.is_expired():
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return JsonKeyValueCodec().decode(entry.value)
|
return JsonKeyValueCodec().decode(entry.value)
|
||||||
|
|
||||||
|
|
||||||
def _get_other_session() -> Session:
|
|
||||||
# This session is used to simulate what another worker will find in the metastore
|
|
||||||
# during the locking process.
|
|
||||||
from superset import db
|
|
||||||
|
|
||||||
bind = db.session.get_bind()
|
|
||||||
SessionMaker = sessionmaker(bind=bind)
|
|
||||||
return SessionMaker()
|
|
||||||
|
|
||||||
|
|
||||||
def test_key_value_distributed_lock_happy_path() -> None:
|
def test_key_value_distributed_lock_happy_path() -> None:
|
||||||
"""
|
"""
|
||||||
Test successfully acquiring and returning the distributed lock.
|
Test successfully acquiring and returning the distributed lock.
|
||||||
|
|
||||||
|
Note we use a nested transaction to ensure that the cleanup from the outer context
|
||||||
|
manager is correctly invoked, otherwise a partial rollback would occur leaving the
|
||||||
|
database in a fractured state.
|
||||||
"""
|
"""
|
||||||
session = _get_other_session()
|
|
||||||
|
|
||||||
with freeze_time("2021-01-01"):
|
with freeze_time("2021-01-01"):
|
||||||
assert _get_lock(MAIN_KEY, session) is None
|
assert _get_lock(MAIN_KEY) is None
|
||||||
|
|
||||||
with KeyValueDistributedLock("ns", a=1, b=2) as key:
|
with KeyValueDistributedLock("ns", a=1, b=2) as key:
|
||||||
assert key == MAIN_KEY
|
assert key == MAIN_KEY
|
||||||
assert _get_lock(key, session) is True
|
assert _get_lock(key) is True
|
||||||
assert _get_lock(OTHER_KEY, session) is None
|
assert _get_lock(OTHER_KEY) is None
|
||||||
with pytest.raises(CreateKeyValueDistributedLockFailedException):
|
|
||||||
with KeyValueDistributedLock("ns", a=1, b=2):
|
|
||||||
pass
|
|
||||||
|
|
||||||
assert _get_lock(MAIN_KEY, session) is None
|
with db.session.begin_nested():
|
||||||
|
with pytest.raises(CreateKeyValueDistributedLockFailedException):
|
||||||
|
with KeyValueDistributedLock("ns", a=1, b=2):
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert _get_lock(MAIN_KEY) is None
|
||||||
|
|
||||||
|
|
||||||
def test_key_value_distributed_lock_expired() -> None:
|
def test_key_value_distributed_lock_expired() -> None:
|
||||||
"""
|
"""
|
||||||
Test expiration of the distributed lock
|
Test expiration of the distributed lock
|
||||||
"""
|
"""
|
||||||
session = _get_other_session()
|
|
||||||
|
|
||||||
with freeze_time("2021-01-01T"):
|
with freeze_time("2021-01-01"):
|
||||||
assert _get_lock(MAIN_KEY, session) is None
|
assert _get_lock(MAIN_KEY) is None
|
||||||
with KeyValueDistributedLock("ns", a=1, b=2):
|
with KeyValueDistributedLock("ns", a=1, b=2):
|
||||||
assert _get_lock(MAIN_KEY, session) is True
|
assert _get_lock(MAIN_KEY) is True
|
||||||
with freeze_time("2022-01-01T"):
|
with freeze_time("2022-01-01"):
|
||||||
assert _get_lock(MAIN_KEY, session) is None
|
assert _get_lock(MAIN_KEY) is None
|
||||||
|
|
||||||
assert _get_lock(MAIN_KEY, session) is None
|
assert _get_lock(MAIN_KEY) is None
|
||||||
|
|
Loading…
Reference in New Issue