2019-01-15 18:53:27 -05:00
|
|
|
# Licensed to the Apache Software Foundation (ASF) under one
|
|
|
|
# or more contributor license agreements. See the NOTICE file
|
|
|
|
# distributed with this work for additional information
|
|
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
|
|
# to you under the Apache License, Version 2.0 (the
|
|
|
|
# "License"); you may not use this file except in compliance
|
|
|
|
# with the License. You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing,
|
|
|
|
# software distributed under the License is distributed on an
|
|
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
|
|
# KIND, either express or implied. See the License for the
|
|
|
|
# specific language governing permissions and limitations
|
|
|
|
# under the License.
|
2019-11-20 10:47:06 -05:00
|
|
|
# isort:skip_file
|
2016-11-10 02:08:22 -05:00
|
|
|
"""Unit tests for Superset"""
|
2019-06-24 01:37:41 -04:00
|
|
|
import imp
|
2016-10-19 12:17:08 -04:00
|
|
|
import json
|
2019-12-16 16:10:33 -05:00
|
|
|
from typing import Union
|
2019-11-20 10:47:06 -05:00
|
|
|
from unittest.mock import Mock
|
2016-08-30 00:55:31 -04:00
|
|
|
|
2018-08-29 00:04:06 -04:00
|
|
|
import pandas as pd
|
2019-10-18 17:44:27 -04:00
|
|
|
from flask_appbuilder.security.sqla import models as ab_models
|
2019-11-20 10:47:06 -05:00
|
|
|
from flask_testing import TestCase
|
2016-08-30 00:55:31 -04:00
|
|
|
|
2019-11-20 10:47:06 -05:00
|
|
|
from tests.test_app import app # isort:skip
|
|
|
|
from superset import db, security_manager
|
2017-11-07 23:23:40 -05:00
|
|
|
from superset.connectors.druid.models import DruidCluster, DruidDatasource
|
|
|
|
from superset.connectors.sqla.models import SqlaTable
|
2017-03-10 12:11:51 -05:00
|
|
|
from superset.models import core as models
|
2019-12-18 14:40:45 -05:00
|
|
|
from superset.models.slice import Slice
|
2019-02-04 15:34:24 -05:00
|
|
|
from superset.models.core import Database
|
2019-12-18 14:40:45 -05:00
|
|
|
from superset.models.dashboard import Dashboard
|
2019-12-17 19:17:49 -05:00
|
|
|
from superset.models.datasource_access_request import DatasourceAccessRequest
|
2019-09-08 13:18:09 -04:00
|
|
|
from superset.utils.core import get_example_database
|
2016-08-30 00:55:31 -04:00
|
|
|
|
2019-11-20 10:47:06 -05:00
|
|
|
FAKE_DB_NAME = "fake_db_100"
|
2016-08-30 00:55:31 -04:00
|
|
|
|
|
|
|
|
2019-11-20 10:47:06 -05:00
|
|
|
class SupersetTestCase(TestCase):
|
2020-02-20 05:15:22 -05:00
|
|
|
|
|
|
|
default_schema_backend_map = {
|
|
|
|
"sqlite": "main",
|
|
|
|
"mysql": "superset",
|
|
|
|
"postgresql": "public",
|
|
|
|
}
|
|
|
|
|
2016-08-30 00:55:31 -04:00
|
|
|
def __init__(self, *args, **kwargs):
|
2016-11-10 02:08:22 -05:00
|
|
|
super(SupersetTestCase, self).__init__(*args, **kwargs)
|
2016-09-22 12:53:14 -04:00
|
|
|
self.maxDiff = None
|
2016-11-17 14:58:33 -05:00
|
|
|
|
2019-11-20 10:47:06 -05:00
|
|
|
def create_app(self):
|
|
|
|
return app
|
|
|
|
|
2019-12-16 16:10:33 -05:00
|
|
|
@staticmethod
|
|
|
|
def create_user(
|
|
|
|
username: str,
|
|
|
|
password: str,
|
|
|
|
role_name: str,
|
|
|
|
first_name: str = "admin",
|
|
|
|
last_name: str = "user",
|
|
|
|
email: str = "admin@fab.org",
|
|
|
|
) -> Union[ab_models.User, bool]:
|
|
|
|
role_admin = security_manager.find_role(role_name)
|
|
|
|
return security_manager.add_user(
|
|
|
|
username, first_name, last_name, email, role_admin, password
|
|
|
|
)
|
|
|
|
|
2020-01-21 13:04:52 -05:00
|
|
|
@staticmethod
|
|
|
|
def get_user(username: str) -> ab_models.User:
|
|
|
|
user = (
|
|
|
|
db.session.query(security_manager.user_model)
|
|
|
|
.filter_by(username=username)
|
|
|
|
.one_or_none()
|
|
|
|
)
|
|
|
|
return user
|
|
|
|
|
2018-10-16 20:59:34 -04:00
|
|
|
@classmethod
|
|
|
|
def create_druid_test_objects(cls):
|
2016-09-22 12:53:14 -04:00
|
|
|
# create druid cluster and druid datasources
|
2020-01-13 14:02:36 -05:00
|
|
|
|
2019-11-20 10:47:06 -05:00
|
|
|
with app.app_context():
|
|
|
|
session = db.session
|
|
|
|
cluster = (
|
|
|
|
session.query(DruidCluster).filter_by(cluster_name="druid_test").first()
|
2016-09-22 12:53:14 -04:00
|
|
|
)
|
2019-11-20 10:47:06 -05:00
|
|
|
if not cluster:
|
|
|
|
cluster = DruidCluster(cluster_name="druid_test")
|
|
|
|
session.add(cluster)
|
|
|
|
session.commit()
|
|
|
|
|
|
|
|
druid_datasource1 = DruidDatasource(
|
2020-01-13 14:02:36 -05:00
|
|
|
datasource_name="druid_ds_1", cluster=cluster
|
2019-11-20 10:47:06 -05:00
|
|
|
)
|
|
|
|
session.add(druid_datasource1)
|
|
|
|
druid_datasource2 = DruidDatasource(
|
2020-01-13 14:02:36 -05:00
|
|
|
datasource_name="druid_ds_2", cluster=cluster
|
2019-11-20 10:47:06 -05:00
|
|
|
)
|
|
|
|
session.add(druid_datasource2)
|
|
|
|
session.commit()
|
2016-09-22 12:53:14 -04:00
|
|
|
|
2017-01-13 22:30:17 -05:00
|
|
|
def get_table(self, table_id):
|
2019-06-25 16:34:48 -04:00
|
|
|
return db.session.query(SqlaTable).filter_by(id=table_id).one()
|
2017-01-13 22:30:17 -05:00
|
|
|
|
2019-06-24 01:37:41 -04:00
|
|
|
@staticmethod
|
|
|
|
def is_module_installed(module_name):
|
|
|
|
try:
|
|
|
|
imp.find_module(module_name)
|
|
|
|
return True
|
|
|
|
except ImportError:
|
|
|
|
return False
|
|
|
|
|
2018-09-20 14:21:11 -04:00
|
|
|
def get_or_create(self, cls, criteria, session, **kwargs):
|
2016-10-07 19:24:39 -04:00
|
|
|
obj = session.query(cls).filter_by(**criteria).first()
|
|
|
|
if not obj:
|
|
|
|
obj = cls(**criteria)
|
2018-09-20 14:21:11 -04:00
|
|
|
obj.__dict__.update(**kwargs)
|
|
|
|
session.add(obj)
|
|
|
|
session.commit()
|
2016-10-07 19:24:39 -04:00
|
|
|
return obj
|
|
|
|
|
2019-06-25 16:34:48 -04:00
|
|
|
def login(self, username="admin", password="general"):
|
|
|
|
resp = self.get_resp("/login/", data=dict(username=username, password=password))
|
|
|
|
self.assertNotIn("User confirmation needed", resp)
|
2016-08-30 00:55:31 -04:00
|
|
|
|
2016-10-07 19:24:39 -04:00
|
|
|
def get_slice(self, slice_name, session):
|
2019-12-18 14:40:45 -05:00
|
|
|
slc = session.query(Slice).filter_by(slice_name=slice_name).one()
|
2016-10-07 19:24:39 -04:00
|
|
|
session.expunge_all()
|
|
|
|
return slc
|
|
|
|
|
2016-10-20 18:30:09 -04:00
|
|
|
def get_table_by_name(self, name):
|
2018-08-06 18:30:13 -04:00
|
|
|
return db.session.query(SqlaTable).filter_by(table_name=name).one()
|
2016-10-20 18:30:09 -04:00
|
|
|
|
2019-02-04 15:34:24 -05:00
|
|
|
def get_database_by_id(self, db_id):
|
|
|
|
return db.session.query(Database).filter_by(id=db_id).one()
|
|
|
|
|
2016-10-20 18:30:09 -04:00
|
|
|
def get_druid_ds_by_name(self, name):
|
2019-06-25 16:34:48 -04:00
|
|
|
return db.session.query(DruidDatasource).filter_by(datasource_name=name).first()
|
2016-10-20 18:30:09 -04:00
|
|
|
|
2018-08-29 00:04:06 -04:00
|
|
|
def get_datasource_mock(self):
|
|
|
|
datasource = Mock()
|
|
|
|
results = Mock()
|
|
|
|
results.query = Mock()
|
|
|
|
results.status = Mock()
|
|
|
|
results.error_message = None
|
|
|
|
results.df = pd.DataFrame()
|
2019-06-25 16:34:48 -04:00
|
|
|
datasource.type = "table"
|
2018-08-29 00:04:06 -04:00
|
|
|
datasource.query = Mock(return_value=results)
|
|
|
|
mock_dttm_col = Mock()
|
|
|
|
datasource.get_col = Mock(return_value=mock_dttm_col)
|
|
|
|
datasource.query = Mock(return_value=results)
|
|
|
|
datasource.database = Mock()
|
|
|
|
datasource.database.db_engine_spec = Mock()
|
|
|
|
datasource.database.db_engine_spec.mutate_expression_label = lambda x: x
|
|
|
|
return datasource
|
|
|
|
|
2019-09-23 12:09:12 -04:00
|
|
|
def get_resp(
|
|
|
|
self, url, data=None, follow_redirects=True, raise_on_error=True, json_=None
|
|
|
|
):
|
2016-10-02 21:03:19 -04:00
|
|
|
"""Shortcut to get the parsed results while following redirects"""
|
2016-11-17 14:58:33 -05:00
|
|
|
if data:
|
2019-06-25 16:34:48 -04:00
|
|
|
resp = self.client.post(url, data=data, follow_redirects=follow_redirects)
|
2019-09-23 12:09:12 -04:00
|
|
|
elif json_:
|
|
|
|
resp = self.client.post(url, json=json_, follow_redirects=follow_redirects)
|
2016-11-17 14:58:33 -05:00
|
|
|
else:
|
|
|
|
resp = self.client.get(url, follow_redirects=follow_redirects)
|
2016-12-01 18:21:18 -05:00
|
|
|
if raise_on_error and resp.status_code > 400:
|
2019-06-25 16:34:48 -04:00
|
|
|
raise Exception("http request failed with code {}".format(resp.status_code))
|
|
|
|
return resp.data.decode("utf-8")
|
2016-11-17 14:58:33 -05:00
|
|
|
|
2019-09-23 12:09:12 -04:00
|
|
|
def get_json_resp(
|
|
|
|
self, url, data=None, follow_redirects=True, raise_on_error=True, json_=None
|
|
|
|
):
|
2016-10-19 12:17:08 -04:00
|
|
|
"""Shortcut to get the parsed results while following redirects"""
|
2019-09-23 12:09:12 -04:00
|
|
|
resp = self.get_resp(url, data, follow_redirects, raise_on_error, json_)
|
2016-10-19 12:17:08 -04:00
|
|
|
return json.loads(resp)
|
|
|
|
|
2016-09-22 12:53:14 -04:00
|
|
|
def get_access_requests(self, username, ds_type, ds_id):
|
2019-12-17 19:17:49 -05:00
|
|
|
DAR = DatasourceAccessRequest
|
2016-11-30 17:05:09 -05:00
|
|
|
return (
|
|
|
|
db.session.query(DAR)
|
2017-11-10 15:06:22 -05:00
|
|
|
.filter(
|
2018-03-27 19:46:02 -04:00
|
|
|
DAR.created_by == security_manager.find_user(username=username),
|
2016-11-30 17:05:09 -05:00
|
|
|
DAR.datasource_type == ds_type,
|
|
|
|
DAR.datasource_id == ds_id,
|
2017-11-10 15:06:22 -05:00
|
|
|
)
|
|
|
|
.first()
|
2016-11-30 17:05:09 -05:00
|
|
|
)
|
2016-09-22 12:53:14 -04:00
|
|
|
|
2016-08-30 00:55:31 -04:00
|
|
|
def logout(self):
|
2019-06-25 16:34:48 -04:00
|
|
|
self.client.get("/logout/", follow_redirects=True)
|
2016-08-30 00:55:31 -04:00
|
|
|
|
2016-11-17 14:58:33 -05:00
|
|
|
def grant_public_access_to_table(self, table):
|
2019-06-25 16:34:48 -04:00
|
|
|
public_role = security_manager.find_role("Public")
|
2016-08-30 00:55:31 -04:00
|
|
|
perms = db.session.query(ab_models.PermissionView).all()
|
|
|
|
for perm in perms:
|
2019-06-25 16:34:48 -04:00
|
|
|
if (
|
|
|
|
perm.permission.name == "datasource_access"
|
|
|
|
and perm.view_menu
|
|
|
|
and table.perm in perm.view_menu.name
|
|
|
|
):
|
2018-03-27 19:46:02 -04:00
|
|
|
security_manager.add_permission_role(public_role, perm)
|
2016-08-30 00:55:31 -04:00
|
|
|
|
2016-11-17 14:58:33 -05:00
|
|
|
def revoke_public_access_to_table(self, table):
|
2019-06-25 16:34:48 -04:00
|
|
|
public_role = security_manager.find_role("Public")
|
2016-08-30 00:55:31 -04:00
|
|
|
perms = db.session.query(ab_models.PermissionView).all()
|
|
|
|
for perm in perms:
|
2019-06-25 16:34:48 -04:00
|
|
|
if (
|
|
|
|
perm.permission.name == "datasource_access"
|
|
|
|
and perm.view_menu
|
|
|
|
and table.perm in perm.view_menu.name
|
|
|
|
):
|
2018-03-27 19:46:02 -04:00
|
|
|
security_manager.del_permission_role(public_role, perm)
|
2016-11-01 23:48:31 -04:00
|
|
|
|
2019-09-08 13:18:09 -04:00
|
|
|
def _get_database_by_name(self, database_name="main"):
|
|
|
|
if database_name == "examples":
|
|
|
|
return get_example_database()
|
|
|
|
else:
|
|
|
|
raise ValueError("Database doesn't exist")
|
|
|
|
|
2019-06-25 16:34:48 -04:00
|
|
|
def run_sql(
|
|
|
|
self,
|
|
|
|
sql,
|
|
|
|
client_id=None,
|
|
|
|
user_name=None,
|
|
|
|
raise_on_error=False,
|
|
|
|
query_limit=None,
|
2019-09-08 13:18:09 -04:00
|
|
|
database_name="examples",
|
2019-12-09 19:12:40 -05:00
|
|
|
sql_editor_id=None,
|
2020-03-03 12:52:20 -05:00
|
|
|
select_as_cta=False,
|
|
|
|
tmp_table_name=None,
|
2019-06-25 16:34:48 -04:00
|
|
|
):
|
2016-11-17 14:58:33 -05:00
|
|
|
if user_name:
|
|
|
|
self.logout()
|
2019-09-08 13:18:09 -04:00
|
|
|
self.login(username=(user_name or "admin"))
|
|
|
|
dbid = self._get_database_by_name(database_name).id
|
2020-03-03 12:52:20 -05:00
|
|
|
json_payload = {
|
|
|
|
"database_id": dbid,
|
|
|
|
"sql": sql,
|
|
|
|
"client_id": client_id,
|
|
|
|
"queryLimit": query_limit,
|
|
|
|
"sql_editor_id": sql_editor_id,
|
|
|
|
}
|
|
|
|
if tmp_table_name:
|
|
|
|
json_payload["tmp_table_name"] = tmp_table_name
|
|
|
|
if select_as_cta:
|
|
|
|
json_payload["select_as_cta"] = select_as_cta
|
|
|
|
|
2016-11-17 14:58:33 -05:00
|
|
|
resp = self.get_json_resp(
|
2020-03-03 12:52:20 -05:00
|
|
|
"/superset/sql_json/", raise_on_error=False, json_=json_payload
|
2016-11-01 23:48:31 -04:00
|
|
|
)
|
2019-06-25 16:34:48 -04:00
|
|
|
if raise_on_error and "error" in resp:
|
|
|
|
raise Exception("run_sql failed")
|
2016-11-17 14:58:33 -05:00
|
|
|
return resp
|
2019-02-01 16:21:25 -05:00
|
|
|
|
2019-09-08 13:18:09 -04:00
|
|
|
def create_fake_db(self):
|
|
|
|
self.login(username="admin")
|
2019-11-20 10:47:06 -05:00
|
|
|
database_name = FAKE_DB_NAME
|
2019-09-08 13:18:09 -04:00
|
|
|
db_id = 100
|
|
|
|
extra = """{
|
|
|
|
"schemas_allowed_for_csv_upload":
|
|
|
|
["this_schema_is_allowed", "this_schema_is_allowed_too"]
|
|
|
|
}"""
|
|
|
|
|
|
|
|
return self.get_or_create(
|
|
|
|
cls=models.Database,
|
|
|
|
criteria={"database_name": database_name},
|
|
|
|
session=db.session,
|
2020-03-10 12:20:37 -04:00
|
|
|
sqlalchemy_uri="sqlite:///:memory:",
|
2019-09-08 13:18:09 -04:00
|
|
|
id=db_id,
|
|
|
|
extra=extra,
|
|
|
|
)
|
|
|
|
|
2019-11-20 10:47:06 -05:00
|
|
|
def delete_fake_db(self):
|
|
|
|
database = (
|
|
|
|
db.session.query(Database)
|
|
|
|
.filter(Database.database_name == FAKE_DB_NAME)
|
|
|
|
.scalar()
|
|
|
|
)
|
|
|
|
if database:
|
|
|
|
db.session.delete(database)
|
|
|
|
|
2020-04-07 16:00:42 -04:00
|
|
|
def create_fake_presto_db(self):
|
|
|
|
self.login(username="admin")
|
|
|
|
database_name = "presto"
|
|
|
|
db_id = 200
|
|
|
|
return self.get_or_create(
|
|
|
|
cls=models.Database,
|
|
|
|
criteria={"database_name": database_name},
|
|
|
|
session=db.session,
|
|
|
|
sqlalchemy_uri="presto://user@host:8080/hive",
|
|
|
|
id=db_id,
|
|
|
|
)
|
|
|
|
|
|
|
|
def delete_fake_presto_db(self):
|
|
|
|
database = (
|
|
|
|
db.session.query(Database)
|
|
|
|
.filter(Database.database_name == "presto")
|
|
|
|
.scalar()
|
|
|
|
)
|
|
|
|
if database:
|
|
|
|
db.session.delete(database)
|
|
|
|
db.session.commit()
|
|
|
|
|
2019-09-08 13:18:09 -04:00
|
|
|
def validate_sql(
|
|
|
|
self,
|
|
|
|
sql,
|
|
|
|
client_id=None,
|
|
|
|
user_name=None,
|
|
|
|
raise_on_error=False,
|
|
|
|
database_name="examples",
|
|
|
|
):
|
2019-05-06 13:21:02 -04:00
|
|
|
if user_name:
|
|
|
|
self.logout()
|
2019-06-25 16:34:48 -04:00
|
|
|
self.login(username=(user_name if user_name else "admin"))
|
2019-09-08 13:18:09 -04:00
|
|
|
dbid = self._get_database_by_name(database_name).id
|
2019-05-06 13:21:02 -04:00
|
|
|
resp = self.get_json_resp(
|
2019-06-25 16:34:48 -04:00
|
|
|
"/superset/validate_sql_json/",
|
2019-05-06 13:21:02 -04:00
|
|
|
raise_on_error=False,
|
|
|
|
data=dict(database_id=dbid, sql=sql, client_id=client_id),
|
|
|
|
)
|
2019-06-25 16:34:48 -04:00
|
|
|
if raise_on_error and "error" in resp:
|
|
|
|
raise Exception("validate_sql failed")
|
2019-05-06 13:21:02 -04:00
|
|
|
return resp
|
|
|
|
|
2019-07-17 00:36:56 -04:00
|
|
|
def get_dash_by_slug(self, dash_slug):
|
|
|
|
sesh = db.session()
|
2019-12-18 14:40:45 -05:00
|
|
|
return sesh.query(Dashboard).filter_by(slug=dash_slug).first()
|