# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. # isort:skip_file """Unit tests for Superset""" from datetime import datetime, timedelta from unittest import mock import json import random import string import pytest import prison from sqlalchemy.sql import func import tests.integration_tests.test_app from superset import db, security_manager from superset.common.db_query_status import QueryStatus from superset.models.core import Database from superset.utils.database import get_example_database, get_main_database from superset.models.sql_lab import Query from tests.integration_tests.base_tests import SupersetTestCase QUERIES_FIXTURE_COUNT = 10 class TestQueryApi(SupersetTestCase): def insert_query( self, database_id: int, user_id: int, client_id: str, sql: str = "", select_sql: str = "", executed_sql: str = "", limit: int = 100, progress: int = 100, rows: int = 100, tab_name: str = "", status: str = "success", changed_on: datetime = datetime(2020, 1, 1), ) -> Query: database = db.session.query(Database).get(database_id) user = db.session.query(security_manager.user_model).get(user_id) query = Query( database=database, user=user, client_id=client_id, sql=sql, select_sql=select_sql, executed_sql=executed_sql, limit=limit, progress=progress, rows=rows, tab_name=tab_name, status=status, changed_on=changed_on, ) db.session.add(query) db.session.commit() return query @pytest.fixture() def create_queries(self): with self.create_app().app_context(): queries = [] admin_id = self.get_user("admin").id alpha_id = self.get_user("alpha").id example_database_id = get_example_database().id main_database_id = get_main_database().id for cx in range(QUERIES_FIXTURE_COUNT - 1): queries.append( self.insert_query( example_database_id, admin_id, self.get_random_string(), sql=f"SELECT col1, col2 from table{cx}", rows=cx, status=QueryStatus.SUCCESS if (cx % 2) == 0 else QueryStatus.RUNNING, ) ) queries.append( self.insert_query( main_database_id, alpha_id, self.get_random_string(), sql=f"SELECT col1, col2 from table{QUERIES_FIXTURE_COUNT}", rows=QUERIES_FIXTURE_COUNT, status=QueryStatus.SUCCESS, ) ) yield queries # rollback changes for query in queries: db.session.delete(query) db.session.commit() @staticmethod def get_random_string(length: int = 10): letters = string.ascii_letters return "".join(random.choice(letters) for i in range(length)) def test_get_query(self): """ Query API: Test get query """ admin = self.get_user("admin") client_id = self.get_random_string() example_db = get_example_database() query = self.insert_query( example_db.id, admin.id, client_id, sql="SELECT col1, col2 from table1", select_sql="SELECT col1, col2 from table1", executed_sql="SELECT col1, col2 from table1 LIMIT 100", ) self.login(username="admin") uri = f"api/v1/query/{query.id}" rv = self.client.get(uri) self.assertEqual(rv.status_code, 200) expected_result = { "database": {"id": example_db.id}, "client_id": client_id, "end_result_backend_time": None, "error_message": None, "executed_sql": "SELECT col1, col2 from table1 LIMIT 100", "limit": 100, "progress": 100, "results_key": None, "rows": 100, "schema": None, "select_as_cta": None, "select_as_cta_used": False, "select_sql": "SELECT col1, col2 from table1", "sql": "SELECT col1, col2 from table1", "sql_editor_id": None, "status": "success", "tab_name": "", "tmp_schema_name": None, "tmp_table_name": None, "tracking_url": None, } data = json.loads(rv.data.decode("utf-8")) self.assertIn("changed_on", data["result"]) for key, value in data["result"].items(): # We can't assert timestamp if key not in ( "changed_on", "end_time", "start_running_time", "start_time", "id", ): self.assertEqual(value, expected_result[key]) # rollback changes db.session.delete(query) db.session.commit() def test_get_query_not_found(self): """ Query API: Test get query not found """ admin = self.get_user("admin") client_id = self.get_random_string() query = self.insert_query(get_example_database().id, admin.id, client_id) max_id = db.session.query(func.max(Query.id)).scalar() self.login(username="admin") uri = f"api/v1/query/{max_id + 1}" rv = self.client.get(uri) self.assertEqual(rv.status_code, 404) db.session.delete(query) db.session.commit() def test_get_query_no_data_access(self): """ Query API: Test get query without data access """ gamma1 = self.create_user( "gamma_1", "password", "Gamma", email="gamma1@superset.org" ) gamma2 = self.create_user( "gamma_2", "password", "Gamma", email="gamma2@superset.org" ) # Add SQLLab role to these gamma users, so they have access to queries sqllab_role = self.get_role("sql_lab") gamma1.roles.append(sqllab_role) gamma2.roles.append(sqllab_role) gamma1_client_id = self.get_random_string() gamma2_client_id = self.get_random_string() query_gamma1 = self.insert_query( get_example_database().id, gamma1.id, gamma1_client_id ) query_gamma2 = self.insert_query( get_example_database().id, gamma2.id, gamma2_client_id ) # Gamma1 user, only sees their own queries self.login(username="gamma_1", password="password") uri = f"api/v1/query/{query_gamma2.id}" rv = self.client.get(uri) self.assertEqual(rv.status_code, 404) uri = f"api/v1/query/{query_gamma1.id}" rv = self.client.get(uri) self.assertEqual(rv.status_code, 200) # Gamma2 user, only sees their own queries self.logout() self.login(username="gamma_2", password="password") uri = f"api/v1/query/{query_gamma1.id}" rv = self.client.get(uri) self.assertEqual(rv.status_code, 404) uri = f"api/v1/query/{query_gamma2.id}" rv = self.client.get(uri) self.assertEqual(rv.status_code, 200) # Admin's have the "all query access" permission self.logout() self.login(username="admin") uri = f"api/v1/query/{query_gamma1.id}" rv = self.client.get(uri) self.assertEqual(rv.status_code, 200) uri = f"api/v1/query/{query_gamma2.id}" rv = self.client.get(uri) self.assertEqual(rv.status_code, 200) # rollback changes db.session.delete(query_gamma1) db.session.delete(query_gamma2) db.session.delete(gamma1) db.session.delete(gamma2) db.session.commit() @pytest.mark.usefixtures("create_queries") def test_get_list_query(self): """ Query API: Test get list query """ self.login(username="admin") uri = "api/v1/query/" rv = self.client.get(uri) self.assertEqual(rv.status_code, 200) data = json.loads(rv.data.decode("utf-8")) assert data["count"] == QUERIES_FIXTURE_COUNT # check expected columns assert sorted(list(data["result"][0].keys())) == [ "changed_on", "database", "end_time", "executed_sql", "id", "rows", "schema", "sql", "sql_tables", "start_time", "status", "tab_name", "tmp_table_name", "tracking_url", "user", ] assert sorted(list(data["result"][0]["user"].keys())) == [ "first_name", "id", "last_name", ] assert list(data["result"][0]["database"].keys()) == [ "database_name", ] @pytest.mark.usefixtures("create_queries") def test_get_list_query_filter_sql(self): """ Query API: Test get list query filter """ self.login(username="admin") arguments = {"filters": [{"col": "sql", "opr": "ct", "value": "table2"}]} uri = f"api/v1/query/?q={prison.dumps(arguments)}" rv = self.client.get(uri) assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) assert data["count"] == 1 @pytest.mark.usefixtures("create_queries") def test_get_list_query_filter_database(self): """ Query API: Test get list query filter database """ self.login(username="admin") database_id = get_main_database().id arguments = { "filters": [{"col": "database", "opr": "rel_o_m", "value": database_id}] } uri = f"api/v1/query/?q={prison.dumps(arguments)}" rv = self.client.get(uri) assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) assert data["count"] == 1 @pytest.mark.usefixtures("create_queries") def test_get_list_query_filter_user(self): """ Query API: Test get list query filter user """ self.login(username="admin") alpha_id = self.get_user("alpha").id arguments = {"filters": [{"col": "user", "opr": "rel_o_m", "value": alpha_id}]} uri = f"api/v1/query/?q={prison.dumps(arguments)}" rv = self.client.get(uri) assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) assert data["count"] == 1 @pytest.mark.usefixtures("create_queries") def test_get_list_query_filter_changed_on(self): """ Query API: Test get list query filter changed_on """ self.login(username="admin") arguments = { "filters": [ {"col": "changed_on", "opr": "lt", "value": "2020-02-01T00:00:00Z"}, {"col": "changed_on", "opr": "gt", "value": "2019-12-30T00:00:00Z"}, ] } uri = f"api/v1/query/?q={prison.dumps(arguments)}" rv = self.client.get(uri) assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) assert data["count"] == QUERIES_FIXTURE_COUNT @pytest.mark.usefixtures("create_queries") def test_get_list_query_order(self): """ Query API: Test get list query filter changed_on """ self.login(username="admin") order_columns = [ "changed_on", "database.database_name", "rows", "schema", "sql", "tab_name", "user.first_name", ] for order_column in order_columns: arguments = {"order_column": order_column, "order_direction": "asc"} uri = f"api/v1/query/?q={prison.dumps(arguments)}" rv = self.client.get(uri) assert rv.status_code == 200 def test_get_list_query_no_data_access(self): """ Query API: Test get queries no data access """ admin = self.get_user("admin") client_id = self.get_random_string() query = self.insert_query( get_example_database().id, admin.id, client_id, sql="SELECT col1, col2 from table1", ) self.login(username="gamma_sqllab") arguments = {"filters": [{"col": "sql", "opr": "sw", "value": "SELECT col1"}]} uri = f"api/v1/query/?q={prison.dumps(arguments)}" rv = self.client.get(uri) assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) assert data["count"] == 0 # rollback changes db.session.delete(query) db.session.commit() def test_get_updated_since(self): """ Query API: Test get queries updated since timestamp """ now = datetime.utcnow() client_id = self.get_random_string() admin = self.get_user("admin") example_db = get_example_database() old_query = self.insert_query( example_db.id, admin.id, self.get_random_string(), sql="SELECT col1, col2 from table1", select_sql="SELECT col1, col2 from table1", executed_sql="SELECT col1, col2 from table1 LIMIT 100", changed_on=now - timedelta(days=3), ) updated_query = self.insert_query( example_db.id, admin.id, client_id, sql="SELECT col1, col2 from table1", select_sql="SELECT col1, col2 from table1", executed_sql="SELECT col1, col2 from table1 LIMIT 100", changed_on=now - timedelta(days=1), ) self.login(username="admin") timestamp = datetime.timestamp(now - timedelta(days=2)) * 1000 uri = f"api/v1/query/updated_since?q={prison.dumps({'last_updated_ms': timestamp})}" rv = self.client.get(uri) self.assertEqual(rv.status_code, 200) expected_result = updated_query.to_dict() data = json.loads(rv.data.decode("utf-8")) self.assertEqual(len(data["result"]), 1) for key, value in data["result"][0].items(): # We can't assert timestamp if key not in ( "changed_on", "end_time", "start_running_time", "start_time", "id", ): self.assertEqual(value, expected_result[key]) # rollback changes db.session.delete(old_query) db.session.delete(updated_query) db.session.commit() @mock.patch("superset.sql_lab.cancel_query") @mock.patch("superset.views.core.db.session") def test_stop_query_not_found( self, mock_superset_db_session, mock_sql_lab_cancel_query ): """ Handles stop query when the DB engine spec does not have a cancel query method (with invalid client_id). """ form_data = {"client_id": "foo2"} query_mock = mock.Mock() query_mock.return_value = None self.login(username="admin") mock_superset_db_session.query().filter_by().one_or_none = query_mock mock_sql_lab_cancel_query.return_value = True rv = self.client.post( "/api/v1/query/stop", data=json.dumps(form_data), content_type="application/json", ) assert rv.status_code == 404 data = json.loads(rv.data.decode("utf-8")) assert data["message"] == "Query with client_id foo2 not found" @mock.patch("superset.sql_lab.cancel_query") @mock.patch("superset.views.core.db.session") def test_stop_query(self, mock_superset_db_session, mock_sql_lab_cancel_query): """ Handles stop query when the DB engine spec does not have a cancel query method. """ form_data = {"client_id": "foo"} query_mock = mock.Mock() query_mock.client_id = "foo" query_mock.status = QueryStatus.RUNNING self.login(username="admin") mock_superset_db_session.query().filter_by().one_or_none().return_value = ( query_mock ) mock_sql_lab_cancel_query.return_value = True rv = self.client.post( "/api/v1/query/stop", data=json.dumps(form_data), content_type="application/json", ) assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) assert data["result"] == "OK"