superset/tests/charts/api_tests.py

924 lines
35 KiB
Python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
"""Unit tests for Superset"""
import json
from typing import List, Optional
from datetime import datetime
from unittest import mock
import humanize
import prison
import pytest
from sqlalchemy.sql import func
from superset.utils.core import get_example_database
from tests.test_app import app
from superset.connectors.connector_registry import ConnectorRegistry
from superset.extensions import db, security_manager
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.utils import core as utils
from tests.base_api_tests import ApiOwnersTestCaseMixin
from tests.base_tests import SupersetTestCase
from tests.fixtures.query_context import get_query_context
CHART_DATA_URI = "api/v1/chart/data"
class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin):
resource_name = "chart"
def insert_chart(
self,
slice_name: str,
owners: List[int],
datasource_id: int,
datasource_type: str = "table",
description: Optional[str] = None,
viz_type: Optional[str] = None,
params: Optional[str] = None,
cache_timeout: Optional[int] = None,
) -> Slice:
obj_owners = list()
for owner in owners:
user = db.session.query(security_manager.user_model).get(owner)
obj_owners.append(user)
datasource = ConnectorRegistry.get_datasource(
datasource_type, datasource_id, db.session
)
slice = Slice(
slice_name=slice_name,
datasource_id=datasource.id,
datasource_name=datasource.name,
datasource_type=datasource.type,
owners=obj_owners,
description=description,
viz_type=viz_type,
params=params,
cache_timeout=cache_timeout,
)
db.session.add(slice)
db.session.commit()
return slice
def test_delete_chart(self):
"""
Chart API: Test delete
"""
admin_id = self.get_user("admin").id
chart_id = self.insert_chart("name", [admin_id], 1).id
self.login(username="admin")
uri = f"api/v1/chart/{chart_id}"
rv = self.delete_assert_metric(uri, "delete")
self.assertEqual(rv.status_code, 200)
model = db.session.query(Slice).get(chart_id)
self.assertEqual(model, None)
def test_delete_bulk_charts(self):
"""
Chart API: Test delete bulk
"""
admin_id = self.get_user("admin").id
chart_count = 4
chart_ids = list()
for chart_name_index in range(chart_count):
chart_ids.append(
self.insert_chart(f"title{chart_name_index}", [admin_id], 1).id
)
self.login(username="admin")
argument = chart_ids
uri = f"api/v1/chart/?q={prison.dumps(argument)}"
rv = self.delete_assert_metric(uri, "bulk_delete")
self.assertEqual(rv.status_code, 200)
response = json.loads(rv.data.decode("utf-8"))
expected_response = {"message": f"Deleted {chart_count} charts"}
self.assertEqual(response, expected_response)
for chart_id in chart_ids:
model = db.session.query(Slice).get(chart_id)
self.assertEqual(model, None)
def test_delete_bulk_chart_bad_request(self):
"""
Chart API: Test delete bulk bad request
"""
chart_ids = [1, "a"]
self.login(username="admin")
argument = chart_ids
uri = f"api/v1/chart/?q={prison.dumps(argument)}"
rv = self.delete_assert_metric(uri, "bulk_delete")
self.assertEqual(rv.status_code, 400)
def test_delete_not_found_chart(self):
"""
Chart API: Test not found delete
"""
self.login(username="admin")
chart_id = 1000
uri = f"api/v1/chart/{chart_id}"
rv = self.delete_assert_metric(uri, "delete")
self.assertEqual(rv.status_code, 404)
def test_delete_bulk_charts_not_found(self):
"""
Chart API: Test delete bulk not found
"""
max_id = db.session.query(func.max(Slice.id)).scalar()
chart_ids = [max_id + 1, max_id + 2]
self.login(username="admin")
argument = chart_ids
uri = f"api/v1/chart/?q={prison.dumps(argument)}"
rv = self.delete_assert_metric(uri, "bulk_delete")
self.assertEqual(rv.status_code, 404)
def test_delete_chart_admin_not_owned(self):
"""
Chart API: Test admin delete not owned
"""
gamma_id = self.get_user("gamma").id
chart_id = self.insert_chart("title", [gamma_id], 1).id
self.login(username="admin")
uri = f"api/v1/chart/{chart_id}"
rv = self.delete_assert_metric(uri, "delete")
self.assertEqual(rv.status_code, 200)
model = db.session.query(Slice).get(chart_id)
self.assertEqual(model, None)
def test_delete_bulk_chart_admin_not_owned(self):
"""
Chart API: Test admin delete bulk not owned
"""
gamma_id = self.get_user("gamma").id
chart_count = 4
chart_ids = list()
for chart_name_index in range(chart_count):
chart_ids.append(
self.insert_chart(f"title{chart_name_index}", [gamma_id], 1).id
)
self.login(username="admin")
argument = chart_ids
uri = f"api/v1/chart/?q={prison.dumps(argument)}"
rv = self.delete_assert_metric(uri, "bulk_delete")
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
expected_response = {"message": f"Deleted {chart_count} charts"}
self.assertEqual(response, expected_response)
for chart_id in chart_ids:
model = db.session.query(Slice).get(chart_id)
self.assertEqual(model, None)
def test_delete_chart_not_owned(self):
"""
Chart API: Test delete try not owned
"""
user_alpha1 = self.create_user(
"alpha1", "password", "Alpha", email="alpha1@superset.org"
)
user_alpha2 = self.create_user(
"alpha2", "password", "Alpha", email="alpha2@superset.org"
)
chart = self.insert_chart("title", [user_alpha1.id], 1)
self.login(username="alpha2", password="password")
uri = f"api/v1/chart/{chart.id}"
rv = self.delete_assert_metric(uri, "delete")
self.assertEqual(rv.status_code, 403)
db.session.delete(chart)
db.session.delete(user_alpha1)
db.session.delete(user_alpha2)
db.session.commit()
def test_delete_bulk_chart_not_owned(self):
"""
Chart API: Test delete bulk try not owned
"""
user_alpha1 = self.create_user(
"alpha1", "password", "Alpha", email="alpha1@superset.org"
)
user_alpha2 = self.create_user(
"alpha2", "password", "Alpha", email="alpha2@superset.org"
)
chart_count = 4
charts = list()
for chart_name_index in range(chart_count):
charts.append(
self.insert_chart(f"title{chart_name_index}", [user_alpha1.id], 1)
)
owned_chart = self.insert_chart("title_owned", [user_alpha2.id], 1)
self.login(username="alpha2", password="password")
# verify we can't delete not owned charts
arguments = [chart.id for chart in charts]
uri = f"api/v1/chart/?q={prison.dumps(arguments)}"
rv = self.delete_assert_metric(uri, "bulk_delete")
self.assertEqual(rv.status_code, 403)
response = json.loads(rv.data.decode("utf-8"))
expected_response = {"message": "Forbidden"}
self.assertEqual(response, expected_response)
# # nothing is deleted in bulk with a list of owned and not owned charts
arguments = [chart.id for chart in charts] + [owned_chart.id]
uri = f"api/v1/chart/?q={prison.dumps(arguments)}"
rv = self.delete_assert_metric(uri, "bulk_delete")
self.assertEqual(rv.status_code, 403)
response = json.loads(rv.data.decode("utf-8"))
expected_response = {"message": "Forbidden"}
self.assertEqual(response, expected_response)
for chart in charts:
db.session.delete(chart)
db.session.delete(owned_chart)
db.session.delete(user_alpha1)
db.session.delete(user_alpha2)
db.session.commit()
def test_create_chart(self):
"""
Chart API: Test create chart
"""
admin_id = self.get_user("admin").id
chart_data = {
"slice_name": "name1",
"description": "description1",
"owners": [admin_id],
"viz_type": "viz_type1",
"params": "1234",
"cache_timeout": 1000,
"datasource_id": 1,
"datasource_type": "table",
"dashboards": [1, 2],
}
self.login(username="admin")
uri = f"api/v1/chart/"
rv = self.post_assert_metric(uri, chart_data, "post")
self.assertEqual(rv.status_code, 201)
data = json.loads(rv.data.decode("utf-8"))
model = db.session.query(Slice).get(data.get("id"))
db.session.delete(model)
db.session.commit()
def test_create_simple_chart(self):
"""
Chart API: Test create simple chart
"""
chart_data = {
"slice_name": "title1",
"datasource_id": 1,
"datasource_type": "table",
}
self.login(username="admin")
uri = f"api/v1/chart/"
rv = self.post_assert_metric(uri, chart_data, "post")
self.assertEqual(rv.status_code, 201)
data = json.loads(rv.data.decode("utf-8"))
model = db.session.query(Slice).get(data.get("id"))
db.session.delete(model)
db.session.commit()
def test_create_chart_validate_owners(self):
"""
Chart API: Test create validate owners
"""
chart_data = {
"slice_name": "title1",
"datasource_id": 1,
"datasource_type": "table",
"owners": [1000],
}
self.login(username="admin")
uri = f"api/v1/chart/"
rv = self.post_assert_metric(uri, chart_data, "post")
self.assertEqual(rv.status_code, 422)
response = json.loads(rv.data.decode("utf-8"))
expected_response = {"message": {"owners": ["Owners are invalid"]}}
self.assertEqual(response, expected_response)
def test_create_chart_validate_params(self):
"""
Chart API: Test create validate params json
"""
chart_data = {
"slice_name": "title1",
"datasource_id": 1,
"datasource_type": "table",
"params": '{"A:"a"}',
}
self.login(username="admin")
uri = f"api/v1/chart/"
rv = self.post_assert_metric(uri, chart_data, "post")
self.assertEqual(rv.status_code, 400)
def test_create_chart_validate_datasource(self):
"""
Chart API: Test create validate datasource
"""
self.login(username="admin")
chart_data = {
"slice_name": "title1",
"datasource_id": 1,
"datasource_type": "unknown",
}
uri = f"api/v1/chart/"
rv = self.post_assert_metric(uri, chart_data, "post")
self.assertEqual(rv.status_code, 400)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(
response,
{"message": {"datasource_type": ["Must be one of: druid, table, view."]}},
)
chart_data = {
"slice_name": "title1",
"datasource_id": 0,
"datasource_type": "table",
}
uri = f"api/v1/chart/"
rv = self.post_assert_metric(uri, chart_data, "post")
self.assertEqual(rv.status_code, 422)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(
response, {"message": {"datasource_id": ["Datasource does not exist"]}}
)
def test_update_chart(self):
"""
Chart API: Test update
"""
admin = self.get_user("admin")
gamma = self.get_user("gamma")
chart_id = self.insert_chart("title", [admin.id], 1).id
birth_names_table_id = SupersetTestCase.get_table_by_name("birth_names").id
chart_data = {
"slice_name": "title1_changed",
"description": "description1",
"owners": [gamma.id],
"viz_type": "viz_type1",
"params": """{"a": 1}""",
"cache_timeout": 1000,
"datasource_id": birth_names_table_id,
"datasource_type": "table",
"dashboards": [1],
}
self.login(username="admin")
uri = f"api/v1/chart/{chart_id}"
rv = self.put_assert_metric(uri, chart_data, "put")
self.assertEqual(rv.status_code, 200)
model = db.session.query(Slice).get(chart_id)
related_dashboard = db.session.query(Dashboard).get(1)
self.assertEqual(model.slice_name, "title1_changed")
self.assertEqual(model.description, "description1")
self.assertIn(admin, model.owners)
self.assertIn(gamma, model.owners)
self.assertEqual(model.viz_type, "viz_type1")
self.assertEqual(model.params, """{"a": 1}""")
self.assertEqual(model.cache_timeout, 1000)
self.assertEqual(model.datasource_id, birth_names_table_id)
self.assertEqual(model.datasource_type, "table")
self.assertEqual(model.datasource_name, "birth_names")
self.assertIn(related_dashboard, model.dashboards)
db.session.delete(model)
db.session.commit()
def test_update_chart_new_owner(self):
"""
Chart API: Test update set new owner to current user
"""
gamma = self.get_user("gamma")
admin = self.get_user("admin")
chart_id = self.insert_chart("title", [gamma.id], 1).id
chart_data = {"slice_name": "title1_changed"}
self.login(username="admin")
uri = f"api/v1/chart/{chart_id}"
rv = self.put_assert_metric(uri, chart_data, "put")
self.assertEqual(rv.status_code, 200)
model = db.session.query(Slice).get(chart_id)
self.assertIn(admin, model.owners)
db.session.delete(model)
db.session.commit()
def test_update_chart_not_owned(self):
"""
Chart API: Test update not owned
"""
user_alpha1 = self.create_user(
"alpha1", "password", "Alpha", email="alpha1@superset.org"
)
user_alpha2 = self.create_user(
"alpha2", "password", "Alpha", email="alpha2@superset.org"
)
chart = self.insert_chart("title", [user_alpha1.id], 1)
self.login(username="alpha2", password="password")
chart_data = {"slice_name": "title1_changed"}
uri = f"api/v1/chart/{chart.id}"
rv = self.put_assert_metric(uri, chart_data, "put")
self.assertEqual(rv.status_code, 403)
db.session.delete(chart)
db.session.delete(user_alpha1)
db.session.delete(user_alpha2)
db.session.commit()
def test_update_chart_validate_datasource(self):
"""
Chart API: Test update validate datasource
"""
admin = self.get_user("admin")
chart = self.insert_chart("title", [admin.id], 1)
self.login(username="admin")
chart_data = {"datasource_id": 1, "datasource_type": "unknown"}
uri = f"api/v1/chart/{chart.id}"
rv = self.put_assert_metric(uri, chart_data, "put")
self.assertEqual(rv.status_code, 400)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(
response,
{"message": {"datasource_type": ["Must be one of: druid, table, view."]}},
)
chart_data = {"datasource_id": 0, "datasource_type": "table"}
uri = f"api/v1/chart/{chart.id}"
rv = self.put_assert_metric(uri, chart_data, "put")
self.assertEqual(rv.status_code, 422)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(
response, {"message": {"datasource_id": ["Datasource does not exist"]}}
)
db.session.delete(chart)
db.session.commit()
def test_update_chart_validate_owners(self):
"""
Chart API: Test update validate owners
"""
chart_data = {
"slice_name": "title1",
"datasource_id": 1,
"datasource_type": "table",
"owners": [1000],
}
self.login(username="admin")
uri = f"api/v1/chart/"
rv = self.client.post(uri, json=chart_data)
self.assertEqual(rv.status_code, 422)
response = json.loads(rv.data.decode("utf-8"))
expected_response = {"message": {"owners": ["Owners are invalid"]}}
self.assertEqual(response, expected_response)
def test_get_chart(self):
"""
Chart API: Test get chart
"""
admin = self.get_user("admin")
chart = self.insert_chart("title", [admin.id], 1)
self.login(username="admin")
uri = f"api/v1/chart/{chart.id}"
rv = self.get_assert_metric(uri, "get")
self.assertEqual(rv.status_code, 200)
expected_result = {
"cache_timeout": None,
"dashboards": [],
"description": None,
"owners": [
{
"id": 1,
"username": "admin",
"first_name": "admin",
"last_name": "user",
}
],
"params": None,
"slice_name": "title",
"viz_type": None,
}
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(data["result"], expected_result)
db.session.delete(chart)
db.session.commit()
def test_get_chart_not_found(self):
"""
Chart API: Test get chart not found
"""
chart_id = 1000
self.login(username="admin")
uri = f"api/v1/chart/{chart_id}"
rv = self.get_assert_metric(uri, "get")
self.assertEqual(rv.status_code, 404)
def test_get_chart_no_data_access(self):
"""
Chart API: Test get chart without data access
"""
self.login(username="gamma")
chart_no_access = (
db.session.query(Slice)
.filter_by(slice_name="Girl Name Cloud")
.one_or_none()
)
uri = f"api/v1/chart/{chart_no_access.id}"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
def test_get_charts(self):
"""
Chart API: Test get charts
"""
self.login(username="admin")
uri = f"api/v1/chart/"
rv = self.get_assert_metric(uri, "get_list")
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(data["count"], 33)
def test_get_charts_changed_on(self):
"""
Dashboard API: Test get charts changed on
"""
admin = self.get_user("admin")
start_changed_on = datetime.now()
chart = self.insert_chart("foo_a", [admin.id], 1, description="ZY_bar")
self.login(username="admin")
arguments = {
"order_column": "changed_on_delta_humanized",
"order_direction": "desc",
}
uri = f"api/v1/chart/?q={prison.dumps(arguments)}"
rv = self.get_assert_metric(uri, "get_list")
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(
data["result"][0]["changed_on_delta_humanized"],
humanize.naturaltime(datetime.now() - start_changed_on),
)
# rollback changes
db.session.delete(chart)
db.session.commit()
def test_get_charts_filter(self):
"""
Chart API: Test get charts filter
"""
self.login(username="admin")
arguments = {"filters": [{"col": "slice_name", "opr": "sw", "value": "G"}]}
uri = f"api/v1/chart/?q={prison.dumps(arguments)}"
rv = self.get_assert_metric(uri, "get_list")
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(data["count"], 5)
def test_get_charts_custom_filter(self):
"""
Chart API: Test get charts custom filter
"""
admin = self.get_user("admin")
chart1 = self.insert_chart("foo_a", [admin.id], 1, description="ZY_bar")
chart2 = self.insert_chart("zy_foo", [admin.id], 1, description="desc1")
chart3 = self.insert_chart("foo_b", [admin.id], 1, description="desc1zy_")
chart4 = self.insert_chart("bar", [admin.id], 1, description="foo")
arguments = {
"filters": [
{"col": "slice_name", "opr": "name_or_description", "value": "zy_"}
],
"order_column": "slice_name",
"order_direction": "asc",
"keys": ["none"],
"columns": ["slice_name", "description"],
}
self.login(username="admin")
uri = f"api/v1/chart/?q={prison.dumps(arguments)}"
rv = self.get_assert_metric(uri, "get_list")
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(data["count"], 3)
expected_response = [
{"description": "ZY_bar", "slice_name": "foo_a",},
{"description": "desc1zy_", "slice_name": "foo_b",},
{"description": "desc1", "slice_name": "zy_foo",},
]
for index, item in enumerate(data["result"]):
self.assertEqual(
item["description"], expected_response[index]["description"]
)
self.assertEqual(item["slice_name"], expected_response[index]["slice_name"])
self.logout()
self.login(username="gamma")
uri = f"api/v1/chart/?q={prison.dumps(arguments)}"
rv = self.get_assert_metric(uri, "get_list")
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(data["count"], 0)
# rollback changes
db.session.delete(chart1)
db.session.delete(chart2)
db.session.delete(chart3)
db.session.delete(chart4)
db.session.commit()
def test_get_charts_page(self):
"""
Chart API: Test get charts filter
"""
# Assuming we have 33 sample charts
self.login(username="admin")
arguments = {"page_size": 10, "page": 0}
uri = f"api/v1/chart/?q={prison.dumps(arguments)}"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(len(data["result"]), 10)
arguments = {"page_size": 10, "page": 3}
uri = f"api/v1/chart/?q={prison.dumps(arguments)}"
rv = self.get_assert_metric(uri, "get_list")
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(len(data["result"]), 3)
def test_get_charts_no_data_access(self):
"""
Chart API: Test get charts no data access
"""
self.login(username="gamma")
uri = f"api/v1/chart/"
rv = self.get_assert_metric(uri, "get_list")
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(data["count"], 0)
def test_chart_data_simple(self):
"""
Chart data API: Test chart data query
"""
self.login(username="admin")
table = self.get_table_by_name("birth_names")
request_payload = get_query_context(table.name, table.id, table.type)
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(data["result"][0]["rowcount"], 45)
def test_chart_data_limit_offset(self):
"""
Chart data API: Test chart data query with limit and offset
"""
self.login(username="admin")
table = self.get_table_by_name("birth_names")
request_payload = get_query_context(table.name, table.id, table.type)
request_payload["queries"][0]["row_limit"] = 5
request_payload["queries"][0]["row_offset"] = 0
request_payload["queries"][0]["orderby"] = [["name", True]]
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]
self.assertEqual(result["rowcount"], 5)
# TODO: fix offset for presto DB
if get_example_database().backend == "presto":
return
# ensure that offset works properly
offset = 2
expected_name = result["data"][offset]["name"]
request_payload["queries"][0]["row_offset"] = offset
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]
self.assertEqual(result["rowcount"], 5)
self.assertEqual(result["data"][0]["name"], expected_name)
@mock.patch(
"superset.common.query_object.config", {**app.config, "ROW_LIMIT": 7},
)
def test_chart_data_default_row_limit(self):
"""
Chart data API: Ensure row count doesn't exceed default limit
"""
self.login(username="admin")
table = self.get_table_by_name("birth_names")
request_payload = get_query_context(table.name, table.id, table.type)
del request_payload["queries"][0]["row_limit"]
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]
self.assertEqual(result["rowcount"], 7)
@mock.patch(
"superset.common.query_context.config", {**app.config, "SAMPLES_ROW_LIMIT": 5},
)
def test_chart_data_default_sample_limit(self):
"""
Chart data API: Ensure sample response row count doesn't exceed default limit
"""
self.login(username="admin")
table = self.get_table_by_name("birth_names")
request_payload = get_query_context(table.name, table.id, table.type)
request_payload["result_type"] = utils.ChartDataResultType.SAMPLES
request_payload["queries"][0]["row_limit"] = 10
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]
self.assertEqual(result["rowcount"], 5)
def test_chart_data_incorrect_result_type(self):
"""
Chart data API: Test chart data with unsupported result type
"""
self.login(username="admin")
table = self.get_table_by_name("birth_names")
request_payload = get_query_context(table.name, table.id, table.type)
request_payload["result_type"] = "qwerty"
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
self.assertEqual(rv.status_code, 400)
def test_chart_data_incorrect_result_format(self):
"""
Chart data API: Test chart data with unsupported result format
"""
self.login(username="admin")
table = self.get_table_by_name("birth_names")
request_payload = get_query_context(table.name, table.id, table.type)
request_payload["result_format"] = "qwerty"
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
self.assertEqual(rv.status_code, 400)
def test_chart_data_query_result_type(self):
"""
Chart data API: Test chart data with query result format
"""
self.login(username="admin")
table = self.get_table_by_name("birth_names")
request_payload = get_query_context(table.name, table.id, table.type)
request_payload["result_type"] = utils.ChartDataResultType.QUERY
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
self.assertEqual(rv.status_code, 200)
def test_chart_data_csv_result_format(self):
"""
Chart data API: Test chart data with CSV result format
"""
self.login(username="admin")
table = self.get_table_by_name("birth_names")
request_payload = get_query_context(table.name, table.id, table.type)
request_payload["result_format"] = "csv"
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
self.assertEqual(rv.status_code, 200)
def test_chart_data_mixed_case_filter_op(self):
"""
Chart data API: Ensure mixed case filter operator generates valid result
"""
self.login(username="admin")
table = self.get_table_by_name("birth_names")
request_payload = get_query_context(table.name, table.id, table.type)
request_payload["queries"][0]["filters"][0]["op"] = "In"
request_payload["queries"][0]["row_limit"] = 10
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]
self.assertEqual(result["rowcount"], 10)
def test_chart_data_prophet(self):
"""
Chart data API: Ensure prophet post transformation works
"""
pytest.importorskip("fbprophet")
self.login(username="admin")
table = self.get_table_by_name("birth_names")
request_payload = get_query_context(table.name, table.id, table.type)
time_grain = "P1Y"
request_payload["queries"][0]["is_timeseries"] = True
request_payload["queries"][0]["groupby"] = []
request_payload["queries"][0]["extras"] = {"time_grain_sqla": time_grain}
request_payload["queries"][0]["granularity"] = "ds"
request_payload["queries"][0]["post_processing"] = [
{
"operation": "prophet",
"options": {
"time_grain": time_grain,
"periods": 3,
"confidence_interval": 0.9,
},
}
]
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
print(rv.data)
self.assertEqual(rv.status_code, 200)
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]
row = result["data"][0]
self.assertIn("__timestamp", row)
self.assertIn("sum__num", row)
self.assertIn("sum__num__yhat", row)
self.assertIn("sum__num__yhat_upper", row)
self.assertIn("sum__num__yhat_lower", row)
self.assertEqual(result["rowcount"], 47)
def test_chart_data_no_data(self):
"""
Chart data API: Test chart data with empty result
"""
self.login(username="admin")
table = self.get_table_by_name("birth_names")
request_payload = get_query_context(table.name, table.id, table.type)
request_payload["queries"][0]["filters"] = [
{"col": "gender", "op": "==", "val": "foo"}
]
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
self.assertEqual(rv.status_code, 200)
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]
self.assertEqual(result["rowcount"], 0)
self.assertEqual(result["data"], [])
def test_chart_data_incorrect_request(self):
"""
Chart data API: Test chart data with invalid SQL
"""
self.login(username="admin")
table = self.get_table_by_name("birth_names")
request_payload = get_query_context(table.name, table.id, table.type)
request_payload["queries"][0]["filters"] = []
# erroneus WHERE-clause
request_payload["queries"][0]["extras"]["where"] = "(gender abc def)"
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
self.assertEqual(rv.status_code, 400)
def test_chart_data_with_invalid_datasource(self):
"""Chart data API: Test chart data query with invalid schema
"""
self.login(username="admin")
table = self.get_table_by_name("birth_names")
payload = get_query_context(table.name, table.id, table.type)
payload["datasource"] = "abc"
rv = self.post_assert_metric(CHART_DATA_URI, payload, "data")
self.assertEqual(rv.status_code, 400)
def test_chart_data_with_invalid_enum_value(self):
"""Chart data API: Test chart data query with invalid enum value
"""
self.login(username="admin")
table = self.get_table_by_name("birth_names")
payload = get_query_context(table.name, table.id, table.type)
payload["queries"][0]["extras"]["time_range_endpoints"] = [
"abc",
"EXCLUSIVE",
]
rv = self.client.post(CHART_DATA_URI, json=payload)
self.assertEqual(rv.status_code, 400)
def test_query_exec_not_allowed(self):
"""
Chart data API: Test chart data query not allowed
"""
self.login(username="gamma")
table = self.get_table_by_name("birth_names")
payload = get_query_context(table.name, table.id, table.type)
rv = self.post_assert_metric(CHART_DATA_URI, payload, "data")
self.assertEqual(rv.status_code, 401)
def test_chart_data_jinja_filter_request(self):
"""
Chart data API: Ensure request referencing filters via jinja renders a correct query
"""
self.login(username="admin")
table = self.get_table_by_name("birth_names")
request_payload = get_query_context(table.name, table.id, table.type)
request_payload["result_type"] = utils.ChartDataResultType.QUERY
request_payload["queries"][0]["filters"] = [
{"col": "gender", "op": "==", "val": "boy"}
]
request_payload["queries"][0]["extras"][
"where"
] = "('boy' = '{{ filter_values('gender', 'xyz' )[0] }}')"
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]["query"]
if get_example_database().backend != "presto":
assert "('boy' = 'boy')" in result