Generalize switch between different datasources (#1078)

* Generalize switch between different datasources.

* Fix previous migration since slice model changed

* Fix warm up cache and other small stuff

* Adding modules and datasources through config

* Replace tabs w/ spaces

* Fix other style issues

* Change add method for SliceModelView to pick the first non-empty ds

* Remove tests on slice add redirect

* Change way of db migration

* Fix styling

* Fix create slice

* Small fixes

* Fix code climate check

* Adding notes on how to create new datasource in CONTRIBUTING.md

* Fix last merge

* A commit just to trigger travis build again

* Add migration to merge two heads

* Fix codeclimate

* Simplify source_registry

* Fix codeclimate

* Remove all getter methods
This commit is contained in:
ShengyaoQian 2016-09-21 09:52:05 -07:00 committed by Maxime Beauchemin
parent ed2feaf84b
commit 5a0e06e7a2
13 changed files with 223 additions and 110 deletions

View File

@ -251,3 +251,20 @@ You can then translate the strings gathered in files located under
to take effect, they need to be compiled using this command:
fabmanager babel-compile --target caravel/translations/
## Adding new datasources
1. Create Models and Views for the datasource, add them under caravel folder, like a new my_models.py
with models for cluster, datasources, columns and metrics and my_views.py with clustermodelview
and datasourcemodelview.
2. Create db migration files for the new models
3. Specify this variable to add the datasource model and from which module it is from in config.py:
For example:
`ADDITIONAL_MODULE_DS_MAP = {'caravel.my_models': ['MyDatasource', 'MyOtherDatasource']}`
This means it'll register MyDatasource and MyOtherDatasource in caravel.my_models module in the source registry.

View File

@ -14,6 +14,7 @@ from sqlalchemy import event, exc
from flask_appbuilder.baseviews import expose
from flask_cache import Cache
from flask_migrate import Migrate
from caravel import source_registry
from werkzeug.contrib.fixers import ProxyFix
@ -95,5 +96,7 @@ appbuilder = AppBuilder(
sm = appbuilder.sm
src_registry = source_registry.SourceRegistry()
get_session = appbuilder.get_session
from caravel import config, views # noqa
from caravel import views, config # noqa

View File

@ -20,6 +20,14 @@ config = app.config
manager = Manager(app)
manager.add_command('db', MigrateCommand)
module_datasource_map = config.get("DEFAULT_MODULE_DS_MAP")
module_datasource_map.update(config.get("ADDITIONAL_MODULE_DS_MAP"))
datasources = {}
for module in module_datasource_map:
datasources[module] = __import__(module, fromlist=module_datasource_map[module])
utils.register_sources(datasources, module_datasource_map, caravel.src_registry)
@manager.option(

View File

@ -164,6 +164,13 @@ VIZ_TYPE_BLACKLIST = []
DRUID_DATA_SOURCE_BLACKLIST = []
# --------------------------------------------------
# Modules and datasources to be registered
# --------------------------------------------------
DEFAULT_MODULE_DS_MAP = {'caravel.models': ['DruidDatasource', 'SqlaTable']}
ADDITIONAL_MODULE_DS_MAP = {}
"""
1) http://docs.python-guide.org/en/latest/writing/logging/
2) https://docs.python.org/2/library/logging.config.html

View File

@ -75,7 +75,7 @@ def load_energy():
slice_name="Energy Sankey",
viz_type='sankey',
datasource_type='table',
table=tbl,
datasource_id=tbl.id,
params=textwrap.dedent("""\
{
"collapsed_fieldsets": "",
@ -105,7 +105,7 @@ def load_energy():
slice_name="Energy Force Layout",
viz_type='directed_force',
datasource_type='table',
table=tbl,
datasource_id=tbl.id,
params=textwrap.dedent("""\
{
"charge": "-500",
@ -136,7 +136,7 @@ def load_energy():
slice_name="Heatmap",
viz_type='heatmap',
datasource_type='table',
table=tbl,
datasource_id=tbl.id,
params=textwrap.dedent("""\
{
"all_columns_x": "source",
@ -224,7 +224,7 @@ def load_world_bank_health_n_pop():
slice_name="Region Filter",
viz_type='filter_box',
datasource_type='table',
table=tbl,
datasource_id=tbl.id,
params=get_slice_json(
defaults,
viz_type='filter_box',
@ -233,7 +233,7 @@ def load_world_bank_health_n_pop():
slice_name="World's Population",
viz_type='big_number',
datasource_type='table',
table=tbl,
datasource_id=tbl.id,
params=get_slice_json(
defaults,
since='2000',
@ -245,7 +245,7 @@ def load_world_bank_health_n_pop():
slice_name="Most Populated Countries",
viz_type='table',
datasource_type='table',
table=tbl,
datasource_id=tbl.id,
params=get_slice_json(
defaults,
viz_type='table',
@ -255,7 +255,7 @@ def load_world_bank_health_n_pop():
slice_name="Growth Rate",
viz_type='line',
datasource_type='table',
table=tbl,
datasource_id=tbl.id,
params=get_slice_json(
defaults,
viz_type='line',
@ -267,7 +267,7 @@ def load_world_bank_health_n_pop():
slice_name="% Rural",
viz_type='world_map',
datasource_type='table',
table=tbl,
datasource_id=tbl.id,
params=get_slice_json(
defaults,
viz_type='world_map',
@ -277,7 +277,7 @@ def load_world_bank_health_n_pop():
slice_name="Life Expectancy VS Rural %",
viz_type='bubble',
datasource_type='table',
table=tbl,
datasource_id=tbl.id,
params=get_slice_json(
defaults,
viz_type='bubble',
@ -298,7 +298,7 @@ def load_world_bank_health_n_pop():
slice_name="Rural Breakdown",
viz_type='sunburst',
datasource_type='table',
table=tbl,
datasource_id=tbl.id,
params=get_slice_json(
defaults,
viz_type='sunburst',
@ -310,7 +310,7 @@ def load_world_bank_health_n_pop():
slice_name="World's Pop Growth",
viz_type='area',
datasource_type='table',
table=tbl,
datasource_id=tbl.id,
params=get_slice_json(
defaults,
since="1960-01-01",
@ -321,7 +321,7 @@ def load_world_bank_health_n_pop():
slice_name="Box plot",
viz_type='box_plot',
datasource_type='table',
table=tbl,
datasource_id=tbl.id,
params=get_slice_json(
defaults,
since="1960-01-01",
@ -333,7 +333,7 @@ def load_world_bank_health_n_pop():
slice_name="Treemap",
viz_type='treemap',
datasource_type='table',
table=tbl,
datasource_id=tbl.id,
params=get_slice_json(
defaults,
since="1960-01-01",
@ -345,7 +345,7 @@ def load_world_bank_health_n_pop():
slice_name="Parallel Coordinates",
viz_type='para',
datasource_type='table',
table=tbl,
datasource_id=tbl.id,
params=get_slice_json(
defaults,
since="2011-01-01",
@ -615,7 +615,7 @@ def load_birth_names():
slice_name="Girls",
viz_type='table',
datasource_type='table',
table=tbl,
datasource_id=tbl.id,
params=get_slice_json(
defaults,
groupby=['name'],
@ -625,7 +625,7 @@ def load_birth_names():
slice_name="Boys",
viz_type='table',
datasource_type='table',
table=tbl,
datasource_id=tbl.id,
params=get_slice_json(
defaults,
groupby=['name'],
@ -636,7 +636,7 @@ def load_birth_names():
slice_name="Participants",
viz_type='big_number',
datasource_type='table',
table=tbl,
datasource_id=tbl.id,
params=get_slice_json(
defaults,
viz_type="big_number", granularity="ds",
@ -645,7 +645,7 @@ def load_birth_names():
slice_name="Genders",
viz_type='pie',
datasource_type='table',
table=tbl,
datasource_id=tbl.id,
params=get_slice_json(
defaults,
viz_type="pie", groupby=['gender'])),
@ -653,7 +653,7 @@ def load_birth_names():
slice_name="Genders by State",
viz_type='dist_bar',
datasource_type='table',
table=tbl,
datasource_id=tbl.id,
params=get_slice_json(
defaults,
flt_eq_1="other", viz_type="dist_bar",
@ -663,7 +663,7 @@ def load_birth_names():
slice_name="Trends",
viz_type='line',
datasource_type='table',
table=tbl,
datasource_id=tbl.id,
params=get_slice_json(
defaults,
viz_type="line", groupby=['name'],
@ -672,7 +672,7 @@ def load_birth_names():
slice_name="Title",
viz_type='markup',
datasource_type='table',
table=tbl,
datasource_id=tbl.id,
params=get_slice_json(
defaults,
viz_type="markup", markup_type="html",
@ -690,7 +690,7 @@ def load_birth_names():
slice_name="Name Cloud",
viz_type='word_cloud',
datasource_type='table',
table=tbl,
datasource_id=tbl.id,
params=get_slice_json(
defaults,
viz_type="word_cloud", size_from="10",
@ -700,7 +700,7 @@ def load_birth_names():
slice_name="Pivot Table",
viz_type='pivot_table',
datasource_type='table',
table=tbl,
datasource_id=tbl.id,
params=get_slice_json(
defaults,
viz_type="pivot_table", metrics=['sum__num'],
@ -709,7 +709,7 @@ def load_birth_names():
slice_name="Number of Girls",
viz_type='big_number_total',
datasource_type='table',
table=tbl,
datasource_id=tbl.id,
params=get_slice_json(
defaults,
viz_type="big_number_total", granularity="ds",
@ -862,7 +862,7 @@ def load_unicode_test_data():
slice_name="Unicode Cloud",
viz_type='word_cloud',
datasource_type='table',
table=tbl,
datasource_id=tbl.id,
params=get_slice_json(slice_data),
)
merge_slice(slc)
@ -935,7 +935,7 @@ def load_random_time_series_data():
slice_name="Calendar Heatmap",
viz_type='cal_heatmap',
datasource_type='table',
table=tbl,
datasource_id=tbl.id,
params=get_slice_json(slice_data),
)
merge_slice(slc)
@ -1005,7 +1005,7 @@ def load_long_lat_data():
slice_name="Mapbox Long/Lat",
viz_type='mapbox',
datasource_type='table',
table=tbl,
datasource_id=tbl.id,
params=get_slice_json(slice_data),
)
merge_slice(slc)
@ -1084,7 +1084,7 @@ def load_multiformat_time_series_data():
slice_name="Calendar Heatmap multiformat" + str(i),
viz_type='cal_heatmap',
datasource_type='table',
table=tbl,
datasource_id=tbl.id,
params=get_slice_json(slice_data),
)
merge_slice(slc)

View File

@ -11,15 +11,34 @@ revision = '27ae655e4247'
down_revision = 'd8bc074f7aad'
from alembic import op
from caravel import db, models
from caravel import db
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship
from flask_appbuilder import Model
from sqlalchemy import (
Column, Integer, ForeignKey, Table)
Base = declarative_base()
class Slice(Base):
"""Declarative class to do query in upgrade"""
__tablename__ = 'slices'
id = Column(Integer, primary_key=True)
class Dashboard(Base):
"""Declarative class to do query in upgrade"""
__tablename__ = 'dashboards'
id = Column(Integer, primary_key=True)
def upgrade():
bind = op.get_bind()
session = db.Session(bind=bind)
objects = session.query(models.Slice).all()
objects += session.query(models.Dashboard).all()
objects = session.query(Slice).all()
objects += session.query(Dashboard).all()
for obj in objects:
if obj.created_by and obj.created_by not in obj.owners:
obj.owners.append(obj.created_by)

View File

@ -0,0 +1,59 @@
from alembic import op
import sqlalchemy as sa
from caravel import db
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import (
Column, Integer, String)
"""update slice model
Revision ID: 33d996bcc382
Revises: 41f6a59a61f2
Create Date: 2016-09-07 23:50:59.366779
"""
# revision identifiers, used by Alembic.
revision = '33d996bcc382'
down_revision = '41f6a59a61f2'
Base = declarative_base()
class Slice(Base):
"""Declarative class to do query in upgrade"""
__tablename__ = 'slices'
id = Column(Integer, primary_key=True)
datasource_id = Column(Integer)
druid_datasource_id = Column(Integer)
table_id = Column(Integer)
datasource_type = Column(String(200))
def upgrade():
bind = op.get_bind()
op.add_column('slices', sa.Column('datasource_id', sa.Integer()))
session = db.Session(bind=bind)
for slc in session.query(Slice).all():
if slc.druid_datasource_id:
slc.datasource_id = slc.druid_datasource_id
if slc.table_id:
slc.datasource_id = slc.table_id
session.merge(slc)
session.commit()
session.close()
def downgrade():
bind = op.get_bind()
session = db.Session(bind=bind)
for slc in session.query(Slice).all():
if slc.datasource_type == 'druid':
slc.druid_datasource_id = slc.datasource_id
if slc.datasource_type == 'table':
slc.table_id = slc.datasource_id
session.merge(slc)
session.commit()
session.close()
op.drop_column('slices', 'datasource_id')

View File

@ -0,0 +1,19 @@
"""empty message
Revision ID: b347b202819b
Revises: ('33d996bcc382', '65903709c321')
Create Date: 2016-09-19 17:22:40.138601
"""
# revision identifiers, used by Alembic.
revision = 'b347b202819b'
down_revision = ('33d996bcc382', '65903709c321')
def upgrade():
pass
def downgrade():
pass

View File

@ -49,7 +49,7 @@ from sqlalchemy_utils import EncryptedType
from werkzeug.datastructures import ImmutableMultiDict
import caravel
from caravel import app, db, get_session, utils, sm
from caravel import app, db, get_session, utils, sm, src_registry
from caravel.viz import viz_types
from caravel.utils import flasher, MetricPermException, DimSelector
@ -156,8 +156,7 @@ class Slice(Model, AuditMixinNullable):
__tablename__ = 'slices'
id = Column(Integer, primary_key=True)
slice_name = Column(String(250))
druid_datasource_id = Column(Integer, ForeignKey('datasources.id'))
table_id = Column(Integer, ForeignKey('tables.id'))
datasource_id = Column(Integer)
datasource_type = Column(String(200))
datasource_name = Column(String(2000))
viz_type = Column(String(250))
@ -165,33 +164,34 @@ class Slice(Model, AuditMixinNullable):
description = Column(Text)
cache_timeout = Column(Integer)
perm = Column(String(2000))
table = relationship(
'SqlaTable', foreign_keys=[table_id], backref='slices')
druid_datasource = relationship(
'DruidDatasource', foreign_keys=[druid_datasource_id], backref='slices')
owners = relationship("User", secondary=slice_user)
def __repr__(self):
return self.slice_name
@property
def cls_model(self):
return src_registry.sources[self.datasource_type]
@property
def datasource(self):
return self.table or self.druid_datasource
return self.get_datasource
@datasource.getter
@utils.memoized
def get_datasource(self):
ds = db.session.query(
self.cls_model).filter_by(
id=self.datasource_id).first()
return ds
@renders('datasource_name')
def datasource_link(self):
if self.table:
return self.table.link
elif self.druid_datasource:
return self.druid_datasource.link
return self.datasource.link
@property
def datasource_edit_url(self):
if self.table:
return self.table.url
elif self.druid_datasource:
return self.druid_datasource.url
self.datasource.url
@property
@utils.memoized
@ -204,10 +204,6 @@ class Slice(Model, AuditMixinNullable):
def description_markeddown(self):
return utils.markdown(self.description)
@property
def datasource_id(self):
return self.table_id or self.druid_datasource_id
@property
def data(self):
"""Data used to render slice in templates"""
@ -283,12 +279,8 @@ class Slice(Model, AuditMixinNullable):
def set_perm(mapper, connection, target): # noqa
if target.table_id:
src_class = SqlaTable
id_ = target.table_id
elif target.druid_datasource_id:
src_class = DruidDatasource
id_ = target.druid_datasource_id
src_class = target.cls_model
id_ = target.datasource_id
ds = db.session.query(src_class).filter_by(id=int(id_)).first()
target.perm = ds.perm

View File

@ -0,0 +1,15 @@
from flask import flash
class SourceRegistry(object):
""" Central Registry for all available datasource engines"""
sources = {}
def add_source(self, ds_type, cls_model):
if ds_type not in self.sources:
self.sources[ds_type] = cls_model
if self.sources[ds_type] is not cls_model:
raise Exception(
'source type: {} is already associated with Model: {}'.format(
ds_type, self.sources[ds_type]))

View File

@ -410,6 +410,14 @@ def readfile(filepath):
return content
def register_sources(datasources, module_datasource_map, registry):
for m in datasources:
datasource_list = module_datasource_map[m]
for ds in datasource_list:
ds_class = getattr(datasources[m], ds)
registry.add_source(ds_class.type, ds_class)
def generic_find_constraint_name(table, columns, referenced, db):
"""Utility to find a constraint name in alembic migrations"""
t = sa.Table(table, db.metadata, autoload=True, autoload_with=db.engine)

View File

@ -33,7 +33,8 @@ from wtforms.validators import ValidationError
import caravel
from caravel import (
appbuilder, cache, db, models, viz, utils, app, sm, ascii_art, sql_lab
appbuilder, cache, db, models, viz, utils, app,
sm, ascii_art, sql_lab, src_registry
)
config = app.config
@ -675,8 +676,7 @@ class SliceModelView(CaravelModelView, DeleteMixin): # noqa
list_columns = [
'slice_link', 'viz_type', 'datasource_link', 'creator', 'modified']
edit_columns = [
'slice_name', 'description', 'viz_type', 'druid_datasource',
'table', 'owners', 'dashboards', 'params', 'cache_timeout']
'slice_name', 'description', 'viz_type', 'owners', 'dashboards', 'params', 'cache_timeout']
base_order = ('changed_on', 'desc')
description_columns = {
'description': Markup(
@ -722,18 +722,13 @@ class SliceModelView(CaravelModelView, DeleteMixin): # noqa
if not widget:
return redirect(self.get_redirect())
a_druid_datasource = db.session.query(models.DruidDatasource).first()
if a_druid_datasource is not None:
url = "/druiddatasourcemodelview/list/"
msg = _(
"Click on a datasource link to create a Slice, "
"or click on a table link "
"<a href='/tablemodelview/list/'>here</a> "
"to create a Slice for a table"
)
else:
url = "/tablemodelview/list/"
msg = _("Click on a table link to create a Slice")
sources = src_registry.sources
for source in sources:
ds = db.session.query(src_registry.sources[source]).first()
if ds is not None:
url = "/{}/list/".format(ds.baselink)
msg = _("Click on a {} link to create a Slice".format(source))
break
redirect_url = "/r/msg/?url={}&msg={}".format(url, msg)
return redirect(redirect_url)
@ -978,8 +973,8 @@ class Caravel(BaseCaravelView):
@log_this
def explore(self, datasource_type, datasource_id, slice_id=None):
error_redirect = '/slicemodelview/list/'
datasource_class = models.SqlaTable \
if datasource_type == "table" else models.DruidDatasource
datasource_class = src_registry.sources[datasource_type]
datasources = (
db.session
.query(datasource_class)
@ -1093,12 +1088,8 @@ class Caravel(BaseCaravelView):
if k not in as_list and isinstance(v, list):
d[k] = v[0]
table_id = druid_datasource_id = None
datasource_type = args.get('datasource_type')
if datasource_type in ('datasource', 'druid'):
druid_datasource_id = args.get('datasource_id')
elif datasource_type == 'table':
table_id = args.get('datasource_id')
datasource_id = args.get('datasource_id')
if action in ('saveas'):
d.pop('slice_id') # don't save old slice_id
@ -1107,9 +1098,8 @@ class Caravel(BaseCaravelView):
slc.params = json.dumps(d, indent=4, sort_keys=True)
slc.datasource_name = args.get('datasource_name')
slc.viz_type = args.get('viz_type')
slc.druid_datasource_id = druid_datasource_id
slc.table_id = table_id
slc.datasource_type = datasource_type
slc.datasource_id = datasource_id
slc.slice_name = slice_name
if action in ('saveas') and slice_add_perm:
@ -1330,7 +1320,9 @@ class Caravel(BaseCaravelView):
json_error_response(__(
"Table %(t)s wasn't found in the database %(d)s",
t=table_name, s=db_name), status=404)
slices = table.slices
slices = session.query(models.Slice).filter_by(
datasource_id=table.id,
datasource_type=table.type).all()
for slice in slices:
try:

View File

@ -210,32 +210,6 @@ class CoreTests(CaravelTestCase):
assert new_slice in dash.slices
assert len(set(dash.slices)) == len(dash.slices)
def test_add_slice_redirect_to_sqla(self, username='admin'):
self.login(username=username)
url = '/slicemodelview/add'
resp = self.client.get(url, follow_redirects=True)
assert (
"Click on a table link to create a Slice" in
resp.data.decode('utf-8')
)
def test_add_slice_redirect_to_druid(self, username='admin'):
datasource = DruidDatasource(
datasource_name="datasource_name",
)
db.session.add(datasource)
db.session.commit()
self.login(username=username)
url = '/slicemodelview/add'
resp = self.client.get(url, follow_redirects=True)
assert (
"Click on a datasource link to create a Slice"
in resp.data.decode('utf-8')
)
db.session.delete(datasource)
db.session.commit()
def test_druid_sync_from_config(self):
cluster = models.DruidCluster(cluster_name="new_druid")