mirror of https://github.com/apache/superset.git
feat: support for KQL in `SQLScript` (#27522)
This commit is contained in:
parent
5083ca0e81
commit
cd7972d05b
|
@ -19,12 +19,13 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import logging
|
||||
import re
|
||||
import urllib.parse
|
||||
from collections.abc import Iterable, Iterator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast
|
||||
from typing import Any, cast, Generic, TypeVar
|
||||
from unittest.mock import Mock
|
||||
|
||||
import sqlglot
|
||||
|
@ -334,89 +335,175 @@ def is_cte(source: exp.Table, scope: Scope) -> bool:
|
|||
return source.name in ctes_in_scope
|
||||
|
||||
|
||||
class SQLScript:
|
||||
# To avoid unnecessary parsing/formatting of queries, the statement has the concept of
|
||||
# an "internal representation", which is the AST of the SQL statement. For most of the
|
||||
# engines supported by Superset this is `sqlglot.exp.Expression`, but there is a special
|
||||
# case: KustoKQL uses a different syntax and there are no Python parsers for it, so we
|
||||
# store the AST as a string (the original query), and manipulate it with regular
|
||||
# expressions.
|
||||
InternalRepresentation = TypeVar("InternalRepresentation")
|
||||
|
||||
# The base type. This helps type checking the `split_query` method correctly, since each
|
||||
# derived class has a more specific return type (the class itself). This will no longer
|
||||
# be needed once Python 3.11 is the lowest version supported. See PEP 673 for more
|
||||
# information: https://peps.python.org/pep-0673/
|
||||
TBaseSQLStatement = TypeVar("TBaseSQLStatement") # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class BaseSQLStatement(Generic[InternalRepresentation]):
|
||||
"""
|
||||
A SQL script, with 0+ statements.
|
||||
Base class for SQL statements.
|
||||
|
||||
The class can be instantiated with a string representation of the query or, for
|
||||
efficiency reasons, with a pre-parsed AST. This is useful with `sqlglot.parse`,
|
||||
which will split a query in multiple already parsed statements.
|
||||
|
||||
The `engine` parameters comes from the `engine` attribute in a Superset DB engine
|
||||
spec.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query: str,
|
||||
engine: str | None = None,
|
||||
statement: str | InternalRepresentation,
|
||||
engine: str,
|
||||
):
|
||||
dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
|
||||
self._parsed: InternalRepresentation = (
|
||||
self._parse_statement(statement, engine)
|
||||
if isinstance(statement, str)
|
||||
else statement
|
||||
)
|
||||
self.engine = engine
|
||||
self.tables = self._extract_tables_from_statement(self._parsed, self.engine)
|
||||
|
||||
self.statements = [
|
||||
SQLStatement(statement, engine=engine)
|
||||
for statement in parse(query, dialect=dialect)
|
||||
if statement
|
||||
]
|
||||
@classmethod
|
||||
def split_query(
|
||||
cls: type[TBaseSQLStatement],
|
||||
query: str,
|
||||
engine: str,
|
||||
) -> list[TBaseSQLStatement]:
|
||||
"""
|
||||
Split a query into multiple instantiated statements.
|
||||
|
||||
This is a helper function to split a full SQL query into multiple
|
||||
`BaseSQLStatement` instances. It's used by `SQLScript` when instantiating the
|
||||
statements within a query.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def _parse_statement(
|
||||
cls,
|
||||
statement: str,
|
||||
engine: str,
|
||||
) -> InternalRepresentation:
|
||||
"""
|
||||
Parse a string containing a single SQL statement, and returns the parsed AST.
|
||||
|
||||
Derived classes should not assume that `statement` contains a single statement,
|
||||
and MUST explicitly validate that. Since this validation is parser dependent the
|
||||
responsibility is left to the children classes.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def _extract_tables_from_statement(
|
||||
cls,
|
||||
parsed: InternalRepresentation,
|
||||
engine: str,
|
||||
) -> set[Table]:
|
||||
"""
|
||||
Extract all table references in a given statement.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def format(self, comments: bool = True) -> str:
|
||||
"""
|
||||
Pretty-format the SQL query.
|
||||
Format the statement, optionally ommitting comments.
|
||||
"""
|
||||
return ";\n".join(statement.format(comments) for statement in self.statements)
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_settings(self) -> dict[str, str]:
|
||||
def get_settings(self) -> dict[str, str | bool]:
|
||||
"""
|
||||
Return the settings for the SQL query.
|
||||
Return any settings set by the statement.
|
||||
|
||||
>>> statement = SQLScript("SET foo = 'bar'; SET foo = 'baz'")
|
||||
>>> statement.get_settings()
|
||||
{"foo": "'baz'"}
|
||||
For example, for this statement:
|
||||
|
||||
sql> SET foo = 'bar';
|
||||
|
||||
The method should return `{"foo": "'bar'"}`. Note the single quotes.
|
||||
"""
|
||||
settings: dict[str, str] = {}
|
||||
for statement in self.statements:
|
||||
settings.update(statement.get_settings())
|
||||
raise NotImplementedError()
|
||||
|
||||
return settings
|
||||
def __str__(self) -> str:
|
||||
return self.format()
|
||||
|
||||
|
||||
class SQLStatement:
|
||||
class SQLStatement(BaseSQLStatement[exp.Expression]):
|
||||
"""
|
||||
A SQL statement.
|
||||
|
||||
This class provides helper methods to manipulate and introspect SQL.
|
||||
This class is used for all engines with dialects that can be parsed using sqlglot.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
statement: str | exp.Expression,
|
||||
engine: str | None = None,
|
||||
engine: str,
|
||||
):
|
||||
dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
|
||||
self._dialect = SQLGLOT_DIALECTS.get(engine)
|
||||
super().__init__(statement, engine)
|
||||
|
||||
@classmethod
|
||||
def split_query(
|
||||
cls,
|
||||
query: str,
|
||||
engine: str,
|
||||
) -> list[SQLStatement]:
|
||||
dialect = SQLGLOT_DIALECTS.get(engine)
|
||||
|
||||
if isinstance(statement, str):
|
||||
try:
|
||||
self._parsed = self._parse_statement(statement, dialect)
|
||||
except ParseError as ex:
|
||||
raise SupersetParseError(statement, engine) from ex
|
||||
else:
|
||||
self._parsed = statement
|
||||
statements = sqlglot.parse(query, dialect=dialect)
|
||||
except sqlglot.errors.ParseError as ex:
|
||||
raise SupersetParseError("Unable to split query") from ex
|
||||
|
||||
self._dialect = dialect
|
||||
self.tables = extract_tables_from_statement(self._parsed, dialect)
|
||||
return [cls(statement, engine) for statement in statements if statement]
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
def _parse_statement(
|
||||
sql_statement: str,
|
||||
dialect: Dialects | None,
|
||||
cls,
|
||||
statement: str,
|
||||
engine: str,
|
||||
) -> exp.Expression:
|
||||
"""
|
||||
Parse a single SQL statement.
|
||||
"""
|
||||
statements = [
|
||||
statement
|
||||
for statement in sqlglot.parse(sql_statement, dialect=dialect)
|
||||
if statement
|
||||
]
|
||||
dialect = SQLGLOT_DIALECTS.get(engine)
|
||||
|
||||
# We could parse with `sqlglot.parse_one` to get a single statement, but we need
|
||||
# to verify that the string contains exactly one statement.
|
||||
try:
|
||||
statements = sqlglot.parse(statement, dialect=dialect)
|
||||
except sqlglot.errors.ParseError as ex:
|
||||
raise SupersetParseError("Unable to split query") from ex
|
||||
|
||||
statements = [statement for statement in statements if statement]
|
||||
if len(statements) != 1:
|
||||
raise ValueError("SQLStatement should have exactly one statement")
|
||||
raise SupersetParseError("SQLStatement should have exactly one statement")
|
||||
|
||||
return statements[0]
|
||||
|
||||
@classmethod
|
||||
def _extract_tables_from_statement(
|
||||
cls,
|
||||
parsed: exp.Expression,
|
||||
engine: str,
|
||||
) -> set[Table]:
|
||||
"""
|
||||
Find all referenced tables.
|
||||
"""
|
||||
dialect = SQLGLOT_DIALECTS.get(engine)
|
||||
return extract_tables_from_statement(parsed, dialect)
|
||||
|
||||
def format(self, comments: bool = True) -> str:
|
||||
"""
|
||||
Pretty-format the SQL statement.
|
||||
|
@ -424,7 +511,7 @@ class SQLStatement:
|
|||
write = Dialect.get_or_raise(self._dialect)
|
||||
return write.generate(self._parsed, copy=False, comments=comments, pretty=True)
|
||||
|
||||
def get_settings(self) -> dict[str, str]:
|
||||
def get_settings(self) -> dict[str, str | bool]:
|
||||
"""
|
||||
Return the settings for the SQL statement.
|
||||
|
||||
|
@ -440,6 +527,192 @@ class SQLStatement:
|
|||
}
|
||||
|
||||
|
||||
class KQLSplitState(enum.Enum):
|
||||
"""
|
||||
State machine for splitting a KQL query.
|
||||
|
||||
The state machine keeps track of whether we're inside a string or not, so we
|
||||
don't split the query in a semi-colon that's part of a string.
|
||||
"""
|
||||
|
||||
OUTSIDE_STRING = enum.auto()
|
||||
INSIDE_SINGLE_QUOTED_STRING = enum.auto()
|
||||
INSIDE_DOUBLE_QUOTED_STRING = enum.auto()
|
||||
INSIDE_MULTILINE_STRING = enum.auto()
|
||||
|
||||
|
||||
def split_kql(kql: str) -> list[str]:
|
||||
"""
|
||||
Custom function for splitting KQL statements.
|
||||
"""
|
||||
statements = []
|
||||
state = KQLSplitState.OUTSIDE_STRING
|
||||
statement_start = 0
|
||||
query = kql if kql.endswith(";") else kql + ";"
|
||||
for i, character in enumerate(query):
|
||||
if state == KQLSplitState.OUTSIDE_STRING:
|
||||
if character == ";":
|
||||
statements.append(query[statement_start:i])
|
||||
statement_start = i + 1
|
||||
elif character == "'":
|
||||
state = KQLSplitState.INSIDE_SINGLE_QUOTED_STRING
|
||||
elif character == '"':
|
||||
state = KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING
|
||||
elif character == "`" and query[i - 2 : i] == "``":
|
||||
state = KQLSplitState.INSIDE_MULTILINE_STRING
|
||||
|
||||
elif (
|
||||
state == KQLSplitState.INSIDE_SINGLE_QUOTED_STRING
|
||||
and character == "'"
|
||||
and query[i - 1] != "\\"
|
||||
):
|
||||
state = KQLSplitState.OUTSIDE_STRING
|
||||
|
||||
elif (
|
||||
state == KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING
|
||||
and character == '"'
|
||||
and query[i - 1] != "\\"
|
||||
):
|
||||
state = KQLSplitState.OUTSIDE_STRING
|
||||
|
||||
elif (
|
||||
state == KQLSplitState.INSIDE_MULTILINE_STRING
|
||||
and character == "`"
|
||||
and query[i - 2 : i] == "``"
|
||||
):
|
||||
state = KQLSplitState.OUTSIDE_STRING
|
||||
|
||||
return statements
|
||||
|
||||
|
||||
class KustoKQLStatement(BaseSQLStatement[str]):
|
||||
"""
|
||||
Special class for Kusto KQL.
|
||||
|
||||
Kusto KQL is a SQL-like language, but it's not supported by sqlglot. Queries look
|
||||
like this:
|
||||
|
||||
StormEvents
|
||||
| summarize PropertyDamage = sum(DamageProperty) by State
|
||||
| join kind=innerunique PopulationData on State
|
||||
| project State, PropertyDamagePerCapita = PropertyDamage / Population
|
||||
| sort by PropertyDamagePerCapita
|
||||
|
||||
See https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/ for more
|
||||
details about it.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def split_query(
|
||||
cls,
|
||||
query: str,
|
||||
engine: str,
|
||||
) -> list[KustoKQLStatement]:
|
||||
"""
|
||||
Split a query at semi-colons.
|
||||
|
||||
Since we don't have a parser, we use a simple state machine based function. See
|
||||
https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/scalar-data-types/string
|
||||
for more information.
|
||||
"""
|
||||
return [cls(statement, engine) for statement in split_kql(query)]
|
||||
|
||||
@classmethod
|
||||
def _parse_statement(
|
||||
cls,
|
||||
statement: str,
|
||||
engine: str,
|
||||
) -> str:
|
||||
if engine != "kustokql":
|
||||
raise SupersetParseError(f"Invalid engine: {engine}")
|
||||
|
||||
statements = split_kql(statement)
|
||||
if len(statements) != 1:
|
||||
raise SupersetParseError("SQLStatement should have exactly one statement")
|
||||
|
||||
return statements[0].strip()
|
||||
|
||||
@classmethod
|
||||
def _extract_tables_from_statement(cls, parsed: str, engine: str) -> set[Table]:
|
||||
"""
|
||||
Extract all tables referenced in the statement.
|
||||
|
||||
StormEvents
|
||||
| where InjuriesDirect + InjuriesIndirect > 50
|
||||
| join (PopulationData) on State
|
||||
| project State, Population, TotalInjuries = InjuriesDirect + InjuriesIndirect
|
||||
|
||||
"""
|
||||
logger.warning(
|
||||
"Kusto KQL doesn't support table extraction. This means that data access "
|
||||
"roles will not be enforced by Superset in the database."
|
||||
)
|
||||
return set()
|
||||
|
||||
def format(self, comments: bool = True) -> str:
|
||||
"""
|
||||
Pretty-format the SQL statement.
|
||||
"""
|
||||
return self._parsed
|
||||
|
||||
def get_settings(self) -> dict[str, str | bool]:
|
||||
"""
|
||||
Return the settings for the SQL statement.
|
||||
|
||||
>>> statement = KustoKQLStatement("set querytrace;")
|
||||
>>> statement.get_settings()
|
||||
{"querytrace": True}
|
||||
|
||||
"""
|
||||
set_regex = r"^set\s+(?P<name>\w+)(?:\s*=\s*(?P<value>\w+))?$"
|
||||
if match := re.match(set_regex, self._parsed, re.IGNORECASE):
|
||||
return {match.group("name"): match.group("value") or True}
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
class SQLScript:
|
||||
"""
|
||||
A SQL script, with 0+ statements.
|
||||
"""
|
||||
|
||||
# Special engines that can't be parsed using sqlglot. Supporting non-SQL engines
|
||||
# adds a lot of complexity to Superset, so we should avoid adding new engines to
|
||||
# this data structure.
|
||||
special_engines = {
|
||||
"kustokql": KustoKQLStatement,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query: str,
|
||||
engine: str,
|
||||
):
|
||||
statement_class = self.special_engines.get(engine, SQLStatement)
|
||||
self.statements = statement_class.split_query(query, engine)
|
||||
|
||||
def format(self, comments: bool = True) -> str:
|
||||
"""
|
||||
Pretty-format the SQL query.
|
||||
"""
|
||||
return ";\n".join(statement.format(comments) for statement in self.statements)
|
||||
|
||||
def get_settings(self) -> dict[str, str | bool]:
|
||||
"""
|
||||
Return the settings for the SQL query.
|
||||
|
||||
>>> statement = SQLScript("SET foo = 'bar'; SET foo = 'baz'")
|
||||
>>> statement.get_settings()
|
||||
{"foo": "'baz'"}
|
||||
|
||||
"""
|
||||
settings: dict[str, str | bool] = {}
|
||||
for statement in self.statements:
|
||||
settings.update(statement.get_settings())
|
||||
|
||||
return settings
|
||||
|
||||
|
||||
class ParsedQuery:
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
@ -37,8 +37,10 @@ from superset.sql_parse import (
|
|||
has_table_query,
|
||||
insert_rls_as_subquery,
|
||||
insert_rls_in_predicate,
|
||||
KustoKQLStatement,
|
||||
ParsedQuery,
|
||||
sanitize_clause,
|
||||
split_kql,
|
||||
SQLScript,
|
||||
SQLStatement,
|
||||
strip_comments_from_sql,
|
||||
|
@ -1883,21 +1885,31 @@ def test_sqlquery() -> None:
|
|||
"""
|
||||
Test the `SQLScript` class.
|
||||
"""
|
||||
script = SQLScript("SELECT 1; SELECT 2;")
|
||||
script = SQLScript("SELECT 1; SELECT 2;", "sqlite")
|
||||
|
||||
assert len(script.statements) == 2
|
||||
assert script.format() == "SELECT\n 1;\nSELECT\n 2"
|
||||
assert script.statements[0].format() == "SELECT\n 1"
|
||||
|
||||
script = SQLScript("SET a=1; SET a=2; SELECT 3;")
|
||||
script = SQLScript("SET a=1; SET a=2; SELECT 3;", "sqlite")
|
||||
assert script.get_settings() == {"a": "2"}
|
||||
|
||||
query = SQLScript(
|
||||
"""set querytrace;
|
||||
Events | take 100""",
|
||||
"kustokql",
|
||||
)
|
||||
assert query.get_settings() == {"querytrace": True}
|
||||
|
||||
|
||||
def test_sqlstatement() -> None:
|
||||
"""
|
||||
Test the `SQLStatement` class.
|
||||
"""
|
||||
statement = SQLStatement("SELECT * FROM table1 UNION ALL SELECT * FROM table2")
|
||||
statement = SQLStatement(
|
||||
"SELECT * FROM table1 UNION ALL SELECT * FROM table2",
|
||||
"sqlite",
|
||||
)
|
||||
|
||||
assert statement.tables == {
|
||||
Table(table="table1", schema=None, catalog=None),
|
||||
|
@ -1908,7 +1920,7 @@ def test_sqlstatement() -> None:
|
|||
== "SELECT\n *\nFROM table1\nUNION ALL\nSELECT\n *\nFROM table2"
|
||||
)
|
||||
|
||||
statement = SQLStatement("SET a=1")
|
||||
statement = SQLStatement("SET a=1", "sqlite")
|
||||
assert statement.get_settings() == {"a": "1"}
|
||||
|
||||
|
||||
|
@ -1950,3 +1962,137 @@ def test_extract_tables_from_jinja_sql(
|
|||
extract_tables_from_jinja_sql(sql.format(engine=engine, macro=macro), engine)
|
||||
== expected
|
||||
)
|
||||
|
||||
|
||||
def test_kustokqlstatement_split_query() -> None:
|
||||
"""
|
||||
Test the `KustoKQLStatement` split method.
|
||||
"""
|
||||
statements = KustoKQLStatement.split_query(
|
||||
"""
|
||||
let totalPagesPerDay = PageViews
|
||||
| summarize by Page, Day = startofday(Timestamp)
|
||||
| summarize count() by Day;
|
||||
let materializedScope = PageViews
|
||||
| summarize by Page, Day = startofday(Timestamp);
|
||||
let cachedResult = materialize(materializedScope);
|
||||
cachedResult
|
||||
| project Page, Day1 = Day
|
||||
| join kind = inner
|
||||
(
|
||||
cachedResult
|
||||
| project Page, Day2 = Day
|
||||
)
|
||||
on Page
|
||||
| where Day2 > Day1
|
||||
| summarize count() by Day1, Day2
|
||||
| join kind = inner
|
||||
totalPagesPerDay
|
||||
on $left.Day1 == $right.Day
|
||||
| project Day1, Day2, Percentage = count_*100.0/count_1
|
||||
""",
|
||||
"kustokql",
|
||||
)
|
||||
assert len(statements) == 4
|
||||
|
||||
|
||||
def test_kustokqlstatement_with_program() -> None:
|
||||
"""
|
||||
Test the `KustoKQLStatement` split method when the KQL has a program.
|
||||
"""
|
||||
statements = KustoKQLStatement.split_query(
|
||||
"""
|
||||
print program = ```
|
||||
public class Program {
|
||||
public static void Main() {
|
||||
System.Console.WriteLine("Hello!");
|
||||
}
|
||||
}```
|
||||
""",
|
||||
"kustokql",
|
||||
)
|
||||
assert len(statements) == 1
|
||||
|
||||
|
||||
def test_kustokqlstatement_with_set() -> None:
|
||||
"""
|
||||
Test the `KustoKQLStatement` split method when the KQL has a set command.
|
||||
"""
|
||||
statements = KustoKQLStatement.split_query(
|
||||
"""
|
||||
set querytrace;
|
||||
Events | take 100
|
||||
""",
|
||||
"kustokql",
|
||||
)
|
||||
assert len(statements) == 2
|
||||
assert statements[0].format() == "set querytrace"
|
||||
assert statements[1].format() == "Events | take 100"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"kql,statements",
|
||||
[
|
||||
('print banner=strcat("Hello", ", ", "World!")', 1),
|
||||
(r"print 'O\'Malley\'s'", 1),
|
||||
(r"print 'O\'Mal;ley\'s'", 1),
|
||||
("print ```foo;\nbar;\nbaz;```\n", 1),
|
||||
],
|
||||
)
|
||||
def test_kustokql_statement_split_special(kql: str, statements: int) -> None:
|
||||
assert len(KustoKQLStatement.split_query(kql, "kustokql")) == statements
|
||||
|
||||
|
||||
def test_split_kql() -> None:
|
||||
"""
|
||||
Test the `split_kql` function.
|
||||
"""
|
||||
kql = """
|
||||
let totalPagesPerDay = PageViews
|
||||
| summarize by Page, Day = startofday(Timestamp)
|
||||
| summarize count() by Day;
|
||||
let materializedScope = PageViews
|
||||
| summarize by Page, Day = startofday(Timestamp);
|
||||
let cachedResult = materialize(materializedScope);
|
||||
cachedResult
|
||||
| project Page, Day1 = Day
|
||||
| join kind = inner
|
||||
(
|
||||
cachedResult
|
||||
| project Page, Day2 = Day
|
||||
)
|
||||
on Page
|
||||
| where Day2 > Day1
|
||||
| summarize count() by Day1, Day2
|
||||
| join kind = inner
|
||||
totalPagesPerDay
|
||||
on $left.Day1 == $right.Day
|
||||
| project Day1, Day2, Percentage = count_*100.0/count_1
|
||||
"""
|
||||
assert split_kql(kql) == [
|
||||
"""
|
||||
let totalPagesPerDay = PageViews
|
||||
| summarize by Page, Day = startofday(Timestamp)
|
||||
| summarize count() by Day""",
|
||||
"""
|
||||
let materializedScope = PageViews
|
||||
| summarize by Page, Day = startofday(Timestamp)""",
|
||||
"""
|
||||
let cachedResult = materialize(materializedScope)""",
|
||||
"""
|
||||
cachedResult
|
||||
| project Page, Day1 = Day
|
||||
| join kind = inner
|
||||
(
|
||||
cachedResult
|
||||
| project Page, Day2 = Day
|
||||
)
|
||||
on Page
|
||||
| where Day2 > Day1
|
||||
| summarize count() by Day1, Day2
|
||||
| join kind = inner
|
||||
totalPagesPerDay
|
||||
on $left.Day1 == $right.Day
|
||||
| project Day1, Day2, Percentage = count_*100.0/count_1
|
||||
""",
|
||||
]
|
||||
|
|
Loading…
Reference in New Issue