diff --git a/.travis.yml b/.travis.yml index 8729632075..623107c565 100644 --- a/.travis.yml +++ b/.travis.yml @@ -64,8 +64,10 @@ jobs: - redis-server before_script: - mysql -u root -e "DROP DATABASE IF EXISTS superset; CREATE DATABASE superset DEFAULT CHARACTER SET utf8 COLLATE utf8_unicode_ci" + - mysql -u root -e "DROP DATABASE IF EXISTS sqllab_test_db; CREATE DATABASE sqllab_test_db DEFAULT CHARACTER SET utf8 COLLATE utf8_unicode_ci" + - mysql -u root -e "DROP DATABASE IF EXISTS admin_database; CREATE DATABASE admin_database DEFAULT CHARACTER SET utf8 COLLATE utf8_unicode_ci" - mysql -u root -e "CREATE USER 'mysqluser'@'localhost' IDENTIFIED BY 'mysqluserpassword';" - - mysql -u root -e "GRANT ALL ON superset.* TO 'mysqluser'@'localhost';" + - mysql -u root -e "GRANT ALL ON *.* TO 'mysqluser'@'localhost';" - language: python env: TOXENV=javascript before_install: @@ -91,8 +93,15 @@ jobs: - postgresql - redis-server before_script: + - psql -U postgres -c "DROP DATABASE IF EXISTS superset;" - psql -U postgres -c "CREATE DATABASE superset;" + - psql -U postgres superset -c "DROP SCHEMA IF EXISTS sqllab_test_db;" + - psql -U postgres superset -c "CREATE SCHEMA sqllab_test_db;" + - psql -U postgres superset -c "DROP SCHEMA IF EXISTS admin_database;" + - psql -U postgres superset -c "CREATE SCHEMA admin_database;" - psql -U postgres -c "CREATE USER postgresuser WITH PASSWORD 'pguserpassword';" + - psql -U postgres superset -c "GRANT ALL PRIVILEGES ON SCHEMA sqllab_test_db to postgresuser"; + - psql -U postgres superset -c "GRANT ALL PRIVILEGES ON SCHEMA admin_database to postgresuser"; - language: python python: 3.6 env: TOXENV=pylint diff --git a/superset/assets/version_info.json b/superset/assets/version_info.json new file mode 100644 index 0000000000..e69de29bb2 diff --git a/superset/config.py b/superset/config.py index 6168c51f13..4481a9558e 100644 --- a/superset/config.py +++ b/superset/config.py @@ -28,7 +28,7 @@ import os import sys from collections import OrderedDict from datetime import date -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING from celery.schedules import crontab from dateutil import tz @@ -41,6 +41,9 @@ from superset.utils.logging_configurator import DefaultLoggingConfigurator logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from flask_appbuilder.security.sqla import models # pylint: disable=unused-import + from superset.models.core import Database # pylint: disable=unused-import # Realtime stats logger, a StatsD implementation exists STATS_LOGGER = DummyStatsLogger() @@ -523,6 +526,32 @@ SQLLAB_ASYNC_TIME_LIMIT_SEC = 60 * 60 * 6 # timeout. SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT = 10 # seconds +# Flag that controls if limit should be enforced on the CTA (create table as queries). +SQLLAB_CTAS_NO_LIMIT = False + +# This allows you to define custom logic around the "CREATE TABLE AS" or CTAS feature +# in SQL Lab that defines where the target schema should be for a given user. +# Database `CTAS Schema` has a precedence over this setting. +# Example below returns a username and CTA queries will write tables into the schema +# name `username` +# SQLLAB_CTAS_SCHEMA_NAME_FUNC = lambda database, user, schema, sql: user.username +# This is move involved example where depending on the database you can leverage data +# available to assign schema for the CTA query: +# def compute_schema_name(database: Database, user: User, schema: str, sql: str) -> str: +# if database.name == 'mysql_payments_slave': +# return 'tmp_superset_schema' +# if database.name == 'presto_gold': +# return user.username +# if database.name == 'analytics': +# if 'analytics' in [r.name for r in user.roles]: +# return 'analytics_cta' +# else: +# return f'tmp_{schema}' +# Function accepts database object, user object, schema name and sql that will be run. +SQLLAB_CTAS_SCHEMA_NAME_FUNC: Optional[ + Callable[["Database", "models.User", str, str], str] +] = None + # An instantiated derivative of werkzeug.contrib.cache.BaseCache # if enabled, it can be used to store the results of long-running queries # in SQL Lab by using the "Run Async" button/feature diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index ff40b2a19d..7c63e931c6 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -336,7 +336,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return database.compile_sqla_query(qry) elif LimitMethod.FORCE_LIMIT: parsed_query = sql_parse.ParsedQuery(sql) - sql = parsed_query.get_query_with_new_limit(limit) + sql = parsed_query.set_or_update_query_limit(limit) return sql @classmethod @@ -351,7 +351,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return parsed_query.limit @classmethod - def get_query_with_new_limit(cls, sql: str, limit: int) -> str: + def set_or_update_query_limit(cls, sql: str, limit: int) -> str: """ Create a query based on original query but with new limit clause @@ -360,7 +360,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods :return: Query with new limit """ parsed_query = sql_parse.ParsedQuery(sql) - return parsed_query.get_query_with_new_limit(limit) + return parsed_query.set_or_update_query_limit(limit) @staticmethod def csv_to_df(**kwargs: Any) -> pd.DataFrame: @@ -632,10 +632,12 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods """ Generate a "SELECT * from [schema.]table_name" query with appropriate limit. + WARNING: expects only unquoted table and schema names. + :param database: Database instance - :param table_name: Table name + :param table_name: Table name, unquoted :param engine: SqlALchemy Engine instance - :param schema: Schema + :param schema: Schema, unquoted :param limit: limit to impose on query :param show_cols: Show columns in query; otherwise use "*" :param indent: Add indentation to query diff --git a/superset/migrations/versions/72428d1ea401_add_tmp_schema_name_to_the_query_object.py b/superset/migrations/versions/72428d1ea401_add_tmp_schema_name_to_the_query_object.py new file mode 100644 index 0000000000..d50db62b52 --- /dev/null +++ b/superset/migrations/versions/72428d1ea401_add_tmp_schema_name_to_the_query_object.py @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Add tmp_schema_name to the query object. + +Revision ID: 72428d1ea401 +Revises: 0a6f12f60c73 +Create Date: 2020-02-20 08:52:22.877902 + +""" + +# revision identifiers, used by Alembic. +revision = "72428d1ea401" +down_revision = "0a6f12f60c73" + +import sqlalchemy as sa +from alembic import op + + +def upgrade(): + op.add_column( + "query", sa.Column("tmp_schema_name", sa.String(length=256), nullable=True) + ) + + +def downgrade(): + try: + # sqlite doesn't like dropping the columns + op.drop_column("query", "tmp_schema_name") + except Exception: + pass diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py index 820ded8065..3dad0da31c 100644 --- a/superset/models/sql_lab.py +++ b/superset/models/sql_lab.py @@ -55,6 +55,7 @@ class Query(Model, ExtraJSONMixin): # Store the tmp table into the DB only if the user asks for it. tmp_table_name = Column(String(256)) + tmp_schema_name = Column(String(256)) user_id = Column(Integer, ForeignKey("ab_user.id"), nullable=True) status = Column(String(16), default=QueryStatus.PENDING) tab_name = Column(String(256)) diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 58c61fc72d..01dd2e5bb1 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -59,6 +59,7 @@ stats_logger = config["STATS_LOGGER"] SQLLAB_TIMEOUT = config["SQLLAB_ASYNC_TIME_LIMIT_SEC"] SQLLAB_HARD_TIMEOUT = SQLLAB_TIMEOUT + 60 SQL_MAX_ROW = config["SQL_MAX_ROW"] +SQLLAB_CTAS_NO_LIMIT = config["SQLLAB_CTAS_NO_LIMIT"] SQL_QUERY_MUTATOR = config["SQL_QUERY_MUTATOR"] log_query = config["QUERY_LOGGER"] logger = logging.getLogger(__name__) @@ -207,9 +208,15 @@ def execute_sql_statement(sql_statement, query, user_name, session, cursor, log_ query.tmp_table_name = "tmp_{}_table_{}".format( query.user_id, start_dttm.strftime("%Y_%m_%d_%H_%M_%S") ) - sql = parsed_query.as_create_table(query.tmp_table_name) + sql = parsed_query.as_create_table( + query.tmp_table_name, schema_name=query.tmp_schema_name + ) query.select_as_cta_used = True - if parsed_query.is_select(): + + # Do not apply limit to the CTA queries when SQLLAB_CTAS_NO_LIMIT is set to true + if parsed_query.is_select() and not ( + query.select_as_cta_used and SQLLAB_CTAS_NO_LIMIT + ): if SQL_MAX_ROW and (not query.limit or query.limit > SQL_MAX_ROW): query.limit = SQL_MAX_ROW if query.limit: @@ -378,6 +385,9 @@ def execute_sql_statements( payload = handle_query_error(msg, query, session, payload) return payload + # Commit the connection so CTA queries will create the table. + conn.commit() + # Success, updating the query entry in database query.rows = result_set.size query.progress = 100 @@ -385,8 +395,8 @@ def execute_sql_statements( if query.select_as_cta: query.select_sql = database.select_star( query.tmp_table_name, + schema=query.tmp_schema_name, limit=query.limit, - schema=database.force_ctas_schema, show_cols=False, latest_partition=False, ) diff --git a/superset/sql_parse.py b/superset/sql_parse.py index 3dc7ec9dd8..35c6e9f8ce 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -151,20 +151,28 @@ class ParsedQuery: self._alias_names.add(token_list.tokens[0].value) self.__extract_from_token(token_list) - def as_create_table(self, table_name: str, overwrite: bool = False) -> str: + def as_create_table( + self, + table_name: str, + schema_name: Optional[str] = None, + overwrite: bool = False, + ) -> str: """Reformats the query into the create table as query. Works only for the single select SQL statements, in all other cases the sql query is not modified. - :param table_name: Table that will contain the results of the query execution + :param table_name: table that will contain the results of the query execution + :param schema_name: schema name for the target table :param overwrite: table_name will be dropped if true :return: Create table as query """ exec_sql = "" sql = self.stripped() + # TODO(bkyryliuk): quote full_table_name + full_table_name = f"{schema_name}.{table_name}" if schema_name else table_name if overwrite: - exec_sql = f"DROP TABLE IF EXISTS {table_name};\n" - exec_sql += f"CREATE TABLE {table_name} AS \n{sql}" + exec_sql = f"DROP TABLE IF EXISTS {full_table_name};\n" + exec_sql += f"CREATE TABLE {full_table_name} AS \n{sql}" return exec_sql def __extract_from_token(self, token: Token): # pylint: disable=too-many-branches @@ -205,10 +213,12 @@ class ParsedQuery: if not self.__is_identifier(token2): self.__extract_from_token(item) - def get_query_with_new_limit(self, new_limit: int) -> str: - """ - returns the query with the specified limit. - Does not change the underlying query + def set_or_update_query_limit(self, new_limit: int) -> str: + """Returns the query with the specified limit. + + Does not change the underlying query if user did not apply the limit, + otherwise replaces the limit with the lower value between existing limit + in the query and new_limit. :param new_limit: Limit to be incorporated into returned query :return: The original query with new limit @@ -223,7 +233,10 @@ class ParsedQuery: limit_pos = pos break _, limit = statement.token_next(idx=limit_pos) - if limit.ttype == sqlparse.tokens.Literal.Number.Integer: + # Override the limit only when it exceeds the configured value. + if limit.ttype == sqlparse.tokens.Literal.Number.Integer and new_limit < int( + limit.value + ): limit.value = new_limit elif limit.is_group: limit.value = f"{next(limit.get_identifiers())}, {new_limit}" diff --git a/superset/views/core.py b/superset/views/core.py index c249187a13..4ecc097289 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -19,7 +19,7 @@ import logging import re from contextlib import closing from datetime import datetime, timedelta -from typing import Any, cast, Dict, List, Optional, Union +from typing import Any, Callable, cast, Dict, List, Optional, Union from urllib import parse import backoff @@ -73,6 +73,7 @@ from superset.exceptions import ( SupersetTimeoutException, ) from superset.jinja_context import get_template_processor +from superset.models.core import Database from superset.models.dashboard import Dashboard from superset.models.datasource_access_request import DatasourceAccessRequest from superset.models.slice import Slice @@ -247,6 +248,17 @@ def _deserialize_results_payload( return json.loads(payload) # type: ignore +def get_cta_schema_name( + database: Database, user: ab_models.User, schema: str, sql: str +) -> Optional[str]: + func: Optional[Callable[[Database, ab_models.User, str, str], str]] = config[ + "SQLLAB_CTAS_SCHEMA_NAME_FUNC" + ] + if not func: + return None + return func(database, user, schema, sql) + + class AccessRequestsModelView(SupersetModelView, DeleteMixin): datamodel = SQLAInterface(DAR) include_route_methods = RouteMethod.CRUD_SET @@ -2351,9 +2363,14 @@ class Superset(BaseSupersetView): if not mydb: return json_error_response(f"Database with id {database_id} is missing.") - # Set tmp_table_name for CTA + # Set tmp_schema_name for CTA + # TODO(bkyryliuk): consider parsing, splitting tmp_schema_name from tmp_table_name if user enters + # . + tmp_schema_name: Optional[str] = schema if select_as_cta and mydb.force_ctas_schema: - tmp_table_name = f"{mydb.force_ctas_schema}.{tmp_table_name}" + tmp_schema_name = mydb.force_ctas_schema + elif select_as_cta: + tmp_schema_name = get_cta_schema_name(mydb, g.user, schema, sql) # Save current query query = Query( @@ -2366,6 +2383,7 @@ class Superset(BaseSupersetView): status=status, sql_editor_id=sql_editor_id, tmp_table_name=tmp_table_name, + tmp_schema_name=tmp_schema_name, user_id=g.user.get_id() if g.user else None, client_id=client_id, ) @@ -2406,9 +2424,11 @@ class Superset(BaseSupersetView): f"Query {query_id}: Template rendering failed: {error_msg}" ) - # set LIMIT after template processing - limits = [mydb.db_engine_spec.get_limit_from_sql(rendered_query), limit] - query.limit = min(lim for lim in limits if lim is not None) + # Limit is not applied to the CTA queries if SQLLAB_CTAS_NO_LIMIT flag is set to True. + if not (config.get("SQLLAB_CTAS_NO_LIMIT") and select_as_cta): + # set LIMIT after template processing + limits = [mydb.db_engine_spec.get_limit_from_sql(rendered_query), limit] + query.limit = min(lim for lim in limits if lim is not None) # Flag for whether or not to expand data # (feature that will expand Presto row objects and arrays) diff --git a/tests/base_tests.py b/tests/base_tests.py index 5728dd8868..09b183d98d 100644 --- a/tests/base_tests.py +++ b/tests/base_tests.py @@ -229,22 +229,27 @@ class SupersetTestCase(TestCase): query_limit=None, database_name="examples", sql_editor_id=None, + select_as_cta=False, + tmp_table_name=None, ): if user_name: self.logout() self.login(username=(user_name or "admin")) dbid = self._get_database_by_name(database_name).id + json_payload = { + "database_id": dbid, + "sql": sql, + "client_id": client_id, + "queryLimit": query_limit, + "sql_editor_id": sql_editor_id, + } + if tmp_table_name: + json_payload["tmp_table_name"] = tmp_table_name + if select_as_cta: + json_payload["select_as_cta"] = select_as_cta + resp = self.get_json_resp( - "/superset/sql_json/", - raise_on_error=False, - json_=dict( - database_id=dbid, - sql=sql, - select_as_create_as=False, - client_id=client_id, - queryLimit=query_limit, - sql_editor_id=sql_editor_id, - ), + "/superset/sql_json/", raise_on_error=False, json_=json_payload ) if raise_on_error and "error" in resp: raise Exception("run_sql failed") diff --git a/tests/celery_tests.py b/tests/celery_tests.py index f89e7871eb..07950c11e3 100644 --- a/tests/celery_tests.py +++ b/tests/celery_tests.py @@ -17,14 +17,20 @@ # isort:skip_file """Unit tests for Superset Celery worker""" import datetime +import io import json +import logging import subprocess import time import unittest import unittest.mock as mock import flask +import sqlalchemy +from contextlib2 import contextmanager from flask import current_app +from sqlalchemy.orm import sessionmaker +from sqlalchemy.pool import NullPool from tests.test_app import app from superset import db, sql_lab @@ -38,11 +44,11 @@ from superset.utils.core import get_example_database from .base_tests import SupersetTestCase +CELERY_SHORT_SLEEP_TIME = 2 CELERY_SLEEP_TIME = 5 class UtilityFunctionTests(SupersetTestCase): - # TODO(bkyryliuk): support more cases in CTA function. def test_create_table_as(self): q = ParsedQuery("SELECT * FROM outer_space;") @@ -90,6 +96,9 @@ class AppContextTests(SupersetTestCase): flask._app_ctx_stack.push(popped_app) +CTAS_SCHEMA_NAME = "sqllab_test_db" + + class CeleryTestCase(SupersetTestCase): def get_query_by_name(self, sql): session = db.session @@ -159,7 +168,6 @@ class CeleryTestCase(SupersetTestCase): def test_run_sync_query_cta(self): main_db = get_example_database() - backend = main_db.backend db_id = main_db.id tmp_table_name = "tmp_async_22" self.drop_table_if_exists(tmp_table_name, main_db) @@ -172,11 +180,12 @@ class CeleryTestCase(SupersetTestCase): query2 = self.get_query_by_id(result["query"]["serverId"]) # Check the data in the tmp table. - if backend != "postgresql": - # TODO This test won't work in Postgres - results = self.run_sql(db_id, query2.select_sql, "sdf2134") - self.assertEqual(results["status"], "success") - self.assertGreater(len(results["data"]), 0) + results = self.run_sql(db_id, query2.select_sql, "sdf2134") + self.assertEqual(results["status"], "success") + self.assertGreater(len(results["data"]), 0) + + # cleanup tmp table + self.drop_table_if_exists(tmp_table_name, get_example_database()) def test_run_sync_query_cta_no_data(self): main_db = get_example_database() @@ -199,15 +208,89 @@ class CeleryTestCase(SupersetTestCase): db.session.flush() return self.run_sql(db_id, sql) - def test_run_async_query(self): + @mock.patch( + "superset.views.core.get_cta_schema_name", lambda d, u, s, sql: CTAS_SCHEMA_NAME + ) + def test_run_sync_query_cta_config(self): + main_db = get_example_database() + db_id = main_db.id + if main_db.backend == "sqlite": + # sqlite doesn't support schemas + return + tmp_table_name = "tmp_async_22" + expected_full_table_name = f"{CTAS_SCHEMA_NAME}.{tmp_table_name}" + self.drop_table_if_exists(expected_full_table_name, main_db) + name = "James" + sql_where = f"SELECT name FROM birth_names WHERE name='{name}'" + result = self.run_sql( + db_id, sql_where, "cid2", tmp_table=tmp_table_name, cta=True + ) + + self.assertEqual(QueryStatus.SUCCESS, result["query"]["state"]) + self.assertEqual([], result["data"]) + self.assertEqual([], result["columns"]) + query = self.get_query_by_id(result["query"]["serverId"]) + self.assertEqual( + f"CREATE TABLE {expected_full_table_name} AS \n" + "SELECT name FROM birth_names " + "WHERE name='James'", + query.executed_sql, + ) + self.assertEqual( + "SELECT *\n" f"FROM {expected_full_table_name}", query.select_sql + ) + time.sleep(CELERY_SHORT_SLEEP_TIME) + results = self.run_sql(db_id, query.select_sql) + self.assertEqual(results["status"], "success") + self.drop_table_if_exists(expected_full_table_name, get_example_database()) + + @mock.patch( + "superset.views.core.get_cta_schema_name", lambda d, u, s, sql: CTAS_SCHEMA_NAME + ) + def test_run_async_query_cta_config(self): + main_db = get_example_database() + db_id = main_db.id + if main_db.backend == "sqlite": + # sqlite doesn't support schemas + return + tmp_table_name = "sqllab_test_table_async_1" + expected_full_table_name = f"{CTAS_SCHEMA_NAME}.{tmp_table_name}" + self.drop_table_if_exists(expected_full_table_name, main_db) + sql_where = "SELECT name FROM birth_names WHERE name='James' LIMIT 10" + result = self.run_sql( + db_id, + sql_where, + "cid3", + async_=True, + tmp_table="sqllab_test_table_async_1", + cta=True, + ) + db.session.close() + time.sleep(CELERY_SLEEP_TIME) + + query = self.get_query_by_id(result["query"]["serverId"]) + self.assertEqual(QueryStatus.SUCCESS, query.status) + self.assertTrue(f"FROM {expected_full_table_name}" in query.select_sql) + self.assertEqual( + f"CREATE TABLE {expected_full_table_name} AS \n" + "SELECT name FROM birth_names " + "WHERE name='James' " + "LIMIT 10", + query.executed_sql, + ) + self.drop_table_if_exists(expected_full_table_name, get_example_database()) + + def test_run_async_cta_query(self): main_db = get_example_database() db_id = main_db.id - self.drop_table_if_exists("tmp_async_1", main_db) + table_name = "tmp_async_4" + self.drop_table_if_exists(table_name, main_db) + time.sleep(CELERY_SLEEP_TIME) sql_where = "SELECT name FROM birth_names WHERE name='James' LIMIT 10" result = self.run_sql( - db_id, sql_where, "4", async_=True, tmp_table="tmp_async_1", cta=True + db_id, sql_where, "cid4", async_=True, tmp_table="tmp_async_4", cta=True ) db.session.close() assert result["query"]["state"] in ( @@ -221,9 +304,9 @@ class CeleryTestCase(SupersetTestCase): query = self.get_query_by_id(result["query"]["serverId"]) self.assertEqual(QueryStatus.SUCCESS, query.status) - self.assertTrue("FROM tmp_async_1" in query.select_sql) + self.assertTrue(f"FROM {table_name}" in query.select_sql) self.assertEqual( - "CREATE TABLE tmp_async_1 AS \n" + f"CREATE TABLE {table_name} AS \n" "SELECT name FROM birth_names " "WHERE name='James' " "LIMIT 10", @@ -234,7 +317,7 @@ class CeleryTestCase(SupersetTestCase): self.assertEqual(True, query.select_as_cta) self.assertEqual(True, query.select_as_cta_used) - def test_run_async_query_with_lower_limit(self): + def test_run_async_cta_query_with_lower_limit(self): main_db = get_example_database() db_id = main_db.id tmp_table = "tmp_async_2" @@ -242,7 +325,7 @@ class CeleryTestCase(SupersetTestCase): sql_where = "SELECT name FROM birth_names LIMIT 1" result = self.run_sql( - db_id, sql_where, "5", async_=True, tmp_table=tmp_table, cta=True + db_id, sql_where, "id1", async_=True, tmp_table=tmp_table, cta=True ) db.session.close() assert result["query"]["state"] in ( @@ -255,14 +338,15 @@ class CeleryTestCase(SupersetTestCase): query = self.get_query_by_id(result["query"]["serverId"]) self.assertEqual(QueryStatus.SUCCESS, query.status) - self.assertTrue(f"FROM {tmp_table}" in query.select_sql) + + self.assertIn(f"FROM {tmp_table}", query.select_sql) self.assertEqual( f"CREATE TABLE {tmp_table} AS \n" "SELECT name FROM birth_names LIMIT 1", query.executed_sql, ) self.assertEqual(sql_where, query.sql) self.assertEqual(0, query.rows) - self.assertEqual(1, query.limit) + self.assertEqual(None, query.limit) self.assertEqual(True, query.select_as_cta) self.assertEqual(True, query.select_as_cta_used) @@ -280,9 +364,12 @@ class CeleryTestCase(SupersetTestCase): with mock.patch.object( db_engine_spec, "expand_data", wraps=db_engine_spec.expand_data ) as expand_data: - data, selected_columns, all_columns, expanded_columns = sql_lab._serialize_and_expand_data( - results, db_engine_spec, False, True - ) + ( + data, + selected_columns, + all_columns, + expanded_columns, + ) = sql_lab._serialize_and_expand_data(results, db_engine_spec, False, True) expand_data.assert_called_once() self.assertIsInstance(data, list) @@ -301,9 +388,12 @@ class CeleryTestCase(SupersetTestCase): with mock.patch.object( db_engine_spec, "expand_data", wraps=db_engine_spec.expand_data ) as expand_data: - data, selected_columns, all_columns, expanded_columns = sql_lab._serialize_and_expand_data( - results, db_engine_spec, True - ) + ( + data, + selected_columns, + all_columns, + expanded_columns, + ) = sql_lab._serialize_and_expand_data(results, db_engine_spec, True) expand_data.assert_not_called() self.assertIsInstance(data, bytes) @@ -324,7 +414,12 @@ class CeleryTestCase(SupersetTestCase): "sql": "SELECT * FROM birth_names LIMIT 100", "status": QueryStatus.PENDING, } - serialized_data, selected_columns, all_columns, expanded_columns = sql_lab._serialize_and_expand_data( + ( + serialized_data, + selected_columns, + all_columns, + expanded_columns, + ) = sql_lab._serialize_and_expand_data( results, db_engine_spec, use_new_deserialization ) payload = { @@ -357,7 +452,12 @@ class CeleryTestCase(SupersetTestCase): "sql": "SELECT * FROM birth_names LIMIT 100", "status": QueryStatus.PENDING, } - serialized_data, selected_columns, all_columns, expanded_columns = sql_lab._serialize_and_expand_data( + ( + serialized_data, + selected_columns, + all_columns, + expanded_columns, + ) = sql_lab._serialize_and_expand_data( results, db_engine_spec, use_new_deserialization ) payload = { diff --git a/tests/sql_parse_tests.py b/tests/sql_parse_tests.py index 2991b47d57..46e54ffaff 100644 --- a/tests/sql_parse_tests.py +++ b/tests/sql_parse_tests.py @@ -451,19 +451,28 @@ class SupersetTestCase(unittest.TestCase): def test_get_query_with_new_limit_comment(self): sql = "SELECT * FROM birth_names -- SOME COMMENT" parsed = sql_parse.ParsedQuery(sql) - newsql = parsed.get_query_with_new_limit(1000) + newsql = parsed.set_or_update_query_limit(1000) self.assertEqual(newsql, sql + "\nLIMIT 1000") def test_get_query_with_new_limit_comment_with_limit(self): sql = "SELECT * FROM birth_names -- SOME COMMENT WITH LIMIT 555" parsed = sql_parse.ParsedQuery(sql) - newsql = parsed.get_query_with_new_limit(1000) + newsql = parsed.set_or_update_query_limit(1000) self.assertEqual(newsql, sql + "\nLIMIT 1000") - def test_get_query_with_new_limit(self): + def test_get_query_with_new_limit_lower(self): sql = "SELECT * FROM birth_names LIMIT 555" parsed = sql_parse.ParsedQuery(sql) - newsql = parsed.get_query_with_new_limit(1000) + newsql = parsed.set_or_update_query_limit(1000) + # not applied as new limit is higher + expected = "SELECT * FROM birth_names LIMIT 555" + self.assertEqual(newsql, expected) + + def test_get_query_with_new_limit_upper(self): + sql = "SELECT * FROM birth_names LIMIT 1555" + parsed = sql_parse.ParsedQuery(sql) + newsql = parsed.set_or_update_query_limit(1000) + # applied as new limit is lower expected = "SELECT * FROM birth_names LIMIT 1000" self.assertEqual(newsql, expected) diff --git a/tests/sqllab_tests.py b/tests/sqllab_tests.py index 27be7f1578..ad130a8ee2 100644 --- a/tests/sqllab_tests.py +++ b/tests/sqllab_tests.py @@ -19,11 +19,12 @@ import json from datetime import datetime, timedelta from random import random +from unittest import mock import prison import tests.test_app -from superset import db, security_manager +from superset import config, db, security_manager from superset.connectors.sqla.models import SqlaTable from superset.dataframe import df_to_records from superset.db_engine_specs import BaseEngineSpec @@ -67,6 +68,39 @@ class SqlLabTests(SupersetTestCase): data = self.run_sql("SELECT * FROM unexistant_table", "2") self.assertLess(0, len(data["error"])) + @mock.patch( + "superset.views.core.get_cta_schema_name", + lambda d, u, s, sql: f"{u.username}_database", + ) + def test_sql_json_cta_dynamic_db(self): + main_db = get_example_database() + if main_db.backend == "sqlite": + # sqlite doesn't support database creation + return + + old_allow_ctas = main_db.allow_ctas + main_db.allow_ctas = True # enable cta + + self.login("admin") + self.run_sql( + "SELECT * FROM birth_names", + "1", + database_name="examples", + tmp_table_name="test_target", + select_as_cta=True, + ) + + # assertions + data = db.session.execute("SELECT * FROM admin_database.test_target").fetchall() + self.assertEqual( + 75691, len(data) + ) # SQL_MAX_ROW not applied due to the SQLLAB_CTAS_NO_LIMIT set to True + + # cleanup + db.session.execute("DROP TABLE admin_database.test_target") + main_db.allow_ctas = old_allow_ctas + db.session.commit() + def test_multi_sql(self): self.login("admin") diff --git a/tests/superset_test_config.py b/tests/superset_test_config.py index e46df4619f..d95b91a089 100644 --- a/tests/superset_test_config.py +++ b/tests/superset_test_config.py @@ -28,8 +28,8 @@ SUPERSET_WEBSERVER_PORT = 8081 if "SUPERSET__SQLALCHEMY_DATABASE_URI" in os.environ: SQLALCHEMY_DATABASE_URI = os.environ["SUPERSET__SQLALCHEMY_DATABASE_URI"] -SQL_SELECT_AS_CTA = True SQL_MAX_ROW = 666 +SQLLAB_CTAS_NO_LIMIT = True # SQL_MAX_ROW will not take affect for the CTA queries FEATURE_FLAGS = {"foo": "bar", "KV_STORE": True, "SHARE_QUERIES_VIA_KV_STORE": True} diff --git a/tests/superset_test_config_sqllab_backend_persist.py b/tests/superset_test_config_sqllab_backend_persist.py index ace73b85b8..86619a2ff7 100644 --- a/tests/superset_test_config_sqllab_backend_persist.py +++ b/tests/superset_test_config_sqllab_backend_persist.py @@ -30,8 +30,8 @@ SUPERSET_WEBSERVER_PORT = 8081 if "SUPERSET__SQLALCHEMY_DATABASE_URI" in os.environ: SQLALCHEMY_DATABASE_URI = os.environ["SUPERSET__SQLALCHEMY_DATABASE_URI"] -SQL_SELECT_AS_CTA = True SQL_MAX_ROW = 666 +SQLLAB_CTAS_NO_LIMIT = True # SQL_MAX_ROW will not take affect for the CTA queries FEATURE_FLAGS = {"foo": "bar"}