From 3db76c6fdc8e74419e61dc99542cd6f2c81cb23a Mon Sep 17 00:00:00 2001 From: Bogdan Date: Wed, 24 Jun 2020 09:50:41 -0700 Subject: [PATCH] Implement create view as functionality (#9794) Implement create view as button in sqllab Make CVAS configurable Co-authored-by: bogdan kyryliuk --- UPDATING.md | 4 +- requirements-dev.txt | 1 + .../src/SqlLab/actions/sqlLab.js | 6 + .../src/SqlLab/components/ResultSet.jsx | 7 +- .../src/SqlLab/components/SqlEditor.jsx | 44 +++- .../ea396d202291_ctas_method_in_query.py | 42 +++ superset/models/core.py | 2 + superset/models/sql_lab.py | 3 +- superset/sql_lab.py | 4 +- superset/sql_parse.py | 12 +- superset/views/core.py | 7 +- superset/views/database/api.py | 1 + superset/views/database/mixins.py | 3 + tests/base_tests.py | 3 + tests/celery_tests.py | 245 +++++++++++------- tests/database_api_tests.py | 1 + tests/sqllab_tests.py | 56 ++-- 17 files changed, 304 insertions(+), 137 deletions(-) create mode 100644 superset/migrations/versions/ea396d202291_ctas_method_in_query.py diff --git a/UPDATING.md b/UPDATING.md index f5a3644b84..b00f157a5f 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -32,7 +32,9 @@ assists people when migrating to a new version. * [9786](https://github.com/apache/incubator-superset/pull/9786): with the upgrade of `werkzeug` from version `0.16.0` to `1.0.1`, the `werkzeug.contrib.cache` module has been moved to a standalone package [cachelib](https://pypi.org/project/cachelib/). For example, to import the `RedisCache` class, please use the following import: `from cachelib.redis import RedisCache`. -* [9572](https://github.com/apache/incubator-superset/pull/9572): a change which by defau;t means that the Jinja `current_user_id`, `current_username`, and `url_param` context calls no longer need to be wrapped via `cache_key_wrapper` in order to be included in the cache key. The `cache_key_wrapper` function should only be required for Jinja add-ons. +* [9794](https://github.com/apache/incubator-superset/pull/9794): introduces `create view as` functionality in the sqllab. This change will require the `query` table migration and potential service downtime as that table has quite some traffic. + +* [9572](https://github.com/apache/incubator-superset/pull/9572): a change which by default means that the Jinja `current_user_id`, `current_username`, and `url_param` context calls no longer need to be wrapped via `cache_key_wrapper` in order to be included in the cache key. The `cache_key_wrapper` function should only be required for Jinja add-ons. * [8867](https://github.com/apache/incubator-superset/pull/8867): a change which adds the `tmp_schema_name` column to the `query` table which requires locking the table. Given the `query` table is heavily used performance may be degraded during the migration. Scheduled downtime may be advised. diff --git a/requirements-dev.txt b/requirements-dev.txt index 5325ea698a..3687aff937 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -22,6 +22,7 @@ ipdb==0.12 isort==4.3.21 mypy==0.770 nose==1.3.7 +parameterized==0.7.4 pip-tools==5.1.2 pre-commit==1.17.0 psycopg2-binary==2.8.5 diff --git a/superset-frontend/src/SqlLab/actions/sqlLab.js b/superset-frontend/src/SqlLab/actions/sqlLab.js index 778dcb3b95..122d9ecf67 100644 --- a/superset-frontend/src/SqlLab/actions/sqlLab.js +++ b/superset-frontend/src/SqlLab/actions/sqlLab.js @@ -97,6 +97,11 @@ export const addSuccessToast = addSuccessToastAction; export const addDangerToast = addDangerToastAction; export const addWarningToast = addWarningToastAction; +export const CtasEnum = { + TABLE: 'TABLE', + VIEW: 'VIEW', +}; + // a map of SavedQuery field names to the different names used client-side, // because for now making the names consistent is too complicated // so it might as well only happen in one place @@ -346,6 +351,7 @@ export function runQuery(query) { tab: query.tab, tmp_table_name: query.tempTableName, select_as_cta: query.ctas, + ctas_method: query.ctas_method, templateParams: query.templateParams, queryLimit: query.queryLimit, expand_data: true, diff --git a/superset-frontend/src/SqlLab/components/ResultSet.jsx b/superset-frontend/src/SqlLab/components/ResultSet.jsx index f3f0ea89d2..5ccc089544 100644 --- a/superset-frontend/src/SqlLab/components/ResultSet.jsx +++ b/superset-frontend/src/SqlLab/components/ResultSet.jsx @@ -30,6 +30,7 @@ import FilterableTable from '../../components/FilterableTable/FilterableTable'; import QueryStateLabel from './QueryStateLabel'; import CopyToClipboard from '../../components/CopyToClipboard'; import { prepareCopyToClipboardTabularData } from '../../utils/common'; +import { CtasEnum } from '../actions/sqlLab'; const propTypes = { actions: PropTypes.object, @@ -219,10 +220,14 @@ export default class ResultSet extends React.PureComponent { tmpTable = query.results.query.tempTable; tmpSchema = query.results.query.tempSchema; } + let object = 'Table'; + if (query.ctas_method === CtasEnum.VIEW) { + object = 'View'; + } return (
- {t('Table')} [ + {t(object)} [ {tmpSchema}.{tmpTable} diff --git a/superset-frontend/src/SqlLab/components/SqlEditor.jsx b/superset-frontend/src/SqlLab/components/SqlEditor.jsx index 7f1123c132..54f729cacc 100644 --- a/superset-frontend/src/SqlLab/components/SqlEditor.jsx +++ b/superset-frontend/src/SqlLab/components/SqlEditor.jsx @@ -54,6 +54,7 @@ import { } from '../constants'; import RunQueryActionButton from './RunQueryActionButton'; import { FeatureFlag, isFeatureEnabled } from '../../featureFlags'; +import { CtasEnum } from '../actions/sqlLab'; const SQL_EDITOR_PADDING = 10; const INITIAL_NORTH_PERCENT = 30; @@ -284,7 +285,7 @@ class SqlEditor extends React.PureComponent { this.startQuery(); } } - startQuery(ctas = false) { + startQuery(ctas = false, ctas_method = CtasEnum.TABLE) { const qe = this.props.queryEditor; const query = { dbId: qe.dbId, @@ -299,6 +300,7 @@ class SqlEditor extends React.PureComponent { ? this.props.database.allow_run_async : false, ctas, + ctas_method, updateTabState: !qe.selectedText, }; this.props.actions.runQuery(query); @@ -313,7 +315,10 @@ class SqlEditor extends React.PureComponent { } } createTableAs() { - this.startQuery(true); + this.startQuery(true, CtasEnum.TABLE); + } + createViewAs() { + this.startQuery(true, CtasEnum.VIEW); } ctasChanged(event) { this.setState({ ctas: event.target.value }); @@ -372,8 +377,13 @@ class SqlEditor extends React.PureComponent { } renderEditorBottomBar(hotkeys) { let ctasControls; - if (this.props.database && this.props.database.allow_ctas) { + if ( + this.props.database && + (this.props.database.allow_ctas || this.props.database.allow_cvas) + ) { const ctasToolTip = t('Create table as with query results'); + const cvasToolTip = t('Create view as with query results'); + ctasControls = ( @@ -385,14 +395,26 @@ class SqlEditor extends React.PureComponent { onChange={this.ctasChanged.bind(this)} /> - + {this.props.database.allow_ctas && ( + + )} + {this.props.database.allow_cvas && ( + + )} diff --git a/superset/migrations/versions/ea396d202291_ctas_method_in_query.py b/superset/migrations/versions/ea396d202291_ctas_method_in_query.py new file mode 100644 index 0000000000..6dd0b24cfa --- /dev/null +++ b/superset/migrations/versions/ea396d202291_ctas_method_in_query.py @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Add ctas_method to the Query object + +Revision ID: ea396d202291 +Revises: e557699a813e +Create Date: 2020-05-12 12:59:26.583276 + +""" + +# revision identifiers, used by Alembic. +revision = "ea396d202291" +down_revision = "e557699a813e" + +import sqlalchemy as sa +from alembic import op + + +def upgrade(): + op.add_column( + "query", sa.Column("ctas_method", sa.String(length=16), nullable=True) + ) + op.add_column("dbs", sa.Column("allow_cvas", sa.Boolean(), nullable=True)) + + +def downgrade(): + op.drop_column("query", "ctas_method") + op.drop_column("dbs", "allow_cvas") diff --git a/superset/models/core.py b/superset/models/core.py index 3fe92a3caf..562e9c523d 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -118,6 +118,7 @@ class Database( allow_run_async = Column(Boolean, default=False) allow_csv_upload = Column(Boolean, default=False) allow_ctas = Column(Boolean, default=False) + allow_cvas = Column(Boolean, default=False) allow_dml = Column(Boolean, default=False) force_ctas_schema = Column(String(250)) allow_multi_schema_metadata_fetch = Column( # pylint: disable=invalid-name @@ -147,6 +148,7 @@ class Database( "expose_in_sqllab", "allow_run_async", "allow_ctas", + "allow_cvas", "allow_csv_upload", "extra", ] diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py index 8a90e9abb1..886d0c51c4 100644 --- a/superset/models/sql_lab.py +++ b/superset/models/sql_lab.py @@ -19,7 +19,6 @@ import re from datetime import datetime from typing import Any, Dict -# pylint: disable=ungrouped-imports import simplejson as json import sqlalchemy as sqla from flask import Markup @@ -40,6 +39,7 @@ from sqlalchemy.orm import backref, relationship from superset import security_manager from superset.models.helpers import AuditMixinNullable, ExtraJSONMixin from superset.models.tags import QueryUpdater +from superset.sql_parse import CtasMethod from superset.utils.core import QueryStatus, user_label @@ -72,6 +72,7 @@ class Query(Model, ExtraJSONMixin): limit = Column(Integer) select_as_cta = Column(Boolean) select_as_cta_used = Column(Boolean, default=False) + ctas_method = Column(String(16), default=CtasMethod.TABLE) progress = Column(Integer, default=0) # 1..100 # # of rows in the result set or rows modified. diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 4d3597c43e..8c3f24fc57 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -223,7 +223,9 @@ def execute_sql_statement( query.user_id, start_dttm.strftime("%Y_%m_%d_%H_%M_%S") ) sql = parsed_query.as_create_table( - query.tmp_table_name, schema_name=query.tmp_schema_name + query.tmp_table_name, + schema_name=query.tmp_schema_name, + method=query.ctas_method, ) query.select_as_cta_used = True diff --git a/superset/sql_parse.py b/superset/sql_parse.py index b43b113ee5..e532a5eef6 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -16,6 +16,7 @@ # under the License. import logging from dataclasses import dataclass +from enum import Enum from typing import List, Optional, Set from urllib import parse @@ -31,6 +32,11 @@ CTE_PREFIX = "CTE__" logger = logging.getLogger(__name__) +class CtasMethod(str, Enum): + TABLE = "TABLE" + VIEW = "VIEW" + + def _extract_limit_from_query(statement: TokenList) -> Optional[int]: """ Extract limit clause from SQL statement. @@ -185,6 +191,7 @@ class ParsedQuery: table_name: str, schema_name: Optional[str] = None, overwrite: bool = False, + method: CtasMethod = CtasMethod.TABLE, ) -> str: """Reformats the query into the create table as query. @@ -193,6 +200,7 @@ class ParsedQuery: :param table_name: table that will contain the results of the query execution :param schema_name: schema name for the target table :param overwrite: table_name will be dropped if true + :param method: method for the CTA query, currently view or table creation :return: Create table as query """ exec_sql = "" @@ -200,8 +208,8 @@ class ParsedQuery: # TODO(bkyryliuk): quote full_table_name full_table_name = f"{schema_name}.{table_name}" if schema_name else table_name if overwrite: - exec_sql = f"DROP TABLE IF EXISTS {full_table_name};\n" - exec_sql += f"CREATE TABLE {full_table_name} AS \n{sql}" + exec_sql = f"DROP {method} IF EXISTS {full_table_name};\n" + exec_sql += f"CREATE {method} {full_table_name} AS \n{sql}" return exec_sql def _extract_from_token( # pylint: disable=too-many-branches diff --git a/superset/views/core.py b/superset/views/core.py index 4146914dc5..9767a1e274 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -82,7 +82,7 @@ from superset.security.analytics_db_safety import ( check_sqlalchemy_uri, DBSecurityException, ) -from superset.sql_parse import ParsedQuery, Table +from superset.sql_parse import CtasMethod, ParsedQuery, Table from superset.sql_validators import get_validator_by_name from superset.typing import FlaskResponse from superset.utils import core as utils, dashboard_import_export @@ -133,6 +133,7 @@ logger = logging.getLogger(__name__) DATABASE_KEYS = [ "allow_csv_upload", "allow_ctas", + "allow_cvas", "allow_dml", "allow_multi_schema_metadata_fetch", "allow_run_async", @@ -2239,6 +2240,9 @@ class Superset(BaseSupersetView): ) limit = 0 select_as_cta: bool = cast(bool, query_params.get("select_as_cta")) + ctas_method: CtasMethod = cast( + CtasMethod, query_params.get("ctas_method", CtasMethod.TABLE) + ) tmp_table_name: str = cast(str, query_params.get("tmp_table_name")) client_id: str = cast( str, query_params.get("client_id") or utils.shortid()[:10] @@ -2267,6 +2271,7 @@ class Superset(BaseSupersetView): sql=sql, schema=schema, select_as_cta=select_as_cta, + ctas_method=ctas_method, start_time=now_as_float(), tab_name=tab_name, status=status, diff --git a/superset/views/database/api.py b/superset/views/database/api.py index fe328d65ef..aec49f32b8 100644 --- a/superset/views/database/api.py +++ b/superset/views/database/api.py @@ -128,6 +128,7 @@ class DatabaseRestApi(DatabaseMixin, BaseSupersetModelRestApi): "database_name", "expose_in_sqllab", "allow_ctas", + "allow_cvas", "force_ctas_schema", "allow_run_async", "allow_dml", diff --git a/superset/views/database/mixins.py b/superset/views/database/mixins.py index a10345a6a7..77a2b7d115 100644 --- a/superset/views/database/mixins.py +++ b/superset/views/database/mixins.py @@ -60,6 +60,7 @@ class DatabaseMixin: "allow_run_async", "allow_csv_upload", "allow_ctas", + "allow_cvas", "allow_dml", "force_ctas_schema", "impersonate_user", @@ -111,6 +112,7 @@ class DatabaseMixin: "for more information." ), "allow_ctas": _("Allow CREATE TABLE AS option in SQL Lab"), + "allow_cvas": _("Allow CREATE VIEW AS option in SQL Lab"), "allow_dml": _( "Allow users to run non-SELECT statements " "(UPDATE, DELETE, CREATE, ...) " @@ -182,6 +184,7 @@ class DatabaseMixin: label_columns = { "expose_in_sqllab": _("Expose in SQL Lab"), "allow_ctas": _("Allow CREATE TABLE AS"), + "allow_cvas": _("Allow CREATE VIEW AS"), "allow_dml": _("Allow DML"), "force_ctas_schema": _("CTAS Schema"), "database_name": _("Database"), diff --git a/tests/base_tests.py b/tests/base_tests.py index d6d6516afd..d41a2401f0 100644 --- a/tests/base_tests.py +++ b/tests/base_tests.py @@ -27,6 +27,7 @@ from flask_appbuilder.security.sqla import models as ab_models from flask_testing import TestCase from sqlalchemy.orm import Session +from superset.sql_parse import CtasMethod from tests.test_app import app # isort:skip from superset import db, security_manager from superset.connectors.base.models import BaseDatasource @@ -259,6 +260,7 @@ class SupersetTestCase(TestCase): select_as_cta=False, tmp_table_name=None, schema=None, + ctas_method=CtasMethod.TABLE, ): if user_name: self.logout() @@ -270,6 +272,7 @@ class SupersetTestCase(TestCase): "client_id": client_id, "queryLimit": query_limit, "sql_editor_id": sql_editor_id, + "ctas_method": ctas_method, } if tmp_table_name: json_payload["tmp_table_name"] = tmp_table_name diff --git a/tests/celery_tests.py b/tests/celery_tests.py index f31dcf78e1..e30d7807f9 100644 --- a/tests/celery_tests.py +++ b/tests/celery_tests.py @@ -17,20 +17,15 @@ # isort:skip_file """Unit tests for Superset Celery worker""" import datetime -import io import json -import logging +from parameterized import parameterized import subprocess import time import unittest import unittest.mock as mock import flask -import sqlalchemy -from contextlib2 import contextmanager from flask import current_app -from sqlalchemy.orm import sessionmaker -from sqlalchemy.pool import NullPool from tests.test_app import app from superset import db, sql_lab @@ -39,7 +34,7 @@ from superset.db_engine_specs.base import BaseEngineSpec from superset.extensions import celery_app from superset.models.helpers import QueryStatus from superset.models.sql_lab import Query -from superset.sql_parse import ParsedQuery +from superset.sql_parse import ParsedQuery, CtasMethod from superset.utils.core import get_example_database from .base_tests import SupersetTestCase @@ -120,7 +115,14 @@ class CeleryTestCase(SupersetTestCase): db.session.commit() def run_sql( - self, db_id, sql, client_id=None, cta=False, tmp_table="tmp", async_=False + self, + db_id, + sql, + client_id=None, + cta=False, + tmp_table="tmp", + async_=False, + ctas_method=CtasMethod.TABLE, ): self.login() resp = self.client.post( @@ -132,34 +134,55 @@ class CeleryTestCase(SupersetTestCase): select_as_cta=cta, tmp_table_name=tmp_table, client_id=client_id, + ctas_method=ctas_method, ), ) self.logout() return json.loads(resp.data) - def test_run_sync_query_dont_exist(self): + @parameterized.expand( + [CtasMethod.TABLE,] + ) + def test_run_sync_query_dont_exist(self, ctas_method): main_db = get_example_database() db_id = main_db.id sql_dont_exist = "SELECT name FROM table_dont_exist" - result1 = self.run_sql(db_id, sql_dont_exist, "1", cta=True) - self.assertTrue("error" in result1) + result = self.run_sql( + db_id, sql_dont_exist, f"1_{ctas_method}", cta=True, ctas_method=ctas_method + ) + if ( + get_example_database().backend != "sqlite" + and ctas_method == CtasMethod.VIEW + ): + self.assertEqual(QueryStatus.SUCCESS, result["status"], msg=result) + else: + self.assertEqual(QueryStatus.FAILED, result["status"], msg=result) - def test_run_sync_query_cta(self): + @parameterized.expand([CtasMethod.TABLE, CtasMethod.VIEW]) + def test_run_sync_query_cta(self, ctas_method): main_db = get_example_database() db_id = main_db.id - tmp_table_name = "tmp_async_22" + tmp_table_name = f"tmp_sync_23_{ctas_method.lower()}" self.drop_table_if_exists(tmp_table_name, main_db) name = "James" sql_where = f"SELECT name FROM birth_names WHERE name='{name}' LIMIT 1" - result = self.run_sql(db_id, sql_where, "2", tmp_table=tmp_table_name, cta=True) - self.assertEqual(QueryStatus.SUCCESS, result["query"]["state"]) + result = self.run_sql( + db_id, + sql_where, + f"2_{ctas_method}", + tmp_table=tmp_table_name, + cta=True, + ctas_method=ctas_method, + ) + # provide better error message + self.assertEqual(QueryStatus.SUCCESS, result["query"]["state"], msg=result) self.assertEqual([], result["data"]) self.assertEqual([], result["columns"]) query2 = self.get_query_by_id(result["query"]["serverId"]) # Check the data in the tmp table. - results = self.run_sql(db_id, query2.select_sql, "sdf2134") - self.assertEqual(results["status"], "success") + results = self.run_sql(db_id, query2.select_sql, f"7_{ctas_method}") + self.assertEqual(QueryStatus.SUCCESS, results["status"], msg=results) self.assertGreater(len(results["data"]), 0) # cleanup tmp table @@ -186,89 +209,113 @@ class CeleryTestCase(SupersetTestCase): db.session.flush() return self.run_sql(db_id, sql) - @mock.patch( - "superset.views.core.get_cta_schema_name", lambda d, u, s, sql: CTAS_SCHEMA_NAME - ) - def test_run_sync_query_cta_config(self): - main_db = get_example_database() - db_id = main_db.id - if main_db.backend == "sqlite": - # sqlite doesn't support schemas - return - tmp_table_name = "tmp_async_22" - expected_full_table_name = f"{CTAS_SCHEMA_NAME}.{tmp_table_name}" - self.drop_table_if_exists(expected_full_table_name, main_db) - name = "James" - sql_where = f"SELECT name FROM birth_names WHERE name='{name}'" - result = self.run_sql( - db_id, sql_where, "cid2", tmp_table=tmp_table_name, cta=True - ) + @parameterized.expand([CtasMethod.TABLE, CtasMethod.VIEW]) + def test_run_sync_query_cta_config(self, ctas_method): + with mock.patch( + "superset.views.core.get_cta_schema_name", + lambda d, u, s, sql: CTAS_SCHEMA_NAME, + ): + main_db = get_example_database() + db_id = main_db.id + if main_db.backend == "sqlite": + # sqlite doesn't support schemas + return + tmp_table_name = f"tmp_async_22_{ctas_method.lower()}" + quote = ( + main_db.inspector.engine.dialect.identifier_preparer.quote_identifier + ) + expected_full_table_name = f"{CTAS_SCHEMA_NAME}.{quote(tmp_table_name)}" + self.drop_table_if_exists(expected_full_table_name, main_db) + name = "James" + sql_where = f"SELECT name FROM birth_names WHERE name='{name}'" + result = self.run_sql( + db_id, + sql_where, + f"3_{ctas_method}", + tmp_table=tmp_table_name, + cta=True, + ctas_method=ctas_method, + ) - self.assertEqual(QueryStatus.SUCCESS, result["query"]["state"]) - self.assertEqual([], result["data"]) - self.assertEqual([], result["columns"]) - query = self.get_query_by_id(result["query"]["serverId"]) - self.assertEqual( - f"CREATE TABLE {expected_full_table_name} AS \n" - "SELECT name FROM birth_names " - "WHERE name='James'", - query.executed_sql, - ) - self.assertEqual( - "SELECT *\n" f"FROM {expected_full_table_name}", query.select_sql - ) - time.sleep(CELERY_SHORT_SLEEP_TIME) - results = self.run_sql(db_id, query.select_sql) - self.assertEqual(results["status"], "success") - self.drop_table_if_exists(expected_full_table_name, get_example_database()) + self.assertEqual(QueryStatus.SUCCESS, result["query"]["state"], msg=result) + self.assertEqual([], result["data"]) + self.assertEqual([], result["columns"]) + query = self.get_query_by_id(result["query"]["serverId"]) + self.assertEqual( + f"CREATE {ctas_method} {CTAS_SCHEMA_NAME}.{tmp_table_name} AS \n" + "SELECT name FROM birth_names " + "WHERE name='James'", + query.executed_sql, + ) + self.assertEqual( + "SELECT *\n" f"FROM {CTAS_SCHEMA_NAME}.{tmp_table_name}", + query.select_sql, + ) + time.sleep(CELERY_SHORT_SLEEP_TIME) + results = self.run_sql(db_id, query.select_sql) + self.assertEqual(QueryStatus.SUCCESS, results["status"], msg=result) + self.drop_table_if_exists(expected_full_table_name, get_example_database()) - @mock.patch( - "superset.views.core.get_cta_schema_name", lambda d, u, s, sql: CTAS_SCHEMA_NAME - ) - def test_run_async_query_cta_config(self): - main_db = get_example_database() - db_id = main_db.id - if main_db.backend == "sqlite": - # sqlite doesn't support schemas - return - tmp_table_name = "sqllab_test_table_async_1" - expected_full_table_name = f"{CTAS_SCHEMA_NAME}.{tmp_table_name}" - self.drop_table_if_exists(expected_full_table_name, main_db) - sql_where = "SELECT name FROM birth_names WHERE name='James' LIMIT 10" - result = self.run_sql( - db_id, - sql_where, - "cid3", - async_=True, - tmp_table="sqllab_test_table_async_1", - cta=True, - ) - db.session.close() - time.sleep(CELERY_SLEEP_TIME) + @parameterized.expand([CtasMethod.TABLE, CtasMethod.VIEW]) + def test_run_async_query_cta_config(self, ctas_method): + with mock.patch( + "superset.views.core.get_cta_schema_name", + lambda d, u, s, sql: CTAS_SCHEMA_NAME, + ): + main_db = get_example_database() + db_id = main_db.id + if main_db.backend == "sqlite": + # sqlite doesn't support schemas + return + tmp_table_name = f"sqllab_test_table_async_1_{ctas_method}" + quote = ( + main_db.inspector.engine.dialect.identifier_preparer.quote_identifier + ) + expected_full_table_name = f"{CTAS_SCHEMA_NAME}.{quote(tmp_table_name)}" + self.drop_table_if_exists(expected_full_table_name, main_db) + sql_where = "SELECT name FROM birth_names WHERE name='James' LIMIT 10" + result = self.run_sql( + db_id, + sql_where, + f"4_{ctas_method}", + async_=True, + tmp_table=tmp_table_name, + cta=True, + ctas_method=ctas_method, + ) + db.session.close() + time.sleep(CELERY_SLEEP_TIME) - query = self.get_query_by_id(result["query"]["serverId"]) - self.assertEqual(QueryStatus.SUCCESS, query.status) - self.assertTrue(f"FROM {expected_full_table_name}" in query.select_sql) - self.assertEqual( - f"CREATE TABLE {expected_full_table_name} AS \n" - "SELECT name FROM birth_names " - "WHERE name='James' " - "LIMIT 10", - query.executed_sql, - ) - self.drop_table_if_exists(expected_full_table_name, get_example_database()) + query = self.get_query_by_id(result["query"]["serverId"]) + self.assertEqual(QueryStatus.SUCCESS, query.status) + self.assertIn(expected_full_table_name, query.select_sql) + self.assertEqual( + f"CREATE {ctas_method} {CTAS_SCHEMA_NAME}.{tmp_table_name} AS \n" + "SELECT name FROM birth_names " + "WHERE name='James' " + "LIMIT 10", + query.executed_sql, + ) + self.drop_table_if_exists(expected_full_table_name, get_example_database()) - def test_run_async_cta_query(self): + @parameterized.expand([CtasMethod.TABLE, CtasMethod.VIEW]) + def test_run_async_cta_query(self, ctas_method): main_db = get_example_database() db_id = main_db.id - table_name = "tmp_async_4" + table_name = f"tmp_async_4_{ctas_method}" self.drop_table_if_exists(table_name, main_db) time.sleep(DROP_TABLE_SLEEP_TIME) sql_where = "SELECT name FROM birth_names WHERE name='James' LIMIT 10" result = self.run_sql( - db_id, sql_where, "cid4", async_=True, tmp_table="tmp_async_4", cta=True + db_id, + sql_where, + f"5_{ctas_method}", + async_=True, + tmp_table=table_name, + cta=True, + ctas_method=ctas_method, ) db.session.close() assert result["query"]["state"] in ( @@ -282,10 +329,10 @@ class CeleryTestCase(SupersetTestCase): query = self.get_query_by_id(result["query"]["serverId"]) self.assertEqual(QueryStatus.SUCCESS, query.status) - self.assertTrue(f"FROM {table_name}" in query.select_sql) + self.assertIn(table_name, query.select_sql) self.assertEqual( - f"CREATE TABLE {table_name} AS \n" + f"CREATE {ctas_method} {table_name} AS \n" "SELECT name FROM birth_names " "WHERE name='James' " "LIMIT 10", @@ -296,15 +343,22 @@ class CeleryTestCase(SupersetTestCase): self.assertEqual(True, query.select_as_cta) self.assertEqual(True, query.select_as_cta_used) - def test_run_async_cta_query_with_lower_limit(self): + @parameterized.expand([CtasMethod.TABLE, CtasMethod.VIEW]) + def test_run_async_cta_query_with_lower_limit(self, ctas_method): main_db = get_example_database() db_id = main_db.id - tmp_table = "tmp_async_2" + tmp_table = f"tmp_async_2_{ctas_method}" self.drop_table_if_exists(tmp_table, main_db) sql_where = "SELECT name FROM birth_names LIMIT 1" result = self.run_sql( - db_id, sql_where, "id1", async_=True, tmp_table=tmp_table, cta=True + db_id, + sql_where, + f"6_{ctas_method}", + async_=True, + tmp_table=tmp_table, + cta=True, + ctas_method=ctas_method, ) db.session.close() assert result["query"]["state"] in ( @@ -318,9 +372,10 @@ class CeleryTestCase(SupersetTestCase): query = self.get_query_by_id(result["query"]["serverId"]) self.assertEqual(QueryStatus.SUCCESS, query.status) - self.assertIn(f"FROM {tmp_table}", query.select_sql) + self.assertIn(tmp_table, query.select_sql) self.assertEqual( - f"CREATE TABLE {tmp_table} AS \n" "SELECT name FROM birth_names LIMIT 1", + f"CREATE {ctas_method} {tmp_table} AS \n" + "SELECT name FROM birth_names LIMIT 1", query.executed_sql, ) self.assertEqual(sql_where, query.sql) diff --git a/tests/database_api_tests.py b/tests/database_api_tests.py index f060a2c4e3..f636a29c93 100644 --- a/tests/database_api_tests.py +++ b/tests/database_api_tests.py @@ -43,6 +43,7 @@ class DatabaseApiTests(SupersetTestCase): expected_columns = [ "allow_csv_upload", "allow_ctas", + "allow_cvas", "allow_dml", "allow_multi_schema_metadata_fetch", "allow_run_async", diff --git a/tests/sqllab_tests.py b/tests/sqllab_tests.py index 7c1b0e69d6..426649a13a 100644 --- a/tests/sqllab_tests.py +++ b/tests/sqllab_tests.py @@ -18,6 +18,7 @@ """Unit tests for Sql Lab""" import json from datetime import datetime, timedelta +from parameterized import parameterized from random import random from unittest import mock @@ -29,6 +30,7 @@ from superset.connectors.sqla.models import SqlaTable from superset.db_engine_specs import BaseEngineSpec from superset.models.sql_lab import Query from superset.result_set import SupersetResultSet +from superset.sql_parse import CtasMethod from superset.utils.core import datetime_to_epoch, get_example_database from .base_tests import SupersetTestCase @@ -67,38 +69,44 @@ class SqlLabTests(SupersetTestCase): data = self.run_sql("SELECT * FROM unexistant_table", "2") self.assertLess(0, len(data["error"])) - @mock.patch( - "superset.views.core.get_cta_schema_name", - lambda d, u, s, sql: f"{u.username}_database", - ) - def test_sql_json_cta_dynamic_db(self): + @parameterized.expand([CtasMethod.TABLE, CtasMethod.VIEW]) + def test_sql_json_cta_dynamic_db(self, ctas_method): main_db = get_example_database() if main_db.backend == "sqlite": # sqlite doesn't support database creation return - old_allow_ctas = main_db.allow_ctas - main_db.allow_ctas = True # enable cta + with mock.patch( + "superset.views.core.get_cta_schema_name", + lambda d, u, s, sql: f"{u.username}_database", + ): + old_allow_ctas = main_db.allow_ctas + main_db.allow_ctas = True # enable cta - self.login("admin") - self.run_sql( - "SELECT * FROM birth_names", - "1", - database_name="examples", - tmp_table_name="test_target", - select_as_cta=True, - ) + self.login("admin") + tmp_table_name = f"test_target_{ctas_method.lower()}" + self.run_sql( + "SELECT * FROM birth_names", + "1", + database_name="examples", + tmp_table_name=tmp_table_name, + select_as_cta=True, + ctas_method=ctas_method, + ) - # assertions - data = db.session.execute("SELECT * FROM admin_database.test_target").fetchall() - self.assertEqual( - 75691, len(data) - ) # SQL_MAX_ROW not applied due to the SQLLAB_CTAS_NO_LIMIT set to True + # assertions + db.session.commit() + data = db.session.execute( + f"SELECT * FROM admin_database.{tmp_table_name}" + ).fetchall() + self.assertEqual( + 75691, len(data) + ) # SQL_MAX_ROW not applied due to the SQLLAB_CTAS_NO_LIMIT set to True - # cleanup - db.session.execute("DROP TABLE admin_database.test_target") - main_db.allow_ctas = old_allow_ctas - db.session.commit() + # cleanup + db.session.execute(f"DROP {ctas_method} admin_database.{tmp_table_name}") + main_db.allow_ctas = old_allow_ctas + db.session.commit() def test_multi_sql(self): self.login("admin")