chore(tests): Remove unnecessary/problematic app contexts (#28159)

This commit is contained in:
John Bodley 2024-04-24 13:46:35 -07:00 committed by GitHub
parent a9075fdb1f
commit bc65c245fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 880 additions and 980 deletions

View File

@ -81,36 +81,6 @@ DB_ACCESS_ROLE = "db_access_role"
SCHEMA_ACCESS_ROLE = "schema_access_role" SCHEMA_ACCESS_ROLE = "schema_access_role"
class TestRequestAccess(SupersetTestCase):
@classmethod
def setUpClass(cls):
with app.app_context():
security_manager.add_role("override_me")
security_manager.add_role(TEST_ROLE_1)
security_manager.add_role(TEST_ROLE_2)
security_manager.add_role(DB_ACCESS_ROLE)
security_manager.add_role(SCHEMA_ACCESS_ROLE)
db.session.commit()
@classmethod
def tearDownClass(cls):
with app.app_context():
override_me = security_manager.find_role("override_me")
db.session.delete(override_me)
db.session.delete(security_manager.find_role(TEST_ROLE_1))
db.session.delete(security_manager.find_role(TEST_ROLE_2))
db.session.delete(security_manager.find_role(DB_ACCESS_ROLE))
db.session.delete(security_manager.find_role(SCHEMA_ACCESS_ROLE))
db.session.commit()
def tearDown(self):
override_me = security_manager.find_role("override_me")
override_me.permissions = []
db.session.commit()
db.session.close()
super().tearDown()
@pytest.mark.parametrize( @pytest.mark.parametrize(
"username,user_id", "username,user_id",
[ [

View File

@ -14,18 +14,16 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# isort:skip_file
import pytest
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional
import pytest
from flask.ctx import AppContext
from superset import db from superset import db
from superset.models.annotations import Annotation, AnnotationLayer from superset.models.annotations import Annotation, AnnotationLayer
from tests.integration_tests.test_app import app from tests.integration_tests.test_app import app
ANNOTATION_LAYERS_COUNT = 10 ANNOTATION_LAYERS_COUNT = 10
ANNOTATIONS_COUNT = 5 ANNOTATIONS_COUNT = 5
@ -70,36 +68,35 @@ def _insert_annotation(
@pytest.fixture() @pytest.fixture()
def create_annotation_layers(): def create_annotation_layers(app_context: AppContext):
""" """
Creates ANNOTATION_LAYERS_COUNT-1 layers with no annotations Creates ANNOTATION_LAYERS_COUNT-1 layers with no annotations
and a final one with ANNOTATION_COUNT children and a final one with ANNOTATION_COUNT children
:return: :return:
""" """
with app.app_context(): annotation_layers = []
annotation_layers = [] annotations = []
annotations = [] for cx in range(ANNOTATION_LAYERS_COUNT - 1):
for cx in range(ANNOTATION_LAYERS_COUNT - 1): annotation_layers.append(
annotation_layers.append( _insert_annotation_layer(name=f"name{cx}", descr=f"descr{cx}")
_insert_annotation_layer(name=f"name{cx}", descr=f"descr{cx}") )
layer_with_annotations = _insert_annotation_layer("layer_with_annotations")
annotation_layers.append(layer_with_annotations)
for cx in range(ANNOTATIONS_COUNT):
annotations.append(
_insert_annotation(
layer_with_annotations,
short_descr=f"short_descr{cx}",
long_descr=f"long_descr{cx}",
start_dttm=get_start_dttm(cx),
end_dttm=get_end_dttm(cx),
) )
layer_with_annotations = _insert_annotation_layer("layer_with_annotations") )
annotation_layers.append(layer_with_annotations) yield annotation_layers
for cx in range(ANNOTATIONS_COUNT):
annotations.append(
_insert_annotation(
layer_with_annotations,
short_descr=f"short_descr{cx}",
long_descr=f"long_descr{cx}",
start_dttm=get_start_dttm(cx),
end_dttm=get_end_dttm(cx),
)
)
yield annotation_layers
# rollback changes # rollback changes
for annotation_layer in annotation_layers: for annotation_layer in annotation_layers:
db.session.delete(annotation_layer) db.session.delete(annotation_layer)
for annotation in annotations: for annotation in annotations:
db.session.delete(annotation) db.session.delete(annotation)
db.session.commit() db.session.commit()

View File

@ -23,6 +23,7 @@ from zipfile import is_zipfile, ZipFile
import prison import prison
import pytest import pytest
import yaml import yaml
from flask.ctx import AppContext
from flask_babel import lazy_gettext as _ from flask_babel import lazy_gettext as _
from parameterized import parameterized from parameterized import parameterized
from sqlalchemy import and_ from sqlalchemy import and_
@ -82,121 +83,115 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
resource_name = "chart" resource_name = "chart"
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def clear_data_cache(self): def clear_data_cache(self, app_context: AppContext):
with app.app_context(): cache_manager.data_cache.clear()
cache_manager.data_cache.clear() yield
yield
@pytest.fixture() @pytest.fixture()
def create_charts(self): def create_charts(self):
with self.create_app().app_context(): charts = []
charts = [] admin = self.get_user("admin")
admin = self.get_user("admin") for cx in range(CHARTS_FIXTURE_COUNT - 1):
for cx in range(CHARTS_FIXTURE_COUNT - 1): charts.append(self.insert_chart(f"name{cx}", [admin.id], 1))
charts.append(self.insert_chart(f"name{cx}", [admin.id], 1)) fav_charts = []
fav_charts = [] for cx in range(round(CHARTS_FIXTURE_COUNT / 2)):
for cx in range(round(CHARTS_FIXTURE_COUNT / 2)): fav_star = FavStar(
fav_star = FavStar( user_id=admin.id, class_name="slice", obj_id=charts[cx].id
user_id=admin.id, class_name="slice", obj_id=charts[cx].id )
) db.session.add(fav_star)
db.session.add(fav_star)
db.session.commit()
fav_charts.append(fav_star)
yield charts
# rollback changes
for chart in charts:
db.session.delete(chart)
for fav_chart in fav_charts:
db.session.delete(fav_chart)
db.session.commit() db.session.commit()
fav_charts.append(fav_star)
yield charts
# rollback changes
for chart in charts:
db.session.delete(chart)
for fav_chart in fav_charts:
db.session.delete(fav_chart)
db.session.commit()
@pytest.fixture() @pytest.fixture()
def create_charts_created_by_gamma(self): def create_charts_created_by_gamma(self):
with self.create_app().app_context(): charts = []
charts = [] user = self.get_user("gamma")
user = self.get_user("gamma") for cx in range(CHARTS_FIXTURE_COUNT - 1):
for cx in range(CHARTS_FIXTURE_COUNT - 1): charts.append(self.insert_chart(f"gamma{cx}", [user.id], 1))
charts.append(self.insert_chart(f"gamma{cx}", [user.id], 1)) yield charts
yield charts # rollback changes
# rollback changes for chart in charts:
for chart in charts: db.session.delete(chart)
db.session.delete(chart) db.session.commit()
db.session.commit()
@pytest.fixture() @pytest.fixture()
def create_certified_charts(self): def create_certified_charts(self):
with self.create_app().app_context(): certified_charts = []
certified_charts = [] admin = self.get_user("admin")
admin = self.get_user("admin") for cx in range(CHARTS_FIXTURE_COUNT):
for cx in range(CHARTS_FIXTURE_COUNT): certified_charts.append(
certified_charts.append( self.insert_chart(
self.insert_chart( f"certified{cx}",
f"certified{cx}", [admin.id],
[admin.id], 1,
1, certified_by="John Doe",
certified_by="John Doe", certification_details="Sample certification",
certification_details="Sample certification",
)
) )
)
yield certified_charts yield certified_charts
# rollback changes # rollback changes
for chart in certified_charts: for chart in certified_charts:
db.session.delete(chart) db.session.delete(chart)
db.session.commit() db.session.commit()
@pytest.fixture() @pytest.fixture()
def create_chart_with_report(self): def create_chart_with_report(self):
with self.create_app().app_context(): admin = self.get_user("admin")
admin = self.get_user("admin") chart = self.insert_chart(f"chart_report", [admin.id], 1)
chart = self.insert_chart(f"chart_report", [admin.id], 1) report_schedule = ReportSchedule(
report_schedule = ReportSchedule( type=ReportScheduleType.REPORT,
type=ReportScheduleType.REPORT, name="report_with_chart",
name="report_with_chart", crontab="* * * * *",
crontab="* * * * *", chart=chart,
chart=chart, )
) db.session.commit()
db.session.commit()
yield chart yield chart
# rollback changes # rollback changes
db.session.delete(report_schedule) db.session.delete(report_schedule)
db.session.delete(chart) db.session.delete(chart)
db.session.commit() db.session.commit()
@pytest.fixture() @pytest.fixture()
def add_dashboard_to_chart(self): def add_dashboard_to_chart(self):
with self.create_app().app_context(): admin = self.get_user("admin")
admin = self.get_user("admin")
self.chart = self.insert_chart("My chart", [admin.id], 1) self.chart = self.insert_chart("My chart", [admin.id], 1)
self.original_dashboard = Dashboard() self.original_dashboard = Dashboard()
self.original_dashboard.dashboard_title = "Original Dashboard" self.original_dashboard.dashboard_title = "Original Dashboard"
self.original_dashboard.slug = "slug" self.original_dashboard.slug = "slug"
self.original_dashboard.owners = [admin] self.original_dashboard.owners = [admin]
self.original_dashboard.slices = [self.chart] self.original_dashboard.slices = [self.chart]
self.original_dashboard.published = False self.original_dashboard.published = False
db.session.add(self.original_dashboard) db.session.add(self.original_dashboard)
self.new_dashboard = Dashboard() self.new_dashboard = Dashboard()
self.new_dashboard.dashboard_title = "New Dashboard" self.new_dashboard.dashboard_title = "New Dashboard"
self.new_dashboard.slug = "new_slug" self.new_dashboard.slug = "new_slug"
self.new_dashboard.owners = [admin] self.new_dashboard.owners = [admin]
self.new_dashboard.published = False self.new_dashboard.published = False
db.session.add(self.new_dashboard) db.session.add(self.new_dashboard)
db.session.commit() db.session.commit()
yield self.chart yield self.chart
db.session.delete(self.original_dashboard) db.session.delete(self.original_dashboard)
db.session.delete(self.new_dashboard) db.session.delete(self.new_dashboard)
db.session.delete(self.chart) db.session.delete(self.chart)
db.session.commit() db.session.commit()
def test_info_security_chart(self): def test_info_security_chart(self):
""" """
@ -1127,40 +1122,39 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
@pytest.fixture() @pytest.fixture()
def load_energy_charts(self): def load_energy_charts(self):
with app.app_context(): admin = self.get_user("admin")
admin = self.get_user("admin") energy_table = (
energy_table = ( db.session.query(SqlaTable)
db.session.query(SqlaTable) .filter_by(table_name="energy_usage")
.filter_by(table_name="energy_usage") .one_or_none()
.one_or_none() )
) energy_table_id = 1
energy_table_id = 1 if energy_table:
if energy_table: energy_table_id = energy_table.id
energy_table_id = energy_table.id chart1 = self.insert_chart(
chart1 = self.insert_chart( "foo_a", [admin.id], energy_table_id, description="ZY_bar"
"foo_a", [admin.id], energy_table_id, description="ZY_bar" )
) chart2 = self.insert_chart(
chart2 = self.insert_chart( "zy_foo", [admin.id], energy_table_id, description="desc1"
"zy_foo", [admin.id], energy_table_id, description="desc1" )
) chart3 = self.insert_chart(
chart3 = self.insert_chart( "foo_b", [admin.id], energy_table_id, description="desc1zy_"
"foo_b", [admin.id], energy_table_id, description="desc1zy_" )
) chart4 = self.insert_chart(
chart4 = self.insert_chart( "foo_c", [admin.id], energy_table_id, viz_type="viz_zy_"
"foo_c", [admin.id], energy_table_id, viz_type="viz_zy_" )
) chart5 = self.insert_chart(
chart5 = self.insert_chart( "bar", [admin.id], energy_table_id, description="foo"
"bar", [admin.id], energy_table_id, description="foo" )
)
yield yield
# rollback changes # rollback changes
db.session.delete(chart1) db.session.delete(chart1)
db.session.delete(chart2) db.session.delete(chart2)
db.session.delete(chart3) db.session.delete(chart3)
db.session.delete(chart4) db.session.delete(chart4)
db.session.delete(chart5) db.session.delete(chart5)
db.session.commit() db.session.commit()
@pytest.mark.usefixtures("load_energy_charts") @pytest.mark.usefixtures("load_energy_charts")
def test_get_charts_custom_filter(self): def test_get_charts_custom_filter(self):

View File

@ -27,6 +27,7 @@ from unittest import mock
from zipfile import ZipFile from zipfile import ZipFile
from flask import Response from flask import Response
from flask.ctx import AppContext
from tests.integration_tests.conftest import with_feature_flags from tests.integration_tests.conftest import with_feature_flags
from superset.charts.data.api import ChartDataRestApi from superset.charts.data.api import ChartDataRestApi
from superset.models.sql_lab import Query from superset.models.sql_lab import Query
@ -88,10 +89,9 @@ INCOMPATIBLE_ADHOC_COLUMN_FIXTURE: AdhocColumn = {
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def skip_by_backend(): def skip_by_backend(app_context: AppContext):
with app.app_context(): if backend() == "hive":
if backend() == "hive": pytest.skip("Skipping tests for Hive backend")
pytest.skip("Skipping tests for Hive backend")
class BaseTestChartDataApi(SupersetTestCase): class BaseTestChartDataApi(SupersetTestCase):

View File

@ -118,8 +118,8 @@ def get_or_create_user(get_user, create_user) -> ab_models.User:
@pytest.fixture(autouse=True, scope="session") @pytest.fixture(autouse=True, scope="session")
def setup_sample_data() -> Any: def setup_sample_data() -> Any:
# TODO(john-bodley): Determine a cleaner way of setting up the sample data without # TODO(john-bodley): Determine a cleaner way of setting up the sample data without
# relying on `tests.integration_tests.test_app.app` leveraging an `app` fixture which is purposely # relying on `tests.integration_tests.test_app.app` leveraging an `app` fixture
# scoped to the function level to ensure tests remain idempotent. # which is purposely scoped to the function level to ensure tests remain idempotent.
with app.app_context(): with app.app_context():
setup_presto_if_needed() setup_presto_if_needed()
@ -135,7 +135,6 @@ def setup_sample_data() -> Any:
with app.app_context(): with app.app_context():
# drop sqlalchemy tables # drop sqlalchemy tables
db.session.commit() db.session.commit()
from sqlalchemy.ext import declarative from sqlalchemy.ext import declarative
@ -163,12 +162,12 @@ def example_db_provider() -> Callable[[], Database]: # type: ignore
_db: Database | None = None _db: Database | None = None
def __call__(self) -> Database: def __call__(self) -> Database:
with app.app_context(): if self._db is None:
if self._db is None: with app.app_context():
self._db = get_example_database() self._db = get_example_database()
self._load_lazy_data_to_decouple_from_session() self._load_lazy_data_to_decouple_from_session()
return self._db return self._db
def _load_lazy_data_to_decouple_from_session(self) -> None: def _load_lazy_data_to_decouple_from_session(self) -> None:
self._db._get_sqla_engine() # type: ignore self._db._get_sqla_engine() # type: ignore

View File

@ -58,39 +58,36 @@ from .base_tests import SupersetTestCase
class TestDashboard(SupersetTestCase): class TestDashboard(SupersetTestCase):
@pytest.fixture @pytest.fixture
def load_dashboard(self): def load_dashboard(self):
with app.app_context(): table = db.session.query(SqlaTable).filter_by(table_name="energy_usage").one()
table = ( # get a slice from the allowed table
db.session.query(SqlaTable).filter_by(table_name="energy_usage").one() slice = db.session.query(Slice).filter_by(slice_name="Energy Sankey").one()
)
# get a slice from the allowed table
slice = db.session.query(Slice).filter_by(slice_name="Energy Sankey").one()
self.grant_public_access_to_table(table) self.grant_public_access_to_table(table)
pytest.hidden_dash_slug = f"hidden_dash_{random()}" pytest.hidden_dash_slug = f"hidden_dash_{random()}"
pytest.published_dash_slug = f"published_dash_{random()}" pytest.published_dash_slug = f"published_dash_{random()}"
# Create a published and hidden dashboard and add them to the database # Create a published and hidden dashboard and add them to the database
published_dash = Dashboard() published_dash = Dashboard()
published_dash.dashboard_title = "Published Dashboard" published_dash.dashboard_title = "Published Dashboard"
published_dash.slug = pytest.published_dash_slug published_dash.slug = pytest.published_dash_slug
published_dash.slices = [slice] published_dash.slices = [slice]
published_dash.published = True published_dash.published = True
hidden_dash = Dashboard() hidden_dash = Dashboard()
hidden_dash.dashboard_title = "Hidden Dashboard" hidden_dash.dashboard_title = "Hidden Dashboard"
hidden_dash.slug = pytest.hidden_dash_slug hidden_dash.slug = pytest.hidden_dash_slug
hidden_dash.slices = [slice] hidden_dash.slices = [slice]
hidden_dash.published = False hidden_dash.published = False
db.session.add(published_dash) db.session.add(published_dash)
db.session.add(hidden_dash) db.session.add(hidden_dash)
yield db.session.commit() yield db.session.commit()
self.revoke_public_access_to_table(table) self.revoke_public_access_to_table(table)
db.session.delete(published_dash) db.session.delete(published_dash)
db.session.delete(hidden_dash) db.session.delete(hidden_dash)
db.session.commit() db.session.commit()
def get_mock_positions(self, dash): def get_mock_positions(self, dash):
positions = {"DASHBOARD_VERSION_KEY": "v2"} positions = {"DASHBOARD_VERSION_KEY": "v2"}

View File

@ -2088,8 +2088,6 @@ class TestDashboardApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCas
self.assertNotEqual(result["uuid"], "") self.assertNotEqual(result["uuid"], "")
self.assertEqual(result["allowed_domains"], allowed_domains) self.assertEqual(result["allowed_domains"], allowed_domains)
db.session.expire_all()
# get returns value # get returns value
resp = self.get_assert_metric(uri, "get_embedded") resp = self.get_assert_metric(uri, "get_embedded")
self.assertEqual(resp.status_code, 200) self.assertEqual(resp.status_code, 200)
@ -2110,8 +2108,6 @@ class TestDashboardApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCas
self.assertNotEqual(result["uuid"], "") self.assertNotEqual(result["uuid"], "")
self.assertEqual(result["allowed_domains"], []) self.assertEqual(result["allowed_domains"], [])
db.session.expire_all()
# get returns changed value # get returns changed value
resp = self.get_assert_metric(uri, "get_embedded") resp = self.get_assert_metric(uri, "get_embedded")
self.assertEqual(resp.status_code, 200) self.assertEqual(resp.status_code, 200)
@ -2123,8 +2119,6 @@ class TestDashboardApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCas
resp = self.delete_assert_metric(uri, "delete_embedded") resp = self.delete_assert_metric(uri, "delete_embedded")
self.assertEqual(resp.status_code, 200) self.assertEqual(resp.status_code, 200)
db.session.expire_all()
# get returns 404 # get returns 404
resp = self.get_assert_metric(uri, "get_embedded") resp = self.get_assert_metric(uri, "get_embedded")
self.assertEqual(resp.status_code, 404) self.assertEqual(resp.status_code, 404)

View File

@ -37,39 +37,36 @@ from tests.integration_tests.fixtures.energy_dashboard import (
class TestDashboardDatasetSecurity(DashboardTestCase): class TestDashboardDatasetSecurity(DashboardTestCase):
@pytest.fixture @pytest.fixture
def load_dashboard(self): def load_dashboard(self):
with app.app_context(): table = db.session.query(SqlaTable).filter_by(table_name="energy_usage").one()
table = ( # get a slice from the allowed table
db.session.query(SqlaTable).filter_by(table_name="energy_usage").one() slice = db.session.query(Slice).filter_by(slice_name="Energy Sankey").one()
)
# get a slice from the allowed table
slice = db.session.query(Slice).filter_by(slice_name="Energy Sankey").one()
self.grant_public_access_to_table(table) self.grant_public_access_to_table(table)
pytest.hidden_dash_slug = f"hidden_dash_{random_slug()}" pytest.hidden_dash_slug = f"hidden_dash_{random_slug()}"
pytest.published_dash_slug = f"published_dash_{random_slug()}" pytest.published_dash_slug = f"published_dash_{random_slug()}"
# Create a published and hidden dashboard and add them to the database # Create a published and hidden dashboard and add them to the database
published_dash = Dashboard() published_dash = Dashboard()
published_dash.dashboard_title = "Published Dashboard" published_dash.dashboard_title = "Published Dashboard"
published_dash.slug = pytest.published_dash_slug published_dash.slug = pytest.published_dash_slug
published_dash.slices = [slice] published_dash.slices = [slice]
published_dash.published = True published_dash.published = True
hidden_dash = Dashboard() hidden_dash = Dashboard()
hidden_dash.dashboard_title = "Hidden Dashboard" hidden_dash.dashboard_title = "Hidden Dashboard"
hidden_dash.slug = pytest.hidden_dash_slug hidden_dash.slug = pytest.hidden_dash_slug
hidden_dash.slices = [slice] hidden_dash.slices = [slice]
hidden_dash.published = False hidden_dash.published = False
db.session.add(published_dash) db.session.add(published_dash)
db.session.add(hidden_dash) db.session.add(hidden_dash)
yield db.session.commit() yield db.session.commit()
self.revoke_public_access_to_table(table) self.revoke_public_access_to_table(table)
db.session.delete(published_dash) db.session.delete(published_dash)
db.session.delete(hidden_dash) db.session.delete(hidden_dash)
db.session.commit() db.session.commit()
def test_dashboard_access__admin_can_access_all(self): def test_dashboard_access__admin_can_access_all(self):
# arrange # arrange

View File

@ -20,6 +20,7 @@ from __future__ import annotations
import json import json
import pytest import pytest
from flask.ctx import AppContext
from superset import db, security_manager from superset import db, security_manager
from superset.commands.database.exceptions import ( from superset.commands.database.exceptions import (
@ -84,16 +85,14 @@ def get_upload_db():
return db.session.query(Database).filter_by(database_name=CSV_UPLOAD_DATABASE).one() return db.session.query(Database).filter_by(database_name=CSV_UPLOAD_DATABASE).one()
@pytest.fixture(scope="function") @pytest.fixture()
def setup_csv_upload_with_context(): def setup_csv_upload_with_context(app_context: AppContext):
with app.app_context(): yield from _setup_csv_upload()
yield from _setup_csv_upload()
@pytest.fixture(scope="function") @pytest.fixture()
def setup_csv_upload_with_context_schema(): def setup_csv_upload_with_context_schema(app_context: AppContext):
with app.app_context(): yield from _setup_csv_upload(["public"])
yield from _setup_csv_upload(["public"])
@pytest.mark.usefixtures("setup_csv_upload_with_context") @pytest.mark.usefixtures("setup_csv_upload_with_context")

View File

@ -46,6 +46,5 @@ class TestEmbeddedDashboardDAO(SupersetTestCase):
def test_get_by_uuid(self): def test_get_by_uuid(self):
dash = db.session.query(Dashboard).filter_by(slug="world_health").first() dash = db.session.query(Dashboard).filter_by(slug="world_health").first()
uuid = str(EmbeddedDashboardDAO.upsert(dash, ["test.example.com"]).uuid) uuid = str(EmbeddedDashboardDAO.upsert(dash, ["test.example.com"]).uuid)
db.session.expire_all()
embedded = EmbeddedDashboardDAO.find_by_id(uuid) embedded = EmbeddedDashboardDAO.find_by_id(uuid)
self.assertIsNotNone(embedded) self.assertIsNotNone(embedded)

View File

@ -173,39 +173,38 @@ def get_datasource_post() -> dict[str, Any]:
@pytest.fixture() @pytest.fixture()
@pytest.mark.usefixtures("app_conntext")
def load_dataset_with_columns() -> Generator[SqlaTable, None, None]: def load_dataset_with_columns() -> Generator[SqlaTable, None, None]:
with app.app_context(): engine = create_engine(app.config["SQLALCHEMY_DATABASE_URI"], echo=True)
engine = create_engine(app.config["SQLALCHEMY_DATABASE_URI"], echo=True) meta = MetaData()
meta = MetaData()
students = Table( students = Table(
"students", "students",
meta, meta,
Column("id", Integer, primary_key=True), Column("id", Integer, primary_key=True),
Column("name", String(255)), Column("name", String(255)),
Column("lastname", String(255)), Column("lastname", String(255)),
Column("ds", Date), Column("ds", Date),
) )
meta.create_all(engine) meta.create_all(engine)
students.insert().values(name="George", ds="2021-01-01") students.insert().values(name="George", ds="2021-01-01")
dataset = SqlaTable( dataset = SqlaTable(
database_id=db.session.query(Database).first().id, table_name="students" database_id=db.session.query(Database).first().id, table_name="students"
) )
column = TableColumn(table_id=dataset.id, column_name="name") column = TableColumn(table_id=dataset.id, column_name="name")
dataset.columns = [column] dataset.columns = [column]
db.session.add(dataset) db.session.add(dataset)
db.session.commit() db.session.commit()
yield dataset yield dataset
# cleanup # cleanup
students_table = meta.tables.get("students") if (students_table := meta.tables.get("students")) is not None:
if students_table is not None: base = declarative_base()
base = declarative_base() # needed for sqlite
# needed for sqlite
db.session.commit()
base.metadata.drop_all(engine, [students_table], checkfirst=True)
db.session.delete(dataset)
db.session.delete(column)
db.session.commit() db.session.commit()
base.metadata.drop_all(engine, [students_table], checkfirst=True)
db.session.delete(dataset)
db.session.delete(column)
db.session.commit()

View File

@ -15,30 +15,29 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import pytest import pytest
from flask.ctx import AppContext
from superset.extensions import db, security_manager from superset.extensions import db, security_manager
from tests.integration_tests.test_app import app from tests.integration_tests.test_app import app
@pytest.fixture() @pytest.fixture()
def public_role_like_gamma(): def public_role_like_gamma(app_context: AppContext):
with app.app_context(): app.config["PUBLIC_ROLE_LIKE"] = "Gamma"
app.config["PUBLIC_ROLE_LIKE"] = "Gamma" security_manager.sync_role_definitions()
security_manager.sync_role_definitions()
yield yield
security_manager.get_public_role().permissions = [] security_manager.get_public_role().permissions = []
db.session.commit() db.session.commit()
@pytest.fixture() @pytest.fixture()
def public_role_like_test_role(): def public_role_like_test_role(app_context: AppContext):
with app.app_context(): app.config["PUBLIC_ROLE_LIKE"] = "TestRole"
app.config["PUBLIC_ROLE_LIKE"] = "TestRole" security_manager.sync_role_definitions()
security_manager.sync_role_definitions()
yield yield
security_manager.get_public_role().permissions = [] security_manager.get_public_role().permissions = []
db.session.commit() db.session.commit()

View File

@ -22,12 +22,12 @@ from tests.integration_tests.test_app import app
@pytest.fixture @pytest.fixture
@pytest.mark.usefixtures("app_context")
def with_tagging_system_feature(): def with_tagging_system_feature():
with app.app_context(): is_enabled = app.config["DEFAULT_FEATURE_FLAGS"]["TAGGING_SYSTEM"]
is_enabled = app.config["DEFAULT_FEATURE_FLAGS"]["TAGGING_SYSTEM"] if not is_enabled:
if not is_enabled: app.config["DEFAULT_FEATURE_FLAGS"]["TAGGING_SYSTEM"] = True
app.config["DEFAULT_FEATURE_FLAGS"]["TAGGING_SYSTEM"] = True register_sqla_event_listeners()
register_sqla_event_listeners() yield
yield app.config["DEFAULT_FEATURE_FLAGS"]["TAGGING_SYSTEM"] = False
app.config["DEFAULT_FEATURE_FLAGS"]["TAGGING_SYSTEM"] = False clear_sqla_event_listeners()
clear_sqla_event_listeners()

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import pytest import pytest
from flask.ctx import AppContext
from flask_appbuilder.security.sqla.models import Role, User from flask_appbuilder.security.sqla.models import Role, User
from superset import db, security_manager from superset import db, security_manager
@ -23,27 +24,24 @@ from tests.integration_tests.test_app import app
@pytest.fixture() @pytest.fixture()
def create_gamma_sqllab_no_data(): def create_gamma_sqllab_no_data(app_context: AppContext):
with app.app_context(): gamma_role = db.session.query(Role).filter(Role.name == "Gamma").one_or_none()
gamma_role = db.session.query(Role).filter(Role.name == "Gamma").one_or_none() sqllab_role = db.session.query(Role).filter(Role.name == "sql_lab").one_or_none()
sqllab_role = (
db.session.query(Role).filter(Role.name == "sql_lab").one_or_none()
)
security_manager.add_user( security_manager.add_user(
GAMMA_SQLLAB_NO_DATA_USERNAME, GAMMA_SQLLAB_NO_DATA_USERNAME,
"gamma_sqllab_no_data", "gamma_sqllab_no_data",
"gamma_sqllab_no_data", "gamma_sqllab_no_data",
"gamma_sqllab_no_data@apache.org", "gamma_sqllab_no_data@apache.org",
[gamma_role, sqllab_role], [gamma_role, sqllab_role],
password="general", password="general",
) )
yield yield
user = ( user = (
db.session.query(User) db.session.query(User)
.filter(User.username == GAMMA_SQLLAB_NO_DATA_USERNAME) .filter(User.username == GAMMA_SQLLAB_NO_DATA_USERNAME)
.one_or_none() .one_or_none()
) )
db.session.delete(user) db.session.delete(user)
db.session.commit() db.session.commit()

View File

@ -93,13 +93,12 @@ def load_world_bank_dashboard_with_slices_class_scope(load_world_bank_data):
def create_dashboard_for_loaded_data(): def create_dashboard_for_loaded_data():
with app.app_context(): table = create_table_metadata(WB_HEALTH_POPULATION, get_example_database())
table = create_table_metadata(WB_HEALTH_POPULATION, get_example_database()) slices = _create_world_bank_slices(table)
slices = _create_world_bank_slices(table) dash = _create_world_bank_dashboard(table)
dash = _create_world_bank_dashboard(table) slices_ids_to_delete = [slice.id for slice in slices]
slices_ids_to_delete = [slice.id for slice in slices] dash_id_to_delete = dash.id
dash_id_to_delete = dash.id return dash_id_to_delete, slices_ids_to_delete
return dash_id_to_delete, slices_ids_to_delete
def _create_world_bank_slices(table: SqlaTable) -> list[Slice]: def _create_world_bank_slices(table: SqlaTable) -> list[Slice]:

View File

@ -16,6 +16,8 @@
# under the License. # under the License.
from importlib import import_module from importlib import import_module
import pytest
from superset import db from superset import db
from superset.migrations.shared.security_converge import ( from superset.migrations.shared.security_converge import (
_find_pvm, _find_pvm,
@ -34,28 +36,28 @@ upgrade = migration_module.do_upgrade
downgrade = migration_module.do_downgrade downgrade = migration_module.do_downgrade
@pytest.mark.usefixtures("app_context")
def test_migration_upgrade(): def test_migration_upgrade():
with app.app_context(): pre_perm = PermissionView(
pre_perm = PermissionView( permission=Permission(name="can_view_and_drill"),
permission=Permission(name="can_view_and_drill"), view_menu=db.session.query(ViewMenu).filter_by(name="Dashboard").one(),
view_menu=db.session.query(ViewMenu).filter_by(name="Dashboard").one(), )
) db.session.add(pre_perm)
db.session.add(pre_perm) db.session.commit()
db.session.commit()
assert _find_pvm(db.session, "Dashboard", "can_view_and_drill") is not None assert _find_pvm(db.session, "Dashboard", "can_view_and_drill") is not None
upgrade(db.session) upgrade(db.session)
assert _find_pvm(db.session, "Dashboard", "can_view_chart_as_table") is not None assert _find_pvm(db.session, "Dashboard", "can_view_chart_as_table") is not None
assert _find_pvm(db.session, "Dashboard", "can_view_query") is not None assert _find_pvm(db.session, "Dashboard", "can_view_query") is not None
assert _find_pvm(db.session, "Dashboard", "can_view_and_drill") is None assert _find_pvm(db.session, "Dashboard", "can_view_and_drill") is None
@pytest.mark.usefixtures("app_context")
def test_migration_downgrade(): def test_migration_downgrade():
with app.app_context(): downgrade(db.session)
downgrade(db.session)
assert _find_pvm(db.session, "Dashboard", "can_view_chart_as_table") is None assert _find_pvm(db.session, "Dashboard", "can_view_chart_as_table") is None
assert _find_pvm(db.session, "Dashboard", "can_view_query") is None assert _find_pvm(db.session, "Dashboard", "can_view_query") is None
assert _find_pvm(db.session, "Dashboard", "can_view_and_drill") is not None assert _find_pvm(db.session, "Dashboard", "can_view_and_drill") is not None

View File

@ -20,6 +20,7 @@ from typing import Optional, Union
import pandas as pd import pandas as pd
import pytest import pytest
from flask.ctx import AppContext
from pytest_mock import MockFixture from pytest_mock import MockFixture
from superset.commands.report.exceptions import AlertQueryError from superset.commands.report.exceptions import AlertQueryError
@ -61,43 +62,40 @@ def test_execute_query_as_report_executor(
config: list[ExecutorType], config: list[ExecutorType],
expected_result: Union[tuple[ExecutorType, str], Exception], expected_result: Union[tuple[ExecutorType, str], Exception],
mocker: MockFixture, mocker: MockFixture,
app_context: None, app_context: AppContext,
get_user, get_user,
) -> None: ) -> None:
from superset.commands.report.alert import AlertCommand from superset.commands.report.alert import AlertCommand
from superset.reports.models import ReportSchedule from superset.reports.models import ReportSchedule
with app.app_context(): original_config = app.config["ALERT_REPORTS_EXECUTE_AS"]
original_config = app.config["ALERT_REPORTS_EXECUTE_AS"] app.config["ALERT_REPORTS_EXECUTE_AS"] = config
app.config["ALERT_REPORTS_EXECUTE_AS"] = config owners = [get_user(owner_name) for owner_name in owner_names]
owners = [get_user(owner_name) for owner_name in owner_names] report_schedule = ReportSchedule(
report_schedule = ReportSchedule( created_by=get_user(creator_name) if creator_name else None,
created_by=get_user(creator_name) if creator_name else None, owners=owners,
owners=owners, type=ReportScheduleType.ALERT,
type=ReportScheduleType.ALERT, description="description",
description="description", crontab="0 9 * * *",
crontab="0 9 * * *", creation_method=ReportCreationMethod.ALERTS_REPORTS,
creation_method=ReportCreationMethod.ALERTS_REPORTS, sql="SELECT 1",
sql="SELECT 1", grace_period=14400,
grace_period=14400, working_timeout=3600,
working_timeout=3600, database=get_example_database(),
database=get_example_database(), validator_config_json='{"op": "==", "threshold": 1}',
validator_config_json='{"op": "==", "threshold": 1}', )
) command = AlertCommand(report_schedule=report_schedule)
command = AlertCommand(report_schedule=report_schedule) override_user_mock = mocker.patch("superset.commands.report.alert.override_user")
override_user_mock = mocker.patch( cm = (
"superset.commands.report.alert.override_user" pytest.raises(type(expected_result))
) if isinstance(expected_result, Exception)
cm = ( else nullcontext()
pytest.raises(type(expected_result)) )
if isinstance(expected_result, Exception) with cm:
else nullcontext() command.run()
) assert override_user_mock.call_args[0][0].username == expected_result
with cm:
command.run()
assert override_user_mock.call_args[0][0].username == expected_result
app.config["ALERT_REPORTS_EXECUTE_AS"] = original_config app.config["ALERT_REPORTS_EXECUTE_AS"] = original_config
def test_execute_query_succeeded_no_retry( def test_execute_query_succeeded_no_retry(

View File

@ -23,6 +23,7 @@ from uuid import uuid4
import pytest import pytest
from flask import current_app from flask import current_app
from flask.ctx import AppContext
from flask_appbuilder.security.sqla.models import User from flask_appbuilder.security.sqla.models import User
from flask_sqlalchemy import BaseQuery from flask_sqlalchemy import BaseQuery
from freezegun import freeze_time from freezegun import freeze_time
@ -162,204 +163,191 @@ def create_test_table_context(database: Database):
@pytest.fixture() @pytest.fixture()
def create_report_email_chart(): def create_report_email_chart():
with app.app_context(): chart = db.session.query(Slice).first()
chart = db.session.query(Slice).first() report_schedule = create_report_notification(
report_schedule = create_report_notification( email_target="target@email.com", chart=chart
email_target="target@email.com", chart=chart )
) yield report_schedule
yield report_schedule
cleanup_report_schedule(report_schedule) cleanup_report_schedule(report_schedule)
@pytest.fixture() @pytest.fixture()
def create_report_email_chart_alpha_owner(get_user): def create_report_email_chart_alpha_owner(get_user):
with app.app_context(): owners = [get_user("alpha")]
owners = [get_user("alpha")] chart = db.session.query(Slice).first()
chart = db.session.query(Slice).first() report_schedule = create_report_notification(
report_schedule = create_report_notification( email_target="target@email.com", chart=chart, owners=owners
email_target="target@email.com", chart=chart, owners=owners )
) yield report_schedule
yield report_schedule
cleanup_report_schedule(report_schedule) cleanup_report_schedule(report_schedule)
@pytest.fixture() @pytest.fixture()
def create_report_email_chart_force_screenshot(): def create_report_email_chart_force_screenshot():
with app.app_context(): chart = db.session.query(Slice).first()
chart = db.session.query(Slice).first() report_schedule = create_report_notification(
report_schedule = create_report_notification( email_target="target@email.com", chart=chart, force_screenshot=True
email_target="target@email.com", chart=chart, force_screenshot=True )
) yield report_schedule
yield report_schedule
cleanup_report_schedule(report_schedule) cleanup_report_schedule(report_schedule)
@pytest.fixture() @pytest.fixture()
def create_report_email_chart_with_csv(): def create_report_email_chart_with_csv():
with app.app_context(): chart = db.session.query(Slice).first()
chart = db.session.query(Slice).first() chart.query_context = '{"mock": "query_context"}'
chart.query_context = '{"mock": "query_context"}' report_schedule = create_report_notification(
report_schedule = create_report_notification( email_target="target@email.com",
email_target="target@email.com", chart=chart,
chart=chart, report_format=ReportDataFormat.CSV,
report_format=ReportDataFormat.CSV, )
) yield report_schedule
yield report_schedule cleanup_report_schedule(report_schedule)
cleanup_report_schedule(report_schedule)
@pytest.fixture() @pytest.fixture()
def create_report_email_chart_with_text(): def create_report_email_chart_with_text():
with app.app_context(): chart = db.session.query(Slice).first()
chart = db.session.query(Slice).first() chart.query_context = '{"mock": "query_context"}'
chart.query_context = '{"mock": "query_context"}' report_schedule = create_report_notification(
report_schedule = create_report_notification( email_target="target@email.com",
email_target="target@email.com", chart=chart,
chart=chart, report_format=ReportDataFormat.TEXT,
report_format=ReportDataFormat.TEXT, )
) yield report_schedule
yield report_schedule cleanup_report_schedule(report_schedule)
cleanup_report_schedule(report_schedule)
@pytest.fixture() @pytest.fixture()
def create_report_email_chart_with_csv_no_query_context(): def create_report_email_chart_with_csv_no_query_context():
with app.app_context(): chart = db.session.query(Slice).first()
chart = db.session.query(Slice).first() chart.query_context = None
chart.query_context = None report_schedule = create_report_notification(
report_schedule = create_report_notification( email_target="target@email.com",
email_target="target@email.com", chart=chart,
chart=chart, report_format=ReportDataFormat.CSV,
report_format=ReportDataFormat.CSV, name="report_csv_no_query_context",
name="report_csv_no_query_context", )
) yield report_schedule
yield report_schedule cleanup_report_schedule(report_schedule)
cleanup_report_schedule(report_schedule)
@pytest.fixture() @pytest.fixture()
def create_report_email_dashboard(): def create_report_email_dashboard():
with app.app_context(): dashboard = db.session.query(Dashboard).first()
dashboard = db.session.query(Dashboard).first() report_schedule = create_report_notification(
report_schedule = create_report_notification( email_target="target@email.com", dashboard=dashboard
email_target="target@email.com", dashboard=dashboard )
) yield report_schedule
yield report_schedule
cleanup_report_schedule(report_schedule) cleanup_report_schedule(report_schedule)
@pytest.fixture() @pytest.fixture()
def create_report_email_dashboard_force_screenshot(): def create_report_email_dashboard_force_screenshot():
with app.app_context(): dashboard = db.session.query(Dashboard).first()
dashboard = db.session.query(Dashboard).first() report_schedule = create_report_notification(
report_schedule = create_report_notification( email_target="target@email.com", dashboard=dashboard, force_screenshot=True
email_target="target@email.com", dashboard=dashboard, force_screenshot=True )
) yield report_schedule
yield report_schedule
cleanup_report_schedule(report_schedule) cleanup_report_schedule(report_schedule)
@pytest.fixture() @pytest.fixture()
def create_report_slack_chart(): def create_report_slack_chart():
with app.app_context(): chart = db.session.query(Slice).first()
chart = db.session.query(Slice).first() report_schedule = create_report_notification(
report_schedule = create_report_notification( slack_channel="slack_channel", chart=chart
slack_channel="slack_channel", chart=chart )
) yield report_schedule
yield report_schedule
cleanup_report_schedule(report_schedule) cleanup_report_schedule(report_schedule)
@pytest.fixture() @pytest.fixture()
def create_report_slack_chart_with_csv(): def create_report_slack_chart_with_csv():
with app.app_context(): chart = db.session.query(Slice).first()
chart = db.session.query(Slice).first() chart.query_context = '{"mock": "query_context"}'
chart.query_context = '{"mock": "query_context"}' report_schedule = create_report_notification(
report_schedule = create_report_notification( slack_channel="slack_channel",
slack_channel="slack_channel", chart=chart,
chart=chart, report_format=ReportDataFormat.CSV,
report_format=ReportDataFormat.CSV, )
) yield report_schedule
yield report_schedule
cleanup_report_schedule(report_schedule) cleanup_report_schedule(report_schedule)
@pytest.fixture() @pytest.fixture()
def create_report_slack_chart_with_text(): def create_report_slack_chart_with_text():
with app.app_context(): chart = db.session.query(Slice).first()
chart = db.session.query(Slice).first() chart.query_context = '{"mock": "query_context"}'
chart.query_context = '{"mock": "query_context"}' report_schedule = create_report_notification(
report_schedule = create_report_notification( slack_channel="slack_channel",
slack_channel="slack_channel", chart=chart,
chart=chart, report_format=ReportDataFormat.TEXT,
report_format=ReportDataFormat.TEXT, )
) yield report_schedule
yield report_schedule
cleanup_report_schedule(report_schedule) cleanup_report_schedule(report_schedule)
@pytest.fixture() @pytest.fixture()
def create_report_slack_chart_working(): def create_report_slack_chart_working():
with app.app_context(): chart = db.session.query(Slice).first()
chart = db.session.query(Slice).first() report_schedule = create_report_notification(
report_schedule = create_report_notification( slack_channel="slack_channel", chart=chart
slack_channel="slack_channel", chart=chart )
) report_schedule.last_state = ReportState.WORKING
report_schedule.last_state = ReportState.WORKING report_schedule.last_eval_dttm = datetime(2020, 1, 1, 0, 0)
report_schedule.last_eval_dttm = datetime(2020, 1, 1, 0, 0) report_schedule.last_value = None
report_schedule.last_value = None report_schedule.last_value_row_json = None
report_schedule.last_value_row_json = None db.session.commit()
db.session.commit() log = ReportExecutionLog(
log = ReportExecutionLog( scheduled_dttm=report_schedule.last_eval_dttm,
scheduled_dttm=report_schedule.last_eval_dttm, start_dttm=report_schedule.last_eval_dttm,
start_dttm=report_schedule.last_eval_dttm, end_dttm=report_schedule.last_eval_dttm,
end_dttm=report_schedule.last_eval_dttm, value=report_schedule.last_value,
value=report_schedule.last_value, value_row_json=report_schedule.last_value_row_json,
value_row_json=report_schedule.last_value_row_json, state=ReportState.WORKING,
state=ReportState.WORKING, report_schedule=report_schedule,
report_schedule=report_schedule, uuid=uuid4(),
uuid=uuid4(), )
) db.session.add(log)
db.session.add(log) db.session.commit()
db.session.commit()
yield report_schedule yield report_schedule
cleanup_report_schedule(report_schedule) cleanup_report_schedule(report_schedule)
@pytest.fixture() @pytest.fixture()
def create_alert_slack_chart_success(): def create_alert_slack_chart_success():
with app.app_context(): chart = db.session.query(Slice).first()
chart = db.session.query(Slice).first() report_schedule = create_report_notification(
report_schedule = create_report_notification( slack_channel="slack_channel",
slack_channel="slack_channel", chart=chart,
chart=chart, report_type=ReportScheduleType.ALERT,
report_type=ReportScheduleType.ALERT, )
) report_schedule.last_state = ReportState.SUCCESS
report_schedule.last_state = ReportState.SUCCESS report_schedule.last_eval_dttm = datetime(2020, 1, 1, 0, 0)
report_schedule.last_eval_dttm = datetime(2020, 1, 1, 0, 0)
log = ReportExecutionLog( log = ReportExecutionLog(
report_schedule=report_schedule, report_schedule=report_schedule,
state=ReportState.SUCCESS, state=ReportState.SUCCESS,
start_dttm=report_schedule.last_eval_dttm, start_dttm=report_schedule.last_eval_dttm,
end_dttm=report_schedule.last_eval_dttm, end_dttm=report_schedule.last_eval_dttm,
scheduled_dttm=report_schedule.last_eval_dttm, scheduled_dttm=report_schedule.last_eval_dttm,
) )
db.session.add(log) db.session.add(log)
db.session.commit() db.session.commit()
yield report_schedule yield report_schedule
cleanup_report_schedule(report_schedule) cleanup_report_schedule(report_schedule)
@pytest.fixture( @pytest.fixture(
@ -375,36 +363,33 @@ def create_alert_slack_chart_grace(request):
"validator_config_json": '{"op": "<", "threshold": 10}', "validator_config_json": '{"op": "<", "threshold": 10}',
}, },
} }
with app.app_context(): chart = db.session.query(Slice).first()
chart = db.session.query(Slice).first() example_database = get_example_database()
example_database = get_example_database() with create_test_table_context(example_database):
with create_test_table_context(example_database): report_schedule = create_report_notification(
report_schedule = create_report_notification( slack_channel="slack_channel",
slack_channel="slack_channel", chart=chart,
chart=chart, report_type=ReportScheduleType.ALERT,
report_type=ReportScheduleType.ALERT, database=example_database,
database=example_database, sql=param_config[request.param]["sql"],
sql=param_config[request.param]["sql"], validator_type=param_config[request.param]["validator_type"],
validator_type=param_config[request.param]["validator_type"], validator_config_json=param_config[request.param]["validator_config_json"],
validator_config_json=param_config[request.param][ )
"validator_config_json" report_schedule.last_state = ReportState.GRACE
], report_schedule.last_eval_dttm = datetime(2020, 1, 1, 0, 0)
)
report_schedule.last_state = ReportState.GRACE
report_schedule.last_eval_dttm = datetime(2020, 1, 1, 0, 0)
log = ReportExecutionLog( log = ReportExecutionLog(
report_schedule=report_schedule, report_schedule=report_schedule,
state=ReportState.SUCCESS, state=ReportState.SUCCESS,
start_dttm=report_schedule.last_eval_dttm, start_dttm=report_schedule.last_eval_dttm,
end_dttm=report_schedule.last_eval_dttm, end_dttm=report_schedule.last_eval_dttm,
scheduled_dttm=report_schedule.last_eval_dttm, scheduled_dttm=report_schedule.last_eval_dttm,
) )
db.session.add(log) db.session.add(log)
db.session.commit() db.session.commit()
yield report_schedule yield report_schedule
cleanup_report_schedule(report_schedule) cleanup_report_schedule(report_schedule)
@pytest.fixture( @pytest.fixture(
@ -462,25 +447,22 @@ def create_alert_email_chart(request):
"validator_config_json": '{"op": ">", "threshold": 54.999}', "validator_config_json": '{"op": ">", "threshold": 54.999}',
}, },
} }
with app.app_context(): chart = db.session.query(Slice).first()
chart = db.session.query(Slice).first() example_database = get_example_database()
example_database = get_example_database() with create_test_table_context(example_database):
with create_test_table_context(example_database): report_schedule = create_report_notification(
report_schedule = create_report_notification( email_target="target@email.com",
email_target="target@email.com", chart=chart,
chart=chart, report_type=ReportScheduleType.ALERT,
report_type=ReportScheduleType.ALERT, database=example_database,
database=example_database, sql=param_config[request.param]["sql"],
sql=param_config[request.param]["sql"], validator_type=param_config[request.param]["validator_type"],
validator_type=param_config[request.param]["validator_type"], validator_config_json=param_config[request.param]["validator_config_json"],
validator_config_json=param_config[request.param][ force_screenshot=True,
"validator_config_json" )
], yield report_schedule
force_screenshot=True,
)
yield report_schedule
cleanup_report_schedule(report_schedule) cleanup_report_schedule(report_schedule)
@pytest.fixture( @pytest.fixture(
@ -544,24 +526,21 @@ def create_no_alert_email_chart(request):
"validator_config_json": '{"op": ">", "threshold": 0}', "validator_config_json": '{"op": ">", "threshold": 0}',
}, },
} }
with app.app_context(): chart = db.session.query(Slice).first()
chart = db.session.query(Slice).first() example_database = get_example_database()
example_database = get_example_database() with create_test_table_context(example_database):
with create_test_table_context(example_database): report_schedule = create_report_notification(
report_schedule = create_report_notification( email_target="target@email.com",
email_target="target@email.com", chart=chart,
chart=chart, report_type=ReportScheduleType.ALERT,
report_type=ReportScheduleType.ALERT, database=example_database,
database=example_database, sql=param_config[request.param]["sql"],
sql=param_config[request.param]["sql"], validator_type=param_config[request.param]["validator_type"],
validator_type=param_config[request.param]["validator_type"], validator_config_json=param_config[request.param]["validator_config_json"],
validator_config_json=param_config[request.param][ )
"validator_config_json" yield report_schedule
],
)
yield report_schedule
cleanup_report_schedule(report_schedule) cleanup_report_schedule(report_schedule)
@pytest.fixture(params=["alert1", "alert2"]) @pytest.fixture(params=["alert1", "alert2"])
@ -578,28 +557,25 @@ def create_mul_alert_email_chart(request):
"validator_config_json": '{"op": "<", "threshold": 10}', "validator_config_json": '{"op": "<", "threshold": 10}',
}, },
} }
with app.app_context(): chart = db.session.query(Slice).first()
chart = db.session.query(Slice).first() example_database = get_example_database()
example_database = get_example_database() with create_test_table_context(example_database):
with create_test_table_context(example_database): report_schedule = create_report_notification(
report_schedule = create_report_notification( email_target="target@email.com",
email_target="target@email.com", chart=chart,
chart=chart, report_type=ReportScheduleType.ALERT,
report_type=ReportScheduleType.ALERT, database=example_database,
database=example_database, sql=param_config[request.param]["sql"],
sql=param_config[request.param]["sql"], validator_type=param_config[request.param]["validator_type"],
validator_type=param_config[request.param]["validator_type"], validator_config_json=param_config[request.param]["validator_config_json"],
validator_config_json=param_config[request.param][ )
"validator_config_json" yield report_schedule
],
)
yield report_schedule
cleanup_report_schedule(report_schedule) cleanup_report_schedule(report_schedule)
@pytest.fixture(params=["alert1", "alert2"]) @pytest.fixture(params=["alert1", "alert2"])
def create_invalid_sql_alert_email_chart(request): def create_invalid_sql_alert_email_chart(request, app_context: AppContext):
param_config = { param_config = {
"alert1": { "alert1": {
"sql": "SELECT 'string' ", "sql": "SELECT 'string' ",
@ -612,25 +588,22 @@ def create_invalid_sql_alert_email_chart(request):
"validator_config_json": '{"op": "<", "threshold": 10}', "validator_config_json": '{"op": "<", "threshold": 10}',
}, },
} }
with app.app_context(): chart = db.session.query(Slice).first()
chart = db.session.query(Slice).first() example_database = get_example_database()
example_database = get_example_database() with create_test_table_context(example_database):
with create_test_table_context(example_database): report_schedule = create_report_notification(
report_schedule = create_report_notification( email_target="target@email.com",
email_target="target@email.com", chart=chart,
chart=chart, report_type=ReportScheduleType.ALERT,
report_type=ReportScheduleType.ALERT, database=example_database,
database=example_database, sql=param_config[request.param]["sql"],
sql=param_config[request.param]["sql"], validator_type=param_config[request.param]["validator_type"],
validator_type=param_config[request.param]["validator_type"], validator_config_json=param_config[request.param]["validator_config_json"],
validator_config_json=param_config[request.param][ grace_period=60 * 60,
"validator_config_json" )
], yield report_schedule
grace_period=60 * 60,
)
yield report_schedule
cleanup_report_schedule(report_schedule) cleanup_report_schedule(report_schedule)
@pytest.mark.usefixtures( @pytest.mark.usefixtures(
@ -835,7 +808,8 @@ def test_email_chart_report_dry_run(
@pytest.mark.usefixtures( @pytest.mark.usefixtures(
"load_birth_names_dashboard_with_slices", "create_report_email_chart_with_csv" "load_birth_names_dashboard_with_slices",
"create_report_email_chart_with_csv",
) )
@patch("superset.utils.csv.urllib.request.urlopen") @patch("superset.utils.csv.urllib.request.urlopen")
@patch("superset.utils.csv.urllib.request.OpenerDirector.open") @patch("superset.utils.csv.urllib.request.OpenerDirector.open")
@ -923,7 +897,8 @@ def test_email_chart_report_schedule_with_csv_no_query_context(
@pytest.mark.usefixtures( @pytest.mark.usefixtures(
"load_birth_names_dashboard_with_slices", "create_report_email_chart_with_text" "load_birth_names_dashboard_with_slices",
"create_report_email_chart_with_text",
) )
@patch("superset.utils.csv.urllib.request.urlopen") @patch("superset.utils.csv.urllib.request.urlopen")
@patch("superset.utils.csv.urllib.request.OpenerDirector.open") @patch("superset.utils.csv.urllib.request.OpenerDirector.open")
@ -1545,7 +1520,8 @@ def test_slack_chart_alert_no_attachment(email_mock, create_alert_email_chart):
@pytest.mark.usefixtures( @pytest.mark.usefixtures(
"load_birth_names_dashboard_with_slices", "create_report_slack_chart" "load_birth_names_dashboard_with_slices",
"create_report_slack_chart",
) )
@patch("superset.utils.slack.WebClient") @patch("superset.utils.slack.WebClient")
@patch("superset.utils.screenshots.ChartScreenshot.get_screenshot") @patch("superset.utils.screenshots.ChartScreenshot.get_screenshot")
@ -1571,7 +1547,7 @@ def test_slack_token_callable_chart_report(
assert_log(ReportState.SUCCESS) assert_log(ReportState.SUCCESS)
@pytest.mark.usefixtures("create_no_alert_email_chart") @pytest.mark.usefixtures("app_context")
def test_email_chart_no_alert(create_no_alert_email_chart): def test_email_chart_no_alert(create_no_alert_email_chart):
""" """
ExecuteReport Command: Test chart email no alert ExecuteReport Command: Test chart email no alert
@ -1583,7 +1559,7 @@ def test_email_chart_no_alert(create_no_alert_email_chart):
assert_log(ReportState.NOOP) assert_log(ReportState.NOOP)
@pytest.mark.usefixtures("create_mul_alert_email_chart") @pytest.mark.usefixtures("app_context")
def test_email_mul_alert(create_mul_alert_email_chart): def test_email_mul_alert(create_mul_alert_email_chart):
""" """
ExecuteReport Command: Test chart email multiple rows ExecuteReport Command: Test chart email multiple rows
@ -1824,7 +1800,6 @@ def test_email_disable_screenshot(email_mock, create_alert_email_chart):
assert_log(ReportState.SUCCESS) assert_log(ReportState.SUCCESS)
@pytest.mark.usefixtures("create_invalid_sql_alert_email_chart")
@patch("superset.reports.notifications.email.send_email_smtp") @patch("superset.reports.notifications.email.send_email_smtp")
def test_invalid_sql_alert(email_mock, create_invalid_sql_alert_email_chart): def test_invalid_sql_alert(email_mock, create_invalid_sql_alert_email_chart):
""" """
@ -1841,7 +1816,6 @@ def test_invalid_sql_alert(email_mock, create_invalid_sql_alert_email_chart):
assert_log(ReportState.ERROR) assert_log(ReportState.ERROR)
@pytest.mark.usefixtures("create_invalid_sql_alert_email_chart")
@patch("superset.reports.notifications.email.send_email_smtp") @patch("superset.reports.notifications.email.send_email_smtp")
def test_grace_period_error(email_mock, create_invalid_sql_alert_email_chart): def test_grace_period_error(email_mock, create_invalid_sql_alert_email_chart):
""" """
@ -1884,7 +1858,6 @@ def test_grace_period_error(email_mock, create_invalid_sql_alert_email_chart):
) )
@pytest.mark.usefixtures("create_invalid_sql_alert_email_chart")
@patch("superset.reports.notifications.email.send_email_smtp") @patch("superset.reports.notifications.email.send_email_smtp")
@patch("superset.utils.screenshots.ChartScreenshot.get_screenshot") @patch("superset.utils.screenshots.ChartScreenshot.get_screenshot")
def test_grace_period_error_flap( def test_grace_period_error_flap(

View File

@ -35,150 +35,144 @@ def owners(get_user) -> list[User]:
return [get_user("admin")] return [get_user("admin")]
@pytest.mark.usefixtures("owners") @pytest.mark.usefixtures("app_context")
@patch("superset.tasks.scheduler.execute.apply_async") @patch("superset.tasks.scheduler.execute.apply_async")
def test_scheduler_celery_timeout_ny(execute_mock, owners): def test_scheduler_celery_timeout_ny(execute_mock, owners):
""" """
Reports scheduler: Test scheduler setting celery soft and hard timeout Reports scheduler: Test scheduler setting celery soft and hard timeout
""" """
with app.app_context(): report_schedule = insert_report_schedule(
report_schedule = insert_report_schedule( type=ReportScheduleType.ALERT,
type=ReportScheduleType.ALERT, name="report",
name="report", crontab="0 4 * * *",
crontab="0 4 * * *", timezone="America/New_York",
timezone="America/New_York", owners=owners,
owners=owners, )
)
with freeze_time("2020-01-01T09:00:00Z"): with freeze_time("2020-01-01T09:00:00Z"):
scheduler() scheduler()
assert execute_mock.call_args[1]["soft_time_limit"] == 3601 assert execute_mock.call_args[1]["soft_time_limit"] == 3601
assert execute_mock.call_args[1]["time_limit"] == 3610 assert execute_mock.call_args[1]["time_limit"] == 3610
db.session.delete(report_schedule) db.session.delete(report_schedule)
db.session.commit() db.session.commit()
@pytest.mark.usefixtures("owners") @pytest.mark.usefixtures("app_context")
@patch("superset.tasks.scheduler.execute.apply_async") @patch("superset.tasks.scheduler.execute.apply_async")
def test_scheduler_celery_no_timeout_ny(execute_mock, owners): def test_scheduler_celery_no_timeout_ny(execute_mock, owners):
""" """
Reports scheduler: Test scheduler setting celery soft and hard timeout Reports scheduler: Test scheduler setting celery soft and hard timeout
""" """
with app.app_context(): app.config["ALERT_REPORTS_WORKING_TIME_OUT_KILL"] = False
app.config["ALERT_REPORTS_WORKING_TIME_OUT_KILL"] = False report_schedule = insert_report_schedule(
report_schedule = insert_report_schedule( type=ReportScheduleType.ALERT,
type=ReportScheduleType.ALERT, name="report",
name="report", crontab="0 4 * * *",
crontab="0 4 * * *", timezone="America/New_York",
timezone="America/New_York", owners=owners,
owners=owners, )
)
with freeze_time("2020-01-01T09:00:00Z"): with freeze_time("2020-01-01T09:00:00Z"):
scheduler() scheduler()
assert execute_mock.call_args[1] == {"eta": FakeDatetime(2020, 1, 1, 9, 0)} assert execute_mock.call_args[1] == {"eta": FakeDatetime(2020, 1, 1, 9, 0)}
db.session.delete(report_schedule) db.session.delete(report_schedule)
db.session.commit() db.session.commit()
app.config["ALERT_REPORTS_WORKING_TIME_OUT_KILL"] = True app.config["ALERT_REPORTS_WORKING_TIME_OUT_KILL"] = True
@pytest.mark.usefixtures("owners") @pytest.mark.usefixtures("app_context")
@patch("superset.tasks.scheduler.execute.apply_async") @patch("superset.tasks.scheduler.execute.apply_async")
def test_scheduler_celery_timeout_utc(execute_mock, owners): def test_scheduler_celery_timeout_utc(execute_mock, owners):
""" """
Reports scheduler: Test scheduler setting celery soft and hard timeout Reports scheduler: Test scheduler setting celery soft and hard timeout
""" """
with app.app_context(): report_schedule = insert_report_schedule(
report_schedule = insert_report_schedule( type=ReportScheduleType.ALERT,
type=ReportScheduleType.ALERT, name="report",
name="report", crontab="0 9 * * *",
crontab="0 9 * * *", timezone="UTC",
timezone="UTC", owners=owners,
owners=owners, )
)
with freeze_time("2020-01-01T09:00:00Z"): with freeze_time("2020-01-01T09:00:00Z"):
scheduler() scheduler()
assert execute_mock.call_args[1]["soft_time_limit"] == 3601 assert execute_mock.call_args[1]["soft_time_limit"] == 3601
assert execute_mock.call_args[1]["time_limit"] == 3610 assert execute_mock.call_args[1]["time_limit"] == 3610
db.session.delete(report_schedule) db.session.delete(report_schedule)
db.session.commit() db.session.commit()
@pytest.mark.usefixtures("owners") @pytest.mark.usefixtures("app_context")
@patch("superset.tasks.scheduler.execute.apply_async") @patch("superset.tasks.scheduler.execute.apply_async")
def test_scheduler_celery_no_timeout_utc(execute_mock, owners): def test_scheduler_celery_no_timeout_utc(execute_mock, owners):
""" """
Reports scheduler: Test scheduler setting celery soft and hard timeout Reports scheduler: Test scheduler setting celery soft and hard timeout
""" """
with app.app_context(): app.config["ALERT_REPORTS_WORKING_TIME_OUT_KILL"] = False
app.config["ALERT_REPORTS_WORKING_TIME_OUT_KILL"] = False report_schedule = insert_report_schedule(
report_schedule = insert_report_schedule( type=ReportScheduleType.ALERT,
type=ReportScheduleType.ALERT, name="report",
name="report", crontab="0 9 * * *",
crontab="0 9 * * *", timezone="UTC",
timezone="UTC", owners=owners,
owners=owners, )
)
with freeze_time("2020-01-01T09:00:00Z"): with freeze_time("2020-01-01T09:00:00Z"):
scheduler() scheduler()
assert execute_mock.call_args[1] == {"eta": FakeDatetime(2020, 1, 1, 9, 0)} assert execute_mock.call_args[1] == {"eta": FakeDatetime(2020, 1, 1, 9, 0)}
db.session.delete(report_schedule) db.session.delete(report_schedule)
db.session.commit() db.session.commit()
app.config["ALERT_REPORTS_WORKING_TIME_OUT_KILL"] = True app.config["ALERT_REPORTS_WORKING_TIME_OUT_KILL"] = True
@pytest.mark.usefixtures("owners") @pytest.mark.usefixtures("app_context")
@patch("superset.tasks.scheduler.is_feature_enabled") @patch("superset.tasks.scheduler.is_feature_enabled")
@patch("superset.tasks.scheduler.execute.apply_async") @patch("superset.tasks.scheduler.execute.apply_async")
def test_scheduler_feature_flag_off(execute_mock, is_feature_enabled, owners): def test_scheduler_feature_flag_off(execute_mock, is_feature_enabled, owners):
""" """
Reports scheduler: Test scheduler with feature flag off Reports scheduler: Test scheduler with feature flag off
""" """
with app.app_context(): is_feature_enabled.return_value = False
is_feature_enabled.return_value = False report_schedule = insert_report_schedule(
report_schedule = insert_report_schedule( type=ReportScheduleType.ALERT,
type=ReportScheduleType.ALERT, name="report",
name="report", crontab="0 9 * * *",
crontab="0 9 * * *", timezone="UTC",
timezone="UTC", owners=owners,
owners=owners, )
)
with freeze_time("2020-01-01T09:00:00Z"): with freeze_time("2020-01-01T09:00:00Z"):
scheduler() scheduler()
execute_mock.assert_not_called() execute_mock.assert_not_called()
db.session.delete(report_schedule) db.session.delete(report_schedule)
db.session.commit() db.session.commit()
@pytest.mark.usefixtures("owners") @pytest.mark.usefixtures("app_context")
@patch("superset.commands.report.execute.AsyncExecuteReportScheduleCommand.__init__") @patch("superset.commands.report.execute.AsyncExecuteReportScheduleCommand.__init__")
@patch("superset.commands.report.execute.AsyncExecuteReportScheduleCommand.run") @patch("superset.commands.report.execute.AsyncExecuteReportScheduleCommand.run")
@patch("superset.tasks.scheduler.execute.update_state") @patch("superset.tasks.scheduler.execute.update_state")
def test_execute_task(update_state_mock, command_mock, init_mock, owners): def test_execute_task(update_state_mock, command_mock, init_mock, owners):
from superset.commands.report.exceptions import ReportScheduleUnexpectedError from superset.commands.report.exceptions import ReportScheduleUnexpectedError
with app.app_context(): report_schedule = insert_report_schedule(
report_schedule = insert_report_schedule( type=ReportScheduleType.ALERT,
type=ReportScheduleType.ALERT, name=f"report-{randint(0,1000)}",
name=f"report-{randint(0,1000)}", crontab="0 4 * * *",
crontab="0 4 * * *", timezone="America/New_York",
timezone="America/New_York", owners=owners,
owners=owners, )
) init_mock.return_value = None
init_mock.return_value = None command_mock.side_effect = ReportScheduleUnexpectedError("Unexpected error")
command_mock.side_effect = ReportScheduleUnexpectedError("Unexpected error") with freeze_time("2020-01-01T09:00:00Z"):
with freeze_time("2020-01-01T09:00:00Z"): execute(report_schedule.id)
execute(report_schedule.id) update_state_mock.assert_called_with(state="FAILURE")
update_state_mock.assert_called_with(state="FAILURE")
db.session.delete(report_schedule) db.session.delete(report_schedule)
db.session.commit() db.session.commit()
@pytest.mark.usefixtures("owners") @pytest.mark.usefixtures("app_context")
@patch("superset.commands.report.execute.AsyncExecuteReportScheduleCommand.__init__") @patch("superset.commands.report.execute.AsyncExecuteReportScheduleCommand.__init__")
@patch("superset.commands.report.execute.AsyncExecuteReportScheduleCommand.run") @patch("superset.commands.report.execute.AsyncExecuteReportScheduleCommand.run")
@patch("superset.tasks.scheduler.execute.update_state") @patch("superset.tasks.scheduler.execute.update_state")
@ -188,23 +182,22 @@ def test_execute_task_with_command_exception(
): ):
from superset.commands.exceptions import CommandException from superset.commands.exceptions import CommandException
with app.app_context(): report_schedule = insert_report_schedule(
report_schedule = insert_report_schedule( type=ReportScheduleType.ALERT,
type=ReportScheduleType.ALERT, name=f"report-{randint(0,1000)}",
name=f"report-{randint(0,1000)}", crontab="0 4 * * *",
crontab="0 4 * * *", timezone="America/New_York",
timezone="America/New_York", owners=owners,
owners=owners, )
init_mock.return_value = None
command_mock.side_effect = CommandException("Unexpected error")
with freeze_time("2020-01-01T09:00:00Z"):
execute(report_schedule.id)
update_state_mock.assert_called_with(state="FAILURE")
logger_mock.exception.assert_called_with(
"A downstream exception occurred while generating a report: None. Unexpected error",
exc_info=True,
) )
init_mock.return_value = None
command_mock.side_effect = CommandException("Unexpected error")
with freeze_time("2020-01-01T09:00:00Z"):
execute(report_schedule.id)
update_state_mock.assert_called_with(state="FAILURE")
logger_mock.exception.assert_called_with(
"A downstream exception occurred while generating a report: None. Unexpected error",
exc_info=True,
)
db.session.delete(report_schedule) db.session.delete(report_schedule)
db.session.commit() db.session.commit()

View File

@ -84,10 +84,9 @@ from tests.integration_tests.test_app import app
def test_check_sqlalchemy_uri( def test_check_sqlalchemy_uri(
sqlalchemy_uri: str, error: bool, error_message: Optional[str] sqlalchemy_uri: str, error: bool, error_message: Optional[str]
): ):
with app.app_context(): if error:
if error: with pytest.raises(SupersetSecurityException) as excinfo:
with pytest.raises(SupersetSecurityException) as excinfo:
check_sqlalchemy_uri(make_url(sqlalchemy_uri))
assert str(excinfo.value) == error_message
else:
check_sqlalchemy_uri(make_url(sqlalchemy_uri)) check_sqlalchemy_uri(make_url(sqlalchemy_uri))
assert str(excinfo.value) == error_message
else:
check_sqlalchemy_uri(make_url(sqlalchemy_uri))

View File

@ -24,7 +24,7 @@ import pytest
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from flask import Flask from flask.ctx import AppContext
from pytest_mock import MockFixture from pytest_mock import MockFixture
from sqlalchemy.sql import text from sqlalchemy.sql import text
from sqlalchemy.sql.elements import TextClause from sqlalchemy.sql.elements import TextClause
@ -598,26 +598,25 @@ class TestDatabaseModel(SupersetTestCase):
db.session.commit() db.session.commit()
@pytest.fixture @pytest.fixture()
def text_column_table(): def text_column_table(app_context: AppContext):
with app.app_context(): table = SqlaTable(
table = SqlaTable( table_name="text_column_table",
table_name="text_column_table", sql=(
sql=( "SELECT 'foo' as foo "
"SELECT 'foo' as foo " "UNION SELECT '' "
"UNION SELECT '' " "UNION SELECT NULL "
"UNION SELECT NULL " "UNION SELECT 'null' "
"UNION SELECT 'null' " "UNION SELECT '\"text in double quotes\"' "
"UNION SELECT '\"text in double quotes\"' " "UNION SELECT '''text in single quotes''' "
"UNION SELECT '''text in single quotes''' " "UNION SELECT 'double quotes \" in text' "
"UNION SELECT 'double quotes \" in text' " "UNION SELECT 'single quotes '' in text' "
"UNION SELECT 'single quotes '' in text' " ),
), database=get_example_database(),
database=get_example_database(), )
) TableColumn(column_name="foo", type="VARCHAR(255)", table=table)
TableColumn(column_name="foo", type="VARCHAR(255)", table=table) SqlMetric(metric_name="count", expression="count(*)", table=table)
SqlMetric(metric_name="count", expression="count(*)", table=table) yield table
yield table
def test_values_for_column_on_text_column(text_column_table): def test_values_for_column_on_text_column(text_column_table):
@ -836,6 +835,7 @@ def test_none_operand_in_filter(login_as_admin, physical_dataset):
) )
@pytest.mark.usefixtures("app_context")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"row,dimension,result", "row,dimension,result",
[ [
@ -857,7 +857,6 @@ def test_none_operand_in_filter(login_as_admin, physical_dataset):
], ],
) )
def test__normalize_prequery_result_type( def test__normalize_prequery_result_type(
app_context: Flask,
mocker: MockFixture, mocker: MockFixture,
row: pd.Series, row: pd.Series,
dimension: str, dimension: str,
@ -927,7 +926,8 @@ def test__normalize_prequery_result_type(
assert normalized == result assert normalized == result
def test__temporal_range_operator_in_adhoc_filter(app_context, physical_dataset): @pytest.mark.usefixtures("app_context")
def test__temporal_range_operator_in_adhoc_filter(physical_dataset):
result = physical_dataset.query( result = physical_dataset.query(
{ {
"columns": ["col1", "col2"], "columns": ["col1", "col2"],

View File

@ -82,5 +82,4 @@ def test_form_data_to_adhoc_incorrect_clause_type():
form_data = {"where": "1 = 1", "having": "count(*) > 1"} form_data = {"where": "1 = 1", "having": "count(*) > 1"}
with pytest.raises(ValueError): with pytest.raises(ValueError):
with app.app_context(): form_data_to_adhoc(form_data, "foobar")
form_data_to_adhoc(form_data, "foobar")

View File

@ -685,20 +685,19 @@ class TestUtils(SupersetTestCase):
self.assertIsNotNone(parse_js_uri_path_item("item")) self.assertIsNotNone(parse_js_uri_path_item("item"))
def test_get_stacktrace(self): def test_get_stacktrace(self):
with app.app_context(): app.config["SHOW_STACKTRACE"] = True
app.config["SHOW_STACKTRACE"] = True try:
try: raise Exception("NONONO!")
raise Exception("NONONO!") except Exception:
except Exception: stacktrace = get_stacktrace()
stacktrace = get_stacktrace() self.assertIn("NONONO", stacktrace)
self.assertIn("NONONO", stacktrace)
app.config["SHOW_STACKTRACE"] = False app.config["SHOW_STACKTRACE"] = False
try: try:
raise Exception("NONONO!") raise Exception("NONONO!")
except Exception: except Exception:
stacktrace = get_stacktrace() stacktrace = get_stacktrace()
assert stacktrace is None assert stacktrace is None
def test_split(self): def test_split(self):
self.assertEqual(list(split("a b")), ["a", "b"]) self.assertEqual(list(split("a b")), ["a", "b"])
@ -839,9 +838,8 @@ class TestUtils(SupersetTestCase):
) )
def test_get_form_data_default(self) -> None: def test_get_form_data_default(self) -> None:
with app.test_request_context(): form_data, slc = get_form_data()
form_data, slc = get_form_data() self.assertEqual(slc, None)
self.assertEqual(slc, None)
def test_get_form_data_request_args(self) -> None: def test_get_form_data_request_args(self) -> None:
with app.test_request_context( with app.test_request_context(

View File

@ -411,72 +411,71 @@ def test_delete_ssh_tunnel(
""" """
Test that we can delete SSH Tunnel Test that we can delete SSH Tunnel
""" """
with app.app_context(): from superset.daos.database import DatabaseDAO
from superset.daos.database import DatabaseDAO from superset.databases.api import DatabaseRestApi
from superset.databases.api import DatabaseRestApi from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.databases.ssh_tunnel.models import SSHTunnel from superset.models.core import Database
from superset.models.core import Database
DatabaseRestApi.datamodel.session = session DatabaseRestApi.datamodel.session = session
# create table for databases # create table for databases
Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member
# Create our Database # Create our Database
database = Database( database = Database(
database_name="my_database", database_name="my_database",
sqlalchemy_uri="gsheets://", sqlalchemy_uri="gsheets://",
encrypted_extra=json.dumps( encrypted_extra=json.dumps(
{ {
"service_account_info": { "service_account_info": {
"type": "service_account", "type": "service_account",
"project_id": "black-sanctum-314419", "project_id": "black-sanctum-314419",
"private_key_id": "259b0d419a8f840056158763ff54d8b08f7b8173", "private_key_id": "259b0d419a8f840056158763ff54d8b08f7b8173",
"private_key": "SECRET", "private_key": "SECRET",
"client_email": "google-spreadsheets-demo-servi@black-sanctum-314419.iam.gserviceaccount.com", "client_email": "google-spreadsheets-demo-servi@black-sanctum-314419.iam.gserviceaccount.com",
"client_id": "SSH_TUNNEL_CREDENTIALS_CLIENT", "client_id": "SSH_TUNNEL_CREDENTIALS_CLIENT",
"auth_uri": "https://accounts.google.com/o/oauth2/auth", "auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://oauth2.googleapis.com/token", "token_uri": "https://oauth2.googleapis.com/token",
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/google-spreadsheets-demo-servi%40black-sanctum-314419.iam.gserviceaccount.com", "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/google-spreadsheets-demo-servi%40black-sanctum-314419.iam.gserviceaccount.com",
}, },
} }
), ),
) )
db.session.add(database) db.session.add(database)
db.session.commit() db.session.commit()
# mock the lookup so that we don't need to include the driver # mock the lookup so that we don't need to include the driver
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets") mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
mocker.patch("superset.utils.log.DBEventLogger.log") mocker.patch("superset.utils.log.DBEventLogger.log")
mocker.patch( mocker.patch(
"superset.commands.database.ssh_tunnel.delete.is_feature_enabled", "superset.commands.database.ssh_tunnel.delete.is_feature_enabled",
return_value=True, return_value=True,
) )
# Create our SSHTunnel # Create our SSHTunnel
tunnel = SSHTunnel( tunnel = SSHTunnel(
database_id=1, database_id=1,
database=database, database=database,
) )
db.session.add(tunnel) db.session.add(tunnel)
db.session.commit() db.session.commit()
# Get our recently created SSHTunnel # Get our recently created SSHTunnel
response_tunnel = DatabaseDAO.get_ssh_tunnel(1) response_tunnel = DatabaseDAO.get_ssh_tunnel(1)
assert response_tunnel assert response_tunnel
assert isinstance(response_tunnel, SSHTunnel) assert isinstance(response_tunnel, SSHTunnel)
assert 1 == response_tunnel.database_id assert 1 == response_tunnel.database_id
# Delete the recently created SSHTunnel # Delete the recently created SSHTunnel
response_delete_tunnel = client.delete( response_delete_tunnel = client.delete(
f"/api/v1/database/{database.id}/ssh_tunnel/" f"/api/v1/database/{database.id}/ssh_tunnel/"
) )
assert response_delete_tunnel.json["message"] == "OK" assert response_delete_tunnel.json["message"] == "OK"
response_tunnel = DatabaseDAO.get_ssh_tunnel(1) response_tunnel = DatabaseDAO.get_ssh_tunnel(1)
assert response_tunnel is None assert response_tunnel is None
def test_delete_ssh_tunnel_not_found( def test_delete_ssh_tunnel_not_found(
@ -489,70 +488,69 @@ def test_delete_ssh_tunnel_not_found(
""" """
Test that we cannot delete a tunnel that does not exist Test that we cannot delete a tunnel that does not exist
""" """
with app.app_context(): from superset.daos.database import DatabaseDAO
from superset.daos.database import DatabaseDAO from superset.databases.api import DatabaseRestApi
from superset.databases.api import DatabaseRestApi from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.databases.ssh_tunnel.models import SSHTunnel from superset.models.core import Database
from superset.models.core import Database
DatabaseRestApi.datamodel.session = session DatabaseRestApi.datamodel.session = session
# create table for databases # create table for databases
Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member
# Create our Database # Create our Database
database = Database( database = Database(
database_name="my_database", database_name="my_database",
sqlalchemy_uri="gsheets://", sqlalchemy_uri="gsheets://",
encrypted_extra=json.dumps( encrypted_extra=json.dumps(
{ {
"service_account_info": { "service_account_info": {
"type": "service_account", "type": "service_account",
"project_id": "black-sanctum-314419", "project_id": "black-sanctum-314419",
"private_key_id": "259b0d419a8f840056158763ff54d8b08f7b8173", "private_key_id": "259b0d419a8f840056158763ff54d8b08f7b8173",
"private_key": "SECRET", "private_key": "SECRET",
"client_email": "google-spreadsheets-demo-servi@black-sanctum-314419.iam.gserviceaccount.com", "client_email": "google-spreadsheets-demo-servi@black-sanctum-314419.iam.gserviceaccount.com",
"client_id": "SSH_TUNNEL_CREDENTIALS_CLIENT", "client_id": "SSH_TUNNEL_CREDENTIALS_CLIENT",
"auth_uri": "https://accounts.google.com/o/oauth2/auth", "auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://oauth2.googleapis.com/token", "token_uri": "https://oauth2.googleapis.com/token",
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/google-spreadsheets-demo-servi%40black-sanctum-314419.iam.gserviceaccount.com", "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/google-spreadsheets-demo-servi%40black-sanctum-314419.iam.gserviceaccount.com",
}, },
} }
), ),
) )
db.session.add(database) db.session.add(database)
db.session.commit() db.session.commit()
# mock the lookup so that we don't need to include the driver # mock the lookup so that we don't need to include the driver
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets") mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
mocker.patch("superset.utils.log.DBEventLogger.log") mocker.patch("superset.utils.log.DBEventLogger.log")
mocker.patch( mocker.patch(
"superset.commands.database.ssh_tunnel.delete.is_feature_enabled", "superset.commands.database.ssh_tunnel.delete.is_feature_enabled",
return_value=True, return_value=True,
) )
# Create our SSHTunnel # Create our SSHTunnel
tunnel = SSHTunnel( tunnel = SSHTunnel(
database_id=1, database_id=1,
database=database, database=database,
) )
db.session.add(tunnel) db.session.add(tunnel)
db.session.commit() db.session.commit()
# Delete the recently created SSHTunnel # Delete the recently created SSHTunnel
response_delete_tunnel = client.delete("/api/v1/database/2/ssh_tunnel/") response_delete_tunnel = client.delete("/api/v1/database/2/ssh_tunnel/")
assert response_delete_tunnel.json["message"] == "Not found" assert response_delete_tunnel.json["message"] == "Not found"
# Get our recently created SSHTunnel # Get our recently created SSHTunnel
response_tunnel = DatabaseDAO.get_ssh_tunnel(1) response_tunnel = DatabaseDAO.get_ssh_tunnel(1)
assert response_tunnel assert response_tunnel
assert isinstance(response_tunnel, SSHTunnel) assert isinstance(response_tunnel, SSHTunnel)
assert 1 == response_tunnel.database_id assert 1 == response_tunnel.database_id
response_tunnel = DatabaseDAO.get_ssh_tunnel(2) response_tunnel = DatabaseDAO.get_ssh_tunnel(2)
assert response_tunnel is None assert response_tunnel is None
def test_apply_dynamic_database_filter( def test_apply_dynamic_database_filter(
@ -568,88 +566,87 @@ def test_apply_dynamic_database_filter(
defining a filter function and patching the config to get defining a filter function and patching the config to get
the filtered results. the filtered results.
""" """
with app.app_context(): from superset.daos.database import DatabaseDAO
from superset.daos.database import DatabaseDAO from superset.databases.api import DatabaseRestApi
from superset.databases.api import DatabaseRestApi from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.databases.ssh_tunnel.models import SSHTunnel from superset.models.core import Database
DatabaseRestApi.datamodel.session = session
# create table for databases
Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member
# Create our First Database
database = Database(
database_name="first-database",
sqlalchemy_uri="gsheets://",
encrypted_extra=json.dumps(
{
"metadata_params": {},
"engine_params": {},
"metadata_cache_timeout": {},
"schemas_allowed_for_file_upload": [],
}
),
)
db.session.add(database)
db.session.commit()
# Create our Second Database
database = Database(
database_name="second-database",
sqlalchemy_uri="gsheets://",
encrypted_extra=json.dumps(
{
"metadata_params": {},
"engine_params": {},
"metadata_cache_timeout": {},
"schemas_allowed_for_file_upload": [],
}
),
)
db.session.add(database)
db.session.commit()
# mock the lookup so that we don't need to include the driver
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
mocker.patch("superset.utils.log.DBEventLogger.log")
mocker.patch(
"superset.commands.database.ssh_tunnel.delete.is_feature_enabled",
return_value=False,
)
def _base_filter(query):
from superset.models.core import Database from superset.models.core import Database
DatabaseRestApi.datamodel.session = session return query.filter(Database.database_name.startswith("second"))
# create table for databases # Create a mock object
Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member base_filter_mock = Mock(side_effect=_base_filter)
# Create our First Database # Get our recently created Databases
database = Database( response_databases = DatabaseDAO.find_all()
database_name="first-database", assert response_databases
sqlalchemy_uri="gsheets://", expected_db_names = ["first-database", "second-database"]
encrypted_extra=json.dumps( actual_db_names = [db.database_name for db in response_databases]
{ assert actual_db_names == expected_db_names
"metadata_params": {},
"engine_params": {},
"metadata_cache_timeout": {},
"schemas_allowed_for_file_upload": [],
}
),
)
db.session.add(database)
db.session.commit()
# Create our Second Database # Ensure that the filter has not been called because it's not in our config
database = Database( assert base_filter_mock.call_count == 0
database_name="second-database",
sqlalchemy_uri="gsheets://",
encrypted_extra=json.dumps(
{
"metadata_params": {},
"engine_params": {},
"metadata_cache_timeout": {},
"schemas_allowed_for_file_upload": [],
}
),
)
db.session.add(database)
db.session.commit()
# mock the lookup so that we don't need to include the driver original_config = current_app.config.copy()
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets") original_config["EXTRA_DYNAMIC_QUERY_FILTERS"] = {"databases": base_filter_mock}
mocker.patch("superset.utils.log.DBEventLogger.log")
mocker.patch(
"superset.commands.database.ssh_tunnel.delete.is_feature_enabled",
return_value=False,
)
def _base_filter(query): mocker.patch("superset.views.filters.current_app.config", new=original_config)
from superset.models.core import Database # Get filtered list
response_databases = DatabaseDAO.find_all()
assert response_databases
expected_db_names = ["second-database"]
actual_db_names = [db.database_name for db in response_databases]
assert actual_db_names == expected_db_names
return query.filter(Database.database_name.startswith("second")) # Ensure that the filter has been called once
assert base_filter_mock.call_count == 1
# Create a mock object
base_filter_mock = Mock(side_effect=_base_filter)
# Get our recently created Databases
response_databases = DatabaseDAO.find_all()
assert response_databases
expected_db_names = ["first-database", "second-database"]
actual_db_names = [db.database_name for db in response_databases]
assert actual_db_names == expected_db_names
# Ensure that the filter has not been called because it's not in our config
assert base_filter_mock.call_count == 0
original_config = current_app.config.copy()
original_config["EXTRA_DYNAMIC_QUERY_FILTERS"] = {"databases": base_filter_mock}
mocker.patch("superset.views.filters.current_app.config", new=original_config)
# Get filtered list
response_databases = DatabaseDAO.find_all()
assert response_databases
expected_db_names = ["second-database"]
actual_db_names = [db.database_name for db in response_databases]
assert actual_db_names == expected_db_names
# Ensure that the filter has been called once
assert base_filter_mock.call_count == 1
def test_oauth2_happy_path( def test_oauth2_happy_path(