[security] improving the security scheme (#1587)

* [security] improving the security scheme

* Addressing comments

* improving docs

* Creating security module to organize things

* Moving CLI to its own module

* perms

* Materializung perms

* progrss

* Addressing comments, linting
This commit is contained in:
Maxime Beauchemin 2016-11-17 11:58:33 -08:00 committed by GitHub
parent aad9744d85
commit bce02e3f51
19 changed files with 765 additions and 543 deletions

1
.gitignore vendored
View File

@ -8,6 +8,7 @@ changelog.sh
_build
_static
_images
_modules
superset/bin/supersetc
env_py3
.eggs

View File

@ -7,8 +7,19 @@ FAB provides authentication, user management, permissions and roles.
Provided Roles
--------------
Superset ships with 3 roles that are handled by Superset itself. You can
assume that these 3 roles will stay up-to-date as Superset evolves.
Superset ships with a set of roles that are handled by Superset itself.
You can assume that these roles will stay up-to-date as Superset evolves.
Even though it's possible for ``Admin`` usrs to do so, it is not recommended
that you alter these roles in any way by removing
or adding permissions to them as these roles will be re-synchronized to
their original values as you run your next ``superset init`` command.
Since it's not recommended to alter the roles described here, it's right
to assume that your security strategy should be to compose user access based
on these base roles and roles that you create. For instance you could
create a role ``Financial Analyst`` that would be made of set of permissions
to a set of data sources (tables) and/or databases. Users would then be
granted ``Gamma``, ``Financial Analyst``, and perhaps ``sql_lab``.
Admin
"""""
@ -33,6 +44,12 @@ mostly content consumers, though they can create slices and dashboards.
Also note that when Gamma users look at the dashboards and slices list view,
they will only see the objects that they have access to.
sql_lab
"""""""
The ``sql_lab`` role grants access to SQL Lab. Note that while ``Admin``
users have access to all databases by default, both ``Alpha`` and ``Gamma``
users need to be given access on a per database basis.
Managing Gamma per data source access
-------------------------------------

View File

@ -23,6 +23,8 @@ CONFIG_MODULE = os.environ.get('SUPERSET_CONFIG', 'superset.config')
app = Flask(__name__)
app.config.from_object(CONFIG_MODULE)
conf = app.config
if not app.debug:
# In production mode, add log handler to sys.stderr.
app.logger.addHandler(logging.StreamHandler())

View File

@ -158,7 +158,7 @@ export function setNetworkStatus(networkOn) {
export function addAlert(alert) {
const o = Object.assign({}, alert);
o.id = shortid.generate();
return { type: ADD_ALERT, o };
return { type: ADD_ALERT, alert: o };
}
export function removeAlert(alert) {

View File

@ -2,6 +2,13 @@ const $ = window.$ = require('jquery');
import React from 'react';
import Select from 'react-select';
const propTypes = {
onChange: React.PropTypes.func,
actions: React.PropTypes.object,
databaseId: React.PropTypes.number,
valueRenderer: React.PropTypes.func,
};
class DatabaseSelect extends React.PureComponent {
constructor(props) {
super(props);
@ -23,6 +30,12 @@ class DatabaseSelect extends React.PureComponent {
const options = data.result.map((db) => ({ value: db.id, label: db.database_name }));
this.setState({ databaseOptions: options, databaseLoading: false });
this.props.actions.setDatabases(data.result);
if (data.result.length === 0) {
this.props.actions.addAlert({
bsStyle: 'danger',
msg: "It seems you don't have access to any database",
});
}
});
}
render() {
@ -43,11 +56,6 @@ class DatabaseSelect extends React.PureComponent {
}
}
DatabaseSelect.propTypes = {
onChange: React.PropTypes.func,
actions: React.PropTypes.object,
databaseId: React.PropTypes.number,
valueRenderer: React.PropTypes.func,
};
DatabaseSelect.propTypes = propTypes;
export default DatabaseSelect;

View File

@ -4,156 +4,7 @@ from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import logging
import celery
from celery.bin import worker as celery_worker
from datetime import datetime
from subprocess import Popen
from flask_migrate import MigrateCommand
from flask_script import Manager
import superset
from superset import app, ascii_art, db, data, utils
config = app.config
manager = Manager(app)
manager.add_command('db', MigrateCommand)
@manager.option(
'-d', '--debug', action='store_true',
help="Start the web server in debug mode")
@manager.option(
'-a', '--address', default=config.get("SUPERSET_WEBSERVER_ADDRESS"),
help="Specify the address to which to bind the web server")
@manager.option(
'-p', '--port', default=config.get("SUPERSET_WEBSERVER_PORT"),
help="Specify the port on which to run the web server")
@manager.option(
'-w', '--workers', default=config.get("SUPERSET_WORKERS", 2),
help="Number of gunicorn web server workers to fire up")
@manager.option(
'-t', '--timeout', default=config.get("SUPERSET_WEBSERVER_TIMEOUT"),
help="Specify the timeout (seconds) for the gunicorn web server")
def runserver(debug, address, port, timeout, workers):
"""Starts a Superset web server"""
debug = debug or config.get("DEBUG")
if debug:
app.run(
host='0.0.0.0',
port=int(port),
threaded=True,
debug=True)
else:
cmd = (
"gunicorn "
"-w {workers} "
"--timeout {timeout} "
"-b {address}:{port} "
"--limit-request-line 0 "
"--limit-request-field_size 0 "
"superset:app").format(**locals())
print("Starting server with command: " + cmd)
Popen(cmd, shell=True).wait()
@manager.command
def init():
"""Inits the Superset application"""
utils.init(superset)
@manager.option(
'-v', '--verbose', action='store_true',
help="Show extra information")
def version(verbose):
"""Prints the current version number"""
s = (
"\n{boat}\n\n"
"-----------------------\n"
"Superset {version}\n"
"-----------------------").format(
boat=ascii_art.boat, version=config.get('VERSION_STRING'))
print(s)
if verbose:
print("[DB] : " + "{}".format(db.engine))
@manager.option(
'-t', '--load-test-data', action='store_true',
help="Load additional test data")
def load_examples(load_test_data):
"""Loads a set of Slices and Dashboards and a supporting dataset """
print("Loading examples into {}".format(db))
data.load_css_templates()
print("Loading energy related dataset")
data.load_energy()
print("Loading [World Bank's Health Nutrition and Population Stats]")
data.load_world_bank_health_n_pop()
print("Loading [Birth names]")
data.load_birth_names()
print("Loading [Random time series data]")
data.load_random_time_series_data()
print("Loading [Random long/lat data]")
data.load_long_lat_data()
print("Loading [Multiformat time series]")
data.load_multiformat_time_series_data()
print("Loading [Misc Charts] dashboard")
data.load_misc_dashboard()
if load_test_data:
print("Loading [Unicode test data]")
data.load_unicode_test_data()
@manager.option(
'-d', '--datasource',
help=(
"Specify which datasource name to load, if omitted, all "
"datasources will be refreshed"))
def refresh_druid(datasource):
"""Refresh druid datasources"""
session = db.session()
from superset import models
for cluster in session.query(models.DruidCluster).all():
try:
cluster.refresh_datasources(datasource_name=datasource)
except Exception as e:
print(
"Error while processing cluster '{}'\n{}".format(
cluster, str(e)))
logging.exception(e)
cluster.metadata_last_refreshed = datetime.now()
print(
"Refreshed metadata from cluster "
"[" + cluster.cluster_name + "]")
session.commit()
@manager.command
def worker():
"""Starts a Superset worker for async SQL query execution."""
# celery -A tasks worker --loglevel=info
print("Starting SQL Celery worker.")
if config.get('CELERY_CONFIG'):
print("Celery broker url: ")
print(config.get('CELERY_CONFIG').BROKER_URL)
application = celery.current_app._get_current_object()
c_worker = celery_worker.worker(app=application)
options = {
'broker': config.get('CELERY_CONFIG').BROKER_URL,
'loglevel': 'INFO',
'traceback': True,
}
c_worker.run(**options)
from superset.cli import manager
if __name__ == "__main__":
manager.run()

158
superset/cli.py Executable file
View File

@ -0,0 +1,158 @@
#!/usr/bin/env python
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import logging
import celery
from celery.bin import worker as celery_worker
from datetime import datetime
from subprocess import Popen
from flask_migrate import MigrateCommand
from flask_script import Manager
from superset import app, ascii_art, db, data, security
config = app.config
manager = Manager(app)
manager.add_command('db', MigrateCommand)
@manager.option(
'-d', '--debug', action='store_true',
help="Start the web server in debug mode")
@manager.option(
'-a', '--address', default=config.get("SUPERSET_WEBSERVER_ADDRESS"),
help="Specify the address to which to bind the web server")
@manager.option(
'-p', '--port', default=config.get("SUPERSET_WEBSERVER_PORT"),
help="Specify the port on which to run the web server")
@manager.option(
'-w', '--workers', default=config.get("SUPERSET_WORKERS", 2),
help="Number of gunicorn web server workers to fire up")
@manager.option(
'-t', '--timeout', default=config.get("SUPERSET_WEBSERVER_TIMEOUT"),
help="Specify the timeout (seconds) for the gunicorn web server")
def runserver(debug, address, port, timeout, workers):
"""Starts a Superset web server"""
debug = debug or config.get("DEBUG")
if debug:
app.run(
host='0.0.0.0',
port=int(port),
threaded=True,
debug=True)
else:
cmd = (
"gunicorn "
"-w {workers} "
"--timeout {timeout} "
"-b {address}:{port} "
"--limit-request-line 0 "
"--limit-request-field_size 0 "
"superset:app").format(**locals())
print("Starting server with command: " + cmd)
Popen(cmd, shell=True).wait()
@manager.command
def init():
"""Inits the Superset application"""
security.sync_role_definitions()
@manager.option(
'-v', '--verbose', action='store_true',
help="Show extra information")
def version(verbose):
"""Prints the current version number"""
s = (
"\n{boat}\n\n"
"-----------------------\n"
"Superset {version}\n"
"-----------------------").format(
boat=ascii_art.boat, version=config.get('VERSION_STRING'))
print(s)
if verbose:
print("[DB] : " + "{}".format(db.engine))
@manager.option(
'-t', '--load-test-data', action='store_true',
help="Load additional test data")
def load_examples(load_test_data):
"""Loads a set of Slices and Dashboards and a supporting dataset """
print("Loading examples into {}".format(db))
data.load_css_templates()
print("Loading energy related dataset")
data.load_energy()
print("Loading [World Bank's Health Nutrition and Population Stats]")
data.load_world_bank_health_n_pop()
print("Loading [Birth names]")
data.load_birth_names()
print("Loading [Random time series data]")
data.load_random_time_series_data()
print("Loading [Random long/lat data]")
data.load_long_lat_data()
print("Loading [Multiformat time series]")
data.load_multiformat_time_series_data()
print("Loading [Misc Charts] dashboard")
data.load_misc_dashboard()
if load_test_data:
print("Loading [Unicode test data]")
data.load_unicode_test_data()
@manager.option(
'-d', '--datasource',
help=(
"Specify which datasource name to load, if omitted, all "
"datasources will be refreshed"))
def refresh_druid(datasource):
"""Refresh druid datasources"""
session = db.session()
from superset import models
for cluster in session.query(models.DruidCluster).all():
try:
cluster.refresh_datasources(datasource_name=datasource)
except Exception as e:
print(
"Error while processing cluster '{}'\n{}".format(
cluster, str(e)))
logging.exception(e)
cluster.metadata_last_refreshed = datetime.now()
print(
"Refreshed metadata from cluster "
"[" + cluster.cluster_name + "]")
session.commit()
@manager.command
def worker():
"""Starts a Superset worker for async SQL query execution."""
# celery -A tasks worker --loglevel=info
print("Starting SQL Celery worker.")
if config.get('CELERY_CONFIG'):
print("Celery broker url: ")
print(config.get('CELERY_CONFIG').BROKER_URL)
application = celery.current_app._get_current_object()
c_worker = celery_worker.worker(app=application)
options = {
'broker': config.get('CELERY_CONFIG').BROKER_URL,
'loglevel': 'INFO',
'traceback': True,
}
c_worker.run(**options)

View File

@ -8,7 +8,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from superset import app
import json
import os

View File

@ -14,8 +14,8 @@ import random
import pandas as pd
from sqlalchemy import String, DateTime, Date, Float, BigInteger
import superset
from superset import app, db, models, utils
from superset.security import get_or_create_main_db
# Shortcuts
DB = models.Database
@ -67,7 +67,7 @@ def load_energy():
tbl = TBL(table_name=tbl_name)
tbl.description = "Energy consumption"
tbl.is_featured = True
tbl.database = utils.get_or_create_main_db(superset)
tbl.database = get_or_create_main_db()
db.session.merge(tbl)
db.session.commit()
tbl.fetch_metadata()
@ -194,7 +194,7 @@ def load_world_bank_health_n_pop():
tbl.description = utils.readfile(os.path.join(DATA_FOLDER, 'countries.md'))
tbl.main_dttm_col = 'year'
tbl.is_featured = True
tbl.database = utils.get_or_create_main_db(superset)
tbl.database = get_or_create_main_db()
db.session.merge(tbl)
db.session.commit()
tbl.fetch_metadata()
@ -586,7 +586,7 @@ def load_birth_names():
if not obj:
obj = TBL(table_name='birth_names')
obj.main_dttm_col = 'ds'
obj.database = utils.get_or_create_main_db(superset)
obj.database = get_or_create_main_db()
obj.is_featured = True
db.session.merge(obj)
db.session.commit()
@ -834,7 +834,7 @@ def load_unicode_test_data():
if not obj:
obj = TBL(table_name='unicode_test')
obj.main_dttm_col = 'date'
obj.database = utils.get_or_create_main_db(superset)
obj.database = get_or_create_main_db()
obj.is_featured = False
db.session.merge(obj)
db.session.commit()
@ -872,7 +872,11 @@ def load_unicode_test_data():
merge_slice(slc)
print("Creating a dashboard")
dash = db.session.query(Dash).filter_by(dashboard_title="Unicode Test").first()
dash = (
db.session.query(Dash)
.filter_by(dashboard_title="Unicode Test")
.first()
)
if not dash:
dash = Dash()
@ -913,7 +917,7 @@ def load_random_time_series_data():
if not obj:
obj = TBL(table_name='random_time_series')
obj.main_dttm_col = 'ds'
obj.database = utils.get_or_create_main_db(superset)
obj.database = get_or_create_main_db()
obj.is_featured = False
db.session.merge(obj)
db.session.commit()
@ -981,7 +985,7 @@ def load_long_lat_data():
if not obj:
obj = TBL(table_name='long_lat')
obj.main_dttm_col = 'date'
obj.database = utils.get_or_create_main_db(superset)
obj.database = get_or_create_main_db()
obj.is_featured = False
db.session.merge(obj)
db.session.commit()
@ -1046,7 +1050,7 @@ def load_multiformat_time_series_data():
if not obj:
obj = TBL(table_name='multiformat_time_series')
obj.main_dttm_col = 'ds'
obj.database = utils.get_or_create_main_db(superset)
obj.database = get_or_create_main_db()
obj.is_featured = False
dttm_and_expr_dict = {
'ds': [None, None],

View File

@ -0,0 +1,27 @@
"""materialize perms
Revision ID: e46f2d27a08e
Revises: c611f2b591b8
Create Date: 2016-11-14 15:23:32.594898
"""
# revision identifiers, used by Alembic.
revision = 'e46f2d27a08e'
down_revision = 'c611f2b591b8'
from alembic import op
import sqlalchemy as sa
def upgrade():
op.add_column('datasources', sa.Column('perm', sa.String(length=1000), nullable=True))
op.add_column('dbs', sa.Column('perm', sa.String(length=1000), nullable=True))
op.add_column('tables', sa.Column('perm', sa.String(length=1000), nullable=True))
def downgrade():
op.drop_column('tables', 'perm')
op.drop_column('datasources', 'perm')
op.drop_column('dbs', 'perm')

View File

@ -21,6 +21,7 @@ import requests
import sqlalchemy as sqla
from sqlalchemy.engine.url import make_url
from sqlalchemy.orm import subqueryload
from sqlalchemy.ext.hybrid import hybrid_property
import sqlparse
from dateutil.parser import parse
@ -69,6 +70,27 @@ QueryResult = namedtuple('namedtuple', ['df', 'query', 'duration'])
FillterPattern = re.compile(r'''((?:[^,"']|"[^"]*"|'[^']*')+)''')
def set_perm(mapper, connection, target): # noqa
target.perm = target.get_perm()
def init_metrics_perm(metrics=None):
"""Create permissions for restricted metrics
:param metrics: a list of metrics to be processed, if not specified,
all metrics are processed
:type metrics: models.SqlMetric or models.DruidMetric
"""
if not metrics:
metrics = []
for model in [SqlMetric, DruidMetric]:
metrics += list(db.session.query(model).all())
for metric in metrics:
if metric.is_restricted and metric.perm:
sm.add_permission_view_menu('metric_access', metric.perm)
class JavascriptPostAggregator(Postaggregator):
def __init__(self, name, field_names, function):
self.post_aggregator = {
@ -198,7 +220,7 @@ class Slice(Model, AuditMixinNullable, ImportMixin):
params = Column(Text)
description = Column(Text)
cache_timeout = Column(Integer)
perm = Column(String(2000))
perm = Column(String(1000))
owners = relationship("User", secondary=slice_user)
export_fields = ('slice_name', 'datasource_type', 'datasource_name',
@ -365,14 +387,14 @@ class Slice(Model, AuditMixinNullable, ImportMixin):
return slc_to_import.id
def set_perm(mapper, connection, target): # noqa
def set_related_perm(mapper, connection, target): # noqa
src_class = target.cls_model
id_ = target.datasource_id
ds = db.session.query(src_class).filter_by(id=int(id_)).first()
target.perm = ds.perm
sqla.event.listen(Slice, 'before_insert', set_perm)
sqla.event.listen(Slice, 'before_update', set_perm)
sqla.event.listen(Slice, 'before_insert', set_related_perm)
sqla.event.listen(Slice, 'before_update', set_related_perm)
dashboard_slices = Table(
@ -663,6 +685,7 @@ class Database(Model, AuditMixinNullable):
"engine_params": {}
}
"""))
perm = Column(String(1000))
def __repr__(self):
return self.database_name
@ -826,11 +849,13 @@ class Database(Model, AuditMixinNullable):
def sql_url(self):
return '/superset/sql/{}/'.format(self.id)
@property
def perm(self):
def get_perm(self):
return (
"[{obj.database_name}].(id:{obj.id})").format(obj=self)
sqla.event.listen(Database, 'before_insert', set_perm)
sqla.event.listen(Database, 'before_update', set_perm)
class SqlaTable(Model, Queryable, AuditMixinNullable, ImportMixin):
@ -857,6 +882,7 @@ class SqlaTable(Model, Queryable, AuditMixinNullable, ImportMixin):
schema = Column(String(255))
sql = Column(Text)
params = Column(Text)
perm = Column(String(1000))
baselink = "tablemodelview"
export_fields = (
@ -882,8 +908,7 @@ class SqlaTable(Model, Queryable, AuditMixinNullable, ImportMixin):
return Markup(
'<a href="{self.explore_url}">{table_name}</a>'.format(**locals()))
@property
def perm(self):
def get_perm(self):
return (
"[{obj.database}].[{obj.table_name}]"
"(id:{obj.id})").format(obj=self)
@ -1299,6 +1324,9 @@ class SqlaTable(Model, Queryable, AuditMixinNullable, ImportMixin):
return datasource.id
sqla.event.listen(SqlaTable, 'before_insert', set_perm)
sqla.event.listen(SqlaTable, 'before_update', set_perm)
class SqlMetric(Model, AuditMixinNullable, ImportMixin):
@ -1574,6 +1602,7 @@ class DruidDatasource(Model, AuditMixinNullable, Queryable):
'DruidCluster', backref='datasources', foreign_keys=[cluster_name])
offset = Column(Integer, default=0)
cache_timeout = Column(Integer)
perm = Column(String(1000))
@property
def database(self):
@ -1597,8 +1626,7 @@ class DruidDatasource(Model, AuditMixinNullable, Queryable):
def name(self):
return self.datasource_name
@property
def perm(self):
def get_perm(self):
return (
"[{obj.cluster_name}].[{obj.datasource_name}]"
"(id:{obj.id})").format(obj=self)
@ -2178,6 +2206,9 @@ class DruidDatasource(Model, AuditMixinNullable, Queryable):
filters = cond
return filters
sqla.event.listen(DruidDatasource, 'before_insert', set_perm)
sqla.event.listen(DruidDatasource, 'before_update', set_perm)
class Log(Model):
@ -2403,7 +2434,7 @@ class DruidColumn(Model, AuditMixinNullable):
session.add(metric)
session.flush()
utils.init_metrics_perm(superset, new_metrics)
init_metrics_perm(new_metrics)
class FavStar(Model):

178
superset/security.py Normal file
View File

@ -0,0 +1,178 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from itertools import product
import logging
from flask_appbuilder.security.sqla import models as ab_models
from superset import conf, db, models, sm
READ_ONLY_MODELVIEWS = {
'DatabaseAsync',
'DatabaseView',
'DruidClusterModelView',
}
ADMIN_ONLY_VIEW_MENUES = {
'AccessRequestsModelView',
'Manage',
'SQL Lab',
'Queries',
'Refresh Druid Metadata',
'ResetPasswordView',
'RoleModelView',
'Security',
'UserDBModelView',
} | READ_ONLY_MODELVIEWS
ADMIN_ONLY_PERMISSIONS = {
'all_datasource_access',
'all_database_access',
'datasource_access',
'database_access',
'can_sql_json',
'can_override_role_permissions',
'can_sync_druid_source',
'can_override_role_permissions',
'can_approve',
}
READ_ONLY_PERMISSION = {
'can_show',
'can_list',
}
ALPHA_ONLY_PERMISSIONS = set([
'can_add',
'can_download',
'can_delete',
'can_edit',
'can_save',
'datasource_access',
'database_access',
'muldelete',
])
READ_ONLY_PRODUCT = set(
product(READ_ONLY_PERMISSION, READ_ONLY_MODELVIEWS))
def get_or_create_main_db():
logging.info("Creating database reference")
dbobj = (
db.session.query(models.Database)
.filter_by(database_name='main')
.first()
)
if not dbobj:
dbobj = models.Database(database_name="main")
logging.info(conf.get("SQLALCHEMY_DATABASE_URI"))
dbobj.set_sqlalchemy_uri(conf.get("SQLALCHEMY_DATABASE_URI"))
dbobj.expose_in_sqllab = True
dbobj.allow_run_sync = True
db.session.add(dbobj)
db.session.commit()
return dbobj
def sync_role_definitions():
"""Inits the Superset application with security roles and such"""
logging.info("Syncing role definition")
# Creating default roles
alpha = sm.add_role("Alpha")
admin = sm.add_role("Admin")
gamma = sm.add_role("Gamma")
public = sm.add_role("Public")
sql_lab = sm.add_role("sql_lab")
granter = sm.add_role("granter")
get_or_create_main_db()
# Global perms
sm.add_permission_view_menu(
'all_datasource_access', 'all_datasource_access')
sm.add_permission_view_menu('all_database_access', 'all_database_access')
perms = db.session.query(ab_models.PermissionView).all()
perms = [p for p in perms if p.permission and p.view_menu]
logging.info("Syncing admin perms")
for p in perms:
sm.add_permission_role(admin, p)
logging.info("Syncing alpha perms")
for p in perms:
if (
(
p.view_menu.name not in ADMIN_ONLY_VIEW_MENUES and
p.permission.name not in ADMIN_ONLY_PERMISSIONS
) or
(p.permission.name, p.view_menu.name) in READ_ONLY_PRODUCT
):
sm.add_permission_role(alpha, p)
else:
sm.del_permission_role(alpha, p)
logging.info("Syncing gamma perms and public if specified")
PUBLIC_ROLE_LIKE_GAMMA = conf.get('PUBLIC_ROLE_LIKE_GAMMA', False)
for p in perms:
if (
(
p.view_menu.name not in ADMIN_ONLY_VIEW_MENUES and
p.permission.name not in ADMIN_ONLY_PERMISSIONS and
p.permission.name not in ALPHA_ONLY_PERMISSIONS
) or
(p.permission.name, p.view_menu.name) in READ_ONLY_PRODUCT
):
sm.add_permission_role(gamma, p)
if PUBLIC_ROLE_LIKE_GAMMA:
sm.add_permission_role(public, p)
else:
sm.del_permission_role(gamma, p)
sm.del_permission_role(public, p)
logging.info("Syncing sql_lab perms")
for p in perms:
if (
p.view_menu.name in {'SQL Lab'} or
p.permission.name in {
'can_sql_json', 'can_csv', 'can_search_queries'}
):
sm.add_permission_role(sql_lab, p)
else:
sm.del_permission_role(sql_lab, p)
logging.info("Syncing granter perms")
for p in perms:
if (
p.permission.name in {
'can_override_role_permissions', 'can_aprove'}
):
sm.add_permission_role(granter, p)
else:
sm.del_permission_role(granter, p)
logging.info("Making sure all data source perms have been created")
session = db.session()
datasources = [
o for o in session.query(models.SqlaTable).all()]
datasources += [
o for o in session.query(models.DruidDatasource).all()]
for datasource in datasources:
perm = datasource.get_perm()
sm.add_permission_view_menu('datasource_access', perm)
if perm != datasource.perm:
datasource.perm = perm
logging.info("Making sure all database perms have been created")
databases = [o for o in session.query(models.Database).all()]
for database in databases:
perm = database.get_perm()
if perm != database.perm:
database.perm = perm
sm.add_permission_view_menu('database_access', perm)
session.commit()
logging.info("Making sure all metrics perms exist")
models.init_metrics_perm()

View File

@ -20,7 +20,6 @@ import parsedatetime
import sqlalchemy as sa
from dateutil.parser import parse
from flask import flash, Markup
from flask_appbuilder.security.sqla import models as ab_models
import markdown as md
from sqlalchemy.types import TypeDecorator, TEXT
from pydruid.utils.having import Having
@ -109,23 +108,6 @@ class memoized(object): # noqa
return functools.partial(self.__call__, obj)
def get_or_create_main_db(superset):
db = superset.db
config = superset.app.config
DB = superset.models.Database
logging.info("Creating database reference")
dbobj = db.session.query(DB).filter_by(database_name='main').first()
if not dbobj:
dbobj = DB(database_name="main")
logging.info(config.get("SQLALCHEMY_DATABASE_URI"))
dbobj.set_sqlalchemy_uri(config.get("SQLALCHEMY_DATABASE_URI"))
dbobj.expose_in_sqllab = True
dbobj.allow_run_sync = True
db.session.add(dbobj)
db.session.commit()
return dbobj
class DimSelector(Having):
def __init__(self, **args):
# Just a hack to prevent any exceptions
@ -185,12 +167,6 @@ def dttm_from_timtuple(d):
d.tm_year, d.tm_mon, d.tm_mday, d.tm_hour, d.tm_min, d.tm_sec)
def merge_perm(sm, permission_name, view_menu_name):
pv = sm.find_permission_view_menu(permission_name, view_menu_name)
if not pv:
sm.add_permission_view_menu(permission_name, view_menu_name)
def parse_human_timedelta(s):
"""
Returns ``datetime.datetime`` from natural language time deltas
@ -224,113 +200,6 @@ class JSONEncodedDict(TypeDecorator):
return value
def init(superset):
"""Inits the Superset application with security roles and such"""
ADMIN_ONLY_VIEW_MENUES = set([
'ResetPasswordView',
'RoleModelView',
'Security',
'UserDBModelView',
'SQL Lab',
'AccessRequestsModelView',
'Manage',
])
ADMIN_ONLY_PERMISSIONS = set([
'can_sync_druid_source',
'can_override_role_permissions',
'can_approve',
])
ALPHA_ONLY_PERMISSIONS = set([
'all_datasource_access',
'can_add',
'can_download',
'can_delete',
'can_edit',
'can_save',
'datasource_access',
'database_access',
'muldelete',
])
db = superset.db
models = superset.models
config = superset.app.config
sm = superset.appbuilder.sm
alpha = sm.add_role("Alpha")
admin = sm.add_role("Admin")
get_or_create_main_db(superset)
merge_perm(sm, 'all_datasource_access', 'all_datasource_access')
perms = db.session.query(ab_models.PermissionView).all()
# set alpha and admin permissions
for perm in perms:
if (
perm.permission and
perm.permission.name in ('datasource_access', 'database_access')):
continue
if (
perm.view_menu and
perm.view_menu.name not in ADMIN_ONLY_VIEW_MENUES and
perm.permission and
perm.permission.name not in ADMIN_ONLY_PERMISSIONS):
sm.add_permission_role(alpha, perm)
sm.add_permission_role(admin, perm)
gamma = sm.add_role("Gamma")
public_role = sm.find_role("Public")
public_role_like_gamma = \
public_role and config.get('PUBLIC_ROLE_LIKE_GAMMA', False)
# set gamma permissions
for perm in perms:
if (
perm.view_menu and
perm.view_menu.name not in ADMIN_ONLY_VIEW_MENUES and
perm.permission and
perm.permission.name not in ADMIN_ONLY_PERMISSIONS and
perm.permission.name not in ALPHA_ONLY_PERMISSIONS):
sm.add_permission_role(gamma, perm)
if public_role_like_gamma:
sm.add_permission_role(public_role, perm)
session = db.session()
table_perms = [
table.perm for table in session.query(models.SqlaTable).all()]
table_perms += [
table.perm for table in session.query(models.DruidDatasource).all()]
for table_perm in table_perms:
merge_perm(sm, 'datasource_access', table_perm)
db_perms = [db.perm for db in session.query(models.Database).all()]
for db_perm in db_perms:
merge_perm(sm, 'database_access', db_perm)
init_metrics_perm(superset)
def init_metrics_perm(superset, metrics=None):
"""Create permissions for restricted metrics
:param metrics: a list of metrics to be processed, if not specified,
all metrics are processed
:type metrics: models.SqlMetric or models.DruidMetric
"""
db = superset.db
models = superset.models
sm = superset.appbuilder.sm
if not metrics:
metrics = []
for model in [models.SqlMetric, models.DruidMetric]:
metrics += list(db.session.query(model).all())
for metric in metrics:
if metric.is_restricted and metric.perm:
merge_perm(sm, 'metric_access', metric.perm)
def datetime_f(dttm):
"""Formats datetime to take less room when it is recent"""
if dttm:

View File

@ -3,6 +3,7 @@ from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from datetime import datetime, timedelta
import json
import logging
import pickle
@ -11,7 +12,6 @@ import sys
import time
import traceback
import zlib
from datetime import datetime, timedelta
import functools
import sqlalchemy as sqla
@ -28,14 +28,13 @@ from flask_babel import lazy_gettext as _
from flask_appbuilder.models.sqla.filters import BaseFilter
from sqlalchemy import create_engine
from werkzeug.datastructures import ImmutableMultiDict
from werkzeug.routing import BaseConverter
from wtforms.validators import ValidationError
import superset
from superset import (
appbuilder, cache, db, models, viz, utils, app,
sm, ascii_art, sql_lab, results_backend
sm, ascii_art, sql_lab, results_backend, security,
)
from superset.source_registry import SourceRegistry
from superset.models import DatasourceAccessRequest as DAR
@ -55,12 +54,17 @@ class BaseSupersetView(BaseView):
"all_datasource_access", "all_datasource_access")
def database_access(self, database):
return (self.all_datasource_access() or
self.can_access("database_access", database.perm))
return (
self.can_access("all_database_access", "all_database_access") or
self.can_access("database_access", database.perm)
)
def datasource_access(self, datasource):
return (self.database_access(datasource.database) or
self.can_access("datasource_access", datasource.perm))
return (
self.database_access(datasource.database) or
self.can_access("all_database_access", "all_database_access") or
self.can_access("datasource_access", datasource.perm)
)
class ListWidgetWithCheckboxes(ListWidget):
@ -181,47 +185,94 @@ def get_user_roles():
class SupersetFilter(BaseFilter):
def get_perms(self):
perms = []
"""Add utility function to make BaseFilter easy and fast
These utility function exist in the SecurityManager, but would do
a database round trip at every check. Here we cache the role objects
to be able to make multiple checks but query the db only once
"""
def get_user_roles(self):
attr = '__get_user_roles'
if not hasattr(self, attr):
setattr(self, attr, get_user_roles())
return getattr(self, attr)
def get_all_permissions(self):
"""Returns a set of tuples with the perm name and view menu name"""
perms = set()
for role in get_user_roles():
for perm_view in role.permissions:
if perm_view.permission.name == 'datasource_access':
perms.append(perm_view.view_menu.name)
t = (perm_view.permission.name, perm_view.view_menu.name)
perms.add(t)
return perms
def has_role(self, role_name_or_list):
"""Whether the user has this role name"""
if not isinstance(role_name_or_list, list):
role_name_or_list = [role_name_or_list]
return any(
[r.name in role_name_or_list for r in self.get_user_roles()])
class TableSlice(SupersetFilter):
def has_perm(self, permission_name, view_menu_name):
"""Whether the user has this perm"""
return (permission_name, view_menu_name) in self.get_all_permissions()
def get_view_menus(self, permission_name):
"""Returns the details of view_menus for a perm name"""
vm = set()
for perm_name, vm_name in self.get_all_permissions():
if perm_name == permission_name:
vm.add(vm_name)
return vm
def has_all_datasource_access(self):
return (
self.has_role(['Admin', 'Alpha']) or
self.has_perm('all_datasource_access', 'all_datasource_access'))
class DatabaseFilter(SupersetFilter):
def apply(self, query, func): # noqa
if any([r.name in ('Admin', 'Alpha') for r in get_user_roles()]):
if (
self.has_role('Admin') or
self.has_perm('all_database_access', 'all_database_access')):
return query
perms = self.get_perms()
tables = []
for perm in perms:
match = re.search(r'\(id:(\d+)\)', perm)
tables.append(match.group(1))
qry = query.filter(self.model.id.in_(tables))
return qry
perms = self.get_view_menus('database_access')
return query.filter(self.model.perm.in_(perms))
class FilterSlice(SupersetFilter):
class DatasourceFilter(SupersetFilter):
def apply(self, query, func): # noqa
if any([r.name in ('Admin', 'Alpha') for r in get_user_roles()]):
if self.has_all_datasource_access():
return query
qry = query.filter(self.model.perm.in_(self.get_perms()))
return qry
perms = self.get_view_menus('datasource_access')
return query.filter(self.model.perm.in_(perms))
class FilterDashboard(SupersetFilter):
class SliceFilter(SupersetFilter):
def apply(self, query, func): # noqa
if self.has_all_datasource_access():
return query
perms = self.get_view_menus('datasource_access')
return query.filter(self.model.perm.in_(perms))
class DashboardFilter(SupersetFilter):
"""List dashboards for which users have access to at least one slice"""
def apply(self, query, func): # noqa
if any([r.name in ('Admin', 'Alpha') for r in get_user_roles()]):
if self.has_all_datasource_access():
return query
Slice = models.Slice # noqa
Dash = models.Dashboard # noqa
datasource_perms = self.get_view_menus('datasource_access')
slice_ids_qry = (
db.session
.query(Slice.id)
.filter(Slice.perm.in_(self.get_perms()))
.filter(Slice.perm.in_(datasource_perms))
)
query = query.filter(
Dash.id.in_(
@ -233,37 +284,6 @@ class FilterDashboard(SupersetFilter):
)
return query
class FilterDashboardSlices(SupersetFilter):
def apply(self, query, value): # noqa
if any([r.name in ('Admin', 'Alpha') for r in get_user_roles()]):
return query
qry = query.filter(self.model.perm.in_(self.get_perms()))
return qry
class FilterDashboardOwners(SupersetFilter):
def apply(self, query, value): # noqa
if any([r.name in ('Admin', 'Alpha') for r in get_user_roles()]):
return query
qry = query.filter_by(id=g.user.id)
return qry
class FilterDruidDatasource(SupersetFilter):
def apply(self, query, func): # noqa
if any([r.name in ('Admin', 'Alpha') for r in get_user_roles()]):
return query
perms = self.get_perms()
druid_datasources = []
for perm in perms:
match = re.search(r'\(id:(\d+)\)', perm)
if match:
druid_datasources.append(match.group(1))
qry = query.filter(self.model.id.in_(druid_datasources))
return qry
def validate_json(form, field): # noqa
try:
json.loads(field.data)
@ -494,10 +514,11 @@ class DatabaseView(SupersetModelView, DeleteMixin): # noqa
'extra',
'database_name',
'sqlalchemy_uri',
'perm',
'created_by',
'created_on',
'changed_by',
'changed_on'
'changed_on',
]
add_template = "superset/models/database/add.html"
edit_template = "superset/models/database/edit.html"
@ -551,7 +572,7 @@ class DatabaseView(SupersetModelView, DeleteMixin): # noqa
def pre_add(self, db):
db.set_sqlalchemy_uri(db.sqlalchemy_uri)
utils.merge_perm(sm, 'database_access', db.perm)
security.merge_perm(sm, 'database_access', db.perm)
def pre_update(self, db):
self.pre_add(db)
@ -578,6 +599,7 @@ appbuilder.add_view(
class DatabaseAsync(DatabaseView):
base_filters = [['id', DatabaseFilter, lambda: []]]
list_columns = [
'id', 'database_name',
'expose_in_sqllab', 'allow_ctas', 'force_ctas_schema',
@ -605,6 +627,7 @@ class TableModelView(SupersetModelView, DeleteMixin): # noqa
'table_name', 'sql', 'is_featured', 'database', 'schema',
'description', 'owner',
'main_dttm_col', 'default_endpoint', 'offset', 'cache_timeout']
show_columns = edit_columns + ['perm']
related_views = [TableColumnInlineView, SqlMetricInlineView]
base_order = ('changed_on', 'desc')
description_columns = {
@ -622,7 +645,7 @@ class TableModelView(SupersetModelView, DeleteMixin): # noqa
"run a query against this string as a subquery."
),
}
base_filters = [['id', TableSlice, lambda: []]]
base_filters = [['id', DatasourceFilter, lambda: []]]
label_columns = {
'link': _("Table"),
'changed_by_': _("Changed By"),
@ -659,7 +682,7 @@ class TableModelView(SupersetModelView, DeleteMixin): # noqa
def post_add(self, table):
table.fetch_metadata()
utils.merge_perm(sm, 'datasource_access', table.perm)
security.merge_perm(sm, 'datasource_access', table.perm)
flash(_(
"The table was created. As part of this two phase configuration "
"process, you should now click the edit button by "
@ -725,7 +748,7 @@ class DruidClusterModelView(SupersetModelView, DeleteMixin): # noqa
}
def pre_add(self, cluster):
utils.merge_perm(sm, 'database_access', cluster.perm)
security.merge_perm(sm, 'database_access', cluster.perm)
def pre_update(self, cluster):
self.pre_add(cluster)
@ -769,7 +792,7 @@ class SliceModelView(SupersetModelView, DeleteMixin): # noqa
"Duration (in seconds) of the caching timeout for this slice."
),
}
base_filters = [['id', FilterSlice, lambda: []]]
base_filters = [['id', SliceFilter, lambda: []]]
label_columns = {
'cache_timeout': _("Cache Timeout"),
'creator': _("Creator"),
@ -865,15 +888,11 @@ class DashboardModelView(SupersetModelView, DeleteMixin): # noqa
"want to alter specific parameters."),
'owners': _("Owners is a list of users who can alter the dashboard."),
}
base_filters = [['slice', FilterDashboard, lambda: []]]
base_filters = [['slice', DashboardFilter, lambda: []]]
add_form_query_rel_fields = {
'slices': [['slices', FilterDashboardSlices, None]],
'owners': [['owners', FilterDashboardOwners, None]],
}
edit_form_query_rel_fields = {
'slices': [['slices', FilterDashboardSlices, None]],
'owners': [['owners', FilterDashboardOwners, None]],
'slices': [['slices', SliceFilter, None]],
}
edit_form_query_rel_fields = add_form_query_rel_fields
label_columns = {
'dashboard_link': _("Dashboard"),
'dashboard_title': _("Title"),
@ -964,6 +983,19 @@ appbuilder.add_view(
icon="fa-list-ol")
class QueryView(SupersetModelView):
datamodel = SQLAInterface(models.Query)
list_columns = ['user', 'database', 'status', 'start_time', 'end_time']
appbuilder.add_view(
QueryView,
"Queries",
label=__("Queries"),
category="Manage",
category_label=__("Manage"),
icon="fa-search")
class DruidDatasourceModelView(SupersetModelView, DeleteMixin): # noqa
datamodel = SQLAInterface(models.DruidDatasource)
list_widget = ListWidgetWithCheckboxes
@ -977,6 +1009,7 @@ class DruidDatasourceModelView(SupersetModelView, DeleteMixin): # noqa
'is_featured', 'is_hidden', 'default_endpoint', 'offset',
'cache_timeout']
add_columns = edit_columns
show_columns = add_columns + ['perm']
page_size = 500
base_order = ('datasource_name', 'asc')
description_columns = {
@ -985,7 +1018,7 @@ class DruidDatasourceModelView(SupersetModelView, DeleteMixin): # noqa
"Supports <a href='"
"https://daringfireball.net/projects/markdown/'>markdown</a>"),
}
base_filters = [['id', FilterDruidDatasource, lambda: []]]
base_filters = [['id', DatasourceFilter, lambda: []]]
label_columns = {
'datasource_link': _("Data Source"),
'cluster': _("Cluster"),
@ -1013,7 +1046,7 @@ class DruidDatasourceModelView(SupersetModelView, DeleteMixin): # noqa
def post_add(self, datasource):
datasource.generate_metrics()
utils.merge_perm(sm, 'datasource_access', datasource.perm)
security.merge_perm(sm, 'datasource_access', datasource.perm)
def post_update(self, datasource):
self.post_add(datasource)
@ -1073,7 +1106,7 @@ appbuilder.add_view_no_menu(R)
class Superset(BaseSupersetView):
"""The base views for Superset!"""
@has_access
@has_access_api
@expose("/override_role_permissions/", methods=['POST'])
def override_role_permissions(self):
"""Updates the role with the give datasource permissions.
@ -1863,29 +1896,6 @@ class Superset(BaseSupersetView):
url = '/superset/explore/table/{table.id}/?{params}'.format(**locals())
return redirect(url)
@has_access
@expose("/sql/<database_id>/")
@log_this
def sql(self, database_id):
if not self.all_datasource_access():
flash(ALL_DATASOURCE_ACCESS_ERR, "danger")
return redirect("/tablemodelview/list/")
mydb = db.session.query(
models.Database).filter_by(id=database_id).first()
if not self.database_access(mydb.perm):
flash(get_database_access_error_msg(mydb.database_name), "danger")
return redirect("/tablemodelview/list/")
engine = mydb.get_sqla_engine()
tables = engine.table_names()
table_name = request.args.get('table_name')
return self.render_template(
"superset/sql.html",
tables=tables,
table_name=table_name,
db=mydb)
@has_access
@expose("/table/<database_id>/<table_name>/<schema>/")
@log_this
@ -2284,7 +2294,6 @@ class Superset(BaseSupersetView):
title=ascii_art.stacktrace,
art=ascii_art.error), 500
@has_access
@expose("/welcome")
def welcome(self):
"""Personalized welcome page"""
@ -2301,6 +2310,7 @@ appbuilder.add_view_no_menu(Superset)
if config['DRUID_IS_ACTIVE']:
appbuilder.add_link(
"Refresh Druid Metadata",
label=__("Refresh Druid Metadata"),
href='/superset/refresh_datasources/',
category='Sources',
category_label=__("Sources"),

View File

@ -4,20 +4,19 @@ from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import imp
import logging
import json
import os
import unittest
from flask_appbuilder.security.sqla import models as ab_models
import superset
from superset import app, db, models, utils, appbuilder, sm
from superset import app, cli, db, models, appbuilder, sm
from superset.security import sync_role_definitions
os.environ['SUPERSET_CONFIG'] = 'tests.superset_test_config'
BASE_DIR = app.config.get("BASE_DIR")
cli = imp.load_source('cli', BASE_DIR + "/bin/superset")
class SupersetTestCase(unittest.TestCase):
@ -30,13 +29,22 @@ class SupersetTestCase(unittest.TestCase):
not os.environ.get('SOLO_TEST') and
not os.environ.get('examples_loaded')
):
logging.info("Loading examples")
cli.load_examples(load_test_data=True)
utils.init(superset)
logging.info("Done loading examples")
sync_role_definitions()
os.environ['examples_loaded'] = '1'
else:
sync_role_definitions()
super(SupersetTestCase, self).__init__(*args, **kwargs)
self.client = app.test_client()
self.maxDiff = None
utils.init(superset)
gamma_sqllab = sm.add_role("gamma_sqllab")
for perm in sm.find_role('Gamma').permissions:
sm.add_permission_role(gamma_sqllab, perm)
for perm in sm.find_role('sql_lab').permissions:
sm.add_permission_role(gamma_sqllab, perm)
admin = appbuilder.sm.find_user('admin')
if not admin:
@ -52,6 +60,13 @@ class SupersetTestCase(unittest.TestCase):
appbuilder.sm.find_role('Gamma'),
password='general')
gamma_sqllab = appbuilder.sm.find_user('gamma_sqllab')
if not gamma_sqllab:
gamma_sqllab = appbuilder.sm.add_user(
'gamma_sqllab', 'gamma_sqllab', 'user', 'gamma_sqllab@fab.org',
appbuilder.sm.find_role('gamma_sqllab'),
password='general')
alpha = appbuilder.sm.find_user('alpha')
if not alpha:
appbuilder.sm.add_user(
@ -80,7 +95,6 @@ class SupersetTestCase(unittest.TestCase):
session.add(druid_datasource2)
session.commit()
utils.init(superset)
def get_or_create(self, cls, criteria, session):
obj = session.query(cls).filter_by(**criteria).first()
@ -89,11 +103,10 @@ class SupersetTestCase(unittest.TestCase):
return obj
def login(self, username='admin', password='general'):
resp = self.client.post(
resp = self.get_resp(
'/login/',
data=dict(username=username, password=password),
follow_redirects=True)
assert 'Welcome' in resp.data.decode('utf-8')
data=dict(username=username, password=password))
self.assertIn('Welcome', resp)
def get_query_by_sql(self, sql):
session = db.create_scoped_session()
@ -128,14 +141,19 @@ class SupersetTestCase(unittest.TestCase):
return db.session.query(models.DruidDatasource).filter_by(
datasource_name=name).first()
def get_resp(self, url):
def get_resp(self, url, data=None, follow_redirects=True):
"""Shortcut to get the parsed results while following redirects"""
resp = self.client.get(url, follow_redirects=True)
return resp.data.decode('utf-8')
if data:
resp = self.client.post(
url, data=data, follow_redirects=follow_redirects)
return resp.data.decode('utf-8')
else:
resp = self.client.get(url, follow_redirects=follow_redirects)
return resp.data.decode('utf-8')
def get_json_resp(self, url):
def get_json_resp(self, url, data=None):
"""Shortcut to get the parsed results while following redirects"""
resp = self.get_resp(url)
resp = self.get_resp(url, data=data)
return json.loads(resp)
def get_main_database(self, session):
@ -160,29 +178,30 @@ class SupersetTestCase(unittest.TestCase):
def logout(self):
self.client.get('/logout/', follow_redirects=True)
def setup_public_access_for_dashboard(self, table_name):
def grant_public_access_to_table(self, table):
public_role = appbuilder.sm.find_role('Public')
perms = db.session.query(ab_models.PermissionView).all()
for perm in perms:
if (perm.permission.name == 'datasource_access' and
perm.view_menu and table_name in perm.view_menu.name):
perm.view_menu and table.perm in perm.view_menu.name):
appbuilder.sm.add_permission_role(public_role, perm)
def revoke_public_access(self, table_name):
def revoke_public_access_to_table(self, table):
public_role = appbuilder.sm.find_role('Public')
perms = db.session.query(ab_models.PermissionView).all()
for perm in perms:
if (perm.permission.name == 'datasource_access' and
perm.view_menu and table_name in perm.view_menu.name):
perm.view_menu and table.perm in perm.view_menu.name):
appbuilder.sm.del_permission_role(public_role, perm)
def run_sql(self, sql, user_name, client_id):
self.login(username=(user_name if user_name else 'admin'))
def run_sql(self, sql, client_id, user_name=None):
if user_name:
self.logout()
self.login(username=(user_name if user_name else 'admin'))
dbid = self.get_main_database(db.session).id
resp = self.client.post(
resp = self.get_json_resp(
'/superset/sql_json/',
data=dict(database_id=dbid, sql=sql, select_as_create_as=False,
client_id=client_id),
)
self.logout()
return json.loads(resp.data.decode('utf-8'))
return resp

View File

@ -4,7 +4,6 @@ from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import imp
import json
import os
import subprocess
@ -13,15 +12,14 @@ import unittest
import pandas as pd
import superset
from superset import app, appbuilder, db, models, sql_lab, utils, dataframe
from superset import app, appbuilder, cli, db, models, sql_lab, dataframe
from superset.security import sync_role_definitions
from .base_tests import SupersetTestCase
QueryStatus = models.QueryStatus
BASE_DIR = app.config.get('BASE_DIR')
cli = imp.load_source('cli', BASE_DIR + '/bin/superset')
class CeleryConfig(object):
@ -99,7 +97,7 @@ class CeleryTestCase(SupersetTestCase):
except OSError as e:
app.logger.warn(str(e))
utils.init(superset)
sync_role_definitions()
worker_command = BASE_DIR + '/bin/superset worker'
subprocess.Popen(
@ -179,6 +177,7 @@ class CeleryTestCase(SupersetTestCase):
def test_run_sync_query(self):
main_db = self.get_main_database(db.session)
eng = main_db.get_sqla_engine()
perm_name = 'can_sql_json'
db_id = main_db.id
# Case 1.
@ -189,7 +188,8 @@ class CeleryTestCase(SupersetTestCase):
# Case 2.
# Table and DB exists, CTA call to the backend.
sql_where = "SELECT name FROM ab_permission WHERE name='can_sql'"
sql_where = (
"SELECT name FROM ab_permission WHERE name='{}'".format(perm_name))
result2 = self.run_sql(
db_id, sql_where, "2", tmp_table='tmp_table_2', cta='true')
self.assertEqual(QueryStatus.SUCCESS, result2['query']['state'])
@ -200,7 +200,7 @@ class CeleryTestCase(SupersetTestCase):
# Check the data in the tmp table.
df2 = pd.read_sql_query(sql=query2.select_sql, con=eng)
data2 = df2.to_dict(orient='records')
self.assertEqual([{'name': 'can_sql'}], data2)
self.assertEqual([{'name': perm_name}], data2)
# Case 3.
# Table and DB exists, CTA call to the backend, no data.

View File

@ -24,8 +24,6 @@ class CoreTests(SupersetTestCase):
requires_examples = True
def __init__(self, *args, **kwargs):
# Load examples first, so that we setup proper permission-view
# relations for all example data sources.
super(CoreTests, self).__init__(*args, **kwargs)
@classmethod
@ -118,7 +116,9 @@ class CoreTests(SupersetTestCase):
def test_save_slice(self):
self.login(username='admin')
slice_id = self.get_slice("Energy Sankey", db.session).id
slice_name = "Energy Sankey"
slice_id = self.get_slice(slice_name, db.session).id
db.session.commit()
copy_name = "Test Sankey Save"
tbl_id = self.table_ids.get('energy_usage')
url = (
@ -128,9 +128,15 @@ class CoreTests(SupersetTestCase):
"collapsed_fieldsets=&action={}&datasource_name=energy_usage&"
"datasource_id=1&datasource_type=table&previous_viz_type=sankey")
db.session.commit()
# Changing name
resp = self.get_resp(url.format(tbl_id, slice_id, copy_name, 'save'))
assert copy_name in resp
# Setting the name back to its original name
resp = self.get_resp(url.format(tbl_id, slice_id, slice_name, 'save'))
assert slice_name in resp
# Doing a basic overwrite
assert 'Energy' in self.get_resp(
url.format(tbl_id, slice_id, copy_name, 'overwrite'))
@ -281,15 +287,15 @@ class CoreTests(SupersetTestCase):
assert "List Dashboard" in self.get_resp('/dashboardmodelview/list/')
def test_csv_endpoint(self):
self.login('admin')
sql = """
SELECT first_name, last_name
FROM ab_user
WHERE first_name='admin'
"""
client_id = "{}".format(random.getrandbits(64))[:10]
self.run_sql(sql, 'admin', client_id)
self.run_sql(sql, client_id)
self.login('admin')
resp = self.get_resp('/superset/csv/{}'.format(client_id))
data = csv.reader(io.StringIO(resp))
expected_data = csv.reader(
@ -299,36 +305,48 @@ class CoreTests(SupersetTestCase):
self.logout()
def test_public_user_dashboard_access(self):
table = (
db.session
.query(models.SqlaTable)
.filter_by(table_name='birth_names')
.one()
)
# Try access before adding appropriate permissions.
self.revoke_public_access('birth_names')
self.revoke_public_access_to_table(table)
self.logout()
resp = self.get_resp('/slicemodelview/list/')
assert 'birth_names</a>' not in resp
self.assertNotIn('birth_names</a>', resp)
resp = self.get_resp('/dashboardmodelview/list/')
assert '/superset/dashboard/births/' not in resp
self.assertNotIn('/superset/dashboard/births/', resp)
self.setup_public_access_for_dashboard('birth_names')
self.grant_public_access_to_table(table)
# Try access after adding appropriate permissions.
assert 'birth_names' in self.get_resp('/slicemodelview/list/')
self.assertIn('birth_names', self.get_resp('/slicemodelview/list/'))
resp = self.get_resp('/dashboardmodelview/list/')
assert "/superset/dashboard/births/" in resp
self.assertIn("/superset/dashboard/births/", resp)
assert 'Births' in self.get_resp('/superset/dashboard/births/')
self.assertIn('Births', self.get_resp('/superset/dashboard/births/'))
# Confirm that public doesn't have access to other datasets.
resp = self.get_resp('/slicemodelview/list/')
assert 'wb_health_population</a>' not in resp
self.assertNotIn('wb_health_population</a>', resp)
resp = self.get_resp('/dashboardmodelview/list/')
assert "/superset/dashboard/world_health/" not in resp
self.assertNotIn("/superset/dashboard/world_health/", resp)
def test_dashboard_with_created_by_can_be_accessed_by_public_users(self):
self.logout()
self.setup_public_access_for_dashboard('birth_names')
table = (
db.session
.query(models.SqlaTable)
.filter_by(table_name='birth_names')
.one()
)
self.grant_public_access_to_table(table)
dash = db.session.query(models.Dashboard).filter_by(dashboard_title="Births").first()
dash.owners = [appbuilder.sm.find_user('admin')]
@ -382,8 +400,9 @@ class CoreTests(SupersetTestCase):
self.assertEqual("SELECT '2017-01-01T00:00:00'", rendered)
def test_templated_sql_json(self):
self.login('admin')
sql = "SELECT '{{ datetime(2017, 1, 1).isoformat() }}' as test"
data = self.run_sql(sql, "admin", "fdaklj3ws")
data = self.run_sql(sql, "fdaklj3ws")
self.assertEqual(data['data'][0]['test'], "2017-01-01T00:00:00")
def test_table_metadata(self):

View File

@ -241,8 +241,8 @@ class DruidTests(SupersetTestCase):
no_gamma_ds.cluster = cluster
db.session.merge(no_gamma_ds)
utils.merge_perm(sm, 'datasource_access', gamma_ds.perm)
utils.merge_perm(sm, 'datasource_access', no_gamma_ds.perm)
sm.add_permission_view_menu('datasource_access', gamma_ds.perm)
sm.add_permission_view_menu('datasource_access', no_gamma_ds.perm)
db.session.commit()

View File

@ -4,9 +4,8 @@ from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import csv
from datetime import datetime, timedelta
import json
import io
import unittest
from flask_appbuilder.security.sqla import models as ab_models
@ -20,25 +19,41 @@ class SqlLabTests(SupersetTestCase):
def __init__(self, *args, **kwargs):
super(SqlLabTests, self).__init__(*args, **kwargs)
def setUp(self):
def run_some_queries(self):
self.logout()
db.session.query(models.Query).delete()
self.run_sql("SELECT * FROM ab_user", 'admin', client_id='client_id_1')
self.run_sql("SELECT * FROM NO_TABLE", 'admin', client_id='client_id_3')
self.run_sql("SELECT * FROM ab_permission", 'gamma', client_id='client_id_2')
db.session.commit()
self.run_sql(
"SELECT * FROM ab_user",
client_id='client_id_1',
user_name='admin')
self.run_sql(
"SELECT * FROM NO_TABLE",
client_id='client_id_3',
user_name='admin')
self.run_sql(
"SELECT * FROM ab_permission",
client_id='client_id_2',
user_name='gamma_sqllab')
self.logout()
def tearDown(self):
db.session.query(models.Query).delete()
db.session.commit()
self.logout()
def test_sql_json(self):
data = self.run_sql('SELECT * FROM ab_user', 'admin', "1")
self.login('admin')
data = self.run_sql('SELECT * FROM ab_user', "1")
self.assertLess(0, len(data['data']))
data = self.run_sql('SELECT * FROM unexistant_table', 'admin', "2")
data = self.run_sql('SELECT * FROM unexistant_table', "2")
self.assertLess(0, len(data['error']))
def test_sql_json_has_access(self):
main_db = self.get_main_database(db.session)
utils.merge_perm(sm, 'database_access', main_db.perm)
sm.add_permission_view_menu('database_access', main_db.perm)
db.session.commit()
main_db_permission_view = (
db.session.query(ab_models.PermissionView)
@ -48,119 +63,133 @@ class SqlLabTests(SupersetTestCase):
)
astronaut = sm.add_role("Astronaut")
sm.add_permission_role(astronaut, main_db_permission_view)
# Astronaut role is Gamma + main db permissions
for gamma_perm in sm.find_role('Gamma').permissions:
sm.add_permission_role(astronaut, gamma_perm)
# Astronaut role is Gamma + sqllab + main db permissions
for perm in sm.find_role('Gamma').permissions:
sm.add_permission_role(astronaut, perm)
for perm in sm.find_role('sql_lab').permissions:
sm.add_permission_role(astronaut, perm)
gagarin = appbuilder.sm.find_user('gagarin')
if not gagarin:
appbuilder.sm.add_user(
'gagarin', 'Iurii', 'Gagarin', 'gagarin@cosmos.ussr',
appbuilder.sm.find_role('Astronaut'),
astronaut,
password='general')
data = self.run_sql('SELECT * FROM ab_user', 'gagarin', "3")
data = self.run_sql('SELECT * FROM ab_user', "3", user_name='gagarin')
db.session.query(models.Query).delete()
db.session.commit()
self.assertLess(0, len(data['data']))
def test_queries_endpoint(self):
resp = self.client.get('/superset/queries/{}'.format(0))
self.run_some_queries()
# Not logged in, should error out
resp = self.client.get('/superset/queries/0')
self.assertEquals(403, resp.status_code)
# Admin sees queries
self.login('admin')
data = self.get_json_resp('/superset/queries/{}'.format(0))
data = self.get_json_resp('/superset/queries/0')
self.assertEquals(2, len(data))
self.logout()
self.run_sql("SELECT * FROM ab_user1", 'admin', client_id='client_id_4')
self.run_sql("SELECT * FROM ab_user2", 'admin', client_id='client_id_5')
# Run 2 more queries
self.run_sql("SELECT * FROM ab_user1", client_id='client_id_4')
self.run_sql("SELECT * FROM ab_user2", client_id='client_id_5')
self.login('admin')
data = self.get_json_resp('/superset/queries/{}'.format(0))
data = self.get_json_resp('/superset/queries/0')
self.assertEquals(4, len(data))
now = datetime.now() + timedelta(days=1)
query = db.session.query(models.Query).filter_by(
sql='SELECT * FROM ab_user1').first()
query.changed_on = utils.EPOCH
query.changed_on = now
db.session.commit()
data = self.get_json_resp('/superset/queries/{}'.format(123456000))
self.assertEquals(3, len(data))
data = self.get_json_resp(
'/superset/queries/{}'.format(
int(utils.datetime_to_epoch(now))-1000))
self.assertEquals(1, len(data))
self.logout()
resp = self.client.get('/superset/queries/{}'.format(0))
resp = self.client.get('/superset/queries/0')
self.assertEquals(403, resp.status_code)
def test_search_query_on_db_id(self):
self.login('admin')
# Test search queries on database Id
resp = self.get_resp('/superset/search_queries?database_id=1')
data = json.loads(resp)
self.assertEquals(3, len(data))
db_ids = [data[k]['dbId'] for k in data]
self.assertEquals([1, 1, 1], db_ids)
self.run_some_queries()
self.login('admin')
# Test search queries on database Id
resp = self.get_resp('/superset/search_queries?database_id=1')
data = json.loads(resp)
self.assertEquals(3, len(data))
db_ids = [data[k]['dbId'] for k in data]
self.assertEquals([1, 1, 1], db_ids)
resp = self.get_resp('/superset/search_queries?database_id=-1')
data = json.loads(resp)
self.assertEquals(0, len(data))
self.logout()
resp = self.get_resp('/superset/search_queries?database_id=-1')
data = json.loads(resp)
self.assertEquals(0, len(data))
def test_search_query_on_user(self):
self.login('admin')
# Test search queries on user Id
user = appbuilder.sm.find_user('admin')
resp = self.get_resp('/superset/search_queries?user_id={}'.format(user.id))
data = json.loads(resp)
self.assertEquals(2, len(data))
user_ids = [data[k]['userId'] for k in data]
self.assertEquals([user.id, user.id], user_ids)
self.run_some_queries()
self.login('admin')
user = appbuilder.sm.find_user('gamma')
resp = self.get_resp('/superset/search_queries?user_id={}'.format(user.id))
data = json.loads(resp)
self.assertEquals(1, len(data))
self.assertEquals(list(data.values())[0]['userId'] , user.id)
self.logout()
# Test search queries on user Id
user = appbuilder.sm.find_user('admin')
data = self.get_json_resp(
'/superset/search_queries?user_id={}'.format(user.id))
self.assertEquals(2, len(data))
user_ids = {data[k]['userId'] for k in data}
self.assertEquals(set([user.id]), user_ids)
user = appbuilder.sm.find_user('gamma_sqllab')
resp = self.get_resp('/superset/search_queries?user_id={}'.format(user.id))
data = json.loads(resp)
self.assertEquals(1, len(data))
self.assertEquals(list(data.values())[0]['userId'] , user.id)
def test_search_query_on_status(self):
self.login('admin')
# Test search queries on status
resp = self.get_resp('/superset/search_queries?status=success')
data = json.loads(resp)
self.assertEquals(2, len(data))
states = [data[k]['state'] for k in data]
self.assertEquals(['success', 'success'], states)
self.run_some_queries()
self.login('admin')
# Test search queries on status
resp = self.get_resp('/superset/search_queries?status=success')
data = json.loads(resp)
self.assertEquals(2, len(data))
states = [data[k]['state'] for k in data]
self.assertEquals(['success', 'success'], states)
resp = self.get_resp('/superset/search_queries?status=failed')
data = json.loads(resp)
self.assertEquals(1, len(data))
self.assertEquals(list(data.values())[0]['state'], 'failed')
self.logout()
resp = self.get_resp('/superset/search_queries?status=failed')
data = json.loads(resp)
self.assertEquals(1, len(data))
self.assertEquals(list(data.values())[0]['state'], 'failed')
def test_search_query_on_text(self):
self.login('admin')
resp = self.get_resp('/superset/search_queries?search_text=permission')
data = json.loads(resp)
self.assertEquals(1, len(data))
self.assertIn('permission', list(data.values())[0]['sql'])
self.logout()
self.run_some_queries()
self.login('admin')
resp = self.get_resp('/superset/search_queries?search_text=permission')
data = json.loads(resp)
self.assertEquals(1, len(data))
self.assertIn('permission', list(data.values())[0]['sql'])
def test_search_query_on_time(self):
self.login('admin')
first_query_time = db.session.query(models.Query).filter_by(
sql='SELECT * FROM ab_user').first().start_time
second_query_time = db.session.query(models.Query).filter_by(
sql='SELECT * FROM ab_permission').first().start_time
# Test search queries on time filter
from_time = 'from={}'.format(int(first_query_time))
to_time = 'to={}'.format(int(second_query_time))
params = [from_time, to_time]
resp = self.get_resp('/superset/search_queries?'+'&'.join(params))
data = json.loads(resp)
self.assertEquals(2, len(data))
for _, v in data.items():
self.assertLess(int(first_query_time), v['startDttm'])
self.assertLess(v['startDttm'], int(second_query_time))
self.logout()
self.run_some_queries()
self.login('admin')
first_query_time = (
db.session.query(models.Query)
.filter_by(sql='SELECT * FROM ab_user').one()
).start_time
second_query_time = (
db.session.query(models.Query)
.filter_by(sql='SELECT * FROM ab_permission').one()
).start_time
# Test search queries on time filter
from_time = 'from={}'.format(int(first_query_time))
to_time = 'to={}'.format(int(second_query_time))
params = [from_time, to_time]
resp = self.get_resp('/superset/search_queries?'+'&'.join(params))
data = json.loads(resp)
self.assertEquals(2, len(data))
for _, v in data.items():
self.assertLess(int(first_query_time), v['startDttm'])
self.assertLess(v['startDttm'], int(second_query_time))
if __name__ == '__main__':