mirror of https://github.com/apache/superset.git
chore: improve SQL parsing (#26767)
This commit is contained in:
parent
a75bb7685d
commit
26d8077e97
|
@ -25,7 +25,7 @@ describe('AdhocMetrics', () => {
|
|||
});
|
||||
|
||||
it('Clear metric and set simple adhoc metric', () => {
|
||||
const metric = 'sum(num_girls)';
|
||||
const metric = 'SUM(num_girls)';
|
||||
const metricName = 'Sum Girls';
|
||||
cy.get('[data-test=metrics]')
|
||||
.find('[data-test="remove-control-button"]')
|
||||
|
|
|
@ -100,7 +100,7 @@ describe('Visualization > Table', () => {
|
|||
});
|
||||
cy.verifySliceSuccess({
|
||||
waitAlias: '@chartData',
|
||||
querySubstring: /group by.*name/i,
|
||||
querySubstring: /group by\n.*name/i,
|
||||
chartSelector: 'table',
|
||||
});
|
||||
});
|
||||
|
@ -246,7 +246,7 @@ describe('Visualization > Table', () => {
|
|||
cy.visitChartByParams(formData);
|
||||
cy.verifySliceSuccess({
|
||||
waitAlias: '@chartData',
|
||||
querySubstring: /group by.*state/i,
|
||||
querySubstring: /group by\n.*state/i,
|
||||
chartSelector: 'table',
|
||||
});
|
||||
cy.get('td').contains(/\d*%/);
|
||||
|
|
|
@ -921,6 +921,7 @@ export function formatQuery(queryEditor) {
|
|||
const { sql } = getUpToDateQuery(getState(), queryEditor);
|
||||
return SupersetClient.post({
|
||||
endpoint: `/api/v1/sqllab/format_sql/`,
|
||||
// TODO (betodealmeida): pass engine as a parameter for better formatting
|
||||
body: JSON.stringify({ sql }),
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
}).then(({ json }) => {
|
||||
|
|
|
@ -33,7 +33,6 @@ import dateutil.parser
|
|||
import numpy as np
|
||||
import pandas as pd
|
||||
import sqlalchemy as sa
|
||||
import sqlparse
|
||||
from flask import escape, Markup
|
||||
from flask_appbuilder import Model
|
||||
from flask_appbuilder.security.sqla.models import User
|
||||
|
@ -104,7 +103,6 @@ from superset.models.helpers import (
|
|||
ExploreMixin,
|
||||
ImportExportMixin,
|
||||
QueryResult,
|
||||
QueryStringExtended,
|
||||
validate_adhoc_subquery,
|
||||
)
|
||||
from superset.models.slice import Slice
|
||||
|
@ -1099,7 +1097,9 @@ def _process_sql_expression(
|
|||
|
||||
|
||||
class SqlaTable(
|
||||
Model, BaseDatasource, ExploreMixin
|
||||
Model,
|
||||
BaseDatasource,
|
||||
ExploreMixin,
|
||||
): # pylint: disable=too-many-public-methods
|
||||
"""An ORM object for SqlAlchemy table references"""
|
||||
|
||||
|
@ -1413,26 +1413,6 @@ class SqlaTable(
|
|||
def get_template_processor(self, **kwargs: Any) -> BaseTemplateProcessor:
|
||||
return get_template_processor(table=self, database=self.database, **kwargs)
|
||||
|
||||
def get_query_str_extended(
|
||||
self,
|
||||
query_obj: QueryObjectDict,
|
||||
mutate: bool = True,
|
||||
) -> QueryStringExtended:
|
||||
sqlaq = self.get_sqla_query(**query_obj)
|
||||
sql = self.database.compile_sqla_query(sqlaq.sqla_query)
|
||||
sql = self._apply_cte(sql, sqlaq.cte)
|
||||
sql = sqlparse.format(sql, reindent=True)
|
||||
if mutate:
|
||||
sql = self.mutate_query_from_config(sql)
|
||||
return QueryStringExtended(
|
||||
applied_template_filters=sqlaq.applied_template_filters,
|
||||
applied_filter_columns=sqlaq.applied_filter_columns,
|
||||
rejected_filter_columns=sqlaq.rejected_filter_columns,
|
||||
labels_expected=sqlaq.labels_expected,
|
||||
prequeries=sqlaq.prequeries,
|
||||
sql=sql,
|
||||
)
|
||||
|
||||
def get_query_str(self, query_obj: QueryObjectDict) -> str:
|
||||
query_str_ext = self.get_query_str_extended(query_obj)
|
||||
all_queries = query_str_ext.prequeries + [query_str_ext.sql]
|
||||
|
@ -1474,33 +1454,6 @@ class SqlaTable(
|
|||
|
||||
return from_clause, cte
|
||||
|
||||
def get_rendered_sql(
|
||||
self, template_processor: BaseTemplateProcessor | None = None
|
||||
) -> str:
|
||||
"""
|
||||
Render sql with template engine (Jinja).
|
||||
"""
|
||||
|
||||
sql = self.sql
|
||||
if template_processor:
|
||||
try:
|
||||
sql = template_processor.process_template(sql)
|
||||
except TemplateError as ex:
|
||||
raise QueryObjectValidationError(
|
||||
_(
|
||||
"Error while rendering virtual dataset query: %(msg)s",
|
||||
msg=ex.message,
|
||||
)
|
||||
) from ex
|
||||
sql = sqlparse.format(sql.strip("\t\r\n; "), strip_comments=True)
|
||||
if not sql:
|
||||
raise QueryObjectValidationError(_("Virtual dataset query cannot be empty"))
|
||||
if len(sqlparse.split(sql)) > 1:
|
||||
raise QueryObjectValidationError(
|
||||
_("Virtual dataset query cannot consist of multiple statements")
|
||||
)
|
||||
return sql
|
||||
|
||||
def adhoc_metric_to_sqla(
|
||||
self,
|
||||
metric: AdhocMetric,
|
||||
|
|
|
@ -59,7 +59,7 @@ from superset import security_manager, sql_parse
|
|||
from superset.constants import TimeGrain as TimeGrainConstants
|
||||
from superset.databases.utils import make_url_safe
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.sql_parse import ParsedQuery, Table
|
||||
from superset.sql_parse import ParsedQuery, SQLScript, Table
|
||||
from superset.superset_typing import ResultSetColumnType, SQLAColumnType
|
||||
from superset.utils import core as utils
|
||||
from superset.utils.core import ColumnSpec, GenericDataType
|
||||
|
@ -1448,7 +1448,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
qry = partition_query
|
||||
sql = database.compile_sqla_query(qry)
|
||||
if indent:
|
||||
sql = sqlparse.format(sql, reindent=True)
|
||||
sql = SQLScript(sql, engine=cls.engine).format()
|
||||
return sql
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -24,7 +24,6 @@ from datetime import datetime
|
|||
from re import Pattern
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
import sqlparse
|
||||
from flask_babel import gettext as __
|
||||
from sqlalchemy.dialects.postgresql import DOUBLE_PRECISION, ENUM, JSON
|
||||
from sqlalchemy.dialects.postgresql.base import PGInspector
|
||||
|
@ -37,6 +36,7 @@ from superset.db_engine_specs.base import BaseEngineSpec, BasicParametersMixin
|
|||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.exceptions import SupersetException, SupersetSecurityException
|
||||
from superset.models.sql_lab import Query
|
||||
from superset.sql_parse import SQLScript
|
||||
from superset.utils import core as utils
|
||||
from superset.utils.core import GenericDataType
|
||||
|
||||
|
@ -281,8 +281,9 @@ class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec):
|
|||
This method simply uses the parent method after checking that there are no
|
||||
malicious path setting in the query.
|
||||
"""
|
||||
sql = sqlparse.format(query.sql, strip_comments=True)
|
||||
if re.search(r"set\s+search_path\s*=", sql, re.IGNORECASE):
|
||||
script = SQLScript(query.sql, engine=cls.engine)
|
||||
settings = script.get_settings()
|
||||
if "search_path" in settings:
|
||||
raise SupersetSecurityException(
|
||||
SupersetError(
|
||||
error_type=SupersetErrorType.QUERY_SECURITY_ACCESS_ERROR,
|
||||
|
|
|
@ -83,6 +83,7 @@ class SupersetErrorType(StrEnum):
|
|||
RESULTS_BACKEND_ERROR = "RESULTS_BACKEND_ERROR"
|
||||
ASYNC_WORKERS_ERROR = "ASYNC_WORKERS_ERROR"
|
||||
ADHOC_SUBQUERY_NOT_ALLOWED_ERROR = "ADHOC_SUBQUERY_NOT_ALLOWED_ERROR"
|
||||
INVALID_SQL_ERROR = "INVALID_SQL_ERROR"
|
||||
|
||||
# Generic errors
|
||||
GENERIC_COMMAND_ERROR = "GENERIC_COMMAND_ERROR"
|
||||
|
@ -176,6 +177,7 @@ ERROR_TYPES_TO_ISSUE_CODES_MAPPING = {
|
|||
SupersetErrorType.INVALID_PAYLOAD_SCHEMA_ERROR: [1020],
|
||||
SupersetErrorType.INVALID_CTAS_QUERY_ERROR: [1023],
|
||||
SupersetErrorType.INVALID_CVAS_QUERY_ERROR: [1024, 1025],
|
||||
SupersetErrorType.INVALID_SQL_ERROR: [1003],
|
||||
SupersetErrorType.SQLLAB_TIMEOUT_ERROR: [1026, 1027],
|
||||
SupersetErrorType.OBJECT_DOES_NOT_EXIST_ERROR: [1029],
|
||||
SupersetErrorType.SYNTAX_ERROR: [1030],
|
||||
|
|
|
@ -295,3 +295,20 @@ class SupersetMarshmallowValidationError(SupersetErrorException):
|
|||
extra={"messages": exc.messages, "payload": payload},
|
||||
)
|
||||
super().__init__(error)
|
||||
|
||||
|
||||
class SupersetParseError(SupersetErrorException):
|
||||
"""
|
||||
Exception to be raised when we fail to parse SQL.
|
||||
"""
|
||||
|
||||
status = 422
|
||||
|
||||
def __init__(self, sql: str, engine: Optional[str] = None):
|
||||
error = SupersetError(
|
||||
message=_("The SQL is invalid and cannot be parsed."),
|
||||
error_type=SupersetErrorType.INVALID_SQL_ERROR,
|
||||
level=ErrorLevel.ERROR,
|
||||
extra={"sql": sql, "engine": engine},
|
||||
)
|
||||
super().__init__(error)
|
||||
|
|
|
@ -64,6 +64,7 @@ from superset.exceptions import (
|
|||
ColumnNotFoundException,
|
||||
QueryClauseValidationException,
|
||||
QueryObjectValidationError,
|
||||
SupersetParseError,
|
||||
SupersetSecurityException,
|
||||
)
|
||||
from superset.extensions import feature_flag_manager
|
||||
|
@ -73,6 +74,8 @@ from superset.sql_parse import (
|
|||
insert_rls_in_predicate,
|
||||
ParsedQuery,
|
||||
sanitize_clause,
|
||||
SQLScript,
|
||||
SQLStatement,
|
||||
)
|
||||
from superset.superset_typing import (
|
||||
AdhocMetric,
|
||||
|
@ -901,12 +904,18 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
|||
return sql
|
||||
|
||||
def get_query_str_extended(
|
||||
self, query_obj: QueryObjectDict, mutate: bool = True
|
||||
self,
|
||||
query_obj: QueryObjectDict,
|
||||
mutate: bool = True,
|
||||
) -> QueryStringExtended:
|
||||
sqlaq = self.get_sqla_query(**query_obj)
|
||||
sql = self.database.compile_sqla_query(sqlaq.sqla_query)
|
||||
sql = self._apply_cte(sql, sqlaq.cte)
|
||||
sql = sqlparse.format(sql, reindent=True)
|
||||
try:
|
||||
sql = SQLStatement(sql, engine=self.db_engine_spec.engine).format()
|
||||
except SupersetParseError:
|
||||
logger.warning("Unable to parse SQL to format it, passing it as-is")
|
||||
|
||||
if mutate:
|
||||
sql = self.mutate_query_from_config(sql)
|
||||
return QueryStringExtended(
|
||||
|
@ -1054,7 +1063,8 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
|||
)
|
||||
|
||||
def get_rendered_sql(
|
||||
self, template_processor: Optional[BaseTemplateProcessor] = None
|
||||
self,
|
||||
template_processor: Optional[BaseTemplateProcessor] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Render sql with template engine (Jinja).
|
||||
|
@ -1071,13 +1081,16 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
|||
msg=ex.message,
|
||||
)
|
||||
) from ex
|
||||
sql = sqlparse.format(sql.strip("\t\r\n; "), strip_comments=True)
|
||||
if not sql:
|
||||
raise QueryObjectValidationError(_("Virtual dataset query cannot be empty"))
|
||||
if len(sqlparse.split(sql)) > 1:
|
||||
|
||||
script = SQLScript(sql.strip("\t\r\n; "), engine=self.db_engine_spec.engine)
|
||||
if len(script.statements) > 1:
|
||||
raise QueryObjectValidationError(
|
||||
_("Virtual dataset query cannot consist of multiple statements")
|
||||
)
|
||||
|
||||
sql = script.statements[0].format(comments=False)
|
||||
if not sql:
|
||||
raise QueryObjectValidationError(_("Virtual dataset query cannot be empty"))
|
||||
return sql
|
||||
|
||||
def text(self, clause: str) -> TextClause:
|
||||
|
|
|
@ -22,13 +22,14 @@ import re
|
|||
import urllib.parse
|
||||
from collections.abc import Iterable, Iterator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast, Optional
|
||||
from typing import Any, cast, Optional, Union
|
||||
|
||||
import sqlglot
|
||||
import sqlparse
|
||||
from sqlalchemy import and_
|
||||
from sqlglot import exp, parse, parse_one
|
||||
from sqlglot.dialects import Dialects
|
||||
from sqlglot.errors import SqlglotError
|
||||
from sqlglot.dialects.dialect import Dialect, Dialects
|
||||
from sqlglot.errors import ParseError, SqlglotError
|
||||
from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope
|
||||
from sqlparse import keywords
|
||||
from sqlparse.lexer import Lexer
|
||||
|
@ -55,7 +56,7 @@ from sqlparse.tokens import (
|
|||
)
|
||||
from sqlparse.utils import imt
|
||||
|
||||
from superset.exceptions import QueryClauseValidationException
|
||||
from superset.exceptions import QueryClauseValidationException, SupersetParseError
|
||||
from superset.utils.backports import StrEnum
|
||||
|
||||
try:
|
||||
|
@ -252,6 +253,185 @@ class Table:
|
|||
return str(self) == str(__o)
|
||||
|
||||
|
||||
def extract_tables_from_statement(
|
||||
statement: exp.Expression,
|
||||
dialect: Optional[Dialects],
|
||||
) -> set[Table]:
|
||||
"""
|
||||
Extract all table references in a single statement.
|
||||
|
||||
Please not that this is not trivial; consider the following queries:
|
||||
|
||||
DESCRIBE some_table;
|
||||
SHOW PARTITIONS FROM some_table;
|
||||
WITH masked_name AS (SELECT * FROM some_table) SELECT * FROM masked_name;
|
||||
|
||||
See the unit tests for other tricky cases.
|
||||
"""
|
||||
sources: Iterable[exp.Table]
|
||||
|
||||
if isinstance(statement, exp.Describe):
|
||||
# A `DESCRIBE` query has no sources in sqlglot, so we need to explicitly
|
||||
# query for all tables.
|
||||
sources = statement.find_all(exp.Table)
|
||||
elif isinstance(statement, exp.Command):
|
||||
# Commands, like `SHOW COLUMNS FROM foo`, have to be converted into a
|
||||
# `SELECT` statetement in order to extract tables.
|
||||
literal = statement.find(exp.Literal)
|
||||
if not literal:
|
||||
return set()
|
||||
|
||||
try:
|
||||
pseudo_query = parse_one(f"SELECT {literal.this}", dialect=dialect)
|
||||
except ParseError:
|
||||
return set()
|
||||
sources = pseudo_query.find_all(exp.Table)
|
||||
else:
|
||||
sources = [
|
||||
source
|
||||
for scope in traverse_scope(statement)
|
||||
for source in scope.sources.values()
|
||||
if isinstance(source, exp.Table) and not is_cte(source, scope)
|
||||
]
|
||||
|
||||
return {
|
||||
Table(
|
||||
source.name,
|
||||
source.db if source.db != "" else None,
|
||||
source.catalog if source.catalog != "" else None,
|
||||
)
|
||||
for source in sources
|
||||
}
|
||||
|
||||
|
||||
def is_cte(source: exp.Table, scope: Scope) -> bool:
|
||||
"""
|
||||
Is the source a CTE?
|
||||
|
||||
CTEs in the parent scope look like tables (and are represented by
|
||||
exp.Table objects), but should not be considered as such;
|
||||
otherwise a user with access to table `foo` could access any table
|
||||
with a query like this:
|
||||
|
||||
WITH foo AS (SELECT * FROM target_table) SELECT * FROM foo
|
||||
|
||||
"""
|
||||
parent_sources = scope.parent.sources if scope.parent else {}
|
||||
ctes_in_scope = {
|
||||
name
|
||||
for name, parent_scope in parent_sources.items()
|
||||
if isinstance(parent_scope, Scope) and parent_scope.scope_type == ScopeType.CTE
|
||||
}
|
||||
|
||||
return source.name in ctes_in_scope
|
||||
|
||||
|
||||
class SQLScript:
|
||||
"""
|
||||
A SQL script, with 0+ statements.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query: str,
|
||||
engine: Optional[str] = None,
|
||||
):
|
||||
dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
|
||||
|
||||
self.statements = [
|
||||
SQLStatement(statement, engine=engine)
|
||||
for statement in parse(query, dialect=dialect)
|
||||
if statement
|
||||
]
|
||||
|
||||
def format(self, comments: bool = True) -> str:
|
||||
"""
|
||||
Pretty-format the SQL query.
|
||||
"""
|
||||
return ";\n".join(statement.format(comments) for statement in self.statements)
|
||||
|
||||
def get_settings(self) -> dict[str, str]:
|
||||
"""
|
||||
Return the settings for the SQL query.
|
||||
|
||||
>>> statement = SQLScript("SET foo = 'bar'; SET foo = 'baz'")
|
||||
>>> statement.get_settings()
|
||||
{"foo": "'baz'"}
|
||||
|
||||
"""
|
||||
settings: dict[str, str] = {}
|
||||
for statement in self.statements:
|
||||
settings.update(statement.get_settings())
|
||||
|
||||
return settings
|
||||
|
||||
|
||||
class SQLStatement:
|
||||
"""
|
||||
A SQL statement.
|
||||
|
||||
This class provides helper methods to manipulate and introspect SQL.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
statement: Union[str, exp.Expression],
|
||||
engine: Optional[str] = None,
|
||||
):
|
||||
dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
|
||||
|
||||
if isinstance(statement, str):
|
||||
try:
|
||||
self._parsed = self._parse_statement(statement, dialect)
|
||||
except ParseError as ex:
|
||||
raise SupersetParseError(statement, engine) from ex
|
||||
else:
|
||||
self._parsed = statement
|
||||
|
||||
self._dialect = dialect
|
||||
self.tables = extract_tables_from_statement(self._parsed, dialect)
|
||||
|
||||
@staticmethod
|
||||
def _parse_statement(
|
||||
sql_statement: str,
|
||||
dialect: Optional[Dialects],
|
||||
) -> exp.Expression:
|
||||
"""
|
||||
Parse a single SQL statement.
|
||||
"""
|
||||
statements = [
|
||||
statement
|
||||
for statement in sqlglot.parse(sql_statement, dialect=dialect)
|
||||
if statement
|
||||
]
|
||||
if len(statements) != 1:
|
||||
raise ValueError("SQLStatement should have exactly one statement")
|
||||
|
||||
return statements[0]
|
||||
|
||||
def format(self, comments: bool = True) -> str:
|
||||
"""
|
||||
Pretty-format the SQL statement.
|
||||
"""
|
||||
write = Dialect.get_or_raise(self._dialect)
|
||||
return write.generate(self._parsed, copy=False, comments=comments, pretty=True)
|
||||
|
||||
def get_settings(self) -> dict[str, str]:
|
||||
"""
|
||||
Return the settings for the SQL statement.
|
||||
|
||||
>>> statement = SQLStatement("SET foo = 'bar'")
|
||||
>>> statement.get_settings()
|
||||
{"foo": "'bar'"}
|
||||
|
||||
"""
|
||||
return {
|
||||
eq.this.sql(): eq.expression.sql()
|
||||
for set_item in self._parsed.find_all(exp.SetItem)
|
||||
for eq in set_item.find_all(exp.EQ)
|
||||
}
|
||||
|
||||
|
||||
class ParsedQuery:
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -294,7 +474,7 @@ class ParsedQuery:
|
|||
return {
|
||||
table
|
||||
for statement in statements
|
||||
for table in self._extract_tables_from_statement(statement)
|
||||
for table in extract_tables_from_statement(statement, self._dialect)
|
||||
if statement
|
||||
}
|
||||
|
||||
|
|
|
@ -19,7 +19,6 @@ from typing import Any, cast, Optional
|
|||
from urllib import parse
|
||||
|
||||
import simplejson as json
|
||||
import sqlparse
|
||||
from flask import request, Response
|
||||
from flask_appbuilder import permission_name
|
||||
from flask_appbuilder.api import expose, protect, rison, safe
|
||||
|
@ -38,6 +37,7 @@ from superset.extensions import event_logger
|
|||
from superset.jinja_context import get_template_processor
|
||||
from superset.models.sql_lab import Query
|
||||
from superset.sql_lab import get_sql_results
|
||||
from superset.sql_parse import SQLScript
|
||||
from superset.sqllab.command_status import SqlJsonExecutionStatus
|
||||
from superset.sqllab.exceptions import (
|
||||
QueryIsForbiddenToAccessException,
|
||||
|
@ -230,7 +230,7 @@ class SqlLabRestApi(BaseSupersetApi):
|
|||
"""
|
||||
try:
|
||||
model = self.format_model_schema.load(request.json)
|
||||
result = sqlparse.format(model["sql"], reindent=True, keyword_case="upper")
|
||||
result = SQLScript(model["sql"], model.get("engine")).format()
|
||||
return self.response(200, result=result)
|
||||
except ValidationError as error:
|
||||
return self.response_400(message=error.messages)
|
||||
|
|
|
@ -44,6 +44,7 @@ class EstimateQueryCostSchema(Schema):
|
|||
|
||||
class FormatQueryPayloadSchema(Schema):
|
||||
sql = fields.String(required=True)
|
||||
engine = fields.String(required=False, allow_none=True)
|
||||
|
||||
|
||||
class ExecutePayloadSchema(Schema):
|
||||
|
|
|
@ -138,8 +138,8 @@ def get_select_star(table: str, limit: int, schema: Optional[str] = None):
|
|||
schema = quote_f(schema)
|
||||
table = quote_f(table)
|
||||
if schema:
|
||||
return f"SELECT *\nFROM {schema}.{table}\nLIMIT {limit}"
|
||||
return f"SELECT *\nFROM {table}\nLIMIT {limit}"
|
||||
return f"SELECT\n *\nFROM {schema}.{table}\nLIMIT {limit}"
|
||||
return f"SELECT\n *\nFROM {table}\nLIMIT {limit}"
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("login_as_admin")
|
||||
|
@ -333,9 +333,9 @@ def test_run_async_cta_query_with_lower_limit(test_client, ctas_method):
|
|||
query = wait_for_success(result)
|
||||
assert QueryStatus.SUCCESS == query.status
|
||||
|
||||
sqllite_select_sql = f"SELECT *\nFROM {tmp_table}\nLIMIT {query.limit}\nOFFSET 0"
|
||||
sqlite_select_sql = f"SELECT\n *\nFROM {tmp_table}\nLIMIT {query.limit}\nOFFSET 0"
|
||||
assert query.select_sql == (
|
||||
sqllite_select_sql
|
||||
sqlite_select_sql
|
||||
if backend() == "sqlite"
|
||||
else get_select_star(tmp_table, query.limit)
|
||||
)
|
||||
|
|
|
@ -694,7 +694,7 @@ class TestPostChartDataApi(BaseTestChartDataApi):
|
|||
rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data")
|
||||
result = rv.json["result"][0]["query"]
|
||||
if get_example_database().backend != "presto":
|
||||
assert "('boy' = 'boy')" in result
|
||||
assert "(\n 'boy' = 'boy'\n )" in result
|
||||
|
||||
@with_feature_flags(GLOBAL_ASYNC_QUERIES=True)
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
|
@ -1319,13 +1319,13 @@ def test_time_filter_with_grain(test_client, login_as_admin, physical_query_cont
|
|||
backend = get_example_database().backend
|
||||
if backend == "sqlite":
|
||||
assert (
|
||||
"DATETIME(col5, 'start of day', -strftime('%w', col5) || ' days') >="
|
||||
"DATETIME(col5, 'start of day', -STRFTIME('%w', col5) || ' days') >="
|
||||
in query
|
||||
)
|
||||
elif backend == "mysql":
|
||||
assert "DATE(DATE_SUB(col5, INTERVAL DAYOFWEEK(col5) - 1 DAY)) >=" in query
|
||||
assert "DATE(DATE_SUB(col5, INTERVAL (DAYOFWEEK(col5) - 1) DAY)) >=" in query
|
||||
elif backend == "postgresql":
|
||||
assert "DATE_TRUNC('week', col5) >=" in query
|
||||
assert "DATE_TRUNC('WEEK', col5) >=" in query
|
||||
elif backend == "presto":
|
||||
assert "date_trunc('week', CAST(col5 AS TIMESTAMP)) >=" in query
|
||||
|
||||
|
|
|
@ -531,7 +531,7 @@ class TestCore(SupersetTestCase):
|
|||
)
|
||||
|
||||
def test_comments_in_sqlatable_query(self):
|
||||
clean_query = "SELECT '/* val 1 */' as c1, '-- val 2' as c2 FROM tbl"
|
||||
clean_query = "SELECT\n '/* val 1 */' AS c1,\n '-- val 2' AS c2\nFROM tbl"
|
||||
commented_query = "/* comment 1 */" + clean_query + "-- comment 2"
|
||||
table = SqlaTable(
|
||||
table_name="test_comments_in_sqlatable_query_table",
|
||||
|
|
|
@ -674,7 +674,7 @@ def test_get_samples_with_multiple_filters(
|
|||
assert "2000-01-02" in rv.json["result"]["query"]
|
||||
assert "2000-01-04" in rv.json["result"]["query"]
|
||||
assert "col3 = 1.2" in rv.json["result"]["query"]
|
||||
assert "col4 is null" in rv.json["result"]["query"]
|
||||
assert "col4 IS NULL" in rv.json["result"]["query"]
|
||||
assert "col2 = 'c'" in rv.json["result"]["query"]
|
||||
|
||||
|
||||
|
|
|
@ -308,10 +308,7 @@ class TestDbEngineSpecs(TestDbEngineSpec):
|
|||
}
|
||||
sql = table.get_query_str(query_obj)
|
||||
assert (
|
||||
"""ORDER BY case
|
||||
when gender='boy' then 'male'
|
||||
else 'female'
|
||||
end ASC;"""
|
||||
"ORDER BY\n CASE WHEN gender = 'boy' THEN 'male' ELSE 'female' END ASC"
|
||||
in sql
|
||||
)
|
||||
|
||||
|
|
|
@ -381,4 +381,4 @@ class TestBigQueryDbEngineSpec(TestDbEngineSpec):
|
|||
"orderby": [["gender_cc", True]],
|
||||
}
|
||||
sql = table.get_query_str(query_obj)
|
||||
assert "ORDER BY `gender_cc` ASC" in sql
|
||||
assert "ORDER BY\n `gender_cc` ASC" in sql
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
# isort:skip_file
|
||||
import json
|
||||
from superset.utils.core import DatasourceType
|
||||
import textwrap
|
||||
import unittest
|
||||
from unittest import mock
|
||||
|
||||
|
@ -298,58 +297,25 @@ class TestDatabaseModel(SupersetTestCase):
|
|||
sql = db.select_star(table_name, show_cols=False, latest_partition=False)
|
||||
with db.get_sqla_engine_with_context() as engine:
|
||||
quote = engine.dialect.identifier_preparer.quote_identifier
|
||||
expected = (
|
||||
textwrap.dedent(
|
||||
f"""\
|
||||
SELECT *
|
||||
FROM {quote(table_name)}
|
||||
LIMIT 100"""
|
||||
)
|
||||
if db.backend in {"presto", "hive"}
|
||||
else textwrap.dedent(
|
||||
f"""\
|
||||
SELECT *
|
||||
FROM {table_name}
|
||||
LIMIT 100"""
|
||||
)
|
||||
)
|
||||
|
||||
source = quote(table_name) if db.backend in {"presto", "hive"} else table_name
|
||||
expected = f"SELECT\n *\nFROM {source}\nLIMIT 100"
|
||||
assert expected in sql
|
||||
sql = db.select_star(table_name, show_cols=True, latest_partition=False)
|
||||
# TODO(bkyryliuk): unify sql generation
|
||||
if db.backend == "presto":
|
||||
assert (
|
||||
textwrap.dedent(
|
||||
"""\
|
||||
SELECT "source" AS "source",
|
||||
"target" AS "target",
|
||||
"value" AS "value"
|
||||
FROM "energy_usage"
|
||||
LIMIT 100"""
|
||||
)
|
||||
== sql
|
||||
'SELECT\n "source" AS "source",\n "target" AS "target",\n "value" AS "value"\nFROM "energy_usage"\nLIMIT 100'
|
||||
in sql
|
||||
)
|
||||
elif db.backend == "hive":
|
||||
assert (
|
||||
textwrap.dedent(
|
||||
"""\
|
||||
SELECT `source`,
|
||||
`target`,
|
||||
`value`
|
||||
FROM `energy_usage`
|
||||
LIMIT 100"""
|
||||
)
|
||||
== sql
|
||||
"SELECT\n `source`,\n `target`,\n `value`\nFROM `energy_usage`\nLIMIT 100"
|
||||
in sql
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
textwrap.dedent(
|
||||
"""\
|
||||
SELECT source,
|
||||
target,
|
||||
value
|
||||
FROM energy_usage
|
||||
LIMIT 100"""
|
||||
)
|
||||
"SELECT\n source,\n target,\n value\nFROM energy_usage\nLIMIT 100"
|
||||
in sql
|
||||
)
|
||||
|
||||
|
@ -367,12 +333,7 @@ class TestDatabaseModel(SupersetTestCase):
|
|||
}
|
||||
fully_qualified_name = fully_qualified_names.get(db.db_engine_spec.engine)
|
||||
if fully_qualified_name:
|
||||
expected = textwrap.dedent(
|
||||
f"""\
|
||||
SELECT *
|
||||
FROM {fully_qualified_name}
|
||||
LIMIT 100"""
|
||||
)
|
||||
expected = f"SELECT\n *\nFROM {fully_qualified_name}\nLIMIT 100"
|
||||
assert sql.startswith(expected)
|
||||
|
||||
def test_single_statement(self):
|
||||
|
|
|
@ -373,11 +373,12 @@ class TestQueryContext(SupersetTestCase):
|
|||
self.login(username="admin")
|
||||
payload = get_query_context("birth_names")
|
||||
sql_text = get_sql_text(payload)
|
||||
|
||||
assert "SELECT" in sql_text
|
||||
assert re.search(r'[`"\[]?num[`"\]]? IS NOT NULL', sql_text)
|
||||
assert re.search(r'NOT [`"\[]?num[`"\]]? IS NULL', sql_text)
|
||||
assert re.search(
|
||||
r"""NOT \([`"\[]?name[`"\]]? IS NULL[\s\n]* """
|
||||
r"""OR [`"\[]?name[`"\]]? IN \('"abc"'\)\)""",
|
||||
r"""NOT \([\s\n]*[`"\[]?name[`"\]]? IS NULL[\s\n]* """
|
||||
r"""OR [`"\[]?name[`"\]]? IN \('"abc"'\)[\s\n]*\)""",
|
||||
sql_text,
|
||||
)
|
||||
|
||||
|
@ -396,7 +397,7 @@ class TestQueryContext(SupersetTestCase):
|
|||
# the alias should be in ORDER BY
|
||||
assert "ORDER BY `sum__num` DESC" in sql_text
|
||||
else:
|
||||
assert re.search(r'ORDER BY [`"\[]?sum__num[`"\]]? DESC', sql_text)
|
||||
assert re.search(r'ORDER BY[\s\n]* [`"\[]?sum__num[`"\]]? DESC', sql_text)
|
||||
|
||||
sql_text = get_sql_text(
|
||||
get_query_context("birth_names:only_orderby_has_metric")
|
||||
|
@ -407,7 +408,9 @@ class TestQueryContext(SupersetTestCase):
|
|||
assert "ORDER BY `sum__num` DESC" in sql_text
|
||||
else:
|
||||
assert re.search(
|
||||
r'ORDER BY SUM\([`"\[]?num[`"\]]?\) DESC', sql_text, re.IGNORECASE
|
||||
r'ORDER BY[\s\n]* SUM\([`"\[]?num[`"\]]?\) DESC',
|
||||
sql_text,
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
sql_text = get_sql_text(get_query_context("birth_names:orderby_dup_alias"))
|
||||
|
@ -438,7 +441,7 @@ class TestQueryContext(SupersetTestCase):
|
|||
assert "sum(`num_girls`) AS `SUM(num_girls)`" not in sql_text
|
||||
|
||||
# Should reference all ORDER BY columns by aliases
|
||||
assert "ORDER BY `num_girls` DESC," in sql_text
|
||||
assert "ORDER BY[\\s\n]* `num_girls` DESC," in sql_text
|
||||
assert "`AVG(num_boys)` DESC," in sql_text
|
||||
assert "`MAX(CASE WHEN...` ASC" in sql_text
|
||||
else:
|
||||
|
@ -446,14 +449,14 @@ class TestQueryContext(SupersetTestCase):
|
|||
# since the selected `num_boys` is renamed to `num_boys__`
|
||||
# it must be references as expression
|
||||
assert re.search(
|
||||
r'ORDER BY SUM\([`"\[]?num_girls[`"\]]?\) DESC',
|
||||
r'ORDER BY[\s\n]* SUM\([`"\[]?num_girls[`"\]]?\) DESC',
|
||||
sql_text,
|
||||
re.IGNORECASE,
|
||||
)
|
||||
else:
|
||||
# Should reference the adhoc metric by alias when possible
|
||||
assert re.search(
|
||||
r'ORDER BY [`"\[]?num_girls[`"\]]? DESC',
|
||||
r'ORDER BY[\s\n]* [`"\[]?num_girls[`"\]]? DESC',
|
||||
sql_text,
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
@ -1075,27 +1078,41 @@ def test_time_offset_with_temporal_range_filter(app_context, physical_dataset):
|
|||
|
||||
sqls = query_payload["query"].split(";")
|
||||
"""
|
||||
SELECT DATE_TRUNC('quarter', col6) AS col6,
|
||||
SUM(col1) AS "SUM(col1)"
|
||||
FROM physical_dataset
|
||||
WHERE col6 >= TO_TIMESTAMP('2002-01-01 00:00:00.000000', 'YYYY-MM-DD HH24:MI:SS.US')
|
||||
AND col6 < TO_TIMESTAMP('2003-01-01 00:00:00.000000', 'YYYY-MM-DD HH24:MI:SS.US')
|
||||
GROUP BY DATE_TRUNC('quarter', col6)
|
||||
LIMIT 10000;
|
||||
SELECT
|
||||
DATETIME(col6, 'start of month', PRINTF('-%d month', (
|
||||
STRFTIME('%m', col6) - 1
|
||||
) % 3)) AS col6,
|
||||
SUM(col1) AS "SUM(col1)"
|
||||
FROM physical_dataset
|
||||
WHERE
|
||||
col6 >= '2002-01-01 00:00:00' AND col6 < '2003-01-01 00:00:00'
|
||||
GROUP BY
|
||||
DATETIME(col6, 'start of month', PRINTF('-%d month', (
|
||||
STRFTIME('%m', col6) - 1
|
||||
) % 3))
|
||||
LIMIT 10000
|
||||
OFFSET 0
|
||||
|
||||
SELECT DATE_TRUNC('quarter', col6) AS col6,
|
||||
SUM(col1) AS "SUM(col1)"
|
||||
FROM physical_dataset
|
||||
WHERE col6 >= TO_TIMESTAMP('2001-10-01 00:00:00.000000', 'YYYY-MM-DD HH24:MI:SS.US')
|
||||
AND col6 < TO_TIMESTAMP('2002-10-01 00:00:00.000000', 'YYYY-MM-DD HH24:MI:SS.US')
|
||||
GROUP BY DATE_TRUNC('quarter', col6)
|
||||
LIMIT 10000;
|
||||
SELECT
|
||||
DATETIME(col6, 'start of month', PRINTF('-%d month', (
|
||||
STRFTIME('%m', col6) - 1
|
||||
) % 3)) AS col6,
|
||||
SUM(col1) AS "SUM(col1)"
|
||||
FROM physical_dataset
|
||||
WHERE
|
||||
col6 >= '2001-10-01 00:00:00' AND col6 < '2002-10-01 00:00:00'
|
||||
GROUP BY
|
||||
DATETIME(col6, 'start of month', PRINTF('-%d month', (
|
||||
STRFTIME('%m', col6) - 1
|
||||
) % 3))
|
||||
LIMIT 10000
|
||||
OFFSET 0
|
||||
"""
|
||||
assert (
|
||||
re.search(r"WHERE col6 >= .*2002-01-01", sqls[0])
|
||||
re.search(r"WHERE\n col6 >= .*2002-01-01", sqls[0])
|
||||
and re.search(r"AND col6 < .*2003-01-01", sqls[0])
|
||||
) is not None
|
||||
assert (
|
||||
re.search(r"WHERE col6 >= .*2001-10-01", sqls[1])
|
||||
re.search(r"WHERE\n col6 >= .*2001-10-01", sqls[1])
|
||||
and re.search(r"AND col6 < .*2002-10-01", sqls[1])
|
||||
) is not None
|
||||
|
|
|
@ -273,7 +273,7 @@ class TestRowLevelSecurity(SupersetTestCase):
|
|||
# establish that the filters are grouped together correctly with
|
||||
# ANDs, ORs and parens in the correct place
|
||||
assert (
|
||||
"WHERE ((name like 'A%'\n or name like 'B%')\n OR (name like 'Q%'))\n AND (gender = 'boy');"
|
||||
"WHERE\n (\n (\n name LIKE 'A%' OR name LIKE 'B%'\n ) OR (\n name LIKE 'Q%'\n )\n )\n AND (\n gender = 'boy'\n )"
|
||||
in sql
|
||||
)
|
||||
|
||||
|
@ -619,7 +619,7 @@ class TestRowLevelSecurityWithRelatedAPI(SupersetTestCase):
|
|||
|
||||
|
||||
RLS_ALICE_REGEX = re.compile(r"name = 'Alice'")
|
||||
RLS_GENDER_REGEX = re.compile(r"AND \(gender = 'girl'\)")
|
||||
RLS_GENDER_REGEX = re.compile(r"AND \([\s\n]*gender = 'girl'[\s\n]*\)")
|
||||
|
||||
|
||||
@mock.patch.dict(
|
||||
|
|
|
@ -290,7 +290,7 @@ class TestSqlLabApi(SupersetTestCase):
|
|||
"/api/v1/sqllab/format_sql/",
|
||||
json=data,
|
||||
)
|
||||
success_resp = {"result": "SELECT 1\nFROM my_table"}
|
||||
success_resp = {"result": "SELECT\n 1\nFROM my_table"}
|
||||
resp_data = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertDictEqual(resp_data, success_resp)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
|
|
|
@ -150,9 +150,10 @@ class TestDatabaseModel(SupersetTestCase):
|
|||
table1 = SqlaTable(
|
||||
table_name="test_has_extra_cache_keys_table",
|
||||
sql="""
|
||||
SELECT '{{ current_user_id() }}' as id,
|
||||
SELECT '{{ current_username() }}' as username,
|
||||
SELECT '{{ current_user_email() }}' as email,
|
||||
SELECT
|
||||
'{{ current_user_id() }}' as id,
|
||||
'{{ current_username() }}' as username,
|
||||
'{{ current_user_email() }}' as email
|
||||
""",
|
||||
database=get_example_database(),
|
||||
)
|
||||
|
@ -166,9 +167,10 @@ class TestDatabaseModel(SupersetTestCase):
|
|||
table2 = SqlaTable(
|
||||
table_name="test_has_extra_cache_keys_disabled_table",
|
||||
sql="""
|
||||
SELECT '{{ current_user_id(False) }}' as id,
|
||||
SELECT '{{ current_username(False) }}' as username,
|
||||
SELECT '{{ current_user_email(False) }}' as email,
|
||||
SELECT
|
||||
'{{ current_user_id(False) }}' as id,
|
||||
'{{ current_username(False) }}' as username,
|
||||
'{{ current_user_email(False) }}' as email,
|
||||
""",
|
||||
database=get_example_database(),
|
||||
)
|
||||
|
@ -250,10 +252,11 @@ class TestDatabaseModel(SupersetTestCase):
|
|||
|
||||
sqla_query = table.get_sqla_query(**base_query_obj)
|
||||
query = table.database.compile_sqla_query(sqla_query.sqla_query)
|
||||
|
||||
# assert virtual dataset
|
||||
assert "SELECT 'user_abc' as user, 'xyz_P1D' as time_grain" in query
|
||||
assert "SELECT\n 'user_abc' AS user,\n 'xyz_P1D' AS time_grain" in query
|
||||
# assert dataset calculated column
|
||||
assert "case when 'abc' = 'abc' then 'yes' else 'no' end AS expr" in query
|
||||
assert "case when 'abc' = 'abc' then 'yes' else 'no' end" in query
|
||||
# assert adhoc column
|
||||
assert "'foo_P1D'" in query
|
||||
# assert dataset saved metric
|
||||
|
@ -746,7 +749,7 @@ def test_none_operand_in_filter(login_as_admin, physical_dataset):
|
|||
{
|
||||
"operator": FilterOperator.NOT_EQUALS.value,
|
||||
"count": 0,
|
||||
"sql_should_contain": "COL4 IS NOT NULL",
|
||||
"sql_should_contain": "NOT COL4 IS NULL",
|
||||
},
|
||||
]
|
||||
for expected in expected_results:
|
||||
|
|
|
@ -219,7 +219,8 @@ def test_select_star(mocker: MockFixture) -> None:
|
|||
)
|
||||
assert (
|
||||
sql
|
||||
== """SELECT a
|
||||
== """SELECT
|
||||
a
|
||||
FROM my_table
|
||||
LIMIT ?
|
||||
OFFSET ?"""
|
||||
|
@ -238,6 +239,7 @@ OFFSET ?"""
|
|||
)
|
||||
assert (
|
||||
sql
|
||||
== """SELECT a
|
||||
== """SELECT
|
||||
a
|
||||
FROM my_table"""
|
||||
)
|
||||
|
|
|
@ -148,7 +148,7 @@ def test_select_star(mocker: MockFixture) -> None:
|
|||
# mock the database so we can compile the query
|
||||
database = mocker.MagicMock()
|
||||
database.compile_sqla_query = lambda query: str(
|
||||
query.compile(dialect=BigQueryDialect())
|
||||
query.compile(dialect=BigQueryDialect(), compile_kwargs={"literal_binds": True})
|
||||
)
|
||||
|
||||
engine = mocker.MagicMock()
|
||||
|
@ -167,9 +167,10 @@ def test_select_star(mocker: MockFixture) -> None:
|
|||
)
|
||||
assert (
|
||||
sql
|
||||
== """SELECT `trailer` AS `trailer`
|
||||
== """SELECT
|
||||
`trailer` AS `trailer`
|
||||
FROM `my_table`
|
||||
LIMIT :param_1"""
|
||||
LIMIT 100"""
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -47,6 +47,15 @@ def test_dataset_macro(mocker: MockFixture) -> None:
|
|||
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
|
||||
from superset.models.core import Database
|
||||
|
||||
mocker.patch(
|
||||
"superset.connectors.sqla.models.security_manager.get_guest_rls_filters",
|
||||
return_value=[],
|
||||
)
|
||||
mocker.patch(
|
||||
"superset.models.helpers.security_manager.get_rls_filters",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
columns = [
|
||||
TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"),
|
||||
TableColumn(column_name="num_boys", type="INTEGER"),
|
||||
|
@ -94,11 +103,12 @@ def test_dataset_macro(mocker: MockFixture) -> None:
|
|||
assert (
|
||||
dataset_macro(1)
|
||||
== """(
|
||||
SELECT ds AS ds,
|
||||
num_boys AS num_boys,
|
||||
revenue AS revenue,
|
||||
expenses AS expenses,
|
||||
revenue-expenses AS profit
|
||||
SELECT
|
||||
ds AS ds,
|
||||
num_boys AS num_boys,
|
||||
revenue AS revenue,
|
||||
expenses AS expenses,
|
||||
revenue - expenses AS profit
|
||||
FROM my_schema.old_dataset
|
||||
) AS dataset_1"""
|
||||
)
|
||||
|
@ -106,28 +116,32 @@ FROM my_schema.old_dataset
|
|||
assert (
|
||||
dataset_macro(1, include_metrics=True)
|
||||
== """(
|
||||
SELECT ds AS ds,
|
||||
num_boys AS num_boys,
|
||||
revenue AS revenue,
|
||||
expenses AS expenses,
|
||||
revenue-expenses AS profit,
|
||||
COUNT(*) AS cnt
|
||||
SELECT
|
||||
ds AS ds,
|
||||
num_boys AS num_boys,
|
||||
revenue AS revenue,
|
||||
expenses AS expenses,
|
||||
revenue - expenses AS profit,
|
||||
COUNT(*) AS cnt
|
||||
FROM my_schema.old_dataset
|
||||
GROUP BY ds,
|
||||
num_boys,
|
||||
revenue,
|
||||
expenses,
|
||||
revenue-expenses
|
||||
GROUP BY
|
||||
ds,
|
||||
num_boys,
|
||||
revenue,
|
||||
expenses,
|
||||
revenue - expenses
|
||||
) AS dataset_1"""
|
||||
)
|
||||
|
||||
assert (
|
||||
dataset_macro(1, include_metrics=True, columns=["ds"])
|
||||
== """(
|
||||
SELECT ds AS ds,
|
||||
COUNT(*) AS cnt
|
||||
SELECT
|
||||
ds AS ds,
|
||||
COUNT(*) AS cnt
|
||||
FROM my_schema.old_dataset
|
||||
GROUP BY ds
|
||||
GROUP BY
|
||||
ds
|
||||
) AS dataset_1"""
|
||||
)
|
||||
|
||||
|
|
|
@ -35,6 +35,8 @@ from superset.sql_parse import (
|
|||
insert_rls_in_predicate,
|
||||
ParsedQuery,
|
||||
sanitize_clause,
|
||||
SQLScript,
|
||||
SQLStatement,
|
||||
strip_comments_from_sql,
|
||||
Table,
|
||||
)
|
||||
|
@ -1850,3 +1852,36 @@ WITH t AS (
|
|||
)
|
||||
SELECT * FROM t"""
|
||||
).is_select()
|
||||
|
||||
|
||||
def test_sqlquery() -> None:
|
||||
"""
|
||||
Test the `SQLScript` class.
|
||||
"""
|
||||
script = SQLScript("SELECT 1; SELECT 2;")
|
||||
|
||||
assert len(script.statements) == 2
|
||||
assert script.format() == "SELECT\n 1;\nSELECT\n 2"
|
||||
assert script.statements[0].format() == "SELECT\n 1"
|
||||
|
||||
script = SQLScript("SET a=1; SET a=2; SELECT 3;")
|
||||
assert script.get_settings() == {"a": "2"}
|
||||
|
||||
|
||||
def test_sqlstatement() -> None:
|
||||
"""
|
||||
Test the `SQLStatement` class.
|
||||
"""
|
||||
statement = SQLStatement("SELECT * FROM table1 UNION ALL SELECT * FROM table2")
|
||||
|
||||
assert statement.tables == {
|
||||
Table(table="table1", schema=None, catalog=None),
|
||||
Table(table="table2", schema=None, catalog=None),
|
||||
}
|
||||
assert (
|
||||
statement.format()
|
||||
== "SELECT\n *\nFROM table1\nUNION ALL\nSELECT\n *\nFROM table2"
|
||||
)
|
||||
|
||||
statement = SQLStatement("SET a=1")
|
||||
assert statement.get_settings() == {"a": "1"}
|
||||
|
|
Loading…
Reference in New Issue