mirror of
https://github.com/apache/superset.git
synced 2024-09-18 19:49:37 -04:00
78ede590ee
Signed-off-by: Josh Soref <jsoref@users.noreply.github.com> Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> Co-authored-by: Josh Soref <jsoref@users.noreply.github.com>
523 lines
17 KiB
Python
523 lines
17 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
import datetime
|
|
from unittest import mock
|
|
|
|
import pytest
|
|
|
|
from superset.connectors.sqla.models import TableColumn
|
|
from superset.db_engine_specs import load_engine_specs
|
|
from superset.db_engine_specs.base import (
|
|
BaseEngineSpec,
|
|
BasicParametersMixin,
|
|
builtin_time_grains,
|
|
LimitMethod,
|
|
)
|
|
from superset.db_engine_specs.mysql import MySQLEngineSpec
|
|
from superset.db_engine_specs.sqlite import SqliteEngineSpec
|
|
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
|
from superset.sql_parse import ParsedQuery
|
|
from superset.utils.database import get_example_database
|
|
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
|
|
from tests.integration_tests.test_app import app
|
|
|
|
from ..fixtures.birth_names_dashboard import (
|
|
load_birth_names_dashboard_with_slices,
|
|
load_birth_names_data,
|
|
)
|
|
from ..fixtures.energy_dashboard import (
|
|
load_energy_table_data,
|
|
load_energy_table_with_slice,
|
|
)
|
|
from ..fixtures.pyodbcRow import Row
|
|
|
|
|
|
class TestDbEngineSpecs(TestDbEngineSpec):
|
|
def test_extract_limit_from_query(self, engine_spec_class=BaseEngineSpec):
|
|
q0 = "select * from table"
|
|
q1 = "select * from mytable limit 10"
|
|
q2 = "select * from (select * from my_subquery limit 10) where col=1 limit 20"
|
|
q3 = "select * from (select * from my_subquery limit 10);"
|
|
q4 = "select * from (select * from my_subquery limit 10) where col=1 limit 20;"
|
|
q5 = "select * from mytable limit 20, 10"
|
|
q6 = "select * from mytable limit 10 offset 20"
|
|
q7 = "select * from mytable limit"
|
|
q8 = "select * from mytable limit 10.0"
|
|
q9 = "select * from mytable limit x"
|
|
q10 = "select * from mytable limit 20, x"
|
|
q11 = "select * from mytable limit x offset 20"
|
|
|
|
self.assertEqual(engine_spec_class.get_limit_from_sql(q0), None)
|
|
self.assertEqual(engine_spec_class.get_limit_from_sql(q1), 10)
|
|
self.assertEqual(engine_spec_class.get_limit_from_sql(q2), 20)
|
|
self.assertEqual(engine_spec_class.get_limit_from_sql(q3), None)
|
|
self.assertEqual(engine_spec_class.get_limit_from_sql(q4), 20)
|
|
self.assertEqual(engine_spec_class.get_limit_from_sql(q5), 10)
|
|
self.assertEqual(engine_spec_class.get_limit_from_sql(q6), 10)
|
|
self.assertEqual(engine_spec_class.get_limit_from_sql(q7), None)
|
|
self.assertEqual(engine_spec_class.get_limit_from_sql(q8), None)
|
|
self.assertEqual(engine_spec_class.get_limit_from_sql(q9), None)
|
|
self.assertEqual(engine_spec_class.get_limit_from_sql(q10), None)
|
|
self.assertEqual(engine_spec_class.get_limit_from_sql(q11), None)
|
|
|
|
def test_wrapped_semi_tabs(self):
|
|
self.sql_limit_regex(
|
|
"SELECT * FROM a \t \n ; \t \n ", "SELECT * FROM a\nLIMIT 1000"
|
|
)
|
|
|
|
def test_simple_limit_query(self):
|
|
self.sql_limit_regex("SELECT * FROM a", "SELECT * FROM a\nLIMIT 1000")
|
|
|
|
def test_modify_limit_query(self):
|
|
self.sql_limit_regex("SELECT * FROM a LIMIT 9999", "SELECT * FROM a LIMIT 1000")
|
|
|
|
def test_limit_query_with_limit_subquery(self): # pylint: disable=invalid-name
|
|
self.sql_limit_regex(
|
|
"SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 9999",
|
|
"SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 1000",
|
|
)
|
|
|
|
def test_limit_query_without_force(self):
|
|
self.sql_limit_regex(
|
|
"SELECT * FROM a LIMIT 10",
|
|
"SELECT * FROM a LIMIT 10",
|
|
limit=11,
|
|
)
|
|
|
|
def test_limit_query_with_force(self):
|
|
self.sql_limit_regex(
|
|
"SELECT * FROM a LIMIT 10",
|
|
"SELECT * FROM a LIMIT 11",
|
|
limit=11,
|
|
force=True,
|
|
)
|
|
|
|
def test_limit_with_expr(self):
|
|
self.sql_limit_regex(
|
|
"""
|
|
SELECT
|
|
'LIMIT 777' AS a
|
|
, b
|
|
FROM
|
|
table
|
|
LIMIT 99990""",
|
|
"""SELECT
|
|
'LIMIT 777' AS a
|
|
, b
|
|
FROM
|
|
table
|
|
LIMIT 1000""",
|
|
)
|
|
|
|
def test_limit_expr_and_semicolon(self):
|
|
self.sql_limit_regex(
|
|
"""
|
|
SELECT
|
|
'LIMIT 777' AS a
|
|
, b
|
|
FROM
|
|
table
|
|
LIMIT 99990 ;""",
|
|
"""SELECT
|
|
'LIMIT 777' AS a
|
|
, b
|
|
FROM
|
|
table
|
|
LIMIT 1000""",
|
|
)
|
|
|
|
def test_get_datatype(self):
|
|
self.assertEqual("VARCHAR", BaseEngineSpec.get_datatype("VARCHAR"))
|
|
|
|
def test_limit_with_implicit_offset(self):
|
|
self.sql_limit_regex(
|
|
"""
|
|
SELECT
|
|
'LIMIT 777' AS a
|
|
, b
|
|
FROM
|
|
table
|
|
LIMIT 99990, 999999""",
|
|
"""SELECT
|
|
'LIMIT 777' AS a
|
|
, b
|
|
FROM
|
|
table
|
|
LIMIT 99990, 1000""",
|
|
)
|
|
|
|
def test_limit_with_explicit_offset(self):
|
|
self.sql_limit_regex(
|
|
"""
|
|
SELECT
|
|
'LIMIT 777' AS a
|
|
, b
|
|
FROM
|
|
table
|
|
LIMIT 99990
|
|
OFFSET 999999""",
|
|
"""SELECT
|
|
'LIMIT 777' AS a
|
|
, b
|
|
FROM
|
|
table
|
|
LIMIT 1000
|
|
OFFSET 999999""",
|
|
)
|
|
|
|
def test_limit_with_non_token_limit(self):
|
|
self.sql_limit_regex(
|
|
"""SELECT 'LIMIT 777'""", """SELECT 'LIMIT 777'\nLIMIT 1000"""
|
|
)
|
|
|
|
def test_limit_with_fetch_many(self):
|
|
class DummyEngineSpec(BaseEngineSpec):
|
|
limit_method = LimitMethod.FETCH_MANY
|
|
|
|
self.sql_limit_regex(
|
|
"SELECT * FROM table", "SELECT * FROM table", DummyEngineSpec
|
|
)
|
|
|
|
def test_engine_time_grain_validity(self):
|
|
time_grains = set(builtin_time_grains.keys())
|
|
# loop over all subclasses of BaseEngineSpec
|
|
for engine in load_engine_specs():
|
|
if engine is not BaseEngineSpec:
|
|
# make sure time grain functions have been defined
|
|
self.assertGreater(len(engine.get_time_grain_expressions()), 0)
|
|
# make sure all defined time grains are supported
|
|
defined_grains = {grain.duration for grain in engine.get_time_grains()}
|
|
intersection = time_grains.intersection(defined_grains)
|
|
self.assertSetEqual(defined_grains, intersection, engine)
|
|
|
|
def test_get_time_grain_expressions(self):
|
|
time_grains = MySQLEngineSpec.get_time_grain_expressions()
|
|
self.assertEqual(
|
|
list(time_grains.keys()),
|
|
[
|
|
None,
|
|
"PT1S",
|
|
"PT1M",
|
|
"PT1H",
|
|
"P1D",
|
|
"P1W",
|
|
"P1M",
|
|
"P3M",
|
|
"P1Y",
|
|
"1969-12-29T00:00:00Z/P1W",
|
|
],
|
|
)
|
|
|
|
def test_get_table_names(self):
|
|
inspector = mock.Mock()
|
|
inspector.get_table_names = mock.Mock(return_value=["schema.table", "table_2"])
|
|
inspector.get_foreign_table_names = mock.Mock(return_value=["table_3"])
|
|
|
|
""" Make sure base engine spec removes schema name from table name
|
|
ie. when try_remove_schema_from_table_name == True. """
|
|
base_result_expected = {"table", "table_2"}
|
|
base_result = BaseEngineSpec.get_table_names(
|
|
database=mock.ANY, schema="schema", inspector=inspector
|
|
)
|
|
assert base_result_expected == base_result
|
|
|
|
@pytest.mark.usefixtures("load_energy_table_with_slice")
|
|
def test_column_datatype_to_string(self):
|
|
example_db = get_example_database()
|
|
sqla_table = example_db.get_table("energy_usage")
|
|
dialect = example_db.get_dialect()
|
|
|
|
# TODO: fix column type conversion for presto.
|
|
if example_db.backend == "presto":
|
|
return
|
|
|
|
col_names = [
|
|
example_db.db_engine_spec.column_datatype_to_string(c.type, dialect)
|
|
for c in sqla_table.columns
|
|
]
|
|
if example_db.backend == "postgresql":
|
|
expected = ["VARCHAR(255)", "VARCHAR(255)", "DOUBLE PRECISION"]
|
|
elif example_db.backend == "hive":
|
|
expected = ["STRING", "STRING", "FLOAT"]
|
|
else:
|
|
expected = ["VARCHAR(255)", "VARCHAR(255)", "FLOAT"]
|
|
self.assertEqual(col_names, expected)
|
|
|
|
def test_convert_dttm(self):
|
|
dttm = self.get_dttm()
|
|
self.assertIsNone(BaseEngineSpec.convert_dttm("", dttm, db_extra=None))
|
|
|
|
def test_pyodbc_rows_to_tuples(self):
|
|
# Test for case when pyodbc.Row is returned (odbc driver)
|
|
data = [
|
|
Row((1, 1, datetime.datetime(2017, 10, 19, 23, 39, 16, 660000))),
|
|
Row((2, 2, datetime.datetime(2018, 10, 19, 23, 39, 16, 660000))),
|
|
]
|
|
expected = [
|
|
(1, 1, datetime.datetime(2017, 10, 19, 23, 39, 16, 660000)),
|
|
(2, 2, datetime.datetime(2018, 10, 19, 23, 39, 16, 660000)),
|
|
]
|
|
result = BaseEngineSpec.pyodbc_rows_to_tuples(data)
|
|
self.assertListEqual(result, expected)
|
|
|
|
def test_pyodbc_rows_to_tuples_passthrough(self):
|
|
# Test for case when tuples are returned
|
|
data = [
|
|
(1, 1, datetime.datetime(2017, 10, 19, 23, 39, 16, 660000)),
|
|
(2, 2, datetime.datetime(2018, 10, 19, 23, 39, 16, 660000)),
|
|
]
|
|
result = BaseEngineSpec.pyodbc_rows_to_tuples(data)
|
|
self.assertListEqual(result, data)
|
|
|
|
@mock.patch("superset.models.core.Database.db_engine_spec", BaseEngineSpec)
|
|
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
|
def test_calculated_column_in_order_by_base_engine_spec(self):
|
|
table = self.get_table(name="birth_names")
|
|
TableColumn(
|
|
column_name="gender_cc",
|
|
type="VARCHAR(255)",
|
|
table=table,
|
|
expression="""
|
|
case
|
|
when gender='boy' then 'male'
|
|
else 'female'
|
|
end
|
|
""",
|
|
)
|
|
|
|
table.database.sqlalchemy_uri = "sqlite://"
|
|
query_obj = {
|
|
"groupby": ["gender_cc"],
|
|
"is_timeseries": False,
|
|
"filter": [],
|
|
"orderby": [["gender_cc", True]],
|
|
}
|
|
sql = table.get_query_str(query_obj)
|
|
assert (
|
|
"""ORDER BY case
|
|
when gender='boy' then 'male'
|
|
else 'female'
|
|
end ASC;"""
|
|
in sql
|
|
)
|
|
|
|
|
|
def test_is_readonly():
|
|
def is_readonly(sql: str) -> bool:
|
|
return BaseEngineSpec.is_readonly_query(ParsedQuery(sql))
|
|
|
|
assert is_readonly("SHOW LOCKS test EXTENDED")
|
|
assert not is_readonly("SET hivevar:desc='Legislators'")
|
|
assert not is_readonly("UPDATE t1 SET col1 = NULL")
|
|
assert is_readonly("EXPLAIN SELECT 1")
|
|
assert is_readonly("SELECT 1")
|
|
assert is_readonly("WITH (SELECT 1) bla SELECT * from bla")
|
|
assert is_readonly("SHOW CATALOGS")
|
|
assert is_readonly("SHOW TABLES")
|
|
|
|
|
|
def test_time_grain_denylist():
|
|
config = app.config.copy()
|
|
app.config["TIME_GRAIN_DENYLIST"] = ["PT1M"]
|
|
|
|
with app.app_context():
|
|
time_grain_functions = SqliteEngineSpec.get_time_grain_expressions()
|
|
assert not "PT1M" in time_grain_functions
|
|
|
|
app.config = config
|
|
|
|
|
|
def test_time_grain_addons():
|
|
config = app.config.copy()
|
|
app.config["TIME_GRAIN_ADDONS"] = {"PTXM": "x seconds"}
|
|
app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {"sqlite": {"PTXM": "ABC({col})"}}
|
|
|
|
with app.app_context():
|
|
time_grains = SqliteEngineSpec.get_time_grains()
|
|
time_grain_addon = time_grains[-1]
|
|
assert "PTXM" == time_grain_addon.duration
|
|
assert "x seconds" == time_grain_addon.label
|
|
|
|
app.config = config
|
|
|
|
|
|
def test_get_time_grain_with_config():
|
|
"""Should concatenate from configs and then sort in the proper order"""
|
|
config = app.config.copy()
|
|
|
|
app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {
|
|
"mysql": {
|
|
"PT2H": "foo",
|
|
"PT4H": "foo",
|
|
"PT6H": "foo",
|
|
"PT8H": "foo",
|
|
"PT10H": "foo",
|
|
"PT12H": "foo",
|
|
"PT1S": "foo",
|
|
}
|
|
}
|
|
|
|
with app.app_context():
|
|
time_grains = MySQLEngineSpec.get_time_grain_expressions()
|
|
assert set(time_grains.keys()) == {
|
|
None,
|
|
"PT1S",
|
|
"PT1M",
|
|
"PT1H",
|
|
"PT2H",
|
|
"PT4H",
|
|
"PT6H",
|
|
"PT8H",
|
|
"PT10H",
|
|
"PT12H",
|
|
"P1D",
|
|
"P1W",
|
|
"P1M",
|
|
"P3M",
|
|
"P1Y",
|
|
"1969-12-29T00:00:00Z/P1W",
|
|
}
|
|
|
|
app.config = config
|
|
|
|
|
|
def test_get_time_grain_with_unknown_values():
|
|
"""Should concatenate from configs and then sort in the proper order
|
|
putting unknown patterns at the end"""
|
|
config = app.config.copy()
|
|
|
|
app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {
|
|
"mysql": {
|
|
"PT2H": "foo",
|
|
"weird": "foo",
|
|
"PT12H": "foo",
|
|
}
|
|
}
|
|
|
|
with app.app_context():
|
|
time_grains = MySQLEngineSpec.get_time_grain_expressions()
|
|
assert list(time_grains)[-1] == "weird"
|
|
|
|
app.config = config
|
|
|
|
|
|
@mock.patch("superset.db_engine_specs.base.is_hostname_valid")
|
|
@mock.patch("superset.db_engine_specs.base.is_port_open")
|
|
def test_validate(is_port_open, is_hostname_valid):
|
|
is_hostname_valid.return_value = True
|
|
is_port_open.return_value = True
|
|
|
|
properties = {
|
|
"parameters": {
|
|
"host": "localhost",
|
|
"port": 5432,
|
|
"username": "username",
|
|
"password": "password",
|
|
"database": "dbname",
|
|
"query": {"sslmode": "verify-full"},
|
|
}
|
|
}
|
|
errors = BasicParametersMixin.validate_parameters(properties)
|
|
assert errors == []
|
|
|
|
|
|
def test_validate_parameters_missing():
|
|
properties = {
|
|
"parameters": {
|
|
"host": "",
|
|
"port": None,
|
|
"username": "",
|
|
"password": "",
|
|
"database": "",
|
|
"query": {},
|
|
}
|
|
}
|
|
errors = BasicParametersMixin.validate_parameters(properties)
|
|
assert errors == [
|
|
SupersetError(
|
|
message=(
|
|
"One or more parameters are missing: " "database, host, port, username"
|
|
),
|
|
error_type=SupersetErrorType.CONNECTION_MISSING_PARAMETERS_ERROR,
|
|
level=ErrorLevel.WARNING,
|
|
extra={"missing": ["database", "host", "port", "username"]},
|
|
),
|
|
]
|
|
|
|
|
|
@mock.patch("superset.db_engine_specs.base.is_hostname_valid")
|
|
def test_validate_parameters_invalid_host(is_hostname_valid):
|
|
is_hostname_valid.return_value = False
|
|
|
|
properties = {
|
|
"parameters": {
|
|
"host": "localhost",
|
|
"port": None,
|
|
"username": "username",
|
|
"password": "password",
|
|
"database": "dbname",
|
|
"query": {"sslmode": "verify-full"},
|
|
}
|
|
}
|
|
errors = BasicParametersMixin.validate_parameters(properties)
|
|
assert errors == [
|
|
SupersetError(
|
|
message="One or more parameters are missing: port",
|
|
error_type=SupersetErrorType.CONNECTION_MISSING_PARAMETERS_ERROR,
|
|
level=ErrorLevel.WARNING,
|
|
extra={"missing": ["port"]},
|
|
),
|
|
SupersetError(
|
|
message="The hostname provided can't be resolved.",
|
|
error_type=SupersetErrorType.CONNECTION_INVALID_HOSTNAME_ERROR,
|
|
level=ErrorLevel.ERROR,
|
|
extra={"invalid": ["host"]},
|
|
),
|
|
]
|
|
|
|
|
|
@mock.patch("superset.db_engine_specs.base.is_hostname_valid")
|
|
@mock.patch("superset.db_engine_specs.base.is_port_open")
|
|
def test_validate_parameters_port_closed(is_port_open, is_hostname_valid):
|
|
is_hostname_valid.return_value = True
|
|
is_port_open.return_value = False
|
|
|
|
properties = {
|
|
"parameters": {
|
|
"host": "localhost",
|
|
"port": 5432,
|
|
"username": "username",
|
|
"password": "password",
|
|
"database": "dbname",
|
|
"query": {"sslmode": "verify-full"},
|
|
}
|
|
}
|
|
errors = BasicParametersMixin.validate_parameters(properties)
|
|
assert errors == [
|
|
SupersetError(
|
|
message="The port is closed.",
|
|
error_type=SupersetErrorType.CONNECTION_PORT_CLOSED_ERROR,
|
|
level=ErrorLevel.ERROR,
|
|
extra={
|
|
"invalid": ["port"],
|
|
"issue_codes": [
|
|
{"code": 1008, "message": "Issue 1008 - The port is closed."}
|
|
],
|
|
},
|
|
)
|
|
]
|