Centralizing logic

This commit is contained in:
Vitor Avila 2024-06-30 02:25:50 -03:00
parent db5b878c51
commit eb042c31fd
5 changed files with 59 additions and 72 deletions

View File

@ -24,6 +24,7 @@ from typing import Any, Union, Optional
from unittest.mock import Mock, patch, MagicMock from unittest.mock import Mock, patch, MagicMock
import pandas as pd import pandas as pd
import prison
from flask import Response, g from flask import Response, g
from flask_appbuilder.security.sqla import models as ab_models from flask_appbuilder.security.sqla import models as ab_models
from flask_testing import TestCase from flask_testing import TestCase
@ -33,6 +34,7 @@ from sqlalchemy.orm import Session # noqa: F401
from sqlalchemy.sql import func from sqlalchemy.sql import func
from sqlalchemy.dialects.mysql import dialect from sqlalchemy.dialects.mysql import dialect
from tests.integration_tests.constants import ADMIN_USERNAME
from tests.integration_tests.test_app import app, login from tests.integration_tests.test_app import app, login
from superset.sql_parse import CtasMethod from superset.sql_parse import CtasMethod
from superset import db, security_manager from superset import db, security_manager
@ -590,6 +592,20 @@ class SupersetTestCase(TestCase):
db.session.commit() db.session.commit()
return dashboard return dashboard
def get_list(
self,
asset_type: str,
filter: dict[str, Any] = {},
username: str = ADMIN_USERNAME,
) -> Response:
"""
Get list of assets, by default using admin account. Can be filtered.
"""
self.login(username)
uri = f"api/v1/{asset_type}/?q={prison.dumps(filter)}"
response = self.get_assert_metric(uri, "get_list")
return response
@contextmanager @contextmanager
def db_insert_temp_object(obj: DeclarativeMeta): def db_insert_temp_object(obj: DeclarativeMeta):

View File

@ -61,7 +61,10 @@ from tests.integration_tests.fixtures.importexport import (
dataset_config, dataset_config,
dataset_metadata_config, dataset_metadata_config,
) )
from tests.integration_tests.fixtures.tags import create_custom_tags # noqa: F401 from tests.integration_tests.fixtures.tags import (
create_custom_tags, # noqa: F401
get_filter_params,
)
from tests.integration_tests.fixtures.unicode_dashboard import ( from tests.integration_tests.fixtures.unicode_dashboard import (
load_unicode_dashboard_with_slice, # noqa: F401 load_unicode_dashboard_with_slice, # noqa: F401
load_unicode_data, # noqa: F401 load_unicode_data, # noqa: F401
@ -1195,38 +1198,21 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
for tag in tags.values() for tag in tags.values()
} }
# Helper function to return filter parameters
def get_filter_params(opr, value):
return {
"filters": [
{
"col": "tags",
"opr": opr,
"value": value,
}
]
}
# Helper function to test chart filtering by tag
def get_charts_filtered_list(filter):
self.login(ADMIN_USERNAME)
uri = f"api/v1/chart/?q={prison.dumps(filter)}"
response = self.get_assert_metric(uri, "get_list")
self.assertEqual(response.status_code, 200)
data = json.loads(response.data.decode("utf-8"))
return data
# Validate API results for each tag # Validate API results for each tag
for tag_name, tag in tags.items(): for tag_name, tag in tags.items():
expected_charts = chart_tag_relationship[tag_name] expected_charts = chart_tag_relationship[tag_name]
# Filter by tag ID # Filter by tag ID
filter_params = get_filter_params("chart_tag_id", tag.id) filter_params = get_filter_params("chart_tag_id", tag.id)
data_by_id = get_charts_filtered_list(filter_params) response_by_id = self.get_list("chart", filter_params)
self.assertEqual(response_by_id.status_code, 200)
data_by_id = json.loads(response_by_id.data.decode("utf-8"))
# Filter by tag name # Filter by tag name
filter_params = get_filter_params("chart_tags", tag.name) filter_params = get_filter_params("chart_tags", tag.name)
data_by_name = get_charts_filtered_list(filter_params) response_by_name = self.get_list("chart", filter_params)
self.assertEqual(response_by_name.status_code, 200)
data_by_name = json.loads(response_by_name.data.decode("utf-8"))
# Compare results # Compare results
self.assertEqual( self.assertEqual(

View File

@ -55,7 +55,10 @@ from tests.integration_tests.fixtures.importexport import (
dataset_config, dataset_config,
dataset_metadata_config, dataset_metadata_config,
) )
from tests.integration_tests.fixtures.tags import create_custom_tags # noqa: F401 from tests.integration_tests.fixtures.tags import (
create_custom_tags, # noqa: F401
get_filter_params,
)
from tests.integration_tests.utils.get_dashboards import get_dashboards_ids from tests.integration_tests.utils.get_dashboards import get_dashboards_ids
from tests.integration_tests.fixtures.birth_names_dashboard import ( from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices, # noqa: F401 load_birth_names_dashboard_with_slices, # noqa: F401
@ -781,38 +784,21 @@ class TestDashboardApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCas
for tag in tags.values() for tag in tags.values()
} }
# Helper function to return filter parameters
def get_filter_params(opr, value):
return {
"filters": [
{
"col": "tags",
"opr": opr,
"value": value,
}
]
}
# Helper function to test chart filtering by tag
def get_charts_filtered_list(filter):
self.login(ADMIN_USERNAME)
uri = f"api/v1/dashboard/?q={prison.dumps(filter)}"
response = self.get_assert_metric(uri, "get_list")
self.assertEqual(response.status_code, 200)
data = json.loads(response.data.decode("utf-8"))
return data
# Validate API results for each tag # Validate API results for each tag
for tag_name, tag in tags.items(): for tag_name, tag in tags.items():
expected_dashboards = dashboard_tag_relationship[tag_name] expected_dashboards = dashboard_tag_relationship[tag_name]
# Filter by tag ID # Filter by tag ID
filter_params = get_filter_params("dashboard_tag_id", tag.id) filter_params = get_filter_params("dashboard_tag_id", tag.id)
data_by_id = get_charts_filtered_list(filter_params) response_by_id = self.get_list("dashboard", filter_params)
self.assertEqual(response_by_id.status_code, 200)
data_by_id = json.loads(response_by_id.data.decode("utf-8"))
# Filter by tag name # Filter by tag name
filter_params = get_filter_params("dashboard_tags", tag.name) filter_params = get_filter_params("dashboard_tags", tag.name)
data_by_name = get_charts_filtered_list(filter_params) response_by_name = self.get_list("dashboard", filter_params)
self.assertEqual(response_by_name.status_code, 200)
data_by_name = json.loads(response_by_name.data.decode("utf-8"))
# Compare results # Compare results
self.assertEqual( self.assertEqual(

View File

@ -53,3 +53,16 @@ def create_custom_tags():
for tags in tags: for tags in tags:
db.session.delete(tags) db.session.delete(tags)
db.session.commit() db.session.commit()
# Helper function to return filter parameters
def get_filter_params(opr, value):
return {
"filters": [
{
"col": "tags",
"opr": opr,
"value": value,
}
]
}

View File

@ -43,7 +43,10 @@ from tests.integration_tests.fixtures.importexport import (
saved_queries_config, saved_queries_config,
saved_queries_metadata_config, saved_queries_metadata_config,
) )
from tests.integration_tests.fixtures.tags import create_custom_tags # noqa: F401 from tests.integration_tests.fixtures.tags import (
create_custom_tags, # noqa: F401
get_filter_params,
)
SAVED_QUERIES_FIXTURE_COUNT = 10 SAVED_QUERIES_FIXTURE_COUNT = 10
@ -456,38 +459,21 @@ class TestSavedQueryApi(SupersetTestCase):
for tag in tags.values() for tag in tags.values()
} }
# Helper function to return filter parameters
def get_filter_params(opr, value):
return {
"filters": [
{
"col": "tags",
"opr": opr,
"value": value,
}
]
}
# Helper function to test chart filtering by tag
def get_saved_queries_filtered_list(filter):
self.login(ADMIN_USERNAME)
uri = f"api/v1/saved_query/?q={prison.dumps(filter)}"
response = self.get_assert_metric(uri, "get_list")
self.assertEqual(response.status_code, 200)
data = json.loads(response.data.decode("utf-8"))
return data
# Validate API results for each tag # Validate API results for each tag
for tag_name, tag in tags.items(): for tag_name, tag in tags.items():
expected_saved_queries = saved_queries_tag_relationship[tag_name] expected_saved_queries = saved_queries_tag_relationship[tag_name]
# Filter by tag ID # Filter by tag ID
filter_params = get_filter_params("saved_query_tag_id", tag.id) filter_params = get_filter_params("saved_query_tag_id", tag.id)
data_by_id = get_saved_queries_filtered_list(filter_params) response_by_id = self.get_list("saved_query", filter_params)
self.assertEqual(response_by_id.status_code, 200)
data_by_id = json.loads(response_by_id.data.decode("utf-8"))
# Filter by tag name # Filter by tag name
filter_params = get_filter_params("saved_query_tags", tag.name) filter_params = get_filter_params("saved_query_tags", tag.name)
data_by_name = get_saved_queries_filtered_list(filter_params) response_by_name = self.get_list("saved_query", filter_params)
self.assertEqual(response_by_name.status_code, 200)
data_by_name = json.loads(response_by_name.data.decode("utf-8"))
# Compare results # Compare results
self.assertEqual( self.assertEqual(