# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import datetime from unittest import mock import pytest from superset.db_engine_specs import engines from superset.db_engine_specs.base import BaseEngineSpec, builtin_time_grains from superset.db_engine_specs.sqlite import SqliteEngineSpec from superset.sql_parse import ParsedQuery from superset.utils.core import get_example_database from tests.db_engine_specs.base_tests import TestDbEngineSpec from tests.test_app import app from ..fixtures.energy_dashboard import load_energy_table_with_slice from ..fixtures.pyodbcRow import Row class TestDbEngineSpecs(TestDbEngineSpec): def test_extract_limit_from_query(self, engine_spec_class=BaseEngineSpec): q0 = "select * from table" q1 = "select * from mytable limit 10" q2 = "select * from (select * from my_subquery limit 10) where col=1 limit 20" q3 = "select * from (select * from my_subquery limit 10);" q4 = "select * from (select * from my_subquery limit 10) where col=1 limit 20;" q5 = "select * from mytable limit 20, 10" q6 = "select * from mytable limit 10 offset 20" q7 = "select * from mytable limit" q8 = "select * from mytable limit 10.0" q9 = "select * from mytable limit x" q10 = "select * from mytable limit 20, x" q11 = "select * from mytable limit x offset 20" self.assertEqual(engine_spec_class.get_limit_from_sql(q0), None) self.assertEqual(engine_spec_class.get_limit_from_sql(q1), 10) self.assertEqual(engine_spec_class.get_limit_from_sql(q2), 20) self.assertEqual(engine_spec_class.get_limit_from_sql(q3), None) self.assertEqual(engine_spec_class.get_limit_from_sql(q4), 20) self.assertEqual(engine_spec_class.get_limit_from_sql(q5), 10) self.assertEqual(engine_spec_class.get_limit_from_sql(q6), 10) self.assertEqual(engine_spec_class.get_limit_from_sql(q7), None) self.assertEqual(engine_spec_class.get_limit_from_sql(q8), None) self.assertEqual(engine_spec_class.get_limit_from_sql(q9), None) self.assertEqual(engine_spec_class.get_limit_from_sql(q10), None) self.assertEqual(engine_spec_class.get_limit_from_sql(q11), None) def test_wrapped_semi_tabs(self): self.sql_limit_regex( "SELECT * FROM a \t \n ; \t \n ", "SELECT * FROM a\nLIMIT 1000" ) def test_simple_limit_query(self): self.sql_limit_regex("SELECT * FROM a", "SELECT * FROM a\nLIMIT 1000") def test_modify_limit_query(self): self.sql_limit_regex("SELECT * FROM a LIMIT 9999", "SELECT * FROM a LIMIT 1000") def test_limit_query_with_limit_subquery(self): # pylint: disable=invalid-name self.sql_limit_regex( "SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 9999", "SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 1000", ) def test_limit_with_expr(self): self.sql_limit_regex( """ SELECT 'LIMIT 777' AS a , b FROM table LIMIT 99990""", """SELECT 'LIMIT 777' AS a , b FROM table LIMIT 1000""", ) def test_limit_expr_and_semicolon(self): self.sql_limit_regex( """ SELECT 'LIMIT 777' AS a , b FROM table LIMIT 99990 ;""", """SELECT 'LIMIT 777' AS a , b FROM table LIMIT 1000""", ) def test_get_datatype(self): self.assertEqual("VARCHAR", BaseEngineSpec.get_datatype("VARCHAR")) def test_limit_with_implicit_offset(self): self.sql_limit_regex( """ SELECT 'LIMIT 777' AS a , b FROM table LIMIT 99990, 999999""", """SELECT 'LIMIT 777' AS a , b FROM table LIMIT 99990, 1000""", ) def test_limit_with_explicit_offset(self): self.sql_limit_regex( """ SELECT 'LIMIT 777' AS a , b FROM table LIMIT 99990 OFFSET 999999""", """SELECT 'LIMIT 777' AS a , b FROM table LIMIT 1000 OFFSET 999999""", ) def test_limit_with_non_token_limit(self): self.sql_limit_regex( """SELECT 'LIMIT 777'""", """SELECT 'LIMIT 777'\nLIMIT 1000""" ) def test_time_grain_denylist(self): with app.app_context(): app.config["TIME_GRAIN_DENYLIST"] = ["PT1M"] time_grain_functions = SqliteEngineSpec.get_time_grain_expressions() self.assertNotIn("PT1M", time_grain_functions) def test_time_grain_addons(self): with app.app_context(): app.config["TIME_GRAIN_ADDONS"] = {"PTXM": "x seconds"} app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = { "sqlite": {"PTXM": "ABC({col})"} } time_grains = SqliteEngineSpec.get_time_grains() time_grain_addon = time_grains[-1] self.assertEqual("PTXM", time_grain_addon.duration) self.assertEqual("x seconds", time_grain_addon.label) app.config["TIME_GRAIN_ADDONS"] = {} app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {} def test_engine_time_grain_validity(self): time_grains = set(builtin_time_grains.keys()) # loop over all subclasses of BaseEngineSpec for engine in engines.values(): if engine is not BaseEngineSpec: # make sure time grain functions have been defined self.assertGreater(len(engine.get_time_grain_expressions()), 0) # make sure all defined time grains are supported defined_grains = {grain.duration for grain in engine.get_time_grains()} intersection = time_grains.intersection(defined_grains) self.assertSetEqual(defined_grains, intersection, engine) def test_get_table_names(self): inspector = mock.Mock() inspector.get_table_names = mock.Mock(return_value=["schema.table", "table_2"]) inspector.get_foreign_table_names = mock.Mock(return_value=["table_3"]) """ Make sure base engine spec removes schema name from table name ie. when try_remove_schema_from_table_name == True. """ base_result_expected = ["table", "table_2"] base_result = BaseEngineSpec.get_table_names( database=mock.ANY, schema="schema", inspector=inspector ) self.assertListEqual(base_result_expected, base_result) @pytest.mark.usefixtures("load_energy_table_with_slice") def test_column_datatype_to_string(self): example_db = get_example_database() sqla_table = example_db.get_table("energy_usage") dialect = example_db.get_dialect() # TODO: fix column type conversion for presto. if example_db.backend == "presto": return col_names = [ example_db.db_engine_spec.column_datatype_to_string(c.type, dialect) for c in sqla_table.columns ] if example_db.backend == "postgresql": expected = ["VARCHAR(255)", "VARCHAR(255)", "DOUBLE PRECISION"] elif example_db.backend == "hive": expected = ["STRING", "STRING", "FLOAT"] else: expected = ["VARCHAR(255)", "VARCHAR(255)", "FLOAT"] self.assertEqual(col_names, expected) def test_convert_dttm(self): dttm = self.get_dttm() self.assertIsNone(BaseEngineSpec.convert_dttm("", dttm)) def test_pyodbc_rows_to_tuples(self): # Test for case when pyodbc.Row is returned (odbc driver) data = [ Row((1, 1, datetime.datetime(2017, 10, 19, 23, 39, 16, 660000))), Row((2, 2, datetime.datetime(2018, 10, 19, 23, 39, 16, 660000))), ] expected = [ (1, 1, datetime.datetime(2017, 10, 19, 23, 39, 16, 660000)), (2, 2, datetime.datetime(2018, 10, 19, 23, 39, 16, 660000)), ] result = BaseEngineSpec.pyodbc_rows_to_tuples(data) self.assertListEqual(result, expected) def test_pyodbc_rows_to_tuples_passthrough(self): # Test for case when tuples are returned data = [ (1, 1, datetime.datetime(2017, 10, 19, 23, 39, 16, 660000)), (2, 2, datetime.datetime(2018, 10, 19, 23, 39, 16, 660000)), ] result = BaseEngineSpec.pyodbc_rows_to_tuples(data) self.assertListEqual(result, data) def test_is_readonly(): def is_readonly(sql: str) -> bool: return BaseEngineSpec.is_readonly_query(ParsedQuery(sql)) assert not is_readonly("SHOW LOCKS test EXTENDED") assert not is_readonly("SET hivevar:desc='Legislators'") assert not is_readonly("UPDATE t1 SET col1 = NULL") assert is_readonly("EXPLAIN SELECT 1") assert is_readonly("SELECT 1") assert is_readonly("WITH (SELECT 1) bla SELECT * from bla")