Implement create view as functionality (#9794)

Implement create view as button in sqllab

Make CVAS configurable

Co-authored-by: bogdan kyryliuk <bogdankyryliuk@dropbox.com>
This commit is contained in:
Bogdan 2020-06-24 09:50:41 -07:00 committed by GitHub
parent 38667b72b1
commit 3db76c6fdc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 304 additions and 137 deletions

View File

@ -32,7 +32,9 @@ assists people when migrating to a new version.
* [9786](https://github.com/apache/incubator-superset/pull/9786): with the upgrade of `werkzeug` from version `0.16.0` to `1.0.1`, the `werkzeug.contrib.cache` module has been moved to a standalone package [cachelib](https://pypi.org/project/cachelib/). For example, to import the `RedisCache` class, please use the following import: `from cachelib.redis import RedisCache`. * [9786](https://github.com/apache/incubator-superset/pull/9786): with the upgrade of `werkzeug` from version `0.16.0` to `1.0.1`, the `werkzeug.contrib.cache` module has been moved to a standalone package [cachelib](https://pypi.org/project/cachelib/). For example, to import the `RedisCache` class, please use the following import: `from cachelib.redis import RedisCache`.
* [9572](https://github.com/apache/incubator-superset/pull/9572): a change which by defau;t means that the Jinja `current_user_id`, `current_username`, and `url_param` context calls no longer need to be wrapped via `cache_key_wrapper` in order to be included in the cache key. The `cache_key_wrapper` function should only be required for Jinja add-ons. * [9794](https://github.com/apache/incubator-superset/pull/9794): introduces `create view as` functionality in the sqllab. This change will require the `query` table migration and potential service downtime as that table has quite some traffic.
* [9572](https://github.com/apache/incubator-superset/pull/9572): a change which by default means that the Jinja `current_user_id`, `current_username`, and `url_param` context calls no longer need to be wrapped via `cache_key_wrapper` in order to be included in the cache key. The `cache_key_wrapper` function should only be required for Jinja add-ons.
* [8867](https://github.com/apache/incubator-superset/pull/8867): a change which adds the `tmp_schema_name` column to the `query` table which requires locking the table. Given the `query` table is heavily used performance may be degraded during the migration. Scheduled downtime may be advised. * [8867](https://github.com/apache/incubator-superset/pull/8867): a change which adds the `tmp_schema_name` column to the `query` table which requires locking the table. Given the `query` table is heavily used performance may be degraded during the migration. Scheduled downtime may be advised.

View File

@ -22,6 +22,7 @@ ipdb==0.12
isort==4.3.21 isort==4.3.21
mypy==0.770 mypy==0.770
nose==1.3.7 nose==1.3.7
parameterized==0.7.4
pip-tools==5.1.2 pip-tools==5.1.2
pre-commit==1.17.0 pre-commit==1.17.0
psycopg2-binary==2.8.5 psycopg2-binary==2.8.5

View File

@ -97,6 +97,11 @@ export const addSuccessToast = addSuccessToastAction;
export const addDangerToast = addDangerToastAction; export const addDangerToast = addDangerToastAction;
export const addWarningToast = addWarningToastAction; export const addWarningToast = addWarningToastAction;
export const CtasEnum = {
TABLE: 'TABLE',
VIEW: 'VIEW',
};
// a map of SavedQuery field names to the different names used client-side, // a map of SavedQuery field names to the different names used client-side,
// because for now making the names consistent is too complicated // because for now making the names consistent is too complicated
// so it might as well only happen in one place // so it might as well only happen in one place
@ -346,6 +351,7 @@ export function runQuery(query) {
tab: query.tab, tab: query.tab,
tmp_table_name: query.tempTableName, tmp_table_name: query.tempTableName,
select_as_cta: query.ctas, select_as_cta: query.ctas,
ctas_method: query.ctas_method,
templateParams: query.templateParams, templateParams: query.templateParams,
queryLimit: query.queryLimit, queryLimit: query.queryLimit,
expand_data: true, expand_data: true,

View File

@ -30,6 +30,7 @@ import FilterableTable from '../../components/FilterableTable/FilterableTable';
import QueryStateLabel from './QueryStateLabel'; import QueryStateLabel from './QueryStateLabel';
import CopyToClipboard from '../../components/CopyToClipboard'; import CopyToClipboard from '../../components/CopyToClipboard';
import { prepareCopyToClipboardTabularData } from '../../utils/common'; import { prepareCopyToClipboardTabularData } from '../../utils/common';
import { CtasEnum } from '../actions/sqlLab';
const propTypes = { const propTypes = {
actions: PropTypes.object, actions: PropTypes.object,
@ -219,10 +220,14 @@ export default class ResultSet extends React.PureComponent {
tmpTable = query.results.query.tempTable; tmpTable = query.results.query.tempTable;
tmpSchema = query.results.query.tempSchema; tmpSchema = query.results.query.tempSchema;
} }
let object = 'Table';
if (query.ctas_method === CtasEnum.VIEW) {
object = 'View';
}
return ( return (
<div> <div>
<Alert bsStyle="info"> <Alert bsStyle="info">
{t('Table')} [ {t(object)} [
<strong> <strong>
{tmpSchema}.{tmpTable} {tmpSchema}.{tmpTable}
</strong> </strong>

View File

@ -54,6 +54,7 @@ import {
} from '../constants'; } from '../constants';
import RunQueryActionButton from './RunQueryActionButton'; import RunQueryActionButton from './RunQueryActionButton';
import { FeatureFlag, isFeatureEnabled } from '../../featureFlags'; import { FeatureFlag, isFeatureEnabled } from '../../featureFlags';
import { CtasEnum } from '../actions/sqlLab';
const SQL_EDITOR_PADDING = 10; const SQL_EDITOR_PADDING = 10;
const INITIAL_NORTH_PERCENT = 30; const INITIAL_NORTH_PERCENT = 30;
@ -284,7 +285,7 @@ class SqlEditor extends React.PureComponent {
this.startQuery(); this.startQuery();
} }
} }
startQuery(ctas = false) { startQuery(ctas = false, ctas_method = CtasEnum.TABLE) {
const qe = this.props.queryEditor; const qe = this.props.queryEditor;
const query = { const query = {
dbId: qe.dbId, dbId: qe.dbId,
@ -299,6 +300,7 @@ class SqlEditor extends React.PureComponent {
? this.props.database.allow_run_async ? this.props.database.allow_run_async
: false, : false,
ctas, ctas,
ctas_method,
updateTabState: !qe.selectedText, updateTabState: !qe.selectedText,
}; };
this.props.actions.runQuery(query); this.props.actions.runQuery(query);
@ -313,7 +315,10 @@ class SqlEditor extends React.PureComponent {
} }
} }
createTableAs() { createTableAs() {
this.startQuery(true); this.startQuery(true, CtasEnum.TABLE);
}
createViewAs() {
this.startQuery(true, CtasEnum.VIEW);
} }
ctasChanged(event) { ctasChanged(event) {
this.setState({ ctas: event.target.value }); this.setState({ ctas: event.target.value });
@ -372,8 +377,13 @@ class SqlEditor extends React.PureComponent {
} }
renderEditorBottomBar(hotkeys) { renderEditorBottomBar(hotkeys) {
let ctasControls; let ctasControls;
if (this.props.database && this.props.database.allow_ctas) { if (
this.props.database &&
(this.props.database.allow_ctas || this.props.database.allow_cvas)
) {
const ctasToolTip = t('Create table as with query results'); const ctasToolTip = t('Create table as with query results');
const cvasToolTip = t('Create view as with query results');
ctasControls = ( ctasControls = (
<FormGroup> <FormGroup>
<InputGroup> <InputGroup>
@ -385,6 +395,7 @@ class SqlEditor extends React.PureComponent {
onChange={this.ctasChanged.bind(this)} onChange={this.ctasChanged.bind(this)}
/> />
<InputGroup.Button> <InputGroup.Button>
{this.props.database.allow_ctas && (
<Button <Button
bsSize="small" bsSize="small"
disabled={this.state.ctas.length === 0} disabled={this.state.ctas.length === 0}
@ -393,6 +404,17 @@ class SqlEditor extends React.PureComponent {
> >
<i className="fa fa-table" /> CTAS <i className="fa fa-table" /> CTAS
</Button> </Button>
)}
{this.props.database.allow_cvas && (
<Button
bsSize="small"
disabled={this.state.ctas.length === 0}
onClick={this.createViewAs.bind(this)}
tooltip={cvasToolTip}
>
<i className="fa fa-table" /> CVAS
</Button>
)}
</InputGroup.Button> </InputGroup.Button>
</InputGroup> </InputGroup>
</FormGroup> </FormGroup>

View File

@ -0,0 +1,42 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Add ctas_method to the Query object
Revision ID: ea396d202291
Revises: e557699a813e
Create Date: 2020-05-12 12:59:26.583276
"""
# revision identifiers, used by Alembic.
revision = "ea396d202291"
down_revision = "e557699a813e"
import sqlalchemy as sa
from alembic import op
def upgrade():
op.add_column(
"query", sa.Column("ctas_method", sa.String(length=16), nullable=True)
)
op.add_column("dbs", sa.Column("allow_cvas", sa.Boolean(), nullable=True))
def downgrade():
op.drop_column("query", "ctas_method")
op.drop_column("dbs", "allow_cvas")

View File

@ -118,6 +118,7 @@ class Database(
allow_run_async = Column(Boolean, default=False) allow_run_async = Column(Boolean, default=False)
allow_csv_upload = Column(Boolean, default=False) allow_csv_upload = Column(Boolean, default=False)
allow_ctas = Column(Boolean, default=False) allow_ctas = Column(Boolean, default=False)
allow_cvas = Column(Boolean, default=False)
allow_dml = Column(Boolean, default=False) allow_dml = Column(Boolean, default=False)
force_ctas_schema = Column(String(250)) force_ctas_schema = Column(String(250))
allow_multi_schema_metadata_fetch = Column( # pylint: disable=invalid-name allow_multi_schema_metadata_fetch = Column( # pylint: disable=invalid-name
@ -147,6 +148,7 @@ class Database(
"expose_in_sqllab", "expose_in_sqllab",
"allow_run_async", "allow_run_async",
"allow_ctas", "allow_ctas",
"allow_cvas",
"allow_csv_upload", "allow_csv_upload",
"extra", "extra",
] ]

View File

@ -19,7 +19,6 @@ import re
from datetime import datetime from datetime import datetime
from typing import Any, Dict from typing import Any, Dict
# pylint: disable=ungrouped-imports
import simplejson as json import simplejson as json
import sqlalchemy as sqla import sqlalchemy as sqla
from flask import Markup from flask import Markup
@ -40,6 +39,7 @@ from sqlalchemy.orm import backref, relationship
from superset import security_manager from superset import security_manager
from superset.models.helpers import AuditMixinNullable, ExtraJSONMixin from superset.models.helpers import AuditMixinNullable, ExtraJSONMixin
from superset.models.tags import QueryUpdater from superset.models.tags import QueryUpdater
from superset.sql_parse import CtasMethod
from superset.utils.core import QueryStatus, user_label from superset.utils.core import QueryStatus, user_label
@ -72,6 +72,7 @@ class Query(Model, ExtraJSONMixin):
limit = Column(Integer) limit = Column(Integer)
select_as_cta = Column(Boolean) select_as_cta = Column(Boolean)
select_as_cta_used = Column(Boolean, default=False) select_as_cta_used = Column(Boolean, default=False)
ctas_method = Column(String(16), default=CtasMethod.TABLE)
progress = Column(Integer, default=0) # 1..100 progress = Column(Integer, default=0) # 1..100
# # of rows in the result set or rows modified. # # of rows in the result set or rows modified.

View File

@ -223,7 +223,9 @@ def execute_sql_statement(
query.user_id, start_dttm.strftime("%Y_%m_%d_%H_%M_%S") query.user_id, start_dttm.strftime("%Y_%m_%d_%H_%M_%S")
) )
sql = parsed_query.as_create_table( sql = parsed_query.as_create_table(
query.tmp_table_name, schema_name=query.tmp_schema_name query.tmp_table_name,
schema_name=query.tmp_schema_name,
method=query.ctas_method,
) )
query.select_as_cta_used = True query.select_as_cta_used = True

View File

@ -16,6 +16,7 @@
# under the License. # under the License.
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum
from typing import List, Optional, Set from typing import List, Optional, Set
from urllib import parse from urllib import parse
@ -31,6 +32,11 @@ CTE_PREFIX = "CTE__"
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CtasMethod(str, Enum):
TABLE = "TABLE"
VIEW = "VIEW"
def _extract_limit_from_query(statement: TokenList) -> Optional[int]: def _extract_limit_from_query(statement: TokenList) -> Optional[int]:
""" """
Extract limit clause from SQL statement. Extract limit clause from SQL statement.
@ -185,6 +191,7 @@ class ParsedQuery:
table_name: str, table_name: str,
schema_name: Optional[str] = None, schema_name: Optional[str] = None,
overwrite: bool = False, overwrite: bool = False,
method: CtasMethod = CtasMethod.TABLE,
) -> str: ) -> str:
"""Reformats the query into the create table as query. """Reformats the query into the create table as query.
@ -193,6 +200,7 @@ class ParsedQuery:
:param table_name: table that will contain the results of the query execution :param table_name: table that will contain the results of the query execution
:param schema_name: schema name for the target table :param schema_name: schema name for the target table
:param overwrite: table_name will be dropped if true :param overwrite: table_name will be dropped if true
:param method: method for the CTA query, currently view or table creation
:return: Create table as query :return: Create table as query
""" """
exec_sql = "" exec_sql = ""
@ -200,8 +208,8 @@ class ParsedQuery:
# TODO(bkyryliuk): quote full_table_name # TODO(bkyryliuk): quote full_table_name
full_table_name = f"{schema_name}.{table_name}" if schema_name else table_name full_table_name = f"{schema_name}.{table_name}" if schema_name else table_name
if overwrite: if overwrite:
exec_sql = f"DROP TABLE IF EXISTS {full_table_name};\n" exec_sql = f"DROP {method} IF EXISTS {full_table_name};\n"
exec_sql += f"CREATE TABLE {full_table_name} AS \n{sql}" exec_sql += f"CREATE {method} {full_table_name} AS \n{sql}"
return exec_sql return exec_sql
def _extract_from_token( # pylint: disable=too-many-branches def _extract_from_token( # pylint: disable=too-many-branches

View File

@ -82,7 +82,7 @@ from superset.security.analytics_db_safety import (
check_sqlalchemy_uri, check_sqlalchemy_uri,
DBSecurityException, DBSecurityException,
) )
from superset.sql_parse import ParsedQuery, Table from superset.sql_parse import CtasMethod, ParsedQuery, Table
from superset.sql_validators import get_validator_by_name from superset.sql_validators import get_validator_by_name
from superset.typing import FlaskResponse from superset.typing import FlaskResponse
from superset.utils import core as utils, dashboard_import_export from superset.utils import core as utils, dashboard_import_export
@ -133,6 +133,7 @@ logger = logging.getLogger(__name__)
DATABASE_KEYS = [ DATABASE_KEYS = [
"allow_csv_upload", "allow_csv_upload",
"allow_ctas", "allow_ctas",
"allow_cvas",
"allow_dml", "allow_dml",
"allow_multi_schema_metadata_fetch", "allow_multi_schema_metadata_fetch",
"allow_run_async", "allow_run_async",
@ -2239,6 +2240,9 @@ class Superset(BaseSupersetView):
) )
limit = 0 limit = 0
select_as_cta: bool = cast(bool, query_params.get("select_as_cta")) select_as_cta: bool = cast(bool, query_params.get("select_as_cta"))
ctas_method: CtasMethod = cast(
CtasMethod, query_params.get("ctas_method", CtasMethod.TABLE)
)
tmp_table_name: str = cast(str, query_params.get("tmp_table_name")) tmp_table_name: str = cast(str, query_params.get("tmp_table_name"))
client_id: str = cast( client_id: str = cast(
str, query_params.get("client_id") or utils.shortid()[:10] str, query_params.get("client_id") or utils.shortid()[:10]
@ -2267,6 +2271,7 @@ class Superset(BaseSupersetView):
sql=sql, sql=sql,
schema=schema, schema=schema,
select_as_cta=select_as_cta, select_as_cta=select_as_cta,
ctas_method=ctas_method,
start_time=now_as_float(), start_time=now_as_float(),
tab_name=tab_name, tab_name=tab_name,
status=status, status=status,

View File

@ -128,6 +128,7 @@ class DatabaseRestApi(DatabaseMixin, BaseSupersetModelRestApi):
"database_name", "database_name",
"expose_in_sqllab", "expose_in_sqllab",
"allow_ctas", "allow_ctas",
"allow_cvas",
"force_ctas_schema", "force_ctas_schema",
"allow_run_async", "allow_run_async",
"allow_dml", "allow_dml",

View File

@ -60,6 +60,7 @@ class DatabaseMixin:
"allow_run_async", "allow_run_async",
"allow_csv_upload", "allow_csv_upload",
"allow_ctas", "allow_ctas",
"allow_cvas",
"allow_dml", "allow_dml",
"force_ctas_schema", "force_ctas_schema",
"impersonate_user", "impersonate_user",
@ -111,6 +112,7 @@ class DatabaseMixin:
"for more information." "for more information."
), ),
"allow_ctas": _("Allow CREATE TABLE AS option in SQL Lab"), "allow_ctas": _("Allow CREATE TABLE AS option in SQL Lab"),
"allow_cvas": _("Allow CREATE VIEW AS option in SQL Lab"),
"allow_dml": _( "allow_dml": _(
"Allow users to run non-SELECT statements " "Allow users to run non-SELECT statements "
"(UPDATE, DELETE, CREATE, ...) " "(UPDATE, DELETE, CREATE, ...) "
@ -182,6 +184,7 @@ class DatabaseMixin:
label_columns = { label_columns = {
"expose_in_sqllab": _("Expose in SQL Lab"), "expose_in_sqllab": _("Expose in SQL Lab"),
"allow_ctas": _("Allow CREATE TABLE AS"), "allow_ctas": _("Allow CREATE TABLE AS"),
"allow_cvas": _("Allow CREATE VIEW AS"),
"allow_dml": _("Allow DML"), "allow_dml": _("Allow DML"),
"force_ctas_schema": _("CTAS Schema"), "force_ctas_schema": _("CTAS Schema"),
"database_name": _("Database"), "database_name": _("Database"),

View File

@ -27,6 +27,7 @@ from flask_appbuilder.security.sqla import models as ab_models
from flask_testing import TestCase from flask_testing import TestCase
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from superset.sql_parse import CtasMethod
from tests.test_app import app # isort:skip from tests.test_app import app # isort:skip
from superset import db, security_manager from superset import db, security_manager
from superset.connectors.base.models import BaseDatasource from superset.connectors.base.models import BaseDatasource
@ -259,6 +260,7 @@ class SupersetTestCase(TestCase):
select_as_cta=False, select_as_cta=False,
tmp_table_name=None, tmp_table_name=None,
schema=None, schema=None,
ctas_method=CtasMethod.TABLE,
): ):
if user_name: if user_name:
self.logout() self.logout()
@ -270,6 +272,7 @@ class SupersetTestCase(TestCase):
"client_id": client_id, "client_id": client_id,
"queryLimit": query_limit, "queryLimit": query_limit,
"sql_editor_id": sql_editor_id, "sql_editor_id": sql_editor_id,
"ctas_method": ctas_method,
} }
if tmp_table_name: if tmp_table_name:
json_payload["tmp_table_name"] = tmp_table_name json_payload["tmp_table_name"] = tmp_table_name

View File

@ -17,20 +17,15 @@
# isort:skip_file # isort:skip_file
"""Unit tests for Superset Celery worker""" """Unit tests for Superset Celery worker"""
import datetime import datetime
import io
import json import json
import logging from parameterized import parameterized
import subprocess import subprocess
import time import time
import unittest import unittest
import unittest.mock as mock import unittest.mock as mock
import flask import flask
import sqlalchemy
from contextlib2 import contextmanager
from flask import current_app from flask import current_app
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import NullPool
from tests.test_app import app from tests.test_app import app
from superset import db, sql_lab from superset import db, sql_lab
@ -39,7 +34,7 @@ from superset.db_engine_specs.base import BaseEngineSpec
from superset.extensions import celery_app from superset.extensions import celery_app
from superset.models.helpers import QueryStatus from superset.models.helpers import QueryStatus
from superset.models.sql_lab import Query from superset.models.sql_lab import Query
from superset.sql_parse import ParsedQuery from superset.sql_parse import ParsedQuery, CtasMethod
from superset.utils.core import get_example_database from superset.utils.core import get_example_database
from .base_tests import SupersetTestCase from .base_tests import SupersetTestCase
@ -120,7 +115,14 @@ class CeleryTestCase(SupersetTestCase):
db.session.commit() db.session.commit()
def run_sql( def run_sql(
self, db_id, sql, client_id=None, cta=False, tmp_table="tmp", async_=False self,
db_id,
sql,
client_id=None,
cta=False,
tmp_table="tmp",
async_=False,
ctas_method=CtasMethod.TABLE,
): ):
self.login() self.login()
resp = self.client.post( resp = self.client.post(
@ -132,34 +134,55 @@ class CeleryTestCase(SupersetTestCase):
select_as_cta=cta, select_as_cta=cta,
tmp_table_name=tmp_table, tmp_table_name=tmp_table,
client_id=client_id, client_id=client_id,
ctas_method=ctas_method,
), ),
) )
self.logout() self.logout()
return json.loads(resp.data) return json.loads(resp.data)
def test_run_sync_query_dont_exist(self): @parameterized.expand(
[CtasMethod.TABLE,]
)
def test_run_sync_query_dont_exist(self, ctas_method):
main_db = get_example_database() main_db = get_example_database()
db_id = main_db.id db_id = main_db.id
sql_dont_exist = "SELECT name FROM table_dont_exist" sql_dont_exist = "SELECT name FROM table_dont_exist"
result1 = self.run_sql(db_id, sql_dont_exist, "1", cta=True) result = self.run_sql(
self.assertTrue("error" in result1) db_id, sql_dont_exist, f"1_{ctas_method}", cta=True, ctas_method=ctas_method
)
if (
get_example_database().backend != "sqlite"
and ctas_method == CtasMethod.VIEW
):
self.assertEqual(QueryStatus.SUCCESS, result["status"], msg=result)
else:
self.assertEqual(QueryStatus.FAILED, result["status"], msg=result)
def test_run_sync_query_cta(self): @parameterized.expand([CtasMethod.TABLE, CtasMethod.VIEW])
def test_run_sync_query_cta(self, ctas_method):
main_db = get_example_database() main_db = get_example_database()
db_id = main_db.id db_id = main_db.id
tmp_table_name = "tmp_async_22" tmp_table_name = f"tmp_sync_23_{ctas_method.lower()}"
self.drop_table_if_exists(tmp_table_name, main_db) self.drop_table_if_exists(tmp_table_name, main_db)
name = "James" name = "James"
sql_where = f"SELECT name FROM birth_names WHERE name='{name}' LIMIT 1" sql_where = f"SELECT name FROM birth_names WHERE name='{name}' LIMIT 1"
result = self.run_sql(db_id, sql_where, "2", tmp_table=tmp_table_name, cta=True) result = self.run_sql(
self.assertEqual(QueryStatus.SUCCESS, result["query"]["state"]) db_id,
sql_where,
f"2_{ctas_method}",
tmp_table=tmp_table_name,
cta=True,
ctas_method=ctas_method,
)
# provide better error message
self.assertEqual(QueryStatus.SUCCESS, result["query"]["state"], msg=result)
self.assertEqual([], result["data"]) self.assertEqual([], result["data"])
self.assertEqual([], result["columns"]) self.assertEqual([], result["columns"])
query2 = self.get_query_by_id(result["query"]["serverId"]) query2 = self.get_query_by_id(result["query"]["serverId"])
# Check the data in the tmp table. # Check the data in the tmp table.
results = self.run_sql(db_id, query2.select_sql, "sdf2134") results = self.run_sql(db_id, query2.select_sql, f"7_{ctas_method}")
self.assertEqual(results["status"], "success") self.assertEqual(QueryStatus.SUCCESS, results["status"], msg=results)
self.assertGreater(len(results["data"]), 0) self.assertGreater(len(results["data"]), 0)
# cleanup tmp table # cleanup tmp table
@ -186,71 +209,88 @@ class CeleryTestCase(SupersetTestCase):
db.session.flush() db.session.flush()
return self.run_sql(db_id, sql) return self.run_sql(db_id, sql)
@mock.patch( @parameterized.expand([CtasMethod.TABLE, CtasMethod.VIEW])
"superset.views.core.get_cta_schema_name", lambda d, u, s, sql: CTAS_SCHEMA_NAME def test_run_sync_query_cta_config(self, ctas_method):
) with mock.patch(
def test_run_sync_query_cta_config(self): "superset.views.core.get_cta_schema_name",
lambda d, u, s, sql: CTAS_SCHEMA_NAME,
):
main_db = get_example_database() main_db = get_example_database()
db_id = main_db.id db_id = main_db.id
if main_db.backend == "sqlite": if main_db.backend == "sqlite":
# sqlite doesn't support schemas # sqlite doesn't support schemas
return return
tmp_table_name = "tmp_async_22" tmp_table_name = f"tmp_async_22_{ctas_method.lower()}"
expected_full_table_name = f"{CTAS_SCHEMA_NAME}.{tmp_table_name}" quote = (
main_db.inspector.engine.dialect.identifier_preparer.quote_identifier
)
expected_full_table_name = f"{CTAS_SCHEMA_NAME}.{quote(tmp_table_name)}"
self.drop_table_if_exists(expected_full_table_name, main_db) self.drop_table_if_exists(expected_full_table_name, main_db)
name = "James" name = "James"
sql_where = f"SELECT name FROM birth_names WHERE name='{name}'" sql_where = f"SELECT name FROM birth_names WHERE name='{name}'"
result = self.run_sql( result = self.run_sql(
db_id, sql_where, "cid2", tmp_table=tmp_table_name, cta=True db_id,
sql_where,
f"3_{ctas_method}",
tmp_table=tmp_table_name,
cta=True,
ctas_method=ctas_method,
) )
self.assertEqual(QueryStatus.SUCCESS, result["query"]["state"]) self.assertEqual(QueryStatus.SUCCESS, result["query"]["state"], msg=result)
self.assertEqual([], result["data"]) self.assertEqual([], result["data"])
self.assertEqual([], result["columns"]) self.assertEqual([], result["columns"])
query = self.get_query_by_id(result["query"]["serverId"]) query = self.get_query_by_id(result["query"]["serverId"])
self.assertEqual( self.assertEqual(
f"CREATE TABLE {expected_full_table_name} AS \n" f"CREATE {ctas_method} {CTAS_SCHEMA_NAME}.{tmp_table_name} AS \n"
"SELECT name FROM birth_names " "SELECT name FROM birth_names "
"WHERE name='James'", "WHERE name='James'",
query.executed_sql, query.executed_sql,
) )
self.assertEqual( self.assertEqual(
"SELECT *\n" f"FROM {expected_full_table_name}", query.select_sql "SELECT *\n" f"FROM {CTAS_SCHEMA_NAME}.{tmp_table_name}",
query.select_sql,
) )
time.sleep(CELERY_SHORT_SLEEP_TIME) time.sleep(CELERY_SHORT_SLEEP_TIME)
results = self.run_sql(db_id, query.select_sql) results = self.run_sql(db_id, query.select_sql)
self.assertEqual(results["status"], "success") self.assertEqual(QueryStatus.SUCCESS, results["status"], msg=result)
self.drop_table_if_exists(expected_full_table_name, get_example_database()) self.drop_table_if_exists(expected_full_table_name, get_example_database())
@mock.patch( @parameterized.expand([CtasMethod.TABLE, CtasMethod.VIEW])
"superset.views.core.get_cta_schema_name", lambda d, u, s, sql: CTAS_SCHEMA_NAME def test_run_async_query_cta_config(self, ctas_method):
) with mock.patch(
def test_run_async_query_cta_config(self): "superset.views.core.get_cta_schema_name",
lambda d, u, s, sql: CTAS_SCHEMA_NAME,
):
main_db = get_example_database() main_db = get_example_database()
db_id = main_db.id db_id = main_db.id
if main_db.backend == "sqlite": if main_db.backend == "sqlite":
# sqlite doesn't support schemas # sqlite doesn't support schemas
return return
tmp_table_name = "sqllab_test_table_async_1" tmp_table_name = f"sqllab_test_table_async_1_{ctas_method}"
expected_full_table_name = f"{CTAS_SCHEMA_NAME}.{tmp_table_name}" quote = (
main_db.inspector.engine.dialect.identifier_preparer.quote_identifier
)
expected_full_table_name = f"{CTAS_SCHEMA_NAME}.{quote(tmp_table_name)}"
self.drop_table_if_exists(expected_full_table_name, main_db) self.drop_table_if_exists(expected_full_table_name, main_db)
sql_where = "SELECT name FROM birth_names WHERE name='James' LIMIT 10" sql_where = "SELECT name FROM birth_names WHERE name='James' LIMIT 10"
result = self.run_sql( result = self.run_sql(
db_id, db_id,
sql_where, sql_where,
"cid3", f"4_{ctas_method}",
async_=True, async_=True,
tmp_table="sqllab_test_table_async_1", tmp_table=tmp_table_name,
cta=True, cta=True,
ctas_method=ctas_method,
) )
db.session.close() db.session.close()
time.sleep(CELERY_SLEEP_TIME) time.sleep(CELERY_SLEEP_TIME)
query = self.get_query_by_id(result["query"]["serverId"]) query = self.get_query_by_id(result["query"]["serverId"])
self.assertEqual(QueryStatus.SUCCESS, query.status) self.assertEqual(QueryStatus.SUCCESS, query.status)
self.assertTrue(f"FROM {expected_full_table_name}" in query.select_sql) self.assertIn(expected_full_table_name, query.select_sql)
self.assertEqual( self.assertEqual(
f"CREATE TABLE {expected_full_table_name} AS \n" f"CREATE {ctas_method} {CTAS_SCHEMA_NAME}.{tmp_table_name} AS \n"
"SELECT name FROM birth_names " "SELECT name FROM birth_names "
"WHERE name='James' " "WHERE name='James' "
"LIMIT 10", "LIMIT 10",
@ -258,17 +298,24 @@ class CeleryTestCase(SupersetTestCase):
) )
self.drop_table_if_exists(expected_full_table_name, get_example_database()) self.drop_table_if_exists(expected_full_table_name, get_example_database())
def test_run_async_cta_query(self): @parameterized.expand([CtasMethod.TABLE, CtasMethod.VIEW])
def test_run_async_cta_query(self, ctas_method):
main_db = get_example_database() main_db = get_example_database()
db_id = main_db.id db_id = main_db.id
table_name = "tmp_async_4" table_name = f"tmp_async_4_{ctas_method}"
self.drop_table_if_exists(table_name, main_db) self.drop_table_if_exists(table_name, main_db)
time.sleep(DROP_TABLE_SLEEP_TIME) time.sleep(DROP_TABLE_SLEEP_TIME)
sql_where = "SELECT name FROM birth_names WHERE name='James' LIMIT 10" sql_where = "SELECT name FROM birth_names WHERE name='James' LIMIT 10"
result = self.run_sql( result = self.run_sql(
db_id, sql_where, "cid4", async_=True, tmp_table="tmp_async_4", cta=True db_id,
sql_where,
f"5_{ctas_method}",
async_=True,
tmp_table=table_name,
cta=True,
ctas_method=ctas_method,
) )
db.session.close() db.session.close()
assert result["query"]["state"] in ( assert result["query"]["state"] in (
@ -282,10 +329,10 @@ class CeleryTestCase(SupersetTestCase):
query = self.get_query_by_id(result["query"]["serverId"]) query = self.get_query_by_id(result["query"]["serverId"])
self.assertEqual(QueryStatus.SUCCESS, query.status) self.assertEqual(QueryStatus.SUCCESS, query.status)
self.assertTrue(f"FROM {table_name}" in query.select_sql) self.assertIn(table_name, query.select_sql)
self.assertEqual( self.assertEqual(
f"CREATE TABLE {table_name} AS \n" f"CREATE {ctas_method} {table_name} AS \n"
"SELECT name FROM birth_names " "SELECT name FROM birth_names "
"WHERE name='James' " "WHERE name='James' "
"LIMIT 10", "LIMIT 10",
@ -296,15 +343,22 @@ class CeleryTestCase(SupersetTestCase):
self.assertEqual(True, query.select_as_cta) self.assertEqual(True, query.select_as_cta)
self.assertEqual(True, query.select_as_cta_used) self.assertEqual(True, query.select_as_cta_used)
def test_run_async_cta_query_with_lower_limit(self): @parameterized.expand([CtasMethod.TABLE, CtasMethod.VIEW])
def test_run_async_cta_query_with_lower_limit(self, ctas_method):
main_db = get_example_database() main_db = get_example_database()
db_id = main_db.id db_id = main_db.id
tmp_table = "tmp_async_2" tmp_table = f"tmp_async_2_{ctas_method}"
self.drop_table_if_exists(tmp_table, main_db) self.drop_table_if_exists(tmp_table, main_db)
sql_where = "SELECT name FROM birth_names LIMIT 1" sql_where = "SELECT name FROM birth_names LIMIT 1"
result = self.run_sql( result = self.run_sql(
db_id, sql_where, "id1", async_=True, tmp_table=tmp_table, cta=True db_id,
sql_where,
f"6_{ctas_method}",
async_=True,
tmp_table=tmp_table,
cta=True,
ctas_method=ctas_method,
) )
db.session.close() db.session.close()
assert result["query"]["state"] in ( assert result["query"]["state"] in (
@ -318,9 +372,10 @@ class CeleryTestCase(SupersetTestCase):
query = self.get_query_by_id(result["query"]["serverId"]) query = self.get_query_by_id(result["query"]["serverId"])
self.assertEqual(QueryStatus.SUCCESS, query.status) self.assertEqual(QueryStatus.SUCCESS, query.status)
self.assertIn(f"FROM {tmp_table}", query.select_sql) self.assertIn(tmp_table, query.select_sql)
self.assertEqual( self.assertEqual(
f"CREATE TABLE {tmp_table} AS \n" "SELECT name FROM birth_names LIMIT 1", f"CREATE {ctas_method} {tmp_table} AS \n"
"SELECT name FROM birth_names LIMIT 1",
query.executed_sql, query.executed_sql,
) )
self.assertEqual(sql_where, query.sql) self.assertEqual(sql_where, query.sql)

View File

@ -43,6 +43,7 @@ class DatabaseApiTests(SupersetTestCase):
expected_columns = [ expected_columns = [
"allow_csv_upload", "allow_csv_upload",
"allow_ctas", "allow_ctas",
"allow_cvas",
"allow_dml", "allow_dml",
"allow_multi_schema_metadata_fetch", "allow_multi_schema_metadata_fetch",
"allow_run_async", "allow_run_async",

View File

@ -18,6 +18,7 @@
"""Unit tests for Sql Lab""" """Unit tests for Sql Lab"""
import json import json
from datetime import datetime, timedelta from datetime import datetime, timedelta
from parameterized import parameterized
from random import random from random import random
from unittest import mock from unittest import mock
@ -29,6 +30,7 @@ from superset.connectors.sqla.models import SqlaTable
from superset.db_engine_specs import BaseEngineSpec from superset.db_engine_specs import BaseEngineSpec
from superset.models.sql_lab import Query from superset.models.sql_lab import Query
from superset.result_set import SupersetResultSet from superset.result_set import SupersetResultSet
from superset.sql_parse import CtasMethod
from superset.utils.core import datetime_to_epoch, get_example_database from superset.utils.core import datetime_to_epoch, get_example_database
from .base_tests import SupersetTestCase from .base_tests import SupersetTestCase
@ -67,36 +69,42 @@ class SqlLabTests(SupersetTestCase):
data = self.run_sql("SELECT * FROM unexistant_table", "2") data = self.run_sql("SELECT * FROM unexistant_table", "2")
self.assertLess(0, len(data["error"])) self.assertLess(0, len(data["error"]))
@mock.patch( @parameterized.expand([CtasMethod.TABLE, CtasMethod.VIEW])
"superset.views.core.get_cta_schema_name", def test_sql_json_cta_dynamic_db(self, ctas_method):
lambda d, u, s, sql: f"{u.username}_database",
)
def test_sql_json_cta_dynamic_db(self):
main_db = get_example_database() main_db = get_example_database()
if main_db.backend == "sqlite": if main_db.backend == "sqlite":
# sqlite doesn't support database creation # sqlite doesn't support database creation
return return
with mock.patch(
"superset.views.core.get_cta_schema_name",
lambda d, u, s, sql: f"{u.username}_database",
):
old_allow_ctas = main_db.allow_ctas old_allow_ctas = main_db.allow_ctas
main_db.allow_ctas = True # enable cta main_db.allow_ctas = True # enable cta
self.login("admin") self.login("admin")
tmp_table_name = f"test_target_{ctas_method.lower()}"
self.run_sql( self.run_sql(
"SELECT * FROM birth_names", "SELECT * FROM birth_names",
"1", "1",
database_name="examples", database_name="examples",
tmp_table_name="test_target", tmp_table_name=tmp_table_name,
select_as_cta=True, select_as_cta=True,
ctas_method=ctas_method,
) )
# assertions # assertions
data = db.session.execute("SELECT * FROM admin_database.test_target").fetchall() db.session.commit()
data = db.session.execute(
f"SELECT * FROM admin_database.{tmp_table_name}"
).fetchall()
self.assertEqual( self.assertEqual(
75691, len(data) 75691, len(data)
) # SQL_MAX_ROW not applied due to the SQLLAB_CTAS_NO_LIMIT set to True ) # SQL_MAX_ROW not applied due to the SQLLAB_CTAS_NO_LIMIT set to True
# cleanup # cleanup
db.session.execute("DROP TABLE admin_database.test_target") db.session.execute(f"DROP {ctas_method} admin_database.{tmp_table_name}")
main_db.allow_ctas = old_allow_ctas main_db.allow_ctas = old_allow_ctas
db.session.commit() db.session.commit()