From 2c564817f1978e34770e02034a7a4c02e1bfdc9f Mon Sep 17 00:00:00 2001 From: John Bodley <4567245+john-bodley@users.noreply.github.com> Date: Sat, 24 Feb 2024 08:47:36 +1300 Subject: [PATCH] fix(sqlglot): Address regressions introduced in #26476 (#27217) --- superset/sql_parse.py | 17 +++++++++++------ tests/unit_tests/sql_parse_tests.py | 10 ++++++---- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/superset/sql_parse.py b/superset/sql_parse.py index 7b89ab8f0e..c85afc9460 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -28,7 +28,7 @@ import sqlparse from sqlalchemy import and_ from sqlglot import exp, parse, parse_one from sqlglot.dialects import Dialects -from sqlglot.errors import ParseError +from sqlglot.errors import SqlglotError from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope from sqlparse import keywords from sqlparse.lexer import Lexer @@ -287,7 +287,7 @@ class ParsedQuery: """ try: statements = parse(self.stripped(), dialect=self._dialect) - except ParseError: + except SqlglotError: logger.warning("Unable to parse SQL (%s): %s", self._dialect, self.sql) return set() @@ -319,12 +319,17 @@ class ParsedQuery: elif isinstance(statement, exp.Command): # Commands, like `SHOW COLUMNS FROM foo`, have to be converted into a # `SELECT` statetement in order to extract tables. - literal = statement.find(exp.Literal) - if not literal: + if not (literal := statement.find(exp.Literal)): return set() - pseudo_query = parse_one(f"SELECT {literal.this}", dialect=self._dialect) - sources = pseudo_query.find_all(exp.Table) + try: + pseudo_query = parse_one( + f"SELECT {literal.this}", + dialect=self._dialect, + ) + sources = pseudo_query.find_all(exp.Table) + except SqlglotError: + return set() else: sources = [ source diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py index f05e16ae85..2fd23f7e8e 100644 --- a/tests/unit_tests/sql_parse_tests.py +++ b/tests/unit_tests/sql_parse_tests.py @@ -271,6 +271,7 @@ def test_extract_tables_illdefined() -> None: assert extract_tables("SELECT * FROM catalogname..tbname") == { Table(table="tbname", schema=None, catalog="catalogname") } + assert extract_tables('SELECT * FROM "tbname') == set() def test_extract_tables_show_tables_from() -> None: @@ -558,6 +559,10 @@ def test_extract_tables_multistatement() -> None: Table("t1"), Table("t2"), } + assert extract_tables( + "ADD JAR file:///hive.jar; SELECT * FROM t1;", + engine="hive", + ) == {Table("t1")} def test_extract_tables_complex() -> None: @@ -1815,10 +1820,7 @@ def test_extract_table_references(mocker: MockerFixture) -> None: # test falling back to sqlparse logger = mocker.patch("superset.sql_parse.logger") sql = "SELECT * FROM table UNION ALL SELECT * FROM other_table" - assert extract_table_references( - sql, - "trino", - ) == { + assert extract_table_references(sql, "trino") == { Table(table="table", schema=None, catalog=None), Table(table="other_table", schema=None, catalog=None), }