mirror of https://github.com/apache/superset.git
Make schema name for the CTA queries and limit configurable (#8867)
* Make schema name configurable Fixing unit tests Fix table quoting Mypy Split tests out for sqlite Grant more permissions for mysql user Postgres doesn't support if not exists More logging Commit for table creation Priviliges for postgres Update tests Resolve comments Lint No limits for the CTA queries if configures * CTA -> CTAS and dict -> {} * Move database creation to the .travis file * Black * Move tweaks to travis db setup * Remove left over version * Address comments * Quote table names in the CTAS queries * Pass tmp_schema_name for the query execution * Rebase alembic migration * Switch to python3 mypy * SQLLAB_CTA_SCHEMA_NAME_FUNC -> SQLLAB_CTAS_SCHEMA_NAME_FUNC * Black
This commit is contained in:
parent
26e916e46b
commit
4e1fa95035
11
.travis.yml
11
.travis.yml
|
@ -64,8 +64,10 @@ jobs:
|
||||||
- redis-server
|
- redis-server
|
||||||
before_script:
|
before_script:
|
||||||
- mysql -u root -e "DROP DATABASE IF EXISTS superset; CREATE DATABASE superset DEFAULT CHARACTER SET utf8 COLLATE utf8_unicode_ci"
|
- mysql -u root -e "DROP DATABASE IF EXISTS superset; CREATE DATABASE superset DEFAULT CHARACTER SET utf8 COLLATE utf8_unicode_ci"
|
||||||
|
- mysql -u root -e "DROP DATABASE IF EXISTS sqllab_test_db; CREATE DATABASE sqllab_test_db DEFAULT CHARACTER SET utf8 COLLATE utf8_unicode_ci"
|
||||||
|
- mysql -u root -e "DROP DATABASE IF EXISTS admin_database; CREATE DATABASE admin_database DEFAULT CHARACTER SET utf8 COLLATE utf8_unicode_ci"
|
||||||
- mysql -u root -e "CREATE USER 'mysqluser'@'localhost' IDENTIFIED BY 'mysqluserpassword';"
|
- mysql -u root -e "CREATE USER 'mysqluser'@'localhost' IDENTIFIED BY 'mysqluserpassword';"
|
||||||
- mysql -u root -e "GRANT ALL ON superset.* TO 'mysqluser'@'localhost';"
|
- mysql -u root -e "GRANT ALL ON *.* TO 'mysqluser'@'localhost';"
|
||||||
- language: python
|
- language: python
|
||||||
env: TOXENV=javascript
|
env: TOXENV=javascript
|
||||||
before_install:
|
before_install:
|
||||||
|
@ -91,8 +93,15 @@ jobs:
|
||||||
- postgresql
|
- postgresql
|
||||||
- redis-server
|
- redis-server
|
||||||
before_script:
|
before_script:
|
||||||
|
- psql -U postgres -c "DROP DATABASE IF EXISTS superset;"
|
||||||
- psql -U postgres -c "CREATE DATABASE superset;"
|
- psql -U postgres -c "CREATE DATABASE superset;"
|
||||||
|
- psql -U postgres superset -c "DROP SCHEMA IF EXISTS sqllab_test_db;"
|
||||||
|
- psql -U postgres superset -c "CREATE SCHEMA sqllab_test_db;"
|
||||||
|
- psql -U postgres superset -c "DROP SCHEMA IF EXISTS admin_database;"
|
||||||
|
- psql -U postgres superset -c "CREATE SCHEMA admin_database;"
|
||||||
- psql -U postgres -c "CREATE USER postgresuser WITH PASSWORD 'pguserpassword';"
|
- psql -U postgres -c "CREATE USER postgresuser WITH PASSWORD 'pguserpassword';"
|
||||||
|
- psql -U postgres superset -c "GRANT ALL PRIVILEGES ON SCHEMA sqllab_test_db to postgresuser";
|
||||||
|
- psql -U postgres superset -c "GRANT ALL PRIVILEGES ON SCHEMA admin_database to postgresuser";
|
||||||
- language: python
|
- language: python
|
||||||
python: 3.6
|
python: 3.6
|
||||||
env: TOXENV=pylint
|
env: TOXENV=pylint
|
||||||
|
|
|
@ -28,7 +28,7 @@ import os
|
||||||
import sys
|
import sys
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from datetime import date
|
from datetime import date
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
from celery.schedules import crontab
|
from celery.schedules import crontab
|
||||||
from dateutil import tz
|
from dateutil import tz
|
||||||
|
@ -41,6 +41,9 @@ from superset.utils.logging_configurator import DefaultLoggingConfigurator
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from flask_appbuilder.security.sqla import models # pylint: disable=unused-import
|
||||||
|
from superset.models.core import Database # pylint: disable=unused-import
|
||||||
|
|
||||||
# Realtime stats logger, a StatsD implementation exists
|
# Realtime stats logger, a StatsD implementation exists
|
||||||
STATS_LOGGER = DummyStatsLogger()
|
STATS_LOGGER = DummyStatsLogger()
|
||||||
|
@ -523,6 +526,32 @@ SQLLAB_ASYNC_TIME_LIMIT_SEC = 60 * 60 * 6
|
||||||
# timeout.
|
# timeout.
|
||||||
SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT = 10 # seconds
|
SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT = 10 # seconds
|
||||||
|
|
||||||
|
# Flag that controls if limit should be enforced on the CTA (create table as queries).
|
||||||
|
SQLLAB_CTAS_NO_LIMIT = False
|
||||||
|
|
||||||
|
# This allows you to define custom logic around the "CREATE TABLE AS" or CTAS feature
|
||||||
|
# in SQL Lab that defines where the target schema should be for a given user.
|
||||||
|
# Database `CTAS Schema` has a precedence over this setting.
|
||||||
|
# Example below returns a username and CTA queries will write tables into the schema
|
||||||
|
# name `username`
|
||||||
|
# SQLLAB_CTAS_SCHEMA_NAME_FUNC = lambda database, user, schema, sql: user.username
|
||||||
|
# This is move involved example where depending on the database you can leverage data
|
||||||
|
# available to assign schema for the CTA query:
|
||||||
|
# def compute_schema_name(database: Database, user: User, schema: str, sql: str) -> str:
|
||||||
|
# if database.name == 'mysql_payments_slave':
|
||||||
|
# return 'tmp_superset_schema'
|
||||||
|
# if database.name == 'presto_gold':
|
||||||
|
# return user.username
|
||||||
|
# if database.name == 'analytics':
|
||||||
|
# if 'analytics' in [r.name for r in user.roles]:
|
||||||
|
# return 'analytics_cta'
|
||||||
|
# else:
|
||||||
|
# return f'tmp_{schema}'
|
||||||
|
# Function accepts database object, user object, schema name and sql that will be run.
|
||||||
|
SQLLAB_CTAS_SCHEMA_NAME_FUNC: Optional[
|
||||||
|
Callable[["Database", "models.User", str, str], str]
|
||||||
|
] = None
|
||||||
|
|
||||||
# An instantiated derivative of werkzeug.contrib.cache.BaseCache
|
# An instantiated derivative of werkzeug.contrib.cache.BaseCache
|
||||||
# if enabled, it can be used to store the results of long-running queries
|
# if enabled, it can be used to store the results of long-running queries
|
||||||
# in SQL Lab by using the "Run Async" button/feature
|
# in SQL Lab by using the "Run Async" button/feature
|
||||||
|
|
|
@ -336,7 +336,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||||
return database.compile_sqla_query(qry)
|
return database.compile_sqla_query(qry)
|
||||||
elif LimitMethod.FORCE_LIMIT:
|
elif LimitMethod.FORCE_LIMIT:
|
||||||
parsed_query = sql_parse.ParsedQuery(sql)
|
parsed_query = sql_parse.ParsedQuery(sql)
|
||||||
sql = parsed_query.get_query_with_new_limit(limit)
|
sql = parsed_query.set_or_update_query_limit(limit)
|
||||||
return sql
|
return sql
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -351,7 +351,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||||
return parsed_query.limit
|
return parsed_query.limit
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_query_with_new_limit(cls, sql: str, limit: int) -> str:
|
def set_or_update_query_limit(cls, sql: str, limit: int) -> str:
|
||||||
"""
|
"""
|
||||||
Create a query based on original query but with new limit clause
|
Create a query based on original query but with new limit clause
|
||||||
|
|
||||||
|
@ -360,7 +360,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||||
:return: Query with new limit
|
:return: Query with new limit
|
||||||
"""
|
"""
|
||||||
parsed_query = sql_parse.ParsedQuery(sql)
|
parsed_query = sql_parse.ParsedQuery(sql)
|
||||||
return parsed_query.get_query_with_new_limit(limit)
|
return parsed_query.set_or_update_query_limit(limit)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def csv_to_df(**kwargs: Any) -> pd.DataFrame:
|
def csv_to_df(**kwargs: Any) -> pd.DataFrame:
|
||||||
|
@ -632,10 +632,12 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||||
"""
|
"""
|
||||||
Generate a "SELECT * from [schema.]table_name" query with appropriate limit.
|
Generate a "SELECT * from [schema.]table_name" query with appropriate limit.
|
||||||
|
|
||||||
|
WARNING: expects only unquoted table and schema names.
|
||||||
|
|
||||||
:param database: Database instance
|
:param database: Database instance
|
||||||
:param table_name: Table name
|
:param table_name: Table name, unquoted
|
||||||
:param engine: SqlALchemy Engine instance
|
:param engine: SqlALchemy Engine instance
|
||||||
:param schema: Schema
|
:param schema: Schema, unquoted
|
||||||
:param limit: limit to impose on query
|
:param limit: limit to impose on query
|
||||||
:param show_cols: Show columns in query; otherwise use "*"
|
:param show_cols: Show columns in query; otherwise use "*"
|
||||||
:param indent: Add indentation to query
|
:param indent: Add indentation to query
|
||||||
|
|
|
@ -0,0 +1,44 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
"""Add tmp_schema_name to the query object.
|
||||||
|
|
||||||
|
Revision ID: 72428d1ea401
|
||||||
|
Revises: 0a6f12f60c73
|
||||||
|
Create Date: 2020-02-20 08:52:22.877902
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = "72428d1ea401"
|
||||||
|
down_revision = "0a6f12f60c73"
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
op.add_column(
|
||||||
|
"query", sa.Column("tmp_schema_name", sa.String(length=256), nullable=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
try:
|
||||||
|
# sqlite doesn't like dropping the columns
|
||||||
|
op.drop_column("query", "tmp_schema_name")
|
||||||
|
except Exception:
|
||||||
|
pass
|
|
@ -55,6 +55,7 @@ class Query(Model, ExtraJSONMixin):
|
||||||
|
|
||||||
# Store the tmp table into the DB only if the user asks for it.
|
# Store the tmp table into the DB only if the user asks for it.
|
||||||
tmp_table_name = Column(String(256))
|
tmp_table_name = Column(String(256))
|
||||||
|
tmp_schema_name = Column(String(256))
|
||||||
user_id = Column(Integer, ForeignKey("ab_user.id"), nullable=True)
|
user_id = Column(Integer, ForeignKey("ab_user.id"), nullable=True)
|
||||||
status = Column(String(16), default=QueryStatus.PENDING)
|
status = Column(String(16), default=QueryStatus.PENDING)
|
||||||
tab_name = Column(String(256))
|
tab_name = Column(String(256))
|
||||||
|
|
|
@ -59,6 +59,7 @@ stats_logger = config["STATS_LOGGER"]
|
||||||
SQLLAB_TIMEOUT = config["SQLLAB_ASYNC_TIME_LIMIT_SEC"]
|
SQLLAB_TIMEOUT = config["SQLLAB_ASYNC_TIME_LIMIT_SEC"]
|
||||||
SQLLAB_HARD_TIMEOUT = SQLLAB_TIMEOUT + 60
|
SQLLAB_HARD_TIMEOUT = SQLLAB_TIMEOUT + 60
|
||||||
SQL_MAX_ROW = config["SQL_MAX_ROW"]
|
SQL_MAX_ROW = config["SQL_MAX_ROW"]
|
||||||
|
SQLLAB_CTAS_NO_LIMIT = config["SQLLAB_CTAS_NO_LIMIT"]
|
||||||
SQL_QUERY_MUTATOR = config["SQL_QUERY_MUTATOR"]
|
SQL_QUERY_MUTATOR = config["SQL_QUERY_MUTATOR"]
|
||||||
log_query = config["QUERY_LOGGER"]
|
log_query = config["QUERY_LOGGER"]
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -207,9 +208,15 @@ def execute_sql_statement(sql_statement, query, user_name, session, cursor, log_
|
||||||
query.tmp_table_name = "tmp_{}_table_{}".format(
|
query.tmp_table_name = "tmp_{}_table_{}".format(
|
||||||
query.user_id, start_dttm.strftime("%Y_%m_%d_%H_%M_%S")
|
query.user_id, start_dttm.strftime("%Y_%m_%d_%H_%M_%S")
|
||||||
)
|
)
|
||||||
sql = parsed_query.as_create_table(query.tmp_table_name)
|
sql = parsed_query.as_create_table(
|
||||||
|
query.tmp_table_name, schema_name=query.tmp_schema_name
|
||||||
|
)
|
||||||
query.select_as_cta_used = True
|
query.select_as_cta_used = True
|
||||||
if parsed_query.is_select():
|
|
||||||
|
# Do not apply limit to the CTA queries when SQLLAB_CTAS_NO_LIMIT is set to true
|
||||||
|
if parsed_query.is_select() and not (
|
||||||
|
query.select_as_cta_used and SQLLAB_CTAS_NO_LIMIT
|
||||||
|
):
|
||||||
if SQL_MAX_ROW and (not query.limit or query.limit > SQL_MAX_ROW):
|
if SQL_MAX_ROW and (not query.limit or query.limit > SQL_MAX_ROW):
|
||||||
query.limit = SQL_MAX_ROW
|
query.limit = SQL_MAX_ROW
|
||||||
if query.limit:
|
if query.limit:
|
||||||
|
@ -378,6 +385,9 @@ def execute_sql_statements(
|
||||||
payload = handle_query_error(msg, query, session, payload)
|
payload = handle_query_error(msg, query, session, payload)
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
|
# Commit the connection so CTA queries will create the table.
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
# Success, updating the query entry in database
|
# Success, updating the query entry in database
|
||||||
query.rows = result_set.size
|
query.rows = result_set.size
|
||||||
query.progress = 100
|
query.progress = 100
|
||||||
|
@ -385,8 +395,8 @@ def execute_sql_statements(
|
||||||
if query.select_as_cta:
|
if query.select_as_cta:
|
||||||
query.select_sql = database.select_star(
|
query.select_sql = database.select_star(
|
||||||
query.tmp_table_name,
|
query.tmp_table_name,
|
||||||
|
schema=query.tmp_schema_name,
|
||||||
limit=query.limit,
|
limit=query.limit,
|
||||||
schema=database.force_ctas_schema,
|
|
||||||
show_cols=False,
|
show_cols=False,
|
||||||
latest_partition=False,
|
latest_partition=False,
|
||||||
)
|
)
|
||||||
|
|
|
@ -151,20 +151,28 @@ class ParsedQuery:
|
||||||
self._alias_names.add(token_list.tokens[0].value)
|
self._alias_names.add(token_list.tokens[0].value)
|
||||||
self.__extract_from_token(token_list)
|
self.__extract_from_token(token_list)
|
||||||
|
|
||||||
def as_create_table(self, table_name: str, overwrite: bool = False) -> str:
|
def as_create_table(
|
||||||
|
self,
|
||||||
|
table_name: str,
|
||||||
|
schema_name: Optional[str] = None,
|
||||||
|
overwrite: bool = False,
|
||||||
|
) -> str:
|
||||||
"""Reformats the query into the create table as query.
|
"""Reformats the query into the create table as query.
|
||||||
|
|
||||||
Works only for the single select SQL statements, in all other cases
|
Works only for the single select SQL statements, in all other cases
|
||||||
the sql query is not modified.
|
the sql query is not modified.
|
||||||
:param table_name: Table that will contain the results of the query execution
|
:param table_name: table that will contain the results of the query execution
|
||||||
|
:param schema_name: schema name for the target table
|
||||||
:param overwrite: table_name will be dropped if true
|
:param overwrite: table_name will be dropped if true
|
||||||
:return: Create table as query
|
:return: Create table as query
|
||||||
"""
|
"""
|
||||||
exec_sql = ""
|
exec_sql = ""
|
||||||
sql = self.stripped()
|
sql = self.stripped()
|
||||||
|
# TODO(bkyryliuk): quote full_table_name
|
||||||
|
full_table_name = f"{schema_name}.{table_name}" if schema_name else table_name
|
||||||
if overwrite:
|
if overwrite:
|
||||||
exec_sql = f"DROP TABLE IF EXISTS {table_name};\n"
|
exec_sql = f"DROP TABLE IF EXISTS {full_table_name};\n"
|
||||||
exec_sql += f"CREATE TABLE {table_name} AS \n{sql}"
|
exec_sql += f"CREATE TABLE {full_table_name} AS \n{sql}"
|
||||||
return exec_sql
|
return exec_sql
|
||||||
|
|
||||||
def __extract_from_token(self, token: Token): # pylint: disable=too-many-branches
|
def __extract_from_token(self, token: Token): # pylint: disable=too-many-branches
|
||||||
|
@ -205,10 +213,12 @@ class ParsedQuery:
|
||||||
if not self.__is_identifier(token2):
|
if not self.__is_identifier(token2):
|
||||||
self.__extract_from_token(item)
|
self.__extract_from_token(item)
|
||||||
|
|
||||||
def get_query_with_new_limit(self, new_limit: int) -> str:
|
def set_or_update_query_limit(self, new_limit: int) -> str:
|
||||||
"""
|
"""Returns the query with the specified limit.
|
||||||
returns the query with the specified limit.
|
|
||||||
Does not change the underlying query
|
Does not change the underlying query if user did not apply the limit,
|
||||||
|
otherwise replaces the limit with the lower value between existing limit
|
||||||
|
in the query and new_limit.
|
||||||
|
|
||||||
:param new_limit: Limit to be incorporated into returned query
|
:param new_limit: Limit to be incorporated into returned query
|
||||||
:return: The original query with new limit
|
:return: The original query with new limit
|
||||||
|
@ -223,7 +233,10 @@ class ParsedQuery:
|
||||||
limit_pos = pos
|
limit_pos = pos
|
||||||
break
|
break
|
||||||
_, limit = statement.token_next(idx=limit_pos)
|
_, limit = statement.token_next(idx=limit_pos)
|
||||||
if limit.ttype == sqlparse.tokens.Literal.Number.Integer:
|
# Override the limit only when it exceeds the configured value.
|
||||||
|
if limit.ttype == sqlparse.tokens.Literal.Number.Integer and new_limit < int(
|
||||||
|
limit.value
|
||||||
|
):
|
||||||
limit.value = new_limit
|
limit.value = new_limit
|
||||||
elif limit.is_group:
|
elif limit.is_group:
|
||||||
limit.value = f"{next(limit.get_identifiers())}, {new_limit}"
|
limit.value = f"{next(limit.get_identifiers())}, {new_limit}"
|
||||||
|
|
|
@ -19,7 +19,7 @@ import logging
|
||||||
import re
|
import re
|
||||||
from contextlib import closing
|
from contextlib import closing
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Any, cast, Dict, List, Optional, Union
|
from typing import Any, Callable, cast, Dict, List, Optional, Union
|
||||||
from urllib import parse
|
from urllib import parse
|
||||||
|
|
||||||
import backoff
|
import backoff
|
||||||
|
@ -73,6 +73,7 @@ from superset.exceptions import (
|
||||||
SupersetTimeoutException,
|
SupersetTimeoutException,
|
||||||
)
|
)
|
||||||
from superset.jinja_context import get_template_processor
|
from superset.jinja_context import get_template_processor
|
||||||
|
from superset.models.core import Database
|
||||||
from superset.models.dashboard import Dashboard
|
from superset.models.dashboard import Dashboard
|
||||||
from superset.models.datasource_access_request import DatasourceAccessRequest
|
from superset.models.datasource_access_request import DatasourceAccessRequest
|
||||||
from superset.models.slice import Slice
|
from superset.models.slice import Slice
|
||||||
|
@ -247,6 +248,17 @@ def _deserialize_results_payload(
|
||||||
return json.loads(payload) # type: ignore
|
return json.loads(payload) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def get_cta_schema_name(
|
||||||
|
database: Database, user: ab_models.User, schema: str, sql: str
|
||||||
|
) -> Optional[str]:
|
||||||
|
func: Optional[Callable[[Database, ab_models.User, str, str], str]] = config[
|
||||||
|
"SQLLAB_CTAS_SCHEMA_NAME_FUNC"
|
||||||
|
]
|
||||||
|
if not func:
|
||||||
|
return None
|
||||||
|
return func(database, user, schema, sql)
|
||||||
|
|
||||||
|
|
||||||
class AccessRequestsModelView(SupersetModelView, DeleteMixin):
|
class AccessRequestsModelView(SupersetModelView, DeleteMixin):
|
||||||
datamodel = SQLAInterface(DAR)
|
datamodel = SQLAInterface(DAR)
|
||||||
include_route_methods = RouteMethod.CRUD_SET
|
include_route_methods = RouteMethod.CRUD_SET
|
||||||
|
@ -2351,9 +2363,14 @@ class Superset(BaseSupersetView):
|
||||||
if not mydb:
|
if not mydb:
|
||||||
return json_error_response(f"Database with id {database_id} is missing.")
|
return json_error_response(f"Database with id {database_id} is missing.")
|
||||||
|
|
||||||
# Set tmp_table_name for CTA
|
# Set tmp_schema_name for CTA
|
||||||
|
# TODO(bkyryliuk): consider parsing, splitting tmp_schema_name from tmp_table_name if user enters
|
||||||
|
# <schema_name>.<table_name>
|
||||||
|
tmp_schema_name: Optional[str] = schema
|
||||||
if select_as_cta and mydb.force_ctas_schema:
|
if select_as_cta and mydb.force_ctas_schema:
|
||||||
tmp_table_name = f"{mydb.force_ctas_schema}.{tmp_table_name}"
|
tmp_schema_name = mydb.force_ctas_schema
|
||||||
|
elif select_as_cta:
|
||||||
|
tmp_schema_name = get_cta_schema_name(mydb, g.user, schema, sql)
|
||||||
|
|
||||||
# Save current query
|
# Save current query
|
||||||
query = Query(
|
query = Query(
|
||||||
|
@ -2366,6 +2383,7 @@ class Superset(BaseSupersetView):
|
||||||
status=status,
|
status=status,
|
||||||
sql_editor_id=sql_editor_id,
|
sql_editor_id=sql_editor_id,
|
||||||
tmp_table_name=tmp_table_name,
|
tmp_table_name=tmp_table_name,
|
||||||
|
tmp_schema_name=tmp_schema_name,
|
||||||
user_id=g.user.get_id() if g.user else None,
|
user_id=g.user.get_id() if g.user else None,
|
||||||
client_id=client_id,
|
client_id=client_id,
|
||||||
)
|
)
|
||||||
|
@ -2406,6 +2424,8 @@ class Superset(BaseSupersetView):
|
||||||
f"Query {query_id}: Template rendering failed: {error_msg}"
|
f"Query {query_id}: Template rendering failed: {error_msg}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Limit is not applied to the CTA queries if SQLLAB_CTAS_NO_LIMIT flag is set to True.
|
||||||
|
if not (config.get("SQLLAB_CTAS_NO_LIMIT") and select_as_cta):
|
||||||
# set LIMIT after template processing
|
# set LIMIT after template processing
|
||||||
limits = [mydb.db_engine_spec.get_limit_from_sql(rendered_query), limit]
|
limits = [mydb.db_engine_spec.get_limit_from_sql(rendered_query), limit]
|
||||||
query.limit = min(lim for lim in limits if lim is not None)
|
query.limit = min(lim for lim in limits if lim is not None)
|
||||||
|
|
|
@ -229,22 +229,27 @@ class SupersetTestCase(TestCase):
|
||||||
query_limit=None,
|
query_limit=None,
|
||||||
database_name="examples",
|
database_name="examples",
|
||||||
sql_editor_id=None,
|
sql_editor_id=None,
|
||||||
|
select_as_cta=False,
|
||||||
|
tmp_table_name=None,
|
||||||
):
|
):
|
||||||
if user_name:
|
if user_name:
|
||||||
self.logout()
|
self.logout()
|
||||||
self.login(username=(user_name or "admin"))
|
self.login(username=(user_name or "admin"))
|
||||||
dbid = self._get_database_by_name(database_name).id
|
dbid = self._get_database_by_name(database_name).id
|
||||||
|
json_payload = {
|
||||||
|
"database_id": dbid,
|
||||||
|
"sql": sql,
|
||||||
|
"client_id": client_id,
|
||||||
|
"queryLimit": query_limit,
|
||||||
|
"sql_editor_id": sql_editor_id,
|
||||||
|
}
|
||||||
|
if tmp_table_name:
|
||||||
|
json_payload["tmp_table_name"] = tmp_table_name
|
||||||
|
if select_as_cta:
|
||||||
|
json_payload["select_as_cta"] = select_as_cta
|
||||||
|
|
||||||
resp = self.get_json_resp(
|
resp = self.get_json_resp(
|
||||||
"/superset/sql_json/",
|
"/superset/sql_json/", raise_on_error=False, json_=json_payload
|
||||||
raise_on_error=False,
|
|
||||||
json_=dict(
|
|
||||||
database_id=dbid,
|
|
||||||
sql=sql,
|
|
||||||
select_as_create_as=False,
|
|
||||||
client_id=client_id,
|
|
||||||
queryLimit=query_limit,
|
|
||||||
sql_editor_id=sql_editor_id,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
if raise_on_error and "error" in resp:
|
if raise_on_error and "error" in resp:
|
||||||
raise Exception("run_sql failed")
|
raise Exception("run_sql failed")
|
||||||
|
|
|
@ -17,14 +17,20 @@
|
||||||
# isort:skip_file
|
# isort:skip_file
|
||||||
"""Unit tests for Superset Celery worker"""
|
"""Unit tests for Superset Celery worker"""
|
||||||
import datetime
|
import datetime
|
||||||
|
import io
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import subprocess
|
import subprocess
|
||||||
import time
|
import time
|
||||||
import unittest
|
import unittest
|
||||||
import unittest.mock as mock
|
import unittest.mock as mock
|
||||||
|
|
||||||
import flask
|
import flask
|
||||||
|
import sqlalchemy
|
||||||
|
from contextlib2 import contextmanager
|
||||||
from flask import current_app
|
from flask import current_app
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
from sqlalchemy.pool import NullPool
|
||||||
|
|
||||||
from tests.test_app import app
|
from tests.test_app import app
|
||||||
from superset import db, sql_lab
|
from superset import db, sql_lab
|
||||||
|
@ -38,11 +44,11 @@ from superset.utils.core import get_example_database
|
||||||
|
|
||||||
from .base_tests import SupersetTestCase
|
from .base_tests import SupersetTestCase
|
||||||
|
|
||||||
|
CELERY_SHORT_SLEEP_TIME = 2
|
||||||
CELERY_SLEEP_TIME = 5
|
CELERY_SLEEP_TIME = 5
|
||||||
|
|
||||||
|
|
||||||
class UtilityFunctionTests(SupersetTestCase):
|
class UtilityFunctionTests(SupersetTestCase):
|
||||||
|
|
||||||
# TODO(bkyryliuk): support more cases in CTA function.
|
# TODO(bkyryliuk): support more cases in CTA function.
|
||||||
def test_create_table_as(self):
|
def test_create_table_as(self):
|
||||||
q = ParsedQuery("SELECT * FROM outer_space;")
|
q = ParsedQuery("SELECT * FROM outer_space;")
|
||||||
|
@ -90,6 +96,9 @@ class AppContextTests(SupersetTestCase):
|
||||||
flask._app_ctx_stack.push(popped_app)
|
flask._app_ctx_stack.push(popped_app)
|
||||||
|
|
||||||
|
|
||||||
|
CTAS_SCHEMA_NAME = "sqllab_test_db"
|
||||||
|
|
||||||
|
|
||||||
class CeleryTestCase(SupersetTestCase):
|
class CeleryTestCase(SupersetTestCase):
|
||||||
def get_query_by_name(self, sql):
|
def get_query_by_name(self, sql):
|
||||||
session = db.session
|
session = db.session
|
||||||
|
@ -159,7 +168,6 @@ class CeleryTestCase(SupersetTestCase):
|
||||||
|
|
||||||
def test_run_sync_query_cta(self):
|
def test_run_sync_query_cta(self):
|
||||||
main_db = get_example_database()
|
main_db = get_example_database()
|
||||||
backend = main_db.backend
|
|
||||||
db_id = main_db.id
|
db_id = main_db.id
|
||||||
tmp_table_name = "tmp_async_22"
|
tmp_table_name = "tmp_async_22"
|
||||||
self.drop_table_if_exists(tmp_table_name, main_db)
|
self.drop_table_if_exists(tmp_table_name, main_db)
|
||||||
|
@ -172,12 +180,13 @@ class CeleryTestCase(SupersetTestCase):
|
||||||
query2 = self.get_query_by_id(result["query"]["serverId"])
|
query2 = self.get_query_by_id(result["query"]["serverId"])
|
||||||
|
|
||||||
# Check the data in the tmp table.
|
# Check the data in the tmp table.
|
||||||
if backend != "postgresql":
|
|
||||||
# TODO This test won't work in Postgres
|
|
||||||
results = self.run_sql(db_id, query2.select_sql, "sdf2134")
|
results = self.run_sql(db_id, query2.select_sql, "sdf2134")
|
||||||
self.assertEqual(results["status"], "success")
|
self.assertEqual(results["status"], "success")
|
||||||
self.assertGreater(len(results["data"]), 0)
|
self.assertGreater(len(results["data"]), 0)
|
||||||
|
|
||||||
|
# cleanup tmp table
|
||||||
|
self.drop_table_if_exists(tmp_table_name, get_example_database())
|
||||||
|
|
||||||
def test_run_sync_query_cta_no_data(self):
|
def test_run_sync_query_cta_no_data(self):
|
||||||
main_db = get_example_database()
|
main_db = get_example_database()
|
||||||
db_id = main_db.id
|
db_id = main_db.id
|
||||||
|
@ -199,15 +208,89 @@ class CeleryTestCase(SupersetTestCase):
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
return self.run_sql(db_id, sql)
|
return self.run_sql(db_id, sql)
|
||||||
|
|
||||||
def test_run_async_query(self):
|
@mock.patch(
|
||||||
|
"superset.views.core.get_cta_schema_name", lambda d, u, s, sql: CTAS_SCHEMA_NAME
|
||||||
|
)
|
||||||
|
def test_run_sync_query_cta_config(self):
|
||||||
|
main_db = get_example_database()
|
||||||
|
db_id = main_db.id
|
||||||
|
if main_db.backend == "sqlite":
|
||||||
|
# sqlite doesn't support schemas
|
||||||
|
return
|
||||||
|
tmp_table_name = "tmp_async_22"
|
||||||
|
expected_full_table_name = f"{CTAS_SCHEMA_NAME}.{tmp_table_name}"
|
||||||
|
self.drop_table_if_exists(expected_full_table_name, main_db)
|
||||||
|
name = "James"
|
||||||
|
sql_where = f"SELECT name FROM birth_names WHERE name='{name}'"
|
||||||
|
result = self.run_sql(
|
||||||
|
db_id, sql_where, "cid2", tmp_table=tmp_table_name, cta=True
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(QueryStatus.SUCCESS, result["query"]["state"])
|
||||||
|
self.assertEqual([], result["data"])
|
||||||
|
self.assertEqual([], result["columns"])
|
||||||
|
query = self.get_query_by_id(result["query"]["serverId"])
|
||||||
|
self.assertEqual(
|
||||||
|
f"CREATE TABLE {expected_full_table_name} AS \n"
|
||||||
|
"SELECT name FROM birth_names "
|
||||||
|
"WHERE name='James'",
|
||||||
|
query.executed_sql,
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
"SELECT *\n" f"FROM {expected_full_table_name}", query.select_sql
|
||||||
|
)
|
||||||
|
time.sleep(CELERY_SHORT_SLEEP_TIME)
|
||||||
|
results = self.run_sql(db_id, query.select_sql)
|
||||||
|
self.assertEqual(results["status"], "success")
|
||||||
|
self.drop_table_if_exists(expected_full_table_name, get_example_database())
|
||||||
|
|
||||||
|
@mock.patch(
|
||||||
|
"superset.views.core.get_cta_schema_name", lambda d, u, s, sql: CTAS_SCHEMA_NAME
|
||||||
|
)
|
||||||
|
def test_run_async_query_cta_config(self):
|
||||||
|
main_db = get_example_database()
|
||||||
|
db_id = main_db.id
|
||||||
|
if main_db.backend == "sqlite":
|
||||||
|
# sqlite doesn't support schemas
|
||||||
|
return
|
||||||
|
tmp_table_name = "sqllab_test_table_async_1"
|
||||||
|
expected_full_table_name = f"{CTAS_SCHEMA_NAME}.{tmp_table_name}"
|
||||||
|
self.drop_table_if_exists(expected_full_table_name, main_db)
|
||||||
|
sql_where = "SELECT name FROM birth_names WHERE name='James' LIMIT 10"
|
||||||
|
result = self.run_sql(
|
||||||
|
db_id,
|
||||||
|
sql_where,
|
||||||
|
"cid3",
|
||||||
|
async_=True,
|
||||||
|
tmp_table="sqllab_test_table_async_1",
|
||||||
|
cta=True,
|
||||||
|
)
|
||||||
|
db.session.close()
|
||||||
|
time.sleep(CELERY_SLEEP_TIME)
|
||||||
|
|
||||||
|
query = self.get_query_by_id(result["query"]["serverId"])
|
||||||
|
self.assertEqual(QueryStatus.SUCCESS, query.status)
|
||||||
|
self.assertTrue(f"FROM {expected_full_table_name}" in query.select_sql)
|
||||||
|
self.assertEqual(
|
||||||
|
f"CREATE TABLE {expected_full_table_name} AS \n"
|
||||||
|
"SELECT name FROM birth_names "
|
||||||
|
"WHERE name='James' "
|
||||||
|
"LIMIT 10",
|
||||||
|
query.executed_sql,
|
||||||
|
)
|
||||||
|
self.drop_table_if_exists(expected_full_table_name, get_example_database())
|
||||||
|
|
||||||
|
def test_run_async_cta_query(self):
|
||||||
main_db = get_example_database()
|
main_db = get_example_database()
|
||||||
db_id = main_db.id
|
db_id = main_db.id
|
||||||
|
|
||||||
self.drop_table_if_exists("tmp_async_1", main_db)
|
table_name = "tmp_async_4"
|
||||||
|
self.drop_table_if_exists(table_name, main_db)
|
||||||
|
time.sleep(CELERY_SLEEP_TIME)
|
||||||
|
|
||||||
sql_where = "SELECT name FROM birth_names WHERE name='James' LIMIT 10"
|
sql_where = "SELECT name FROM birth_names WHERE name='James' LIMIT 10"
|
||||||
result = self.run_sql(
|
result = self.run_sql(
|
||||||
db_id, sql_where, "4", async_=True, tmp_table="tmp_async_1", cta=True
|
db_id, sql_where, "cid4", async_=True, tmp_table="tmp_async_4", cta=True
|
||||||
)
|
)
|
||||||
db.session.close()
|
db.session.close()
|
||||||
assert result["query"]["state"] in (
|
assert result["query"]["state"] in (
|
||||||
|
@ -221,9 +304,9 @@ class CeleryTestCase(SupersetTestCase):
|
||||||
query = self.get_query_by_id(result["query"]["serverId"])
|
query = self.get_query_by_id(result["query"]["serverId"])
|
||||||
self.assertEqual(QueryStatus.SUCCESS, query.status)
|
self.assertEqual(QueryStatus.SUCCESS, query.status)
|
||||||
|
|
||||||
self.assertTrue("FROM tmp_async_1" in query.select_sql)
|
self.assertTrue(f"FROM {table_name}" in query.select_sql)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
"CREATE TABLE tmp_async_1 AS \n"
|
f"CREATE TABLE {table_name} AS \n"
|
||||||
"SELECT name FROM birth_names "
|
"SELECT name FROM birth_names "
|
||||||
"WHERE name='James' "
|
"WHERE name='James' "
|
||||||
"LIMIT 10",
|
"LIMIT 10",
|
||||||
|
@ -234,7 +317,7 @@ class CeleryTestCase(SupersetTestCase):
|
||||||
self.assertEqual(True, query.select_as_cta)
|
self.assertEqual(True, query.select_as_cta)
|
||||||
self.assertEqual(True, query.select_as_cta_used)
|
self.assertEqual(True, query.select_as_cta_used)
|
||||||
|
|
||||||
def test_run_async_query_with_lower_limit(self):
|
def test_run_async_cta_query_with_lower_limit(self):
|
||||||
main_db = get_example_database()
|
main_db = get_example_database()
|
||||||
db_id = main_db.id
|
db_id = main_db.id
|
||||||
tmp_table = "tmp_async_2"
|
tmp_table = "tmp_async_2"
|
||||||
|
@ -242,7 +325,7 @@ class CeleryTestCase(SupersetTestCase):
|
||||||
|
|
||||||
sql_where = "SELECT name FROM birth_names LIMIT 1"
|
sql_where = "SELECT name FROM birth_names LIMIT 1"
|
||||||
result = self.run_sql(
|
result = self.run_sql(
|
||||||
db_id, sql_where, "5", async_=True, tmp_table=tmp_table, cta=True
|
db_id, sql_where, "id1", async_=True, tmp_table=tmp_table, cta=True
|
||||||
)
|
)
|
||||||
db.session.close()
|
db.session.close()
|
||||||
assert result["query"]["state"] in (
|
assert result["query"]["state"] in (
|
||||||
|
@ -255,14 +338,15 @@ class CeleryTestCase(SupersetTestCase):
|
||||||
|
|
||||||
query = self.get_query_by_id(result["query"]["serverId"])
|
query = self.get_query_by_id(result["query"]["serverId"])
|
||||||
self.assertEqual(QueryStatus.SUCCESS, query.status)
|
self.assertEqual(QueryStatus.SUCCESS, query.status)
|
||||||
self.assertTrue(f"FROM {tmp_table}" in query.select_sql)
|
|
||||||
|
self.assertIn(f"FROM {tmp_table}", query.select_sql)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
f"CREATE TABLE {tmp_table} AS \n" "SELECT name FROM birth_names LIMIT 1",
|
f"CREATE TABLE {tmp_table} AS \n" "SELECT name FROM birth_names LIMIT 1",
|
||||||
query.executed_sql,
|
query.executed_sql,
|
||||||
)
|
)
|
||||||
self.assertEqual(sql_where, query.sql)
|
self.assertEqual(sql_where, query.sql)
|
||||||
self.assertEqual(0, query.rows)
|
self.assertEqual(0, query.rows)
|
||||||
self.assertEqual(1, query.limit)
|
self.assertEqual(None, query.limit)
|
||||||
self.assertEqual(True, query.select_as_cta)
|
self.assertEqual(True, query.select_as_cta)
|
||||||
self.assertEqual(True, query.select_as_cta_used)
|
self.assertEqual(True, query.select_as_cta_used)
|
||||||
|
|
||||||
|
@ -280,9 +364,12 @@ class CeleryTestCase(SupersetTestCase):
|
||||||
with mock.patch.object(
|
with mock.patch.object(
|
||||||
db_engine_spec, "expand_data", wraps=db_engine_spec.expand_data
|
db_engine_spec, "expand_data", wraps=db_engine_spec.expand_data
|
||||||
) as expand_data:
|
) as expand_data:
|
||||||
data, selected_columns, all_columns, expanded_columns = sql_lab._serialize_and_expand_data(
|
(
|
||||||
results, db_engine_spec, False, True
|
data,
|
||||||
)
|
selected_columns,
|
||||||
|
all_columns,
|
||||||
|
expanded_columns,
|
||||||
|
) = sql_lab._serialize_and_expand_data(results, db_engine_spec, False, True)
|
||||||
expand_data.assert_called_once()
|
expand_data.assert_called_once()
|
||||||
|
|
||||||
self.assertIsInstance(data, list)
|
self.assertIsInstance(data, list)
|
||||||
|
@ -301,9 +388,12 @@ class CeleryTestCase(SupersetTestCase):
|
||||||
with mock.patch.object(
|
with mock.patch.object(
|
||||||
db_engine_spec, "expand_data", wraps=db_engine_spec.expand_data
|
db_engine_spec, "expand_data", wraps=db_engine_spec.expand_data
|
||||||
) as expand_data:
|
) as expand_data:
|
||||||
data, selected_columns, all_columns, expanded_columns = sql_lab._serialize_and_expand_data(
|
(
|
||||||
results, db_engine_spec, True
|
data,
|
||||||
)
|
selected_columns,
|
||||||
|
all_columns,
|
||||||
|
expanded_columns,
|
||||||
|
) = sql_lab._serialize_and_expand_data(results, db_engine_spec, True)
|
||||||
expand_data.assert_not_called()
|
expand_data.assert_not_called()
|
||||||
|
|
||||||
self.assertIsInstance(data, bytes)
|
self.assertIsInstance(data, bytes)
|
||||||
|
@ -324,7 +414,12 @@ class CeleryTestCase(SupersetTestCase):
|
||||||
"sql": "SELECT * FROM birth_names LIMIT 100",
|
"sql": "SELECT * FROM birth_names LIMIT 100",
|
||||||
"status": QueryStatus.PENDING,
|
"status": QueryStatus.PENDING,
|
||||||
}
|
}
|
||||||
serialized_data, selected_columns, all_columns, expanded_columns = sql_lab._serialize_and_expand_data(
|
(
|
||||||
|
serialized_data,
|
||||||
|
selected_columns,
|
||||||
|
all_columns,
|
||||||
|
expanded_columns,
|
||||||
|
) = sql_lab._serialize_and_expand_data(
|
||||||
results, db_engine_spec, use_new_deserialization
|
results, db_engine_spec, use_new_deserialization
|
||||||
)
|
)
|
||||||
payload = {
|
payload = {
|
||||||
|
@ -357,7 +452,12 @@ class CeleryTestCase(SupersetTestCase):
|
||||||
"sql": "SELECT * FROM birth_names LIMIT 100",
|
"sql": "SELECT * FROM birth_names LIMIT 100",
|
||||||
"status": QueryStatus.PENDING,
|
"status": QueryStatus.PENDING,
|
||||||
}
|
}
|
||||||
serialized_data, selected_columns, all_columns, expanded_columns = sql_lab._serialize_and_expand_data(
|
(
|
||||||
|
serialized_data,
|
||||||
|
selected_columns,
|
||||||
|
all_columns,
|
||||||
|
expanded_columns,
|
||||||
|
) = sql_lab._serialize_and_expand_data(
|
||||||
results, db_engine_spec, use_new_deserialization
|
results, db_engine_spec, use_new_deserialization
|
||||||
)
|
)
|
||||||
payload = {
|
payload = {
|
||||||
|
|
|
@ -451,19 +451,28 @@ class SupersetTestCase(unittest.TestCase):
|
||||||
def test_get_query_with_new_limit_comment(self):
|
def test_get_query_with_new_limit_comment(self):
|
||||||
sql = "SELECT * FROM birth_names -- SOME COMMENT"
|
sql = "SELECT * FROM birth_names -- SOME COMMENT"
|
||||||
parsed = sql_parse.ParsedQuery(sql)
|
parsed = sql_parse.ParsedQuery(sql)
|
||||||
newsql = parsed.get_query_with_new_limit(1000)
|
newsql = parsed.set_or_update_query_limit(1000)
|
||||||
self.assertEqual(newsql, sql + "\nLIMIT 1000")
|
self.assertEqual(newsql, sql + "\nLIMIT 1000")
|
||||||
|
|
||||||
def test_get_query_with_new_limit_comment_with_limit(self):
|
def test_get_query_with_new_limit_comment_with_limit(self):
|
||||||
sql = "SELECT * FROM birth_names -- SOME COMMENT WITH LIMIT 555"
|
sql = "SELECT * FROM birth_names -- SOME COMMENT WITH LIMIT 555"
|
||||||
parsed = sql_parse.ParsedQuery(sql)
|
parsed = sql_parse.ParsedQuery(sql)
|
||||||
newsql = parsed.get_query_with_new_limit(1000)
|
newsql = parsed.set_or_update_query_limit(1000)
|
||||||
self.assertEqual(newsql, sql + "\nLIMIT 1000")
|
self.assertEqual(newsql, sql + "\nLIMIT 1000")
|
||||||
|
|
||||||
def test_get_query_with_new_limit(self):
|
def test_get_query_with_new_limit_lower(self):
|
||||||
sql = "SELECT * FROM birth_names LIMIT 555"
|
sql = "SELECT * FROM birth_names LIMIT 555"
|
||||||
parsed = sql_parse.ParsedQuery(sql)
|
parsed = sql_parse.ParsedQuery(sql)
|
||||||
newsql = parsed.get_query_with_new_limit(1000)
|
newsql = parsed.set_or_update_query_limit(1000)
|
||||||
|
# not applied as new limit is higher
|
||||||
|
expected = "SELECT * FROM birth_names LIMIT 555"
|
||||||
|
self.assertEqual(newsql, expected)
|
||||||
|
|
||||||
|
def test_get_query_with_new_limit_upper(self):
|
||||||
|
sql = "SELECT * FROM birth_names LIMIT 1555"
|
||||||
|
parsed = sql_parse.ParsedQuery(sql)
|
||||||
|
newsql = parsed.set_or_update_query_limit(1000)
|
||||||
|
# applied as new limit is lower
|
||||||
expected = "SELECT * FROM birth_names LIMIT 1000"
|
expected = "SELECT * FROM birth_names LIMIT 1000"
|
||||||
self.assertEqual(newsql, expected)
|
self.assertEqual(newsql, expected)
|
||||||
|
|
||||||
|
|
|
@ -19,11 +19,12 @@
|
||||||
import json
|
import json
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from random import random
|
from random import random
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
import prison
|
import prison
|
||||||
|
|
||||||
import tests.test_app
|
import tests.test_app
|
||||||
from superset import db, security_manager
|
from superset import config, db, security_manager
|
||||||
from superset.connectors.sqla.models import SqlaTable
|
from superset.connectors.sqla.models import SqlaTable
|
||||||
from superset.dataframe import df_to_records
|
from superset.dataframe import df_to_records
|
||||||
from superset.db_engine_specs import BaseEngineSpec
|
from superset.db_engine_specs import BaseEngineSpec
|
||||||
|
@ -67,6 +68,39 @@ class SqlLabTests(SupersetTestCase):
|
||||||
data = self.run_sql("SELECT * FROM unexistant_table", "2")
|
data = self.run_sql("SELECT * FROM unexistant_table", "2")
|
||||||
self.assertLess(0, len(data["error"]))
|
self.assertLess(0, len(data["error"]))
|
||||||
|
|
||||||
|
@mock.patch(
|
||||||
|
"superset.views.core.get_cta_schema_name",
|
||||||
|
lambda d, u, s, sql: f"{u.username}_database",
|
||||||
|
)
|
||||||
|
def test_sql_json_cta_dynamic_db(self):
|
||||||
|
main_db = get_example_database()
|
||||||
|
if main_db.backend == "sqlite":
|
||||||
|
# sqlite doesn't support database creation
|
||||||
|
return
|
||||||
|
|
||||||
|
old_allow_ctas = main_db.allow_ctas
|
||||||
|
main_db.allow_ctas = True # enable cta
|
||||||
|
|
||||||
|
self.login("admin")
|
||||||
|
self.run_sql(
|
||||||
|
"SELECT * FROM birth_names",
|
||||||
|
"1",
|
||||||
|
database_name="examples",
|
||||||
|
tmp_table_name="test_target",
|
||||||
|
select_as_cta=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# assertions
|
||||||
|
data = db.session.execute("SELECT * FROM admin_database.test_target").fetchall()
|
||||||
|
self.assertEqual(
|
||||||
|
75691, len(data)
|
||||||
|
) # SQL_MAX_ROW not applied due to the SQLLAB_CTAS_NO_LIMIT set to True
|
||||||
|
|
||||||
|
# cleanup
|
||||||
|
db.session.execute("DROP TABLE admin_database.test_target")
|
||||||
|
main_db.allow_ctas = old_allow_ctas
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
def test_multi_sql(self):
|
def test_multi_sql(self):
|
||||||
self.login("admin")
|
self.login("admin")
|
||||||
|
|
||||||
|
|
|
@ -28,8 +28,8 @@ SUPERSET_WEBSERVER_PORT = 8081
|
||||||
if "SUPERSET__SQLALCHEMY_DATABASE_URI" in os.environ:
|
if "SUPERSET__SQLALCHEMY_DATABASE_URI" in os.environ:
|
||||||
SQLALCHEMY_DATABASE_URI = os.environ["SUPERSET__SQLALCHEMY_DATABASE_URI"]
|
SQLALCHEMY_DATABASE_URI = os.environ["SUPERSET__SQLALCHEMY_DATABASE_URI"]
|
||||||
|
|
||||||
SQL_SELECT_AS_CTA = True
|
|
||||||
SQL_MAX_ROW = 666
|
SQL_MAX_ROW = 666
|
||||||
|
SQLLAB_CTAS_NO_LIMIT = True # SQL_MAX_ROW will not take affect for the CTA queries
|
||||||
FEATURE_FLAGS = {"foo": "bar", "KV_STORE": True, "SHARE_QUERIES_VIA_KV_STORE": True}
|
FEATURE_FLAGS = {"foo": "bar", "KV_STORE": True, "SHARE_QUERIES_VIA_KV_STORE": True}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -30,8 +30,8 @@ SUPERSET_WEBSERVER_PORT = 8081
|
||||||
if "SUPERSET__SQLALCHEMY_DATABASE_URI" in os.environ:
|
if "SUPERSET__SQLALCHEMY_DATABASE_URI" in os.environ:
|
||||||
SQLALCHEMY_DATABASE_URI = os.environ["SUPERSET__SQLALCHEMY_DATABASE_URI"]
|
SQLALCHEMY_DATABASE_URI = os.environ["SUPERSET__SQLALCHEMY_DATABASE_URI"]
|
||||||
|
|
||||||
SQL_SELECT_AS_CTA = True
|
|
||||||
SQL_MAX_ROW = 666
|
SQL_MAX_ROW = 666
|
||||||
|
SQLLAB_CTAS_NO_LIMIT = True # SQL_MAX_ROW will not take affect for the CTA queries
|
||||||
FEATURE_FLAGS = {"foo": "bar"}
|
FEATURE_FLAGS = {"foo": "bar"}
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue