diff --git a/docs/docs/installation/sql-templating.mdx b/docs/docs/installation/sql-templating.mdx index 311132032a..cdf9afef88 100644 --- a/docs/docs/installation/sql-templating.mdx +++ b/docs/docs/installation/sql-templating.mdx @@ -369,3 +369,17 @@ Since metrics are aggregations, the resulting SQL expression will be grouped by ``` SELECT * FROM {{ dataset(42, include_metrics=True, columns=["ds", "category"]) }} LIMIT 10 ``` + +**Metrics** + +The `{{ metric('metric_key', dataset_id) }}` macro can be used to retrieve the metric SQL syntax from a dataset. This can be useful for different purposes: + +- Override the metric label in the chart level +- Combine multiple metrics in a calculation +- Retrieve a metric syntax in SQL lab +- Re-use metrics across datasets + +This macro avoids copy/paste, allowing users to centralize the metric definition in the dataset layer. + +The `dataset_id` parameter is optional, and if not provided Superset will use the current dataset from context (for example, when using this macro in the Chart Builder, by default the `macro_key` will be searched in the dataset powering the chart). +The parameter can be used in SQL Lab, or when fetching a metric from another dataset. diff --git a/superset/jinja_context.py b/superset/jinja_context.py index 0ee7667811..23949cca11 100644 --- a/superset/jinja_context.py +++ b/superset/jinja_context.py @@ -554,6 +554,7 @@ class JinjaTemplateProcessor(BaseTemplateProcessor): "filter_values": partial(safe_proxy, extra_cache.filter_values), "get_filters": partial(safe_proxy, extra_cache.get_filters), "dataset": partial(safe_proxy, dataset_macro_with_context), + "metric": partial(safe_proxy, metric_macro), } ) @@ -722,3 +723,72 @@ def dataset_macro( sqla_query = dataset.get_query_str_extended(query_obj, mutate=False) sql = sqla_query.sql return f"(\n{sql}\n) AS dataset_{dataset_id}" + + +def get_dataset_id_from_context(metric_key: str) -> int: + """ + Retrives the Dataset ID from the request context. + + :param metric_key: the metric key. + :returns: the dataset ID. + """ + # pylint: disable=import-outside-toplevel + from superset.daos.chart import ChartDAO + from superset.views.utils import get_form_data + + exc_message = _( + "Please specify the Dataset ID for the ``%(name)s`` metric in the Jinja macro.", + name=metric_key, + ) + + form_data, chart = get_form_data() + if not (form_data or chart): + raise SupersetTemplateException(exc_message) + + if chart and chart.datasource_id: + return chart.datasource_id + if dataset_id := form_data.get("url_params", {}).get("datasource_id"): + return dataset_id + if chart_id := ( + form_data.get("slice_id") or form_data.get("url_params", {}).get("slice_id") + ): + chart_data = ChartDAO.find_by_id(chart_id) + if not chart_data: + raise SupersetTemplateException(exc_message) + return chart_data.datasource_id + raise SupersetTemplateException(exc_message) + + +def metric_macro(metric_key: str, dataset_id: Optional[int] = None) -> str: + """ + Given a metric key, returns its syntax. + + The ``dataset_id`` is optional and if not specified, will be retrieved + from the request context (if available). + + :param metric_key: the metric key. + :param dataset_id: the ID for the dataset the metric is associated with. + :returns: the macro SQL syntax. + """ + # pylint: disable=import-outside-toplevel + from superset.daos.dataset import DatasetDAO + + if not dataset_id: + dataset_id = get_dataset_id_from_context(metric_key) + + dataset = DatasetDAO.find_by_id(dataset_id) + if not dataset: + raise DatasetNotFoundError(f"Dataset ID {dataset_id} not found.") + metrics: dict[str, str] = { + metric.metric_name: metric.expression for metric in dataset.metrics + } + dataset_name = dataset.table_name + if metric := metrics.get(metric_key): + return metric + raise SupersetTemplateException( + _( + "Metric ``%(metric_name)s`` not found in %(dataset_name)s.", + metric_name=metric_key, + dataset_name=dataset_name, + ) + ) diff --git a/tests/integration_tests/sqla_models_tests.py b/tests/integration_tests/sqla_models_tests.py index 1a66903da7..6cae6f6a14 100644 --- a/tests/integration_tests/sqla_models_tests.py +++ b/tests/integration_tests/sqla_models_tests.py @@ -267,6 +267,61 @@ class TestDatabaseModel(SupersetTestCase): db.session.delete(table) db.session.commit() + @patch("superset.views.utils.get_form_data") + def test_jinja_metric_macro(self, mock_form_data_context): + self.login(username="admin") + table = self.get_table(name="birth_names") + metric = SqlMetric( + metric_name="count_jinja_metric", expression="count(*)", table=table + ) + db.session.commit() + + base_query_obj = { + "granularity": None, + "from_dttm": None, + "to_dttm": None, + "columns": [], + "metrics": [ + { + "hasCustomLabel": True, + "label": "Metric using Jinja macro", + "expressionType": AdhocMetricExpressionType.SQL, + "sqlExpression": "{{ metric('count_jinja_metric') }}", + }, + { + "hasCustomLabel": True, + "label": "Same but different", + "expressionType": AdhocMetricExpressionType.SQL, + "sqlExpression": "{{ metric('count_jinja_metric', " + + str(table.id) + + ") }}", + }, + ], + "is_timeseries": False, + "filter": [], + "extras": {"time_grain_sqla": "P1D"}, + } + mock_form_data_context.return_value = [ + { + "url_params": { + "datasource_id": table.id, + } + }, + None, + ] + sqla_query = table.get_sqla_query(**base_query_obj) + query = table.database.compile_sqla_query(sqla_query.sqla_query) + + database = table.database + with database.get_sqla_engine_with_context() as engine: + quote = engine.dialect.identifier_preparer.quote_identifier + + for metric_label in {"metric using jinja macro", "same but different"}: + assert f"count(*) as {quote(metric_label)}" in query.lower() + + db.session.delete(metric) + db.session.commit() + def test_adhoc_metrics_and_calc_columns(self): base_query_obj = { "granularity": None, diff --git a/tests/unit_tests/jinja_context_test.py b/tests/unit_tests/jinja_context_test.py index d880151af5..15fe81aeb0 100644 --- a/tests/unit_tests/jinja_context_test.py +++ b/tests/unit_tests/jinja_context_test.py @@ -17,13 +17,384 @@ # pylint: disable=invalid-name, unused-argument import json +from typing import Any import pytest from pytest_mock import MockFixture from sqlalchemy.dialects import mysql +from sqlalchemy.dialects.postgresql import dialect +from superset import app from superset.commands.dataset.exceptions import DatasetNotFoundError -from superset.jinja_context import dataset_macro, WhereInMacro +from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn +from superset.exceptions import SupersetTemplateException +from superset.jinja_context import ( + dataset_macro, + ExtraCache, + metric_macro, + safe_proxy, + WhereInMacro, +) +from superset.models.core import Database +from superset.models.slice import Slice + + +def test_filter_values_adhoc_filters() -> None: + """ + Test the ``filter_values`` macro with ``adhoc_filters``. + """ + with app.test_request_context( + data={ + "form_data": json.dumps( + { + "adhoc_filters": [ + { + "clause": "WHERE", + "comparator": "foo", + "expressionType": "SIMPLE", + "operator": "in", + "subject": "name", + } + ], + } + ) + } + ): + cache = ExtraCache() + assert cache.filter_values("name") == ["foo"] + assert cache.applied_filters == ["name"] + + with app.test_request_context( + data={ + "form_data": json.dumps( + { + "adhoc_filters": [ + { + "clause": "WHERE", + "comparator": ["foo", "bar"], + "expressionType": "SIMPLE", + "operator": "in", + "subject": "name", + } + ], + } + ) + } + ): + cache = ExtraCache() + assert cache.filter_values("name") == ["foo", "bar"] + assert cache.applied_filters == ["name"] + + +def test_filter_values_extra_filters() -> None: + """ + Test the ``filter_values`` macro with ``extra_filters``. + """ + with app.test_request_context( + data={ + "form_data": json.dumps( + {"extra_filters": [{"col": "name", "op": "in", "val": "foo"}]} + ) + } + ): + cache = ExtraCache() + assert cache.filter_values("name") == ["foo"] + assert cache.applied_filters == ["name"] + + +def test_filter_values_default() -> None: + """ + Test the ``filter_values`` macro with a default value. + """ + cache = ExtraCache() + assert cache.filter_values("name", "foo") == ["foo"] + assert cache.removed_filters == [] + + +def test_filter_values_remove_not_present() -> None: + """ + Test the ``filter_values`` macro without a match and ``remove_filter`` set to True. + """ + cache = ExtraCache() + assert cache.filter_values("name", remove_filter=True) == [] + assert cache.removed_filters == [] + + +def test_filter_values_no_default() -> None: + """ + Test calling the ``filter_values`` macro without a match. + """ + cache = ExtraCache() + assert cache.filter_values("name") == [] + + +def test_get_filters_adhoc_filters() -> None: + """ + Test the ``get_filters`` macro. + """ + with app.test_request_context( + data={ + "form_data": json.dumps( + { + "adhoc_filters": [ + { + "clause": "WHERE", + "comparator": "foo", + "expressionType": "SIMPLE", + "operator": "in", + "subject": "name", + } + ], + } + ) + } + ): + cache = ExtraCache() + assert cache.get_filters("name") == [ + {"op": "IN", "col": "name", "val": ["foo"]} + ] + + assert cache.removed_filters == [] + assert cache.applied_filters == ["name"] + + with app.test_request_context( + data={ + "form_data": json.dumps( + { + "adhoc_filters": [ + { + "clause": "WHERE", + "comparator": ["foo", "bar"], + "expressionType": "SIMPLE", + "operator": "in", + "subject": "name", + } + ], + } + ) + } + ): + cache = ExtraCache() + assert cache.get_filters("name") == [ + {"op": "IN", "col": "name", "val": ["foo", "bar"]} + ] + assert cache.removed_filters == [] + + with app.test_request_context( + data={ + "form_data": json.dumps( + { + "adhoc_filters": [ + { + "clause": "WHERE", + "comparator": ["foo", "bar"], + "expressionType": "SIMPLE", + "operator": "in", + "subject": "name", + } + ], + } + ) + } + ): + cache = ExtraCache() + assert cache.get_filters("name", remove_filter=True) == [ + {"op": "IN", "col": "name", "val": ["foo", "bar"]} + ] + assert cache.removed_filters == ["name"] + assert cache.applied_filters == ["name"] + + +def test_get_filters_remove_not_present() -> None: + """ + Test the ``get_filters`` macro without a match and ``remove_filter`` set to True. + """ + cache = ExtraCache() + assert cache.get_filters("name", remove_filter=True) == [] + assert cache.removed_filters == [] + + +def test_url_param_query() -> None: + """ + Test the ``url_param`` macro. + """ + with app.test_request_context(query_string={"foo": "bar"}): + cache = ExtraCache() + assert cache.url_param("foo") == "bar" + + +def test_url_param_default() -> None: + """ + Test the ``url_param`` macro with a default value. + """ + with app.test_request_context(): + cache = ExtraCache() + assert cache.url_param("foo", "bar") == "bar" + + +def test_url_param_no_default() -> None: + """ + Test the ``url_param`` macro without a match. + """ + with app.test_request_context(): + cache = ExtraCache() + assert cache.url_param("foo") is None + + +def test_url_param_form_data() -> None: + """ + Test the ``url_param`` with ``url_params`` in ``form_data``. + """ + with app.test_request_context( + query_string={"form_data": json.dumps({"url_params": {"foo": "bar"}})} + ): + cache = ExtraCache() + assert cache.url_param("foo") == "bar" + + +def test_url_param_escaped_form_data() -> None: + """ + Test the ``url_param`` with ``url_params`` in ``form_data`` returning + an escaped value with a quote. + """ + with app.test_request_context( + query_string={"form_data": json.dumps({"url_params": {"foo": "O'Brien"}})} + ): + cache = ExtraCache(dialect=dialect()) + assert cache.url_param("foo") == "O''Brien" + + +def test_url_param_escaped_default_form_data() -> None: + """ + Test the ``url_param`` with default value containing an escaped quote. + """ + with app.test_request_context( + query_string={"form_data": json.dumps({"url_params": {"foo": "O'Brien"}})} + ): + cache = ExtraCache(dialect=dialect()) + assert cache.url_param("bar", "O'Malley") == "O''Malley" + + +def test_url_param_unescaped_form_data() -> None: + """ + Test the ``url_param`` with ``url_params`` in ``form_data`` returning + an un-escaped value with a quote. + """ + with app.test_request_context( + query_string={"form_data": json.dumps({"url_params": {"foo": "O'Brien"}})} + ): + cache = ExtraCache(dialect=dialect()) + assert cache.url_param("foo", escape_result=False) == "O'Brien" + + +def test_url_param_unescaped_default_form_data() -> None: + """ + Test the ``url_param`` with default value containing an un-escaped quote. + """ + with app.test_request_context( + query_string={"form_data": json.dumps({"url_params": {"foo": "O'Brien"}})} + ): + cache = ExtraCache(dialect=dialect()) + assert cache.url_param("bar", "O'Malley", escape_result=False) == "O'Malley" + + +def test_safe_proxy_primitive() -> None: + """ + Test the ``safe_proxy`` helper with a function returning a ``str``. + """ + + def func(input_: Any) -> Any: + return input_ + + assert safe_proxy(func, "foo") == "foo" + + +def test_safe_proxy_dict() -> None: + """ + Test the ``safe_proxy`` helper with a function returning a ``dict``. + """ + + def func(input_: Any) -> Any: + return input_ + + assert safe_proxy(func, {"foo": "bar"}) == {"foo": "bar"} + + +def test_safe_proxy_lambda() -> None: + """ + Test the ``safe_proxy`` helper with a function returning a ``lambda``. + Should raise ``SupersetTemplateException``. + """ + + def func(input_: Any) -> Any: + return input_ + + with pytest.raises(SupersetTemplateException): + safe_proxy(func, lambda: "bar") + + +def test_safe_proxy_nested_lambda() -> None: + """ + Test the ``safe_proxy`` helper with a function returning a ``dict`` + containing ``lambda`` value. Should raise ``SupersetTemplateException``. + """ + + def func(input_: Any) -> Any: + return input_ + + with pytest.raises(SupersetTemplateException): + safe_proxy(func, {"foo": lambda: "bar"}) + + +def test_user_macros(mocker: MockFixture): + """ + Test all user macros: + - ``current_user_id`` + - ``current_username`` + - ``current_user_email`` + """ + mock_g = mocker.patch("superset.utils.core.g") + mock_cache_key_wrapper = mocker.patch( + "superset.jinja_context.ExtraCache.cache_key_wrapper" + ) + mock_g.user.id = 1 + mock_g.user.username = "my_username" + mock_g.user.email = "my_email@test.com" + cache = ExtraCache() + assert cache.current_user_id() == 1 + assert cache.current_username() == "my_username" + assert cache.current_user_email() == "my_email@test.com" + assert mock_cache_key_wrapper.call_count == 3 + + +def test_user_macros_without_cache_key_inclusion(mocker: MockFixture): + """ + Test all user macros with ``add_to_cache_keys`` set to ``False``. + """ + mock_g = mocker.patch("superset.utils.core.g") + mock_cache_key_wrapper = mocker.patch( + "superset.jinja_context.ExtraCache.cache_key_wrapper" + ) + mock_g.user.id = 1 + mock_g.user.username = "my_username" + mock_g.user.email = "my_email@test.com" + cache = ExtraCache() + assert cache.current_user_id(False) == 1 + assert cache.current_username(False) == "my_username" + assert cache.current_user_email(False) == "my_email@test.com" + assert mock_cache_key_wrapper.call_count == 0 + + +def test_user_macros_without_user_info(mocker: MockFixture): + """ + Test all user macros when no user info is available. + """ + mock_g = mocker.patch("superset.utils.core.g") + mock_g.user = None + cache = ExtraCache() + assert cache.current_user_id() == None + assert cache.current_username() == None + assert cache.current_user_email() == None def test_where_in() -> None: @@ -43,10 +414,6 @@ def test_dataset_macro(mocker: MockFixture) -> None: """ Test the ``dataset_macro`` macro. """ - # pylint: disable=import-outside-toplevel - from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn - from superset.models.core import Database - mocker.patch( "superset.connectors.sqla.models.security_manager.get_guest_rls_filters", return_value=[], @@ -180,3 +547,298 @@ SELECT 1 -- end ) AS dataset_1""" ) + + +def test_metric_macro_with_dataset_id(mocker: MockFixture) -> None: + """ + Test the ``metric_macro`` when passing a dataset ID. + """ + mock_get_form_data = mocker.patch("superset.views.utils.get_form_data") + DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") + DatasetDAO.find_by_id.return_value = SqlaTable( + table_name="test_dataset", + metrics=[ + SqlMetric(metric_name="count", expression="COUNT(*)"), + ], + database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), + schema="my_schema", + sql=None, + ) + assert metric_macro("count", 1) == "COUNT(*)" + mock_get_form_data.assert_not_called() + + +def test_metric_macro_with_dataset_id_invalid_key(mocker: MockFixture) -> None: + """ + Test the ``metric_macro`` when passing a dataset ID and an invalid key. + """ + mock_get_form_data = mocker.patch("superset.views.utils.get_form_data") + DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") + DatasetDAO.find_by_id.return_value = SqlaTable( + table_name="test_dataset", + metrics=[ + SqlMetric(metric_name="count", expression="COUNT(*)"), + ], + database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), + schema="my_schema", + sql=None, + ) + with pytest.raises(SupersetTemplateException) as excinfo: + metric_macro("blah", 1) + assert str(excinfo.value) == "Metric ``blah`` not found in test_dataset." + mock_get_form_data.assert_not_called() + + +def test_metric_macro_invalid_dataset_id(mocker: MockFixture) -> None: + """ + Test the ``metric_macro`` when specifying a dataset that doesn't exist. + """ + mock_get_form_data = mocker.patch("superset.views.utils.get_form_data") + DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") + DatasetDAO.find_by_id.return_value = None + with pytest.raises(DatasetNotFoundError) as excinfo: + metric_macro("macro_key", 100) + assert str(excinfo.value) == "Dataset ID 100 not found." + mock_get_form_data.assert_not_called() + + +def test_metric_macro_no_dataset_id_no_context(mocker: MockFixture) -> None: + """ + Test the ``metric_macro`` when not specifying a dataset ID and it's + not available in the context. + """ + DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") + mock_get_form_data = mocker.patch("superset.views.utils.get_form_data") + mock_get_form_data.return_value = [None, None] + with pytest.raises(SupersetTemplateException) as excinfo: + metric_macro("macro_key") + assert str(excinfo.value) == ( + "Please specify the Dataset ID for the ``macro_key`` metric in the Jinja macro." + ) + mock_get_form_data.assert_called_once() + DatasetDAO.find_by_id.assert_not_called() + + +def test_metric_macro_no_dataset_id_with_context_missing_info( + mocker: MockFixture, +) -> None: + """ + Test the ``metric_macro`` when not specifying a dataset ID and request + has context but no dataset/chart ID. + """ + DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") + mock_get_form_data = mocker.patch("superset.views.utils.get_form_data") + mock_get_form_data.return_value = [ + { + "url_params": {}, + }, + None, + ] + with pytest.raises(SupersetTemplateException) as excinfo: + metric_macro("macro_key") + assert str(excinfo.value) == ( + "Please specify the Dataset ID for the ``macro_key`` metric in the Jinja macro." + ) + mock_get_form_data.assert_called_once() + DatasetDAO.find_by_id.assert_not_called() + + +def test_metric_macro_no_dataset_id_with_context_datasource_id( + mocker: MockFixture, +) -> None: + """ + Test the ``metric_macro`` when not specifying a dataset ID and it's + available in the context (url_params.datasource_id). + """ + DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") + DatasetDAO.find_by_id.return_value = SqlaTable( + table_name="test_dataset", + metrics=[ + SqlMetric(metric_name="macro_key", expression="COUNT(*)"), + ], + database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), + schema="my_schema", + sql=None, + ) + mock_get_form_data = mocker.patch("superset.views.utils.get_form_data") + mock_get_form_data.return_value = [ + { + "url_params": { + "datasource_id": 1, + } + }, + None, + ] + assert metric_macro("macro_key") == "COUNT(*)" + mock_get_form_data.assert_called_once() + DatasetDAO.find_by_id.assert_called_once_with(1) + + +def test_metric_macro_no_dataset_id_with_context_datasource_id_none( + mocker: MockFixture, +) -> None: + """ + Test the ``metric_macro`` when not specifying a dataset ID and it's + set to None in the context (url_params.datasource_id). + """ + ChartDAO = mocker.patch("superset.daos.chart.ChartDAO") + ChartDAO.find_by_id.return_value = None + DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") + mock_get_form_data = mocker.patch("superset.views.utils.get_form_data") + mock_get_form_data.return_value = [ + { + "url_params": { + "datasource_id": None, + } + }, + None, + ] + + with pytest.raises(SupersetTemplateException) as excinfo: + metric_macro("macro_key") + assert str(excinfo.value) == ( + "Please specify the Dataset ID for the ``macro_key`` metric in the Jinja macro." + ) + mock_get_form_data.assert_called_once() + DatasetDAO.find_by_id.assert_not_called() + + +def test_metric_macro_no_dataset_id_with_context_chart_id(mocker: MockFixture) -> None: + """ + Test the ``metric_macro`` when not specifying a dataset ID and context + includes an existing chart ID (url_params.slice_id). + """ + ChartDAO = mocker.patch("superset.daos.chart.ChartDAO") + ChartDAO.find_by_id.return_value = Slice( + datasource_id=1, + ) + DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") + DatasetDAO.find_by_id.return_value = SqlaTable( + table_name="test_dataset", + metrics=[ + SqlMetric(metric_name="macro_key", expression="COUNT(*)"), + ], + database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), + schema="my_schema", + sql=None, + ) + mock_get_form_data = mocker.patch("superset.views.utils.get_form_data") + mock_get_form_data.return_value = [ + { + "slice_id": 1, + }, + None, + ] + assert metric_macro("macro_key") == "COUNT(*)" + mock_get_form_data.assert_called_once() + DatasetDAO.find_by_id.assert_called_once_with(1) + + +def test_metric_macro_no_dataset_id_with_context_slice_id_none( + mocker: MockFixture, +) -> None: + """ + Test the ``metric_macro`` when not specifying a dataset ID and context + includes slice_id set to None (url_params.slice_id). + """ + ChartDAO = mocker.patch("superset.daos.chart.ChartDAO") + ChartDAO.find_by_id.return_value = None + DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") + mock_get_form_data = mocker.patch("superset.views.utils.get_form_data") + mock_get_form_data.return_value = [ + { + "slice_id": None, + }, + None, + ] + + with pytest.raises(SupersetTemplateException) as excinfo: + metric_macro("macro_key") + assert str(excinfo.value) == ( + "Please specify the Dataset ID for the ``macro_key`` metric in the Jinja macro." + ) + mock_get_form_data.assert_called_once() + DatasetDAO.find_by_id.assert_not_called() + + +def test_metric_macro_no_dataset_id_with_context_chart(mocker: MockFixture) -> None: + """ + Test the ``metric_macro`` when not specifying a dataset ID and context + includes an existing chart (get_form_data()[1]). + """ + ChartDAO = mocker.patch("superset.daos.chart.ChartDAO") + DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") + DatasetDAO.find_by_id.return_value = SqlaTable( + table_name="test_dataset", + metrics=[ + SqlMetric(metric_name="macro_key", expression="COUNT(*)"), + ], + database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), + schema="my_schema", + sql=None, + ) + mock_get_form_data = mocker.patch("superset.views.utils.get_form_data") + mock_get_form_data.return_value = [ + { + "slice_id": 1, + }, + Slice(datasource_id=1), + ] + assert metric_macro("macro_key") == "COUNT(*)" + mock_get_form_data.assert_called_once() + DatasetDAO.find_by_id.assert_called_once_with(1) + ChartDAO.find_by_id.assert_not_called() + + +def test_metric_macro_no_dataset_id_with_context_deleted_chart( + mocker: MockFixture, +) -> None: + """ + Test the ``metric_macro`` when not specifying a dataset ID and context + includes a deleted chart ID. + """ + ChartDAO = mocker.patch("superset.daos.chart.ChartDAO") + ChartDAO.find_by_id.return_value = None + DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") + mock_get_form_data = mocker.patch("superset.views.utils.get_form_data") + mock_get_form_data.return_value = [ + { + "slice_id": 1, + }, + None, + ] + + with pytest.raises(SupersetTemplateException) as excinfo: + metric_macro("macro_key") + assert str(excinfo.value) == ( + "Please specify the Dataset ID for the ``macro_key`` metric in the Jinja macro." + ) + mock_get_form_data.assert_called_once() + DatasetDAO.find_by_id.assert_not_called() + + +def test_metric_macro_no_dataset_id_with_context_chart_no_datasource_id( + mocker: MockFixture, +) -> None: + """ + Test the ``metric_macro`` when not specifying a dataset ID and context + includes an existing chart (get_form_data()[1]) with no dataset ID. + """ + ChartDAO = mocker.patch("superset.daos.chart.ChartDAO") + ChartDAO.find_by_id.return_value = None + DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") + mock_get_form_data = mocker.patch("superset.views.utils.get_form_data") + mock_get_form_data.return_value = [ + {}, + Slice( + datasource_id=None, + ), + ] + + with pytest.raises(SupersetTemplateException) as excinfo: + metric_macro("macro_key") + assert str(excinfo.value) == ( + "Please specify the Dataset ID for the ``macro_key`` metric in the Jinja macro." + ) + mock_get_form_data.assert_called_once() + DatasetDAO.find_by_id.assert_not_called() diff --git a/tests/unit_tests/test_jinja_context.py b/tests/unit_tests/test_jinja_context.py deleted file mode 100644 index 70dcbe56be..0000000000 --- a/tests/unit_tests/test_jinja_context.py +++ /dev/null @@ -1,305 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import json -from typing import Any -from unittest.mock import patch - -import pytest -from sqlalchemy.dialects.postgresql import dialect - -from superset import app -from superset.exceptions import SupersetTemplateException -from superset.jinja_context import ExtraCache, safe_proxy - - -def test_filter_values_default() -> None: - cache = ExtraCache() - assert cache.filter_values("name", "foo") == ["foo"] - assert cache.removed_filters == [] - - -def test_filter_values_remove_not_present() -> None: - cache = ExtraCache() - assert cache.filter_values("name", remove_filter=True) == [] - assert cache.removed_filters == [] - - -def test_get_filters_remove_not_present() -> None: - cache = ExtraCache() - assert cache.get_filters("name", remove_filter=True) == [] - assert cache.removed_filters == [] - - -def test_filter_values_no_default() -> None: - cache = ExtraCache() - assert cache.filter_values("name") == [] - - -def test_filter_values_adhoc_filters() -> None: - with app.test_request_context( - data={ - "form_data": json.dumps( - { - "adhoc_filters": [ - { - "clause": "WHERE", - "comparator": "foo", - "expressionType": "SIMPLE", - "operator": "in", - "subject": "name", - } - ], - } - ) - } - ): - cache = ExtraCache() - assert cache.filter_values("name") == ["foo"] - assert cache.applied_filters == ["name"] - - with app.test_request_context( - data={ - "form_data": json.dumps( - { - "adhoc_filters": [ - { - "clause": "WHERE", - "comparator": ["foo", "bar"], - "expressionType": "SIMPLE", - "operator": "in", - "subject": "name", - } - ], - } - ) - } - ): - cache = ExtraCache() - assert cache.filter_values("name") == ["foo", "bar"] - assert cache.applied_filters == ["name"] - - -def test_get_filters_adhoc_filters() -> None: - with app.test_request_context( - data={ - "form_data": json.dumps( - { - "adhoc_filters": [ - { - "clause": "WHERE", - "comparator": "foo", - "expressionType": "SIMPLE", - "operator": "in", - "subject": "name", - } - ], - } - ) - } - ): - cache = ExtraCache() - assert cache.get_filters("name") == [ - {"op": "IN", "col": "name", "val": ["foo"]} - ] - - assert cache.removed_filters == [] - assert cache.applied_filters == ["name"] - - with app.test_request_context( - data={ - "form_data": json.dumps( - { - "adhoc_filters": [ - { - "clause": "WHERE", - "comparator": ["foo", "bar"], - "expressionType": "SIMPLE", - "operator": "in", - "subject": "name", - } - ], - } - ) - } - ): - cache = ExtraCache() - assert cache.get_filters("name") == [ - {"op": "IN", "col": "name", "val": ["foo", "bar"]} - ] - assert cache.removed_filters == [] - - with app.test_request_context( - data={ - "form_data": json.dumps( - { - "adhoc_filters": [ - { - "clause": "WHERE", - "comparator": ["foo", "bar"], - "expressionType": "SIMPLE", - "operator": "in", - "subject": "name", - } - ], - } - ) - } - ): - cache = ExtraCache() - assert cache.get_filters("name", remove_filter=True) == [ - {"op": "IN", "col": "name", "val": ["foo", "bar"]} - ] - assert cache.removed_filters == ["name"] - assert cache.applied_filters == ["name"] - - -def test_filter_values_extra_filters() -> None: - with app.test_request_context( - data={ - "form_data": json.dumps( - {"extra_filters": [{"col": "name", "op": "in", "val": "foo"}]} - ) - } - ): - cache = ExtraCache() - assert cache.filter_values("name") == ["foo"] - assert cache.applied_filters == ["name"] - - -def test_url_param_default() -> None: - with app.test_request_context(): - cache = ExtraCache() - assert cache.url_param("foo", "bar") == "bar" - - -def test_url_param_no_default() -> None: - with app.test_request_context(): - cache = ExtraCache() - assert cache.url_param("foo") is None - - -def test_url_param_query() -> None: - with app.test_request_context(query_string={"foo": "bar"}): - cache = ExtraCache() - assert cache.url_param("foo") == "bar" - - -def test_url_param_form_data() -> None: - with app.test_request_context( - query_string={"form_data": json.dumps({"url_params": {"foo": "bar"}})} - ): - cache = ExtraCache() - assert cache.url_param("foo") == "bar" - - -def test_url_param_escaped_form_data() -> None: - with app.test_request_context( - query_string={"form_data": json.dumps({"url_params": {"foo": "O'Brien"}})} - ): - cache = ExtraCache(dialect=dialect()) - assert cache.url_param("foo") == "O''Brien" - - -def test_url_param_escaped_default_form_data() -> None: - with app.test_request_context( - query_string={"form_data": json.dumps({"url_params": {"foo": "O'Brien"}})} - ): - cache = ExtraCache(dialect=dialect()) - assert cache.url_param("bar", "O'Malley") == "O''Malley" - - -def test_url_param_unescaped_form_data() -> None: - with app.test_request_context( - query_string={"form_data": json.dumps({"url_params": {"foo": "O'Brien"}})} - ): - cache = ExtraCache(dialect=dialect()) - assert cache.url_param("foo", escape_result=False) == "O'Brien" - - -def test_url_param_unescaped_default_form_data() -> None: - with app.test_request_context( - query_string={"form_data": json.dumps({"url_params": {"foo": "O'Brien"}})} - ): - cache = ExtraCache(dialect=dialect()) - assert cache.url_param("bar", "O'Malley", escape_result=False) == "O'Malley" - - -def test_safe_proxy_primitive() -> None: - def func(input_: Any) -> Any: - return input_ - - assert safe_proxy(func, "foo") == "foo" - - -def test_safe_proxy_dict() -> None: - def func(input_: Any) -> Any: - return input_ - - assert safe_proxy(func, {"foo": "bar"}) == {"foo": "bar"} - - -def test_safe_proxy_lambda() -> None: - def func(input_: Any) -> Any: - return input_ - - with pytest.raises(SupersetTemplateException): - safe_proxy(func, lambda: "bar") - - -def test_safe_proxy_nested_lambda() -> None: - def func(input_: Any) -> Any: - return input_ - - with pytest.raises(SupersetTemplateException): - safe_proxy(func, {"foo": lambda: "bar"}) - - -@patch("superset.jinja_context.ExtraCache.cache_key_wrapper") -@patch("superset.utils.core.g") -def test_user_macros(mock_flask_user, mock_cache_key_wrapper): - mock_flask_user.user.id = 1 - mock_flask_user.user.username = "my_username" - mock_flask_user.user.email = "my_email@test.com" - cache = ExtraCache() - assert cache.current_user_id() == 1 - assert cache.current_username() == "my_username" - assert cache.current_user_email() == "my_email@test.com" - assert mock_cache_key_wrapper.call_count == 3 - - -@patch("superset.jinja_context.ExtraCache.cache_key_wrapper") -@patch("superset.utils.core.g") -def test_user_macros_without_cache_key_inclusion( - mock_flask_user, mock_cache_key_wrapper -): - mock_flask_user.user.id = 1 - mock_flask_user.user.username = "my_username" - mock_flask_user.user.email = "my_email@test.com" - cache = ExtraCache() - assert cache.current_user_id(False) == 1 - assert cache.current_username(False) == "my_username" - assert cache.current_user_email(False) == "my_email@test.com" - assert mock_cache_key_wrapper.call_count == 0 - - -@patch("superset.utils.core.g") -def test_user_macros_without_user_info(mock_flask_user): - mock_flask_user.user = None - cache = ExtraCache() - assert cache.current_user_id() == None - assert cache.current_username() == None - assert cache.current_user_email() == None