# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import inspect import mock from sqlalchemy import column, select, table from sqlalchemy.dialects.mssql import pymssql from sqlalchemy.types import String, UnicodeText from superset import db_engine_specs from superset.db_engine_specs import ( BaseEngineSpec, BQEngineSpec, HiveEngineSpec, MssqlEngineSpec, MySQLEngineSpec, OracleEngineSpec, PrestoEngineSpec, ) from superset.models.core import Database from .base_tests import SupersetTestCase class DbEngineSpecsTestCase(SupersetTestCase): def test_0_progress(self): log = """ 17/02/07 18:26:27 INFO log.PerfLogger: 17/02/07 18:26:27 INFO log.PerfLogger: """.split('\n') # noqa ignore: E501 self.assertEquals( 0, HiveEngineSpec.progress(log)) def test_number_of_jobs_progress(self): log = """ 17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2 """.split('\n') self.assertEquals(0, HiveEngineSpec.progress(log)) def test_job_1_launched_progress(self): log = """ 17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2 17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2 """.split('\n') self.assertEquals(0, HiveEngineSpec.progress(log)) def test_job_1_launched_stage_1_0_progress(self): log = """ 17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2 17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%, reduce = 0% """.split('\n') # noqa ignore: E501 self.assertEquals(0, HiveEngineSpec.progress(log)) def test_job_1_launched_stage_1_map_40_progress(self): log = """ 17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2 17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%, reduce = 0% 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 40%, reduce = 0% """.split('\n') # noqa ignore: E501 self.assertEquals(10, HiveEngineSpec.progress(log)) def test_job_1_launched_stage_1_map_80_reduce_40_progress(self): log = """ 17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2 17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%, reduce = 0% 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 40%, reduce = 0% 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 80%, reduce = 40% """.split('\n') # noqa ignore: E501 self.assertEquals(30, HiveEngineSpec.progress(log)) def test_job_1_launched_stage_2_stages_progress(self): log = """ 17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2 17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%, reduce = 0% 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 40%, reduce = 0% 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 80%, reduce = 40% 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-2 map = 0%, reduce = 0% 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 100%, reduce = 0% """.split('\n') # noqa ignore: E501 self.assertEquals(12, HiveEngineSpec.progress(log)) def test_job_2_launched_stage_2_stages_progress(self): log = """ 17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2 17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 100%, reduce = 0% 17/02/07 19:15:55 INFO ql.Driver: Launching Job 2 out of 2 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%, reduce = 0% 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 40%, reduce = 0% """.split('\n') # noqa ignore: E501 self.assertEquals(60, HiveEngineSpec.progress(log)) def test_hive_error_msg(self): msg = ( '{...} errorMessage="Error while compiling statement: FAILED: ' 'SemanticException [Error 10001]: Line 4' ':5 Table not found \'fact_ridesfdslakj\'", statusCode=3, ' 'sqlState=\'42S02\', errorCode=10001)){...}') self.assertEquals(( 'Error while compiling statement: FAILED: ' 'SemanticException [Error 10001]: Line 4:5 ' "Table not found 'fact_ridesfdslakj'"), HiveEngineSpec.extract_error_message(Exception(msg))) e = Exception("Some string that doesn't match the regex") self.assertEquals( str(e), HiveEngineSpec.extract_error_message(e)) msg = ( 'errorCode=10001, ' 'errorMessage="Error while compiling statement"), operationHandle' '=None)"' ) self.assertEquals(( 'Error while compiling statement'), HiveEngineSpec.extract_error_message(Exception(msg))) def get_generic_database(self): return Database(sqlalchemy_uri='mysql://localhost') def sql_limit_regex( self, sql, expected_sql, engine_spec_class=MySQLEngineSpec, limit=1000): main = self.get_generic_database() limited = engine_spec_class.apply_limit_to_sql(sql, limit, main) self.assertEquals(expected_sql, limited) def test_extract_limit_from_query(self, engine_spec_class=MySQLEngineSpec): q0 = 'select * from table' q1 = 'select * from mytable limit 10' q2 = 'select * from (select * from my_subquery limit 10) where col=1 limit 20' q3 = 'select * from (select * from my_subquery limit 10);' q4 = 'select * from (select * from my_subquery limit 10) where col=1 limit 20;' q5 = 'select * from mytable limit 10, 20' q6 = 'select * from mytable limit 10 offset 20' q7 = 'select * from mytable limit' q8 = 'select * from mytable limit 10.0' q9 = 'select * from mytable limit x' q10 = 'select * from mytable limit x, 20' q11 = 'select * from mytable limit x offset 20' self.assertEqual(engine_spec_class.get_limit_from_sql(q0), None) self.assertEqual(engine_spec_class.get_limit_from_sql(q1), 10) self.assertEqual(engine_spec_class.get_limit_from_sql(q2), 20) self.assertEqual(engine_spec_class.get_limit_from_sql(q3), None) self.assertEqual(engine_spec_class.get_limit_from_sql(q4), 20) self.assertEqual(engine_spec_class.get_limit_from_sql(q5), 10) self.assertEqual(engine_spec_class.get_limit_from_sql(q6), 10) self.assertEqual(engine_spec_class.get_limit_from_sql(q7), None) self.assertEqual(engine_spec_class.get_limit_from_sql(q8), None) self.assertEqual(engine_spec_class.get_limit_from_sql(q9), None) self.assertEqual(engine_spec_class.get_limit_from_sql(q10), None) self.assertEqual(engine_spec_class.get_limit_from_sql(q11), None) def test_wrapped_query(self): self.sql_limit_regex( 'SELECT * FROM a', 'SELECT * \nFROM (SELECT * FROM a) AS inner_qry \n LIMIT 1000', MssqlEngineSpec, ) def test_wrapped_semi(self): self.sql_limit_regex( 'SELECT * FROM a;', 'SELECT * \nFROM (SELECT * FROM a) AS inner_qry \n LIMIT 1000', MssqlEngineSpec, ) def test_wrapped_semi_tabs(self): self.sql_limit_regex( 'SELECT * FROM a \t \n ; \t \n ', 'SELECT * \nFROM (SELECT * FROM a) AS inner_qry \n LIMIT 1000', MssqlEngineSpec, ) def test_simple_limit_query(self): self.sql_limit_regex( 'SELECT * FROM a', 'SELECT * FROM a LIMIT 1000', ) def test_modify_limit_query(self): self.sql_limit_regex( 'SELECT * FROM a LIMIT 9999', 'SELECT * FROM a LIMIT 1000', ) def test_limit_query_with_limit_subquery(self): self.sql_limit_regex( 'SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 9999', 'SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 1000', ) def test_limit_with_expr(self): self.sql_limit_regex( """ SELECT 'LIMIT 777' AS a , b FROM table LIMIT 99990""", """SELECT 'LIMIT 777' AS a , b FROM table LIMIT 1000""", ) def test_limit_expr_and_semicolon(self): self.sql_limit_regex( """ SELECT 'LIMIT 777' AS a , b FROM table LIMIT 99990 ;""", """SELECT 'LIMIT 777' AS a , b FROM table LIMIT 1000""", ) def test_get_datatype(self): self.assertEquals('STRING', PrestoEngineSpec.get_datatype('string')) self.assertEquals('TINY', MySQLEngineSpec.get_datatype(1)) self.assertEquals('VARCHAR', MySQLEngineSpec.get_datatype(15)) self.assertEquals('VARCHAR', BaseEngineSpec.get_datatype('VARCHAR')) def test_limit_with_implicit_offset(self): self.sql_limit_regex( """ SELECT 'LIMIT 777' AS a , b FROM table LIMIT 99990, 999999""", """SELECT 'LIMIT 777' AS a , b FROM table LIMIT 99990, 1000""", ) def test_limit_with_explicit_offset(self): self.sql_limit_regex( """ SELECT 'LIMIT 777' AS a , b FROM table LIMIT 99990 OFFSET 999999""", """SELECT 'LIMIT 777' AS a , b FROM table LIMIT 1000 OFFSET 999999""", ) def test_limit_with_non_token_limit(self): self.sql_limit_regex( """ SELECT 'LIMIT 777'""", """ SELECT 'LIMIT 777' LIMIT 1000""", ) def test_time_grain_blacklist(self): blacklist = ['PT1M'] time_grains = { 'PT1S': 'second', 'PT1M': 'minute', } time_grain_functions = { 'PT1S': '{col}', 'PT1M': '{col}', } time_grains = db_engine_specs._create_time_grains_tuple(time_grains, time_grain_functions, blacklist) self.assertEqual(1, len(time_grains)) self.assertEqual('PT1S', time_grains[0].duration) def test_engine_time_grain_validity(self): time_grains = set(db_engine_specs.builtin_time_grains.keys()) # loop over all subclasses of BaseEngineSpec for cls_name, cls in inspect.getmembers(db_engine_specs): if inspect.isclass(cls) and issubclass(cls, BaseEngineSpec) \ and cls is not BaseEngineSpec: # make sure time grain functions have been defined self.assertGreater(len(cls.time_grain_functions), 0) # make sure that all defined time grains are supported defined_time_grains = {grain.duration for grain in cls.get_time_grains()} intersection = time_grains.intersection(defined_time_grains) self.assertSetEqual(defined_time_grains, intersection, cls_name) def test_presto_get_view_names_return_empty_list(self): self.assertEquals([], PrestoEngineSpec.get_view_names(mock.ANY, mock.ANY)) def test_hive_get_view_names_return_empty_list(self): self.assertEquals([], HiveEngineSpec.get_view_names(mock.ANY, mock.ANY)) def test_bigquery_sqla_column_label(self): label = BQEngineSpec.make_label_compatible(column('Col').name) label_expected = 'Col' self.assertEqual(label, label_expected) label = BQEngineSpec.make_label_compatible(column('SUM(x)').name) label_expected = 'SUM_x__5f110b965a993675bc4953bb3e03c4a5' self.assertEqual(label, label_expected) label = BQEngineSpec.make_label_compatible(column('SUM[x]').name) label_expected = 'SUM_x__7ebe14a3f9534aeee125449b0bc083a8' self.assertEqual(label, label_expected) label = BQEngineSpec.make_label_compatible(column('12345_col').name) label_expected = '_12345_col_8d3906e2ea99332eb185f7f8ecb2ffd6' self.assertEqual(label, label_expected) def test_oracle_sqla_column_name_length_exceeded(self): col = column('This_Is_32_Character_Column_Name') label = OracleEngineSpec.make_label_compatible(col.name) self.assertEqual(label.quote, True) label_expected = '3b26974078683be078219674eeb8f5' self.assertEqual(label, label_expected) def test_mssql_column_types(self): def assert_type(type_string, type_expected): type_assigned = MssqlEngineSpec.get_sqla_column_type(type_string) if type_expected is None: self.assertIsNone(type_assigned) else: self.assertIsInstance(type_assigned, type_expected) assert_type('INT', None) assert_type('STRING', String) assert_type('CHAR(10)', String) assert_type('VARCHAR(10)', String) assert_type('TEXT', String) assert_type('NCHAR(10)', UnicodeText) assert_type('NVARCHAR(10)', UnicodeText) assert_type('NTEXT', UnicodeText) def test_mssql_where_clause_n_prefix(self): dialect = pymssql.dialect() spec = MssqlEngineSpec str_col = column('col', type_=spec.get_sqla_column_type('VARCHAR(10)')) unicode_col = column('unicode_col', type_=spec.get_sqla_column_type('NTEXT')) tbl = table('tbl') sel = select([str_col, unicode_col]).\ select_from(tbl).\ where(str_col == 'abc').\ where(unicode_col == 'abc') query = str(sel.compile(dialect=dialect, compile_kwargs={'literal_binds': True})) query_expected = "SELECT col, unicode_col \nFROM tbl \nWHERE col = 'abc' AND unicode_col = N'abc'" # noqa self.assertEqual(query, query_expected)