chore: improve SQL parsing (#26767)

This commit is contained in:
Beto Dealmeida 2024-03-13 18:27:01 -04:00 committed by GitHub
parent a75bb7685d
commit 26d8077e97
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 393 additions and 195 deletions

View File

@ -25,7 +25,7 @@ describe('AdhocMetrics', () => {
});
it('Clear metric and set simple adhoc metric', () => {
const metric = 'sum(num_girls)';
const metric = 'SUM(num_girls)';
const metricName = 'Sum Girls';
cy.get('[data-test=metrics]')
.find('[data-test="remove-control-button"]')

View File

@ -100,7 +100,7 @@ describe('Visualization > Table', () => {
});
cy.verifySliceSuccess({
waitAlias: '@chartData',
querySubstring: /group by.*name/i,
querySubstring: /group by\n.*name/i,
chartSelector: 'table',
});
});
@ -246,7 +246,7 @@ describe('Visualization > Table', () => {
cy.visitChartByParams(formData);
cy.verifySliceSuccess({
waitAlias: '@chartData',
querySubstring: /group by.*state/i,
querySubstring: /group by\n.*state/i,
chartSelector: 'table',
});
cy.get('td').contains(/\d*%/);

View File

@ -921,6 +921,7 @@ export function formatQuery(queryEditor) {
const { sql } = getUpToDateQuery(getState(), queryEditor);
return SupersetClient.post({
endpoint: `/api/v1/sqllab/format_sql/`,
// TODO (betodealmeida): pass engine as a parameter for better formatting
body: JSON.stringify({ sql }),
headers: { 'Content-Type': 'application/json' },
}).then(({ json }) => {

View File

@ -33,7 +33,6 @@ import dateutil.parser
import numpy as np
import pandas as pd
import sqlalchemy as sa
import sqlparse
from flask import escape, Markup
from flask_appbuilder import Model
from flask_appbuilder.security.sqla.models import User
@ -104,7 +103,6 @@ from superset.models.helpers import (
ExploreMixin,
ImportExportMixin,
QueryResult,
QueryStringExtended,
validate_adhoc_subquery,
)
from superset.models.slice import Slice
@ -1099,7 +1097,9 @@ def _process_sql_expression(
class SqlaTable(
Model, BaseDatasource, ExploreMixin
Model,
BaseDatasource,
ExploreMixin,
): # pylint: disable=too-many-public-methods
"""An ORM object for SqlAlchemy table references"""
@ -1413,26 +1413,6 @@ class SqlaTable(
def get_template_processor(self, **kwargs: Any) -> BaseTemplateProcessor:
return get_template_processor(table=self, database=self.database, **kwargs)
def get_query_str_extended(
self,
query_obj: QueryObjectDict,
mutate: bool = True,
) -> QueryStringExtended:
sqlaq = self.get_sqla_query(**query_obj)
sql = self.database.compile_sqla_query(sqlaq.sqla_query)
sql = self._apply_cte(sql, sqlaq.cte)
sql = sqlparse.format(sql, reindent=True)
if mutate:
sql = self.mutate_query_from_config(sql)
return QueryStringExtended(
applied_template_filters=sqlaq.applied_template_filters,
applied_filter_columns=sqlaq.applied_filter_columns,
rejected_filter_columns=sqlaq.rejected_filter_columns,
labels_expected=sqlaq.labels_expected,
prequeries=sqlaq.prequeries,
sql=sql,
)
def get_query_str(self, query_obj: QueryObjectDict) -> str:
query_str_ext = self.get_query_str_extended(query_obj)
all_queries = query_str_ext.prequeries + [query_str_ext.sql]
@ -1474,33 +1454,6 @@ class SqlaTable(
return from_clause, cte
def get_rendered_sql(
self, template_processor: BaseTemplateProcessor | None = None
) -> str:
"""
Render sql with template engine (Jinja).
"""
sql = self.sql
if template_processor:
try:
sql = template_processor.process_template(sql)
except TemplateError as ex:
raise QueryObjectValidationError(
_(
"Error while rendering virtual dataset query: %(msg)s",
msg=ex.message,
)
) from ex
sql = sqlparse.format(sql.strip("\t\r\n; "), strip_comments=True)
if not sql:
raise QueryObjectValidationError(_("Virtual dataset query cannot be empty"))
if len(sqlparse.split(sql)) > 1:
raise QueryObjectValidationError(
_("Virtual dataset query cannot consist of multiple statements")
)
return sql
def adhoc_metric_to_sqla(
self,
metric: AdhocMetric,

View File

@ -59,7 +59,7 @@ from superset import security_manager, sql_parse
from superset.constants import TimeGrain as TimeGrainConstants
from superset.databases.utils import make_url_safe
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.sql_parse import ParsedQuery, Table
from superset.sql_parse import ParsedQuery, SQLScript, Table
from superset.superset_typing import ResultSetColumnType, SQLAColumnType
from superset.utils import core as utils
from superset.utils.core import ColumnSpec, GenericDataType
@ -1448,7 +1448,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
qry = partition_query
sql = database.compile_sqla_query(qry)
if indent:
sql = sqlparse.format(sql, reindent=True)
sql = SQLScript(sql, engine=cls.engine).format()
return sql
@classmethod

View File

@ -24,7 +24,6 @@ from datetime import datetime
from re import Pattern
from typing import Any, TYPE_CHECKING
import sqlparse
from flask_babel import gettext as __
from sqlalchemy.dialects.postgresql import DOUBLE_PRECISION, ENUM, JSON
from sqlalchemy.dialects.postgresql.base import PGInspector
@ -37,6 +36,7 @@ from superset.db_engine_specs.base import BaseEngineSpec, BasicParametersMixin
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetException, SupersetSecurityException
from superset.models.sql_lab import Query
from superset.sql_parse import SQLScript
from superset.utils import core as utils
from superset.utils.core import GenericDataType
@ -281,8 +281,9 @@ class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec):
This method simply uses the parent method after checking that there are no
malicious path setting in the query.
"""
sql = sqlparse.format(query.sql, strip_comments=True)
if re.search(r"set\s+search_path\s*=", sql, re.IGNORECASE):
script = SQLScript(query.sql, engine=cls.engine)
settings = script.get_settings()
if "search_path" in settings:
raise SupersetSecurityException(
SupersetError(
error_type=SupersetErrorType.QUERY_SECURITY_ACCESS_ERROR,

View File

@ -83,6 +83,7 @@ class SupersetErrorType(StrEnum):
RESULTS_BACKEND_ERROR = "RESULTS_BACKEND_ERROR"
ASYNC_WORKERS_ERROR = "ASYNC_WORKERS_ERROR"
ADHOC_SUBQUERY_NOT_ALLOWED_ERROR = "ADHOC_SUBQUERY_NOT_ALLOWED_ERROR"
INVALID_SQL_ERROR = "INVALID_SQL_ERROR"
# Generic errors
GENERIC_COMMAND_ERROR = "GENERIC_COMMAND_ERROR"
@ -176,6 +177,7 @@ ERROR_TYPES_TO_ISSUE_CODES_MAPPING = {
SupersetErrorType.INVALID_PAYLOAD_SCHEMA_ERROR: [1020],
SupersetErrorType.INVALID_CTAS_QUERY_ERROR: [1023],
SupersetErrorType.INVALID_CVAS_QUERY_ERROR: [1024, 1025],
SupersetErrorType.INVALID_SQL_ERROR: [1003],
SupersetErrorType.SQLLAB_TIMEOUT_ERROR: [1026, 1027],
SupersetErrorType.OBJECT_DOES_NOT_EXIST_ERROR: [1029],
SupersetErrorType.SYNTAX_ERROR: [1030],

View File

@ -295,3 +295,20 @@ class SupersetMarshmallowValidationError(SupersetErrorException):
extra={"messages": exc.messages, "payload": payload},
)
super().__init__(error)
class SupersetParseError(SupersetErrorException):
"""
Exception to be raised when we fail to parse SQL.
"""
status = 422
def __init__(self, sql: str, engine: Optional[str] = None):
error = SupersetError(
message=_("The SQL is invalid and cannot be parsed."),
error_type=SupersetErrorType.INVALID_SQL_ERROR,
level=ErrorLevel.ERROR,
extra={"sql": sql, "engine": engine},
)
super().__init__(error)

View File

@ -64,6 +64,7 @@ from superset.exceptions import (
ColumnNotFoundException,
QueryClauseValidationException,
QueryObjectValidationError,
SupersetParseError,
SupersetSecurityException,
)
from superset.extensions import feature_flag_manager
@ -73,6 +74,8 @@ from superset.sql_parse import (
insert_rls_in_predicate,
ParsedQuery,
sanitize_clause,
SQLScript,
SQLStatement,
)
from superset.superset_typing import (
AdhocMetric,
@ -901,12 +904,18 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
return sql
def get_query_str_extended(
self, query_obj: QueryObjectDict, mutate: bool = True
self,
query_obj: QueryObjectDict,
mutate: bool = True,
) -> QueryStringExtended:
sqlaq = self.get_sqla_query(**query_obj)
sql = self.database.compile_sqla_query(sqlaq.sqla_query)
sql = self._apply_cte(sql, sqlaq.cte)
sql = sqlparse.format(sql, reindent=True)
try:
sql = SQLStatement(sql, engine=self.db_engine_spec.engine).format()
except SupersetParseError:
logger.warning("Unable to parse SQL to format it, passing it as-is")
if mutate:
sql = self.mutate_query_from_config(sql)
return QueryStringExtended(
@ -1054,7 +1063,8 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
)
def get_rendered_sql(
self, template_processor: Optional[BaseTemplateProcessor] = None
self,
template_processor: Optional[BaseTemplateProcessor] = None,
) -> str:
"""
Render sql with template engine (Jinja).
@ -1071,13 +1081,16 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
msg=ex.message,
)
) from ex
sql = sqlparse.format(sql.strip("\t\r\n; "), strip_comments=True)
if not sql:
raise QueryObjectValidationError(_("Virtual dataset query cannot be empty"))
if len(sqlparse.split(sql)) > 1:
script = SQLScript(sql.strip("\t\r\n; "), engine=self.db_engine_spec.engine)
if len(script.statements) > 1:
raise QueryObjectValidationError(
_("Virtual dataset query cannot consist of multiple statements")
)
sql = script.statements[0].format(comments=False)
if not sql:
raise QueryObjectValidationError(_("Virtual dataset query cannot be empty"))
return sql
def text(self, clause: str) -> TextClause:

View File

@ -22,13 +22,14 @@ import re
import urllib.parse
from collections.abc import Iterable, Iterator
from dataclasses import dataclass
from typing import Any, cast, Optional
from typing import Any, cast, Optional, Union
import sqlglot
import sqlparse
from sqlalchemy import and_
from sqlglot import exp, parse, parse_one
from sqlglot.dialects import Dialects
from sqlglot.errors import SqlglotError
from sqlglot.dialects.dialect import Dialect, Dialects
from sqlglot.errors import ParseError, SqlglotError
from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope
from sqlparse import keywords
from sqlparse.lexer import Lexer
@ -55,7 +56,7 @@ from sqlparse.tokens import (
)
from sqlparse.utils import imt
from superset.exceptions import QueryClauseValidationException
from superset.exceptions import QueryClauseValidationException, SupersetParseError
from superset.utils.backports import StrEnum
try:
@ -252,6 +253,185 @@ class Table:
return str(self) == str(__o)
def extract_tables_from_statement(
statement: exp.Expression,
dialect: Optional[Dialects],
) -> set[Table]:
"""
Extract all table references in a single statement.
Please not that this is not trivial; consider the following queries:
DESCRIBE some_table;
SHOW PARTITIONS FROM some_table;
WITH masked_name AS (SELECT * FROM some_table) SELECT * FROM masked_name;
See the unit tests for other tricky cases.
"""
sources: Iterable[exp.Table]
if isinstance(statement, exp.Describe):
# A `DESCRIBE` query has no sources in sqlglot, so we need to explicitly
# query for all tables.
sources = statement.find_all(exp.Table)
elif isinstance(statement, exp.Command):
# Commands, like `SHOW COLUMNS FROM foo`, have to be converted into a
# `SELECT` statetement in order to extract tables.
literal = statement.find(exp.Literal)
if not literal:
return set()
try:
pseudo_query = parse_one(f"SELECT {literal.this}", dialect=dialect)
except ParseError:
return set()
sources = pseudo_query.find_all(exp.Table)
else:
sources = [
source
for scope in traverse_scope(statement)
for source in scope.sources.values()
if isinstance(source, exp.Table) and not is_cte(source, scope)
]
return {
Table(
source.name,
source.db if source.db != "" else None,
source.catalog if source.catalog != "" else None,
)
for source in sources
}
def is_cte(source: exp.Table, scope: Scope) -> bool:
"""
Is the source a CTE?
CTEs in the parent scope look like tables (and are represented by
exp.Table objects), but should not be considered as such;
otherwise a user with access to table `foo` could access any table
with a query like this:
WITH foo AS (SELECT * FROM target_table) SELECT * FROM foo
"""
parent_sources = scope.parent.sources if scope.parent else {}
ctes_in_scope = {
name
for name, parent_scope in parent_sources.items()
if isinstance(parent_scope, Scope) and parent_scope.scope_type == ScopeType.CTE
}
return source.name in ctes_in_scope
class SQLScript:
"""
A SQL script, with 0+ statements.
"""
def __init__(
self,
query: str,
engine: Optional[str] = None,
):
dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
self.statements = [
SQLStatement(statement, engine=engine)
for statement in parse(query, dialect=dialect)
if statement
]
def format(self, comments: bool = True) -> str:
"""
Pretty-format the SQL query.
"""
return ";\n".join(statement.format(comments) for statement in self.statements)
def get_settings(self) -> dict[str, str]:
"""
Return the settings for the SQL query.
>>> statement = SQLScript("SET foo = 'bar'; SET foo = 'baz'")
>>> statement.get_settings()
{"foo": "'baz'"}
"""
settings: dict[str, str] = {}
for statement in self.statements:
settings.update(statement.get_settings())
return settings
class SQLStatement:
"""
A SQL statement.
This class provides helper methods to manipulate and introspect SQL.
"""
def __init__(
self,
statement: Union[str, exp.Expression],
engine: Optional[str] = None,
):
dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
if isinstance(statement, str):
try:
self._parsed = self._parse_statement(statement, dialect)
except ParseError as ex:
raise SupersetParseError(statement, engine) from ex
else:
self._parsed = statement
self._dialect = dialect
self.tables = extract_tables_from_statement(self._parsed, dialect)
@staticmethod
def _parse_statement(
sql_statement: str,
dialect: Optional[Dialects],
) -> exp.Expression:
"""
Parse a single SQL statement.
"""
statements = [
statement
for statement in sqlglot.parse(sql_statement, dialect=dialect)
if statement
]
if len(statements) != 1:
raise ValueError("SQLStatement should have exactly one statement")
return statements[0]
def format(self, comments: bool = True) -> str:
"""
Pretty-format the SQL statement.
"""
write = Dialect.get_or_raise(self._dialect)
return write.generate(self._parsed, copy=False, comments=comments, pretty=True)
def get_settings(self) -> dict[str, str]:
"""
Return the settings for the SQL statement.
>>> statement = SQLStatement("SET foo = 'bar'")
>>> statement.get_settings()
{"foo": "'bar'"}
"""
return {
eq.this.sql(): eq.expression.sql()
for set_item in self._parsed.find_all(exp.SetItem)
for eq in set_item.find_all(exp.EQ)
}
class ParsedQuery:
def __init__(
self,
@ -294,7 +474,7 @@ class ParsedQuery:
return {
table
for statement in statements
for table in self._extract_tables_from_statement(statement)
for table in extract_tables_from_statement(statement, self._dialect)
if statement
}

View File

@ -19,7 +19,6 @@ from typing import Any, cast, Optional
from urllib import parse
import simplejson as json
import sqlparse
from flask import request, Response
from flask_appbuilder import permission_name
from flask_appbuilder.api import expose, protect, rison, safe
@ -38,6 +37,7 @@ from superset.extensions import event_logger
from superset.jinja_context import get_template_processor
from superset.models.sql_lab import Query
from superset.sql_lab import get_sql_results
from superset.sql_parse import SQLScript
from superset.sqllab.command_status import SqlJsonExecutionStatus
from superset.sqllab.exceptions import (
QueryIsForbiddenToAccessException,
@ -230,7 +230,7 @@ class SqlLabRestApi(BaseSupersetApi):
"""
try:
model = self.format_model_schema.load(request.json)
result = sqlparse.format(model["sql"], reindent=True, keyword_case="upper")
result = SQLScript(model["sql"], model.get("engine")).format()
return self.response(200, result=result)
except ValidationError as error:
return self.response_400(message=error.messages)

View File

@ -44,6 +44,7 @@ class EstimateQueryCostSchema(Schema):
class FormatQueryPayloadSchema(Schema):
sql = fields.String(required=True)
engine = fields.String(required=False, allow_none=True)
class ExecutePayloadSchema(Schema):

View File

@ -138,8 +138,8 @@ def get_select_star(table: str, limit: int, schema: Optional[str] = None):
schema = quote_f(schema)
table = quote_f(table)
if schema:
return f"SELECT *\nFROM {schema}.{table}\nLIMIT {limit}"
return f"SELECT *\nFROM {table}\nLIMIT {limit}"
return f"SELECT\n *\nFROM {schema}.{table}\nLIMIT {limit}"
return f"SELECT\n *\nFROM {table}\nLIMIT {limit}"
@pytest.mark.usefixtures("login_as_admin")
@ -333,9 +333,9 @@ def test_run_async_cta_query_with_lower_limit(test_client, ctas_method):
query = wait_for_success(result)
assert QueryStatus.SUCCESS == query.status
sqllite_select_sql = f"SELECT *\nFROM {tmp_table}\nLIMIT {query.limit}\nOFFSET 0"
sqlite_select_sql = f"SELECT\n *\nFROM {tmp_table}\nLIMIT {query.limit}\nOFFSET 0"
assert query.select_sql == (
sqllite_select_sql
sqlite_select_sql
if backend() == "sqlite"
else get_select_star(tmp_table, query.limit)
)

View File

@ -694,7 +694,7 @@ class TestPostChartDataApi(BaseTestChartDataApi):
rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data")
result = rv.json["result"][0]["query"]
if get_example_database().backend != "presto":
assert "('boy' = 'boy')" in result
assert "(\n 'boy' = 'boy'\n )" in result
@with_feature_flags(GLOBAL_ASYNC_QUERIES=True)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@ -1319,13 +1319,13 @@ def test_time_filter_with_grain(test_client, login_as_admin, physical_query_cont
backend = get_example_database().backend
if backend == "sqlite":
assert (
"DATETIME(col5, 'start of day', -strftime('%w', col5) || ' days') >="
"DATETIME(col5, 'start of day', -STRFTIME('%w', col5) || ' days') >="
in query
)
elif backend == "mysql":
assert "DATE(DATE_SUB(col5, INTERVAL DAYOFWEEK(col5) - 1 DAY)) >=" in query
assert "DATE(DATE_SUB(col5, INTERVAL (DAYOFWEEK(col5) - 1) DAY)) >=" in query
elif backend == "postgresql":
assert "DATE_TRUNC('week', col5) >=" in query
assert "DATE_TRUNC('WEEK', col5) >=" in query
elif backend == "presto":
assert "date_trunc('week', CAST(col5 AS TIMESTAMP)) >=" in query

View File

@ -531,7 +531,7 @@ class TestCore(SupersetTestCase):
)
def test_comments_in_sqlatable_query(self):
clean_query = "SELECT '/* val 1 */' as c1, '-- val 2' as c2 FROM tbl"
clean_query = "SELECT\n '/* val 1 */' AS c1,\n '-- val 2' AS c2\nFROM tbl"
commented_query = "/* comment 1 */" + clean_query + "-- comment 2"
table = SqlaTable(
table_name="test_comments_in_sqlatable_query_table",

View File

@ -674,7 +674,7 @@ def test_get_samples_with_multiple_filters(
assert "2000-01-02" in rv.json["result"]["query"]
assert "2000-01-04" in rv.json["result"]["query"]
assert "col3 = 1.2" in rv.json["result"]["query"]
assert "col4 is null" in rv.json["result"]["query"]
assert "col4 IS NULL" in rv.json["result"]["query"]
assert "col2 = 'c'" in rv.json["result"]["query"]

View File

@ -308,10 +308,7 @@ class TestDbEngineSpecs(TestDbEngineSpec):
}
sql = table.get_query_str(query_obj)
assert (
"""ORDER BY case
when gender='boy' then 'male'
else 'female'
end ASC;"""
"ORDER BY\n CASE WHEN gender = 'boy' THEN 'male' ELSE 'female' END ASC"
in sql
)

View File

@ -381,4 +381,4 @@ class TestBigQueryDbEngineSpec(TestDbEngineSpec):
"orderby": [["gender_cc", True]],
}
sql = table.get_query_str(query_obj)
assert "ORDER BY `gender_cc` ASC" in sql
assert "ORDER BY\n `gender_cc` ASC" in sql

View File

@ -17,7 +17,6 @@
# isort:skip_file
import json
from superset.utils.core import DatasourceType
import textwrap
import unittest
from unittest import mock
@ -298,58 +297,25 @@ class TestDatabaseModel(SupersetTestCase):
sql = db.select_star(table_name, show_cols=False, latest_partition=False)
with db.get_sqla_engine_with_context() as engine:
quote = engine.dialect.identifier_preparer.quote_identifier
expected = (
textwrap.dedent(
f"""\
SELECT *
FROM {quote(table_name)}
LIMIT 100"""
)
if db.backend in {"presto", "hive"}
else textwrap.dedent(
f"""\
SELECT *
FROM {table_name}
LIMIT 100"""
)
)
source = quote(table_name) if db.backend in {"presto", "hive"} else table_name
expected = f"SELECT\n *\nFROM {source}\nLIMIT 100"
assert expected in sql
sql = db.select_star(table_name, show_cols=True, latest_partition=False)
# TODO(bkyryliuk): unify sql generation
if db.backend == "presto":
assert (
textwrap.dedent(
"""\
SELECT "source" AS "source",
"target" AS "target",
"value" AS "value"
FROM "energy_usage"
LIMIT 100"""
)
== sql
'SELECT\n "source" AS "source",\n "target" AS "target",\n "value" AS "value"\nFROM "energy_usage"\nLIMIT 100'
in sql
)
elif db.backend == "hive":
assert (
textwrap.dedent(
"""\
SELECT `source`,
`target`,
`value`
FROM `energy_usage`
LIMIT 100"""
)
== sql
"SELECT\n `source`,\n `target`,\n `value`\nFROM `energy_usage`\nLIMIT 100"
in sql
)
else:
assert (
textwrap.dedent(
"""\
SELECT source,
target,
value
FROM energy_usage
LIMIT 100"""
)
"SELECT\n source,\n target,\n value\nFROM energy_usage\nLIMIT 100"
in sql
)
@ -367,12 +333,7 @@ class TestDatabaseModel(SupersetTestCase):
}
fully_qualified_name = fully_qualified_names.get(db.db_engine_spec.engine)
if fully_qualified_name:
expected = textwrap.dedent(
f"""\
SELECT *
FROM {fully_qualified_name}
LIMIT 100"""
)
expected = f"SELECT\n *\nFROM {fully_qualified_name}\nLIMIT 100"
assert sql.startswith(expected)
def test_single_statement(self):

View File

@ -373,11 +373,12 @@ class TestQueryContext(SupersetTestCase):
self.login(username="admin")
payload = get_query_context("birth_names")
sql_text = get_sql_text(payload)
assert "SELECT" in sql_text
assert re.search(r'[`"\[]?num[`"\]]? IS NOT NULL', sql_text)
assert re.search(r'NOT [`"\[]?num[`"\]]? IS NULL', sql_text)
assert re.search(
r"""NOT \([`"\[]?name[`"\]]? IS NULL[\s\n]* """
r"""OR [`"\[]?name[`"\]]? IN \('"abc"'\)\)""",
r"""NOT \([\s\n]*[`"\[]?name[`"\]]? IS NULL[\s\n]* """
r"""OR [`"\[]?name[`"\]]? IN \('"abc"'\)[\s\n]*\)""",
sql_text,
)
@ -396,7 +397,7 @@ class TestQueryContext(SupersetTestCase):
# the alias should be in ORDER BY
assert "ORDER BY `sum__num` DESC" in sql_text
else:
assert re.search(r'ORDER BY [`"\[]?sum__num[`"\]]? DESC', sql_text)
assert re.search(r'ORDER BY[\s\n]* [`"\[]?sum__num[`"\]]? DESC', sql_text)
sql_text = get_sql_text(
get_query_context("birth_names:only_orderby_has_metric")
@ -407,7 +408,9 @@ class TestQueryContext(SupersetTestCase):
assert "ORDER BY `sum__num` DESC" in sql_text
else:
assert re.search(
r'ORDER BY SUM\([`"\[]?num[`"\]]?\) DESC', sql_text, re.IGNORECASE
r'ORDER BY[\s\n]* SUM\([`"\[]?num[`"\]]?\) DESC',
sql_text,
re.IGNORECASE,
)
sql_text = get_sql_text(get_query_context("birth_names:orderby_dup_alias"))
@ -438,7 +441,7 @@ class TestQueryContext(SupersetTestCase):
assert "sum(`num_girls`) AS `SUM(num_girls)`" not in sql_text
# Should reference all ORDER BY columns by aliases
assert "ORDER BY `num_girls` DESC," in sql_text
assert "ORDER BY[\\s\n]* `num_girls` DESC," in sql_text
assert "`AVG(num_boys)` DESC," in sql_text
assert "`MAX(CASE WHEN...` ASC" in sql_text
else:
@ -446,14 +449,14 @@ class TestQueryContext(SupersetTestCase):
# since the selected `num_boys` is renamed to `num_boys__`
# it must be references as expression
assert re.search(
r'ORDER BY SUM\([`"\[]?num_girls[`"\]]?\) DESC',
r'ORDER BY[\s\n]* SUM\([`"\[]?num_girls[`"\]]?\) DESC',
sql_text,
re.IGNORECASE,
)
else:
# Should reference the adhoc metric by alias when possible
assert re.search(
r'ORDER BY [`"\[]?num_girls[`"\]]? DESC',
r'ORDER BY[\s\n]* [`"\[]?num_girls[`"\]]? DESC',
sql_text,
re.IGNORECASE,
)
@ -1075,27 +1078,41 @@ def test_time_offset_with_temporal_range_filter(app_context, physical_dataset):
sqls = query_payload["query"].split(";")
"""
SELECT DATE_TRUNC('quarter', col6) AS col6,
SUM(col1) AS "SUM(col1)"
FROM physical_dataset
WHERE col6 >= TO_TIMESTAMP('2002-01-01 00:00:00.000000', 'YYYY-MM-DD HH24:MI:SS.US')
AND col6 < TO_TIMESTAMP('2003-01-01 00:00:00.000000', 'YYYY-MM-DD HH24:MI:SS.US')
GROUP BY DATE_TRUNC('quarter', col6)
LIMIT 10000;
SELECT
DATETIME(col6, 'start of month', PRINTF('-%d month', (
STRFTIME('%m', col6) - 1
) % 3)) AS col6,
SUM(col1) AS "SUM(col1)"
FROM physical_dataset
WHERE
col6 >= '2002-01-01 00:00:00' AND col6 < '2003-01-01 00:00:00'
GROUP BY
DATETIME(col6, 'start of month', PRINTF('-%d month', (
STRFTIME('%m', col6) - 1
) % 3))
LIMIT 10000
OFFSET 0
SELECT DATE_TRUNC('quarter', col6) AS col6,
SUM(col1) AS "SUM(col1)"
FROM physical_dataset
WHERE col6 >= TO_TIMESTAMP('2001-10-01 00:00:00.000000', 'YYYY-MM-DD HH24:MI:SS.US')
AND col6 < TO_TIMESTAMP('2002-10-01 00:00:00.000000', 'YYYY-MM-DD HH24:MI:SS.US')
GROUP BY DATE_TRUNC('quarter', col6)
LIMIT 10000;
SELECT
DATETIME(col6, 'start of month', PRINTF('-%d month', (
STRFTIME('%m', col6) - 1
) % 3)) AS col6,
SUM(col1) AS "SUM(col1)"
FROM physical_dataset
WHERE
col6 >= '2001-10-01 00:00:00' AND col6 < '2002-10-01 00:00:00'
GROUP BY
DATETIME(col6, 'start of month', PRINTF('-%d month', (
STRFTIME('%m', col6) - 1
) % 3))
LIMIT 10000
OFFSET 0
"""
assert (
re.search(r"WHERE col6 >= .*2002-01-01", sqls[0])
re.search(r"WHERE\n col6 >= .*2002-01-01", sqls[0])
and re.search(r"AND col6 < .*2003-01-01", sqls[0])
) is not None
assert (
re.search(r"WHERE col6 >= .*2001-10-01", sqls[1])
re.search(r"WHERE\n col6 >= .*2001-10-01", sqls[1])
and re.search(r"AND col6 < .*2002-10-01", sqls[1])
) is not None

View File

@ -273,7 +273,7 @@ class TestRowLevelSecurity(SupersetTestCase):
# establish that the filters are grouped together correctly with
# ANDs, ORs and parens in the correct place
assert (
"WHERE ((name like 'A%'\n or name like 'B%')\n OR (name like 'Q%'))\n AND (gender = 'boy');"
"WHERE\n (\n (\n name LIKE 'A%' OR name LIKE 'B%'\n ) OR (\n name LIKE 'Q%'\n )\n )\n AND (\n gender = 'boy'\n )"
in sql
)
@ -619,7 +619,7 @@ class TestRowLevelSecurityWithRelatedAPI(SupersetTestCase):
RLS_ALICE_REGEX = re.compile(r"name = 'Alice'")
RLS_GENDER_REGEX = re.compile(r"AND \(gender = 'girl'\)")
RLS_GENDER_REGEX = re.compile(r"AND \([\s\n]*gender = 'girl'[\s\n]*\)")
@mock.patch.dict(

View File

@ -290,7 +290,7 @@ class TestSqlLabApi(SupersetTestCase):
"/api/v1/sqllab/format_sql/",
json=data,
)
success_resp = {"result": "SELECT 1\nFROM my_table"}
success_resp = {"result": "SELECT\n 1\nFROM my_table"}
resp_data = json.loads(rv.data.decode("utf-8"))
self.assertDictEqual(resp_data, success_resp)
self.assertEqual(rv.status_code, 200)

View File

@ -150,9 +150,10 @@ class TestDatabaseModel(SupersetTestCase):
table1 = SqlaTable(
table_name="test_has_extra_cache_keys_table",
sql="""
SELECT '{{ current_user_id() }}' as id,
SELECT '{{ current_username() }}' as username,
SELECT '{{ current_user_email() }}' as email,
SELECT
'{{ current_user_id() }}' as id,
'{{ current_username() }}' as username,
'{{ current_user_email() }}' as email
""",
database=get_example_database(),
)
@ -166,9 +167,10 @@ class TestDatabaseModel(SupersetTestCase):
table2 = SqlaTable(
table_name="test_has_extra_cache_keys_disabled_table",
sql="""
SELECT '{{ current_user_id(False) }}' as id,
SELECT '{{ current_username(False) }}' as username,
SELECT '{{ current_user_email(False) }}' as email,
SELECT
'{{ current_user_id(False) }}' as id,
'{{ current_username(False) }}' as username,
'{{ current_user_email(False) }}' as email,
""",
database=get_example_database(),
)
@ -250,10 +252,11 @@ class TestDatabaseModel(SupersetTestCase):
sqla_query = table.get_sqla_query(**base_query_obj)
query = table.database.compile_sqla_query(sqla_query.sqla_query)
# assert virtual dataset
assert "SELECT 'user_abc' as user, 'xyz_P1D' as time_grain" in query
assert "SELECT\n 'user_abc' AS user,\n 'xyz_P1D' AS time_grain" in query
# assert dataset calculated column
assert "case when 'abc' = 'abc' then 'yes' else 'no' end AS expr" in query
assert "case when 'abc' = 'abc' then 'yes' else 'no' end" in query
# assert adhoc column
assert "'foo_P1D'" in query
# assert dataset saved metric
@ -746,7 +749,7 @@ def test_none_operand_in_filter(login_as_admin, physical_dataset):
{
"operator": FilterOperator.NOT_EQUALS.value,
"count": 0,
"sql_should_contain": "COL4 IS NOT NULL",
"sql_should_contain": "NOT COL4 IS NULL",
},
]
for expected in expected_results:

View File

@ -219,7 +219,8 @@ def test_select_star(mocker: MockFixture) -> None:
)
assert (
sql
== """SELECT a
== """SELECT
a
FROM my_table
LIMIT ?
OFFSET ?"""
@ -238,6 +239,7 @@ OFFSET ?"""
)
assert (
sql
== """SELECT a
== """SELECT
a
FROM my_table"""
)

View File

@ -148,7 +148,7 @@ def test_select_star(mocker: MockFixture) -> None:
# mock the database so we can compile the query
database = mocker.MagicMock()
database.compile_sqla_query = lambda query: str(
query.compile(dialect=BigQueryDialect())
query.compile(dialect=BigQueryDialect(), compile_kwargs={"literal_binds": True})
)
engine = mocker.MagicMock()
@ -167,9 +167,10 @@ def test_select_star(mocker: MockFixture) -> None:
)
assert (
sql
== """SELECT `trailer` AS `trailer`
== """SELECT
`trailer` AS `trailer`
FROM `my_table`
LIMIT :param_1"""
LIMIT 100"""
)

View File

@ -47,6 +47,15 @@ def test_dataset_macro(mocker: MockFixture) -> None:
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.models.core import Database
mocker.patch(
"superset.connectors.sqla.models.security_manager.get_guest_rls_filters",
return_value=[],
)
mocker.patch(
"superset.models.helpers.security_manager.get_rls_filters",
return_value=[],
)
columns = [
TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"),
TableColumn(column_name="num_boys", type="INTEGER"),
@ -94,11 +103,12 @@ def test_dataset_macro(mocker: MockFixture) -> None:
assert (
dataset_macro(1)
== """(
SELECT ds AS ds,
num_boys AS num_boys,
revenue AS revenue,
expenses AS expenses,
revenue-expenses AS profit
SELECT
ds AS ds,
num_boys AS num_boys,
revenue AS revenue,
expenses AS expenses,
revenue - expenses AS profit
FROM my_schema.old_dataset
) AS dataset_1"""
)
@ -106,28 +116,32 @@ FROM my_schema.old_dataset
assert (
dataset_macro(1, include_metrics=True)
== """(
SELECT ds AS ds,
num_boys AS num_boys,
revenue AS revenue,
expenses AS expenses,
revenue-expenses AS profit,
COUNT(*) AS cnt
SELECT
ds AS ds,
num_boys AS num_boys,
revenue AS revenue,
expenses AS expenses,
revenue - expenses AS profit,
COUNT(*) AS cnt
FROM my_schema.old_dataset
GROUP BY ds,
num_boys,
revenue,
expenses,
revenue-expenses
GROUP BY
ds,
num_boys,
revenue,
expenses,
revenue - expenses
) AS dataset_1"""
)
assert (
dataset_macro(1, include_metrics=True, columns=["ds"])
== """(
SELECT ds AS ds,
COUNT(*) AS cnt
SELECT
ds AS ds,
COUNT(*) AS cnt
FROM my_schema.old_dataset
GROUP BY ds
GROUP BY
ds
) AS dataset_1"""
)

View File

@ -35,6 +35,8 @@ from superset.sql_parse import (
insert_rls_in_predicate,
ParsedQuery,
sanitize_clause,
SQLScript,
SQLStatement,
strip_comments_from_sql,
Table,
)
@ -1850,3 +1852,36 @@ WITH t AS (
)
SELECT * FROM t"""
).is_select()
def test_sqlquery() -> None:
"""
Test the `SQLScript` class.
"""
script = SQLScript("SELECT 1; SELECT 2;")
assert len(script.statements) == 2
assert script.format() == "SELECT\n 1;\nSELECT\n 2"
assert script.statements[0].format() == "SELECT\n 1"
script = SQLScript("SET a=1; SET a=2; SELECT 3;")
assert script.get_settings() == {"a": "2"}
def test_sqlstatement() -> None:
"""
Test the `SQLStatement` class.
"""
statement = SQLStatement("SELECT * FROM table1 UNION ALL SELECT * FROM table2")
assert statement.tables == {
Table(table="table1", schema=None, catalog=None),
Table(table="table2", schema=None, catalog=None),
}
assert (
statement.format()
== "SELECT\n *\nFROM table1\nUNION ALL\nSELECT\n *\nFROM table2"
)
statement = SQLStatement("SET a=1")
assert statement.get_settings() == {"a": "1"}