Refactor the explore view (#1252)

* Refactor the explore view

* Fixing the tests

* Addressing comments
This commit is contained in:
Maxime Beauchemin 2016-10-07 16:24:39 -07:00 committed by GitHub
parent b7d1f78f5e
commit f70d301f0d
8 changed files with 194 additions and 137 deletions

View File

@ -12,3 +12,11 @@ class SourceRegistry(object):
for class_name in class_names:
source_class = getattr(module_obj, class_name)
cls.sources[source_class.type] = source_class
@classmethod
def get_datasource(cls, datasource_type, datasource_id, session):
return (
session.query(cls.sources[datasource_type])
.filter_by(id=datasource_id)
.one()
)

View File

@ -1,6 +1,6 @@
<html>
<head>
<title>{{viz.token}}</title>
<title>{{ viz.token }}</title>
<link rel="stylesheet" type="text/css" href="/static/assets/node_modules/font-awesome/css/font-awesome.min.css" />
<link rel="stylesheet" type="text/css" href="/static/assets/stylesheets/caravel.css" />
<link rel="stylesheet" type="text/css" href="/static/appbuilder/css/flags/flags16.css" />

View File

@ -26,7 +26,6 @@ from flask_babel import lazy_gettext as _
from flask_appbuilder.models.sqla.filters import BaseFilter
from sqlalchemy import create_engine
from werkzeug.datastructures import ImmutableMultiDict
from werkzeug.routing import BaseConverter
from wtforms.validators import ValidationError
@ -244,7 +243,8 @@ class FilterDruidDatasource(CaravelFilter):
druid_datasources = []
for perm in perms:
match = re.search(r'\(id:(\d+)\)', perm)
druid_datasources.append(match.group(1))
if match:
druid_datasources.append(match.group(1))
qry = query.filter(self.model.id.in_(druid_datasources))
return qry
@ -672,6 +672,7 @@ class DruidClusterModelView(CaravelModelView, DeleteMixin): # noqa
'broker_port': _("Broker Port"),
'broker_endpoint': _("Broker Endpoint"),
}
def pre_add(self, db):
utils.merge_perm(sm, 'database_access', db.perm)
@ -699,7 +700,8 @@ class SliceModelView(CaravelModelView, DeleteMixin): # noqa
list_columns = [
'slice_link', 'viz_type', 'datasource_link', 'creator', 'modified']
edit_columns = [
'slice_name', 'description', 'viz_type', 'owners', 'dashboards', 'params', 'cache_timeout']
'slice_name', 'description', 'viz_type', 'owners', 'dashboards',
'params', 'cache_timeout']
base_order = ('changed_on', 'desc')
description_columns = {
'description': Markup(
@ -1099,61 +1101,80 @@ class Caravel(BaseCaravelView):
session.commit()
return redirect('/accessrequestsmodelview/list/')
def get_viz(
self,
slice_id=None,
args=None,
datasource_type=None,
datasource_id=None):
if slice_id:
slc = db.session.query(models.Slice).filter_by(id=slice_id).one()
return slc.get_viz()
else:
viz_type = args.get('viz_type', 'table')
datasource = SourceRegistry.get_datasource(
datasource_type, datasource_id, db.session)
viz_obj = viz.viz_types[viz_type](datasource, request.args)
return viz_obj
@has_access
@expose("/explore/<datasource_type>/<datasource_id>/<slice_id>/")
@expose("/explore/<datasource_type>/<datasource_id>/")
@expose("/datasource/<datasource_type>/<datasource_id>/") # Legacy url
@expose("/slice/<slice_id>/")
def slice(self, slice_id):
viz_obj = self.get_viz(slice_id)
return redirect(viz_obj.get_url(**request.args))
@has_access_api
@expose("/explore_json/<datasource_type>/<datasource_id>/")
def explore_json(self, datasource_type, datasource_id):
viz_obj = self.get_viz(
datasource_type=datasource_type,
datasource_id=datasource_id,
args=request.args)
if not self.datasource_access(viz_obj.datasource):
return Response(
json.dumps(
{'error': _("You don't have access to this datasource")}),
status=404,
mimetype="application/json")
return Response(
viz_obj.get_json(),
status=200,
mimetype="application/json")
@log_this
def explore(self, datasource_type, datasource_id, slice_id=None):
@has_access
@expose("/explore/<datasource_type>/<datasource_id>/")
def explore(self, datasource_type, datasource_id):
viz_type = request.args.get("viz_type")
slice_id = request.args.get('slice_id')
slc = db.session.query(models.Slice).filter_by(id=slice_id).first()
error_redirect = '/slicemodelview/list/'
datasource_class = SourceRegistry.sources[datasource_type]
datasources = db.session.query(datasource_class).all()
datasources = sorted(datasources, key=lambda ds: ds.full_name)
datasource = [ds for ds in datasources if int(datasource_id) == ds.id]
datasource = datasource[0] if datasource else None
if not datasource:
viz_obj = self.get_viz(
datasource_type=datasource_type,
datasource_id=datasource_id,
args=request.args)
if not viz_obj.datasource:
flash(DATASOURCE_MISSING_ERR, "alert")
return redirect(error_redirect)
if not self.datasource_access(datasource):
if not self.datasource_access(viz_obj.datasource):
flash(
__(get_datasource_access_error_msg(datasource.name)), "danger")
__(get_datasource_access_error_msg(viz_obj.datasource.name)),
"danger")
return redirect(
'caravel/request_access/?'
'datasource_type={datasource_type}&'
'datasource_id={datasource_id}&'
''.format(**locals()))
request_args_multi_dict = request.args # MultiDict
slice_id = slice_id or request_args_multi_dict.get("slice_id")
slc = None
# build viz_obj and get it's params
if slice_id:
slc = db.session.query(models.Slice).filter_by(id=slice_id).first()
try:
viz_obj = slc.get_viz(
url_params_multidict=request_args_multi_dict)
except Exception as e:
logging.exception(e)
flash(utils.error_msg_from_exception(e), "danger")
return redirect(error_redirect)
else:
viz_type = request_args_multi_dict.get("viz_type")
if not viz_type and datasource.default_endpoint:
return redirect(datasource.default_endpoint)
# default to table if no default endpoint and no viz_type
viz_type = viz_type or "table"
# validate viz params
try:
viz_obj = viz.viz_types[viz_type](
datasource, request_args_multi_dict)
except Exception as e:
logging.exception(e)
flash(utils.error_msg_from_exception(e), "danger")
return redirect(error_redirect)
slice_params_multi_dict = ImmutableMultiDict(viz_obj.orig_form_data)
if not viz_type and viz_obj.datasource.default_endpoint:
return redirect(viz_obj.datasource.default_endpoint)
# slc perms
slice_add_perm = self.can_access('can_add', 'SliceModelView')
@ -1161,45 +1182,29 @@ class Caravel(BaseCaravelView):
slice_download_perm = self.can_access('can_download', 'SliceModelView')
# handle save or overwrite
action = slice_params_multi_dict.get('action')
action = request.args.get('action')
if action in ('saveas', 'overwrite'):
return self.save_or_overwrite_slice(
slice_params_multi_dict, slc, slice_add_perm, slice_edit_perm)
request.args, slc, slice_add_perm, slice_edit_perm)
# handle different endpoints
if slice_params_multi_dict.get("json") == "true":
if config.get("DEBUG"):
# Allows for nice debugger stack traces in debug mode
return Response(
viz_obj.get_json(),
status=200,
mimetype="application/json")
try:
return Response(
viz_obj.get_json(),
status=200,
mimetype="application/json")
except Exception as e:
logging.exception(e)
return json_error_response(utils.error_msg_from_exception(e))
elif slice_params_multi_dict.get("csv") == "true":
if request.args.get("csv") == "true":
payload = viz_obj.get_csv()
return Response(
payload,
status=200,
headers=generate_download_headers("csv"),
mimetype="application/csv")
elif request.args.get("standalone") == "true":
return self.render_template("caravel/standalone.html", viz=viz_obj)
else:
if slice_params_multi_dict.get("standalone") == "true":
template = "caravel/standalone.html"
else:
template = "caravel/explore.html"
return self.render_template(
template, viz=viz_obj, slice=slc, datasources=datasources,
"caravel/explore.html",
viz=viz_obj, slice=slc, datasources=datasources,
can_add=slice_add_perm, can_edit=slice_edit_perm,
can_download=slice_download_perm,
userid=g.user.get_id() if g.user else '')
userid=g.user.get_id() if g.user else ''
)
@has_access
@expose("/exploreV2/<datasource_type>/<datasource_id>/<slice_id>/")
@ -1705,7 +1710,11 @@ class Caravel(BaseCaravelView):
data = json.loads(request.args.get('data'))
table_name = data.get('datasourceName')
viz_type = data.get('chartType')
table = db.session.query(models.SqlaTable).filter_by(table_name=table_name).first()
table = (
db.session.query(models.SqlaTable)
.filter_by(table_name=table_name)
.first()
)
if not table:
table = models.SqlaTable(table_name=table_name)
table.database_id = data.get('dbId')

View File

@ -104,7 +104,7 @@ class BaseViz(object):
def reassignments(self):
pass
def get_url(self, for_cache_key=False, **kwargs):
def get_url(self, for_cache_key=False, json_endpoint=False, **kwargs):
"""Returns the URL for the viz
:param for_cache_key: when getting the url as the identifier to hash
@ -140,8 +140,12 @@ class BaseViz(object):
for item in v:
od.add(key, item)
base_endpoint = '/caravel/explore'
if json_endpoint:
base_endpoint = '/caravel/explore_json'
href = Href(
'/caravel/explore/{self.datasource.type}/'
'{base_endpoint}/{self.datasource.type}/'
'{self.datasource.id}/'.format(**locals()))
if for_cache_key and 'force' in od:
del od['force']
@ -373,7 +377,7 @@ class BaseViz(object):
@property
def json_endpoint(self):
return self.get_url(json="true")
return self.get_url(json_endpoint=True)
@property
def cache_key(self):
@ -1261,7 +1265,6 @@ class HistogramViz(BaseViz):
}
}
def query_obj(self):
"""Returns the query object for this visualization"""
d = super(HistogramViz, self).query_obj()
@ -1272,7 +1275,6 @@ class HistogramViz(BaseViz):
d['columns'] = [numeric_column]
return d
def get_df(self, query_obj=None):
"""Returns a pandas dataframe based on the query object"""
if not query_obj:
@ -1289,7 +1291,6 @@ class HistogramViz(BaseViz):
df = df.fillna(0)
return df
def get_data(self):
"""Returns the chart data"""
df = self.get_df()

View File

@ -5,4 +5,4 @@ export CARAVEL_CONFIG=tests.caravel_test_config
set -e
caravel/bin/caravel version -v
export SOLO_TEST=1
nosetests tests.core_tests:CoreTests.test_public_user_dashboard_access
nosetests tests.core_tests:CoreTests.test_slice_endpoint

View File

@ -81,6 +81,12 @@ class CaravelTestCase(unittest.TestCase):
utils.init(caravel)
def get_or_create(self, cls, criteria, session):
obj = session.query(cls).filter_by(**criteria).first()
if not obj:
obj = cls(**criteria)
return obj
def login(self, username='admin', password='general'):
resp = self.client.post(
'/login/',
@ -104,6 +110,15 @@ class CaravelTestCase(unittest.TestCase):
session.close()
return query
def get_slice(self, slice_name, session):
slc = (
session.query(models.Slice)
.filter_by(slice_name=slice_name)
.one()
)
session.expunge_all()
return slc
def get_resp(self, url):
"""Shortcut to get the parsed results while following redirects"""
resp = self.client.get(url, follow_redirects=True)
@ -124,11 +139,6 @@ class CaravelTestCase(unittest.TestCase):
def logout(self):
self.client.get('/logout/', follow_redirects=True)
def test_welcome(self):
self.login()
resp = self.client.get('/caravel/welcome')
assert 'Welcome' in resp.data.decode('utf-8')
def setup_public_access_for_dashboard(self, table_name):
public_role = appbuilder.sm.find_role('Public')
perms = db.session.query(ab_models.PermissionView).all()

View File

@ -44,6 +44,33 @@ class CoreTests(CaravelTestCase):
def tearDown(self):
pass
def test_welcome(self):
self.login()
resp = self.client.get('/caravel/welcome')
assert 'Welcome' in resp.data.decode('utf-8')
def test_slice_endpoint(self):
self.login(username='admin')
slc = self.get_slice("Girls", db.session)
resp = self.get_resp('/caravel/slice/{}/'.format(slc.id))
assert 'Time Column' in resp
assert 'List Roles' in resp
# Testing overrides
resp = self.get_resp(
'/caravel/slice/{}/?standalone=true'.format(slc.id))
assert 'List Roles' not in resp
def test_endpoints_for_a_slice(self):
self.login(username='admin')
slc = self.get_slice("Girls", db.session)
resp = self.get_resp(slc.viz.csv_endpoint)
assert 'Jennifer,' in resp
resp = self.get_resp(slc.viz.json_endpoint)
assert '"Jennifer"' in resp
def test_admin_only_permissions(self):
def assert_admin_permission_in(role_name, assert_func):
role = sm.find_role(role_name)
@ -73,13 +100,7 @@ class CoreTests(CaravelTestCase):
def test_save_slice(self):
self.login(username='admin')
slc = (
db.session.query(models.Slice.id)
.filter_by(slice_name="Energy Sankey")
.first())
slice_id = slc.id
slice_id = self.get_slice("Energy Sankey", db.session).id
copy_name = "Test Sankey Save"
tbl_id = self.table_ids.get('energy_usage')
url = (

View File

@ -14,7 +14,6 @@ from caravel import db, sm, utils
from caravel.models import DruidCluster, DruidDatasource
from .base_tests import CaravelTestCase
from flask_appbuilder.security.sqla import models as ab_models
SEGMENT_METADATA = [{
@ -118,25 +117,40 @@ class DruidTests(CaravelTestCase):
datasource_id))
assert "[test_cluster].[test_datasource]" in resp.data.decode('utf-8')
resp = self.client.get(
'/caravel/explore/druid/{}/?viz_type=table&granularity=one+day&'
url = (
'/caravel/explore_json/druid/{}/?viz_type=table&granularity=one+day&'
'druid_time_origin=&since=7+days+ago&until=now&row_limit=5000&'
'include_search=false&metrics=count&groupby=name&flt_col_0=dim1&'
'flt_op_0=in&flt_eq_0=&slice_id=&slice_name=&collapsed_fieldsets=&'
'action=&datasource_name=test_datasource&datasource_id={}&'
'datasource_type=druid&previous_viz_type=table&json=true&'
'datasource_type=druid&previous_viz_type=table&'
'force=true'.format(datasource_id, datasource_id))
assert "Canada" in resp.data.decode('utf-8')
resp = self.get_resp(url)
assert "Canada" in resp
def test_druid_sync_from_config(self):
CLUSTER_NAME = 'new_druid'
self.login()
cluster = DruidCluster(cluster_name="new_druid")
db.session.add(cluster)
cluster = self.get_or_create(
DruidCluster,
{'cluster_name': CLUSTER_NAME},
db.session)
db.session.merge(cluster)
db.session.commit()
ds = (
db.session.query(DruidDatasource)
.filter_by(datasource_name='test_click')
.first()
)
if ds:
db.session.delete(ds)
db.session.commit()
cfg = {
"user": "admin",
"cluster": "new_druid",
"cluster": CLUSTER_NAME,
"config": {
"name": "test_click",
"dimensions": ["affiliate_id", "campaign", "first_seen"],
@ -152,30 +166,24 @@ class DruidTests(CaravelTestCase):
}
}
}
resp = self.client.post('/caravel/sync_druid/', data=json.dumps(cfg))
def check():
resp = self.client.post('/caravel/sync_druid/', data=json.dumps(cfg))
druid_ds = db.session.query(DruidDatasource).filter_by(
datasource_name="test_click").first()
col_names = set([c.column_name for c in druid_ds.columns])
assert {"affiliate_id", "campaign", "first_seen"} == col_names
metric_names = {m.metric_name for m in druid_ds.metrics}
assert {"count", "sum"} == metric_names
assert resp.status_code == 201
druid_ds = db.session.query(DruidDatasource).filter_by(
datasource_name="test_click").first()
assert set([c.column_name for c in druid_ds.columns]) == set(
["affiliate_id", "campaign", "first_seen"])
assert set([m.metric_name for m in druid_ds.metrics]) == set(
["count", "sum"])
assert resp.status_code == 201
# datasource exists, not changes required
resp = self.client.post('/caravel/sync_druid/', data=json.dumps(cfg))
druid_ds = db.session.query(DruidDatasource).filter_by(
datasource_name="test_click").first()
assert set([c.column_name for c in druid_ds.columns]) == set(
["affiliate_id", "campaign", "first_seen"])
assert set([m.metric_name for m in druid_ds.metrics]) == set(
["count", "sum"])
assert resp.status_code == 201
check()
# checking twice to make sure a second sync yields the same results
check()
# datasource exists, add new metrics and dimentions
cfg = {
"user": "admin",
"cluster": "new_druid",
"cluster": CLUSTER_NAME,
"config": {
"name": "test_click",
"dimensions": ["affiliate_id", "second_seen"],
@ -200,26 +208,33 @@ class DruidTests(CaravelTestCase):
assert resp.status_code == 201
def test_filter_druid_datasource(self):
gamma_ds = DruidDatasource(
datasource_name="datasource_for_gamma",
)
db.session.add(gamma_ds)
no_gamma_ds = DruidDatasource(
datasource_name="datasource_not_for_gamma",
)
db.session.add(no_gamma_ds)
db.session.commit()
CLUSTER_NAME = 'new_druid'
cluster = self.get_or_create(
DruidCluster,
{'cluster_name': CLUSTER_NAME},
db.session)
db.session.merge(cluster)
gamma_ds = self.get_or_create(
DruidDatasource, {'datasource_name': 'datasource_for_gamma'},
db.session)
gamma_ds.cluster = cluster
db.session.merge(gamma_ds)
no_gamma_ds = self.get_or_create(
DruidDatasource, {'datasource_name': 'datasource_not_for_gamma'},
db.session)
no_gamma_ds.cluster = cluster
db.session.merge(no_gamma_ds)
utils.merge_perm(sm, 'datasource_access', gamma_ds.perm)
utils.merge_perm(sm, 'datasource_access', no_gamma_ds.perm)
db.session.commit()
gamma_ds_permission_view = (
db.session.query(ab_models.PermissionView)
.join(ab_models.ViewMenu)
.filter(ab_models.ViewMenu.name == gamma_ds.perm)
.first()
)
sm.add_permission_role(sm.find_role('Gamma'), gamma_ds_permission_view)
perm = sm.find_permission_view_menu('datasource_access', gamma_ds.perm)
sm.add_permission_role(sm.find_role('Gamma'), perm)
db.session.commit()
self.login(username='gamma')
url = '/druiddatasourcemodelview/list/'
@ -227,13 +242,6 @@ class DruidTests(CaravelTestCase):
assert 'datasource_for_gamma' in resp
assert 'datasource_not_for_gamma' not in resp
def test_add_filter(self, username='admin'):
# navigate to energy_usage slice with "Electricity,heat" in filter values
data = (
"/caravel/explore/table/1/?viz_type=table&groupby=source&metric=count&flt_col_1=source&flt_op_1=in&flt_eq_1=%27Electricity%2Cheat%27"
"&userid=1&datasource_name=energy_usage&datasource_id=1&datasource_type=tablerdo_save=saveas")
assert "source" in self.get_resp(data)
if __name__ == '__main__':
unittest.main()