Add per database permissions for the SQL Lab. (#885)

This commit is contained in:
Bogdan 2016-08-09 17:53:23 -07:00 committed by GitHub
parent b48101ca51
commit d6bb8c6935
4 changed files with 94 additions and 29 deletions

View File

@ -504,6 +504,11 @@ class Database(Model, AuditMixinNullable):
def sql_link(self): def sql_link(self):
return '<a href="{}">SQL</a>'.format(self.sql_url) return '<a href="{}">SQL</a>'.format(self.sql_url)
@property
def perm(self):
return (
"[{obj.database_name}].(id:{obj.id})").format(obj=self)
class SqlaTable(Model, Queryable, AuditMixinNullable): class SqlaTable(Model, Queryable, AuditMixinNullable):

View File

@ -200,7 +200,7 @@ def init(caravel):
perms = db.session.query(ab_models.PermissionView).all() perms = db.session.query(ab_models.PermissionView).all()
for perm in perms: for perm in perms:
if perm.permission.name == 'datasource_access': if perm.permission.name in ('datasource_access', 'database_access'):
continue continue
if perm.view_menu and perm.view_menu.name not in ( if perm.view_menu and perm.view_menu.name not in (
'UserDBModelView', 'RoleModelView', 'ResetPasswordView', 'UserDBModelView', 'RoleModelView', 'ResetPasswordView',
@ -226,6 +226,7 @@ def init(caravel):
'can_edit', 'can_edit',
'can_save', 'can_save',
'datasource_access', 'datasource_access',
'database_access',
'muldelete', 'muldelete',
)): )):
sm.add_permission_role(gamma, perm) sm.add_permission_role(gamma, perm)
@ -239,6 +240,9 @@ def init(caravel):
for table_perm in table_perms: for table_perm in table_perms:
merge_perm(sm, 'datasource_access', table_perm) merge_perm(sm, 'datasource_access', table_perm)
db_perms = [db.perm for db in session.query(models.Database).all()]
for db_perm in db_perms:
merge_perm(sm, 'database_access', db_perm)
init_metrics_perm(caravel) init_metrics_perm(caravel)

View File

@ -407,6 +407,7 @@ class DatabaseView(CaravelModelView, DeleteMixin): # noqa
db.password = conn.password db.password = conn.password
conn.password = "X" * 10 if conn.password else None conn.password = "X" * 10 if conn.password else None
db.sqlalchemy_uri = str(conn) # hides the password db.sqlalchemy_uri = str(conn) # hides the password
utils.merge_perm(sm, 'database_access', db.perm)
def pre_update(self, db): def pre_update(self, db):
self.pre_add(db) self.pre_add(db)
@ -1176,15 +1177,17 @@ class Caravel(BaseCaravelView):
@expose("/sql/<database_id>/") @expose("/sql/<database_id>/")
@log_this @log_this
def sql(self, database_id): def sql(self, database_id):
if (
not self.can_access(
'all_datasource_access', 'all_datasource_access')):
flash(
"This view requires the `all_datasource_access` "
"permission", "danger")
return redirect("/tablemodelview/list/")
mydb = db.session.query( mydb = db.session.query(
models.Database).filter_by(id=database_id).first() models.Database).filter_by(id=database_id).first()
if not (self.can_access(
'all_datasource_access', 'all_datasource_access') or
self.can_access('database_access', mydb.perm)):
flash(
"This view requires the specific database or "
"`all_datasource_access` permission", "danger"
)
return redirect("/tablemodelview/list/")
engine = mydb.get_sqla_engine() engine = mydb.get_sqla_engine()
tables = engine.table_names() tables = engine.table_names()
@ -1221,6 +1224,18 @@ class Caravel(BaseCaravelView):
mydb = db.session.query( mydb = db.session.query(
models.Database).filter_by(id=database_id).first() models.Database).filter_by(id=database_id).first()
t = mydb.get_table(table_name) t = mydb.get_table(table_name)
# Prevent exposing column fields to users that cannot access DB.
if not (self.can_access(
'all_datasource_access', 'all_datasource_access') or
self.can_access('database_access', mydb.perm) or
self.can_access('datasource_access', t.perm)):
flash(
"This view requires the specific database, table or "
"`all_datasource_access` permission", "danger"
)
return redirect("/tablemodelview/list/")
fields = ", ".join( fields = ", ".join(
[c.name for c in t.columns] or "*") [c.name for c in t.columns] or "*")
s = "SELECT\n{}\nFROM {}".format(fields, table_name) s = "SELECT\n{}\nFROM {}".format(fields, table_name)
@ -1242,11 +1257,13 @@ class Caravel(BaseCaravelView):
database_id = data.get('database_id') database_id = data.get('database_id')
mydb = session.query(models.Database).filter_by(id=database_id).first() mydb = session.query(models.Database).filter_by(id=database_id).first()
if ( if not (self.can_access(
not self.can_access( 'all_datasource_access', 'all_datasource_access') or
'all_datasource_access', 'all_datasource_access')): self.can_access('database_access', mydb.perm)):
raise utils.CaravelSecurityException(_( raise utils.CaravelSecurityException(_(
"SQL Lab requires the `all_datasource_access` permission")) "SQL Lab requires the `all_datasource_access` or "
"specific db permission"))
content = "" content = ""
if mydb: if mydb:
eng = mydb.get_sqla_engine() eng = mydb.get_sqla_engine()
@ -1254,10 +1271,12 @@ class Caravel(BaseCaravelView):
sql = sql.strip().strip(';') sql = sql.strip().strip(';')
qry = ( qry = (
select('*') select('*')
.select_from(TextAsFrom(text(sql), ['*']).alias('inner_qry')) .select_from(TextAsFrom(text(sql), ['*'])
.alias('inner_qry'))
.limit(limit) .limit(limit)
) )
sql = str(qry.compile(eng, compile_kwargs={"literal_binds": True})) sql = '{}'.format(qry.compile(
eng, compile_kwargs={"literal_binds": True}))
try: try:
df = pd.read_sql_query(sql=sql, con=eng) df = pd.read_sql_query(sql=sql, con=eng)
content = df.to_html( content = df.to_html(
@ -1289,11 +1308,12 @@ class Caravel(BaseCaravelView):
database_id = request.form.get('database_id') database_id = request.form.get('database_id')
mydb = session.query(models.Database).filter_by(id=database_id).first() mydb = session.query(models.Database).filter_by(id=database_id).first()
if ( if not (self.can_access(
not self.can_access( 'all_datasource_access', 'all_datasource_access') or
'all_datasource_access', 'all_datasource_access')): self.can_access('database_access', mydb.perm)):
raise utils.CaravelSecurityException(_( raise utils.CaravelSecurityException(_(
"This view requires the `all_datasource_access` permission")) "SQL Lab requires the `all_datasource_access` or "
"specific DB permission"))
error_msg = "" error_msg = ""
if not mydb: if not mydb:
@ -1304,10 +1324,12 @@ class Caravel(BaseCaravelView):
sql = sql.strip().strip(';') sql = sql.strip().strip(';')
qry = ( qry = (
select('*') select('*')
.select_from(TextAsFrom(text(sql), ['*']).alias('inner_qry')) .select_from(TextAsFrom(text(sql), ['*'])
.alias('inner_qry'))
.limit(limit) .limit(limit)
) )
sql = str(qry.compile(eng, compile_kwargs={"literal_binds": True})) sql = '{}'.format(qry.compile(
eng, compile_kwargs={"literal_binds": True}))
try: try:
df = pd.read_sql_query(sql=sql, con=eng) df = pd.read_sql_query(sql=sql, con=eng)
df = df.fillna(0) # TODO make sure NULL df = df.fillna(0) # TODO make sure NULL
@ -1328,7 +1350,8 @@ class Caravel(BaseCaravelView):
'columns': [c for c in df.columns], 'columns': [c for c in df.columns],
'data': df.to_dict(orient='records'), 'data': df.to_dict(orient='records'),
} }
return json.dumps(data, default=utils.json_int_dttm_ser, allow_nan=False) return json.dumps(
data, default=utils.json_int_dttm_ser, allow_nan=False)
@has_access @has_access
@expose("/refresh_datasources/") @expose("/refresh_datasources/")
@ -1342,7 +1365,7 @@ class Caravel(BaseCaravelView):
except Exception as e: except Exception as e:
flash( flash(
"Error while processing cluster '{}'\n{}".format( "Error while processing cluster '{}'\n{}".format(
cluster_name, str(e)), cluster_name, utils.error_msg_from_exception(e)),
"danger") "danger")
logging.exception(e) logging.exception(e)
return redirect('/druidclustermodelview/list/') return redirect('/druidclustermodelview/list/')

View File

@ -16,7 +16,7 @@ from flask import escape
from flask_appbuilder.security.sqla import models as ab_models from flask_appbuilder.security.sqla import models as ab_models
import caravel import caravel
from caravel import app, db, models, utils, appbuilder from caravel import app, db, models, utils, appbuilder, sm
from caravel.models import DruidCluster, DruidDatasource from caravel.models import DruidCluster, DruidDatasource
os.environ['CARAVEL_CONFIG'] = 'tests.caravel_test_config' os.environ['CARAVEL_CONFIG'] = 'tests.caravel_test_config'
@ -247,8 +247,8 @@ class CoreTests(CaravelTestCase):
resp = self.client.get('/dashboardmodelview/list/') resp = self.client.get('/dashboardmodelview/list/')
assert "List Dashboard" in resp.data.decode('utf-8') assert "List Dashboard" in resp.data.decode('utf-8')
def run_sql(self, sql): def run_sql(self, sql, user_name):
self.login(username='admin') self.login(username=user_name)
dbid = ( dbid = (
db.session.query(models.Database) db.session.query(models.Database)
.filter_by(database_name="main") .filter_by(database_name="main")
@ -258,13 +258,47 @@ class CoreTests(CaravelTestCase):
'/caravel/sql_json/', '/caravel/sql_json/',
data=dict(database_id=dbid, sql=sql), data=dict(database_id=dbid, sql=sql),
) )
self.logout()
return json.loads(resp.data.decode('utf-8')) return json.loads(resp.data.decode('utf-8'))
def test_sql_json(self): def test_sql_json_no_access(self):
data = self.run_sql("SELECT * FROM ab_user") self.assertRaises(
utils.CaravelSecurityException,
self.run_sql, "SELECT * FROM ab_user", 'gamma')
def test_sql_json_has_access(self):
main_db = (
db.session.query(models.Database).filter_by(database_name="main")
.first()
)
utils.merge_perm(sm, 'database_access', main_db.perm)
db.session.commit()
main_db_permission_view = (
db.session.query(ab_models.PermissionView)
.join(ab_models.ViewMenu)
.filter(ab_models.ViewMenu.name == '[main].(id:1)')
.first()
)
astronaut = sm.add_role("Astronaut")
sm.add_permission_role(astronaut, main_db_permission_view)
# Astronaut role is Gamme + main db permissions
for gamma_perm in sm.find_role('Gamma').permissions:
sm.add_permission_role(astronaut, gamma_perm)
gagarin = appbuilder.sm.find_user('gagarin')
if not gagarin:
appbuilder.sm.add_user(
'gagarin', 'Iurii', 'Gagarin', 'gagarin@cosmos.ussr',
appbuilder.sm.find_role('Astronaut'),
password='general')
data = self.run_sql('SELECT * FROM ab_user', 'gagarin')
assert len(data['data']) > 0 assert len(data['data']) > 0
data = self.run_sql("SELECT * FROM unexistant_table") def test_sql_json(self):
data = self.run_sql("SELECT * FROM ab_user", 'admin')
assert len(data['data']) > 0
data = self.run_sql("SELECT * FROM unexistant_table", 'admin')
assert len(data['error']) > 0 assert len(data['error']) > 0
def test_public_user_dashboard_access(self): def test_public_user_dashboard_access(self):
@ -301,7 +335,6 @@ class CoreTests(CaravelTestCase):
data = resp.data.decode('utf-8') data = resp.data.decode('utf-8')
assert "/caravel/dashboard/world_health/" not in data assert "/caravel/dashboard/world_health/" not in data
def test_only_owners_can_save(self): def test_only_owners_can_save(self):
dash = ( dash = (
db.session db.session