mirror of
https://github.com/apache/superset.git
synced 2024-09-13 00:59:37 -04:00
0584e3629f
* Add separate limit setting for SqlLab Use separate param for wrap sql Get query limit from config unit tests for limit control rendering in sql editor py unit test pg tests Add max rows limit Remove concept of infinity, always require defined limits consistency Assert on validation errors instead of tooltip fix unit tests attempt persist state pr comments and linting * load configs in via common param * default to 1k
172 lines
6.1 KiB
Python
172 lines
6.1 KiB
Python
"""Unit tests for Superset"""
|
|
import json
|
|
import unittest
|
|
|
|
from flask_appbuilder.security.sqla import models as ab_models
|
|
from mock import Mock
|
|
import pandas as pd
|
|
|
|
from superset import app, db, security_manager
|
|
from superset.connectors.druid.models import DruidCluster, DruidDatasource
|
|
from superset.connectors.sqla.models import SqlaTable
|
|
from superset.models import core as models
|
|
from superset.utils.core import get_main_database
|
|
|
|
BASE_DIR = app.config.get('BASE_DIR')
|
|
|
|
|
|
class SupersetTestCase(unittest.TestCase):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super(SupersetTestCase, self).__init__(*args, **kwargs)
|
|
self.client = app.test_client()
|
|
self.maxDiff = None
|
|
|
|
@classmethod
|
|
def create_druid_test_objects(cls):
|
|
# create druid cluster and druid datasources
|
|
session = db.session
|
|
cluster = (
|
|
session.query(DruidCluster)
|
|
.filter_by(cluster_name='druid_test')
|
|
.first()
|
|
)
|
|
if not cluster:
|
|
cluster = DruidCluster(cluster_name='druid_test')
|
|
session.add(cluster)
|
|
session.commit()
|
|
|
|
druid_datasource1 = DruidDatasource(
|
|
datasource_name='druid_ds_1',
|
|
cluster_name='druid_test',
|
|
)
|
|
session.add(druid_datasource1)
|
|
druid_datasource2 = DruidDatasource(
|
|
datasource_name='druid_ds_2',
|
|
cluster_name='druid_test',
|
|
)
|
|
session.add(druid_datasource2)
|
|
session.commit()
|
|
|
|
def get_table(self, table_id):
|
|
return (
|
|
db.session
|
|
.query(SqlaTable)
|
|
.filter_by(id=table_id)
|
|
.one()
|
|
)
|
|
|
|
def get_or_create(self, cls, criteria, session, **kwargs):
|
|
obj = session.query(cls).filter_by(**criteria).first()
|
|
if not obj:
|
|
obj = cls(**criteria)
|
|
obj.__dict__.update(**kwargs)
|
|
session.add(obj)
|
|
session.commit()
|
|
return obj
|
|
|
|
def login(self, username='admin', password='general'):
|
|
resp = self.get_resp(
|
|
'/login/',
|
|
data=dict(username=username, password=password))
|
|
self.assertNotIn('User confirmation needed', resp)
|
|
|
|
def get_slice(self, slice_name, session):
|
|
slc = (
|
|
session.query(models.Slice)
|
|
.filter_by(slice_name=slice_name)
|
|
.one()
|
|
)
|
|
session.expunge_all()
|
|
return slc
|
|
|
|
def get_table_by_name(self, name):
|
|
return db.session.query(SqlaTable).filter_by(table_name=name).one()
|
|
|
|
def get_druid_ds_by_name(self, name):
|
|
return db.session.query(DruidDatasource).filter_by(
|
|
datasource_name=name).first()
|
|
|
|
def get_datasource_mock(self):
|
|
datasource = Mock()
|
|
results = Mock()
|
|
results.query = Mock()
|
|
results.status = Mock()
|
|
results.error_message = None
|
|
results.df = pd.DataFrame()
|
|
datasource.type = 'table'
|
|
datasource.query = Mock(return_value=results)
|
|
mock_dttm_col = Mock()
|
|
datasource.get_col = Mock(return_value=mock_dttm_col)
|
|
datasource.query = Mock(return_value=results)
|
|
datasource.database = Mock()
|
|
datasource.database.db_engine_spec = Mock()
|
|
datasource.database.db_engine_spec.mutate_expression_label = lambda x: x
|
|
return datasource
|
|
|
|
def get_resp(
|
|
self, url, data=None, follow_redirects=True, raise_on_error=True):
|
|
"""Shortcut to get the parsed results while following redirects"""
|
|
if data:
|
|
resp = self.client.post(
|
|
url, data=data, follow_redirects=follow_redirects)
|
|
else:
|
|
resp = self.client.get(url, follow_redirects=follow_redirects)
|
|
if raise_on_error and resp.status_code > 400:
|
|
raise Exception(
|
|
'http request failed with code {}'.format(resp.status_code))
|
|
return resp.data.decode('utf-8')
|
|
|
|
def get_json_resp(
|
|
self, url, data=None, follow_redirects=True, raise_on_error=True):
|
|
"""Shortcut to get the parsed results while following redirects"""
|
|
resp = self.get_resp(url, data, follow_redirects, raise_on_error)
|
|
return json.loads(resp)
|
|
|
|
def get_access_requests(self, username, ds_type, ds_id):
|
|
DAR = models.DatasourceAccessRequest
|
|
return (
|
|
db.session.query(DAR)
|
|
.filter(
|
|
DAR.created_by == security_manager.find_user(username=username),
|
|
DAR.datasource_type == ds_type,
|
|
DAR.datasource_id == ds_id,
|
|
)
|
|
.first()
|
|
)
|
|
|
|
def logout(self):
|
|
self.client.get('/logout/', follow_redirects=True)
|
|
|
|
def grant_public_access_to_table(self, table):
|
|
public_role = security_manager.find_role('Public')
|
|
perms = db.session.query(ab_models.PermissionView).all()
|
|
for perm in perms:
|
|
if (perm.permission.name == 'datasource_access' and
|
|
perm.view_menu and table.perm in perm.view_menu.name):
|
|
security_manager.add_permission_role(public_role, perm)
|
|
|
|
def revoke_public_access_to_table(self, table):
|
|
public_role = security_manager.find_role('Public')
|
|
perms = db.session.query(ab_models.PermissionView).all()
|
|
for perm in perms:
|
|
if (perm.permission.name == 'datasource_access' and
|
|
perm.view_menu and table.perm in perm.view_menu.name):
|
|
security_manager.del_permission_role(public_role, perm)
|
|
|
|
def run_sql(self, sql, client_id=None, user_name=None, raise_on_error=False,
|
|
query_limit=None):
|
|
if user_name:
|
|
self.logout()
|
|
self.login(username=(user_name if user_name else 'admin'))
|
|
dbid = get_main_database(db.session).id
|
|
resp = self.get_json_resp(
|
|
'/superset/sql_json/',
|
|
raise_on_error=False,
|
|
data=dict(database_id=dbid, sql=sql, select_as_create_as=False,
|
|
client_id=client_id, queryLimit=query_limit),
|
|
)
|
|
if raise_on_error and 'error' in resp:
|
|
raise Exception('run_sql failed')
|
|
return resp
|