feat: support for KQL in `SQLScript` (#27522)

This commit is contained in:
Beto Dealmeida 2024-03-22 12:48:20 -04:00 committed by GitHub
parent 5083ca0e81
commit cd7972d05b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 468 additions and 49 deletions

View File

@ -19,12 +19,13 @@
from __future__ import annotations
import enum
import logging
import re
import urllib.parse
from collections.abc import Iterable, Iterator
from dataclasses import dataclass
from typing import Any, cast
from typing import Any, cast, Generic, TypeVar
from unittest.mock import Mock
import sqlglot
@ -334,89 +335,175 @@ def is_cte(source: exp.Table, scope: Scope) -> bool:
return source.name in ctes_in_scope
class SQLScript:
# To avoid unnecessary parsing/formatting of queries, the statement has the concept of
# an "internal representation", which is the AST of the SQL statement. For most of the
# engines supported by Superset this is `sqlglot.exp.Expression`, but there is a special
# case: KustoKQL uses a different syntax and there are no Python parsers for it, so we
# store the AST as a string (the original query), and manipulate it with regular
# expressions.
InternalRepresentation = TypeVar("InternalRepresentation")
# The base type. This helps type checking the `split_query` method correctly, since each
# derived class has a more specific return type (the class itself). This will no longer
# be needed once Python 3.11 is the lowest version supported. See PEP 673 for more
# information: https://peps.python.org/pep-0673/
TBaseSQLStatement = TypeVar("TBaseSQLStatement") # pylint: disable=invalid-name
class BaseSQLStatement(Generic[InternalRepresentation]):
"""
A SQL script, with 0+ statements.
Base class for SQL statements.
The class can be instantiated with a string representation of the query or, for
efficiency reasons, with a pre-parsed AST. This is useful with `sqlglot.parse`,
which will split a query in multiple already parsed statements.
The `engine` parameters comes from the `engine` attribute in a Superset DB engine
spec.
"""
def __init__(
self,
query: str,
engine: str | None = None,
statement: str | InternalRepresentation,
engine: str,
):
dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
self._parsed: InternalRepresentation = (
self._parse_statement(statement, engine)
if isinstance(statement, str)
else statement
)
self.engine = engine
self.tables = self._extract_tables_from_statement(self._parsed, self.engine)
self.statements = [
SQLStatement(statement, engine=engine)
for statement in parse(query, dialect=dialect)
if statement
]
@classmethod
def split_query(
cls: type[TBaseSQLStatement],
query: str,
engine: str,
) -> list[TBaseSQLStatement]:
"""
Split a query into multiple instantiated statements.
This is a helper function to split a full SQL query into multiple
`BaseSQLStatement` instances. It's used by `SQLScript` when instantiating the
statements within a query.
"""
raise NotImplementedError()
@classmethod
def _parse_statement(
cls,
statement: str,
engine: str,
) -> InternalRepresentation:
"""
Parse a string containing a single SQL statement, and returns the parsed AST.
Derived classes should not assume that `statement` contains a single statement,
and MUST explicitly validate that. Since this validation is parser dependent the
responsibility is left to the children classes.
"""
raise NotImplementedError()
@classmethod
def _extract_tables_from_statement(
cls,
parsed: InternalRepresentation,
engine: str,
) -> set[Table]:
"""
Extract all table references in a given statement.
"""
raise NotImplementedError()
def format(self, comments: bool = True) -> str:
"""
Pretty-format the SQL query.
Format the statement, optionally ommitting comments.
"""
return ";\n".join(statement.format(comments) for statement in self.statements)
raise NotImplementedError()
def get_settings(self) -> dict[str, str]:
def get_settings(self) -> dict[str, str | bool]:
"""
Return the settings for the SQL query.
Return any settings set by the statement.
>>> statement = SQLScript("SET foo = 'bar'; SET foo = 'baz'")
>>> statement.get_settings()
{"foo": "'baz'"}
For example, for this statement:
sql> SET foo = 'bar';
The method should return `{"foo": "'bar'"}`. Note the single quotes.
"""
settings: dict[str, str] = {}
for statement in self.statements:
settings.update(statement.get_settings())
raise NotImplementedError()
return settings
def __str__(self) -> str:
return self.format()
class SQLStatement:
class SQLStatement(BaseSQLStatement[exp.Expression]):
"""
A SQL statement.
This class provides helper methods to manipulate and introspect SQL.
This class is used for all engines with dialects that can be parsed using sqlglot.
"""
def __init__(
self,
statement: str | exp.Expression,
engine: str | None = None,
engine: str,
):
dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
self._dialect = SQLGLOT_DIALECTS.get(engine)
super().__init__(statement, engine)
if isinstance(statement, str):
try:
self._parsed = self._parse_statement(statement, dialect)
except ParseError as ex:
raise SupersetParseError(statement, engine) from ex
else:
self._parsed = statement
@classmethod
def split_query(
cls,
query: str,
engine: str,
) -> list[SQLStatement]:
dialect = SQLGLOT_DIALECTS.get(engine)
self._dialect = dialect
self.tables = extract_tables_from_statement(self._parsed, dialect)
try:
statements = sqlglot.parse(query, dialect=dialect)
except sqlglot.errors.ParseError as ex:
raise SupersetParseError("Unable to split query") from ex
@staticmethod
return [cls(statement, engine) for statement in statements if statement]
@classmethod
def _parse_statement(
sql_statement: str,
dialect: Dialects | None,
cls,
statement: str,
engine: str,
) -> exp.Expression:
"""
Parse a single SQL statement.
"""
statements = [
statement
for statement in sqlglot.parse(sql_statement, dialect=dialect)
if statement
]
dialect = SQLGLOT_DIALECTS.get(engine)
# We could parse with `sqlglot.parse_one` to get a single statement, but we need
# to verify that the string contains exactly one statement.
try:
statements = sqlglot.parse(statement, dialect=dialect)
except sqlglot.errors.ParseError as ex:
raise SupersetParseError("Unable to split query") from ex
statements = [statement for statement in statements if statement]
if len(statements) != 1:
raise ValueError("SQLStatement should have exactly one statement")
raise SupersetParseError("SQLStatement should have exactly one statement")
return statements[0]
@classmethod
def _extract_tables_from_statement(
cls,
parsed: exp.Expression,
engine: str,
) -> set[Table]:
"""
Find all referenced tables.
"""
dialect = SQLGLOT_DIALECTS.get(engine)
return extract_tables_from_statement(parsed, dialect)
def format(self, comments: bool = True) -> str:
"""
Pretty-format the SQL statement.
@ -424,7 +511,7 @@ class SQLStatement:
write = Dialect.get_or_raise(self._dialect)
return write.generate(self._parsed, copy=False, comments=comments, pretty=True)
def get_settings(self) -> dict[str, str]:
def get_settings(self) -> dict[str, str | bool]:
"""
Return the settings for the SQL statement.
@ -440,6 +527,192 @@ class SQLStatement:
}
class KQLSplitState(enum.Enum):
"""
State machine for splitting a KQL query.
The state machine keeps track of whether we're inside a string or not, so we
don't split the query in a semi-colon that's part of a string.
"""
OUTSIDE_STRING = enum.auto()
INSIDE_SINGLE_QUOTED_STRING = enum.auto()
INSIDE_DOUBLE_QUOTED_STRING = enum.auto()
INSIDE_MULTILINE_STRING = enum.auto()
def split_kql(kql: str) -> list[str]:
"""
Custom function for splitting KQL statements.
"""
statements = []
state = KQLSplitState.OUTSIDE_STRING
statement_start = 0
query = kql if kql.endswith(";") else kql + ";"
for i, character in enumerate(query):
if state == KQLSplitState.OUTSIDE_STRING:
if character == ";":
statements.append(query[statement_start:i])
statement_start = i + 1
elif character == "'":
state = KQLSplitState.INSIDE_SINGLE_QUOTED_STRING
elif character == '"':
state = KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING
elif character == "`" and query[i - 2 : i] == "``":
state = KQLSplitState.INSIDE_MULTILINE_STRING
elif (
state == KQLSplitState.INSIDE_SINGLE_QUOTED_STRING
and character == "'"
and query[i - 1] != "\\"
):
state = KQLSplitState.OUTSIDE_STRING
elif (
state == KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING
and character == '"'
and query[i - 1] != "\\"
):
state = KQLSplitState.OUTSIDE_STRING
elif (
state == KQLSplitState.INSIDE_MULTILINE_STRING
and character == "`"
and query[i - 2 : i] == "``"
):
state = KQLSplitState.OUTSIDE_STRING
return statements
class KustoKQLStatement(BaseSQLStatement[str]):
"""
Special class for Kusto KQL.
Kusto KQL is a SQL-like language, but it's not supported by sqlglot. Queries look
like this:
StormEvents
| summarize PropertyDamage = sum(DamageProperty) by State
| join kind=innerunique PopulationData on State
| project State, PropertyDamagePerCapita = PropertyDamage / Population
| sort by PropertyDamagePerCapita
See https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/ for more
details about it.
"""
@classmethod
def split_query(
cls,
query: str,
engine: str,
) -> list[KustoKQLStatement]:
"""
Split a query at semi-colons.
Since we don't have a parser, we use a simple state machine based function. See
https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/scalar-data-types/string
for more information.
"""
return [cls(statement, engine) for statement in split_kql(query)]
@classmethod
def _parse_statement(
cls,
statement: str,
engine: str,
) -> str:
if engine != "kustokql":
raise SupersetParseError(f"Invalid engine: {engine}")
statements = split_kql(statement)
if len(statements) != 1:
raise SupersetParseError("SQLStatement should have exactly one statement")
return statements[0].strip()
@classmethod
def _extract_tables_from_statement(cls, parsed: str, engine: str) -> set[Table]:
"""
Extract all tables referenced in the statement.
StormEvents
| where InjuriesDirect + InjuriesIndirect > 50
| join (PopulationData) on State
| project State, Population, TotalInjuries = InjuriesDirect + InjuriesIndirect
"""
logger.warning(
"Kusto KQL doesn't support table extraction. This means that data access "
"roles will not be enforced by Superset in the database."
)
return set()
def format(self, comments: bool = True) -> str:
"""
Pretty-format the SQL statement.
"""
return self._parsed
def get_settings(self) -> dict[str, str | bool]:
"""
Return the settings for the SQL statement.
>>> statement = KustoKQLStatement("set querytrace;")
>>> statement.get_settings()
{"querytrace": True}
"""
set_regex = r"^set\s+(?P<name>\w+)(?:\s*=\s*(?P<value>\w+))?$"
if match := re.match(set_regex, self._parsed, re.IGNORECASE):
return {match.group("name"): match.group("value") or True}
return {}
class SQLScript:
"""
A SQL script, with 0+ statements.
"""
# Special engines that can't be parsed using sqlglot. Supporting non-SQL engines
# adds a lot of complexity to Superset, so we should avoid adding new engines to
# this data structure.
special_engines = {
"kustokql": KustoKQLStatement,
}
def __init__(
self,
query: str,
engine: str,
):
statement_class = self.special_engines.get(engine, SQLStatement)
self.statements = statement_class.split_query(query, engine)
def format(self, comments: bool = True) -> str:
"""
Pretty-format the SQL query.
"""
return ";\n".join(statement.format(comments) for statement in self.statements)
def get_settings(self) -> dict[str, str | bool]:
"""
Return the settings for the SQL query.
>>> statement = SQLScript("SET foo = 'bar'; SET foo = 'baz'")
>>> statement.get_settings()
{"foo": "'baz'"}
"""
settings: dict[str, str | bool] = {}
for statement in self.statements:
settings.update(statement.get_settings())
return settings
class ParsedQuery:
def __init__(
self,

View File

@ -37,8 +37,10 @@ from superset.sql_parse import (
has_table_query,
insert_rls_as_subquery,
insert_rls_in_predicate,
KustoKQLStatement,
ParsedQuery,
sanitize_clause,
split_kql,
SQLScript,
SQLStatement,
strip_comments_from_sql,
@ -1883,21 +1885,31 @@ def test_sqlquery() -> None:
"""
Test the `SQLScript` class.
"""
script = SQLScript("SELECT 1; SELECT 2;")
script = SQLScript("SELECT 1; SELECT 2;", "sqlite")
assert len(script.statements) == 2
assert script.format() == "SELECT\n 1;\nSELECT\n 2"
assert script.statements[0].format() == "SELECT\n 1"
script = SQLScript("SET a=1; SET a=2; SELECT 3;")
script = SQLScript("SET a=1; SET a=2; SELECT 3;", "sqlite")
assert script.get_settings() == {"a": "2"}
query = SQLScript(
"""set querytrace;
Events | take 100""",
"kustokql",
)
assert query.get_settings() == {"querytrace": True}
def test_sqlstatement() -> None:
"""
Test the `SQLStatement` class.
"""
statement = SQLStatement("SELECT * FROM table1 UNION ALL SELECT * FROM table2")
statement = SQLStatement(
"SELECT * FROM table1 UNION ALL SELECT * FROM table2",
"sqlite",
)
assert statement.tables == {
Table(table="table1", schema=None, catalog=None),
@ -1908,7 +1920,7 @@ def test_sqlstatement() -> None:
== "SELECT\n *\nFROM table1\nUNION ALL\nSELECT\n *\nFROM table2"
)
statement = SQLStatement("SET a=1")
statement = SQLStatement("SET a=1", "sqlite")
assert statement.get_settings() == {"a": "1"}
@ -1950,3 +1962,137 @@ def test_extract_tables_from_jinja_sql(
extract_tables_from_jinja_sql(sql.format(engine=engine, macro=macro), engine)
== expected
)
def test_kustokqlstatement_split_query() -> None:
"""
Test the `KustoKQLStatement` split method.
"""
statements = KustoKQLStatement.split_query(
"""
let totalPagesPerDay = PageViews
| summarize by Page, Day = startofday(Timestamp)
| summarize count() by Day;
let materializedScope = PageViews
| summarize by Page, Day = startofday(Timestamp);
let cachedResult = materialize(materializedScope);
cachedResult
| project Page, Day1 = Day
| join kind = inner
(
cachedResult
| project Page, Day2 = Day
)
on Page
| where Day2 > Day1
| summarize count() by Day1, Day2
| join kind = inner
totalPagesPerDay
on $left.Day1 == $right.Day
| project Day1, Day2, Percentage = count_*100.0/count_1
""",
"kustokql",
)
assert len(statements) == 4
def test_kustokqlstatement_with_program() -> None:
"""
Test the `KustoKQLStatement` split method when the KQL has a program.
"""
statements = KustoKQLStatement.split_query(
"""
print program = ```
public class Program {
public static void Main() {
System.Console.WriteLine("Hello!");
}
}```
""",
"kustokql",
)
assert len(statements) == 1
def test_kustokqlstatement_with_set() -> None:
"""
Test the `KustoKQLStatement` split method when the KQL has a set command.
"""
statements = KustoKQLStatement.split_query(
"""
set querytrace;
Events | take 100
""",
"kustokql",
)
assert len(statements) == 2
assert statements[0].format() == "set querytrace"
assert statements[1].format() == "Events | take 100"
@pytest.mark.parametrize(
"kql,statements",
[
('print banner=strcat("Hello", ", ", "World!")', 1),
(r"print 'O\'Malley\'s'", 1),
(r"print 'O\'Mal;ley\'s'", 1),
("print ```foo;\nbar;\nbaz;```\n", 1),
],
)
def test_kustokql_statement_split_special(kql: str, statements: int) -> None:
assert len(KustoKQLStatement.split_query(kql, "kustokql")) == statements
def test_split_kql() -> None:
"""
Test the `split_kql` function.
"""
kql = """
let totalPagesPerDay = PageViews
| summarize by Page, Day = startofday(Timestamp)
| summarize count() by Day;
let materializedScope = PageViews
| summarize by Page, Day = startofday(Timestamp);
let cachedResult = materialize(materializedScope);
cachedResult
| project Page, Day1 = Day
| join kind = inner
(
cachedResult
| project Page, Day2 = Day
)
on Page
| where Day2 > Day1
| summarize count() by Day1, Day2
| join kind = inner
totalPagesPerDay
on $left.Day1 == $right.Day
| project Day1, Day2, Percentage = count_*100.0/count_1
"""
assert split_kql(kql) == [
"""
let totalPagesPerDay = PageViews
| summarize by Page, Day = startofday(Timestamp)
| summarize count() by Day""",
"""
let materializedScope = PageViews
| summarize by Page, Day = startofday(Timestamp)""",
"""
let cachedResult = materialize(materializedScope)""",
"""
cachedResult
| project Page, Day1 = Day
| join kind = inner
(
cachedResult
| project Page, Day2 = Day
)
on Page
| where Day2 > Day1
| summarize count() by Day1, Day2
| join kind = inner
totalPagesPerDay
on $left.Day1 == $right.Day
| project Day1, Day2, Percentage = count_*100.0/count_1
""",
]