mirror of
https://github.com/apache/superset.git
synced 2024-09-18 11:39:49 -04:00
d8be0a7dd5
* Break line before LIMIT statement to prevent trailing comment issue This may not be a perfect solution but it addresses the issue in 7483 closes https://github.com/apache/incubator-superset/issues/7483 * fix tests
467 lines
20 KiB
Python
467 lines
20 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
import inspect
|
|
from unittest import mock
|
|
|
|
from sqlalchemy import column, select, table
|
|
from sqlalchemy.dialects.mssql import pymssql
|
|
from sqlalchemy.engine.result import RowProxy
|
|
from sqlalchemy.types import String, UnicodeText
|
|
|
|
from superset import db_engine_specs
|
|
from superset.db_engine_specs import (
|
|
BaseEngineSpec, BQEngineSpec, HiveEngineSpec, MssqlEngineSpec,
|
|
MySQLEngineSpec, OracleEngineSpec, PrestoEngineSpec,
|
|
)
|
|
from superset.models.core import Database
|
|
from .base_tests import SupersetTestCase
|
|
|
|
|
|
class DbEngineSpecsTestCase(SupersetTestCase):
|
|
def test_0_progress(self):
|
|
log = """
|
|
17/02/07 18:26:27 INFO log.PerfLogger: <PERFLOG method=compile from=org.apache.hadoop.hive.ql.Driver>
|
|
17/02/07 18:26:27 INFO log.PerfLogger: <PERFLOG method=parse from=org.apache.hadoop.hive.ql.Driver>
|
|
""".split('\n') # noqa ignore: E501
|
|
self.assertEquals(
|
|
0, HiveEngineSpec.progress(log))
|
|
|
|
def test_number_of_jobs_progress(self):
|
|
log = """
|
|
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
|
|
""".split('\n')
|
|
self.assertEquals(0, HiveEngineSpec.progress(log))
|
|
|
|
def test_job_1_launched_progress(self):
|
|
log = """
|
|
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
|
|
17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
|
|
""".split('\n')
|
|
self.assertEquals(0, HiveEngineSpec.progress(log))
|
|
|
|
def test_job_1_launched_stage_1_0_progress(self):
|
|
log = """
|
|
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
|
|
17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
|
|
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%, reduce = 0%
|
|
""".split('\n') # noqa ignore: E501
|
|
self.assertEquals(0, HiveEngineSpec.progress(log))
|
|
|
|
def test_job_1_launched_stage_1_map_40_progress(self):
|
|
log = """
|
|
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
|
|
17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
|
|
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%, reduce = 0%
|
|
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 40%, reduce = 0%
|
|
""".split('\n') # noqa ignore: E501
|
|
self.assertEquals(10, HiveEngineSpec.progress(log))
|
|
|
|
def test_job_1_launched_stage_1_map_80_reduce_40_progress(self):
|
|
log = """
|
|
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
|
|
17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
|
|
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%, reduce = 0%
|
|
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 40%, reduce = 0%
|
|
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 80%, reduce = 40%
|
|
""".split('\n') # noqa ignore: E501
|
|
self.assertEquals(30, HiveEngineSpec.progress(log))
|
|
|
|
def test_job_1_launched_stage_2_stages_progress(self):
|
|
log = """
|
|
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
|
|
17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
|
|
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%, reduce = 0%
|
|
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 40%, reduce = 0%
|
|
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 80%, reduce = 40%
|
|
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-2 map = 0%, reduce = 0%
|
|
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 100%, reduce = 0%
|
|
""".split('\n') # noqa ignore: E501
|
|
self.assertEquals(12, HiveEngineSpec.progress(log))
|
|
|
|
def test_job_2_launched_stage_2_stages_progress(self):
|
|
log = """
|
|
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
|
|
17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
|
|
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 100%, reduce = 0%
|
|
17/02/07 19:15:55 INFO ql.Driver: Launching Job 2 out of 2
|
|
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%, reduce = 0%
|
|
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 40%, reduce = 0%
|
|
""".split('\n') # noqa ignore: E501
|
|
self.assertEquals(60, HiveEngineSpec.progress(log))
|
|
|
|
def test_hive_error_msg(self):
|
|
msg = (
|
|
'{...} errorMessage="Error while compiling statement: FAILED: '
|
|
'SemanticException [Error 10001]: Line 4'
|
|
':5 Table not found \'fact_ridesfdslakj\'", statusCode=3, '
|
|
'sqlState=\'42S02\', errorCode=10001)){...}')
|
|
self.assertEquals((
|
|
'Error while compiling statement: FAILED: '
|
|
'SemanticException [Error 10001]: Line 4:5 '
|
|
"Table not found 'fact_ridesfdslakj'"),
|
|
HiveEngineSpec.extract_error_message(Exception(msg)))
|
|
|
|
e = Exception("Some string that doesn't match the regex")
|
|
self.assertEquals(
|
|
str(e), HiveEngineSpec.extract_error_message(e))
|
|
|
|
msg = (
|
|
'errorCode=10001, '
|
|
'errorMessage="Error while compiling statement"), operationHandle'
|
|
'=None)"'
|
|
)
|
|
self.assertEquals((
|
|
'Error while compiling statement'),
|
|
HiveEngineSpec.extract_error_message(Exception(msg)))
|
|
|
|
def get_generic_database(self):
|
|
return Database(sqlalchemy_uri='mysql://localhost')
|
|
|
|
def sql_limit_regex(
|
|
self, sql, expected_sql,
|
|
engine_spec_class=MySQLEngineSpec,
|
|
limit=1000):
|
|
main = self.get_generic_database()
|
|
limited = engine_spec_class.apply_limit_to_sql(sql, limit, main)
|
|
self.assertEquals(expected_sql, limited)
|
|
|
|
def test_extract_limit_from_query(self, engine_spec_class=MySQLEngineSpec):
|
|
q0 = 'select * from table'
|
|
q1 = 'select * from mytable limit 10'
|
|
q2 = 'select * from (select * from my_subquery limit 10) where col=1 limit 20'
|
|
q3 = 'select * from (select * from my_subquery limit 10);'
|
|
q4 = 'select * from (select * from my_subquery limit 10) where col=1 limit 20;'
|
|
q5 = 'select * from mytable limit 10, 20'
|
|
q6 = 'select * from mytable limit 10 offset 20'
|
|
q7 = 'select * from mytable limit'
|
|
q8 = 'select * from mytable limit 10.0'
|
|
q9 = 'select * from mytable limit x'
|
|
q10 = 'select * from mytable limit x, 20'
|
|
q11 = 'select * from mytable limit x offset 20'
|
|
|
|
self.assertEqual(engine_spec_class.get_limit_from_sql(q0), None)
|
|
self.assertEqual(engine_spec_class.get_limit_from_sql(q1), 10)
|
|
self.assertEqual(engine_spec_class.get_limit_from_sql(q2), 20)
|
|
self.assertEqual(engine_spec_class.get_limit_from_sql(q3), None)
|
|
self.assertEqual(engine_spec_class.get_limit_from_sql(q4), 20)
|
|
self.assertEqual(engine_spec_class.get_limit_from_sql(q5), 10)
|
|
self.assertEqual(engine_spec_class.get_limit_from_sql(q6), 10)
|
|
self.assertEqual(engine_spec_class.get_limit_from_sql(q7), None)
|
|
self.assertEqual(engine_spec_class.get_limit_from_sql(q8), None)
|
|
self.assertEqual(engine_spec_class.get_limit_from_sql(q9), None)
|
|
self.assertEqual(engine_spec_class.get_limit_from_sql(q10), None)
|
|
self.assertEqual(engine_spec_class.get_limit_from_sql(q11), None)
|
|
|
|
def test_wrapped_query(self):
|
|
self.sql_limit_regex(
|
|
'SELECT * FROM a',
|
|
'SELECT * \nFROM (SELECT * FROM a) AS inner_qry \n LIMIT 1000',
|
|
MssqlEngineSpec,
|
|
)
|
|
|
|
def test_wrapped_semi(self):
|
|
self.sql_limit_regex(
|
|
'SELECT * FROM a;',
|
|
'SELECT * \nFROM (SELECT * FROM a) AS inner_qry \n LIMIT 1000',
|
|
MssqlEngineSpec,
|
|
)
|
|
|
|
def test_wrapped_semi_tabs(self):
|
|
self.sql_limit_regex(
|
|
'SELECT * FROM a \t \n ; \t \n ',
|
|
'SELECT * \nFROM (SELECT * FROM a) AS inner_qry \n LIMIT 1000',
|
|
MssqlEngineSpec,
|
|
)
|
|
|
|
def test_simple_limit_query(self):
|
|
self.sql_limit_regex(
|
|
'SELECT * FROM a',
|
|
'SELECT * FROM a\nLIMIT 1000',
|
|
)
|
|
|
|
def test_modify_limit_query(self):
|
|
self.sql_limit_regex(
|
|
'SELECT * FROM a LIMIT 9999',
|
|
'SELECT * FROM a LIMIT 1000',
|
|
)
|
|
|
|
def test_limit_query_with_limit_subquery(self):
|
|
self.sql_limit_regex(
|
|
'SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 9999',
|
|
'SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 1000',
|
|
)
|
|
|
|
def test_limit_with_expr(self):
|
|
self.sql_limit_regex(
|
|
"""
|
|
SELECT
|
|
'LIMIT 777' AS a
|
|
, b
|
|
FROM
|
|
table
|
|
LIMIT 99990""",
|
|
"""SELECT
|
|
'LIMIT 777' AS a
|
|
, b
|
|
FROM
|
|
table
|
|
LIMIT 1000""",
|
|
)
|
|
|
|
def test_limit_expr_and_semicolon(self):
|
|
self.sql_limit_regex(
|
|
"""
|
|
SELECT
|
|
'LIMIT 777' AS a
|
|
, b
|
|
FROM
|
|
table
|
|
LIMIT 99990 ;""",
|
|
"""SELECT
|
|
'LIMIT 777' AS a
|
|
, b
|
|
FROM
|
|
table
|
|
LIMIT 1000""",
|
|
)
|
|
|
|
def test_get_datatype(self):
|
|
self.assertEquals('STRING', PrestoEngineSpec.get_datatype('string'))
|
|
self.assertEquals('TINY', MySQLEngineSpec.get_datatype(1))
|
|
self.assertEquals('VARCHAR', MySQLEngineSpec.get_datatype(15))
|
|
self.assertEquals('VARCHAR', BaseEngineSpec.get_datatype('VARCHAR'))
|
|
|
|
def test_limit_with_implicit_offset(self):
|
|
self.sql_limit_regex(
|
|
"""
|
|
SELECT
|
|
'LIMIT 777' AS a
|
|
, b
|
|
FROM
|
|
table
|
|
LIMIT 99990, 999999""",
|
|
"""SELECT
|
|
'LIMIT 777' AS a
|
|
, b
|
|
FROM
|
|
table
|
|
LIMIT 99990, 1000""",
|
|
)
|
|
|
|
def test_limit_with_explicit_offset(self):
|
|
self.sql_limit_regex(
|
|
"""
|
|
SELECT
|
|
'LIMIT 777' AS a
|
|
, b
|
|
FROM
|
|
table
|
|
LIMIT 99990
|
|
OFFSET 999999""",
|
|
"""SELECT
|
|
'LIMIT 777' AS a
|
|
, b
|
|
FROM
|
|
table
|
|
LIMIT 1000
|
|
OFFSET 999999""",
|
|
)
|
|
|
|
def test_limit_with_non_token_limit(self):
|
|
self.sql_limit_regex(
|
|
"""
|
|
SELECT
|
|
'LIMIT 777'""",
|
|
"""
|
|
SELECT
|
|
'LIMIT 777'\nLIMIT 1000""",
|
|
)
|
|
|
|
def test_time_grain_blacklist(self):
|
|
blacklist = ['PT1M']
|
|
time_grains = {
|
|
'PT1S': 'second',
|
|
'PT1M': 'minute',
|
|
}
|
|
time_grain_functions = {
|
|
'PT1S': '{col}',
|
|
'PT1M': '{col}',
|
|
}
|
|
time_grains = db_engine_specs._create_time_grains_tuple(time_grains,
|
|
time_grain_functions,
|
|
blacklist)
|
|
self.assertEqual(1, len(time_grains))
|
|
self.assertEqual('PT1S', time_grains[0].duration)
|
|
|
|
def test_engine_time_grain_validity(self):
|
|
time_grains = set(db_engine_specs.builtin_time_grains.keys())
|
|
# loop over all subclasses of BaseEngineSpec
|
|
for cls_name, cls in inspect.getmembers(db_engine_specs):
|
|
if inspect.isclass(cls) and issubclass(cls, BaseEngineSpec) \
|
|
and cls is not BaseEngineSpec:
|
|
# make sure time grain functions have been defined
|
|
self.assertGreater(len(cls.time_grain_functions), 0)
|
|
# make sure that all defined time grains are supported
|
|
defined_time_grains = {grain.duration for grain in cls.get_time_grains()}
|
|
intersection = time_grains.intersection(defined_time_grains)
|
|
self.assertSetEqual(defined_time_grains, intersection, cls_name)
|
|
|
|
def test_presto_get_view_names_return_empty_list(self):
|
|
self.assertEquals([], PrestoEngineSpec.get_view_names(mock.ANY, mock.ANY))
|
|
|
|
def verify_presto_column(self, column, expected_results):
|
|
inspector = mock.Mock()
|
|
inspector.engine.dialect.identifier_preparer.quote_identifier = mock.Mock()
|
|
keymap = {'Column': (None, None, 0),
|
|
'Type': (None, None, 1),
|
|
'Null': (None, None, 2)}
|
|
row = RowProxy(mock.Mock(), column, [None, None, None, None], keymap)
|
|
inspector.bind.execute = mock.Mock(return_value=[row])
|
|
results = PrestoEngineSpec.get_columns(inspector, '', '')
|
|
self.assertEqual(len(expected_results), len(results))
|
|
for expected_result, result in zip(expected_results, results):
|
|
self.assertEqual(expected_result[0], result['name'])
|
|
self.assertEqual(expected_result[1], str(result['type']))
|
|
|
|
def test_presto_get_column(self):
|
|
presto_column = ('column_name', 'boolean', '')
|
|
expected_results = [('column_name', 'BOOLEAN')]
|
|
self.verify_presto_column(presto_column, expected_results)
|
|
|
|
def test_presto_get_simple_row_column(self):
|
|
presto_column = ('column_name', 'row(nested_obj double)', '')
|
|
expected_results = [
|
|
('column_name', 'ROW'),
|
|
('column_name.nested_obj', 'FLOAT')]
|
|
self.verify_presto_column(presto_column, expected_results)
|
|
|
|
def test_presto_get_simple_row_column_with_tricky_name(self):
|
|
presto_column = ('column_name', 'row("Field Name(Tricky, Name)" double)', '')
|
|
expected_results = [
|
|
('column_name', 'ROW'),
|
|
('column_name."Field Name(Tricky, Name)"', 'FLOAT')]
|
|
self.verify_presto_column(presto_column, expected_results)
|
|
|
|
def test_presto_get_simple_array_column(self):
|
|
presto_column = ('column_name', 'array(double)', '')
|
|
expected_results = [('column_name', 'ARRAY')]
|
|
self.verify_presto_column(presto_column, expected_results)
|
|
|
|
def test_presto_get_row_within_array_within_row_column(self):
|
|
presto_column = (
|
|
'column_name',
|
|
'row(nested_array array(row(nested_row double)), nested_obj double)', '')
|
|
expected_results = [
|
|
('column_name', 'ROW'),
|
|
('column_name.nested_array', 'ARRAY'),
|
|
('column_name.nested_array.nested_row', 'FLOAT'),
|
|
('column_name.nested_obj', 'FLOAT'),
|
|
]
|
|
self.verify_presto_column(presto_column, expected_results)
|
|
|
|
def test_presto_get_array_within_row_within_array_column(self):
|
|
presto_column = (
|
|
'column_name',
|
|
'array(row(nested_array array(double), nested_obj double))', '')
|
|
expected_results = [
|
|
('column_name', 'ARRAY'),
|
|
('column_name.nested_array', 'ARRAY'),
|
|
('column_name.nested_obj', 'FLOAT')]
|
|
self.verify_presto_column(presto_column, expected_results)
|
|
|
|
def test_presto_get_fields(self):
|
|
cols = [
|
|
{'name': 'column'},
|
|
{'name': 'column.nested_obj'},
|
|
{'name': 'column."quoted.nested obj"'}]
|
|
actual_results = PrestoEngineSpec._get_fields(cols)
|
|
expected_results = [
|
|
{'name': '"column"', 'label': 'column'},
|
|
{'name': '"column"."nested_obj"', 'label': 'column.nested_obj'},
|
|
{'name': '"column"."quoted.nested obj"',
|
|
'label': 'column."quoted.nested obj"'}]
|
|
for actual_result, expected_result in zip(actual_results, expected_results):
|
|
self.assertEqual(actual_result.element.name, expected_result['name'])
|
|
self.assertEqual(actual_result.name, expected_result['label'])
|
|
|
|
def test_presto_filter_presto_cols(self):
|
|
cols = [
|
|
{'name': 'column', 'type': 'ARRAY'},
|
|
{'name': 'column.nested_obj', 'type': 'FLOAT'}]
|
|
actual_results = PrestoEngineSpec._filter_presto_cols(cols)
|
|
expected_results = [cols[0]]
|
|
self.assertEqual(actual_results, expected_results)
|
|
|
|
def test_hive_get_view_names_return_empty_list(self):
|
|
self.assertEquals([], HiveEngineSpec.get_view_names(mock.ANY, mock.ANY))
|
|
|
|
def test_bigquery_sqla_column_label(self):
|
|
label = BQEngineSpec.make_label_compatible(column('Col').name)
|
|
label_expected = 'Col'
|
|
self.assertEqual(label, label_expected)
|
|
|
|
label = BQEngineSpec.make_label_compatible(column('SUM(x)').name)
|
|
label_expected = 'SUM_x__5f110b965a993675bc4953bb3e03c4a5'
|
|
self.assertEqual(label, label_expected)
|
|
|
|
label = BQEngineSpec.make_label_compatible(column('SUM[x]').name)
|
|
label_expected = 'SUM_x__7ebe14a3f9534aeee125449b0bc083a8'
|
|
self.assertEqual(label, label_expected)
|
|
|
|
label = BQEngineSpec.make_label_compatible(column('12345_col').name)
|
|
label_expected = '_12345_col_8d3906e2ea99332eb185f7f8ecb2ffd6'
|
|
self.assertEqual(label, label_expected)
|
|
|
|
def test_oracle_sqla_column_name_length_exceeded(self):
|
|
col = column('This_Is_32_Character_Column_Name')
|
|
label = OracleEngineSpec.make_label_compatible(col.name)
|
|
self.assertEqual(label.quote, True)
|
|
label_expected = '3b26974078683be078219674eeb8f5'
|
|
self.assertEqual(label, label_expected)
|
|
|
|
def test_mssql_column_types(self):
|
|
def assert_type(type_string, type_expected):
|
|
type_assigned = MssqlEngineSpec.get_sqla_column_type(type_string)
|
|
if type_expected is None:
|
|
self.assertIsNone(type_assigned)
|
|
else:
|
|
self.assertIsInstance(type_assigned, type_expected)
|
|
|
|
assert_type('INT', None)
|
|
assert_type('STRING', String)
|
|
assert_type('CHAR(10)', String)
|
|
assert_type('VARCHAR(10)', String)
|
|
assert_type('TEXT', String)
|
|
assert_type('NCHAR(10)', UnicodeText)
|
|
assert_type('NVARCHAR(10)', UnicodeText)
|
|
assert_type('NTEXT', UnicodeText)
|
|
|
|
def test_mssql_where_clause_n_prefix(self):
|
|
dialect = pymssql.dialect()
|
|
spec = MssqlEngineSpec
|
|
str_col = column('col', type_=spec.get_sqla_column_type('VARCHAR(10)'))
|
|
unicode_col = column('unicode_col', type_=spec.get_sqla_column_type('NTEXT'))
|
|
tbl = table('tbl')
|
|
sel = select([str_col, unicode_col]).\
|
|
select_from(tbl).\
|
|
where(str_col == 'abc').\
|
|
where(unicode_col == 'abc')
|
|
|
|
query = str(sel.compile(dialect=dialect, compile_kwargs={'literal_binds': True}))
|
|
query_expected = "SELECT col, unicode_col \nFROM tbl \nWHERE col = 'abc' AND unicode_col = N'abc'" # noqa
|
|
self.assertEqual(query, query_expected)
|