# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. # isort:skip_file """Unit tests for Superset""" import json import unittest import copy from datetime import datetime from io import BytesIO from typing import Any, Dict, Optional, List from unittest import mock from zipfile import ZipFile from flask import Response from tests.integration_tests.conftest import with_feature_flags from superset.models.sql_lab import Query from tests.integration_tests.base_tests import ( SupersetTestCase, test_client, ) from tests.integration_tests.annotation_layers.fixtures import create_annotation_layers from tests.integration_tests.fixtures.birth_names_dashboard import ( load_birth_names_dashboard_with_slices, load_birth_names_data, ) from tests.integration_tests.test_app import app from tests.integration_tests.fixtures.energy_dashboard import ( load_energy_table_with_slice, load_energy_table_data, ) import pytest from superset.models.slice import Slice from superset.charts.data.commands.get_data_command import ChartDataCommand from superset.connectors.sqla.models import TableColumn, SqlaTable from superset.errors import SupersetErrorType from superset.extensions import async_query_manager, db from superset.models.annotations import AnnotationLayer from superset.models.slice import Slice from superset.superset_typing import AdhocColumn from superset.utils.core import ( AnnotationType, get_example_default_schema, AdhocMetricExpressionType, ExtraFiltersReasonType, ) from superset.utils.database import get_example_database, get_main_database from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType from tests.common.query_context_generator import ANNOTATION_LAYERS from tests.integration_tests.fixtures.query_context import get_query_context CHART_DATA_URI = "api/v1/chart/data" CHARTS_FIXTURE_COUNT = 10 ADHOC_COLUMN_FIXTURE: AdhocColumn = { "hasCustomLabel": True, "label": "male_or_female", "sqlExpression": "case when gender = 'boy' then 'male' " "when gender = 'girl' then 'female' else 'other' end", } INCOMPATIBLE_ADHOC_COLUMN_FIXTURE: AdhocColumn = { "hasCustomLabel": True, "label": "exciting_or_boring", "sqlExpression": "case when genre = 'Action' then 'Exciting' else 'Boring' end", } class BaseTestChartDataApi(SupersetTestCase): query_context_payload_template = None def setUp(self) -> None: self.login("admin") if self.query_context_payload_template is None: BaseTestChartDataApi.query_context_payload_template = get_query_context( "birth_names" ) self.query_context_payload = copy.deepcopy(self.query_context_payload_template) def get_expected_row_count(self, client_id: str) -> int: start_date = datetime.now() start_date = start_date.replace( year=start_date.year - 100, hour=0, minute=0, second=0 ) quoted_table_name = self.quote_name("birth_names") sql = f""" SELECT COUNT(*) AS rows_count FROM ( SELECT name AS name, SUM(num) AS sum__num FROM {quoted_table_name} WHERE ds >= '{start_date.strftime("%Y-%m-%d %H:%M:%S")}' AND gender = 'boy' GROUP BY name ORDER BY sum__num DESC LIMIT 100) AS inner__query """ resp = self.run_sql(sql, client_id, raise_on_error=True) db.session.query(Query).delete() db.session.commit() return resp["data"][0]["rows_count"] def quote_name(self, name: str): if get_main_database().backend in {"presto", "hive"}: return get_example_database().inspector.engine.dialect.identifier_preparer.quote_identifier( name ) return name @pytest.mark.chart_data_flow class TestPostChartDataApi(BaseTestChartDataApi): @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_with_valid_qc__data_is_returned(self): # arrange expected_row_count = self.get_expected_row_count("client_id_1") # act rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") # assert assert rv.status_code == 200 self.assert_row_count(rv, expected_row_count) @staticmethod def assert_row_count(rv: Response, expected_row_count: int): assert rv.json["result"][0]["rowcount"] == expected_row_count @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @mock.patch( "superset.common.query_context_factory.config", {**app.config, "ROW_LIMIT": 7}, ) def test_without_row_limit__row_count_as_default_row_limit(self): # arrange expected_row_count = 7 del self.query_context_payload["queries"][0]["row_limit"] # act rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") # assert self.assert_row_count(rv, expected_row_count) @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @mock.patch( "superset.common.query_context_factory.config", {**app.config, "SAMPLES_ROW_LIMIT": 5}, ) def test_as_samples_without_row_limit__row_count_as_default_samples_row_limit(self): # arrange expected_row_count = 5 app.config["SAMPLES_ROW_LIMIT"] = expected_row_count self.query_context_payload["result_type"] = ChartDataResultType.SAMPLES del self.query_context_payload["queries"][0]["row_limit"] # act rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") # assert self.assert_row_count(rv, expected_row_count) assert "GROUP BY" not in rv.json["result"][0]["query"] @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @mock.patch( "superset.utils.core.current_app.config", {**app.config, "SQL_MAX_ROW": 10}, ) def test_with_row_limit_bigger_then_sql_max_row__rowcount_as_sql_max_row(self): # arrange expected_row_count = 10 self.query_context_payload["queries"][0]["row_limit"] = 10000000 # act rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") # assert self.assert_row_count(rv, expected_row_count) @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @mock.patch( "superset.utils.core.current_app.config", {**app.config, "SQL_MAX_ROW": 5}, ) def test_as_samples_with_row_limit_bigger_then_sql_max_row_rowcount_as_sql_max_row( self, ): expected_row_count = app.config["SQL_MAX_ROW"] self.query_context_payload["result_type"] = ChartDataResultType.SAMPLES self.query_context_payload["queries"][0]["row_limit"] = 10000000 rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") # assert self.assert_row_count(rv, expected_row_count) assert "GROUP BY" not in rv.json["result"][0]["query"] @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @mock.patch( "superset.common.query_actions.config", {**app.config, "SAMPLES_ROW_LIMIT": 5, "SQL_MAX_ROW": 15}, ) def test_with_row_limit_as_samples__rowcount_as_row_limit(self): expected_row_count = 10 self.query_context_payload["result_type"] = ChartDataResultType.SAMPLES self.query_context_payload["queries"][0]["row_limit"] = expected_row_count rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") # assert self.assert_row_count(rv, expected_row_count) assert "GROUP BY" not in rv.json["result"][0]["query"] def test_with_incorrect_result_type__400(self): self.query_context_payload["result_type"] = "qwerty" rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") assert rv.status_code == 400 def test_with_incorrect_result_format__400(self): self.query_context_payload["result_format"] = "qwerty" rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") assert rv.status_code == 400 @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_with_invalid_payload__400(self): invalid_query_context = {"form_data": "NOT VALID JSON"} rv = self.client.post( CHART_DATA_URI, data=invalid_query_context, content_type="multipart/form-data", ) assert rv.status_code == 400 assert rv.json["message"] == "Request is not JSON" @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_with_query_result_type__200(self): self.query_context_payload["result_type"] = ChartDataResultType.QUERY rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") assert rv.status_code == 200 @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_empty_request_with_csv_result_format(self): """ Chart data API: Test empty chart data with CSV result format """ self.query_context_payload["result_format"] = "csv" self.query_context_payload["queries"] = [] rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") assert rv.status_code == 400 @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_empty_request_with_excel_result_format(self): """ Chart data API: Test empty chart data with Excel result format """ self.query_context_payload["result_format"] = "xlsx" self.query_context_payload["queries"] = [] rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") assert rv.status_code == 400 @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_with_csv_result_format(self): """ Chart data API: Test chart data with CSV result format """ self.query_context_payload["result_format"] = "csv" rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") assert rv.status_code == 200 assert rv.mimetype == "text/csv" @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_with_excel_result_format(self): """ Chart data API: Test chart data with Excel result format """ self.query_context_payload["result_format"] = "xlsx" mimetype = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") assert rv.status_code == 200 assert rv.mimetype == mimetype @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_with_multi_query_csv_result_format(self): """ Chart data API: Test chart data with multi-query CSV result format """ self.query_context_payload["result_format"] = "csv" self.query_context_payload["queries"].append( self.query_context_payload["queries"][0] ) rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") assert rv.status_code == 200 assert rv.mimetype == "application/zip" zipfile = ZipFile(BytesIO(rv.data), "r") assert zipfile.namelist() == ["query_1.csv", "query_2.csv"] @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_with_multi_query_excel_result_format(self): """ Chart data API: Test chart data with multi-query Excel result format """ self.query_context_payload["result_format"] = "xlsx" self.query_context_payload["queries"].append( self.query_context_payload["queries"][0] ) rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") assert rv.status_code == 200 assert rv.mimetype == "application/zip" zipfile = ZipFile(BytesIO(rv.data), "r") assert zipfile.namelist() == ["query_1.xlsx", "query_2.xlsx"] @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_with_csv_result_format_when_actor_not_permitted_for_csv__403(self): """ Chart data API: Test chart data with CSV result format """ self.logout() self.login(username="gamma_no_csv") self.query_context_payload["result_format"] = "csv" rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") assert rv.status_code == 403 @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_with_excel_result_format_when_actor_not_permitted_for_excel__403(self): """ Chart data API: Test chart data with Excel result format """ self.logout() self.login(username="gamma_no_csv") self.query_context_payload["result_format"] = "xlsx" rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") assert rv.status_code == 403 @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_with_row_limit_and_offset__row_limit_and_offset_were_applied(self): """ Chart data API: Test chart data query with limit and offset """ self.query_context_payload["queries"][0]["row_limit"] = 5 self.query_context_payload["queries"][0]["row_offset"] = 0 self.query_context_payload["queries"][0]["orderby"] = [["name", True]] rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") self.assert_row_count(rv, 5) result = rv.json["result"][0] # TODO: fix offset for presto DB if get_example_database().backend == "presto": return # ensure that offset works properly offset = 2 expected_name = result["data"][offset]["name"] self.query_context_payload["queries"][0]["row_offset"] = offset rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") result = rv.json["result"][0] assert result["rowcount"] == 5 assert result["data"][0]["name"] == expected_name @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_chart_data_applied_time_extras(self): """ Chart data API: Test chart data query with applied time extras """ self.query_context_payload["queries"][0]["applied_time_extras"] = { "__time_range": "100 years ago : now", "__time_origin": "now", } rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") self.assertEqual(rv.status_code, 200) data = json.loads(rv.data.decode("utf-8")) self.assertEqual( data["result"][0]["applied_filters"], [ {"column": "gender"}, {"column": "num"}, {"column": "name"}, {"column": "__time_range"}, ], ) expected_row_count = self.get_expected_row_count("client_id_2") self.assertEqual(data["result"][0]["rowcount"], expected_row_count) @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_with_in_op_filter__data_is_returned(self): """ Chart data API: Ensure mixed case filter operator generates valid result """ expected_row_count = 10 self.query_context_payload["queries"][0]["filters"][0]["op"] = "In" self.query_context_payload["queries"][0]["row_limit"] = expected_row_count rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") self.assert_row_count(rv, expected_row_count) @unittest.skip("Failing due to timezone difference") @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_chart_data_dttm_filter(self): """ Chart data API: Ensure temporal column filter converts epoch to dttm expression """ table = self.get_birth_names_dataset() if table.database.backend == "presto": # TODO: date handling on Presto not fully in line with other engine specs return self.query_context_payload["queries"][0]["time_range"] = "" dttm = self.get_dttm() ms_epoch = dttm.timestamp() * 1000 self.query_context_payload["queries"][0]["filters"][0] = { "col": "ds", "op": "!=", "val": ms_epoch, } rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") response_payload = json.loads(rv.data.decode("utf-8")) result = response_payload["result"][0] # assert that unconverted timestamp is not present in query assert str(ms_epoch) not in result["query"] # assert that converted timestamp is present in query where supported dttm_col: Optional[TableColumn] = None for col in table.columns: if col.column_name == table.main_dttm_col: dttm_col = col if dttm_col: dttm_expression = table.database.db_engine_spec.convert_dttm( dttm_col.type, dttm, ) self.assertIn(dttm_expression, result["query"]) else: raise Exception("ds column not found") def test_chart_data_prophet(self): """ Chart data API: Ensure prophet post transformation works """ pytest.importorskip("prophet") time_grain = "P1Y" self.query_context_payload["queries"][0]["is_timeseries"] = True self.query_context_payload["queries"][0]["groupby"] = [] self.query_context_payload["queries"][0]["extras"] = { "time_grain_sqla": time_grain } self.query_context_payload["queries"][0]["granularity"] = "ds" self.query_context_payload["queries"][0]["post_processing"] = [ { "operation": "prophet", "options": { "time_grain": time_grain, "periods": 3, "confidence_interval": 0.9, }, } ] rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") self.assertEqual(rv.status_code, 200) response_payload = json.loads(rv.data.decode("utf-8")) result = response_payload["result"][0] row = result["data"][0] self.assertIn("__timestamp", row) self.assertIn("sum__num", row) self.assertIn("sum__num__yhat", row) self.assertIn("sum__num__yhat_upper", row) self.assertIn("sum__num__yhat_lower", row) self.assertEqual(result["rowcount"], 47) @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_with_query_result_type_and_non_existent_filter__filter_omitted(self): self.query_context_payload["queries"][0]["filters"] = [ {"col": "non_existent_filter", "op": "==", "val": "foo"}, ] self.query_context_payload["result_type"] = ChartDataResultType.QUERY rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") assert rv.status_code == 200 assert "non_existent_filter" not in rv.json["result"][0]["query"] @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_with_filter_suppose_to_return_empty_data__no_data_returned(self): self.query_context_payload["queries"][0]["filters"] = [ {"col": "gender", "op": "==", "val": "foo"} ] rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") assert rv.status_code == 200 assert rv.json["result"][0]["data"] == [] self.assert_row_count(rv, 0) def test_with_invalid_where_parameter__400(self): self.query_context_payload["queries"][0]["filters"] = [] # erroneous WHERE-clause self.query_context_payload["queries"][0]["extras"]["where"] = "(gender abc def)" rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") assert rv.status_code == 400 @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_with_invalid_where_parameter_closing_unclosed__400(self): self.query_context_payload["queries"][0]["filters"] = [] self.query_context_payload["queries"][0]["extras"][ "where" ] = "state = 'CA') OR (state = 'NY'" rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") assert rv.status_code == 400 @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_with_where_parameter_including_comment___200(self): self.query_context_payload["queries"][0]["filters"] = [] self.query_context_payload["queries"][0]["extras"]["where"] = "1 = 1 -- abc" rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") assert rv.status_code == 200 @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_with_orderby_parameter_with_second_query__400(self): self.query_context_payload["queries"][0]["filters"] = [] self.query_context_payload["queries"][0]["orderby"] = [ [ { "expressionType": "SQL", "sqlExpression": "sum__num; select 1, 1", }, True, ], ] rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") assert rv.status_code == 400 @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_with_invalid_having_parameter_closing_and_comment__400(self): self.query_context_payload["queries"][0]["filters"] = [] self.query_context_payload["queries"][0]["extras"][ "having" ] = "COUNT(1) = 0) UNION ALL SELECT 'abc', 1--comment" rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") assert rv.status_code == 400 def test_with_invalid_datasource__400(self): self.query_context_payload["datasource"] = "abc" rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") assert rv.status_code == 400 def test_with_not_permitted_actor__403(self): """ Chart data API: Test chart data query not allowed """ self.logout() self.login(username="gamma") rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") assert rv.status_code == 403 assert ( rv.json["errors"][0]["error_type"] == SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR ) @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_when_where_parameter_is_template_and_query_result_type__query_is_templated( self, ): self.query_context_payload["result_type"] = ChartDataResultType.QUERY self.query_context_payload["queries"][0]["filters"] = [ {"col": "gender", "op": "==", "val": "boy"} ] self.query_context_payload["queries"][0]["extras"][ "where" ] = "('boy' = '{{ filter_values('gender', 'xyz' )[0] }}')" rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") result = rv.json["result"][0]["query"] if get_example_database().backend != "presto": assert "('boy' = 'boy')" in result @with_feature_flags(GLOBAL_ASYNC_QUERIES=True) @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_chart_data_async(self): self.logout() async_query_manager.init_app(app) self.login("admin") rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") self.assertEqual(rv.status_code, 202) data = json.loads(rv.data.decode("utf-8")) keys = list(data.keys()) self.assertCountEqual( keys, ["channel_id", "job_id", "user_id", "status", "errors", "result_url"] ) @with_feature_flags(GLOBAL_ASYNC_QUERIES=True) @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_chart_data_async_cached_sync_response(self): """ Chart data API: Test chart data query returns results synchronously when results are already cached. """ async_query_manager.init_app(app) class QueryContext: result_format = ChartDataResultFormat.JSON result_type = ChartDataResultType.FULL cmd_run_val = { "query_context": QueryContext(), "queries": [{"query": "select * from foo"}], } with mock.patch.object( ChartDataCommand, "run", return_value=cmd_run_val ) as patched_run: self.query_context_payload["result_type"] = ChartDataResultType.FULL rv = self.post_assert_metric( CHART_DATA_URI, self.query_context_payload, "data" ) self.assertEqual(rv.status_code, 200) data = json.loads(rv.data.decode("utf-8")) patched_run.assert_called_once_with(force_cached=True) self.assertEqual(data, {"result": [{"query": "select * from foo"}]}) @with_feature_flags(GLOBAL_ASYNC_QUERIES=True) @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_chart_data_async_results_type(self): """ Chart data API: Test chart data query non-JSON format (async) """ async_query_manager.init_app(app) self.query_context_payload["result_type"] = "results" rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") self.assertEqual(rv.status_code, 200) @with_feature_flags(GLOBAL_ASYNC_QUERIES=True) @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_chart_data_async_invalid_token(self): """ Chart data API: Test chart data query (async) """ async_query_manager.init_app(app) test_client.set_cookie( "localhost", app.config["GLOBAL_ASYNC_QUERIES_JWT_COOKIE_NAME"], "foo" ) rv = test_client.post(CHART_DATA_URI, json=self.query_context_payload) self.assertEqual(rv.status_code, 401) @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_chart_data_rowcount(self): """ Chart data API: Query total rows """ expected_row_count = self.get_expected_row_count("client_id_4") self.query_context_payload["queries"][0]["is_rowcount"] = True self.query_context_payload["queries"][0]["groupby"] = ["name"] rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") assert rv.json["result"][0]["data"][0]["rowcount"] == expected_row_count @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_with_timegrains_and_columns_result_types(self): """ Chart data API: Query timegrains and columns """ self.query_context_payload["queries"] = [ {"result_type": ChartDataResultType.TIMEGRAINS}, {"result_type": ChartDataResultType.COLUMNS}, ] result = self.post_assert_metric( CHART_DATA_URI, self.query_context_payload, "data" ).json["result"] timegrain_data_keys = result[0]["data"][0].keys() column_data_keys = result[1]["data"][0].keys() assert list(timegrain_data_keys) == [ "name", "function", "duration", ] assert list(column_data_keys) == [ "column_name", "verbose_name", "dtype", ] @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_with_series_limit(self): SERIES_LIMIT = 5 self.query_context_payload["queries"][0]["columns"] = ["state", "name"] self.query_context_payload["queries"][0]["series_columns"] = ["name"] self.query_context_payload["queries"][0]["series_limit"] = SERIES_LIMIT rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") data = rv.json["result"][0]["data"] unique_names = set(row["name"] for row in data) self.maxDiff = None self.assertEqual(len(unique_names), SERIES_LIMIT) self.assertEqual( set(column for column in data[0].keys()), {"state", "name", "sum__num"} ) @pytest.mark.usefixtures( "create_annotation_layers", "load_birth_names_dashboard_with_slices" ) def test_with_annotations_layers__annotations_data_returned(self): """ Chart data API: Test chart data query """ annotation_layers = [] self.query_context_payload["queries"][0][ "annotation_layers" ] = annotation_layers # formula annotation_layers.append(ANNOTATION_LAYERS[AnnotationType.FORMULA]) # interval interval_layer = ( db.session.query(AnnotationLayer) .filter(AnnotationLayer.name == "name1") .one() ) interval = ANNOTATION_LAYERS[AnnotationType.INTERVAL] interval["value"] = interval_layer.id annotation_layers.append(interval) # event event_layer = ( db.session.query(AnnotationLayer) .filter(AnnotationLayer.name == "name2") .one() ) event = ANNOTATION_LAYERS[AnnotationType.EVENT] event["value"] = event_layer.id annotation_layers.append(event) rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") self.assertEqual(rv.status_code, 200) data = json.loads(rv.data.decode("utf-8")) # response should only contain interval and event data, not formula self.assertEqual(len(data["result"][0]["annotation_data"]), 2) @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_with_virtual_table_with_colons_as_datasource(self): """ Chart data API: test query with literal colon characters in query, metrics, where clause and filters """ owner = self.get_user("admin") table = SqlaTable( table_name="virtual_table_1", schema=get_example_default_schema(), owners=[owner], database=get_example_database(), sql="select ':foo' as foo, ':bar:' as bar, state, num from birth_names", ) db.session.add(table) db.session.commit() table.fetch_metadata() request_payload = self.query_context_payload request_payload["datasource"] = { "type": "table", "id": table.id, } request_payload["queries"][0]["columns"] = ["foo", "bar", "state"] request_payload["queries"][0]["where"] = "':abc' != ':xyz:qwerty'" request_payload["queries"][0]["orderby"] = None request_payload["queries"][0]["metrics"] = [ { "expressionType": AdhocMetricExpressionType.SQL, "sqlExpression": "sum(case when state = ':asdf' then 0 else 1 end)", "label": "count", } ] request_payload["queries"][0]["filters"] = [ { "col": "foo", "op": "!=", "val": ":qwerty:", } ] rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") db.session.delete(table) db.session.commit() assert rv.status_code == 200 result = rv.json["result"][0] data = result["data"] assert {col for col in data[0].keys()} == {"foo", "bar", "state", "count"} # make sure results and query parameters are unescaped assert {row["foo"] for row in data} == {":foo"} assert {row["bar"] for row in data} == {":bar:"} assert "':asdf'" in result["query"] assert "':xyz:qwerty'" in result["query"] assert "':qwerty:'" in result["query"] @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_with_table_columns_without_metrics(self): request_payload = self.query_context_payload request_payload["queries"][0]["columns"] = ["name", "gender"] request_payload["queries"][0]["metrics"] = None request_payload["queries"][0]["orderby"] = [] rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") result = rv.json["result"][0] assert rv.status_code == 200 assert "name" in result["colnames"] assert "gender" in result["colnames"] assert "name" in result["query"] assert "gender" in result["query"] assert list(result["data"][0].keys()) == ["name", "gender"] @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_with_adhoc_column_without_metrics(self): request_payload = self.query_context_payload request_payload["queries"][0]["columns"] = [ "name", { "label": "num divide by 10", "sqlExpression": "num/10", "expressionType": "SQL", }, ] request_payload["queries"][0]["metrics"] = None request_payload["queries"][0]["orderby"] = [] rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") result = rv.json["result"][0] assert rv.status_code == 200 assert "num divide by 10" in result["colnames"] assert "name" in result["colnames"] assert "num divide by 10" in result["query"] assert "name" in result["query"] assert list(result["data"][0].keys()) == ["name", "num divide by 10"] @pytest.mark.chart_data_flow class TestGetChartDataApi(BaseTestChartDataApi): @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_get_data_when_query_context_is_null(self): """ Chart data API: Test GET endpoint when query context is null """ chart = db.session.query(Slice).filter_by(slice_name="Genders").one() rv = self.get_assert_metric(f"api/v1/chart/{chart.id}/data/", "get_data") data = json.loads(rv.data.decode("utf-8")) assert data == { "message": "Chart has no query context saved. Please save the chart again." } @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_chart_data_get(self): """ Chart data API: Test GET endpoint """ chart = db.session.query(Slice).filter_by(slice_name="Genders").one() chart.query_context = json.dumps( { "datasource": {"id": chart.table.id, "type": "table"}, "force": False, "queries": [ { "time_range": "1900-01-01T00:00:00 : 2000-01-01T00:00:00", "granularity": "ds", "filters": [], "extras": { "having": "", "where": "", }, "applied_time_extras": {}, "columns": ["gender"], "metrics": ["sum__num"], "orderby": [["sum__num", False]], "annotation_layers": [], "row_limit": 50000, "timeseries_limit": 0, "order_desc": True, "url_params": {}, "custom_params": {}, "custom_form_data": {}, } ], "result_format": "json", "result_type": "full", } ) rv = self.get_assert_metric(f"api/v1/chart/{chart.id}/data/", "get_data") assert rv.mimetype == "application/json" data = json.loads(rv.data.decode("utf-8")) assert data["result"][0]["status"] == "success" assert data["result"][0]["rowcount"] == 2 @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_chart_data_get_forced(self): """ Chart data API: Test GET endpoint with force cache parameter """ chart = db.session.query(Slice).filter_by(slice_name="Genders").one() chart.query_context = json.dumps( { "datasource": {"id": chart.table.id, "type": "table"}, "force": False, "queries": [ { "time_range": "1900-01-01T00:00:00 : 2000-01-01T00:00:00", "granularity": "ds", "filters": [], "extras": { "having": "", "having_druid": [], "where": "", }, "applied_time_extras": {}, "columns": ["gender"], "metrics": ["sum__num"], "orderby": [["sum__num", False]], "annotation_layers": [], "row_limit": 50000, "timeseries_limit": 0, "order_desc": True, "url_params": {}, "custom_params": {}, "custom_form_data": {}, } ], "result_format": "json", "result_type": "full", } ) self.get_assert_metric(f"api/v1/chart/{chart.id}/data/?force=true", "get_data") # should burst cache rv = self.get_assert_metric( f"api/v1/chart/{chart.id}/data/?force=true", "get_data" ) assert rv.json["result"][0]["is_cached"] is None # should get response from the cache rv = self.get_assert_metric(f"api/v1/chart/{chart.id}/data/", "get_data") assert rv.json["result"][0]["is_cached"] @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @with_feature_flags(GLOBAL_ASYNC_QUERIES=True) @mock.patch("superset.charts.data.api.QueryContextCacheLoader") def test_chart_data_cache(self, cache_loader): """ Chart data cache API: Test chart data async cache request """ async_query_manager.init_app(app) cache_loader.load.return_value = self.query_context_payload orig_run = ChartDataCommand.run def mock_run(self, **kwargs): assert kwargs["force_cached"] == True # override force_cached to get result from DB return orig_run(self, force_cached=False) with mock.patch.object(ChartDataCommand, "run", new=mock_run): rv = self.get_assert_metric( f"{CHART_DATA_URI}/test-cache-key", "data_from_cache" ) data = json.loads(rv.data.decode("utf-8")) expected_row_count = self.get_expected_row_count("client_id_3") self.assertEqual(rv.status_code, 200) self.assertEqual(data["result"][0]["rowcount"], expected_row_count) @with_feature_flags(GLOBAL_ASYNC_QUERIES=True) @mock.patch("superset.charts.data.api.QueryContextCacheLoader") @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_chart_data_cache_run_failed(self, cache_loader): """ Chart data cache API: Test chart data async cache request with run failure """ async_query_manager.init_app(app) cache_loader.load.return_value = self.query_context_payload rv = self.get_assert_metric( f"{CHART_DATA_URI}/test-cache-key", "data_from_cache" ) data = json.loads(rv.data.decode("utf-8")) self.assertEqual(rv.status_code, 422) self.assertEqual(data["message"], "Error loading data from cache") @with_feature_flags(GLOBAL_ASYNC_QUERIES=True) @mock.patch("superset.charts.data.api.QueryContextCacheLoader") @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_chart_data_cache_no_login(self, cache_loader): """ Chart data cache API: Test chart data async cache request (no login) """ self.logout() async_query_manager.init_app(app) cache_loader.load.return_value = self.query_context_payload orig_run = ChartDataCommand.run def mock_run(self, **kwargs): assert kwargs["force_cached"] == True # override force_cached to get result from DB return orig_run(self, force_cached=False) with mock.patch.object(ChartDataCommand, "run", new=mock_run): rv = self.client.get( f"{CHART_DATA_URI}/test-cache-key", ) self.assertEqual(rv.status_code, 401) @with_feature_flags(GLOBAL_ASYNC_QUERIES=True) def test_chart_data_cache_key_error(self): """ Chart data cache API: Test chart data async cache request with invalid cache key """ async_query_manager.init_app(app) rv = self.get_assert_metric( f"{CHART_DATA_URI}/test-cache-key", "data_from_cache" ) self.assertEqual(rv.status_code, 404) @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_chart_data_with_adhoc_column(self): """ Chart data API: Test query with adhoc column in both select and where clause """ self.login(username="admin") request_payload = get_query_context("birth_names") request_payload["queries"][0]["columns"] = [ADHOC_COLUMN_FIXTURE] request_payload["queries"][0]["filters"] = [ {"col": ADHOC_COLUMN_FIXTURE, "op": "IN", "val": ["male", "female"]} ] rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") response_payload = json.loads(rv.data.decode("utf-8")) result = response_payload["result"][0] data = result["data"] assert {column for column in data[0].keys()} == {"male_or_female", "sum__num"} unique_genders = {row["male_or_female"] for row in data} assert unique_genders == {"male", "female"} assert result["applied_filters"] == [{"column": "male_or_female"}] @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_chart_data_with_incompatible_adhoc_column(self): """ Chart data API: Test query with adhoc column that fails to run on this dataset """ self.login(username="admin") request_payload = get_query_context("birth_names") request_payload["queries"][0]["columns"] = [ADHOC_COLUMN_FIXTURE] request_payload["queries"][0]["filters"] = [ {"col": INCOMPATIBLE_ADHOC_COLUMN_FIXTURE, "op": "IN", "val": ["Exciting"]}, {"col": ADHOC_COLUMN_FIXTURE, "op": "IN", "val": ["male", "female"]}, ] rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") response_payload = json.loads(rv.data.decode("utf-8")) result = response_payload["result"][0] data = result["data"] assert {column for column in data[0].keys()} == {"male_or_female", "sum__num"} unique_genders = {row["male_or_female"] for row in data} assert unique_genders == {"male", "female"} assert result["applied_filters"] == [{"column": "male_or_female"}] assert result["rejected_filters"] == [ { "column": "exciting_or_boring", "reason": ExtraFiltersReasonType.COL_NOT_IN_DATASOURCE, } ] @pytest.fixture() def physical_query_context(physical_dataset) -> Dict[str, Any]: return { "datasource": { "type": physical_dataset.type, "id": physical_dataset.id, }, "queries": [ { "columns": ["col1"], "metrics": ["count"], "orderby": [["col1", True]], } ], "result_type": ChartDataResultType.FULL, "force": True, } @mock.patch( "superset.common.query_context_processor.config", { **app.config, "CACHE_DEFAULT_TIMEOUT": 1234, "DATA_CACHE_CONFIG": { **app.config["DATA_CACHE_CONFIG"], "CACHE_DEFAULT_TIMEOUT": None, }, }, ) def test_cache_default_timeout(test_client, login_as_admin, physical_query_context): rv = test_client.post(CHART_DATA_URI, json=physical_query_context) assert rv.json["result"][0]["cache_timeout"] == 1234 def test_custom_cache_timeout(test_client, login_as_admin, physical_query_context): physical_query_context["custom_cache_timeout"] = 5678 rv = test_client.post(CHART_DATA_URI, json=physical_query_context) assert rv.json["result"][0]["cache_timeout"] == 5678 @mock.patch( "superset.common.query_context_processor.config", { **app.config, "CACHE_DEFAULT_TIMEOUT": 100000, "DATA_CACHE_CONFIG": { **app.config["DATA_CACHE_CONFIG"], "CACHE_DEFAULT_TIMEOUT": 3456, }, }, ) def test_data_cache_default_timeout( test_client, login_as_admin, physical_query_context, ): rv = test_client.post(CHART_DATA_URI, json=physical_query_context) assert rv.json["result"][0]["cache_timeout"] == 3456 def test_chart_cache_timeout( test_client, login_as_admin, physical_query_context, load_energy_table_with_slice: List[Slice], ): # should override datasource cache timeout slice_with_cache_timeout = load_energy_table_with_slice[0] slice_with_cache_timeout.cache_timeout = 20 db.session.merge(slice_with_cache_timeout) datasource: SqlaTable = ( db.session.query(SqlaTable) .filter(SqlaTable.id == physical_query_context["datasource"]["id"]) .first() ) datasource.cache_timeout = 1254 db.session.merge(datasource) db.session.commit() physical_query_context["form_data"] = {"slice_id": slice_with_cache_timeout.id} rv = test_client.post(CHART_DATA_URI, json=physical_query_context) assert rv.json["result"][0]["cache_timeout"] == 20 @mock.patch( "superset.common.query_context_processor.config", { **app.config, "DATA_CACHE_CONFIG": { **app.config["DATA_CACHE_CONFIG"], "CACHE_DEFAULT_TIMEOUT": 1010, }, }, ) def test_chart_cache_timeout_not_present( test_client, login_as_admin, physical_query_context ): # should use datasource cache, if it's present datasource: SqlaTable = ( db.session.query(SqlaTable) .filter(SqlaTable.id == physical_query_context["datasource"]["id"]) .first() ) datasource.cache_timeout = 1980 db.session.merge(datasource) db.session.commit() rv = test_client.post(CHART_DATA_URI, json=physical_query_context) assert rv.json["result"][0]["cache_timeout"] == 1980 @mock.patch( "superset.common.query_context_processor.config", { **app.config, "DATA_CACHE_CONFIG": { **app.config["DATA_CACHE_CONFIG"], "CACHE_DEFAULT_TIMEOUT": 1010, }, }, ) def test_chart_cache_timeout_chart_not_found( test_client, login_as_admin, physical_query_context ): # should use default timeout physical_query_context["form_data"] = {"slice_id": 0} rv = test_client.post(CHART_DATA_URI, json=physical_query_context) assert rv.json["result"][0]["cache_timeout"] == 1010 @pytest.mark.parametrize( "status_code,extras", [ (200, {"where": "1 = 1"}), (200, {"having": "count(*) > 0"}), (400, {"where": "col1 in (select distinct col1 from physical_dataset)"}), (400, {"having": "count(*) > (select count(*) from physical_dataset)"}), ], ) @with_feature_flags(ALLOW_ADHOC_SUBQUERY=False) @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_chart_data_subquery_not_allowed( test_client, login_as_admin, physical_dataset, physical_query_context, status_code, extras, ): physical_query_context["queries"][0]["extras"] = extras rv = test_client.post(CHART_DATA_URI, json=physical_query_context) assert rv.status_code == status_code @pytest.mark.parametrize( "status_code,extras", [ (200, {"where": "1 = 1"}), (200, {"having": "count(*) > 0"}), (200, {"where": "col1 in (select distinct col1 from physical_dataset)"}), (200, {"having": "count(*) > (select count(*) from physical_dataset)"}), ], ) @with_feature_flags(ALLOW_ADHOC_SUBQUERY=True) @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_chart_data_subquery_allowed( test_client, login_as_admin, physical_dataset, physical_query_context, status_code, extras, ): physical_query_context["queries"][0]["extras"] = extras rv = test_client.post(CHART_DATA_URI, json=physical_query_context) assert rv.status_code == status_code