2019-10-24 23:46:45 -04:00
|
|
|
# Licensed to the Apache Software Foundation (ASF) under one
|
|
|
|
# or more contributor license agreements. See the NOTICE file
|
|
|
|
# distributed with this work for additional information
|
|
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
|
|
# to you under the Apache License, Version 2.0 (the
|
|
|
|
# "License"); you may not use this file except in compliance
|
|
|
|
# with the License. You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing,
|
|
|
|
# software distributed under the License is distributed on an
|
|
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
|
|
# KIND, either express or implied. See the License for the
|
|
|
|
# specific language governing permissions and limitations
|
|
|
|
# under the License.
|
2019-12-31 07:45:33 -05:00
|
|
|
import datetime
|
2019-10-24 23:46:45 -04:00
|
|
|
from unittest import mock
|
|
|
|
|
|
|
|
from superset.db_engine_specs import engines
|
|
|
|
from superset.db_engine_specs.base import BaseEngineSpec, builtin_time_grains
|
|
|
|
from superset.db_engine_specs.sqlite import SqliteEngineSpec
|
2020-12-03 13:44:11 -05:00
|
|
|
from superset.sql_parse import ParsedQuery
|
2019-10-24 23:46:45 -04:00
|
|
|
from superset.utils.core import get_example_database
|
2020-06-29 18:36:06 -04:00
|
|
|
from tests.db_engine_specs.base_tests import TestDbEngineSpec
|
2019-10-24 23:46:45 -04:00
|
|
|
|
2019-12-31 07:45:33 -05:00
|
|
|
from ..fixtures.pyodbcRow import Row
|
|
|
|
|
2020-12-03 13:44:11 -05:00
|
|
|
from tests.test_app import app # isort:skip
|
|
|
|
|
2019-10-24 23:46:45 -04:00
|
|
|
|
2020-06-29 18:36:06 -04:00
|
|
|
class TestDbEngineSpecs(TestDbEngineSpec):
|
2019-10-24 23:46:45 -04:00
|
|
|
def test_extract_limit_from_query(self, engine_spec_class=BaseEngineSpec):
|
|
|
|
q0 = "select * from table"
|
|
|
|
q1 = "select * from mytable limit 10"
|
|
|
|
q2 = "select * from (select * from my_subquery limit 10) where col=1 limit 20"
|
|
|
|
q3 = "select * from (select * from my_subquery limit 10);"
|
|
|
|
q4 = "select * from (select * from my_subquery limit 10) where col=1 limit 20;"
|
|
|
|
q5 = "select * from mytable limit 20, 10"
|
|
|
|
q6 = "select * from mytable limit 10 offset 20"
|
|
|
|
q7 = "select * from mytable limit"
|
|
|
|
q8 = "select * from mytable limit 10.0"
|
|
|
|
q9 = "select * from mytable limit x"
|
|
|
|
q10 = "select * from mytable limit 20, x"
|
|
|
|
q11 = "select * from mytable limit x offset 20"
|
|
|
|
|
|
|
|
self.assertEqual(engine_spec_class.get_limit_from_sql(q0), None)
|
|
|
|
self.assertEqual(engine_spec_class.get_limit_from_sql(q1), 10)
|
|
|
|
self.assertEqual(engine_spec_class.get_limit_from_sql(q2), 20)
|
|
|
|
self.assertEqual(engine_spec_class.get_limit_from_sql(q3), None)
|
|
|
|
self.assertEqual(engine_spec_class.get_limit_from_sql(q4), 20)
|
|
|
|
self.assertEqual(engine_spec_class.get_limit_from_sql(q5), 10)
|
|
|
|
self.assertEqual(engine_spec_class.get_limit_from_sql(q6), 10)
|
|
|
|
self.assertEqual(engine_spec_class.get_limit_from_sql(q7), None)
|
|
|
|
self.assertEqual(engine_spec_class.get_limit_from_sql(q8), None)
|
|
|
|
self.assertEqual(engine_spec_class.get_limit_from_sql(q9), None)
|
|
|
|
self.assertEqual(engine_spec_class.get_limit_from_sql(q10), None)
|
|
|
|
self.assertEqual(engine_spec_class.get_limit_from_sql(q11), None)
|
|
|
|
|
|
|
|
def test_wrapped_semi_tabs(self):
|
|
|
|
self.sql_limit_regex(
|
|
|
|
"SELECT * FROM a \t \n ; \t \n ", "SELECT * FROM a\nLIMIT 1000"
|
|
|
|
)
|
|
|
|
|
|
|
|
def test_simple_limit_query(self):
|
|
|
|
self.sql_limit_regex("SELECT * FROM a", "SELECT * FROM a\nLIMIT 1000")
|
|
|
|
|
|
|
|
def test_modify_limit_query(self):
|
|
|
|
self.sql_limit_regex("SELECT * FROM a LIMIT 9999", "SELECT * FROM a LIMIT 1000")
|
|
|
|
|
|
|
|
def test_limit_query_with_limit_subquery(self): # pylint: disable=invalid-name
|
|
|
|
self.sql_limit_regex(
|
|
|
|
"SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 9999",
|
|
|
|
"SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 1000",
|
|
|
|
)
|
|
|
|
|
|
|
|
def test_limit_with_expr(self):
|
|
|
|
self.sql_limit_regex(
|
|
|
|
"""
|
|
|
|
SELECT
|
|
|
|
'LIMIT 777' AS a
|
|
|
|
, b
|
|
|
|
FROM
|
|
|
|
table
|
|
|
|
LIMIT 99990""",
|
|
|
|
"""SELECT
|
|
|
|
'LIMIT 777' AS a
|
|
|
|
, b
|
|
|
|
FROM
|
|
|
|
table
|
|
|
|
LIMIT 1000""",
|
|
|
|
)
|
|
|
|
|
|
|
|
def test_limit_expr_and_semicolon(self):
|
|
|
|
self.sql_limit_regex(
|
|
|
|
"""
|
|
|
|
SELECT
|
|
|
|
'LIMIT 777' AS a
|
|
|
|
, b
|
|
|
|
FROM
|
|
|
|
table
|
|
|
|
LIMIT 99990 ;""",
|
|
|
|
"""SELECT
|
|
|
|
'LIMIT 777' AS a
|
|
|
|
, b
|
|
|
|
FROM
|
|
|
|
table
|
|
|
|
LIMIT 1000""",
|
|
|
|
)
|
|
|
|
|
|
|
|
def test_get_datatype(self):
|
|
|
|
self.assertEqual("VARCHAR", BaseEngineSpec.get_datatype("VARCHAR"))
|
|
|
|
|
|
|
|
def test_limit_with_implicit_offset(self):
|
|
|
|
self.sql_limit_regex(
|
|
|
|
"""
|
|
|
|
SELECT
|
|
|
|
'LIMIT 777' AS a
|
|
|
|
, b
|
|
|
|
FROM
|
|
|
|
table
|
|
|
|
LIMIT 99990, 999999""",
|
|
|
|
"""SELECT
|
|
|
|
'LIMIT 777' AS a
|
|
|
|
, b
|
|
|
|
FROM
|
|
|
|
table
|
|
|
|
LIMIT 99990, 1000""",
|
|
|
|
)
|
|
|
|
|
|
|
|
def test_limit_with_explicit_offset(self):
|
|
|
|
self.sql_limit_regex(
|
|
|
|
"""
|
|
|
|
SELECT
|
|
|
|
'LIMIT 777' AS a
|
|
|
|
, b
|
|
|
|
FROM
|
|
|
|
table
|
|
|
|
LIMIT 99990
|
|
|
|
OFFSET 999999""",
|
|
|
|
"""SELECT
|
|
|
|
'LIMIT 777' AS a
|
|
|
|
, b
|
|
|
|
FROM
|
|
|
|
table
|
|
|
|
LIMIT 1000
|
|
|
|
OFFSET 999999""",
|
|
|
|
)
|
|
|
|
|
|
|
|
def test_limit_with_non_token_limit(self):
|
|
|
|
self.sql_limit_regex(
|
|
|
|
"""SELECT 'LIMIT 777'""", """SELECT 'LIMIT 777'\nLIMIT 1000"""
|
|
|
|
)
|
|
|
|
|
2020-07-15 12:22:46 -04:00
|
|
|
def test_time_grain_denylist(self):
|
2019-10-24 23:46:45 -04:00
|
|
|
with app.app_context():
|
2020-07-15 12:22:46 -04:00
|
|
|
app.config["TIME_GRAIN_DENYLIST"] = ["PT1M"]
|
2020-03-05 00:25:50 -05:00
|
|
|
time_grain_functions = SqliteEngineSpec.get_time_grain_expressions()
|
2019-10-24 23:46:45 -04:00
|
|
|
self.assertNotIn("PT1M", time_grain_functions)
|
|
|
|
|
|
|
|
def test_time_grain_addons(self):
|
|
|
|
with app.app_context():
|
|
|
|
app.config["TIME_GRAIN_ADDONS"] = {"PTXM": "x seconds"}
|
2020-03-05 00:25:50 -05:00
|
|
|
app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {
|
2019-10-24 23:46:45 -04:00
|
|
|
"sqlite": {"PTXM": "ABC({col})"}
|
|
|
|
}
|
|
|
|
time_grains = SqliteEngineSpec.get_time_grains()
|
|
|
|
time_grain_addon = time_grains[-1]
|
|
|
|
self.assertEqual("PTXM", time_grain_addon.duration)
|
|
|
|
self.assertEqual("x seconds", time_grain_addon.label)
|
|
|
|
|
|
|
|
def test_engine_time_grain_validity(self):
|
|
|
|
time_grains = set(builtin_time_grains.keys())
|
|
|
|
# loop over all subclasses of BaseEngineSpec
|
|
|
|
for engine in engines.values():
|
|
|
|
if engine is not BaseEngineSpec:
|
|
|
|
# make sure time grain functions have been defined
|
2020-03-05 00:25:50 -05:00
|
|
|
self.assertGreater(len(engine.get_time_grain_expressions()), 0)
|
2019-10-24 23:46:45 -04:00
|
|
|
# make sure all defined time grains are supported
|
|
|
|
defined_grains = {grain.duration for grain in engine.get_time_grains()}
|
|
|
|
intersection = time_grains.intersection(defined_grains)
|
|
|
|
self.assertSetEqual(defined_grains, intersection, engine)
|
|
|
|
|
|
|
|
def test_get_table_names(self):
|
|
|
|
inspector = mock.Mock()
|
|
|
|
inspector.get_table_names = mock.Mock(return_value=["schema.table", "table_2"])
|
|
|
|
inspector.get_foreign_table_names = mock.Mock(return_value=["table_3"])
|
|
|
|
|
|
|
|
""" Make sure base engine spec removes schema name from table name
|
|
|
|
ie. when try_remove_schema_from_table_name == True. """
|
|
|
|
base_result_expected = ["table", "table_2"]
|
|
|
|
base_result = BaseEngineSpec.get_table_names(
|
|
|
|
database=mock.ANY, schema="schema", inspector=inspector
|
|
|
|
)
|
|
|
|
self.assertListEqual(base_result_expected, base_result)
|
|
|
|
|
|
|
|
def test_column_datatype_to_string(self):
|
|
|
|
example_db = get_example_database()
|
|
|
|
sqla_table = example_db.get_table("energy_usage")
|
|
|
|
dialect = example_db.get_dialect()
|
2020-08-06 15:07:22 -04:00
|
|
|
|
|
|
|
# TODO: fix column type conversion for presto.
|
|
|
|
if example_db.backend == "presto":
|
|
|
|
return
|
|
|
|
|
2019-10-24 23:46:45 -04:00
|
|
|
col_names = [
|
|
|
|
example_db.db_engine_spec.column_datatype_to_string(c.type, dialect)
|
|
|
|
for c in sqla_table.columns
|
|
|
|
]
|
|
|
|
if example_db.backend == "postgresql":
|
|
|
|
expected = ["VARCHAR(255)", "VARCHAR(255)", "DOUBLE PRECISION"]
|
2020-08-27 12:49:18 -04:00
|
|
|
elif example_db.backend == "hive":
|
|
|
|
expected = ["STRING", "STRING", "FLOAT"]
|
2019-10-24 23:46:45 -04:00
|
|
|
else:
|
|
|
|
expected = ["VARCHAR(255)", "VARCHAR(255)", "FLOAT"]
|
|
|
|
self.assertEqual(col_names, expected)
|
2019-10-30 02:24:48 -04:00
|
|
|
|
|
|
|
def test_convert_dttm(self):
|
|
|
|
dttm = self.get_dttm()
|
|
|
|
self.assertIsNone(BaseEngineSpec.convert_dttm("", dttm))
|
2019-12-31 07:45:33 -05:00
|
|
|
|
|
|
|
def test_pyodbc_rows_to_tuples(self):
|
|
|
|
# Test for case when pyodbc.Row is returned (odbc driver)
|
|
|
|
data = [
|
|
|
|
Row((1, 1, datetime.datetime(2017, 10, 19, 23, 39, 16, 660000))),
|
|
|
|
Row((2, 2, datetime.datetime(2018, 10, 19, 23, 39, 16, 660000))),
|
|
|
|
]
|
|
|
|
expected = [
|
|
|
|
(1, 1, datetime.datetime(2017, 10, 19, 23, 39, 16, 660000)),
|
|
|
|
(2, 2, datetime.datetime(2018, 10, 19, 23, 39, 16, 660000)),
|
|
|
|
]
|
|
|
|
result = BaseEngineSpec.pyodbc_rows_to_tuples(data)
|
|
|
|
self.assertListEqual(result, expected)
|
|
|
|
|
|
|
|
def test_pyodbc_rows_to_tuples_passthrough(self):
|
|
|
|
# Test for case when tuples are returned
|
|
|
|
data = [
|
|
|
|
(1, 1, datetime.datetime(2017, 10, 19, 23, 39, 16, 660000)),
|
|
|
|
(2, 2, datetime.datetime(2018, 10, 19, 23, 39, 16, 660000)),
|
|
|
|
]
|
|
|
|
result = BaseEngineSpec.pyodbc_rows_to_tuples(data)
|
|
|
|
self.assertListEqual(result, data)
|
2020-12-03 13:44:11 -05:00
|
|
|
|
|
|
|
|
|
|
|
def test_is_readonly():
|
|
|
|
def is_readonly(sql: str) -> bool:
|
|
|
|
return BaseEngineSpec.is_readonly_query(ParsedQuery(sql))
|
|
|
|
|
|
|
|
assert not is_readonly("SHOW LOCKS test EXTENDED")
|
|
|
|
assert not is_readonly("SET hivevar:desc='Legislators'")
|
|
|
|
assert not is_readonly("UPDATE t1 SET col1 = NULL")
|
|
|
|
assert is_readonly("EXPLAIN SELECT 1")
|
|
|
|
assert is_readonly("SELECT 1")
|
|
|
|
assert is_readonly("WITH (SELECT 1) bla SELECT * from bla")
|