Prevent 'main' database connection creation (#8038)

* prevent 'main' database connection creation

* fix tests

* removing get_main_database

* Kill get_main_database

* Point to examples tables
This commit is contained in:
Maxime Beauchemin 2019-09-08 10:18:09 -07:00 committed by GitHub
parent 9d350aadf0
commit 68c4c3a0b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 186 additions and 240 deletions

View File

@ -48,7 +48,6 @@ def make_shell_context():
@app.cli.command()
def init():
"""Inits the Superset application"""
utils.get_or_create_main_db()
utils.get_example_database()
appbuilder.add_permissions(update_perms=True)
security_manager.sync_role_definitions()
@ -430,75 +429,40 @@ def load_test_users_run():
Syncs permissions for those users/roles
"""
if config.get("TESTING"):
security_manager.sync_role_definitions()
gamma_sqllab_role = security_manager.add_role("gamma_sqllab")
for perm in security_manager.find_role("Gamma").permissions:
security_manager.add_permission_role(gamma_sqllab_role, perm)
utils.get_or_create_main_db()
db_perm = utils.get_main_database().perm
security_manager.add_permission_view_menu("database_access", db_perm)
db_pvm = security_manager.find_permission_view_menu(
view_menu_name=db_perm, permission_name="database_access"
sm = security_manager
examples_db = utils.get_example_database()
examples_pv = sm.add_permission_view_menu("database_access", examples_db.perm)
sm.sync_role_definitions()
gamma_sqllab_role = sm.add_role("gamma_sqllab")
sm.add_permission_role(gamma_sqllab_role, examples_pv)
for role in ["Gamma", "sql_lab"]:
for perm in sm.find_role(role).permissions:
sm.add_permission_role(gamma_sqllab_role, perm)
users = (
("admin", "Admin"),
("gamma", "Gamma"),
("gamma2", "Gamma"),
("gamma_sqllab", "gamma_sqllab"),
("alpha", "Alpha"),
)
gamma_sqllab_role.permissions.append(db_pvm)
for perm in security_manager.find_role("sql_lab").permissions:
security_manager.add_permission_role(gamma_sqllab_role, perm)
admin = security_manager.find_user("admin")
if not admin:
security_manager.add_user(
"admin",
"admin",
" user",
"admin@fab.org",
security_manager.find_role("Admin"),
password="general",
)
gamma = security_manager.find_user("gamma")
if not gamma:
security_manager.add_user(
"gamma",
"gamma",
"user",
"gamma@fab.org",
security_manager.find_role("Gamma"),
password="general",
)
gamma2 = security_manager.find_user("gamma2")
if not gamma2:
security_manager.add_user(
"gamma2",
"gamma2",
"user",
"gamma2@fab.org",
security_manager.find_role("Gamma"),
password="general",
)
gamma_sqllab_user = security_manager.find_user("gamma_sqllab")
if not gamma_sqllab_user:
security_manager.add_user(
"gamma_sqllab",
"gamma_sqllab",
"user",
"gamma_sqllab@fab.org",
gamma_sqllab_role,
password="general",
)
alpha = security_manager.find_user("alpha")
if not alpha:
security_manager.add_user(
"alpha",
"alpha",
"user",
"alpha@fab.org",
security_manager.find_role("Alpha"),
password="general",
)
security_manager.get_session.commit()
for username, role in users:
user = sm.find_user(username)
if not user:
sm.add_user(
username,
username,
"user",
username + "@fab.org",
sm.find_role(role),
password="general",
)
sm.get_session.commit()
@app.cli.command()

View File

@ -200,7 +200,6 @@ class SupersetSecurityManager(SecurityManager):
:param database: The Superset database
:returns: Whether the user can access the Superset database
"""
return (
self.all_datasource_access()
or self.all_database_access()
@ -269,9 +268,9 @@ class SupersetSecurityManager(SecurityManager):
:param tables: The list of denied SQL table names
:returns: The error message
"""
return f"""You need access to the following tables: {", ".join(tables)}, all
database access or `all_datasource_access` permission"""
quoted_tables = [f"`{t}`" for t in tables]
return f"""You need access to the following tables: {", ".join(quoted_tables)},
`all_database_access` or `all_datasource_access` permission"""
def get_table_access_link(self, tables: List[str]) -> Optional[str]:
"""

View File

@ -936,10 +936,6 @@ def user_label(user: User) -> Optional[str]:
return None
def get_or_create_main_db():
get_main_database()
def get_or_create_db(database_name, sqlalchemy_uri, *args, **kwargs):
from superset import db
from superset.models import core as models
@ -957,12 +953,6 @@ def get_or_create_db(database_name, sqlalchemy_uri, *args, **kwargs):
return database
def get_main_database():
from superset import conf
return get_or_create_db("main", conf.get("SQLALCHEMY_DATABASE_URI"))
def get_example_database():
from superset import conf

View File

@ -2705,11 +2705,7 @@ class Superset(BaseSupersetView):
query.sql, query.database, query.schema
)
if rejected_tables:
flash(
security_manager.get_table_access_error_msg(
"{}".format(rejected_tables)
)
)
flash(security_manager.get_table_access_error_msg(rejected_tables))
return redirect("/")
blob = None
if results_backend and query.results_key:

View File

@ -28,7 +28,7 @@ from superset.connectors.druid.models import DruidCluster, DruidDatasource
from superset.connectors.sqla.models import SqlaTable
from superset.models import core as models
from superset.models.core import Database
from superset.utils.core import get_main_database
from superset.utils.core import get_example_database
BASE_DIR = app.config.get("BASE_DIR")
@ -168,6 +168,12 @@ class SupersetTestCase(unittest.TestCase):
):
security_manager.del_permission_role(public_role, perm)
def _get_database_by_name(self, database_name="main"):
if database_name == "examples":
return get_example_database()
else:
raise ValueError("Database doesn't exist")
def run_sql(
self,
sql,
@ -175,11 +181,12 @@ class SupersetTestCase(unittest.TestCase):
user_name=None,
raise_on_error=False,
query_limit=None,
database_name="examples",
):
if user_name:
self.logout()
self.login(username=(user_name if user_name else "admin"))
dbid = get_main_database().id
self.login(username=(user_name or "admin"))
dbid = self._get_database_by_name(database_name).id
resp = self.get_json_resp(
"/superset/sql_json/",
raise_on_error=False,
@ -195,11 +202,35 @@ class SupersetTestCase(unittest.TestCase):
raise Exception("run_sql failed")
return resp
def validate_sql(self, sql, client_id=None, user_name=None, raise_on_error=False):
def create_fake_db(self):
self.login(username="admin")
database_name = "fake_db_100"
db_id = 100
extra = """{
"schemas_allowed_for_csv_upload":
["this_schema_is_allowed", "this_schema_is_allowed_too"]
}"""
return self.get_or_create(
cls=models.Database,
criteria={"database_name": database_name},
session=db.session,
id=db_id,
extra=extra,
)
def validate_sql(
self,
sql,
client_id=None,
user_name=None,
raise_on_error=False,
database_name="examples",
):
if user_name:
self.logout()
self.login(username=(user_name if user_name else "admin"))
dbid = get_main_database().id
dbid = self._get_database_by_name(database_name).id
resp = self.get_json_resp(
"/superset/validate_sql_json/",
raise_on_error=False,

View File

@ -28,7 +28,7 @@ from superset.db_engine_specs.base import BaseEngineSpec
from superset.models.helpers import QueryStatus
from superset.models.sql_lab import Query
from superset.sql_parse import ParsedQuery
from superset.utils.core import get_main_database
from superset.utils.core import get_example_database
from .base_tests import SupersetTestCase
@ -132,20 +132,20 @@ class CeleryTestCase(SupersetTestCase):
return json.loads(resp.data)
def test_run_sync_query_dont_exist(self):
main_db = get_main_database()
main_db = get_example_database()
db_id = main_db.id
sql_dont_exist = "SELECT name FROM table_dont_exist"
result1 = self.run_sql(db_id, sql_dont_exist, "1", cta="true")
self.assertTrue("error" in result1)
def test_run_sync_query_cta(self):
main_db = get_main_database()
main_db = get_example_database()
backend = main_db.backend
db_id = main_db.id
tmp_table_name = "tmp_async_22"
self.drop_table_if_exists(tmp_table_name, main_db)
perm_name = "can_sql_json"
sql_where = "SELECT name FROM ab_permission WHERE name='{}'".format(perm_name)
name = "James"
sql_where = f"SELECT name FROM birth_names WHERE name='{name}' LIMIT 1"
result = self.run_sql(
db_id, sql_where, "2", tmp_table=tmp_table_name, cta="true"
)
@ -162,9 +162,9 @@ class CeleryTestCase(SupersetTestCase):
self.assertGreater(len(results["data"]), 0)
def test_run_sync_query_cta_no_data(self):
main_db = get_main_database()
main_db = get_example_database()
db_id = main_db.id
sql_empty_result = "SELECT * FROM ab_user WHERE id=666"
sql_empty_result = "SELECT * FROM birth_names WHERE name='random'"
result3 = self.run_sql(db_id, sql_empty_result, "3")
self.assertEqual(QueryStatus.SUCCESS, result3["query"]["state"])
self.assertEqual([], result3["data"])
@ -183,12 +183,12 @@ class CeleryTestCase(SupersetTestCase):
return self.run_sql(db_id, sql)
def test_run_async_query(self):
main_db = get_main_database()
main_db = get_example_database()
db_id = main_db.id
self.drop_table_if_exists("tmp_async_1", main_db)
sql_where = "SELECT name FROM ab_role WHERE name='Admin'"
sql_where = "SELECT name FROM birth_names WHERE name='James' LIMIT 10"
result = self.run_sql(
db_id, sql_where, "4", async_="true", tmp_table="tmp_async_1", cta="true"
)
@ -202,12 +202,13 @@ class CeleryTestCase(SupersetTestCase):
query = self.get_query_by_id(result["query"]["serverId"])
self.assertEqual(QueryStatus.SUCCESS, query.status)
self.assertTrue("FROM tmp_async_1" in query.select_sql)
self.assertEqual(
"CREATE TABLE tmp_async_1 AS \n"
"SELECT name FROM ab_role "
"WHERE name='Admin'\n"
"LIMIT 666",
"SELECT name FROM birth_names "
"WHERE name='James' "
"LIMIT 10",
query.executed_sql,
)
self.assertEqual(sql_where, query.sql)
@ -216,13 +217,14 @@ class CeleryTestCase(SupersetTestCase):
self.assertEqual(True, query.select_as_cta_used)
def test_run_async_query_with_lower_limit(self):
main_db = get_main_database()
main_db = get_example_database()
db_id = main_db.id
self.drop_table_if_exists("tmp_async_2", main_db)
tmp_table = "tmp_async_2"
self.drop_table_if_exists(tmp_table, main_db)
sql_where = "SELECT name FROM ab_role WHERE name='Alpha' LIMIT 1"
sql_where = "SELECT name FROM birth_names LIMIT 1"
result = self.run_sql(
db_id, sql_where, "5", async_="true", tmp_table="tmp_async_2", cta="true"
db_id, sql_where, "5", async_="true", tmp_table=tmp_table, cta="true"
)
assert result["query"]["state"] in (
QueryStatus.PENDING,
@ -234,10 +236,9 @@ class CeleryTestCase(SupersetTestCase):
query = self.get_query_by_id(result["query"]["serverId"])
self.assertEqual(QueryStatus.SUCCESS, query.status)
self.assertTrue("FROM tmp_async_2" in query.select_sql)
self.assertTrue(f"FROM {tmp_table}" in query.select_sql)
self.assertEqual(
"CREATE TABLE tmp_async_2 AS \nSELECT name FROM ab_role "
"WHERE name='Alpha' LIMIT 1",
f"CREATE TABLE {tmp_table} AS \n" "SELECT name FROM birth_names LIMIT 1",
query.executed_sql,
)
self.assertEqual(sql_where, query.sql)

View File

@ -346,13 +346,13 @@ class CoreTests(SupersetTestCase):
def test_testconn(self, username="admin"):
self.login(username=username)
database = utils.get_main_database()
database = utils.get_example_database()
# validate that the endpoint works with the password-masked sqlalchemy uri
data = json.dumps(
{
"uri": database.safe_sqlalchemy_uri(),
"name": "main",
"name": "examples",
"impersonate_user": False,
}
)
@ -366,7 +366,7 @@ class CoreTests(SupersetTestCase):
data = json.dumps(
{
"uri": database.sqlalchemy_uri_decrypted,
"name": "main",
"name": "examples",
"impersonate_user": False,
}
)
@ -377,7 +377,7 @@ class CoreTests(SupersetTestCase):
assert response.headers["Content-Type"] == "application/json"
def test_custom_password_store(self):
database = utils.get_main_database()
database = utils.get_example_database()
conn_pre = sqla.engine.url.make_url(database.sqlalchemy_uri_decrypted)
def custom_password_store(uri):
@ -395,13 +395,13 @@ class CoreTests(SupersetTestCase):
# validate that sending a password-masked uri does not over-write the decrypted
# uri
self.login(username=username)
database = utils.get_main_database()
database = utils.get_example_database()
sqlalchemy_uri_decrypted = database.sqlalchemy_uri_decrypted
url = "databaseview/edit/{}".format(database.id)
data = {k: database.__getattribute__(k) for k in DatabaseView.add_columns}
data["sqlalchemy_uri"] = database.safe_sqlalchemy_uri()
self.client.post(url, data=data)
database = utils.get_main_database()
database = utils.get_example_database()
self.assertEqual(sqlalchemy_uri_decrypted, database.sqlalchemy_uri_decrypted)
def test_warm_up_cache(self):
@ -460,51 +460,51 @@ class CoreTests(SupersetTestCase):
def test_csv_endpoint(self):
self.login("admin")
sql = """
SELECT first_name, last_name
FROM ab_user
WHERE first_name='admin'
SELECT name
FROM birth_names
WHERE name = 'James'
LIMIT 1
"""
client_id = "{}".format(random.getrandbits(64))[:10]
self.run_sql(sql, client_id, raise_on_error=True)
resp = self.get_resp("/superset/csv/{}".format(client_id))
data = csv.reader(io.StringIO(resp))
expected_data = csv.reader(io.StringIO("first_name,last_name\nadmin, user\n"))
expected_data = csv.reader(io.StringIO("name\nJames\n"))
sql = "SELECT first_name FROM ab_user WHERE first_name LIKE '%admin%'"
client_id = "{}".format(random.getrandbits(64))[:10]
self.run_sql(sql, client_id, raise_on_error=True)
resp = self.get_resp("/superset/csv/{}".format(client_id))
data = csv.reader(io.StringIO(resp))
expected_data = csv.reader(io.StringIO("first_name\nadmin\n"))
expected_data = csv.reader(io.StringIO("name\nJames\n"))
self.assertEqual(list(expected_data), list(data))
self.logout()
def test_extra_table_metadata(self):
self.login("admin")
dbid = utils.get_main_database().id
dbid = utils.get_example_database().id
self.get_json_resp(
f"/superset/extra_table_metadata/{dbid}/" "ab_permission_view/panoramix/"
f"/superset/extra_table_metadata/{dbid}/birth_names/superset/"
)
def test_process_template(self):
maindb = utils.get_main_database()
maindb = utils.get_example_database()
sql = "SELECT '{{ datetime(2017, 1, 1).isoformat() }}'"
tp = jinja_context.get_template_processor(database=maindb)
rendered = tp.process_template(sql)
self.assertEqual("SELECT '2017-01-01T00:00:00'", rendered)
def test_get_template_kwarg(self):
maindb = utils.get_main_database()
maindb = utils.get_example_database()
s = "{{ foo }}"
tp = jinja_context.get_template_processor(database=maindb, foo="bar")
rendered = tp.process_template(s)
self.assertEqual("bar", rendered)
def test_template_kwarg(self):
maindb = utils.get_main_database()
maindb = utils.get_example_database()
s = "{{ foo }}"
tp = jinja_context.get_template_processor(database=maindb)
rendered = tp.process_template(s, foo="bar")
@ -517,23 +517,12 @@ class CoreTests(SupersetTestCase):
self.assertEqual(data["data"][0]["test"], "2017-01-01T00:00:00")
def test_table_metadata(self):
maindb = utils.get_main_database()
backend = maindb.backend
data = self.get_json_resp("/superset/table/{}/ab_user/null/".format(maindb.id))
self.assertEqual(data["name"], "ab_user")
maindb = utils.get_example_database()
data = self.get_json_resp(f"/superset/table/{maindb.id}/birth_names/null/")
self.assertEqual(data["name"], "birth_names")
assert len(data["columns"]) > 5
assert data.get("selectStar").startswith("SELECT")
# Engine specific tests
if backend in ("mysql", "postgresql"):
self.assertEqual(data.get("primaryKey").get("type"), "pk")
self.assertEqual(data.get("primaryKey").get("column_names")[0], "id")
self.assertEqual(len(data.get("foreignKeys")), 2)
if backend == "mysql":
self.assertEqual(len(data.get("indexes")), 7)
elif backend == "postgresql":
self.assertEqual(len(data.get("indexes")), 5)
def test_fetch_datasource_metadata(self):
self.login(username="admin")
url = "/superset/fetch_datasource_metadata?" "datasourceKey=1__table"
@ -746,24 +735,11 @@ class CoreTests(SupersetTestCase):
def test_schemas_access_for_csv_upload_endpoint(
self, mock_all_datasource_access, mock_database_access, mock_schemas_accessible
):
self.login(username="admin")
dbobj = self.create_fake_db()
mock_all_datasource_access.return_value = False
mock_database_access.return_value = False
mock_schemas_accessible.return_value = ["this_schema_is_allowed_too"]
database_name = "fake_db_100"
db_id = 100
extra = """{
"schemas_allowed_for_csv_upload":
["this_schema_is_allowed", "this_schema_is_allowed_too"]
}"""
self.login(username="admin")
dbobj = self.get_or_create(
cls=models.Database,
criteria={"database_name": database_name},
session=db.session,
id=db_id,
extra=extra,
)
data = self.get_json_resp(
url="/superset/schemas_access_for_csv_upload?db_id={db_id}".format(
db_id=dbobj.id

View File

@ -23,7 +23,7 @@ import yaml
from superset import db
from superset.connectors.druid.models import DruidColumn, DruidDatasource, DruidMetric
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.utils.core import get_main_database
from superset.utils.core import get_example_database
from .base_tests import SupersetTestCase
DBREF = "dict_import__export_test"
@ -63,7 +63,7 @@ class DictImportExportTests(SupersetTestCase):
params = {DBREF: id, "database_name": database_name}
dict_rep = {
"database_id": get_main_database().id,
"database_id": get_example_database().id,
"table_name": name,
"schema": schema,
"id": id,

View File

@ -22,7 +22,7 @@ from sqlalchemy.engine.url import make_url
from superset import app
from superset.models.core import Database
from superset.utils.core import get_example_database, get_main_database, QueryStatus
from superset.utils.core import get_example_database, QueryStatus
from .base_tests import SupersetTestCase
@ -149,7 +149,7 @@ class DatabaseModelTestCase(SupersetTestCase):
assert sql.startswith(expected)
def test_single_statement(self):
main_db = get_main_database()
main_db = get_example_database()
if main_db.backend == "mysql":
df = main_db.get_df("SELECT 1", None)
@ -159,7 +159,7 @@ class DatabaseModelTestCase(SupersetTestCase):
self.assertEquals(df.iat[0, 0], 1)
def test_multi_statement(self):
main_db = get_main_database()
main_db = get_example_database()
if main_db.backend == "mysql":
df = main_db.get_df("USE superset; SELECT 1", None)

View File

@ -449,41 +449,41 @@ class SupersetTestCase(unittest.TestCase):
self.assertEquals({"SalesOrderHeader"}, self.extract_tables(query))
def test_get_query_with_new_limit_comment(self):
sql = "SELECT * FROM ab_user -- SOME COMMENT"
sql = "SELECT * FROM birth_names -- SOME COMMENT"
parsed = sql_parse.ParsedQuery(sql)
newsql = parsed.get_query_with_new_limit(1000)
self.assertEquals(newsql, sql + "\nLIMIT 1000")
def test_get_query_with_new_limit_comment_with_limit(self):
sql = "SELECT * FROM ab_user -- SOME COMMENT WITH LIMIT 555"
sql = "SELECT * FROM birth_names -- SOME COMMENT WITH LIMIT 555"
parsed = sql_parse.ParsedQuery(sql)
newsql = parsed.get_query_with_new_limit(1000)
self.assertEquals(newsql, sql + "\nLIMIT 1000")
def test_get_query_with_new_limit(self):
sql = "SELECT * FROM ab_user LIMIT 555"
sql = "SELECT * FROM birth_names LIMIT 555"
parsed = sql_parse.ParsedQuery(sql)
newsql = parsed.get_query_with_new_limit(1000)
expected = "SELECT * FROM ab_user LIMIT 1000"
expected = "SELECT * FROM birth_names LIMIT 1000"
self.assertEquals(newsql, expected)
def test_basic_breakdown_statements(self):
multi_sql = """
SELECT * FROM ab_user;
SELECT * FROM ab_user LIMIT 1;
SELECT * FROM birth_names;
SELECT * FROM birth_names LIMIT 1;
"""
parsed = sql_parse.ParsedQuery(multi_sql)
statements = parsed.get_statements()
self.assertEquals(len(statements), 2)
expected = ["SELECT * FROM ab_user", "SELECT * FROM ab_user LIMIT 1"]
expected = ["SELECT * FROM birth_names", "SELECT * FROM birth_names LIMIT 1"]
self.assertEquals(statements, expected)
def test_messy_breakdown_statements(self):
multi_sql = """
SELECT 1;\t\n\n\n \t
\t\nSELECT 2;
SELECT * FROM ab_user;;;
SELECT * FROM ab_user LIMIT 1
SELECT * FROM birth_names;;;
SELECT * FROM birth_names LIMIT 1
"""
parsed = sql_parse.ParsedQuery(multi_sql)
statements = parsed.get_statements()
@ -491,8 +491,8 @@ class SupersetTestCase(unittest.TestCase):
expected = [
"SELECT 1",
"SELECT 2",
"SELECT * FROM ab_user",
"SELECT * FROM ab_user LIMIT 1",
"SELECT * FROM birth_names",
"SELECT * FROM birth_names LIMIT 1",
]
self.assertEquals(statements, expected)

View File

@ -53,7 +53,7 @@ class SqlValidatorEndpointTests(SupersetTestCase):
app.config["SQL_VALIDATORS_BY_ENGINE"] = {}
resp = self.validate_sql(
"SELECT * FROM ab_user", client_id="1", raise_on_error=False
"SELECT * FROM birth_names", client_id="1", raise_on_error=False
)
self.assertIn("error", resp)
self.assertIn("no SQL validator is configured", resp["error"])
@ -97,7 +97,7 @@ class SqlValidatorEndpointTests(SupersetTestCase):
validator.validate.side_effect = Exception("Kaboom!")
resp = self.validate_sql(
"SELECT * FROM ab_user", client_id="1", raise_on_error=False
"SELECT * FROM birth_names", client_id="1", raise_on_error=False
)
self.assertIn("error", resp)
self.assertIn("Kaboom!", resp["error"])
@ -186,7 +186,7 @@ class PrestoValidatorTests(SupersetTestCase):
# validator for sqlite, this test will fail because the validator
# will no longer error out.
resp = self.validate_sql(
"SELECT * FROM ab_user", client_id="1", raise_on_error=False
"SELECT * FROM birth_names", client_id="1", raise_on_error=False
)
self.assertIn("error", resp)
self.assertIn("no SQL validator is configured", resp["error"])

View File

@ -16,7 +16,7 @@
# under the License.
from superset.connectors.sqla.models import SqlaTable, TableColumn
from superset.db_engine_specs.druid import DruidEngineSpec
from superset.utils.core import get_main_database
from superset.utils.core import get_example_database
from .base_tests import SupersetTestCase
@ -43,7 +43,7 @@ class DatabaseModelTestCase(SupersetTestCase):
def test_has_extra_cache_keys(self):
query = "SELECT '{{ cache_key_wrapper('user_1') }}' as user"
table = SqlaTable(sql=query, database=get_main_database())
table = SqlaTable(sql=query, database=get_example_database())
query_obj = {
"granularity": None,
"from_dttm": None,
@ -60,7 +60,7 @@ class DatabaseModelTestCase(SupersetTestCase):
def test_has_no_extra_cache_keys(self):
query = "SELECT 'abc' as user"
table = SqlaTable(sql=query, database=get_main_database())
table = SqlaTable(sql=query, database=get_example_database())
query_obj = {
"granularity": None,
"from_dttm": None,

View File

@ -17,18 +17,20 @@
"""Unit tests for Sql Lab"""
from datetime import datetime, timedelta
import json
import unittest
from flask_appbuilder.security.sqla import models as ab_models
import prison
from superset import db, security_manager
from superset.dataframe import SupersetDataFrame
from superset.db_engine_specs import BaseEngineSpec
from superset.models.sql_lab import Query
from superset.utils.core import datetime_to_epoch, get_main_database
from superset.utils.core import datetime_to_epoch, get_example_database
from .base_tests import SupersetTestCase
QUERY_1 = "SELECT * FROM birth_names LIMIT 1"
QUERY_2 = "SELECT * FROM NO_TABLE"
QUERY_3 = "SELECT * FROM birth_names LIMIT 10"
class SqlLabTests(SupersetTestCase):
"""Testings for Sql Lab"""
@ -39,17 +41,9 @@ class SqlLabTests(SupersetTestCase):
def run_some_queries(self):
db.session.query(Query).delete()
db.session.commit()
self.run_sql(
"SELECT * FROM ab_user", client_id="client_id_1", user_name="admin"
)
self.run_sql(
"SELECT * FROM NO_TABLE", client_id="client_id_3", user_name="admin"
)
self.run_sql(
"SELECT * FROM ab_permission",
client_id="client_id_2",
user_name="gamma_sqllab",
)
self.run_sql(QUERY_1, client_id="client_id_1", user_name="admin")
self.run_sql(QUERY_2, client_id="client_id_3", user_name="admin")
self.run_sql(QUERY_3, client_id="client_id_2", user_name="gamma_sqllab")
self.logout()
def tearDown(self):
@ -61,7 +55,7 @@ class SqlLabTests(SupersetTestCase):
def test_sql_json(self):
self.login("admin")
data = self.run_sql("SELECT * FROM ab_user", "1")
data = self.run_sql("SELECT * FROM birth_names LIMIT 10", "1")
self.assertLess(0, len(data["data"]))
data = self.run_sql("SELECT * FROM unexistant_table", "2")
@ -71,8 +65,8 @@ class SqlLabTests(SupersetTestCase):
self.login("admin")
multi_sql = """
SELECT first_name FROM ab_user;
SELECT first_name FROM ab_user;
SELECT * FROM birth_names LIMIT 1;
SELECT * FROM birth_names LIMIT 2;
"""
data = self.run_sql(multi_sql, "2234")
self.assertLess(0, len(data["data"]))
@ -80,24 +74,18 @@ class SqlLabTests(SupersetTestCase):
def test_explain(self):
self.login("admin")
data = self.run_sql("EXPLAIN SELECT * FROM ab_user", "1")
data = self.run_sql("EXPLAIN SELECT * FROM birth_names", "1")
self.assertLess(0, len(data["data"]))
def test_sql_json_has_access(self):
main_db = get_main_database()
security_manager.add_permission_view_menu("database_access", main_db.perm)
db.session.commit()
main_db_permission_view = (
db.session.query(ab_models.PermissionView)
.join(ab_models.ViewMenu)
.join(ab_models.Permission)
.filter(ab_models.ViewMenu.name == "[main].(id:1)")
.filter(ab_models.Permission.name == "database_access")
.first()
examples_db = get_example_database()
examples_db_permission_view = security_manager.add_permission_view_menu(
"database_access", examples_db.perm
)
astronaut = security_manager.add_role("Astronaut")
security_manager.add_permission_role(astronaut, main_db_permission_view)
# Astronaut role is Gamma + sqllab + main db permissions
security_manager.add_permission_role(astronaut, examples_db_permission_view)
# Astronaut role is Gamma + sqllab + db permissions
for perm in security_manager.find_role("Gamma").permissions:
security_manager.add_permission_role(astronaut, perm)
for perm in security_manager.find_role("sql_lab").permissions:
@ -113,7 +101,7 @@ class SqlLabTests(SupersetTestCase):
astronaut,
password="general",
)
data = self.run_sql("SELECT * FROM ab_user", "3", user_name="gagarin")
data = self.run_sql(QUERY_1, "3", user_name="gagarin")
db.session.query(Query).delete()
db.session.commit()
self.assertLess(0, len(data["data"]))
@ -132,8 +120,8 @@ class SqlLabTests(SupersetTestCase):
self.assertEquals(2, len(data))
# Run 2 more queries
self.run_sql("SELECT * FROM ab_user LIMIT 1", client_id="client_id_4")
self.run_sql("SELECT * FROM ab_user LIMIT 2", client_id="client_id_5")
self.run_sql("SELECT * FROM birth_names LIMIT 1", client_id="client_id_4")
self.run_sql("SELECT * FROM birth_names LIMIT 2", client_id="client_id_5")
self.login("admin")
data = self.get_json_resp("/superset/queries/0")
self.assertEquals(4, len(data))
@ -141,7 +129,7 @@ class SqlLabTests(SupersetTestCase):
now = datetime.now() + timedelta(days=1)
query = (
db.session.query(Query)
.filter_by(sql="SELECT * FROM ab_user LIMIT 1")
.filter_by(sql="SELECT * FROM birth_names LIMIT 1")
.first()
)
query.changed_on = now
@ -160,11 +148,15 @@ class SqlLabTests(SupersetTestCase):
def test_search_query_on_db_id(self):
self.run_some_queries()
self.login("admin")
examples_dbid = get_example_database().id
# Test search queries on database Id
data = self.get_json_resp("/superset/search_queries?database_id=1")
data = self.get_json_resp(
f"/superset/search_queries?database_id={examples_dbid}"
)
self.assertEquals(3, len(data))
db_ids = [k["dbId"] for k in data]
self.assertEquals([1, 1, 1], db_ids)
self.assertEquals([examples_dbid for i in range(3)], db_ids)
resp = self.get_resp("/superset/search_queries?database_id=-1")
data = json.loads(resp)
@ -205,19 +197,19 @@ class SqlLabTests(SupersetTestCase):
def test_search_query_on_text(self):
self.run_some_queries()
self.login("admin")
url = "/superset/search_queries?search_text=permission"
url = "/superset/search_queries?search_text=birth"
data = self.get_json_resp(url)
self.assertEquals(1, len(data))
self.assertIn("permission", data[0]["sql"])
self.assertEquals(2, len(data))
self.assertIn("birth", data[0]["sql"])
def test_search_query_on_time(self):
self.run_some_queries()
self.login("admin")
first_query_time = (
db.session.query(Query).filter_by(sql="SELECT * FROM ab_user").one()
db.session.query(Query).filter_by(sql=QUERY_1).one()
).start_time
second_query_time = (
db.session.query(Query).filter_by(sql="SELECT * FROM ab_permission").one()
db.session.query(Query).filter_by(sql=QUERY_3).one()
).start_time
# Test search queries on time filter
from_time = "from={}".format(int(first_query_time))
@ -265,7 +257,7 @@ class SqlLabTests(SupersetTestCase):
def test_alias_duplicate(self):
self.run_sql(
"SELECT username as col, id as col, username FROM ab_user",
"SELECT name as col, gender as col FROM birth_names LIMIT 10",
client_id="2e2df3",
user_name="admin",
raise_on_error=True,
@ -281,7 +273,7 @@ class SqlLabTests(SupersetTestCase):
def test_df_conversion_tuple(self):
cols = ["string_col", "int_col", "list_col", "float_col"]
data = [(u"Text", 111, [123], 1.0)]
data = [("Text", 111, [123], 1.0)]
cdf = SupersetDataFrame(data, cols, BaseEngineSpec)
self.assertEquals(len(data), cdf.size)
@ -296,6 +288,7 @@ class SqlLabTests(SupersetTestCase):
self.assertEquals(len(cols), len(cdf.columns))
def test_sqllab_viz(self):
examples_dbid = get_example_database().id
payload = {
"chartType": "dist_bar",
"datasourceName": "test_viz_flow_table",
@ -316,11 +309,10 @@ class SqlLabTests(SupersetTestCase):
},
],
"sql": """\
SELECT viz_type, count(1) as ccount
FROM slices
WHERE viz_type LIKE '%a%'
GROUP BY viz_type""",
"dbId": 1,
SELECT *
FROM birth_names
LIMIT 10""",
"dbId": examples_dbid,
}
data = {"data": json.dumps(payload)}
resp = self.get_json_resp("/superset/sqllab_viz/", data=data)
@ -329,20 +321,20 @@ class SqlLabTests(SupersetTestCase):
def test_sql_limit(self):
self.login("admin")
test_limit = 1
data = self.run_sql("SELECT * FROM ab_user", client_id="sql_limit_1")
data = self.run_sql("SELECT * FROM birth_names", client_id="sql_limit_1")
self.assertGreater(len(data["data"]), test_limit)
data = self.run_sql(
"SELECT * FROM ab_user", client_id="sql_limit_2", query_limit=test_limit
"SELECT * FROM birth_names", client_id="sql_limit_2", query_limit=test_limit
)
self.assertEquals(len(data["data"]), test_limit)
data = self.run_sql(
"SELECT * FROM ab_user LIMIT {}".format(test_limit),
"SELECT * FROM birth_names LIMIT {}".format(test_limit),
client_id="sql_limit_3",
query_limit=test_limit + 1,
)
self.assertEquals(len(data["data"]), test_limit)
data = self.run_sql(
"SELECT * FROM ab_user LIMIT {}".format(test_limit + 1),
"SELECT * FROM birth_names LIMIT {}".format(test_limit + 1),
client_id="sql_limit_4",
query_limit=test_limit,
)
@ -406,6 +398,7 @@ class SqlLabTests(SupersetTestCase):
def test_api_database(self):
self.login("admin")
self.create_fake_db()
arguments = {
"keys": [],
@ -415,12 +408,8 @@ class SqlLabTests(SupersetTestCase):
"page": 0,
"page_size": -1,
}
expected_results = ["examples", "fake_db_100", "main"]
url = "api/v1/database/?{}={}".format("q", prison.dumps(arguments))
data = self.get_json_resp(url)
for i, expected_result in enumerate(expected_results):
self.assertEquals(expected_result, data["result"][i]["database_name"])
if __name__ == "__main__":
unittest.main()
self.assertEquals(
{"examples", "fake_db_100"},
{r.get("database_name") for r in self.get_json_resp(url)["result"]},
)