Implement create view as functionality (#9794)

Implement create view as button in sqllab

Make CVAS configurable

Co-authored-by: bogdan kyryliuk <bogdankyryliuk@dropbox.com>
This commit is contained in:
Bogdan 2020-06-24 09:50:41 -07:00 committed by GitHub
parent 38667b72b1
commit 3db76c6fdc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 304 additions and 137 deletions

View File

@ -32,7 +32,9 @@ assists people when migrating to a new version.
* [9786](https://github.com/apache/incubator-superset/pull/9786): with the upgrade of `werkzeug` from version `0.16.0` to `1.0.1`, the `werkzeug.contrib.cache` module has been moved to a standalone package [cachelib](https://pypi.org/project/cachelib/). For example, to import the `RedisCache` class, please use the following import: `from cachelib.redis import RedisCache`.
* [9572](https://github.com/apache/incubator-superset/pull/9572): a change which by defau;t means that the Jinja `current_user_id`, `current_username`, and `url_param` context calls no longer need to be wrapped via `cache_key_wrapper` in order to be included in the cache key. The `cache_key_wrapper` function should only be required for Jinja add-ons.
* [9794](https://github.com/apache/incubator-superset/pull/9794): introduces `create view as` functionality in the sqllab. This change will require the `query` table migration and potential service downtime as that table has quite some traffic.
* [9572](https://github.com/apache/incubator-superset/pull/9572): a change which by default means that the Jinja `current_user_id`, `current_username`, and `url_param` context calls no longer need to be wrapped via `cache_key_wrapper` in order to be included in the cache key. The `cache_key_wrapper` function should only be required for Jinja add-ons.
* [8867](https://github.com/apache/incubator-superset/pull/8867): a change which adds the `tmp_schema_name` column to the `query` table which requires locking the table. Given the `query` table is heavily used performance may be degraded during the migration. Scheduled downtime may be advised.

View File

@ -22,6 +22,7 @@ ipdb==0.12
isort==4.3.21
mypy==0.770
nose==1.3.7
parameterized==0.7.4
pip-tools==5.1.2
pre-commit==1.17.0
psycopg2-binary==2.8.5

View File

@ -97,6 +97,11 @@ export const addSuccessToast = addSuccessToastAction;
export const addDangerToast = addDangerToastAction;
export const addWarningToast = addWarningToastAction;
export const CtasEnum = {
TABLE: 'TABLE',
VIEW: 'VIEW',
};
// a map of SavedQuery field names to the different names used client-side,
// because for now making the names consistent is too complicated
// so it might as well only happen in one place
@ -346,6 +351,7 @@ export function runQuery(query) {
tab: query.tab,
tmp_table_name: query.tempTableName,
select_as_cta: query.ctas,
ctas_method: query.ctas_method,
templateParams: query.templateParams,
queryLimit: query.queryLimit,
expand_data: true,

View File

@ -30,6 +30,7 @@ import FilterableTable from '../../components/FilterableTable/FilterableTable';
import QueryStateLabel from './QueryStateLabel';
import CopyToClipboard from '../../components/CopyToClipboard';
import { prepareCopyToClipboardTabularData } from '../../utils/common';
import { CtasEnum } from '../actions/sqlLab';
const propTypes = {
actions: PropTypes.object,
@ -219,10 +220,14 @@ export default class ResultSet extends React.PureComponent {
tmpTable = query.results.query.tempTable;
tmpSchema = query.results.query.tempSchema;
}
let object = 'Table';
if (query.ctas_method === CtasEnum.VIEW) {
object = 'View';
}
return (
<div>
<Alert bsStyle="info">
{t('Table')} [
{t(object)} [
<strong>
{tmpSchema}.{tmpTable}
</strong>

View File

@ -54,6 +54,7 @@ import {
} from '../constants';
import RunQueryActionButton from './RunQueryActionButton';
import { FeatureFlag, isFeatureEnabled } from '../../featureFlags';
import { CtasEnum } from '../actions/sqlLab';
const SQL_EDITOR_PADDING = 10;
const INITIAL_NORTH_PERCENT = 30;
@ -284,7 +285,7 @@ class SqlEditor extends React.PureComponent {
this.startQuery();
}
}
startQuery(ctas = false) {
startQuery(ctas = false, ctas_method = CtasEnum.TABLE) {
const qe = this.props.queryEditor;
const query = {
dbId: qe.dbId,
@ -299,6 +300,7 @@ class SqlEditor extends React.PureComponent {
? this.props.database.allow_run_async
: false,
ctas,
ctas_method,
updateTabState: !qe.selectedText,
};
this.props.actions.runQuery(query);
@ -313,7 +315,10 @@ class SqlEditor extends React.PureComponent {
}
}
createTableAs() {
this.startQuery(true);
this.startQuery(true, CtasEnum.TABLE);
}
createViewAs() {
this.startQuery(true, CtasEnum.VIEW);
}
ctasChanged(event) {
this.setState({ ctas: event.target.value });
@ -372,8 +377,13 @@ class SqlEditor extends React.PureComponent {
}
renderEditorBottomBar(hotkeys) {
let ctasControls;
if (this.props.database && this.props.database.allow_ctas) {
if (
this.props.database &&
(this.props.database.allow_ctas || this.props.database.allow_cvas)
) {
const ctasToolTip = t('Create table as with query results');
const cvasToolTip = t('Create view as with query results');
ctasControls = (
<FormGroup>
<InputGroup>
@ -385,14 +395,26 @@ class SqlEditor extends React.PureComponent {
onChange={this.ctasChanged.bind(this)}
/>
<InputGroup.Button>
<Button
bsSize="small"
disabled={this.state.ctas.length === 0}
onClick={this.createTableAs.bind(this)}
tooltip={ctasToolTip}
>
<i className="fa fa-table" /> CTAS
</Button>
{this.props.database.allow_ctas && (
<Button
bsSize="small"
disabled={this.state.ctas.length === 0}
onClick={this.createTableAs.bind(this)}
tooltip={ctasToolTip}
>
<i className="fa fa-table" /> CTAS
</Button>
)}
{this.props.database.allow_cvas && (
<Button
bsSize="small"
disabled={this.state.ctas.length === 0}
onClick={this.createViewAs.bind(this)}
tooltip={cvasToolTip}
>
<i className="fa fa-table" /> CVAS
</Button>
)}
</InputGroup.Button>
</InputGroup>
</FormGroup>

View File

@ -0,0 +1,42 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Add ctas_method to the Query object
Revision ID: ea396d202291
Revises: e557699a813e
Create Date: 2020-05-12 12:59:26.583276
"""
# revision identifiers, used by Alembic.
revision = "ea396d202291"
down_revision = "e557699a813e"
import sqlalchemy as sa
from alembic import op
def upgrade():
op.add_column(
"query", sa.Column("ctas_method", sa.String(length=16), nullable=True)
)
op.add_column("dbs", sa.Column("allow_cvas", sa.Boolean(), nullable=True))
def downgrade():
op.drop_column("query", "ctas_method")
op.drop_column("dbs", "allow_cvas")

View File

@ -118,6 +118,7 @@ class Database(
allow_run_async = Column(Boolean, default=False)
allow_csv_upload = Column(Boolean, default=False)
allow_ctas = Column(Boolean, default=False)
allow_cvas = Column(Boolean, default=False)
allow_dml = Column(Boolean, default=False)
force_ctas_schema = Column(String(250))
allow_multi_schema_metadata_fetch = Column( # pylint: disable=invalid-name
@ -147,6 +148,7 @@ class Database(
"expose_in_sqllab",
"allow_run_async",
"allow_ctas",
"allow_cvas",
"allow_csv_upload",
"extra",
]

View File

@ -19,7 +19,6 @@ import re
from datetime import datetime
from typing import Any, Dict
# pylint: disable=ungrouped-imports
import simplejson as json
import sqlalchemy as sqla
from flask import Markup
@ -40,6 +39,7 @@ from sqlalchemy.orm import backref, relationship
from superset import security_manager
from superset.models.helpers import AuditMixinNullable, ExtraJSONMixin
from superset.models.tags import QueryUpdater
from superset.sql_parse import CtasMethod
from superset.utils.core import QueryStatus, user_label
@ -72,6 +72,7 @@ class Query(Model, ExtraJSONMixin):
limit = Column(Integer)
select_as_cta = Column(Boolean)
select_as_cta_used = Column(Boolean, default=False)
ctas_method = Column(String(16), default=CtasMethod.TABLE)
progress = Column(Integer, default=0) # 1..100
# # of rows in the result set or rows modified.

View File

@ -223,7 +223,9 @@ def execute_sql_statement(
query.user_id, start_dttm.strftime("%Y_%m_%d_%H_%M_%S")
)
sql = parsed_query.as_create_table(
query.tmp_table_name, schema_name=query.tmp_schema_name
query.tmp_table_name,
schema_name=query.tmp_schema_name,
method=query.ctas_method,
)
query.select_as_cta_used = True

View File

@ -16,6 +16,7 @@
# under the License.
import logging
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional, Set
from urllib import parse
@ -31,6 +32,11 @@ CTE_PREFIX = "CTE__"
logger = logging.getLogger(__name__)
class CtasMethod(str, Enum):
TABLE = "TABLE"
VIEW = "VIEW"
def _extract_limit_from_query(statement: TokenList) -> Optional[int]:
"""
Extract limit clause from SQL statement.
@ -185,6 +191,7 @@ class ParsedQuery:
table_name: str,
schema_name: Optional[str] = None,
overwrite: bool = False,
method: CtasMethod = CtasMethod.TABLE,
) -> str:
"""Reformats the query into the create table as query.
@ -193,6 +200,7 @@ class ParsedQuery:
:param table_name: table that will contain the results of the query execution
:param schema_name: schema name for the target table
:param overwrite: table_name will be dropped if true
:param method: method for the CTA query, currently view or table creation
:return: Create table as query
"""
exec_sql = ""
@ -200,8 +208,8 @@ class ParsedQuery:
# TODO(bkyryliuk): quote full_table_name
full_table_name = f"{schema_name}.{table_name}" if schema_name else table_name
if overwrite:
exec_sql = f"DROP TABLE IF EXISTS {full_table_name};\n"
exec_sql += f"CREATE TABLE {full_table_name} AS \n{sql}"
exec_sql = f"DROP {method} IF EXISTS {full_table_name};\n"
exec_sql += f"CREATE {method} {full_table_name} AS \n{sql}"
return exec_sql
def _extract_from_token( # pylint: disable=too-many-branches

View File

@ -82,7 +82,7 @@ from superset.security.analytics_db_safety import (
check_sqlalchemy_uri,
DBSecurityException,
)
from superset.sql_parse import ParsedQuery, Table
from superset.sql_parse import CtasMethod, ParsedQuery, Table
from superset.sql_validators import get_validator_by_name
from superset.typing import FlaskResponse
from superset.utils import core as utils, dashboard_import_export
@ -133,6 +133,7 @@ logger = logging.getLogger(__name__)
DATABASE_KEYS = [
"allow_csv_upload",
"allow_ctas",
"allow_cvas",
"allow_dml",
"allow_multi_schema_metadata_fetch",
"allow_run_async",
@ -2239,6 +2240,9 @@ class Superset(BaseSupersetView):
)
limit = 0
select_as_cta: bool = cast(bool, query_params.get("select_as_cta"))
ctas_method: CtasMethod = cast(
CtasMethod, query_params.get("ctas_method", CtasMethod.TABLE)
)
tmp_table_name: str = cast(str, query_params.get("tmp_table_name"))
client_id: str = cast(
str, query_params.get("client_id") or utils.shortid()[:10]
@ -2267,6 +2271,7 @@ class Superset(BaseSupersetView):
sql=sql,
schema=schema,
select_as_cta=select_as_cta,
ctas_method=ctas_method,
start_time=now_as_float(),
tab_name=tab_name,
status=status,

View File

@ -128,6 +128,7 @@ class DatabaseRestApi(DatabaseMixin, BaseSupersetModelRestApi):
"database_name",
"expose_in_sqllab",
"allow_ctas",
"allow_cvas",
"force_ctas_schema",
"allow_run_async",
"allow_dml",

View File

@ -60,6 +60,7 @@ class DatabaseMixin:
"allow_run_async",
"allow_csv_upload",
"allow_ctas",
"allow_cvas",
"allow_dml",
"force_ctas_schema",
"impersonate_user",
@ -111,6 +112,7 @@ class DatabaseMixin:
"for more information."
),
"allow_ctas": _("Allow CREATE TABLE AS option in SQL Lab"),
"allow_cvas": _("Allow CREATE VIEW AS option in SQL Lab"),
"allow_dml": _(
"Allow users to run non-SELECT statements "
"(UPDATE, DELETE, CREATE, ...) "
@ -182,6 +184,7 @@ class DatabaseMixin:
label_columns = {
"expose_in_sqllab": _("Expose in SQL Lab"),
"allow_ctas": _("Allow CREATE TABLE AS"),
"allow_cvas": _("Allow CREATE VIEW AS"),
"allow_dml": _("Allow DML"),
"force_ctas_schema": _("CTAS Schema"),
"database_name": _("Database"),

View File

@ -27,6 +27,7 @@ from flask_appbuilder.security.sqla import models as ab_models
from flask_testing import TestCase
from sqlalchemy.orm import Session
from superset.sql_parse import CtasMethod
from tests.test_app import app # isort:skip
from superset import db, security_manager
from superset.connectors.base.models import BaseDatasource
@ -259,6 +260,7 @@ class SupersetTestCase(TestCase):
select_as_cta=False,
tmp_table_name=None,
schema=None,
ctas_method=CtasMethod.TABLE,
):
if user_name:
self.logout()
@ -270,6 +272,7 @@ class SupersetTestCase(TestCase):
"client_id": client_id,
"queryLimit": query_limit,
"sql_editor_id": sql_editor_id,
"ctas_method": ctas_method,
}
if tmp_table_name:
json_payload["tmp_table_name"] = tmp_table_name

View File

@ -17,20 +17,15 @@
# isort:skip_file
"""Unit tests for Superset Celery worker"""
import datetime
import io
import json
import logging
from parameterized import parameterized
import subprocess
import time
import unittest
import unittest.mock as mock
import flask
import sqlalchemy
from contextlib2 import contextmanager
from flask import current_app
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import NullPool
from tests.test_app import app
from superset import db, sql_lab
@ -39,7 +34,7 @@ from superset.db_engine_specs.base import BaseEngineSpec
from superset.extensions import celery_app
from superset.models.helpers import QueryStatus
from superset.models.sql_lab import Query
from superset.sql_parse import ParsedQuery
from superset.sql_parse import ParsedQuery, CtasMethod
from superset.utils.core import get_example_database
from .base_tests import SupersetTestCase
@ -120,7 +115,14 @@ class CeleryTestCase(SupersetTestCase):
db.session.commit()
def run_sql(
self, db_id, sql, client_id=None, cta=False, tmp_table="tmp", async_=False
self,
db_id,
sql,
client_id=None,
cta=False,
tmp_table="tmp",
async_=False,
ctas_method=CtasMethod.TABLE,
):
self.login()
resp = self.client.post(
@ -132,34 +134,55 @@ class CeleryTestCase(SupersetTestCase):
select_as_cta=cta,
tmp_table_name=tmp_table,
client_id=client_id,
ctas_method=ctas_method,
),
)
self.logout()
return json.loads(resp.data)
def test_run_sync_query_dont_exist(self):
@parameterized.expand(
[CtasMethod.TABLE,]
)
def test_run_sync_query_dont_exist(self, ctas_method):
main_db = get_example_database()
db_id = main_db.id
sql_dont_exist = "SELECT name FROM table_dont_exist"
result1 = self.run_sql(db_id, sql_dont_exist, "1", cta=True)
self.assertTrue("error" in result1)
result = self.run_sql(
db_id, sql_dont_exist, f"1_{ctas_method}", cta=True, ctas_method=ctas_method
)
if (
get_example_database().backend != "sqlite"
and ctas_method == CtasMethod.VIEW
):
self.assertEqual(QueryStatus.SUCCESS, result["status"], msg=result)
else:
self.assertEqual(QueryStatus.FAILED, result["status"], msg=result)
def test_run_sync_query_cta(self):
@parameterized.expand([CtasMethod.TABLE, CtasMethod.VIEW])
def test_run_sync_query_cta(self, ctas_method):
main_db = get_example_database()
db_id = main_db.id
tmp_table_name = "tmp_async_22"
tmp_table_name = f"tmp_sync_23_{ctas_method.lower()}"
self.drop_table_if_exists(tmp_table_name, main_db)
name = "James"
sql_where = f"SELECT name FROM birth_names WHERE name='{name}' LIMIT 1"
result = self.run_sql(db_id, sql_where, "2", tmp_table=tmp_table_name, cta=True)
self.assertEqual(QueryStatus.SUCCESS, result["query"]["state"])
result = self.run_sql(
db_id,
sql_where,
f"2_{ctas_method}",
tmp_table=tmp_table_name,
cta=True,
ctas_method=ctas_method,
)
# provide better error message
self.assertEqual(QueryStatus.SUCCESS, result["query"]["state"], msg=result)
self.assertEqual([], result["data"])
self.assertEqual([], result["columns"])
query2 = self.get_query_by_id(result["query"]["serverId"])
# Check the data in the tmp table.
results = self.run_sql(db_id, query2.select_sql, "sdf2134")
self.assertEqual(results["status"], "success")
results = self.run_sql(db_id, query2.select_sql, f"7_{ctas_method}")
self.assertEqual(QueryStatus.SUCCESS, results["status"], msg=results)
self.assertGreater(len(results["data"]), 0)
# cleanup tmp table
@ -186,89 +209,113 @@ class CeleryTestCase(SupersetTestCase):
db.session.flush()
return self.run_sql(db_id, sql)
@mock.patch(
"superset.views.core.get_cta_schema_name", lambda d, u, s, sql: CTAS_SCHEMA_NAME
)
def test_run_sync_query_cta_config(self):
main_db = get_example_database()
db_id = main_db.id
if main_db.backend == "sqlite":
# sqlite doesn't support schemas
return
tmp_table_name = "tmp_async_22"
expected_full_table_name = f"{CTAS_SCHEMA_NAME}.{tmp_table_name}"
self.drop_table_if_exists(expected_full_table_name, main_db)
name = "James"
sql_where = f"SELECT name FROM birth_names WHERE name='{name}'"
result = self.run_sql(
db_id, sql_where, "cid2", tmp_table=tmp_table_name, cta=True
)
@parameterized.expand([CtasMethod.TABLE, CtasMethod.VIEW])
def test_run_sync_query_cta_config(self, ctas_method):
with mock.patch(
"superset.views.core.get_cta_schema_name",
lambda d, u, s, sql: CTAS_SCHEMA_NAME,
):
main_db = get_example_database()
db_id = main_db.id
if main_db.backend == "sqlite":
# sqlite doesn't support schemas
return
tmp_table_name = f"tmp_async_22_{ctas_method.lower()}"
quote = (
main_db.inspector.engine.dialect.identifier_preparer.quote_identifier
)
expected_full_table_name = f"{CTAS_SCHEMA_NAME}.{quote(tmp_table_name)}"
self.drop_table_if_exists(expected_full_table_name, main_db)
name = "James"
sql_where = f"SELECT name FROM birth_names WHERE name='{name}'"
result = self.run_sql(
db_id,
sql_where,
f"3_{ctas_method}",
tmp_table=tmp_table_name,
cta=True,
ctas_method=ctas_method,
)
self.assertEqual(QueryStatus.SUCCESS, result["query"]["state"])
self.assertEqual([], result["data"])
self.assertEqual([], result["columns"])
query = self.get_query_by_id(result["query"]["serverId"])
self.assertEqual(
f"CREATE TABLE {expected_full_table_name} AS \n"
"SELECT name FROM birth_names "
"WHERE name='James'",
query.executed_sql,
)
self.assertEqual(
"SELECT *\n" f"FROM {expected_full_table_name}", query.select_sql
)
time.sleep(CELERY_SHORT_SLEEP_TIME)
results = self.run_sql(db_id, query.select_sql)
self.assertEqual(results["status"], "success")
self.drop_table_if_exists(expected_full_table_name, get_example_database())
self.assertEqual(QueryStatus.SUCCESS, result["query"]["state"], msg=result)
self.assertEqual([], result["data"])
self.assertEqual([], result["columns"])
query = self.get_query_by_id(result["query"]["serverId"])
self.assertEqual(
f"CREATE {ctas_method} {CTAS_SCHEMA_NAME}.{tmp_table_name} AS \n"
"SELECT name FROM birth_names "
"WHERE name='James'",
query.executed_sql,
)
self.assertEqual(
"SELECT *\n" f"FROM {CTAS_SCHEMA_NAME}.{tmp_table_name}",
query.select_sql,
)
time.sleep(CELERY_SHORT_SLEEP_TIME)
results = self.run_sql(db_id, query.select_sql)
self.assertEqual(QueryStatus.SUCCESS, results["status"], msg=result)
self.drop_table_if_exists(expected_full_table_name, get_example_database())
@mock.patch(
"superset.views.core.get_cta_schema_name", lambda d, u, s, sql: CTAS_SCHEMA_NAME
)
def test_run_async_query_cta_config(self):
main_db = get_example_database()
db_id = main_db.id
if main_db.backend == "sqlite":
# sqlite doesn't support schemas
return
tmp_table_name = "sqllab_test_table_async_1"
expected_full_table_name = f"{CTAS_SCHEMA_NAME}.{tmp_table_name}"
self.drop_table_if_exists(expected_full_table_name, main_db)
sql_where = "SELECT name FROM birth_names WHERE name='James' LIMIT 10"
result = self.run_sql(
db_id,
sql_where,
"cid3",
async_=True,
tmp_table="sqllab_test_table_async_1",
cta=True,
)
db.session.close()
time.sleep(CELERY_SLEEP_TIME)
@parameterized.expand([CtasMethod.TABLE, CtasMethod.VIEW])
def test_run_async_query_cta_config(self, ctas_method):
with mock.patch(
"superset.views.core.get_cta_schema_name",
lambda d, u, s, sql: CTAS_SCHEMA_NAME,
):
main_db = get_example_database()
db_id = main_db.id
if main_db.backend == "sqlite":
# sqlite doesn't support schemas
return
tmp_table_name = f"sqllab_test_table_async_1_{ctas_method}"
quote = (
main_db.inspector.engine.dialect.identifier_preparer.quote_identifier
)
expected_full_table_name = f"{CTAS_SCHEMA_NAME}.{quote(tmp_table_name)}"
self.drop_table_if_exists(expected_full_table_name, main_db)
sql_where = "SELECT name FROM birth_names WHERE name='James' LIMIT 10"
result = self.run_sql(
db_id,
sql_where,
f"4_{ctas_method}",
async_=True,
tmp_table=tmp_table_name,
cta=True,
ctas_method=ctas_method,
)
db.session.close()
time.sleep(CELERY_SLEEP_TIME)
query = self.get_query_by_id(result["query"]["serverId"])
self.assertEqual(QueryStatus.SUCCESS, query.status)
self.assertTrue(f"FROM {expected_full_table_name}" in query.select_sql)
self.assertEqual(
f"CREATE TABLE {expected_full_table_name} AS \n"
"SELECT name FROM birth_names "
"WHERE name='James' "
"LIMIT 10",
query.executed_sql,
)
self.drop_table_if_exists(expected_full_table_name, get_example_database())
query = self.get_query_by_id(result["query"]["serverId"])
self.assertEqual(QueryStatus.SUCCESS, query.status)
self.assertIn(expected_full_table_name, query.select_sql)
self.assertEqual(
f"CREATE {ctas_method} {CTAS_SCHEMA_NAME}.{tmp_table_name} AS \n"
"SELECT name FROM birth_names "
"WHERE name='James' "
"LIMIT 10",
query.executed_sql,
)
self.drop_table_if_exists(expected_full_table_name, get_example_database())
def test_run_async_cta_query(self):
@parameterized.expand([CtasMethod.TABLE, CtasMethod.VIEW])
def test_run_async_cta_query(self, ctas_method):
main_db = get_example_database()
db_id = main_db.id
table_name = "tmp_async_4"
table_name = f"tmp_async_4_{ctas_method}"
self.drop_table_if_exists(table_name, main_db)
time.sleep(DROP_TABLE_SLEEP_TIME)
sql_where = "SELECT name FROM birth_names WHERE name='James' LIMIT 10"
result = self.run_sql(
db_id, sql_where, "cid4", async_=True, tmp_table="tmp_async_4", cta=True
db_id,
sql_where,
f"5_{ctas_method}",
async_=True,
tmp_table=table_name,
cta=True,
ctas_method=ctas_method,
)
db.session.close()
assert result["query"]["state"] in (
@ -282,10 +329,10 @@ class CeleryTestCase(SupersetTestCase):
query = self.get_query_by_id(result["query"]["serverId"])
self.assertEqual(QueryStatus.SUCCESS, query.status)
self.assertTrue(f"FROM {table_name}" in query.select_sql)
self.assertIn(table_name, query.select_sql)
self.assertEqual(
f"CREATE TABLE {table_name} AS \n"
f"CREATE {ctas_method} {table_name} AS \n"
"SELECT name FROM birth_names "
"WHERE name='James' "
"LIMIT 10",
@ -296,15 +343,22 @@ class CeleryTestCase(SupersetTestCase):
self.assertEqual(True, query.select_as_cta)
self.assertEqual(True, query.select_as_cta_used)
def test_run_async_cta_query_with_lower_limit(self):
@parameterized.expand([CtasMethod.TABLE, CtasMethod.VIEW])
def test_run_async_cta_query_with_lower_limit(self, ctas_method):
main_db = get_example_database()
db_id = main_db.id
tmp_table = "tmp_async_2"
tmp_table = f"tmp_async_2_{ctas_method}"
self.drop_table_if_exists(tmp_table, main_db)
sql_where = "SELECT name FROM birth_names LIMIT 1"
result = self.run_sql(
db_id, sql_where, "id1", async_=True, tmp_table=tmp_table, cta=True
db_id,
sql_where,
f"6_{ctas_method}",
async_=True,
tmp_table=tmp_table,
cta=True,
ctas_method=ctas_method,
)
db.session.close()
assert result["query"]["state"] in (
@ -318,9 +372,10 @@ class CeleryTestCase(SupersetTestCase):
query = self.get_query_by_id(result["query"]["serverId"])
self.assertEqual(QueryStatus.SUCCESS, query.status)
self.assertIn(f"FROM {tmp_table}", query.select_sql)
self.assertIn(tmp_table, query.select_sql)
self.assertEqual(
f"CREATE TABLE {tmp_table} AS \n" "SELECT name FROM birth_names LIMIT 1",
f"CREATE {ctas_method} {tmp_table} AS \n"
"SELECT name FROM birth_names LIMIT 1",
query.executed_sql,
)
self.assertEqual(sql_where, query.sql)

View File

@ -43,6 +43,7 @@ class DatabaseApiTests(SupersetTestCase):
expected_columns = [
"allow_csv_upload",
"allow_ctas",
"allow_cvas",
"allow_dml",
"allow_multi_schema_metadata_fetch",
"allow_run_async",

View File

@ -18,6 +18,7 @@
"""Unit tests for Sql Lab"""
import json
from datetime import datetime, timedelta
from parameterized import parameterized
from random import random
from unittest import mock
@ -29,6 +30,7 @@ from superset.connectors.sqla.models import SqlaTable
from superset.db_engine_specs import BaseEngineSpec
from superset.models.sql_lab import Query
from superset.result_set import SupersetResultSet
from superset.sql_parse import CtasMethod
from superset.utils.core import datetime_to_epoch, get_example_database
from .base_tests import SupersetTestCase
@ -67,38 +69,44 @@ class SqlLabTests(SupersetTestCase):
data = self.run_sql("SELECT * FROM unexistant_table", "2")
self.assertLess(0, len(data["error"]))
@mock.patch(
"superset.views.core.get_cta_schema_name",
lambda d, u, s, sql: f"{u.username}_database",
)
def test_sql_json_cta_dynamic_db(self):
@parameterized.expand([CtasMethod.TABLE, CtasMethod.VIEW])
def test_sql_json_cta_dynamic_db(self, ctas_method):
main_db = get_example_database()
if main_db.backend == "sqlite":
# sqlite doesn't support database creation
return
old_allow_ctas = main_db.allow_ctas
main_db.allow_ctas = True # enable cta
with mock.patch(
"superset.views.core.get_cta_schema_name",
lambda d, u, s, sql: f"{u.username}_database",
):
old_allow_ctas = main_db.allow_ctas
main_db.allow_ctas = True # enable cta
self.login("admin")
self.run_sql(
"SELECT * FROM birth_names",
"1",
database_name="examples",
tmp_table_name="test_target",
select_as_cta=True,
)
self.login("admin")
tmp_table_name = f"test_target_{ctas_method.lower()}"
self.run_sql(
"SELECT * FROM birth_names",
"1",
database_name="examples",
tmp_table_name=tmp_table_name,
select_as_cta=True,
ctas_method=ctas_method,
)
# assertions
data = db.session.execute("SELECT * FROM admin_database.test_target").fetchall()
self.assertEqual(
75691, len(data)
) # SQL_MAX_ROW not applied due to the SQLLAB_CTAS_NO_LIMIT set to True
# assertions
db.session.commit()
data = db.session.execute(
f"SELECT * FROM admin_database.{tmp_table_name}"
).fetchall()
self.assertEqual(
75691, len(data)
) # SQL_MAX_ROW not applied due to the SQLLAB_CTAS_NO_LIMIT set to True
# cleanup
db.session.execute("DROP TABLE admin_database.test_target")
main_db.allow_ctas = old_allow_ctas
db.session.commit()
# cleanup
db.session.execute(f"DROP {ctas_method} admin_database.{tmp_table_name}")
main_db.allow_ctas = old_allow_ctas
db.session.commit()
def test_multi_sql(self):
self.login("admin")