Prevent 'main' database connection creation (#8038)

* prevent 'main' database connection creation

* fix tests

* removing get_main_database

* Kill get_main_database

* Point to examples tables
This commit is contained in:
Maxime Beauchemin 2019-09-08 10:18:09 -07:00 committed by GitHub
parent 9d350aadf0
commit 68c4c3a0b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 186 additions and 240 deletions

View File

@ -48,7 +48,6 @@ def make_shell_context():
@app.cli.command() @app.cli.command()
def init(): def init():
"""Inits the Superset application""" """Inits the Superset application"""
utils.get_or_create_main_db()
utils.get_example_database() utils.get_example_database()
appbuilder.add_permissions(update_perms=True) appbuilder.add_permissions(update_perms=True)
security_manager.sync_role_definitions() security_manager.sync_role_definitions()
@ -430,75 +429,40 @@ def load_test_users_run():
Syncs permissions for those users/roles Syncs permissions for those users/roles
""" """
if config.get("TESTING"): if config.get("TESTING"):
security_manager.sync_role_definitions()
gamma_sqllab_role = security_manager.add_role("gamma_sqllab")
for perm in security_manager.find_role("Gamma").permissions:
security_manager.add_permission_role(gamma_sqllab_role, perm)
utils.get_or_create_main_db()
db_perm = utils.get_main_database().perm
security_manager.add_permission_view_menu("database_access", db_perm)
db_pvm = security_manager.find_permission_view_menu(
view_menu_name=db_perm, permission_name="database_access"
)
gamma_sqllab_role.permissions.append(db_pvm)
for perm in security_manager.find_role("sql_lab").permissions:
security_manager.add_permission_role(gamma_sqllab_role, perm)
admin = security_manager.find_user("admin") sm = security_manager
if not admin:
security_manager.add_user( examples_db = utils.get_example_database()
"admin",
"admin", examples_pv = sm.add_permission_view_menu("database_access", examples_db.perm)
sm.sync_role_definitions()
gamma_sqllab_role = sm.add_role("gamma_sqllab")
sm.add_permission_role(gamma_sqllab_role, examples_pv)
for role in ["Gamma", "sql_lab"]:
for perm in sm.find_role(role).permissions:
sm.add_permission_role(gamma_sqllab_role, perm)
users = (
("admin", "Admin"),
("gamma", "Gamma"),
("gamma2", "Gamma"),
("gamma_sqllab", "gamma_sqllab"),
("alpha", "Alpha"),
)
for username, role in users:
user = sm.find_user(username)
if not user:
sm.add_user(
username,
username,
"user", "user",
"admin@fab.org", username + "@fab.org",
security_manager.find_role("Admin"), sm.find_role(role),
password="general", password="general",
) )
sm.get_session.commit()
gamma = security_manager.find_user("gamma")
if not gamma:
security_manager.add_user(
"gamma",
"gamma",
"user",
"gamma@fab.org",
security_manager.find_role("Gamma"),
password="general",
)
gamma2 = security_manager.find_user("gamma2")
if not gamma2:
security_manager.add_user(
"gamma2",
"gamma2",
"user",
"gamma2@fab.org",
security_manager.find_role("Gamma"),
password="general",
)
gamma_sqllab_user = security_manager.find_user("gamma_sqllab")
if not gamma_sqllab_user:
security_manager.add_user(
"gamma_sqllab",
"gamma_sqllab",
"user",
"gamma_sqllab@fab.org",
gamma_sqllab_role,
password="general",
)
alpha = security_manager.find_user("alpha")
if not alpha:
security_manager.add_user(
"alpha",
"alpha",
"user",
"alpha@fab.org",
security_manager.find_role("Alpha"),
password="general",
)
security_manager.get_session.commit()
@app.cli.command() @app.cli.command()

View File

@ -200,7 +200,6 @@ class SupersetSecurityManager(SecurityManager):
:param database: The Superset database :param database: The Superset database
:returns: Whether the user can access the Superset database :returns: Whether the user can access the Superset database
""" """
return ( return (
self.all_datasource_access() self.all_datasource_access()
or self.all_database_access() or self.all_database_access()
@ -269,9 +268,9 @@ class SupersetSecurityManager(SecurityManager):
:param tables: The list of denied SQL table names :param tables: The list of denied SQL table names
:returns: The error message :returns: The error message
""" """
quoted_tables = [f"`{t}`" for t in tables]
return f"""You need access to the following tables: {", ".join(tables)}, all return f"""You need access to the following tables: {", ".join(quoted_tables)},
database access or `all_datasource_access` permission""" `all_database_access` or `all_datasource_access` permission"""
def get_table_access_link(self, tables: List[str]) -> Optional[str]: def get_table_access_link(self, tables: List[str]) -> Optional[str]:
""" """

View File

@ -936,10 +936,6 @@ def user_label(user: User) -> Optional[str]:
return None return None
def get_or_create_main_db():
get_main_database()
def get_or_create_db(database_name, sqlalchemy_uri, *args, **kwargs): def get_or_create_db(database_name, sqlalchemy_uri, *args, **kwargs):
from superset import db from superset import db
from superset.models import core as models from superset.models import core as models
@ -957,12 +953,6 @@ def get_or_create_db(database_name, sqlalchemy_uri, *args, **kwargs):
return database return database
def get_main_database():
from superset import conf
return get_or_create_db("main", conf.get("SQLALCHEMY_DATABASE_URI"))
def get_example_database(): def get_example_database():
from superset import conf from superset import conf

View File

@ -2705,11 +2705,7 @@ class Superset(BaseSupersetView):
query.sql, query.database, query.schema query.sql, query.database, query.schema
) )
if rejected_tables: if rejected_tables:
flash( flash(security_manager.get_table_access_error_msg(rejected_tables))
security_manager.get_table_access_error_msg(
"{}".format(rejected_tables)
)
)
return redirect("/") return redirect("/")
blob = None blob = None
if results_backend and query.results_key: if results_backend and query.results_key:

View File

@ -28,7 +28,7 @@ from superset.connectors.druid.models import DruidCluster, DruidDatasource
from superset.connectors.sqla.models import SqlaTable from superset.connectors.sqla.models import SqlaTable
from superset.models import core as models from superset.models import core as models
from superset.models.core import Database from superset.models.core import Database
from superset.utils.core import get_main_database from superset.utils.core import get_example_database
BASE_DIR = app.config.get("BASE_DIR") BASE_DIR = app.config.get("BASE_DIR")
@ -168,6 +168,12 @@ class SupersetTestCase(unittest.TestCase):
): ):
security_manager.del_permission_role(public_role, perm) security_manager.del_permission_role(public_role, perm)
def _get_database_by_name(self, database_name="main"):
if database_name == "examples":
return get_example_database()
else:
raise ValueError("Database doesn't exist")
def run_sql( def run_sql(
self, self,
sql, sql,
@ -175,11 +181,12 @@ class SupersetTestCase(unittest.TestCase):
user_name=None, user_name=None,
raise_on_error=False, raise_on_error=False,
query_limit=None, query_limit=None,
database_name="examples",
): ):
if user_name: if user_name:
self.logout() self.logout()
self.login(username=(user_name if user_name else "admin")) self.login(username=(user_name or "admin"))
dbid = get_main_database().id dbid = self._get_database_by_name(database_name).id
resp = self.get_json_resp( resp = self.get_json_resp(
"/superset/sql_json/", "/superset/sql_json/",
raise_on_error=False, raise_on_error=False,
@ -195,11 +202,35 @@ class SupersetTestCase(unittest.TestCase):
raise Exception("run_sql failed") raise Exception("run_sql failed")
return resp return resp
def validate_sql(self, sql, client_id=None, user_name=None, raise_on_error=False): def create_fake_db(self):
self.login(username="admin")
database_name = "fake_db_100"
db_id = 100
extra = """{
"schemas_allowed_for_csv_upload":
["this_schema_is_allowed", "this_schema_is_allowed_too"]
}"""
return self.get_or_create(
cls=models.Database,
criteria={"database_name": database_name},
session=db.session,
id=db_id,
extra=extra,
)
def validate_sql(
self,
sql,
client_id=None,
user_name=None,
raise_on_error=False,
database_name="examples",
):
if user_name: if user_name:
self.logout() self.logout()
self.login(username=(user_name if user_name else "admin")) self.login(username=(user_name if user_name else "admin"))
dbid = get_main_database().id dbid = self._get_database_by_name(database_name).id
resp = self.get_json_resp( resp = self.get_json_resp(
"/superset/validate_sql_json/", "/superset/validate_sql_json/",
raise_on_error=False, raise_on_error=False,

View File

@ -28,7 +28,7 @@ from superset.db_engine_specs.base import BaseEngineSpec
from superset.models.helpers import QueryStatus from superset.models.helpers import QueryStatus
from superset.models.sql_lab import Query from superset.models.sql_lab import Query
from superset.sql_parse import ParsedQuery from superset.sql_parse import ParsedQuery
from superset.utils.core import get_main_database from superset.utils.core import get_example_database
from .base_tests import SupersetTestCase from .base_tests import SupersetTestCase
@ -132,20 +132,20 @@ class CeleryTestCase(SupersetTestCase):
return json.loads(resp.data) return json.loads(resp.data)
def test_run_sync_query_dont_exist(self): def test_run_sync_query_dont_exist(self):
main_db = get_main_database() main_db = get_example_database()
db_id = main_db.id db_id = main_db.id
sql_dont_exist = "SELECT name FROM table_dont_exist" sql_dont_exist = "SELECT name FROM table_dont_exist"
result1 = self.run_sql(db_id, sql_dont_exist, "1", cta="true") result1 = self.run_sql(db_id, sql_dont_exist, "1", cta="true")
self.assertTrue("error" in result1) self.assertTrue("error" in result1)
def test_run_sync_query_cta(self): def test_run_sync_query_cta(self):
main_db = get_main_database() main_db = get_example_database()
backend = main_db.backend backend = main_db.backend
db_id = main_db.id db_id = main_db.id
tmp_table_name = "tmp_async_22" tmp_table_name = "tmp_async_22"
self.drop_table_if_exists(tmp_table_name, main_db) self.drop_table_if_exists(tmp_table_name, main_db)
perm_name = "can_sql_json" name = "James"
sql_where = "SELECT name FROM ab_permission WHERE name='{}'".format(perm_name) sql_where = f"SELECT name FROM birth_names WHERE name='{name}' LIMIT 1"
result = self.run_sql( result = self.run_sql(
db_id, sql_where, "2", tmp_table=tmp_table_name, cta="true" db_id, sql_where, "2", tmp_table=tmp_table_name, cta="true"
) )
@ -162,9 +162,9 @@ class CeleryTestCase(SupersetTestCase):
self.assertGreater(len(results["data"]), 0) self.assertGreater(len(results["data"]), 0)
def test_run_sync_query_cta_no_data(self): def test_run_sync_query_cta_no_data(self):
main_db = get_main_database() main_db = get_example_database()
db_id = main_db.id db_id = main_db.id
sql_empty_result = "SELECT * FROM ab_user WHERE id=666" sql_empty_result = "SELECT * FROM birth_names WHERE name='random'"
result3 = self.run_sql(db_id, sql_empty_result, "3") result3 = self.run_sql(db_id, sql_empty_result, "3")
self.assertEqual(QueryStatus.SUCCESS, result3["query"]["state"]) self.assertEqual(QueryStatus.SUCCESS, result3["query"]["state"])
self.assertEqual([], result3["data"]) self.assertEqual([], result3["data"])
@ -183,12 +183,12 @@ class CeleryTestCase(SupersetTestCase):
return self.run_sql(db_id, sql) return self.run_sql(db_id, sql)
def test_run_async_query(self): def test_run_async_query(self):
main_db = get_main_database() main_db = get_example_database()
db_id = main_db.id db_id = main_db.id
self.drop_table_if_exists("tmp_async_1", main_db) self.drop_table_if_exists("tmp_async_1", main_db)
sql_where = "SELECT name FROM ab_role WHERE name='Admin'" sql_where = "SELECT name FROM birth_names WHERE name='James' LIMIT 10"
result = self.run_sql( result = self.run_sql(
db_id, sql_where, "4", async_="true", tmp_table="tmp_async_1", cta="true" db_id, sql_where, "4", async_="true", tmp_table="tmp_async_1", cta="true"
) )
@ -202,12 +202,13 @@ class CeleryTestCase(SupersetTestCase):
query = self.get_query_by_id(result["query"]["serverId"]) query = self.get_query_by_id(result["query"]["serverId"])
self.assertEqual(QueryStatus.SUCCESS, query.status) self.assertEqual(QueryStatus.SUCCESS, query.status)
self.assertTrue("FROM tmp_async_1" in query.select_sql) self.assertTrue("FROM tmp_async_1" in query.select_sql)
self.assertEqual( self.assertEqual(
"CREATE TABLE tmp_async_1 AS \n" "CREATE TABLE tmp_async_1 AS \n"
"SELECT name FROM ab_role " "SELECT name FROM birth_names "
"WHERE name='Admin'\n" "WHERE name='James' "
"LIMIT 666", "LIMIT 10",
query.executed_sql, query.executed_sql,
) )
self.assertEqual(sql_where, query.sql) self.assertEqual(sql_where, query.sql)
@ -216,13 +217,14 @@ class CeleryTestCase(SupersetTestCase):
self.assertEqual(True, query.select_as_cta_used) self.assertEqual(True, query.select_as_cta_used)
def test_run_async_query_with_lower_limit(self): def test_run_async_query_with_lower_limit(self):
main_db = get_main_database() main_db = get_example_database()
db_id = main_db.id db_id = main_db.id
self.drop_table_if_exists("tmp_async_2", main_db) tmp_table = "tmp_async_2"
self.drop_table_if_exists(tmp_table, main_db)
sql_where = "SELECT name FROM ab_role WHERE name='Alpha' LIMIT 1" sql_where = "SELECT name FROM birth_names LIMIT 1"
result = self.run_sql( result = self.run_sql(
db_id, sql_where, "5", async_="true", tmp_table="tmp_async_2", cta="true" db_id, sql_where, "5", async_="true", tmp_table=tmp_table, cta="true"
) )
assert result["query"]["state"] in ( assert result["query"]["state"] in (
QueryStatus.PENDING, QueryStatus.PENDING,
@ -234,10 +236,9 @@ class CeleryTestCase(SupersetTestCase):
query = self.get_query_by_id(result["query"]["serverId"]) query = self.get_query_by_id(result["query"]["serverId"])
self.assertEqual(QueryStatus.SUCCESS, query.status) self.assertEqual(QueryStatus.SUCCESS, query.status)
self.assertTrue("FROM tmp_async_2" in query.select_sql) self.assertTrue(f"FROM {tmp_table}" in query.select_sql)
self.assertEqual( self.assertEqual(
"CREATE TABLE tmp_async_2 AS \nSELECT name FROM ab_role " f"CREATE TABLE {tmp_table} AS \n" "SELECT name FROM birth_names LIMIT 1",
"WHERE name='Alpha' LIMIT 1",
query.executed_sql, query.executed_sql,
) )
self.assertEqual(sql_where, query.sql) self.assertEqual(sql_where, query.sql)

View File

@ -346,13 +346,13 @@ class CoreTests(SupersetTestCase):
def test_testconn(self, username="admin"): def test_testconn(self, username="admin"):
self.login(username=username) self.login(username=username)
database = utils.get_main_database() database = utils.get_example_database()
# validate that the endpoint works with the password-masked sqlalchemy uri # validate that the endpoint works with the password-masked sqlalchemy uri
data = json.dumps( data = json.dumps(
{ {
"uri": database.safe_sqlalchemy_uri(), "uri": database.safe_sqlalchemy_uri(),
"name": "main", "name": "examples",
"impersonate_user": False, "impersonate_user": False,
} }
) )
@ -366,7 +366,7 @@ class CoreTests(SupersetTestCase):
data = json.dumps( data = json.dumps(
{ {
"uri": database.sqlalchemy_uri_decrypted, "uri": database.sqlalchemy_uri_decrypted,
"name": "main", "name": "examples",
"impersonate_user": False, "impersonate_user": False,
} }
) )
@ -377,7 +377,7 @@ class CoreTests(SupersetTestCase):
assert response.headers["Content-Type"] == "application/json" assert response.headers["Content-Type"] == "application/json"
def test_custom_password_store(self): def test_custom_password_store(self):
database = utils.get_main_database() database = utils.get_example_database()
conn_pre = sqla.engine.url.make_url(database.sqlalchemy_uri_decrypted) conn_pre = sqla.engine.url.make_url(database.sqlalchemy_uri_decrypted)
def custom_password_store(uri): def custom_password_store(uri):
@ -395,13 +395,13 @@ class CoreTests(SupersetTestCase):
# validate that sending a password-masked uri does not over-write the decrypted # validate that sending a password-masked uri does not over-write the decrypted
# uri # uri
self.login(username=username) self.login(username=username)
database = utils.get_main_database() database = utils.get_example_database()
sqlalchemy_uri_decrypted = database.sqlalchemy_uri_decrypted sqlalchemy_uri_decrypted = database.sqlalchemy_uri_decrypted
url = "databaseview/edit/{}".format(database.id) url = "databaseview/edit/{}".format(database.id)
data = {k: database.__getattribute__(k) for k in DatabaseView.add_columns} data = {k: database.__getattribute__(k) for k in DatabaseView.add_columns}
data["sqlalchemy_uri"] = database.safe_sqlalchemy_uri() data["sqlalchemy_uri"] = database.safe_sqlalchemy_uri()
self.client.post(url, data=data) self.client.post(url, data=data)
database = utils.get_main_database() database = utils.get_example_database()
self.assertEqual(sqlalchemy_uri_decrypted, database.sqlalchemy_uri_decrypted) self.assertEqual(sqlalchemy_uri_decrypted, database.sqlalchemy_uri_decrypted)
def test_warm_up_cache(self): def test_warm_up_cache(self):
@ -460,51 +460,51 @@ class CoreTests(SupersetTestCase):
def test_csv_endpoint(self): def test_csv_endpoint(self):
self.login("admin") self.login("admin")
sql = """ sql = """
SELECT first_name, last_name SELECT name
FROM ab_user FROM birth_names
WHERE first_name='admin' WHERE name = 'James'
LIMIT 1
""" """
client_id = "{}".format(random.getrandbits(64))[:10] client_id = "{}".format(random.getrandbits(64))[:10]
self.run_sql(sql, client_id, raise_on_error=True) self.run_sql(sql, client_id, raise_on_error=True)
resp = self.get_resp("/superset/csv/{}".format(client_id)) resp = self.get_resp("/superset/csv/{}".format(client_id))
data = csv.reader(io.StringIO(resp)) data = csv.reader(io.StringIO(resp))
expected_data = csv.reader(io.StringIO("first_name,last_name\nadmin, user\n")) expected_data = csv.reader(io.StringIO("name\nJames\n"))
sql = "SELECT first_name FROM ab_user WHERE first_name LIKE '%admin%'"
client_id = "{}".format(random.getrandbits(64))[:10] client_id = "{}".format(random.getrandbits(64))[:10]
self.run_sql(sql, client_id, raise_on_error=True) self.run_sql(sql, client_id, raise_on_error=True)
resp = self.get_resp("/superset/csv/{}".format(client_id)) resp = self.get_resp("/superset/csv/{}".format(client_id))
data = csv.reader(io.StringIO(resp)) data = csv.reader(io.StringIO(resp))
expected_data = csv.reader(io.StringIO("first_name\nadmin\n")) expected_data = csv.reader(io.StringIO("name\nJames\n"))
self.assertEqual(list(expected_data), list(data)) self.assertEqual(list(expected_data), list(data))
self.logout() self.logout()
def test_extra_table_metadata(self): def test_extra_table_metadata(self):
self.login("admin") self.login("admin")
dbid = utils.get_main_database().id dbid = utils.get_example_database().id
self.get_json_resp( self.get_json_resp(
f"/superset/extra_table_metadata/{dbid}/" "ab_permission_view/panoramix/" f"/superset/extra_table_metadata/{dbid}/birth_names/superset/"
) )
def test_process_template(self): def test_process_template(self):
maindb = utils.get_main_database() maindb = utils.get_example_database()
sql = "SELECT '{{ datetime(2017, 1, 1).isoformat() }}'" sql = "SELECT '{{ datetime(2017, 1, 1).isoformat() }}'"
tp = jinja_context.get_template_processor(database=maindb) tp = jinja_context.get_template_processor(database=maindb)
rendered = tp.process_template(sql) rendered = tp.process_template(sql)
self.assertEqual("SELECT '2017-01-01T00:00:00'", rendered) self.assertEqual("SELECT '2017-01-01T00:00:00'", rendered)
def test_get_template_kwarg(self): def test_get_template_kwarg(self):
maindb = utils.get_main_database() maindb = utils.get_example_database()
s = "{{ foo }}" s = "{{ foo }}"
tp = jinja_context.get_template_processor(database=maindb, foo="bar") tp = jinja_context.get_template_processor(database=maindb, foo="bar")
rendered = tp.process_template(s) rendered = tp.process_template(s)
self.assertEqual("bar", rendered) self.assertEqual("bar", rendered)
def test_template_kwarg(self): def test_template_kwarg(self):
maindb = utils.get_main_database() maindb = utils.get_example_database()
s = "{{ foo }}" s = "{{ foo }}"
tp = jinja_context.get_template_processor(database=maindb) tp = jinja_context.get_template_processor(database=maindb)
rendered = tp.process_template(s, foo="bar") rendered = tp.process_template(s, foo="bar")
@ -517,23 +517,12 @@ class CoreTests(SupersetTestCase):
self.assertEqual(data["data"][0]["test"], "2017-01-01T00:00:00") self.assertEqual(data["data"][0]["test"], "2017-01-01T00:00:00")
def test_table_metadata(self): def test_table_metadata(self):
maindb = utils.get_main_database() maindb = utils.get_example_database()
backend = maindb.backend data = self.get_json_resp(f"/superset/table/{maindb.id}/birth_names/null/")
data = self.get_json_resp("/superset/table/{}/ab_user/null/".format(maindb.id)) self.assertEqual(data["name"], "birth_names")
self.assertEqual(data["name"], "ab_user")
assert len(data["columns"]) > 5 assert len(data["columns"]) > 5
assert data.get("selectStar").startswith("SELECT") assert data.get("selectStar").startswith("SELECT")
# Engine specific tests
if backend in ("mysql", "postgresql"):
self.assertEqual(data.get("primaryKey").get("type"), "pk")
self.assertEqual(data.get("primaryKey").get("column_names")[0], "id")
self.assertEqual(len(data.get("foreignKeys")), 2)
if backend == "mysql":
self.assertEqual(len(data.get("indexes")), 7)
elif backend == "postgresql":
self.assertEqual(len(data.get("indexes")), 5)
def test_fetch_datasource_metadata(self): def test_fetch_datasource_metadata(self):
self.login(username="admin") self.login(username="admin")
url = "/superset/fetch_datasource_metadata?" "datasourceKey=1__table" url = "/superset/fetch_datasource_metadata?" "datasourceKey=1__table"
@ -746,24 +735,11 @@ class CoreTests(SupersetTestCase):
def test_schemas_access_for_csv_upload_endpoint( def test_schemas_access_for_csv_upload_endpoint(
self, mock_all_datasource_access, mock_database_access, mock_schemas_accessible self, mock_all_datasource_access, mock_database_access, mock_schemas_accessible
): ):
self.login(username="admin")
dbobj = self.create_fake_db()
mock_all_datasource_access.return_value = False mock_all_datasource_access.return_value = False
mock_database_access.return_value = False mock_database_access.return_value = False
mock_schemas_accessible.return_value = ["this_schema_is_allowed_too"] mock_schemas_accessible.return_value = ["this_schema_is_allowed_too"]
database_name = "fake_db_100"
db_id = 100
extra = """{
"schemas_allowed_for_csv_upload":
["this_schema_is_allowed", "this_schema_is_allowed_too"]
}"""
self.login(username="admin")
dbobj = self.get_or_create(
cls=models.Database,
criteria={"database_name": database_name},
session=db.session,
id=db_id,
extra=extra,
)
data = self.get_json_resp( data = self.get_json_resp(
url="/superset/schemas_access_for_csv_upload?db_id={db_id}".format( url="/superset/schemas_access_for_csv_upload?db_id={db_id}".format(
db_id=dbobj.id db_id=dbobj.id

View File

@ -23,7 +23,7 @@ import yaml
from superset import db from superset import db
from superset.connectors.druid.models import DruidColumn, DruidDatasource, DruidMetric from superset.connectors.druid.models import DruidColumn, DruidDatasource, DruidMetric
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.utils.core import get_main_database from superset.utils.core import get_example_database
from .base_tests import SupersetTestCase from .base_tests import SupersetTestCase
DBREF = "dict_import__export_test" DBREF = "dict_import__export_test"
@ -63,7 +63,7 @@ class DictImportExportTests(SupersetTestCase):
params = {DBREF: id, "database_name": database_name} params = {DBREF: id, "database_name": database_name}
dict_rep = { dict_rep = {
"database_id": get_main_database().id, "database_id": get_example_database().id,
"table_name": name, "table_name": name,
"schema": schema, "schema": schema,
"id": id, "id": id,

View File

@ -22,7 +22,7 @@ from sqlalchemy.engine.url import make_url
from superset import app from superset import app
from superset.models.core import Database from superset.models.core import Database
from superset.utils.core import get_example_database, get_main_database, QueryStatus from superset.utils.core import get_example_database, QueryStatus
from .base_tests import SupersetTestCase from .base_tests import SupersetTestCase
@ -149,7 +149,7 @@ class DatabaseModelTestCase(SupersetTestCase):
assert sql.startswith(expected) assert sql.startswith(expected)
def test_single_statement(self): def test_single_statement(self):
main_db = get_main_database() main_db = get_example_database()
if main_db.backend == "mysql": if main_db.backend == "mysql":
df = main_db.get_df("SELECT 1", None) df = main_db.get_df("SELECT 1", None)
@ -159,7 +159,7 @@ class DatabaseModelTestCase(SupersetTestCase):
self.assertEquals(df.iat[0, 0], 1) self.assertEquals(df.iat[0, 0], 1)
def test_multi_statement(self): def test_multi_statement(self):
main_db = get_main_database() main_db = get_example_database()
if main_db.backend == "mysql": if main_db.backend == "mysql":
df = main_db.get_df("USE superset; SELECT 1", None) df = main_db.get_df("USE superset; SELECT 1", None)

View File

@ -449,41 +449,41 @@ class SupersetTestCase(unittest.TestCase):
self.assertEquals({"SalesOrderHeader"}, self.extract_tables(query)) self.assertEquals({"SalesOrderHeader"}, self.extract_tables(query))
def test_get_query_with_new_limit_comment(self): def test_get_query_with_new_limit_comment(self):
sql = "SELECT * FROM ab_user -- SOME COMMENT" sql = "SELECT * FROM birth_names -- SOME COMMENT"
parsed = sql_parse.ParsedQuery(sql) parsed = sql_parse.ParsedQuery(sql)
newsql = parsed.get_query_with_new_limit(1000) newsql = parsed.get_query_with_new_limit(1000)
self.assertEquals(newsql, sql + "\nLIMIT 1000") self.assertEquals(newsql, sql + "\nLIMIT 1000")
def test_get_query_with_new_limit_comment_with_limit(self): def test_get_query_with_new_limit_comment_with_limit(self):
sql = "SELECT * FROM ab_user -- SOME COMMENT WITH LIMIT 555" sql = "SELECT * FROM birth_names -- SOME COMMENT WITH LIMIT 555"
parsed = sql_parse.ParsedQuery(sql) parsed = sql_parse.ParsedQuery(sql)
newsql = parsed.get_query_with_new_limit(1000) newsql = parsed.get_query_with_new_limit(1000)
self.assertEquals(newsql, sql + "\nLIMIT 1000") self.assertEquals(newsql, sql + "\nLIMIT 1000")
def test_get_query_with_new_limit(self): def test_get_query_with_new_limit(self):
sql = "SELECT * FROM ab_user LIMIT 555" sql = "SELECT * FROM birth_names LIMIT 555"
parsed = sql_parse.ParsedQuery(sql) parsed = sql_parse.ParsedQuery(sql)
newsql = parsed.get_query_with_new_limit(1000) newsql = parsed.get_query_with_new_limit(1000)
expected = "SELECT * FROM ab_user LIMIT 1000" expected = "SELECT * FROM birth_names LIMIT 1000"
self.assertEquals(newsql, expected) self.assertEquals(newsql, expected)
def test_basic_breakdown_statements(self): def test_basic_breakdown_statements(self):
multi_sql = """ multi_sql = """
SELECT * FROM ab_user; SELECT * FROM birth_names;
SELECT * FROM ab_user LIMIT 1; SELECT * FROM birth_names LIMIT 1;
""" """
parsed = sql_parse.ParsedQuery(multi_sql) parsed = sql_parse.ParsedQuery(multi_sql)
statements = parsed.get_statements() statements = parsed.get_statements()
self.assertEquals(len(statements), 2) self.assertEquals(len(statements), 2)
expected = ["SELECT * FROM ab_user", "SELECT * FROM ab_user LIMIT 1"] expected = ["SELECT * FROM birth_names", "SELECT * FROM birth_names LIMIT 1"]
self.assertEquals(statements, expected) self.assertEquals(statements, expected)
def test_messy_breakdown_statements(self): def test_messy_breakdown_statements(self):
multi_sql = """ multi_sql = """
SELECT 1;\t\n\n\n \t SELECT 1;\t\n\n\n \t
\t\nSELECT 2; \t\nSELECT 2;
SELECT * FROM ab_user;;; SELECT * FROM birth_names;;;
SELECT * FROM ab_user LIMIT 1 SELECT * FROM birth_names LIMIT 1
""" """
parsed = sql_parse.ParsedQuery(multi_sql) parsed = sql_parse.ParsedQuery(multi_sql)
statements = parsed.get_statements() statements = parsed.get_statements()
@ -491,8 +491,8 @@ class SupersetTestCase(unittest.TestCase):
expected = [ expected = [
"SELECT 1", "SELECT 1",
"SELECT 2", "SELECT 2",
"SELECT * FROM ab_user", "SELECT * FROM birth_names",
"SELECT * FROM ab_user LIMIT 1", "SELECT * FROM birth_names LIMIT 1",
] ]
self.assertEquals(statements, expected) self.assertEquals(statements, expected)

View File

@ -53,7 +53,7 @@ class SqlValidatorEndpointTests(SupersetTestCase):
app.config["SQL_VALIDATORS_BY_ENGINE"] = {} app.config["SQL_VALIDATORS_BY_ENGINE"] = {}
resp = self.validate_sql( resp = self.validate_sql(
"SELECT * FROM ab_user", client_id="1", raise_on_error=False "SELECT * FROM birth_names", client_id="1", raise_on_error=False
) )
self.assertIn("error", resp) self.assertIn("error", resp)
self.assertIn("no SQL validator is configured", resp["error"]) self.assertIn("no SQL validator is configured", resp["error"])
@ -97,7 +97,7 @@ class SqlValidatorEndpointTests(SupersetTestCase):
validator.validate.side_effect = Exception("Kaboom!") validator.validate.side_effect = Exception("Kaboom!")
resp = self.validate_sql( resp = self.validate_sql(
"SELECT * FROM ab_user", client_id="1", raise_on_error=False "SELECT * FROM birth_names", client_id="1", raise_on_error=False
) )
self.assertIn("error", resp) self.assertIn("error", resp)
self.assertIn("Kaboom!", resp["error"]) self.assertIn("Kaboom!", resp["error"])
@ -186,7 +186,7 @@ class PrestoValidatorTests(SupersetTestCase):
# validator for sqlite, this test will fail because the validator # validator for sqlite, this test will fail because the validator
# will no longer error out. # will no longer error out.
resp = self.validate_sql( resp = self.validate_sql(
"SELECT * FROM ab_user", client_id="1", raise_on_error=False "SELECT * FROM birth_names", client_id="1", raise_on_error=False
) )
self.assertIn("error", resp) self.assertIn("error", resp)
self.assertIn("no SQL validator is configured", resp["error"]) self.assertIn("no SQL validator is configured", resp["error"])

View File

@ -16,7 +16,7 @@
# under the License. # under the License.
from superset.connectors.sqla.models import SqlaTable, TableColumn from superset.connectors.sqla.models import SqlaTable, TableColumn
from superset.db_engine_specs.druid import DruidEngineSpec from superset.db_engine_specs.druid import DruidEngineSpec
from superset.utils.core import get_main_database from superset.utils.core import get_example_database
from .base_tests import SupersetTestCase from .base_tests import SupersetTestCase
@ -43,7 +43,7 @@ class DatabaseModelTestCase(SupersetTestCase):
def test_has_extra_cache_keys(self): def test_has_extra_cache_keys(self):
query = "SELECT '{{ cache_key_wrapper('user_1') }}' as user" query = "SELECT '{{ cache_key_wrapper('user_1') }}' as user"
table = SqlaTable(sql=query, database=get_main_database()) table = SqlaTable(sql=query, database=get_example_database())
query_obj = { query_obj = {
"granularity": None, "granularity": None,
"from_dttm": None, "from_dttm": None,
@ -60,7 +60,7 @@ class DatabaseModelTestCase(SupersetTestCase):
def test_has_no_extra_cache_keys(self): def test_has_no_extra_cache_keys(self):
query = "SELECT 'abc' as user" query = "SELECT 'abc' as user"
table = SqlaTable(sql=query, database=get_main_database()) table = SqlaTable(sql=query, database=get_example_database())
query_obj = { query_obj = {
"granularity": None, "granularity": None,
"from_dttm": None, "from_dttm": None,

View File

@ -17,18 +17,20 @@
"""Unit tests for Sql Lab""" """Unit tests for Sql Lab"""
from datetime import datetime, timedelta from datetime import datetime, timedelta
import json import json
import unittest
from flask_appbuilder.security.sqla import models as ab_models
import prison import prison
from superset import db, security_manager from superset import db, security_manager
from superset.dataframe import SupersetDataFrame from superset.dataframe import SupersetDataFrame
from superset.db_engine_specs import BaseEngineSpec from superset.db_engine_specs import BaseEngineSpec
from superset.models.sql_lab import Query from superset.models.sql_lab import Query
from superset.utils.core import datetime_to_epoch, get_main_database from superset.utils.core import datetime_to_epoch, get_example_database
from .base_tests import SupersetTestCase from .base_tests import SupersetTestCase
QUERY_1 = "SELECT * FROM birth_names LIMIT 1"
QUERY_2 = "SELECT * FROM NO_TABLE"
QUERY_3 = "SELECT * FROM birth_names LIMIT 10"
class SqlLabTests(SupersetTestCase): class SqlLabTests(SupersetTestCase):
"""Testings for Sql Lab""" """Testings for Sql Lab"""
@ -39,17 +41,9 @@ class SqlLabTests(SupersetTestCase):
def run_some_queries(self): def run_some_queries(self):
db.session.query(Query).delete() db.session.query(Query).delete()
db.session.commit() db.session.commit()
self.run_sql( self.run_sql(QUERY_1, client_id="client_id_1", user_name="admin")
"SELECT * FROM ab_user", client_id="client_id_1", user_name="admin" self.run_sql(QUERY_2, client_id="client_id_3", user_name="admin")
) self.run_sql(QUERY_3, client_id="client_id_2", user_name="gamma_sqllab")
self.run_sql(
"SELECT * FROM NO_TABLE", client_id="client_id_3", user_name="admin"
)
self.run_sql(
"SELECT * FROM ab_permission",
client_id="client_id_2",
user_name="gamma_sqllab",
)
self.logout() self.logout()
def tearDown(self): def tearDown(self):
@ -61,7 +55,7 @@ class SqlLabTests(SupersetTestCase):
def test_sql_json(self): def test_sql_json(self):
self.login("admin") self.login("admin")
data = self.run_sql("SELECT * FROM ab_user", "1") data = self.run_sql("SELECT * FROM birth_names LIMIT 10", "1")
self.assertLess(0, len(data["data"])) self.assertLess(0, len(data["data"]))
data = self.run_sql("SELECT * FROM unexistant_table", "2") data = self.run_sql("SELECT * FROM unexistant_table", "2")
@ -71,8 +65,8 @@ class SqlLabTests(SupersetTestCase):
self.login("admin") self.login("admin")
multi_sql = """ multi_sql = """
SELECT first_name FROM ab_user; SELECT * FROM birth_names LIMIT 1;
SELECT first_name FROM ab_user; SELECT * FROM birth_names LIMIT 2;
""" """
data = self.run_sql(multi_sql, "2234") data = self.run_sql(multi_sql, "2234")
self.assertLess(0, len(data["data"])) self.assertLess(0, len(data["data"]))
@ -80,24 +74,18 @@ class SqlLabTests(SupersetTestCase):
def test_explain(self): def test_explain(self):
self.login("admin") self.login("admin")
data = self.run_sql("EXPLAIN SELECT * FROM ab_user", "1") data = self.run_sql("EXPLAIN SELECT * FROM birth_names", "1")
self.assertLess(0, len(data["data"])) self.assertLess(0, len(data["data"]))
def test_sql_json_has_access(self): def test_sql_json_has_access(self):
main_db = get_main_database() examples_db = get_example_database()
security_manager.add_permission_view_menu("database_access", main_db.perm) examples_db_permission_view = security_manager.add_permission_view_menu(
db.session.commit() "database_access", examples_db.perm
main_db_permission_view = (
db.session.query(ab_models.PermissionView)
.join(ab_models.ViewMenu)
.join(ab_models.Permission)
.filter(ab_models.ViewMenu.name == "[main].(id:1)")
.filter(ab_models.Permission.name == "database_access")
.first()
) )
astronaut = security_manager.add_role("Astronaut") astronaut = security_manager.add_role("Astronaut")
security_manager.add_permission_role(astronaut, main_db_permission_view) security_manager.add_permission_role(astronaut, examples_db_permission_view)
# Astronaut role is Gamma + sqllab + main db permissions # Astronaut role is Gamma + sqllab + db permissions
for perm in security_manager.find_role("Gamma").permissions: for perm in security_manager.find_role("Gamma").permissions:
security_manager.add_permission_role(astronaut, perm) security_manager.add_permission_role(astronaut, perm)
for perm in security_manager.find_role("sql_lab").permissions: for perm in security_manager.find_role("sql_lab").permissions:
@ -113,7 +101,7 @@ class SqlLabTests(SupersetTestCase):
astronaut, astronaut,
password="general", password="general",
) )
data = self.run_sql("SELECT * FROM ab_user", "3", user_name="gagarin") data = self.run_sql(QUERY_1, "3", user_name="gagarin")
db.session.query(Query).delete() db.session.query(Query).delete()
db.session.commit() db.session.commit()
self.assertLess(0, len(data["data"])) self.assertLess(0, len(data["data"]))
@ -132,8 +120,8 @@ class SqlLabTests(SupersetTestCase):
self.assertEquals(2, len(data)) self.assertEquals(2, len(data))
# Run 2 more queries # Run 2 more queries
self.run_sql("SELECT * FROM ab_user LIMIT 1", client_id="client_id_4") self.run_sql("SELECT * FROM birth_names LIMIT 1", client_id="client_id_4")
self.run_sql("SELECT * FROM ab_user LIMIT 2", client_id="client_id_5") self.run_sql("SELECT * FROM birth_names LIMIT 2", client_id="client_id_5")
self.login("admin") self.login("admin")
data = self.get_json_resp("/superset/queries/0") data = self.get_json_resp("/superset/queries/0")
self.assertEquals(4, len(data)) self.assertEquals(4, len(data))
@ -141,7 +129,7 @@ class SqlLabTests(SupersetTestCase):
now = datetime.now() + timedelta(days=1) now = datetime.now() + timedelta(days=1)
query = ( query = (
db.session.query(Query) db.session.query(Query)
.filter_by(sql="SELECT * FROM ab_user LIMIT 1") .filter_by(sql="SELECT * FROM birth_names LIMIT 1")
.first() .first()
) )
query.changed_on = now query.changed_on = now
@ -160,11 +148,15 @@ class SqlLabTests(SupersetTestCase):
def test_search_query_on_db_id(self): def test_search_query_on_db_id(self):
self.run_some_queries() self.run_some_queries()
self.login("admin") self.login("admin")
examples_dbid = get_example_database().id
# Test search queries on database Id # Test search queries on database Id
data = self.get_json_resp("/superset/search_queries?database_id=1") data = self.get_json_resp(
f"/superset/search_queries?database_id={examples_dbid}"
)
self.assertEquals(3, len(data)) self.assertEquals(3, len(data))
db_ids = [k["dbId"] for k in data] db_ids = [k["dbId"] for k in data]
self.assertEquals([1, 1, 1], db_ids) self.assertEquals([examples_dbid for i in range(3)], db_ids)
resp = self.get_resp("/superset/search_queries?database_id=-1") resp = self.get_resp("/superset/search_queries?database_id=-1")
data = json.loads(resp) data = json.loads(resp)
@ -205,19 +197,19 @@ class SqlLabTests(SupersetTestCase):
def test_search_query_on_text(self): def test_search_query_on_text(self):
self.run_some_queries() self.run_some_queries()
self.login("admin") self.login("admin")
url = "/superset/search_queries?search_text=permission" url = "/superset/search_queries?search_text=birth"
data = self.get_json_resp(url) data = self.get_json_resp(url)
self.assertEquals(1, len(data)) self.assertEquals(2, len(data))
self.assertIn("permission", data[0]["sql"]) self.assertIn("birth", data[0]["sql"])
def test_search_query_on_time(self): def test_search_query_on_time(self):
self.run_some_queries() self.run_some_queries()
self.login("admin") self.login("admin")
first_query_time = ( first_query_time = (
db.session.query(Query).filter_by(sql="SELECT * FROM ab_user").one() db.session.query(Query).filter_by(sql=QUERY_1).one()
).start_time ).start_time
second_query_time = ( second_query_time = (
db.session.query(Query).filter_by(sql="SELECT * FROM ab_permission").one() db.session.query(Query).filter_by(sql=QUERY_3).one()
).start_time ).start_time
# Test search queries on time filter # Test search queries on time filter
from_time = "from={}".format(int(first_query_time)) from_time = "from={}".format(int(first_query_time))
@ -265,7 +257,7 @@ class SqlLabTests(SupersetTestCase):
def test_alias_duplicate(self): def test_alias_duplicate(self):
self.run_sql( self.run_sql(
"SELECT username as col, id as col, username FROM ab_user", "SELECT name as col, gender as col FROM birth_names LIMIT 10",
client_id="2e2df3", client_id="2e2df3",
user_name="admin", user_name="admin",
raise_on_error=True, raise_on_error=True,
@ -281,7 +273,7 @@ class SqlLabTests(SupersetTestCase):
def test_df_conversion_tuple(self): def test_df_conversion_tuple(self):
cols = ["string_col", "int_col", "list_col", "float_col"] cols = ["string_col", "int_col", "list_col", "float_col"]
data = [(u"Text", 111, [123], 1.0)] data = [("Text", 111, [123], 1.0)]
cdf = SupersetDataFrame(data, cols, BaseEngineSpec) cdf = SupersetDataFrame(data, cols, BaseEngineSpec)
self.assertEquals(len(data), cdf.size) self.assertEquals(len(data), cdf.size)
@ -296,6 +288,7 @@ class SqlLabTests(SupersetTestCase):
self.assertEquals(len(cols), len(cdf.columns)) self.assertEquals(len(cols), len(cdf.columns))
def test_sqllab_viz(self): def test_sqllab_viz(self):
examples_dbid = get_example_database().id
payload = { payload = {
"chartType": "dist_bar", "chartType": "dist_bar",
"datasourceName": "test_viz_flow_table", "datasourceName": "test_viz_flow_table",
@ -316,11 +309,10 @@ class SqlLabTests(SupersetTestCase):
}, },
], ],
"sql": """\ "sql": """\
SELECT viz_type, count(1) as ccount SELECT *
FROM slices FROM birth_names
WHERE viz_type LIKE '%a%' LIMIT 10""",
GROUP BY viz_type""", "dbId": examples_dbid,
"dbId": 1,
} }
data = {"data": json.dumps(payload)} data = {"data": json.dumps(payload)}
resp = self.get_json_resp("/superset/sqllab_viz/", data=data) resp = self.get_json_resp("/superset/sqllab_viz/", data=data)
@ -329,20 +321,20 @@ class SqlLabTests(SupersetTestCase):
def test_sql_limit(self): def test_sql_limit(self):
self.login("admin") self.login("admin")
test_limit = 1 test_limit = 1
data = self.run_sql("SELECT * FROM ab_user", client_id="sql_limit_1") data = self.run_sql("SELECT * FROM birth_names", client_id="sql_limit_1")
self.assertGreater(len(data["data"]), test_limit) self.assertGreater(len(data["data"]), test_limit)
data = self.run_sql( data = self.run_sql(
"SELECT * FROM ab_user", client_id="sql_limit_2", query_limit=test_limit "SELECT * FROM birth_names", client_id="sql_limit_2", query_limit=test_limit
) )
self.assertEquals(len(data["data"]), test_limit) self.assertEquals(len(data["data"]), test_limit)
data = self.run_sql( data = self.run_sql(
"SELECT * FROM ab_user LIMIT {}".format(test_limit), "SELECT * FROM birth_names LIMIT {}".format(test_limit),
client_id="sql_limit_3", client_id="sql_limit_3",
query_limit=test_limit + 1, query_limit=test_limit + 1,
) )
self.assertEquals(len(data["data"]), test_limit) self.assertEquals(len(data["data"]), test_limit)
data = self.run_sql( data = self.run_sql(
"SELECT * FROM ab_user LIMIT {}".format(test_limit + 1), "SELECT * FROM birth_names LIMIT {}".format(test_limit + 1),
client_id="sql_limit_4", client_id="sql_limit_4",
query_limit=test_limit, query_limit=test_limit,
) )
@ -406,6 +398,7 @@ class SqlLabTests(SupersetTestCase):
def test_api_database(self): def test_api_database(self):
self.login("admin") self.login("admin")
self.create_fake_db()
arguments = { arguments = {
"keys": [], "keys": [],
@ -415,12 +408,8 @@ class SqlLabTests(SupersetTestCase):
"page": 0, "page": 0,
"page_size": -1, "page_size": -1,
} }
expected_results = ["examples", "fake_db_100", "main"]
url = "api/v1/database/?{}={}".format("q", prison.dumps(arguments)) url = "api/v1/database/?{}={}".format("q", prison.dumps(arguments))
data = self.get_json_resp(url) self.assertEquals(
for i, expected_result in enumerate(expected_results): {"examples", "fake_db_100"},
self.assertEquals(expected_result, data["result"][i]["database_name"]) {r.get("database_name") for r in self.get_json_resp(url)["result"]},
)
if __name__ == "__main__":
unittest.main()