chore(tests): Remove unnecessary/problematic app contexts (#28159)

This commit is contained in:
John Bodley 2024-04-24 13:46:35 -07:00 committed by GitHub
parent a9075fdb1f
commit bc65c245fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 880 additions and 980 deletions

View File

@ -81,36 +81,6 @@ DB_ACCESS_ROLE = "db_access_role"
SCHEMA_ACCESS_ROLE = "schema_access_role"
class TestRequestAccess(SupersetTestCase):
@classmethod
def setUpClass(cls):
with app.app_context():
security_manager.add_role("override_me")
security_manager.add_role(TEST_ROLE_1)
security_manager.add_role(TEST_ROLE_2)
security_manager.add_role(DB_ACCESS_ROLE)
security_manager.add_role(SCHEMA_ACCESS_ROLE)
db.session.commit()
@classmethod
def tearDownClass(cls):
with app.app_context():
override_me = security_manager.find_role("override_me")
db.session.delete(override_me)
db.session.delete(security_manager.find_role(TEST_ROLE_1))
db.session.delete(security_manager.find_role(TEST_ROLE_2))
db.session.delete(security_manager.find_role(DB_ACCESS_ROLE))
db.session.delete(security_manager.find_role(SCHEMA_ACCESS_ROLE))
db.session.commit()
def tearDown(self):
override_me = security_manager.find_role("override_me")
override_me.permissions = []
db.session.commit()
db.session.close()
super().tearDown()
@pytest.mark.parametrize(
"username,user_id",
[

View File

@ -14,18 +14,16 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
import pytest
from datetime import datetime
from typing import Optional
import pytest
from flask.ctx import AppContext
from superset import db
from superset.models.annotations import Annotation, AnnotationLayer
from tests.integration_tests.test_app import app
ANNOTATION_LAYERS_COUNT = 10
ANNOTATIONS_COUNT = 5
@ -70,36 +68,35 @@ def _insert_annotation(
@pytest.fixture()
def create_annotation_layers():
def create_annotation_layers(app_context: AppContext):
"""
Creates ANNOTATION_LAYERS_COUNT-1 layers with no annotations
and a final one with ANNOTATION_COUNT children
:return:
"""
with app.app_context():
annotation_layers = []
annotations = []
for cx in range(ANNOTATION_LAYERS_COUNT - 1):
annotation_layers.append(
_insert_annotation_layer(name=f"name{cx}", descr=f"descr{cx}")
annotation_layers = []
annotations = []
for cx in range(ANNOTATION_LAYERS_COUNT - 1):
annotation_layers.append(
_insert_annotation_layer(name=f"name{cx}", descr=f"descr{cx}")
)
layer_with_annotations = _insert_annotation_layer("layer_with_annotations")
annotation_layers.append(layer_with_annotations)
for cx in range(ANNOTATIONS_COUNT):
annotations.append(
_insert_annotation(
layer_with_annotations,
short_descr=f"short_descr{cx}",
long_descr=f"long_descr{cx}",
start_dttm=get_start_dttm(cx),
end_dttm=get_end_dttm(cx),
)
layer_with_annotations = _insert_annotation_layer("layer_with_annotations")
annotation_layers.append(layer_with_annotations)
for cx in range(ANNOTATIONS_COUNT):
annotations.append(
_insert_annotation(
layer_with_annotations,
short_descr=f"short_descr{cx}",
long_descr=f"long_descr{cx}",
start_dttm=get_start_dttm(cx),
end_dttm=get_end_dttm(cx),
)
)
yield annotation_layers
)
yield annotation_layers
# rollback changes
for annotation_layer in annotation_layers:
db.session.delete(annotation_layer)
for annotation in annotations:
db.session.delete(annotation)
db.session.commit()
# rollback changes
for annotation_layer in annotation_layers:
db.session.delete(annotation_layer)
for annotation in annotations:
db.session.delete(annotation)
db.session.commit()

View File

@ -23,6 +23,7 @@ from zipfile import is_zipfile, ZipFile
import prison
import pytest
import yaml
from flask.ctx import AppContext
from flask_babel import lazy_gettext as _
from parameterized import parameterized
from sqlalchemy import and_
@ -82,121 +83,115 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
resource_name = "chart"
@pytest.fixture(autouse=True)
def clear_data_cache(self):
with app.app_context():
cache_manager.data_cache.clear()
yield
def clear_data_cache(self, app_context: AppContext):
cache_manager.data_cache.clear()
yield
@pytest.fixture()
def create_charts(self):
with self.create_app().app_context():
charts = []
admin = self.get_user("admin")
for cx in range(CHARTS_FIXTURE_COUNT - 1):
charts.append(self.insert_chart(f"name{cx}", [admin.id], 1))
fav_charts = []
for cx in range(round(CHARTS_FIXTURE_COUNT / 2)):
fav_star = FavStar(
user_id=admin.id, class_name="slice", obj_id=charts[cx].id
)
db.session.add(fav_star)
db.session.commit()
fav_charts.append(fav_star)
yield charts
# rollback changes
for chart in charts:
db.session.delete(chart)
for fav_chart in fav_charts:
db.session.delete(fav_chart)
charts = []
admin = self.get_user("admin")
for cx in range(CHARTS_FIXTURE_COUNT - 1):
charts.append(self.insert_chart(f"name{cx}", [admin.id], 1))
fav_charts = []
for cx in range(round(CHARTS_FIXTURE_COUNT / 2)):
fav_star = FavStar(
user_id=admin.id, class_name="slice", obj_id=charts[cx].id
)
db.session.add(fav_star)
db.session.commit()
fav_charts.append(fav_star)
yield charts
# rollback changes
for chart in charts:
db.session.delete(chart)
for fav_chart in fav_charts:
db.session.delete(fav_chart)
db.session.commit()
@pytest.fixture()
def create_charts_created_by_gamma(self):
with self.create_app().app_context():
charts = []
user = self.get_user("gamma")
for cx in range(CHARTS_FIXTURE_COUNT - 1):
charts.append(self.insert_chart(f"gamma{cx}", [user.id], 1))
yield charts
# rollback changes
for chart in charts:
db.session.delete(chart)
db.session.commit()
charts = []
user = self.get_user("gamma")
for cx in range(CHARTS_FIXTURE_COUNT - 1):
charts.append(self.insert_chart(f"gamma{cx}", [user.id], 1))
yield charts
# rollback changes
for chart in charts:
db.session.delete(chart)
db.session.commit()
@pytest.fixture()
def create_certified_charts(self):
with self.create_app().app_context():
certified_charts = []
admin = self.get_user("admin")
for cx in range(CHARTS_FIXTURE_COUNT):
certified_charts.append(
self.insert_chart(
f"certified{cx}",
[admin.id],
1,
certified_by="John Doe",
certification_details="Sample certification",
)
certified_charts = []
admin = self.get_user("admin")
for cx in range(CHARTS_FIXTURE_COUNT):
certified_charts.append(
self.insert_chart(
f"certified{cx}",
[admin.id],
1,
certified_by="John Doe",
certification_details="Sample certification",
)
)
yield certified_charts
yield certified_charts
# rollback changes
for chart in certified_charts:
db.session.delete(chart)
db.session.commit()
# rollback changes
for chart in certified_charts:
db.session.delete(chart)
db.session.commit()
@pytest.fixture()
def create_chart_with_report(self):
with self.create_app().app_context():
admin = self.get_user("admin")
chart = self.insert_chart(f"chart_report", [admin.id], 1)
report_schedule = ReportSchedule(
type=ReportScheduleType.REPORT,
name="report_with_chart",
crontab="* * * * *",
chart=chart,
)
db.session.commit()
admin = self.get_user("admin")
chart = self.insert_chart(f"chart_report", [admin.id], 1)
report_schedule = ReportSchedule(
type=ReportScheduleType.REPORT,
name="report_with_chart",
crontab="* * * * *",
chart=chart,
)
db.session.commit()
yield chart
yield chart
# rollback changes
db.session.delete(report_schedule)
db.session.delete(chart)
db.session.commit()
# rollback changes
db.session.delete(report_schedule)
db.session.delete(chart)
db.session.commit()
@pytest.fixture()
def add_dashboard_to_chart(self):
with self.create_app().app_context():
admin = self.get_user("admin")
admin = self.get_user("admin")
self.chart = self.insert_chart("My chart", [admin.id], 1)
self.chart = self.insert_chart("My chart", [admin.id], 1)
self.original_dashboard = Dashboard()
self.original_dashboard.dashboard_title = "Original Dashboard"
self.original_dashboard.slug = "slug"
self.original_dashboard.owners = [admin]
self.original_dashboard.slices = [self.chart]
self.original_dashboard.published = False
db.session.add(self.original_dashboard)
self.original_dashboard = Dashboard()
self.original_dashboard.dashboard_title = "Original Dashboard"
self.original_dashboard.slug = "slug"
self.original_dashboard.owners = [admin]
self.original_dashboard.slices = [self.chart]
self.original_dashboard.published = False
db.session.add(self.original_dashboard)
self.new_dashboard = Dashboard()
self.new_dashboard.dashboard_title = "New Dashboard"
self.new_dashboard.slug = "new_slug"
self.new_dashboard.owners = [admin]
self.new_dashboard.published = False
db.session.add(self.new_dashboard)
self.new_dashboard = Dashboard()
self.new_dashboard.dashboard_title = "New Dashboard"
self.new_dashboard.slug = "new_slug"
self.new_dashboard.owners = [admin]
self.new_dashboard.published = False
db.session.add(self.new_dashboard)
db.session.commit()
db.session.commit()
yield self.chart
yield self.chart
db.session.delete(self.original_dashboard)
db.session.delete(self.new_dashboard)
db.session.delete(self.chart)
db.session.commit()
db.session.delete(self.original_dashboard)
db.session.delete(self.new_dashboard)
db.session.delete(self.chart)
db.session.commit()
def test_info_security_chart(self):
"""
@ -1127,40 +1122,39 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
@pytest.fixture()
def load_energy_charts(self):
with app.app_context():
admin = self.get_user("admin")
energy_table = (
db.session.query(SqlaTable)
.filter_by(table_name="energy_usage")
.one_or_none()
)
energy_table_id = 1
if energy_table:
energy_table_id = energy_table.id
chart1 = self.insert_chart(
"foo_a", [admin.id], energy_table_id, description="ZY_bar"
)
chart2 = self.insert_chart(
"zy_foo", [admin.id], energy_table_id, description="desc1"
)
chart3 = self.insert_chart(
"foo_b", [admin.id], energy_table_id, description="desc1zy_"
)
chart4 = self.insert_chart(
"foo_c", [admin.id], energy_table_id, viz_type="viz_zy_"
)
chart5 = self.insert_chart(
"bar", [admin.id], energy_table_id, description="foo"
)
admin = self.get_user("admin")
energy_table = (
db.session.query(SqlaTable)
.filter_by(table_name="energy_usage")
.one_or_none()
)
energy_table_id = 1
if energy_table:
energy_table_id = energy_table.id
chart1 = self.insert_chart(
"foo_a", [admin.id], energy_table_id, description="ZY_bar"
)
chart2 = self.insert_chart(
"zy_foo", [admin.id], energy_table_id, description="desc1"
)
chart3 = self.insert_chart(
"foo_b", [admin.id], energy_table_id, description="desc1zy_"
)
chart4 = self.insert_chart(
"foo_c", [admin.id], energy_table_id, viz_type="viz_zy_"
)
chart5 = self.insert_chart(
"bar", [admin.id], energy_table_id, description="foo"
)
yield
# rollback changes
db.session.delete(chart1)
db.session.delete(chart2)
db.session.delete(chart3)
db.session.delete(chart4)
db.session.delete(chart5)
db.session.commit()
yield
# rollback changes
db.session.delete(chart1)
db.session.delete(chart2)
db.session.delete(chart3)
db.session.delete(chart4)
db.session.delete(chart5)
db.session.commit()
@pytest.mark.usefixtures("load_energy_charts")
def test_get_charts_custom_filter(self):

View File

@ -27,6 +27,7 @@ from unittest import mock
from zipfile import ZipFile
from flask import Response
from flask.ctx import AppContext
from tests.integration_tests.conftest import with_feature_flags
from superset.charts.data.api import ChartDataRestApi
from superset.models.sql_lab import Query
@ -88,10 +89,9 @@ INCOMPATIBLE_ADHOC_COLUMN_FIXTURE: AdhocColumn = {
@pytest.fixture(autouse=True)
def skip_by_backend():
with app.app_context():
if backend() == "hive":
pytest.skip("Skipping tests for Hive backend")
def skip_by_backend(app_context: AppContext):
if backend() == "hive":
pytest.skip("Skipping tests for Hive backend")
class BaseTestChartDataApi(SupersetTestCase):

View File

@ -118,8 +118,8 @@ def get_or_create_user(get_user, create_user) -> ab_models.User:
@pytest.fixture(autouse=True, scope="session")
def setup_sample_data() -> Any:
# TODO(john-bodley): Determine a cleaner way of setting up the sample data without
# relying on `tests.integration_tests.test_app.app` leveraging an `app` fixture which is purposely
# scoped to the function level to ensure tests remain idempotent.
# relying on `tests.integration_tests.test_app.app` leveraging an `app` fixture
# which is purposely scoped to the function level to ensure tests remain idempotent.
with app.app_context():
setup_presto_if_needed()
@ -135,7 +135,6 @@ def setup_sample_data() -> Any:
with app.app_context():
# drop sqlalchemy tables
db.session.commit()
from sqlalchemy.ext import declarative
@ -163,12 +162,12 @@ def example_db_provider() -> Callable[[], Database]: # type: ignore
_db: Database | None = None
def __call__(self) -> Database:
with app.app_context():
if self._db is None:
if self._db is None:
with app.app_context():
self._db = get_example_database()
self._load_lazy_data_to_decouple_from_session()
return self._db
return self._db
def _load_lazy_data_to_decouple_from_session(self) -> None:
self._db._get_sqla_engine() # type: ignore

View File

@ -58,39 +58,36 @@ from .base_tests import SupersetTestCase
class TestDashboard(SupersetTestCase):
@pytest.fixture
def load_dashboard(self):
with app.app_context():
table = (
db.session.query(SqlaTable).filter_by(table_name="energy_usage").one()
)
# get a slice from the allowed table
slice = db.session.query(Slice).filter_by(slice_name="Energy Sankey").one()
table = db.session.query(SqlaTable).filter_by(table_name="energy_usage").one()
# get a slice from the allowed table
slice = db.session.query(Slice).filter_by(slice_name="Energy Sankey").one()
self.grant_public_access_to_table(table)
self.grant_public_access_to_table(table)
pytest.hidden_dash_slug = f"hidden_dash_{random()}"
pytest.published_dash_slug = f"published_dash_{random()}"
pytest.hidden_dash_slug = f"hidden_dash_{random()}"
pytest.published_dash_slug = f"published_dash_{random()}"
# Create a published and hidden dashboard and add them to the database
published_dash = Dashboard()
published_dash.dashboard_title = "Published Dashboard"
published_dash.slug = pytest.published_dash_slug
published_dash.slices = [slice]
published_dash.published = True
# Create a published and hidden dashboard and add them to the database
published_dash = Dashboard()
published_dash.dashboard_title = "Published Dashboard"
published_dash.slug = pytest.published_dash_slug
published_dash.slices = [slice]
published_dash.published = True
hidden_dash = Dashboard()
hidden_dash.dashboard_title = "Hidden Dashboard"
hidden_dash.slug = pytest.hidden_dash_slug
hidden_dash.slices = [slice]
hidden_dash.published = False
hidden_dash = Dashboard()
hidden_dash.dashboard_title = "Hidden Dashboard"
hidden_dash.slug = pytest.hidden_dash_slug
hidden_dash.slices = [slice]
hidden_dash.published = False
db.session.add(published_dash)
db.session.add(hidden_dash)
yield db.session.commit()
db.session.add(published_dash)
db.session.add(hidden_dash)
yield db.session.commit()
self.revoke_public_access_to_table(table)
db.session.delete(published_dash)
db.session.delete(hidden_dash)
db.session.commit()
self.revoke_public_access_to_table(table)
db.session.delete(published_dash)
db.session.delete(hidden_dash)
db.session.commit()
def get_mock_positions(self, dash):
positions = {"DASHBOARD_VERSION_KEY": "v2"}

View File

@ -2088,8 +2088,6 @@ class TestDashboardApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCas
self.assertNotEqual(result["uuid"], "")
self.assertEqual(result["allowed_domains"], allowed_domains)
db.session.expire_all()
# get returns value
resp = self.get_assert_metric(uri, "get_embedded")
self.assertEqual(resp.status_code, 200)
@ -2110,8 +2108,6 @@ class TestDashboardApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCas
self.assertNotEqual(result["uuid"], "")
self.assertEqual(result["allowed_domains"], [])
db.session.expire_all()
# get returns changed value
resp = self.get_assert_metric(uri, "get_embedded")
self.assertEqual(resp.status_code, 200)
@ -2123,8 +2119,6 @@ class TestDashboardApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCas
resp = self.delete_assert_metric(uri, "delete_embedded")
self.assertEqual(resp.status_code, 200)
db.session.expire_all()
# get returns 404
resp = self.get_assert_metric(uri, "get_embedded")
self.assertEqual(resp.status_code, 404)

View File

@ -37,39 +37,36 @@ from tests.integration_tests.fixtures.energy_dashboard import (
class TestDashboardDatasetSecurity(DashboardTestCase):
@pytest.fixture
def load_dashboard(self):
with app.app_context():
table = (
db.session.query(SqlaTable).filter_by(table_name="energy_usage").one()
)
# get a slice from the allowed table
slice = db.session.query(Slice).filter_by(slice_name="Energy Sankey").one()
table = db.session.query(SqlaTable).filter_by(table_name="energy_usage").one()
# get a slice from the allowed table
slice = db.session.query(Slice).filter_by(slice_name="Energy Sankey").one()
self.grant_public_access_to_table(table)
self.grant_public_access_to_table(table)
pytest.hidden_dash_slug = f"hidden_dash_{random_slug()}"
pytest.published_dash_slug = f"published_dash_{random_slug()}"
pytest.hidden_dash_slug = f"hidden_dash_{random_slug()}"
pytest.published_dash_slug = f"published_dash_{random_slug()}"
# Create a published and hidden dashboard and add them to the database
published_dash = Dashboard()
published_dash.dashboard_title = "Published Dashboard"
published_dash.slug = pytest.published_dash_slug
published_dash.slices = [slice]
published_dash.published = True
# Create a published and hidden dashboard and add them to the database
published_dash = Dashboard()
published_dash.dashboard_title = "Published Dashboard"
published_dash.slug = pytest.published_dash_slug
published_dash.slices = [slice]
published_dash.published = True
hidden_dash = Dashboard()
hidden_dash.dashboard_title = "Hidden Dashboard"
hidden_dash.slug = pytest.hidden_dash_slug
hidden_dash.slices = [slice]
hidden_dash.published = False
hidden_dash = Dashboard()
hidden_dash.dashboard_title = "Hidden Dashboard"
hidden_dash.slug = pytest.hidden_dash_slug
hidden_dash.slices = [slice]
hidden_dash.published = False
db.session.add(published_dash)
db.session.add(hidden_dash)
yield db.session.commit()
db.session.add(published_dash)
db.session.add(hidden_dash)
yield db.session.commit()
self.revoke_public_access_to_table(table)
db.session.delete(published_dash)
db.session.delete(hidden_dash)
db.session.commit()
self.revoke_public_access_to_table(table)
db.session.delete(published_dash)
db.session.delete(hidden_dash)
db.session.commit()
def test_dashboard_access__admin_can_access_all(self):
# arrange

View File

@ -20,6 +20,7 @@ from __future__ import annotations
import json
import pytest
from flask.ctx import AppContext
from superset import db, security_manager
from superset.commands.database.exceptions import (
@ -84,16 +85,14 @@ def get_upload_db():
return db.session.query(Database).filter_by(database_name=CSV_UPLOAD_DATABASE).one()
@pytest.fixture(scope="function")
def setup_csv_upload_with_context():
with app.app_context():
yield from _setup_csv_upload()
@pytest.fixture()
def setup_csv_upload_with_context(app_context: AppContext):
yield from _setup_csv_upload()
@pytest.fixture(scope="function")
def setup_csv_upload_with_context_schema():
with app.app_context():
yield from _setup_csv_upload(["public"])
@pytest.fixture()
def setup_csv_upload_with_context_schema(app_context: AppContext):
yield from _setup_csv_upload(["public"])
@pytest.mark.usefixtures("setup_csv_upload_with_context")

View File

@ -46,6 +46,5 @@ class TestEmbeddedDashboardDAO(SupersetTestCase):
def test_get_by_uuid(self):
dash = db.session.query(Dashboard).filter_by(slug="world_health").first()
uuid = str(EmbeddedDashboardDAO.upsert(dash, ["test.example.com"]).uuid)
db.session.expire_all()
embedded = EmbeddedDashboardDAO.find_by_id(uuid)
self.assertIsNotNone(embedded)

View File

@ -173,39 +173,38 @@ def get_datasource_post() -> dict[str, Any]:
@pytest.fixture()
@pytest.mark.usefixtures("app_conntext")
def load_dataset_with_columns() -> Generator[SqlaTable, None, None]:
with app.app_context():
engine = create_engine(app.config["SQLALCHEMY_DATABASE_URI"], echo=True)
meta = MetaData()
engine = create_engine(app.config["SQLALCHEMY_DATABASE_URI"], echo=True)
meta = MetaData()
students = Table(
"students",
meta,
Column("id", Integer, primary_key=True),
Column("name", String(255)),
Column("lastname", String(255)),
Column("ds", Date),
)
meta.create_all(engine)
students = Table(
"students",
meta,
Column("id", Integer, primary_key=True),
Column("name", String(255)),
Column("lastname", String(255)),
Column("ds", Date),
)
meta.create_all(engine)
students.insert().values(name="George", ds="2021-01-01")
students.insert().values(name="George", ds="2021-01-01")
dataset = SqlaTable(
database_id=db.session.query(Database).first().id, table_name="students"
)
column = TableColumn(table_id=dataset.id, column_name="name")
dataset.columns = [column]
db.session.add(dataset)
db.session.commit()
yield dataset
# cleanup
students_table = meta.tables.get("students")
if students_table is not None:
base = declarative_base()
# needed for sqlite
db.session.commit()
base.metadata.drop_all(engine, [students_table], checkfirst=True)
db.session.delete(dataset)
db.session.delete(column)
dataset = SqlaTable(
database_id=db.session.query(Database).first().id, table_name="students"
)
column = TableColumn(table_id=dataset.id, column_name="name")
dataset.columns = [column]
db.session.add(dataset)
db.session.commit()
yield dataset
# cleanup
if (students_table := meta.tables.get("students")) is not None:
base = declarative_base()
# needed for sqlite
db.session.commit()
base.metadata.drop_all(engine, [students_table], checkfirst=True)
db.session.delete(dataset)
db.session.delete(column)
db.session.commit()

View File

@ -15,30 +15,29 @@
# specific language governing permissions and limitations
# under the License.
import pytest
from flask.ctx import AppContext
from superset.extensions import db, security_manager
from tests.integration_tests.test_app import app
@pytest.fixture()
def public_role_like_gamma():
with app.app_context():
app.config["PUBLIC_ROLE_LIKE"] = "Gamma"
security_manager.sync_role_definitions()
def public_role_like_gamma(app_context: AppContext):
app.config["PUBLIC_ROLE_LIKE"] = "Gamma"
security_manager.sync_role_definitions()
yield
yield
security_manager.get_public_role().permissions = []
db.session.commit()
security_manager.get_public_role().permissions = []
db.session.commit()
@pytest.fixture()
def public_role_like_test_role():
with app.app_context():
app.config["PUBLIC_ROLE_LIKE"] = "TestRole"
security_manager.sync_role_definitions()
def public_role_like_test_role(app_context: AppContext):
app.config["PUBLIC_ROLE_LIKE"] = "TestRole"
security_manager.sync_role_definitions()
yield
yield
security_manager.get_public_role().permissions = []
db.session.commit()
security_manager.get_public_role().permissions = []
db.session.commit()

View File

@ -22,12 +22,12 @@ from tests.integration_tests.test_app import app
@pytest.fixture
@pytest.mark.usefixtures("app_context")
def with_tagging_system_feature():
with app.app_context():
is_enabled = app.config["DEFAULT_FEATURE_FLAGS"]["TAGGING_SYSTEM"]
if not is_enabled:
app.config["DEFAULT_FEATURE_FLAGS"]["TAGGING_SYSTEM"] = True
register_sqla_event_listeners()
yield
app.config["DEFAULT_FEATURE_FLAGS"]["TAGGING_SYSTEM"] = False
clear_sqla_event_listeners()
is_enabled = app.config["DEFAULT_FEATURE_FLAGS"]["TAGGING_SYSTEM"]
if not is_enabled:
app.config["DEFAULT_FEATURE_FLAGS"]["TAGGING_SYSTEM"] = True
register_sqla_event_listeners()
yield
app.config["DEFAULT_FEATURE_FLAGS"]["TAGGING_SYSTEM"] = False
clear_sqla_event_listeners()

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import pytest
from flask.ctx import AppContext
from flask_appbuilder.security.sqla.models import Role, User
from superset import db, security_manager
@ -23,27 +24,24 @@ from tests.integration_tests.test_app import app
@pytest.fixture()
def create_gamma_sqllab_no_data():
with app.app_context():
gamma_role = db.session.query(Role).filter(Role.name == "Gamma").one_or_none()
sqllab_role = (
db.session.query(Role).filter(Role.name == "sql_lab").one_or_none()
)
def create_gamma_sqllab_no_data(app_context: AppContext):
gamma_role = db.session.query(Role).filter(Role.name == "Gamma").one_or_none()
sqllab_role = db.session.query(Role).filter(Role.name == "sql_lab").one_or_none()
security_manager.add_user(
GAMMA_SQLLAB_NO_DATA_USERNAME,
"gamma_sqllab_no_data",
"gamma_sqllab_no_data",
"gamma_sqllab_no_data@apache.org",
[gamma_role, sqllab_role],
password="general",
)
security_manager.add_user(
GAMMA_SQLLAB_NO_DATA_USERNAME,
"gamma_sqllab_no_data",
"gamma_sqllab_no_data",
"gamma_sqllab_no_data@apache.org",
[gamma_role, sqllab_role],
password="general",
)
yield
user = (
db.session.query(User)
.filter(User.username == GAMMA_SQLLAB_NO_DATA_USERNAME)
.one_or_none()
)
db.session.delete(user)
db.session.commit()
yield
user = (
db.session.query(User)
.filter(User.username == GAMMA_SQLLAB_NO_DATA_USERNAME)
.one_or_none()
)
db.session.delete(user)
db.session.commit()

View File

@ -93,13 +93,12 @@ def load_world_bank_dashboard_with_slices_class_scope(load_world_bank_data):
def create_dashboard_for_loaded_data():
with app.app_context():
table = create_table_metadata(WB_HEALTH_POPULATION, get_example_database())
slices = _create_world_bank_slices(table)
dash = _create_world_bank_dashboard(table)
slices_ids_to_delete = [slice.id for slice in slices]
dash_id_to_delete = dash.id
return dash_id_to_delete, slices_ids_to_delete
table = create_table_metadata(WB_HEALTH_POPULATION, get_example_database())
slices = _create_world_bank_slices(table)
dash = _create_world_bank_dashboard(table)
slices_ids_to_delete = [slice.id for slice in slices]
dash_id_to_delete = dash.id
return dash_id_to_delete, slices_ids_to_delete
def _create_world_bank_slices(table: SqlaTable) -> list[Slice]:

View File

@ -16,6 +16,8 @@
# under the License.
from importlib import import_module
import pytest
from superset import db
from superset.migrations.shared.security_converge import (
_find_pvm,
@ -34,28 +36,28 @@ upgrade = migration_module.do_upgrade
downgrade = migration_module.do_downgrade
@pytest.mark.usefixtures("app_context")
def test_migration_upgrade():
with app.app_context():
pre_perm = PermissionView(
permission=Permission(name="can_view_and_drill"),
view_menu=db.session.query(ViewMenu).filter_by(name="Dashboard").one(),
)
db.session.add(pre_perm)
db.session.commit()
pre_perm = PermissionView(
permission=Permission(name="can_view_and_drill"),
view_menu=db.session.query(ViewMenu).filter_by(name="Dashboard").one(),
)
db.session.add(pre_perm)
db.session.commit()
assert _find_pvm(db.session, "Dashboard", "can_view_and_drill") is not None
assert _find_pvm(db.session, "Dashboard", "can_view_and_drill") is not None
upgrade(db.session)
upgrade(db.session)
assert _find_pvm(db.session, "Dashboard", "can_view_chart_as_table") is not None
assert _find_pvm(db.session, "Dashboard", "can_view_query") is not None
assert _find_pvm(db.session, "Dashboard", "can_view_and_drill") is None
assert _find_pvm(db.session, "Dashboard", "can_view_chart_as_table") is not None
assert _find_pvm(db.session, "Dashboard", "can_view_query") is not None
assert _find_pvm(db.session, "Dashboard", "can_view_and_drill") is None
@pytest.mark.usefixtures("app_context")
def test_migration_downgrade():
with app.app_context():
downgrade(db.session)
downgrade(db.session)
assert _find_pvm(db.session, "Dashboard", "can_view_chart_as_table") is None
assert _find_pvm(db.session, "Dashboard", "can_view_query") is None
assert _find_pvm(db.session, "Dashboard", "can_view_and_drill") is not None
assert _find_pvm(db.session, "Dashboard", "can_view_chart_as_table") is None
assert _find_pvm(db.session, "Dashboard", "can_view_query") is None
assert _find_pvm(db.session, "Dashboard", "can_view_and_drill") is not None

View File

@ -20,6 +20,7 @@ from typing import Optional, Union
import pandas as pd
import pytest
from flask.ctx import AppContext
from pytest_mock import MockFixture
from superset.commands.report.exceptions import AlertQueryError
@ -61,43 +62,40 @@ def test_execute_query_as_report_executor(
config: list[ExecutorType],
expected_result: Union[tuple[ExecutorType, str], Exception],
mocker: MockFixture,
app_context: None,
app_context: AppContext,
get_user,
) -> None:
from superset.commands.report.alert import AlertCommand
from superset.reports.models import ReportSchedule
with app.app_context():
original_config = app.config["ALERT_REPORTS_EXECUTE_AS"]
app.config["ALERT_REPORTS_EXECUTE_AS"] = config
owners = [get_user(owner_name) for owner_name in owner_names]
report_schedule = ReportSchedule(
created_by=get_user(creator_name) if creator_name else None,
owners=owners,
type=ReportScheduleType.ALERT,
description="description",
crontab="0 9 * * *",
creation_method=ReportCreationMethod.ALERTS_REPORTS,
sql="SELECT 1",
grace_period=14400,
working_timeout=3600,
database=get_example_database(),
validator_config_json='{"op": "==", "threshold": 1}',
)
command = AlertCommand(report_schedule=report_schedule)
override_user_mock = mocker.patch(
"superset.commands.report.alert.override_user"
)
cm = (
pytest.raises(type(expected_result))
if isinstance(expected_result, Exception)
else nullcontext()
)
with cm:
command.run()
assert override_user_mock.call_args[0][0].username == expected_result
original_config = app.config["ALERT_REPORTS_EXECUTE_AS"]
app.config["ALERT_REPORTS_EXECUTE_AS"] = config
owners = [get_user(owner_name) for owner_name in owner_names]
report_schedule = ReportSchedule(
created_by=get_user(creator_name) if creator_name else None,
owners=owners,
type=ReportScheduleType.ALERT,
description="description",
crontab="0 9 * * *",
creation_method=ReportCreationMethod.ALERTS_REPORTS,
sql="SELECT 1",
grace_period=14400,
working_timeout=3600,
database=get_example_database(),
validator_config_json='{"op": "==", "threshold": 1}',
)
command = AlertCommand(report_schedule=report_schedule)
override_user_mock = mocker.patch("superset.commands.report.alert.override_user")
cm = (
pytest.raises(type(expected_result))
if isinstance(expected_result, Exception)
else nullcontext()
)
with cm:
command.run()
assert override_user_mock.call_args[0][0].username == expected_result
app.config["ALERT_REPORTS_EXECUTE_AS"] = original_config
app.config["ALERT_REPORTS_EXECUTE_AS"] = original_config
def test_execute_query_succeeded_no_retry(

View File

@ -23,6 +23,7 @@ from uuid import uuid4
import pytest
from flask import current_app
from flask.ctx import AppContext
from flask_appbuilder.security.sqla.models import User
from flask_sqlalchemy import BaseQuery
from freezegun import freeze_time
@ -162,204 +163,191 @@ def create_test_table_context(database: Database):
@pytest.fixture()
def create_report_email_chart():
with app.app_context():
chart = db.session.query(Slice).first()
report_schedule = create_report_notification(
email_target="target@email.com", chart=chart
)
yield report_schedule
chart = db.session.query(Slice).first()
report_schedule = create_report_notification(
email_target="target@email.com", chart=chart
)
yield report_schedule
cleanup_report_schedule(report_schedule)
cleanup_report_schedule(report_schedule)
@pytest.fixture()
def create_report_email_chart_alpha_owner(get_user):
with app.app_context():
owners = [get_user("alpha")]
chart = db.session.query(Slice).first()
report_schedule = create_report_notification(
email_target="target@email.com", chart=chart, owners=owners
)
yield report_schedule
owners = [get_user("alpha")]
chart = db.session.query(Slice).first()
report_schedule = create_report_notification(
email_target="target@email.com", chart=chart, owners=owners
)
yield report_schedule
cleanup_report_schedule(report_schedule)
cleanup_report_schedule(report_schedule)
@pytest.fixture()
def create_report_email_chart_force_screenshot():
with app.app_context():
chart = db.session.query(Slice).first()
report_schedule = create_report_notification(
email_target="target@email.com", chart=chart, force_screenshot=True
)
yield report_schedule
chart = db.session.query(Slice).first()
report_schedule = create_report_notification(
email_target="target@email.com", chart=chart, force_screenshot=True
)
yield report_schedule
cleanup_report_schedule(report_schedule)
cleanup_report_schedule(report_schedule)
@pytest.fixture()
def create_report_email_chart_with_csv():
with app.app_context():
chart = db.session.query(Slice).first()
chart.query_context = '{"mock": "query_context"}'
report_schedule = create_report_notification(
email_target="target@email.com",
chart=chart,
report_format=ReportDataFormat.CSV,
)
yield report_schedule
cleanup_report_schedule(report_schedule)
chart = db.session.query(Slice).first()
chart.query_context = '{"mock": "query_context"}'
report_schedule = create_report_notification(
email_target="target@email.com",
chart=chart,
report_format=ReportDataFormat.CSV,
)
yield report_schedule
cleanup_report_schedule(report_schedule)
@pytest.fixture()
def create_report_email_chart_with_text():
with app.app_context():
chart = db.session.query(Slice).first()
chart.query_context = '{"mock": "query_context"}'
report_schedule = create_report_notification(
email_target="target@email.com",
chart=chart,
report_format=ReportDataFormat.TEXT,
)
yield report_schedule
cleanup_report_schedule(report_schedule)
chart = db.session.query(Slice).first()
chart.query_context = '{"mock": "query_context"}'
report_schedule = create_report_notification(
email_target="target@email.com",
chart=chart,
report_format=ReportDataFormat.TEXT,
)
yield report_schedule
cleanup_report_schedule(report_schedule)
@pytest.fixture()
def create_report_email_chart_with_csv_no_query_context():
with app.app_context():
chart = db.session.query(Slice).first()
chart.query_context = None
report_schedule = create_report_notification(
email_target="target@email.com",
chart=chart,
report_format=ReportDataFormat.CSV,
name="report_csv_no_query_context",
)
yield report_schedule
cleanup_report_schedule(report_schedule)
chart = db.session.query(Slice).first()
chart.query_context = None
report_schedule = create_report_notification(
email_target="target@email.com",
chart=chart,
report_format=ReportDataFormat.CSV,
name="report_csv_no_query_context",
)
yield report_schedule
cleanup_report_schedule(report_schedule)
@pytest.fixture()
def create_report_email_dashboard():
with app.app_context():
dashboard = db.session.query(Dashboard).first()
report_schedule = create_report_notification(
email_target="target@email.com", dashboard=dashboard
)
yield report_schedule
dashboard = db.session.query(Dashboard).first()
report_schedule = create_report_notification(
email_target="target@email.com", dashboard=dashboard
)
yield report_schedule
cleanup_report_schedule(report_schedule)
cleanup_report_schedule(report_schedule)
@pytest.fixture()
def create_report_email_dashboard_force_screenshot():
with app.app_context():
dashboard = db.session.query(Dashboard).first()
report_schedule = create_report_notification(
email_target="target@email.com", dashboard=dashboard, force_screenshot=True
)
yield report_schedule
dashboard = db.session.query(Dashboard).first()
report_schedule = create_report_notification(
email_target="target@email.com", dashboard=dashboard, force_screenshot=True
)
yield report_schedule
cleanup_report_schedule(report_schedule)
cleanup_report_schedule(report_schedule)
@pytest.fixture()
def create_report_slack_chart():
with app.app_context():
chart = db.session.query(Slice).first()
report_schedule = create_report_notification(
slack_channel="slack_channel", chart=chart
)
yield report_schedule
chart = db.session.query(Slice).first()
report_schedule = create_report_notification(
slack_channel="slack_channel", chart=chart
)
yield report_schedule
cleanup_report_schedule(report_schedule)
cleanup_report_schedule(report_schedule)
@pytest.fixture()
def create_report_slack_chart_with_csv():
with app.app_context():
chart = db.session.query(Slice).first()
chart.query_context = '{"mock": "query_context"}'
report_schedule = create_report_notification(
slack_channel="slack_channel",
chart=chart,
report_format=ReportDataFormat.CSV,
)
yield report_schedule
chart = db.session.query(Slice).first()
chart.query_context = '{"mock": "query_context"}'
report_schedule = create_report_notification(
slack_channel="slack_channel",
chart=chart,
report_format=ReportDataFormat.CSV,
)
yield report_schedule
cleanup_report_schedule(report_schedule)
cleanup_report_schedule(report_schedule)
@pytest.fixture()
def create_report_slack_chart_with_text():
with app.app_context():
chart = db.session.query(Slice).first()
chart.query_context = '{"mock": "query_context"}'
report_schedule = create_report_notification(
slack_channel="slack_channel",
chart=chart,
report_format=ReportDataFormat.TEXT,
)
yield report_schedule
chart = db.session.query(Slice).first()
chart.query_context = '{"mock": "query_context"}'
report_schedule = create_report_notification(
slack_channel="slack_channel",
chart=chart,
report_format=ReportDataFormat.TEXT,
)
yield report_schedule
cleanup_report_schedule(report_schedule)
cleanup_report_schedule(report_schedule)
@pytest.fixture()
def create_report_slack_chart_working():
with app.app_context():
chart = db.session.query(Slice).first()
report_schedule = create_report_notification(
slack_channel="slack_channel", chart=chart
)
report_schedule.last_state = ReportState.WORKING
report_schedule.last_eval_dttm = datetime(2020, 1, 1, 0, 0)
report_schedule.last_value = None
report_schedule.last_value_row_json = None
db.session.commit()
log = ReportExecutionLog(
scheduled_dttm=report_schedule.last_eval_dttm,
start_dttm=report_schedule.last_eval_dttm,
end_dttm=report_schedule.last_eval_dttm,
value=report_schedule.last_value,
value_row_json=report_schedule.last_value_row_json,
state=ReportState.WORKING,
report_schedule=report_schedule,
uuid=uuid4(),
)
db.session.add(log)
db.session.commit()
chart = db.session.query(Slice).first()
report_schedule = create_report_notification(
slack_channel="slack_channel", chart=chart
)
report_schedule.last_state = ReportState.WORKING
report_schedule.last_eval_dttm = datetime(2020, 1, 1, 0, 0)
report_schedule.last_value = None
report_schedule.last_value_row_json = None
db.session.commit()
log = ReportExecutionLog(
scheduled_dttm=report_schedule.last_eval_dttm,
start_dttm=report_schedule.last_eval_dttm,
end_dttm=report_schedule.last_eval_dttm,
value=report_schedule.last_value,
value_row_json=report_schedule.last_value_row_json,
state=ReportState.WORKING,
report_schedule=report_schedule,
uuid=uuid4(),
)
db.session.add(log)
db.session.commit()
yield report_schedule
yield report_schedule
cleanup_report_schedule(report_schedule)
cleanup_report_schedule(report_schedule)
@pytest.fixture()
def create_alert_slack_chart_success():
with app.app_context():
chart = db.session.query(Slice).first()
report_schedule = create_report_notification(
slack_channel="slack_channel",
chart=chart,
report_type=ReportScheduleType.ALERT,
)
report_schedule.last_state = ReportState.SUCCESS
report_schedule.last_eval_dttm = datetime(2020, 1, 1, 0, 0)
chart = db.session.query(Slice).first()
report_schedule = create_report_notification(
slack_channel="slack_channel",
chart=chart,
report_type=ReportScheduleType.ALERT,
)
report_schedule.last_state = ReportState.SUCCESS
report_schedule.last_eval_dttm = datetime(2020, 1, 1, 0, 0)
log = ReportExecutionLog(
report_schedule=report_schedule,
state=ReportState.SUCCESS,
start_dttm=report_schedule.last_eval_dttm,
end_dttm=report_schedule.last_eval_dttm,
scheduled_dttm=report_schedule.last_eval_dttm,
)
db.session.add(log)
db.session.commit()
yield report_schedule
log = ReportExecutionLog(
report_schedule=report_schedule,
state=ReportState.SUCCESS,
start_dttm=report_schedule.last_eval_dttm,
end_dttm=report_schedule.last_eval_dttm,
scheduled_dttm=report_schedule.last_eval_dttm,
)
db.session.add(log)
db.session.commit()
yield report_schedule
cleanup_report_schedule(report_schedule)
cleanup_report_schedule(report_schedule)
@pytest.fixture(
@ -375,36 +363,33 @@ def create_alert_slack_chart_grace(request):
"validator_config_json": '{"op": "<", "threshold": 10}',
},
}
with app.app_context():
chart = db.session.query(Slice).first()
example_database = get_example_database()
with create_test_table_context(example_database):
report_schedule = create_report_notification(
slack_channel="slack_channel",
chart=chart,
report_type=ReportScheduleType.ALERT,
database=example_database,
sql=param_config[request.param]["sql"],
validator_type=param_config[request.param]["validator_type"],
validator_config_json=param_config[request.param][
"validator_config_json"
],
)
report_schedule.last_state = ReportState.GRACE
report_schedule.last_eval_dttm = datetime(2020, 1, 1, 0, 0)
chart = db.session.query(Slice).first()
example_database = get_example_database()
with create_test_table_context(example_database):
report_schedule = create_report_notification(
slack_channel="slack_channel",
chart=chart,
report_type=ReportScheduleType.ALERT,
database=example_database,
sql=param_config[request.param]["sql"],
validator_type=param_config[request.param]["validator_type"],
validator_config_json=param_config[request.param]["validator_config_json"],
)
report_schedule.last_state = ReportState.GRACE
report_schedule.last_eval_dttm = datetime(2020, 1, 1, 0, 0)
log = ReportExecutionLog(
report_schedule=report_schedule,
state=ReportState.SUCCESS,
start_dttm=report_schedule.last_eval_dttm,
end_dttm=report_schedule.last_eval_dttm,
scheduled_dttm=report_schedule.last_eval_dttm,
)
db.session.add(log)
db.session.commit()
yield report_schedule
log = ReportExecutionLog(
report_schedule=report_schedule,
state=ReportState.SUCCESS,
start_dttm=report_schedule.last_eval_dttm,
end_dttm=report_schedule.last_eval_dttm,
scheduled_dttm=report_schedule.last_eval_dttm,
)
db.session.add(log)
db.session.commit()
yield report_schedule
cleanup_report_schedule(report_schedule)
cleanup_report_schedule(report_schedule)
@pytest.fixture(
@ -462,25 +447,22 @@ def create_alert_email_chart(request):
"validator_config_json": '{"op": ">", "threshold": 54.999}',
},
}
with app.app_context():
chart = db.session.query(Slice).first()
example_database = get_example_database()
with create_test_table_context(example_database):
report_schedule = create_report_notification(
email_target="target@email.com",
chart=chart,
report_type=ReportScheduleType.ALERT,
database=example_database,
sql=param_config[request.param]["sql"],
validator_type=param_config[request.param]["validator_type"],
validator_config_json=param_config[request.param][
"validator_config_json"
],
force_screenshot=True,
)
yield report_schedule
chart = db.session.query(Slice).first()
example_database = get_example_database()
with create_test_table_context(example_database):
report_schedule = create_report_notification(
email_target="target@email.com",
chart=chart,
report_type=ReportScheduleType.ALERT,
database=example_database,
sql=param_config[request.param]["sql"],
validator_type=param_config[request.param]["validator_type"],
validator_config_json=param_config[request.param]["validator_config_json"],
force_screenshot=True,
)
yield report_schedule
cleanup_report_schedule(report_schedule)
cleanup_report_schedule(report_schedule)
@pytest.fixture(
@ -544,24 +526,21 @@ def create_no_alert_email_chart(request):
"validator_config_json": '{"op": ">", "threshold": 0}',
},
}
with app.app_context():
chart = db.session.query(Slice).first()
example_database = get_example_database()
with create_test_table_context(example_database):
report_schedule = create_report_notification(
email_target="target@email.com",
chart=chart,
report_type=ReportScheduleType.ALERT,
database=example_database,
sql=param_config[request.param]["sql"],
validator_type=param_config[request.param]["validator_type"],
validator_config_json=param_config[request.param][
"validator_config_json"
],
)
yield report_schedule
chart = db.session.query(Slice).first()
example_database = get_example_database()
with create_test_table_context(example_database):
report_schedule = create_report_notification(
email_target="target@email.com",
chart=chart,
report_type=ReportScheduleType.ALERT,
database=example_database,
sql=param_config[request.param]["sql"],
validator_type=param_config[request.param]["validator_type"],
validator_config_json=param_config[request.param]["validator_config_json"],
)
yield report_schedule
cleanup_report_schedule(report_schedule)
cleanup_report_schedule(report_schedule)
@pytest.fixture(params=["alert1", "alert2"])
@ -578,28 +557,25 @@ def create_mul_alert_email_chart(request):
"validator_config_json": '{"op": "<", "threshold": 10}',
},
}
with app.app_context():
chart = db.session.query(Slice).first()
example_database = get_example_database()
with create_test_table_context(example_database):
report_schedule = create_report_notification(
email_target="target@email.com",
chart=chart,
report_type=ReportScheduleType.ALERT,
database=example_database,
sql=param_config[request.param]["sql"],
validator_type=param_config[request.param]["validator_type"],
validator_config_json=param_config[request.param][
"validator_config_json"
],
)
yield report_schedule
chart = db.session.query(Slice).first()
example_database = get_example_database()
with create_test_table_context(example_database):
report_schedule = create_report_notification(
email_target="target@email.com",
chart=chart,
report_type=ReportScheduleType.ALERT,
database=example_database,
sql=param_config[request.param]["sql"],
validator_type=param_config[request.param]["validator_type"],
validator_config_json=param_config[request.param]["validator_config_json"],
)
yield report_schedule
cleanup_report_schedule(report_schedule)
cleanup_report_schedule(report_schedule)
@pytest.fixture(params=["alert1", "alert2"])
def create_invalid_sql_alert_email_chart(request):
def create_invalid_sql_alert_email_chart(request, app_context: AppContext):
param_config = {
"alert1": {
"sql": "SELECT 'string' ",
@ -612,25 +588,22 @@ def create_invalid_sql_alert_email_chart(request):
"validator_config_json": '{"op": "<", "threshold": 10}',
},
}
with app.app_context():
chart = db.session.query(Slice).first()
example_database = get_example_database()
with create_test_table_context(example_database):
report_schedule = create_report_notification(
email_target="target@email.com",
chart=chart,
report_type=ReportScheduleType.ALERT,
database=example_database,
sql=param_config[request.param]["sql"],
validator_type=param_config[request.param]["validator_type"],
validator_config_json=param_config[request.param][
"validator_config_json"
],
grace_period=60 * 60,
)
yield report_schedule
chart = db.session.query(Slice).first()
example_database = get_example_database()
with create_test_table_context(example_database):
report_schedule = create_report_notification(
email_target="target@email.com",
chart=chart,
report_type=ReportScheduleType.ALERT,
database=example_database,
sql=param_config[request.param]["sql"],
validator_type=param_config[request.param]["validator_type"],
validator_config_json=param_config[request.param]["validator_config_json"],
grace_period=60 * 60,
)
yield report_schedule
cleanup_report_schedule(report_schedule)
cleanup_report_schedule(report_schedule)
@pytest.mark.usefixtures(
@ -835,7 +808,8 @@ def test_email_chart_report_dry_run(
@pytest.mark.usefixtures(
"load_birth_names_dashboard_with_slices", "create_report_email_chart_with_csv"
"load_birth_names_dashboard_with_slices",
"create_report_email_chart_with_csv",
)
@patch("superset.utils.csv.urllib.request.urlopen")
@patch("superset.utils.csv.urllib.request.OpenerDirector.open")
@ -923,7 +897,8 @@ def test_email_chart_report_schedule_with_csv_no_query_context(
@pytest.mark.usefixtures(
"load_birth_names_dashboard_with_slices", "create_report_email_chart_with_text"
"load_birth_names_dashboard_with_slices",
"create_report_email_chart_with_text",
)
@patch("superset.utils.csv.urllib.request.urlopen")
@patch("superset.utils.csv.urllib.request.OpenerDirector.open")
@ -1545,7 +1520,8 @@ def test_slack_chart_alert_no_attachment(email_mock, create_alert_email_chart):
@pytest.mark.usefixtures(
"load_birth_names_dashboard_with_slices", "create_report_slack_chart"
"load_birth_names_dashboard_with_slices",
"create_report_slack_chart",
)
@patch("superset.utils.slack.WebClient")
@patch("superset.utils.screenshots.ChartScreenshot.get_screenshot")
@ -1571,7 +1547,7 @@ def test_slack_token_callable_chart_report(
assert_log(ReportState.SUCCESS)
@pytest.mark.usefixtures("create_no_alert_email_chart")
@pytest.mark.usefixtures("app_context")
def test_email_chart_no_alert(create_no_alert_email_chart):
"""
ExecuteReport Command: Test chart email no alert
@ -1583,7 +1559,7 @@ def test_email_chart_no_alert(create_no_alert_email_chart):
assert_log(ReportState.NOOP)
@pytest.mark.usefixtures("create_mul_alert_email_chart")
@pytest.mark.usefixtures("app_context")
def test_email_mul_alert(create_mul_alert_email_chart):
"""
ExecuteReport Command: Test chart email multiple rows
@ -1824,7 +1800,6 @@ def test_email_disable_screenshot(email_mock, create_alert_email_chart):
assert_log(ReportState.SUCCESS)
@pytest.mark.usefixtures("create_invalid_sql_alert_email_chart")
@patch("superset.reports.notifications.email.send_email_smtp")
def test_invalid_sql_alert(email_mock, create_invalid_sql_alert_email_chart):
"""
@ -1841,7 +1816,6 @@ def test_invalid_sql_alert(email_mock, create_invalid_sql_alert_email_chart):
assert_log(ReportState.ERROR)
@pytest.mark.usefixtures("create_invalid_sql_alert_email_chart")
@patch("superset.reports.notifications.email.send_email_smtp")
def test_grace_period_error(email_mock, create_invalid_sql_alert_email_chart):
"""
@ -1884,7 +1858,6 @@ def test_grace_period_error(email_mock, create_invalid_sql_alert_email_chart):
)
@pytest.mark.usefixtures("create_invalid_sql_alert_email_chart")
@patch("superset.reports.notifications.email.send_email_smtp")
@patch("superset.utils.screenshots.ChartScreenshot.get_screenshot")
def test_grace_period_error_flap(

View File

@ -35,150 +35,144 @@ def owners(get_user) -> list[User]:
return [get_user("admin")]
@pytest.mark.usefixtures("owners")
@pytest.mark.usefixtures("app_context")
@patch("superset.tasks.scheduler.execute.apply_async")
def test_scheduler_celery_timeout_ny(execute_mock, owners):
"""
Reports scheduler: Test scheduler setting celery soft and hard timeout
"""
with app.app_context():
report_schedule = insert_report_schedule(
type=ReportScheduleType.ALERT,
name="report",
crontab="0 4 * * *",
timezone="America/New_York",
owners=owners,
)
report_schedule = insert_report_schedule(
type=ReportScheduleType.ALERT,
name="report",
crontab="0 4 * * *",
timezone="America/New_York",
owners=owners,
)
with freeze_time("2020-01-01T09:00:00Z"):
scheduler()
assert execute_mock.call_args[1]["soft_time_limit"] == 3601
assert execute_mock.call_args[1]["time_limit"] == 3610
db.session.delete(report_schedule)
db.session.commit()
with freeze_time("2020-01-01T09:00:00Z"):
scheduler()
assert execute_mock.call_args[1]["soft_time_limit"] == 3601
assert execute_mock.call_args[1]["time_limit"] == 3610
db.session.delete(report_schedule)
db.session.commit()
@pytest.mark.usefixtures("owners")
@pytest.mark.usefixtures("app_context")
@patch("superset.tasks.scheduler.execute.apply_async")
def test_scheduler_celery_no_timeout_ny(execute_mock, owners):
"""
Reports scheduler: Test scheduler setting celery soft and hard timeout
"""
with app.app_context():
app.config["ALERT_REPORTS_WORKING_TIME_OUT_KILL"] = False
report_schedule = insert_report_schedule(
type=ReportScheduleType.ALERT,
name="report",
crontab="0 4 * * *",
timezone="America/New_York",
owners=owners,
)
app.config["ALERT_REPORTS_WORKING_TIME_OUT_KILL"] = False
report_schedule = insert_report_schedule(
type=ReportScheduleType.ALERT,
name="report",
crontab="0 4 * * *",
timezone="America/New_York",
owners=owners,
)
with freeze_time("2020-01-01T09:00:00Z"):
scheduler()
assert execute_mock.call_args[1] == {"eta": FakeDatetime(2020, 1, 1, 9, 0)}
db.session.delete(report_schedule)
db.session.commit()
app.config["ALERT_REPORTS_WORKING_TIME_OUT_KILL"] = True
with freeze_time("2020-01-01T09:00:00Z"):
scheduler()
assert execute_mock.call_args[1] == {"eta": FakeDatetime(2020, 1, 1, 9, 0)}
db.session.delete(report_schedule)
db.session.commit()
app.config["ALERT_REPORTS_WORKING_TIME_OUT_KILL"] = True
@pytest.mark.usefixtures("owners")
@pytest.mark.usefixtures("app_context")
@patch("superset.tasks.scheduler.execute.apply_async")
def test_scheduler_celery_timeout_utc(execute_mock, owners):
"""
Reports scheduler: Test scheduler setting celery soft and hard timeout
"""
with app.app_context():
report_schedule = insert_report_schedule(
type=ReportScheduleType.ALERT,
name="report",
crontab="0 9 * * *",
timezone="UTC",
owners=owners,
)
report_schedule = insert_report_schedule(
type=ReportScheduleType.ALERT,
name="report",
crontab="0 9 * * *",
timezone="UTC",
owners=owners,
)
with freeze_time("2020-01-01T09:00:00Z"):
scheduler()
assert execute_mock.call_args[1]["soft_time_limit"] == 3601
assert execute_mock.call_args[1]["time_limit"] == 3610
db.session.delete(report_schedule)
db.session.commit()
with freeze_time("2020-01-01T09:00:00Z"):
scheduler()
assert execute_mock.call_args[1]["soft_time_limit"] == 3601
assert execute_mock.call_args[1]["time_limit"] == 3610
db.session.delete(report_schedule)
db.session.commit()
@pytest.mark.usefixtures("owners")
@pytest.mark.usefixtures("app_context")
@patch("superset.tasks.scheduler.execute.apply_async")
def test_scheduler_celery_no_timeout_utc(execute_mock, owners):
"""
Reports scheduler: Test scheduler setting celery soft and hard timeout
"""
with app.app_context():
app.config["ALERT_REPORTS_WORKING_TIME_OUT_KILL"] = False
report_schedule = insert_report_schedule(
type=ReportScheduleType.ALERT,
name="report",
crontab="0 9 * * *",
timezone="UTC",
owners=owners,
)
app.config["ALERT_REPORTS_WORKING_TIME_OUT_KILL"] = False
report_schedule = insert_report_schedule(
type=ReportScheduleType.ALERT,
name="report",
crontab="0 9 * * *",
timezone="UTC",
owners=owners,
)
with freeze_time("2020-01-01T09:00:00Z"):
scheduler()
assert execute_mock.call_args[1] == {"eta": FakeDatetime(2020, 1, 1, 9, 0)}
db.session.delete(report_schedule)
db.session.commit()
app.config["ALERT_REPORTS_WORKING_TIME_OUT_KILL"] = True
with freeze_time("2020-01-01T09:00:00Z"):
scheduler()
assert execute_mock.call_args[1] == {"eta": FakeDatetime(2020, 1, 1, 9, 0)}
db.session.delete(report_schedule)
db.session.commit()
app.config["ALERT_REPORTS_WORKING_TIME_OUT_KILL"] = True
@pytest.mark.usefixtures("owners")
@pytest.mark.usefixtures("app_context")
@patch("superset.tasks.scheduler.is_feature_enabled")
@patch("superset.tasks.scheduler.execute.apply_async")
def test_scheduler_feature_flag_off(execute_mock, is_feature_enabled, owners):
"""
Reports scheduler: Test scheduler with feature flag off
"""
with app.app_context():
is_feature_enabled.return_value = False
report_schedule = insert_report_schedule(
type=ReportScheduleType.ALERT,
name="report",
crontab="0 9 * * *",
timezone="UTC",
owners=owners,
)
is_feature_enabled.return_value = False
report_schedule = insert_report_schedule(
type=ReportScheduleType.ALERT,
name="report",
crontab="0 9 * * *",
timezone="UTC",
owners=owners,
)
with freeze_time("2020-01-01T09:00:00Z"):
scheduler()
execute_mock.assert_not_called()
db.session.delete(report_schedule)
db.session.commit()
with freeze_time("2020-01-01T09:00:00Z"):
scheduler()
execute_mock.assert_not_called()
db.session.delete(report_schedule)
db.session.commit()
@pytest.mark.usefixtures("owners")
@pytest.mark.usefixtures("app_context")
@patch("superset.commands.report.execute.AsyncExecuteReportScheduleCommand.__init__")
@patch("superset.commands.report.execute.AsyncExecuteReportScheduleCommand.run")
@patch("superset.tasks.scheduler.execute.update_state")
def test_execute_task(update_state_mock, command_mock, init_mock, owners):
from superset.commands.report.exceptions import ReportScheduleUnexpectedError
with app.app_context():
report_schedule = insert_report_schedule(
type=ReportScheduleType.ALERT,
name=f"report-{randint(0,1000)}",
crontab="0 4 * * *",
timezone="America/New_York",
owners=owners,
)
init_mock.return_value = None
command_mock.side_effect = ReportScheduleUnexpectedError("Unexpected error")
with freeze_time("2020-01-01T09:00:00Z"):
execute(report_schedule.id)
update_state_mock.assert_called_with(state="FAILURE")
report_schedule = insert_report_schedule(
type=ReportScheduleType.ALERT,
name=f"report-{randint(0,1000)}",
crontab="0 4 * * *",
timezone="America/New_York",
owners=owners,
)
init_mock.return_value = None
command_mock.side_effect = ReportScheduleUnexpectedError("Unexpected error")
with freeze_time("2020-01-01T09:00:00Z"):
execute(report_schedule.id)
update_state_mock.assert_called_with(state="FAILURE")
db.session.delete(report_schedule)
db.session.commit()
db.session.delete(report_schedule)
db.session.commit()
@pytest.mark.usefixtures("owners")
@pytest.mark.usefixtures("app_context")
@patch("superset.commands.report.execute.AsyncExecuteReportScheduleCommand.__init__")
@patch("superset.commands.report.execute.AsyncExecuteReportScheduleCommand.run")
@patch("superset.tasks.scheduler.execute.update_state")
@ -188,23 +182,22 @@ def test_execute_task_with_command_exception(
):
from superset.commands.exceptions import CommandException
with app.app_context():
report_schedule = insert_report_schedule(
type=ReportScheduleType.ALERT,
name=f"report-{randint(0,1000)}",
crontab="0 4 * * *",
timezone="America/New_York",
owners=owners,
report_schedule = insert_report_schedule(
type=ReportScheduleType.ALERT,
name=f"report-{randint(0,1000)}",
crontab="0 4 * * *",
timezone="America/New_York",
owners=owners,
)
init_mock.return_value = None
command_mock.side_effect = CommandException("Unexpected error")
with freeze_time("2020-01-01T09:00:00Z"):
execute(report_schedule.id)
update_state_mock.assert_called_with(state="FAILURE")
logger_mock.exception.assert_called_with(
"A downstream exception occurred while generating a report: None. Unexpected error",
exc_info=True,
)
init_mock.return_value = None
command_mock.side_effect = CommandException("Unexpected error")
with freeze_time("2020-01-01T09:00:00Z"):
execute(report_schedule.id)
update_state_mock.assert_called_with(state="FAILURE")
logger_mock.exception.assert_called_with(
"A downstream exception occurred while generating a report: None. Unexpected error",
exc_info=True,
)
db.session.delete(report_schedule)
db.session.commit()
db.session.delete(report_schedule)
db.session.commit()

View File

@ -84,10 +84,9 @@ from tests.integration_tests.test_app import app
def test_check_sqlalchemy_uri(
sqlalchemy_uri: str, error: bool, error_message: Optional[str]
):
with app.app_context():
if error:
with pytest.raises(SupersetSecurityException) as excinfo:
check_sqlalchemy_uri(make_url(sqlalchemy_uri))
assert str(excinfo.value) == error_message
else:
if error:
with pytest.raises(SupersetSecurityException) as excinfo:
check_sqlalchemy_uri(make_url(sqlalchemy_uri))
assert str(excinfo.value) == error_message
else:
check_sqlalchemy_uri(make_url(sqlalchemy_uri))

View File

@ -24,7 +24,7 @@ import pytest
import numpy as np
import pandas as pd
from flask import Flask
from flask.ctx import AppContext
from pytest_mock import MockFixture
from sqlalchemy.sql import text
from sqlalchemy.sql.elements import TextClause
@ -598,26 +598,25 @@ class TestDatabaseModel(SupersetTestCase):
db.session.commit()
@pytest.fixture
def text_column_table():
with app.app_context():
table = SqlaTable(
table_name="text_column_table",
sql=(
"SELECT 'foo' as foo "
"UNION SELECT '' "
"UNION SELECT NULL "
"UNION SELECT 'null' "
"UNION SELECT '\"text in double quotes\"' "
"UNION SELECT '''text in single quotes''' "
"UNION SELECT 'double quotes \" in text' "
"UNION SELECT 'single quotes '' in text' "
),
database=get_example_database(),
)
TableColumn(column_name="foo", type="VARCHAR(255)", table=table)
SqlMetric(metric_name="count", expression="count(*)", table=table)
yield table
@pytest.fixture()
def text_column_table(app_context: AppContext):
table = SqlaTable(
table_name="text_column_table",
sql=(
"SELECT 'foo' as foo "
"UNION SELECT '' "
"UNION SELECT NULL "
"UNION SELECT 'null' "
"UNION SELECT '\"text in double quotes\"' "
"UNION SELECT '''text in single quotes''' "
"UNION SELECT 'double quotes \" in text' "
"UNION SELECT 'single quotes '' in text' "
),
database=get_example_database(),
)
TableColumn(column_name="foo", type="VARCHAR(255)", table=table)
SqlMetric(metric_name="count", expression="count(*)", table=table)
yield table
def test_values_for_column_on_text_column(text_column_table):
@ -836,6 +835,7 @@ def test_none_operand_in_filter(login_as_admin, physical_dataset):
)
@pytest.mark.usefixtures("app_context")
@pytest.mark.parametrize(
"row,dimension,result",
[
@ -857,7 +857,6 @@ def test_none_operand_in_filter(login_as_admin, physical_dataset):
],
)
def test__normalize_prequery_result_type(
app_context: Flask,
mocker: MockFixture,
row: pd.Series,
dimension: str,
@ -927,7 +926,8 @@ def test__normalize_prequery_result_type(
assert normalized == result
def test__temporal_range_operator_in_adhoc_filter(app_context, physical_dataset):
@pytest.mark.usefixtures("app_context")
def test__temporal_range_operator_in_adhoc_filter(physical_dataset):
result = physical_dataset.query(
{
"columns": ["col1", "col2"],

View File

@ -82,5 +82,4 @@ def test_form_data_to_adhoc_incorrect_clause_type():
form_data = {"where": "1 = 1", "having": "count(*) > 1"}
with pytest.raises(ValueError):
with app.app_context():
form_data_to_adhoc(form_data, "foobar")
form_data_to_adhoc(form_data, "foobar")

View File

@ -685,20 +685,19 @@ class TestUtils(SupersetTestCase):
self.assertIsNotNone(parse_js_uri_path_item("item"))
def test_get_stacktrace(self):
with app.app_context():
app.config["SHOW_STACKTRACE"] = True
try:
raise Exception("NONONO!")
except Exception:
stacktrace = get_stacktrace()
self.assertIn("NONONO", stacktrace)
app.config["SHOW_STACKTRACE"] = True
try:
raise Exception("NONONO!")
except Exception:
stacktrace = get_stacktrace()
self.assertIn("NONONO", stacktrace)
app.config["SHOW_STACKTRACE"] = False
try:
raise Exception("NONONO!")
except Exception:
stacktrace = get_stacktrace()
assert stacktrace is None
app.config["SHOW_STACKTRACE"] = False
try:
raise Exception("NONONO!")
except Exception:
stacktrace = get_stacktrace()
assert stacktrace is None
def test_split(self):
self.assertEqual(list(split("a b")), ["a", "b"])
@ -839,9 +838,8 @@ class TestUtils(SupersetTestCase):
)
def test_get_form_data_default(self) -> None:
with app.test_request_context():
form_data, slc = get_form_data()
self.assertEqual(slc, None)
form_data, slc = get_form_data()
self.assertEqual(slc, None)
def test_get_form_data_request_args(self) -> None:
with app.test_request_context(

View File

@ -411,72 +411,71 @@ def test_delete_ssh_tunnel(
"""
Test that we can delete SSH Tunnel
"""
with app.app_context():
from superset.daos.database import DatabaseDAO
from superset.databases.api import DatabaseRestApi
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
from superset.daos.database import DatabaseDAO
from superset.databases.api import DatabaseRestApi
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
DatabaseRestApi.datamodel.session = session
DatabaseRestApi.datamodel.session = session
# create table for databases
Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member
# create table for databases
Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member
# Create our Database
database = Database(
database_name="my_database",
sqlalchemy_uri="gsheets://",
encrypted_extra=json.dumps(
{
"service_account_info": {
"type": "service_account",
"project_id": "black-sanctum-314419",
"private_key_id": "259b0d419a8f840056158763ff54d8b08f7b8173",
"private_key": "SECRET",
"client_email": "google-spreadsheets-demo-servi@black-sanctum-314419.iam.gserviceaccount.com",
"client_id": "SSH_TUNNEL_CREDENTIALS_CLIENT",
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://oauth2.googleapis.com/token",
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/google-spreadsheets-demo-servi%40black-sanctum-314419.iam.gserviceaccount.com",
},
}
),
)
db.session.add(database)
db.session.commit()
# Create our Database
database = Database(
database_name="my_database",
sqlalchemy_uri="gsheets://",
encrypted_extra=json.dumps(
{
"service_account_info": {
"type": "service_account",
"project_id": "black-sanctum-314419",
"private_key_id": "259b0d419a8f840056158763ff54d8b08f7b8173",
"private_key": "SECRET",
"client_email": "google-spreadsheets-demo-servi@black-sanctum-314419.iam.gserviceaccount.com",
"client_id": "SSH_TUNNEL_CREDENTIALS_CLIENT",
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://oauth2.googleapis.com/token",
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/google-spreadsheets-demo-servi%40black-sanctum-314419.iam.gserviceaccount.com",
},
}
),
)
db.session.add(database)
db.session.commit()
# mock the lookup so that we don't need to include the driver
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
mocker.patch("superset.utils.log.DBEventLogger.log")
mocker.patch(
"superset.commands.database.ssh_tunnel.delete.is_feature_enabled",
return_value=True,
)
# mock the lookup so that we don't need to include the driver
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
mocker.patch("superset.utils.log.DBEventLogger.log")
mocker.patch(
"superset.commands.database.ssh_tunnel.delete.is_feature_enabled",
return_value=True,
)
# Create our SSHTunnel
tunnel = SSHTunnel(
database_id=1,
database=database,
)
# Create our SSHTunnel
tunnel = SSHTunnel(
database_id=1,
database=database,
)
db.session.add(tunnel)
db.session.commit()
db.session.add(tunnel)
db.session.commit()
# Get our recently created SSHTunnel
response_tunnel = DatabaseDAO.get_ssh_tunnel(1)
assert response_tunnel
assert isinstance(response_tunnel, SSHTunnel)
assert 1 == response_tunnel.database_id
# Get our recently created SSHTunnel
response_tunnel = DatabaseDAO.get_ssh_tunnel(1)
assert response_tunnel
assert isinstance(response_tunnel, SSHTunnel)
assert 1 == response_tunnel.database_id
# Delete the recently created SSHTunnel
response_delete_tunnel = client.delete(
f"/api/v1/database/{database.id}/ssh_tunnel/"
)
assert response_delete_tunnel.json["message"] == "OK"
# Delete the recently created SSHTunnel
response_delete_tunnel = client.delete(
f"/api/v1/database/{database.id}/ssh_tunnel/"
)
assert response_delete_tunnel.json["message"] == "OK"
response_tunnel = DatabaseDAO.get_ssh_tunnel(1)
assert response_tunnel is None
response_tunnel = DatabaseDAO.get_ssh_tunnel(1)
assert response_tunnel is None
def test_delete_ssh_tunnel_not_found(
@ -489,70 +488,69 @@ def test_delete_ssh_tunnel_not_found(
"""
Test that we cannot delete a tunnel that does not exist
"""
with app.app_context():
from superset.daos.database import DatabaseDAO
from superset.databases.api import DatabaseRestApi
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
from superset.daos.database import DatabaseDAO
from superset.databases.api import DatabaseRestApi
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
DatabaseRestApi.datamodel.session = session
DatabaseRestApi.datamodel.session = session
# create table for databases
Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member
# create table for databases
Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member
# Create our Database
database = Database(
database_name="my_database",
sqlalchemy_uri="gsheets://",
encrypted_extra=json.dumps(
{
"service_account_info": {
"type": "service_account",
"project_id": "black-sanctum-314419",
"private_key_id": "259b0d419a8f840056158763ff54d8b08f7b8173",
"private_key": "SECRET",
"client_email": "google-spreadsheets-demo-servi@black-sanctum-314419.iam.gserviceaccount.com",
"client_id": "SSH_TUNNEL_CREDENTIALS_CLIENT",
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://oauth2.googleapis.com/token",
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/google-spreadsheets-demo-servi%40black-sanctum-314419.iam.gserviceaccount.com",
},
}
),
)
db.session.add(database)
db.session.commit()
# Create our Database
database = Database(
database_name="my_database",
sqlalchemy_uri="gsheets://",
encrypted_extra=json.dumps(
{
"service_account_info": {
"type": "service_account",
"project_id": "black-sanctum-314419",
"private_key_id": "259b0d419a8f840056158763ff54d8b08f7b8173",
"private_key": "SECRET",
"client_email": "google-spreadsheets-demo-servi@black-sanctum-314419.iam.gserviceaccount.com",
"client_id": "SSH_TUNNEL_CREDENTIALS_CLIENT",
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://oauth2.googleapis.com/token",
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/google-spreadsheets-demo-servi%40black-sanctum-314419.iam.gserviceaccount.com",
},
}
),
)
db.session.add(database)
db.session.commit()
# mock the lookup so that we don't need to include the driver
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
mocker.patch("superset.utils.log.DBEventLogger.log")
mocker.patch(
"superset.commands.database.ssh_tunnel.delete.is_feature_enabled",
return_value=True,
)
# mock the lookup so that we don't need to include the driver
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
mocker.patch("superset.utils.log.DBEventLogger.log")
mocker.patch(
"superset.commands.database.ssh_tunnel.delete.is_feature_enabled",
return_value=True,
)
# Create our SSHTunnel
tunnel = SSHTunnel(
database_id=1,
database=database,
)
# Create our SSHTunnel
tunnel = SSHTunnel(
database_id=1,
database=database,
)
db.session.add(tunnel)
db.session.commit()
db.session.add(tunnel)
db.session.commit()
# Delete the recently created SSHTunnel
response_delete_tunnel = client.delete("/api/v1/database/2/ssh_tunnel/")
assert response_delete_tunnel.json["message"] == "Not found"
# Delete the recently created SSHTunnel
response_delete_tunnel = client.delete("/api/v1/database/2/ssh_tunnel/")
assert response_delete_tunnel.json["message"] == "Not found"
# Get our recently created SSHTunnel
response_tunnel = DatabaseDAO.get_ssh_tunnel(1)
assert response_tunnel
assert isinstance(response_tunnel, SSHTunnel)
assert 1 == response_tunnel.database_id
# Get our recently created SSHTunnel
response_tunnel = DatabaseDAO.get_ssh_tunnel(1)
assert response_tunnel
assert isinstance(response_tunnel, SSHTunnel)
assert 1 == response_tunnel.database_id
response_tunnel = DatabaseDAO.get_ssh_tunnel(2)
assert response_tunnel is None
response_tunnel = DatabaseDAO.get_ssh_tunnel(2)
assert response_tunnel is None
def test_apply_dynamic_database_filter(
@ -568,88 +566,87 @@ def test_apply_dynamic_database_filter(
defining a filter function and patching the config to get
the filtered results.
"""
with app.app_context():
from superset.daos.database import DatabaseDAO
from superset.databases.api import DatabaseRestApi
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.daos.database import DatabaseDAO
from superset.databases.api import DatabaseRestApi
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
DatabaseRestApi.datamodel.session = session
# create table for databases
Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member
# Create our First Database
database = Database(
database_name="first-database",
sqlalchemy_uri="gsheets://",
encrypted_extra=json.dumps(
{
"metadata_params": {},
"engine_params": {},
"metadata_cache_timeout": {},
"schemas_allowed_for_file_upload": [],
}
),
)
db.session.add(database)
db.session.commit()
# Create our Second Database
database = Database(
database_name="second-database",
sqlalchemy_uri="gsheets://",
encrypted_extra=json.dumps(
{
"metadata_params": {},
"engine_params": {},
"metadata_cache_timeout": {},
"schemas_allowed_for_file_upload": [],
}
),
)
db.session.add(database)
db.session.commit()
# mock the lookup so that we don't need to include the driver
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
mocker.patch("superset.utils.log.DBEventLogger.log")
mocker.patch(
"superset.commands.database.ssh_tunnel.delete.is_feature_enabled",
return_value=False,
)
def _base_filter(query):
from superset.models.core import Database
DatabaseRestApi.datamodel.session = session
return query.filter(Database.database_name.startswith("second"))
# create table for databases
Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member
# Create a mock object
base_filter_mock = Mock(side_effect=_base_filter)
# Create our First Database
database = Database(
database_name="first-database",
sqlalchemy_uri="gsheets://",
encrypted_extra=json.dumps(
{
"metadata_params": {},
"engine_params": {},
"metadata_cache_timeout": {},
"schemas_allowed_for_file_upload": [],
}
),
)
db.session.add(database)
db.session.commit()
# Get our recently created Databases
response_databases = DatabaseDAO.find_all()
assert response_databases
expected_db_names = ["first-database", "second-database"]
actual_db_names = [db.database_name for db in response_databases]
assert actual_db_names == expected_db_names
# Create our Second Database
database = Database(
database_name="second-database",
sqlalchemy_uri="gsheets://",
encrypted_extra=json.dumps(
{
"metadata_params": {},
"engine_params": {},
"metadata_cache_timeout": {},
"schemas_allowed_for_file_upload": [],
}
),
)
db.session.add(database)
db.session.commit()
# Ensure that the filter has not been called because it's not in our config
assert base_filter_mock.call_count == 0
# mock the lookup so that we don't need to include the driver
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
mocker.patch("superset.utils.log.DBEventLogger.log")
mocker.patch(
"superset.commands.database.ssh_tunnel.delete.is_feature_enabled",
return_value=False,
)
original_config = current_app.config.copy()
original_config["EXTRA_DYNAMIC_QUERY_FILTERS"] = {"databases": base_filter_mock}
def _base_filter(query):
from superset.models.core import Database
mocker.patch("superset.views.filters.current_app.config", new=original_config)
# Get filtered list
response_databases = DatabaseDAO.find_all()
assert response_databases
expected_db_names = ["second-database"]
actual_db_names = [db.database_name for db in response_databases]
assert actual_db_names == expected_db_names
return query.filter(Database.database_name.startswith("second"))
# Create a mock object
base_filter_mock = Mock(side_effect=_base_filter)
# Get our recently created Databases
response_databases = DatabaseDAO.find_all()
assert response_databases
expected_db_names = ["first-database", "second-database"]
actual_db_names = [db.database_name for db in response_databases]
assert actual_db_names == expected_db_names
# Ensure that the filter has not been called because it's not in our config
assert base_filter_mock.call_count == 0
original_config = current_app.config.copy()
original_config["EXTRA_DYNAMIC_QUERY_FILTERS"] = {"databases": base_filter_mock}
mocker.patch("superset.views.filters.current_app.config", new=original_config)
# Get filtered list
response_databases = DatabaseDAO.find_all()
assert response_databases
expected_db_names = ["second-database"]
actual_db_names = [db.database_name for db in response_databases]
assert actual_db_names == expected_db_names
# Ensure that the filter has been called once
assert base_filter_mock.call_count == 1
# Ensure that the filter has been called once
assert base_filter_mock.call_count == 1
def test_oauth2_happy_path(