Refactoring Druid & SQLa into a proper "Connector" interface (#2362)

* Formalizing the Connector interface

* Checkpoint

* Fixing views

* Fixing tests

* Adding migrtion

* Tests

* Final

* Addressing comments
This commit is contained in:
Maxime Beauchemin 2017-03-10 09:11:51 -08:00 committed by GitHub
parent 9a8c3a0447
commit 2969cc9993
32 changed files with 3781 additions and 3573 deletions

View File

@ -13,9 +13,9 @@ from flask import Flask, redirect
from flask_appbuilder import SQLA, AppBuilder, IndexView
from flask_appbuilder.baseviews import expose
from flask_migrate import Migrate
from superset.source_registry import SourceRegistry
from superset.connectors.connector_registry import ConnectorRegistry
from werkzeug.contrib.fixers import ProxyFix
from superset import utils
from superset import utils, config # noqa
APP_DIR = os.path.dirname(__file__)
@ -104,6 +104,6 @@ results_backend = app.config.get("RESULTS_BACKEND")
# Registering sources
module_datasource_map = app.config.get("DEFAULT_MODULE_DS_MAP")
module_datasource_map.update(app.config.get("ADDITIONAL_MODULE_DS_MAP"))
SourceRegistry.register_sources(module_datasource_map)
ConnectorRegistry.register_sources(module_datasource_map)
from superset import views, config # noqa
from superset import views # noqa

View File

@ -13,7 +13,7 @@ from subprocess import Popen
from flask_migrate import MigrateCommand
from flask_script import Manager
from superset import app, db, data, security
from superset import app, db, security
config = app.config
@ -89,6 +89,7 @@ def version(verbose):
help="Load additional test data")
def load_examples(load_test_data):
"""Loads a set of Slices and Dashboards and a supporting dataset """
from superset import data
print("Loading examples into {}".format(db))
data.load_css_templates()

View File

@ -178,7 +178,10 @@ DRUID_DATA_SOURCE_BLACKLIST = []
# --------------------------------------------------
# Modules, datasources and middleware to be registered
# --------------------------------------------------
DEFAULT_MODULE_DS_MAP = {'superset.models': ['DruidDatasource', 'SqlaTable']}
DEFAULT_MODULE_DS_MAP = {
'superset.connectors.druid.models': ['DruidDatasource'],
'superset.connectors.sqla.models': ['SqlaTable'],
}
ADDITIONAL_MODULE_DS_MAP = {}
ADDITIONAL_MIDDLEWARE = []
@ -292,14 +295,17 @@ SILENCE_FAB = True
BLUEPRINTS = []
try:
if CONFIG_PATH_ENV_VAR in os.environ:
# Explicitly import config module that is not in pythonpath; useful
# for case where app is being executed via pex.
print('Loaded your LOCAL configuration at [{}]'.format(
os.environ[CONFIG_PATH_ENV_VAR]))
imp.load_source('superset_config', os.environ[CONFIG_PATH_ENV_VAR])
from superset_config import * # noqa
import superset_config
print('Loaded your LOCAL configuration at [{}]'.format(
superset_config.__file__))
else:
from superset_config import * # noqa
import superset_config
print('Loaded your LOCAL configuration at [{}]'.format(
superset_config.__file__))
except ImportError:
pass

View File

160
superset/connectors/base.py Normal file
View File

@ -0,0 +1,160 @@
import json
from sqlalchemy import Column, Integer, String, Text, Boolean
from superset import utils
from superset.models.helpers import AuditMixinNullable, ImportMixin
class BaseDatasource(AuditMixinNullable, ImportMixin):
"""A common interface to objects that are queryable (tables and datasources)"""
__tablename__ = None # {connector_name}_datasource
# Used to do code highlighting when displaying the query in the UI
query_language = None
@property
def column_names(self):
return sorted([c.column_name for c in self.columns])
@property
def main_dttm_col(self):
return "timestamp"
@property
def groupby_column_names(self):
return sorted([c.column_name for c in self.columns if c.groupby])
@property
def filterable_column_names(self):
return sorted([c.column_name for c in self.columns if c.filterable])
@property
def dttm_cols(self):
return []
@property
def url(self):
return '/{}/edit/{}'.format(self.baselink, self.id)
@property
def explore_url(self):
if self.default_endpoint:
return self.default_endpoint
else:
return "/superset/explore/{obj.type}/{obj.id}/".format(obj=self)
@property
def column_formats(self):
return {
m.metric_name: m.d3format
for m in self.metrics
if m.d3format
}
@property
def data(self):
"""Data representation of the datasource sent to the frontend"""
order_by_choices = []
for s in sorted(self.column_names):
order_by_choices.append((json.dumps([s, True]), s + ' [asc]'))
order_by_choices.append((json.dumps([s, False]), s + ' [desc]'))
d = {
'all_cols': utils.choicify(self.column_names),
'column_formats': self.column_formats,
'edit_url': self.url,
'filter_select': self.filter_select_enabled,
'filterable_cols': utils.choicify(self.filterable_column_names),
'gb_cols': utils.choicify(self.groupby_column_names),
'id': self.id,
'metrics_combo': self.metrics_combo,
'name': self.name,
'order_by_choices': order_by_choices,
'type': self.type,
}
# TODO move this block to SqlaTable.data
if self.type == 'table':
grains = self.database.grains() or []
if grains:
grains = [(g.name, g.name) for g in grains]
d['granularity_sqla'] = utils.choicify(self.dttm_cols)
d['time_grain_sqla'] = grains
return d
class BaseColumn(AuditMixinNullable, ImportMixin):
"""Interface for column"""
__tablename__ = None # {connector_name}_column
id = Column(Integer, primary_key=True)
column_name = Column(String(255))
verbose_name = Column(String(1024))
is_active = Column(Boolean, default=True)
type = Column(String(32))
groupby = Column(Boolean, default=False)
count_distinct = Column(Boolean, default=False)
sum = Column(Boolean, default=False)
avg = Column(Boolean, default=False)
max = Column(Boolean, default=False)
min = Column(Boolean, default=False)
filterable = Column(Boolean, default=False)
description = Column(Text)
# [optional] Set this to support import/export functionality
export_fields = []
def __repr__(self):
return self.column_name
num_types = ('DOUBLE', 'FLOAT', 'INT', 'BIGINT', 'LONG', 'REAL', 'NUMERIC')
date_types = ('DATE', 'TIME', 'DATETIME')
str_types = ('VARCHAR', 'STRING', 'CHAR')
@property
def is_num(self):
return any([t in self.type.upper() for t in self.num_types])
@property
def is_time(self):
return any([t in self.type.upper() for t in self.date_types])
@property
def is_string(self):
return any([t in self.type.upper() for t in self.str_types])
class BaseMetric(AuditMixinNullable, ImportMixin):
"""Interface for Metrics"""
__tablename__ = None # {connector_name}_metric
id = Column(Integer, primary_key=True)
metric_name = Column(String(512))
verbose_name = Column(String(1024))
metric_type = Column(String(32))
description = Column(Text)
is_restricted = Column(Boolean, default=False, nullable=True)
d3format = Column(String(128))
"""
The interface should also declare a datasource relationship pointing
to a derivative of BaseDatasource, along with a FK
datasource_name = Column(
String(255),
ForeignKey('datasources.datasource_name'))
datasource = relationship(
# needs to be altered to point to {Connector}Datasource
'BaseDatasource',
backref=backref('metrics', cascade='all, delete-orphan'),
enable_typechecks=False)
"""
@property
def perm(self):
raise NotImplementedError()

View File

@ -1,7 +1,7 @@
from sqlalchemy.orm import subqueryload
class SourceRegistry(object):
class ConnectorRegistry(object):
""" Central Registry for all available datasource engines"""
sources = {}
@ -26,15 +26,15 @@ class SourceRegistry(object):
@classmethod
def get_all_datasources(cls, session):
datasources = []
for source_type in SourceRegistry.sources:
for source_type in ConnectorRegistry.sources:
datasources.extend(
session.query(SourceRegistry.sources[source_type]).all())
session.query(ConnectorRegistry.sources[source_type]).all())
return datasources
@classmethod
def get_datasource_by_name(cls, session, datasource_type, datasource_name,
schema, database_name):
datasource_class = SourceRegistry.sources[datasource_type]
datasource_class = ConnectorRegistry.sources[datasource_type]
datasources = session.query(datasource_class).all()
# Filter datasoures that don't have database.
@ -45,7 +45,7 @@ class SourceRegistry(object):
@classmethod
def query_datasources_by_permissions(cls, session, database, permissions):
datasource_class = SourceRegistry.sources[database.type]
datasource_class = ConnectorRegistry.sources[database.type]
return (
session.query(datasource_class)
.filter_by(database_id=database.id)
@ -56,7 +56,7 @@ class SourceRegistry(object):
@classmethod
def get_eager_datasource(cls, session, datasource_type, datasource_id):
"""Returns datasource with columns and metrics."""
datasource_class = SourceRegistry.sources[datasource_type]
datasource_class = ConnectorRegistry.sources[datasource_type]
return (
session.query(datasource_class)
.options(

View File

@ -0,0 +1,2 @@
from . import models # noqa
from . import views # noqa

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,203 @@
import sqlalchemy as sqla
from flask import Markup
from flask_appbuilder import CompactCRUDMixin
from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask_babel import lazy_gettext as _
from flask_babel import gettext as __
import superset
from superset import db, utils, appbuilder, sm, security
from superset.views.base import (
SupersetModelView, validate_json, DeleteMixin, ListWidgetWithCheckboxes,
DatasourceFilter, get_datasource_exist_error_mgs)
from . import models
class DruidColumnInlineView(CompactCRUDMixin, SupersetModelView): # noqa
datamodel = SQLAInterface(models.DruidColumn)
edit_columns = [
'column_name', 'description', 'dimension_spec_json', 'datasource',
'groupby', 'count_distinct', 'sum', 'min', 'max']
add_columns = edit_columns
list_columns = [
'column_name', 'type', 'groupby', 'filterable', 'count_distinct',
'sum', 'min', 'max']
can_delete = False
page_size = 500
label_columns = {
'column_name': _("Column"),
'type': _("Type"),
'datasource': _("Datasource"),
'groupby': _("Groupable"),
'filterable': _("Filterable"),
'count_distinct': _("Count Distinct"),
'sum': _("Sum"),
'min': _("Min"),
'max': _("Max"),
}
description_columns = {
'dimension_spec_json': utils.markdown(
"this field can be used to specify "
"a `dimensionSpec` as documented [here]"
"(http://druid.io/docs/latest/querying/dimensionspecs.html). "
"Make sure to input valid JSON and that the "
"`outputName` matches the `column_name` defined "
"above.",
True),
}
def post_update(self, col):
col.generate_metrics()
utils.validate_json(col.dimension_spec_json)
def post_add(self, col):
self.post_update(col)
appbuilder.add_view_no_menu(DruidColumnInlineView)
class DruidMetricInlineView(CompactCRUDMixin, SupersetModelView): # noqa
datamodel = SQLAInterface(models.DruidMetric)
list_columns = ['metric_name', 'verbose_name', 'metric_type']
edit_columns = [
'metric_name', 'description', 'verbose_name', 'metric_type', 'json',
'datasource', 'd3format', 'is_restricted']
add_columns = edit_columns
page_size = 500
validators_columns = {
'json': [validate_json],
}
description_columns = {
'metric_type': utils.markdown(
"use `postagg` as the metric type if you are defining a "
"[Druid Post Aggregation]"
"(http://druid.io/docs/latest/querying/post-aggregations.html)",
True),
'is_restricted': _("Whether the access to this metric is restricted "
"to certain roles. Only roles with the permission "
"'metric access on XXX (the name of this metric)' "
"are allowed to access this metric"),
}
label_columns = {
'metric_name': _("Metric"),
'description': _("Description"),
'verbose_name': _("Verbose Name"),
'metric_type': _("Type"),
'json': _("JSON"),
'datasource': _("Druid Datasource"),
}
def post_add(self, metric):
utils.init_metrics_perm(superset, [metric])
def post_update(self, metric):
utils.init_metrics_perm(superset, [metric])
appbuilder.add_view_no_menu(DruidMetricInlineView)
class DruidClusterModelView(SupersetModelView, DeleteMixin): # noqa
datamodel = SQLAInterface(models.DruidCluster)
add_columns = [
'cluster_name',
'coordinator_host', 'coordinator_port', 'coordinator_endpoint',
'broker_host', 'broker_port', 'broker_endpoint', 'cache_timeout',
]
edit_columns = add_columns
list_columns = ['cluster_name', 'metadata_last_refreshed']
label_columns = {
'cluster_name': _("Cluster"),
'coordinator_host': _("Coordinator Host"),
'coordinator_port': _("Coordinator Port"),
'coordinator_endpoint': _("Coordinator Endpoint"),
'broker_host': _("Broker Host"),
'broker_port': _("Broker Port"),
'broker_endpoint': _("Broker Endpoint"),
}
def pre_add(self, cluster):
security.merge_perm(sm, 'database_access', cluster.perm)
def pre_update(self, cluster):
self.pre_add(cluster)
appbuilder.add_view(
DruidClusterModelView,
name="Druid Clusters",
label=__("Druid Clusters"),
icon="fa-cubes",
category="Sources",
category_label=__("Sources"),
category_icon='fa-database',)
class DruidDatasourceModelView(SupersetModelView, DeleteMixin): # noqa
datamodel = SQLAInterface(models.DruidDatasource)
list_widget = ListWidgetWithCheckboxes
list_columns = [
'datasource_link', 'cluster', 'changed_by_', 'changed_on_', 'offset']
order_columns = [
'datasource_link', 'changed_on_', 'offset']
related_views = [DruidColumnInlineView, DruidMetricInlineView]
edit_columns = [
'datasource_name', 'cluster', 'description', 'owner',
'is_featured', 'is_hidden', 'filter_select_enabled',
'default_endpoint', 'offset', 'cache_timeout']
add_columns = edit_columns
show_columns = add_columns + ['perm']
page_size = 500
base_order = ('datasource_name', 'asc')
description_columns = {
'offset': _("Timezone offset (in hours) for this datasource"),
'description': Markup(
"Supports <a href='"
"https://daringfireball.net/projects/markdown/'>markdown</a>"),
}
base_filters = [['id', DatasourceFilter, lambda: []]]
label_columns = {
'datasource_link': _("Data Source"),
'cluster': _("Cluster"),
'description': _("Description"),
'owner': _("Owner"),
'is_featured': _("Is Featured"),
'is_hidden': _("Is Hidden"),
'filter_select_enabled': _("Enable Filter Select"),
'default_endpoint': _("Default Endpoint"),
'offset': _("Time Offset"),
'cache_timeout': _("Cache Timeout"),
}
def pre_add(self, datasource):
number_of_existing_datasources = db.session.query(
sqla.func.count('*')).filter(
models.DruidDatasource.datasource_name ==
datasource.datasource_name,
models.DruidDatasource.cluster_name == datasource.cluster.id
).scalar()
# table object is already added to the session
if number_of_existing_datasources > 1:
raise Exception(get_datasource_exist_error_mgs(
datasource.full_name))
def post_add(self, datasource):
datasource.generate_metrics()
security.merge_perm(sm, 'datasource_access', datasource.get_perm())
if datasource.schema:
security.merge_perm(sm, 'schema_access', datasource.schema_perm)
def post_update(self, datasource):
self.post_add(datasource)
appbuilder.add_view(
DruidDatasourceModelView,
"Druid Datasources",
label=__("Druid Datasources"),
category="Sources",
category_label=__("Sources"),
icon="fa-cube")

View File

@ -0,0 +1,2 @@
from . import models # noqa
from . import views # noqa

View File

@ -0,0 +1,664 @@
from datetime import datetime
import logging
import sqlparse
import pandas as pd
from sqlalchemy import (
Column, Integer, String, ForeignKey, Text, Boolean,
DateTime,
)
import sqlalchemy as sa
from sqlalchemy import asc, and_, desc, select
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.expression import ColumnClause, TextAsFrom
from sqlalchemy.orm import backref, relationship
from sqlalchemy.sql import table, literal_column, text, column
from flask import escape, Markup
from flask_appbuilder import Model
from flask_babel import lazy_gettext as _
from superset import db, utils, import_util
from superset.connectors.base import BaseDatasource, BaseColumn, BaseMetric
from superset.utils import (
wrap_clause_in_parens,
DTTM_ALIAS, QueryStatus
)
from superset.models.helpers import QueryResult
from superset.models.core import Database
from superset.jinja_context import get_template_processor
from superset.models.helpers import set_perm
class TableColumn(Model, BaseColumn):
"""ORM object for table columns, each table can have multiple columns"""
__tablename__ = 'table_columns'
table_id = Column(Integer, ForeignKey('tables.id'))
table = relationship(
'SqlaTable',
backref=backref('columns', cascade='all, delete-orphan'),
foreign_keys=[table_id])
is_dttm = Column(Boolean, default=False)
expression = Column(Text, default='')
python_date_format = Column(String(255))
database_expression = Column(String(255))
export_fields = (
'table_id', 'column_name', 'verbose_name', 'is_dttm', 'is_active',
'type', 'groupby', 'count_distinct', 'sum', 'avg', 'max', 'min',
'filterable', 'expression', 'description', 'python_date_format',
'database_expression'
)
@property
def sqla_col(self):
name = self.column_name
if not self.expression:
col = column(self.column_name).label(name)
else:
col = literal_column(self.expression).label(name)
return col
def get_time_filter(self, start_dttm, end_dttm):
col = self.sqla_col.label('__time')
return and_(
col >= text(self.dttm_sql_literal(start_dttm)),
col <= text(self.dttm_sql_literal(end_dttm)),
)
def get_timestamp_expression(self, time_grain):
"""Getting the time component of the query"""
expr = self.expression or self.column_name
if not self.expression and not time_grain:
return column(expr, type_=DateTime).label(DTTM_ALIAS)
if time_grain:
pdf = self.python_date_format
if pdf in ('epoch_s', 'epoch_ms'):
# if epoch, translate to DATE using db specific conf
db_spec = self.table.database.db_engine_spec
if pdf == 'epoch_s':
expr = db_spec.epoch_to_dttm().format(col=expr)
elif pdf == 'epoch_ms':
expr = db_spec.epoch_ms_to_dttm().format(col=expr)
grain = self.table.database.grains_dict().get(time_grain, '{col}')
expr = grain.function.format(col=expr)
return literal_column(expr, type_=DateTime).label(DTTM_ALIAS)
@classmethod
def import_obj(cls, i_column):
def lookup_obj(lookup_column):
return db.session.query(TableColumn).filter(
TableColumn.table_id == lookup_column.table_id,
TableColumn.column_name == lookup_column.column_name).first()
return import_util.import_simple_obj(db.session, i_column, lookup_obj)
def dttm_sql_literal(self, dttm):
"""Convert datetime object to a SQL expression string
If database_expression is empty, the internal dttm
will be parsed as the string with the pattern that
the user inputted (python_date_format)
If database_expression is not empty, the internal dttm
will be parsed as the sql sentence for the database to convert
"""
tf = self.python_date_format or '%Y-%m-%d %H:%M:%S.%f'
if self.database_expression:
return self.database_expression.format(dttm.strftime('%Y-%m-%d %H:%M:%S'))
elif tf == 'epoch_s':
return str((dttm - datetime(1970, 1, 1)).total_seconds())
elif tf == 'epoch_ms':
return str((dttm - datetime(1970, 1, 1)).total_seconds() * 1000.0)
else:
s = self.table.database.db_engine_spec.convert_dttm(
self.type, dttm)
return s or "'{}'".format(dttm.strftime(tf))
class SqlMetric(Model, BaseMetric):
"""ORM object for metrics, each table can have multiple metrics"""
__tablename__ = 'sql_metrics'
table_id = Column(Integer, ForeignKey('tables.id'))
table = relationship(
'SqlaTable',
backref=backref('metrics', cascade='all, delete-orphan'),
foreign_keys=[table_id])
expression = Column(Text)
export_fields = (
'metric_name', 'verbose_name', 'metric_type', 'table_id', 'expression',
'description', 'is_restricted', 'd3format')
@property
def sqla_col(self):
name = self.metric_name
return literal_column(self.expression).label(name)
@property
def perm(self):
return (
"{parent_name}.[{obj.metric_name}](id:{obj.id})"
).format(obj=self,
parent_name=self.table.full_name) if self.table else None
@classmethod
def import_obj(cls, i_metric):
def lookup_obj(lookup_metric):
return db.session.query(SqlMetric).filter(
SqlMetric.table_id == lookup_metric.table_id,
SqlMetric.metric_name == lookup_metric.metric_name).first()
return import_util.import_simple_obj(db.session, i_metric, lookup_obj)
class SqlaTable(Model, BaseDatasource):
"""An ORM object for SqlAlchemy table references"""
type = "table"
query_language = 'sql'
metric_class = SqlMetric
__tablename__ = 'tables'
id = Column(Integer, primary_key=True)
table_name = Column(String(250))
main_dttm_col = Column(String(250))
description = Column(Text)
default_endpoint = Column(Text)
database_id = Column(Integer, ForeignKey('dbs.id'), nullable=False)
is_featured = Column(Boolean, default=False)
filter_select_enabled = Column(Boolean, default=False)
user_id = Column(Integer, ForeignKey('ab_user.id'))
owner = relationship('User', backref='tables', foreign_keys=[user_id])
database = relationship(
'Database',
backref=backref('tables', cascade='all, delete-orphan'),
foreign_keys=[database_id])
offset = Column(Integer, default=0)
cache_timeout = Column(Integer)
schema = Column(String(255))
sql = Column(Text)
params = Column(Text)
perm = Column(String(1000))
baselink = "tablemodelview"
column_cls = TableColumn
metric_cls = SqlMetric
export_fields = (
'table_name', 'main_dttm_col', 'description', 'default_endpoint',
'database_id', 'is_featured', 'offset', 'cache_timeout', 'schema',
'sql', 'params')
__table_args__ = (
sa.UniqueConstraint(
'database_id', 'schema', 'table_name',
name='_customer_location_uc'),)
def __repr__(self):
return self.name
@property
def description_markeddown(self):
return utils.markdown(self.description)
@property
def link(self):
name = escape(self.name)
return Markup(
'<a href="{self.explore_url}">{name}</a>'.format(**locals()))
@property
def schema_perm(self):
"""Returns schema permission if present, database one otherwise."""
return utils.get_schema_perm(self.database, self.schema)
def get_perm(self):
return (
"[{obj.database}].[{obj.table_name}]"
"(id:{obj.id})").format(obj=self)
@property
def name(self):
if not self.schema:
return self.table_name
return "{}.{}".format(self.schema, self.table_name)
@property
def full_name(self):
return utils.get_datasource_full_name(
self.database, self.table_name, schema=self.schema)
@property
def dttm_cols(self):
l = [c.column_name for c in self.columns if c.is_dttm]
if self.main_dttm_col and self.main_dttm_col not in l:
l.append(self.main_dttm_col)
return l
@property
def num_cols(self):
return [c.column_name for c in self.columns if c.is_num]
@property
def any_dttm_col(self):
cols = self.dttm_cols
if cols:
return cols[0]
@property
def html(self):
t = ((c.column_name, c.type) for c in self.columns)
df = pd.DataFrame(t)
df.columns = ['field', 'type']
return df.to_html(
index=False,
classes=(
"dataframe table table-striped table-bordered "
"table-condensed"))
@property
def metrics_combo(self):
return sorted(
[
(m.metric_name, m.verbose_name or m.metric_name)
for m in self.metrics],
key=lambda x: x[1])
@property
def sql_url(self):
return self.database.sql_url + "?table_name=" + str(self.table_name)
@property
def time_column_grains(self):
return {
"time_columns": self.dttm_cols,
"time_grains": [grain.name for grain in self.database.grains()]
}
def get_col(self, col_name):
columns = self.columns
for col in columns:
if col_name == col.column_name:
return col
def values_for_column(self,
column_name,
from_dttm,
to_dttm,
limit=500):
"""Runs query against sqla to retrieve some
sample values for the given column.
"""
granularity = self.main_dttm_col
cols = {col.column_name: col for col in self.columns}
target_col = cols[column_name]
tbl = table(self.table_name)
qry = sa.select([target_col.sqla_col])
qry = qry.select_from(tbl)
qry = qry.distinct(column_name)
qry = qry.limit(limit)
if granularity:
dttm_col = cols[granularity]
timestamp = dttm_col.sqla_col.label('timestamp')
time_filter = [
timestamp >= text(dttm_col.dttm_sql_literal(from_dttm)),
timestamp <= text(dttm_col.dttm_sql_literal(to_dttm)),
]
qry = qry.where(and_(*time_filter))
engine = self.database.get_sqla_engine()
sql = "{}".format(
qry.compile(
engine, compile_kwargs={"literal_binds": True}, ),
)
return pd.read_sql_query(
sql=sql,
con=engine
)
def get_query_str( # sqla
self, engine, qry_start_dttm,
groupby, metrics,
granularity,
from_dttm, to_dttm,
filter=None, # noqa
is_timeseries=True,
timeseries_limit=15,
timeseries_limit_metric=None,
row_limit=None,
inner_from_dttm=None,
inner_to_dttm=None,
orderby=None,
extras=None,
columns=None):
"""Querying any sqla table from this common interface"""
template_processor = get_template_processor(
table=self, database=self.database)
# For backward compatibility
if granularity not in self.dttm_cols:
granularity = self.main_dttm_col
cols = {col.column_name: col for col in self.columns}
metrics_dict = {m.metric_name: m for m in self.metrics}
if not granularity and is_timeseries:
raise Exception(_(
"Datetime column not provided as part table configuration "
"and is required by this type of chart"))
for m in metrics:
if m not in metrics_dict:
raise Exception(_("Metric '{}' is not valid".format(m)))
metrics_exprs = [metrics_dict.get(m).sqla_col for m in metrics]
timeseries_limit_metric = metrics_dict.get(timeseries_limit_metric)
timeseries_limit_metric_expr = None
if timeseries_limit_metric:
timeseries_limit_metric_expr = \
timeseries_limit_metric.sqla_col
if metrics:
main_metric_expr = metrics_exprs[0]
else:
main_metric_expr = literal_column("COUNT(*)").label("ccount")
select_exprs = []
groupby_exprs = []
if groupby:
select_exprs = []
inner_select_exprs = []
inner_groupby_exprs = []
for s in groupby:
col = cols[s]
outer = col.sqla_col
inner = col.sqla_col.label(col.column_name + '__')
groupby_exprs.append(outer)
select_exprs.append(outer)
inner_groupby_exprs.append(inner)
inner_select_exprs.append(inner)
elif columns:
for s in columns:
select_exprs.append(cols[s].sqla_col)
metrics_exprs = []
if granularity:
@compiles(ColumnClause)
def visit_column(element, compiler, **kw):
"""Patch for sqlalchemy bug
TODO: sqlalchemy 1.2 release should be doing this on its own.
Patch only if the column clause is specific for DateTime
set and granularity is selected.
"""
text = compiler.visit_column(element, **kw)
try:
if (
element.is_literal and
hasattr(element.type, 'python_type') and
type(element.type) is DateTime
):
text = text.replace('%%', '%')
except NotImplementedError:
# Some elements raise NotImplementedError for python_type
pass
return text
dttm_col = cols[granularity]
time_grain = extras.get('time_grain_sqla')
if is_timeseries:
timestamp = dttm_col.get_timestamp_expression(time_grain)
select_exprs += [timestamp]
groupby_exprs += [timestamp]
time_filter = dttm_col.get_time_filter(from_dttm, to_dttm)
select_exprs += metrics_exprs
qry = sa.select(select_exprs)
tbl = table(self.table_name)
if self.schema:
tbl.schema = self.schema
# Supporting arbitrary SQL statements in place of tables
if self.sql:
tbl = TextAsFrom(sa.text(self.sql), []).alias('expr_qry')
if not columns:
qry = qry.group_by(*groupby_exprs)
where_clause_and = []
having_clause_and = []
for flt in filter:
if not all([flt.get(s) for s in ['col', 'op', 'val']]):
continue
col = flt['col']
op = flt['op']
eq = flt['val']
col_obj = cols.get(col)
if col_obj and op in ('in', 'not in'):
values = [types.strip("'").strip('"') for types in eq]
if col_obj.is_num:
values = [utils.js_string_to_num(s) for s in values]
cond = col_obj.sqla_col.in_(values)
if op == 'not in':
cond = ~cond
where_clause_and.append(cond)
if extras:
where = extras.get('where')
if where:
where_clause_and += [wrap_clause_in_parens(
template_processor.process_template(where))]
having = extras.get('having')
if having:
having_clause_and += [wrap_clause_in_parens(
template_processor.process_template(having))]
if granularity:
qry = qry.where(and_(*([time_filter] + where_clause_and)))
else:
qry = qry.where(and_(*where_clause_and))
qry = qry.having(and_(*having_clause_and))
if groupby:
qry = qry.order_by(desc(main_metric_expr))
elif orderby:
for col, ascending in orderby:
direction = asc if ascending else desc
qry = qry.order_by(direction(col))
qry = qry.limit(row_limit)
if is_timeseries and timeseries_limit and groupby:
# some sql dialects require for order by expressions
# to also be in the select clause -- others, e.g. vertica,
# require a unique inner alias
inner_main_metric_expr = main_metric_expr.label('mme_inner__')
inner_select_exprs += [inner_main_metric_expr]
subq = select(inner_select_exprs)
subq = subq.select_from(tbl)
inner_time_filter = dttm_col.get_time_filter(
inner_from_dttm or from_dttm,
inner_to_dttm or to_dttm,
)
subq = subq.where(and_(*(where_clause_and + [inner_time_filter])))
subq = subq.group_by(*inner_groupby_exprs)
ob = inner_main_metric_expr
if timeseries_limit_metric_expr is not None:
ob = timeseries_limit_metric_expr
subq = subq.order_by(desc(ob))
subq = subq.limit(timeseries_limit)
on_clause = []
for i, gb in enumerate(groupby):
on_clause.append(
groupby_exprs[i] == column(gb + '__'))
tbl = tbl.join(subq.alias(), and_(*on_clause))
qry = qry.select_from(tbl)
sql = "{}".format(
qry.compile(
engine, compile_kwargs={"literal_binds": True},),
)
logging.info(sql)
sql = sqlparse.format(sql, reindent=True)
return sql
def query(self, query_obj):
qry_start_dttm = datetime.now()
engine = self.database.get_sqla_engine()
sql = self.get_query_str(engine, qry_start_dttm, **query_obj)
status = QueryStatus.SUCCESS
error_message = None
df = None
try:
df = pd.read_sql_query(sql, con=engine)
except Exception as e:
status = QueryStatus.FAILED
error_message = str(e)
return QueryResult(
status=status,
df=df,
duration=datetime.now() - qry_start_dttm,
query=sql,
error_message=error_message)
def get_sqla_table_object(self):
return self.database.get_table(self.table_name, schema=self.schema)
def fetch_metadata(self):
"""Fetches the metadata for the table and merges it in"""
try:
table = self.get_sqla_table_object()
except Exception:
raise Exception(
"Table doesn't seem to exist in the specified database, "
"couldn't fetch column information")
TC = TableColumn # noqa shortcut to class
M = SqlMetric # noqa
metrics = []
any_date_col = None
for col in table.columns:
try:
datatype = "{}".format(col.type).upper()
except Exception as e:
datatype = "UNKNOWN"
logging.error(
"Unrecognized data type in {}.{}".format(table, col.name))
logging.exception(e)
dbcol = (
db.session
.query(TC)
.filter(TC.table == self)
.filter(TC.column_name == col.name)
.first()
)
db.session.flush()
if not dbcol:
dbcol = TableColumn(column_name=col.name, type=datatype)
dbcol.groupby = dbcol.is_string
dbcol.filterable = dbcol.is_string
dbcol.sum = dbcol.is_num
dbcol.avg = dbcol.is_num
dbcol.is_dttm = dbcol.is_time
db.session.merge(self)
self.columns.append(dbcol)
if not any_date_col and dbcol.is_time:
any_date_col = col.name
quoted = "{}".format(
column(dbcol.column_name).compile(dialect=db.engine.dialect))
if dbcol.sum:
metrics.append(M(
metric_name='sum__' + dbcol.column_name,
verbose_name='sum__' + dbcol.column_name,
metric_type='sum',
expression="SUM({})".format(quoted)
))
if dbcol.avg:
metrics.append(M(
metric_name='avg__' + dbcol.column_name,
verbose_name='avg__' + dbcol.column_name,
metric_type='avg',
expression="AVG({})".format(quoted)
))
if dbcol.max:
metrics.append(M(
metric_name='max__' + dbcol.column_name,
verbose_name='max__' + dbcol.column_name,
metric_type='max',
expression="MAX({})".format(quoted)
))
if dbcol.min:
metrics.append(M(
metric_name='min__' + dbcol.column_name,
verbose_name='min__' + dbcol.column_name,
metric_type='min',
expression="MIN({})".format(quoted)
))
if dbcol.count_distinct:
metrics.append(M(
metric_name='count_distinct__' + dbcol.column_name,
verbose_name='count_distinct__' + dbcol.column_name,
metric_type='count_distinct',
expression="COUNT(DISTINCT {})".format(quoted)
))
dbcol.type = datatype
db.session.merge(self)
db.session.commit()
metrics.append(M(
metric_name='count',
verbose_name='COUNT(*)',
metric_type='count',
expression="COUNT(*)"
))
for metric in metrics:
m = (
db.session.query(M)
.filter(M.metric_name == metric.metric_name)
.filter(M.table_id == self.id)
.first()
)
metric.table_id = self.id
if not m:
db.session.add(metric)
db.session.commit()
if not self.main_dttm_col:
self.main_dttm_col = any_date_col
@classmethod
def import_obj(cls, i_datasource, import_time=None):
"""Imports the datasource from the object to the database.
Metrics and columns and datasource will be overrided if exists.
This function can be used to import/export dashboards between multiple
superset instances. Audit metadata isn't copies over.
"""
def lookup_sqlatable(table):
return db.session.query(SqlaTable).join(Database).filter(
SqlaTable.table_name == table.table_name,
SqlaTable.schema == table.schema,
Database.id == table.database_id,
).first()
def lookup_database(table):
return db.session.query(Database).filter_by(
database_name=table.params_dict['database_name']).one()
return import_util.import_datasource(
db.session, i_datasource, lookup_database, lookup_sqlatable,
import_time)
sa.event.listen(SqlaTable, 'after_insert', set_perm)
sa.event.listen(SqlaTable, 'after_update', set_perm)

View File

@ -0,0 +1,213 @@
import logging
from flask import Markup, flash
from flask_appbuilder import CompactCRUDMixin
from flask_appbuilder.models.sqla.interface import SQLAInterface
import sqlalchemy as sa
from flask_babel import lazy_gettext as _
from flask_babel import gettext as __
from superset import appbuilder, db, utils, security, sm
from superset.views.base import (
SupersetModelView, ListWidgetWithCheckboxes, DeleteMixin, DatasourceFilter,
get_datasource_exist_error_mgs,
)
from . import models
class TableColumnInlineView(CompactCRUDMixin, SupersetModelView): # noqa
datamodel = SQLAInterface(models.TableColumn)
can_delete = False
list_widget = ListWidgetWithCheckboxes
edit_columns = [
'column_name', 'verbose_name', 'description', 'groupby', 'filterable',
'table', 'count_distinct', 'sum', 'min', 'max', 'expression',
'is_dttm', 'python_date_format', 'database_expression']
add_columns = edit_columns
list_columns = [
'column_name', 'type', 'groupby', 'filterable', 'count_distinct',
'sum', 'min', 'max', 'is_dttm']
page_size = 500
description_columns = {
'is_dttm': (_(
"Whether to make this column available as a "
"[Time Granularity] option, column has to be DATETIME or "
"DATETIME-like")),
'expression': utils.markdown(
"a valid SQL expression as supported by the underlying backend. "
"Example: `substr(name, 1, 1)`", True),
'python_date_format': utils.markdown(Markup(
"The pattern of timestamp format, use "
"<a href='https://docs.python.org/2/library/"
"datetime.html#strftime-strptime-behavior'>"
"python datetime string pattern</a> "
"expression. If time is stored in epoch "
"format, put `epoch_s` or `epoch_ms`. Leave `Database Expression` "
"below empty if timestamp is stored in "
"String or Integer(epoch) type"), True),
'database_expression': utils.markdown(
"The database expression to cast internal datetime "
"constants to database date/timestamp type according to the DBAPI. "
"The expression should follow the pattern of "
"%Y-%m-%d %H:%M:%S, based on different DBAPI. "
"The string should be a python string formatter \n"
"`Ex: TO_DATE('{}', 'YYYY-MM-DD HH24:MI:SS')` for Oracle"
"Superset uses default expression based on DB URI if this "
"field is blank.", True),
}
label_columns = {
'column_name': _("Column"),
'verbose_name': _("Verbose Name"),
'description': _("Description"),
'groupby': _("Groupable"),
'filterable': _("Filterable"),
'table': _("Table"),
'count_distinct': _("Count Distinct"),
'sum': _("Sum"),
'min': _("Min"),
'max': _("Max"),
'expression': _("Expression"),
'is_dttm': _("Is temporal"),
'python_date_format': _("Datetime Format"),
'database_expression': _("Database Expression")
}
appbuilder.add_view_no_menu(TableColumnInlineView)
class SqlMetricInlineView(CompactCRUDMixin, SupersetModelView): # noqa
datamodel = SQLAInterface(models.SqlMetric)
list_columns = ['metric_name', 'verbose_name', 'metric_type']
edit_columns = [
'metric_name', 'description', 'verbose_name', 'metric_type',
'expression', 'table', 'd3format', 'is_restricted']
description_columns = {
'expression': utils.markdown(
"a valid SQL expression as supported by the underlying backend. "
"Example: `count(DISTINCT userid)`", True),
'is_restricted': _("Whether the access to this metric is restricted "
"to certain roles. Only roles with the permission "
"'metric access on XXX (the name of this metric)' "
"are allowed to access this metric"),
'd3format': utils.markdown(
"d3 formatting string as defined [here]"
"(https://github.com/d3/d3-format/blob/master/README.md#format). "
"For instance, this default formatting applies in the Table "
"visualization and allow for different metric to use different "
"formats", True
),
}
add_columns = edit_columns
page_size = 500
label_columns = {
'metric_name': _("Metric"),
'description': _("Description"),
'verbose_name': _("Verbose Name"),
'metric_type': _("Type"),
'expression': _("SQL Expression"),
'table': _("Table"),
}
def post_add(self, metric):
if metric.is_restricted:
security.merge_perm(sm, 'metric_access', metric.get_perm())
def post_update(self, metric):
if metric.is_restricted:
security.merge_perm(sm, 'metric_access', metric.get_perm())
appbuilder.add_view_no_menu(SqlMetricInlineView)
class TableModelView(SupersetModelView, DeleteMixin): # noqa
datamodel = SQLAInterface(models.SqlaTable)
list_columns = [
'link', 'database', 'is_featured',
'changed_by_', 'changed_on_']
order_columns = [
'link', 'database', 'is_featured', 'changed_on_']
add_columns = ['database', 'schema', 'table_name']
edit_columns = [
'table_name', 'sql', 'is_featured', 'filter_select_enabled',
'database', 'schema',
'description', 'owner',
'main_dttm_col', 'default_endpoint', 'offset', 'cache_timeout']
show_columns = edit_columns + ['perm']
related_views = [TableColumnInlineView, SqlMetricInlineView]
base_order = ('changed_on', 'desc')
description_columns = {
'offset': _("Timezone offset (in hours) for this datasource"),
'table_name': _(
"Name of the table that exists in the source database"),
'schema': _(
"Schema, as used only in some databases like Postgres, Redshift "
"and DB2"),
'description': Markup(
"Supports <a href='https://daringfireball.net/projects/markdown/'>"
"markdown</a>"),
'sql': _(
"This fields acts a Superset view, meaning that Superset will "
"run a query against this string as a subquery."
),
}
base_filters = [['id', DatasourceFilter, lambda: []]]
label_columns = {
'link': _("Table"),
'changed_by_': _("Changed By"),
'database': _("Database"),
'changed_on_': _("Last Changed"),
'is_featured': _("Is Featured"),
'filter_select_enabled': _("Enable Filter Select"),
'schema': _("Schema"),
'default_endpoint': _("Default Endpoint"),
'offset': _("Offset"),
'cache_timeout': _("Cache Timeout"),
}
def pre_add(self, table):
number_of_existing_tables = db.session.query(
sa.func.count('*')).filter(
models.SqlaTable.table_name == table.table_name,
models.SqlaTable.schema == table.schema,
models.SqlaTable.database_id == table.database.id
).scalar()
# table object is already added to the session
if number_of_existing_tables > 1:
raise Exception(get_datasource_exist_error_mgs(table.full_name))
# Fail before adding if the table can't be found
try:
table.get_sqla_table_object()
except Exception as e:
logging.exception(e)
raise Exception(
"Table [{}] could not be found, "
"please double check your "
"database connection, schema, and "
"table name".format(table.name))
def post_add(self, table):
table.fetch_metadata()
security.merge_perm(sm, 'datasource_access', table.get_perm())
if table.schema:
security.merge_perm(sm, 'schema_access', table.schema_perm)
flash(_(
"The table was created. As part of this two phase configuration "
"process, you should now click the edit button by "
"the new table to configure it."),
"info")
def post_update(self, table):
self.post_add(table)
appbuilder.add_view(
TableModelView,
"Tables",
label=__("Tables"),
category="Sources",
category_label=__("Sources"),
icon='fa-table',)
appbuilder.add_separator("Sources")

View File

@ -14,15 +14,19 @@ import random
import pandas as pd
from sqlalchemy import String, DateTime, Date, Float, BigInteger
from superset import app, db, models, utils
from superset import app, db, utils
from superset.models import core as models
from superset.security import get_or_create_main_db
from superset.connectors.connector_registry import ConnectorRegistry
# Shortcuts
DB = models.Database
Slice = models.Slice
TBL = models.SqlaTable
Dash = models.Dashboard
TBL = ConnectorRegistry.sources['table']
config = app.config
DATA_FOLDER = os.path.join(config.get("BASE_DIR"), 'data')

View File

@ -0,0 +1,22 @@
"""adding verbose_name to druid column
Revision ID: b318dfe5fb6c
Revises: d6db5a5cdb5d
Create Date: 2017-03-08 11:48:10.835741
"""
# revision identifiers, used by Alembic.
revision = 'b318dfe5fb6c'
down_revision = 'd6db5a5cdb5d'
from alembic import op
import sqlalchemy as sa
def upgrade():
op.add_column('columns', sa.Column('verbose_name', sa.String(length=1024), nullable=True))
def downgrade():
op.drop_column('columns', 'verbose_name')

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
from . import core # noqa

951
superset/models/core.py Normal file
View File

@ -0,0 +1,951 @@
"""A collection of ORM sqlalchemy models for Superset"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import functools
import json
import logging
import numpy
import pickle
import re
import textwrap
from future.standard_library import install_aliases
from copy import copy
from datetime import datetime, date
import pandas as pd
import sqlalchemy as sqla
from sqlalchemy.engine.url import make_url
from sqlalchemy.orm import subqueryload
from flask import escape, g, Markup, request
from flask_appbuilder import Model
from flask_appbuilder.models.decorators import renders
from sqlalchemy import (
Column, Integer, String, ForeignKey, Text, Boolean,
DateTime, Date, Table, Numeric,
create_engine, MetaData, select
)
from sqlalchemy.orm import backref, relationship
from sqlalchemy.orm.session import make_transient
from sqlalchemy.sql import text
from sqlalchemy.sql.expression import TextAsFrom
from sqlalchemy_utils import EncryptedType
from superset import app, db, db_engine_specs, utils, sm
from superset.connectors.connector_registry import ConnectorRegistry
from superset.viz import viz_types
from superset.utils import QueryStatus
from superset.models.helpers import AuditMixinNullable, ImportMixin, set_perm
install_aliases()
from urllib import parse # noqa
config = app.config
def set_related_perm(mapper, connection, target): # noqa
src_class = target.cls_model
id_ = target.datasource_id
ds = db.session.query(src_class).filter_by(id=int(id_)).first()
target.perm = ds.perm
class Url(Model, AuditMixinNullable):
"""Used for the short url feature"""
__tablename__ = 'url'
id = Column(Integer, primary_key=True)
url = Column(Text)
class KeyValue(Model):
"""Used for any type of key-value store"""
__tablename__ = 'keyvalue'
id = Column(Integer, primary_key=True)
value = Column(Text, nullable=False)
class CssTemplate(Model, AuditMixinNullable):
"""CSS templates for dashboards"""
__tablename__ = 'css_templates'
id = Column(Integer, primary_key=True)
template_name = Column(String(250))
css = Column(Text, default='')
slice_user = Table('slice_user', Model.metadata,
Column('id', Integer, primary_key=True),
Column('user_id', Integer, ForeignKey('ab_user.id')),
Column('slice_id', Integer, ForeignKey('slices.id'))
)
class Slice(Model, AuditMixinNullable, ImportMixin):
"""A slice is essentially a report or a view on data"""
__tablename__ = 'slices'
id = Column(Integer, primary_key=True)
slice_name = Column(String(250))
datasource_id = Column(Integer)
datasource_type = Column(String(200))
datasource_name = Column(String(2000))
viz_type = Column(String(250))
params = Column(Text)
description = Column(Text)
cache_timeout = Column(Integer)
perm = Column(String(1000))
owners = relationship("User", secondary=slice_user)
export_fields = ('slice_name', 'datasource_type', 'datasource_name',
'viz_type', 'params', 'cache_timeout')
def __repr__(self):
return self.slice_name
@property
def cls_model(self):
return ConnectorRegistry.sources[self.datasource_type]
@property
def datasource(self):
return self.get_datasource
@datasource.getter
@utils.memoized
def get_datasource(self):
ds = db.session.query(
self.cls_model).filter_by(
id=self.datasource_id).first()
return ds
@renders('datasource_name')
def datasource_link(self):
datasource = self.datasource
if datasource:
return self.datasource.link
@property
def datasource_edit_url(self):
self.datasource.url
@property
@utils.memoized
def viz(self):
d = json.loads(self.params)
viz_class = viz_types[self.viz_type]
return viz_class(self.datasource, form_data=d)
@property
def description_markeddown(self):
return utils.markdown(self.description)
@property
def data(self):
"""Data used to render slice in templates"""
d = {}
self.token = ''
try:
d = self.viz.data
self.token = d.get('token')
except Exception as e:
logging.exception(e)
d['error'] = str(e)
return {
'datasource': self.datasource_name,
'description': self.description,
'description_markeddown': self.description_markeddown,
'edit_url': self.edit_url,
'form_data': self.form_data,
'slice_id': self.id,
'slice_name': self.slice_name,
'slice_url': self.slice_url,
}
@property
def json_data(self):
return json.dumps(self.data)
@property
def form_data(self):
form_data = json.loads(self.params)
form_data['slice_id'] = self.id
form_data['viz_type'] = self.viz_type
form_data['datasource'] = (
str(self.datasource_id) + '__' + self.datasource_type)
return form_data
@property
def slice_url(self):
"""Defines the url to access the slice"""
return (
"/superset/explore/{obj.datasource_type}/"
"{obj.datasource_id}/?form_data={params}".format(
obj=self, params=parse.quote(json.dumps(self.form_data))))
@property
def slice_id_url(self):
return (
"/superset/{slc.datasource_type}/{slc.datasource_id}/{slc.id}/"
).format(slc=self)
@property
def edit_url(self):
return "/slicemodelview/edit/{}".format(self.id)
@property
def slice_link(self):
url = self.slice_url
name = escape(self.slice_name)
return Markup('<a href="{url}">{name}</a>'.format(**locals()))
def get_viz(self, url_params_multidict=None):
"""Creates :py:class:viz.BaseViz object from the url_params_multidict.
:param werkzeug.datastructures.MultiDict url_params_multidict:
Contains the visualization params, they override the self.params
stored in the database
:return: object of the 'viz_type' type that is taken from the
url_params_multidict or self.params.
:rtype: :py:class:viz.BaseViz
"""
slice_params = json.loads(self.params)
slice_params['slice_id'] = self.id
slice_params['json'] = "false"
slice_params['slice_name'] = self.slice_name
slice_params['viz_type'] = self.viz_type if self.viz_type else "table"
return viz_types[slice_params.get('viz_type')](
self.datasource,
form_data=slice_params,
slice_=self
)
@classmethod
def import_obj(cls, slc_to_import, import_time=None):
"""Inserts or overrides slc in the database.
remote_id and import_time fields in params_dict are set to track the
slice origin and ensure correct overrides for multiple imports.
Slice.perm is used to find the datasources and connect them.
"""
session = db.session
make_transient(slc_to_import)
slc_to_import.dashboards = []
slc_to_import.alter_params(
remote_id=slc_to_import.id, import_time=import_time)
# find if the slice was already imported
slc_to_override = None
for slc in session.query(Slice).all():
if ('remote_id' in slc.params_dict and
slc.params_dict['remote_id'] == slc_to_import.id):
slc_to_override = slc
slc_to_import = slc_to_import.copy()
params = slc_to_import.params_dict
slc_to_import.datasource_id = ConnectorRegistry.get_datasource_by_name(
session, slc_to_import.datasource_type, params['datasource_name'],
params['schema'], params['database_name']).id
if slc_to_override:
slc_to_override.override(slc_to_import)
session.flush()
return slc_to_override.id
session.add(slc_to_import)
logging.info('Final slice: {}'.format(slc_to_import.to_json()))
session.flush()
return slc_to_import.id
sqla.event.listen(Slice, 'before_insert', set_related_perm)
sqla.event.listen(Slice, 'before_update', set_related_perm)
dashboard_slices = Table(
'dashboard_slices', Model.metadata,
Column('id', Integer, primary_key=True),
Column('dashboard_id', Integer, ForeignKey('dashboards.id')),
Column('slice_id', Integer, ForeignKey('slices.id')),
)
dashboard_user = Table(
'dashboard_user', Model.metadata,
Column('id', Integer, primary_key=True),
Column('user_id', Integer, ForeignKey('ab_user.id')),
Column('dashboard_id', Integer, ForeignKey('dashboards.id'))
)
class Dashboard(Model, AuditMixinNullable, ImportMixin):
"""The dashboard object!"""
__tablename__ = 'dashboards'
id = Column(Integer, primary_key=True)
dashboard_title = Column(String(500))
position_json = Column(Text)
description = Column(Text)
css = Column(Text)
json_metadata = Column(Text)
slug = Column(String(255), unique=True)
slices = relationship(
'Slice', secondary=dashboard_slices, backref='dashboards')
owners = relationship("User", secondary=dashboard_user)
export_fields = ('dashboard_title', 'position_json', 'json_metadata',
'description', 'css', 'slug')
def __repr__(self):
return self.dashboard_title
@property
def table_names(self):
return ", ".join(
{"{}".format(s.datasource.name) for s in self.slices})
@property
def url(self):
return "/superset/dashboard/{}/".format(self.slug or self.id)
@property
def datasources(self):
return {slc.datasource for slc in self.slices}
@property
def sqla_metadata(self):
metadata = MetaData(bind=self.get_sqla_engine())
return metadata.reflect()
def dashboard_link(self):
title = escape(self.dashboard_title)
return Markup(
'<a href="{self.url}">{title}</a>'.format(**locals()))
@property
def json_data(self):
positions = self.position_json
if positions:
positions = json.loads(positions)
d = {
'id': self.id,
'metadata': self.params_dict,
'css': self.css,
'dashboard_title': self.dashboard_title,
'slug': self.slug,
'slices': [slc.data for slc in self.slices],
'position_json': positions,
}
return json.dumps(d)
@property
def params(self):
return self.json_metadata
@params.setter
def params(self, value):
self.json_metadata = value
@property
def position_array(self):
if self.position_json:
return json.loads(self.position_json)
return []
@classmethod
def import_obj(cls, dashboard_to_import, import_time=None):
"""Imports the dashboard from the object to the database.
Once dashboard is imported, json_metadata field is extended and stores
remote_id and import_time. It helps to decide if the dashboard has to
be overridden or just copies over. Slices that belong to this
dashboard will be wired to existing tables. This function can be used
to import/export dashboards between multiple superset instances.
Audit metadata isn't copies over.
"""
def alter_positions(dashboard, old_to_new_slc_id_dict):
""" Updates slice_ids in the position json.
Sample position json:
[{
"col": 5,
"row": 10,
"size_x": 4,
"size_y": 2,
"slice_id": "3610"
}]
"""
position_array = dashboard.position_array
for position in position_array:
if 'slice_id' not in position:
continue
old_slice_id = int(position['slice_id'])
if old_slice_id in old_to_new_slc_id_dict:
position['slice_id'] = '{}'.format(
old_to_new_slc_id_dict[old_slice_id])
dashboard.position_json = json.dumps(position_array)
logging.info('Started import of the dashboard: {}'
.format(dashboard_to_import.to_json()))
session = db.session
logging.info('Dashboard has {} slices'
.format(len(dashboard_to_import.slices)))
# copy slices object as Slice.import_slice will mutate the slice
# and will remove the existing dashboard - slice association
slices = copy(dashboard_to_import.slices)
old_to_new_slc_id_dict = {}
new_filter_immune_slices = []
new_expanded_slices = {}
i_params_dict = dashboard_to_import.params_dict
for slc in slices:
logging.info('Importing slice {} from the dashboard: {}'.format(
slc.to_json(), dashboard_to_import.dashboard_title))
new_slc_id = Slice.import_obj(slc, import_time=import_time)
old_to_new_slc_id_dict[slc.id] = new_slc_id
# update json metadata that deals with slice ids
new_slc_id_str = '{}'.format(new_slc_id)
old_slc_id_str = '{}'.format(slc.id)
if ('filter_immune_slices' in i_params_dict and
old_slc_id_str in i_params_dict['filter_immune_slices']):
new_filter_immune_slices.append(new_slc_id_str)
if ('expanded_slices' in i_params_dict and
old_slc_id_str in i_params_dict['expanded_slices']):
new_expanded_slices[new_slc_id_str] = (
i_params_dict['expanded_slices'][old_slc_id_str])
# override the dashboard
existing_dashboard = None
for dash in session.query(Dashboard).all():
if ('remote_id' in dash.params_dict and
dash.params_dict['remote_id'] ==
dashboard_to_import.id):
existing_dashboard = dash
dashboard_to_import.id = None
alter_positions(dashboard_to_import, old_to_new_slc_id_dict)
dashboard_to_import.alter_params(import_time=import_time)
if new_expanded_slices:
dashboard_to_import.alter_params(
expanded_slices=new_expanded_slices)
if new_filter_immune_slices:
dashboard_to_import.alter_params(
filter_immune_slices=new_filter_immune_slices)
new_slices = session.query(Slice).filter(
Slice.id.in_(old_to_new_slc_id_dict.values())).all()
if existing_dashboard:
existing_dashboard.override(dashboard_to_import)
existing_dashboard.slices = new_slices
session.flush()
return existing_dashboard.id
else:
# session.add(dashboard_to_import) causes sqlachemy failures
# related to the attached users / slices. Creating new object
# allows to avoid conflicts in the sql alchemy state.
copied_dash = dashboard_to_import.copy()
copied_dash.slices = new_slices
session.add(copied_dash)
session.flush()
return copied_dash.id
@classmethod
def export_dashboards(cls, dashboard_ids):
copied_dashboards = []
datasource_ids = set()
for dashboard_id in dashboard_ids:
# make sure that dashboard_id is an integer
dashboard_id = int(dashboard_id)
copied_dashboard = (
db.session.query(Dashboard)
.options(subqueryload(Dashboard.slices))
.filter_by(id=dashboard_id).first()
)
make_transient(copied_dashboard)
for slc in copied_dashboard.slices:
datasource_ids.add((slc.datasource_id, slc.datasource_type))
# add extra params for the import
slc.alter_params(
remote_id=slc.id,
datasource_name=slc.datasource.name,
schema=slc.datasource.name,
database_name=slc.datasource.database.name,
)
copied_dashboard.alter_params(remote_id=dashboard_id)
copied_dashboards.append(copied_dashboard)
eager_datasources = []
for dashboard_id, dashboard_type in datasource_ids:
eager_datasource = ConnectorRegistry.get_eager_datasource(
db.session, dashboard_type, dashboard_id)
eager_datasource.alter_params(
remote_id=eager_datasource.id,
database_name=eager_datasource.database.name,
)
make_transient(eager_datasource)
eager_datasources.append(eager_datasource)
return pickle.dumps({
'dashboards': copied_dashboards,
'datasources': eager_datasources,
})
class Database(Model, AuditMixinNullable):
"""An ORM object that stores Database related information"""
__tablename__ = 'dbs'
type = "table"
id = Column(Integer, primary_key=True)
database_name = Column(String(250), unique=True)
sqlalchemy_uri = Column(String(1024))
password = Column(EncryptedType(String(1024), config.get('SECRET_KEY')))
cache_timeout = Column(Integer)
select_as_create_table_as = Column(Boolean, default=False)
expose_in_sqllab = Column(Boolean, default=False)
allow_run_sync = Column(Boolean, default=True)
allow_run_async = Column(Boolean, default=False)
allow_ctas = Column(Boolean, default=False)
allow_dml = Column(Boolean, default=False)
force_ctas_schema = Column(String(250))
extra = Column(Text, default=textwrap.dedent("""\
{
"metadata_params": {},
"engine_params": {}
}
"""))
perm = Column(String(1000))
def __repr__(self):
return self.database_name
@property
def name(self):
return self.database_name
@property
def backend(self):
url = make_url(self.sqlalchemy_uri_decrypted)
return url.get_backend_name()
def set_sqlalchemy_uri(self, uri):
password_mask = "X" * 10
conn = sqla.engine.url.make_url(uri)
if conn.password != password_mask:
# do not over-write the password with the password mask
self.password = conn.password
conn.password = password_mask if conn.password else None
self.sqlalchemy_uri = str(conn) # hides the password
def get_sqla_engine(self, schema=None):
extra = self.get_extra()
url = make_url(self.sqlalchemy_uri_decrypted)
params = extra.get('engine_params', {})
url.database = self.get_database_for_various_backend(url, schema)
return create_engine(url, **params)
def get_database_for_various_backend(self, uri, default_database=None):
database = uri.database
if self.backend == 'presto' and default_database:
if '/' in database:
database = database.split('/')[0] + '/' + default_database
else:
database += '/' + default_database
# Postgres and Redshift use the concept of schema as a logical entity
# on top of the database, so the database should not be changed
# even if passed default_database
elif self.backend == 'redshift' or self.backend == 'postgresql':
pass
elif default_database:
database = default_database
return database
def get_reserved_words(self):
return self.get_sqla_engine().dialect.preparer.reserved_words
def get_quoter(self):
return self.get_sqla_engine().dialect.identifier_preparer.quote
def get_df(self, sql, schema):
sql = sql.strip().strip(';')
eng = self.get_sqla_engine(schema=schema)
cur = eng.execute(sql, schema=schema)
cols = [col[0] for col in cur.cursor.description]
df = pd.DataFrame(cur.fetchall(), columns=cols)
def needs_conversion(df_series):
if df_series.empty:
return False
for df_type in [list, dict]:
if isinstance(df_series[0], df_type):
return True
return False
for k, v in df.dtypes.iteritems():
if v.type == numpy.object_ and needs_conversion(df[k]):
df[k] = df[k].apply(utils.json_dumps_w_dates)
return df
def compile_sqla_query(self, qry, schema=None):
eng = self.get_sqla_engine(schema=schema)
compiled = qry.compile(eng, compile_kwargs={"literal_binds": True})
return '{}'.format(compiled)
def select_star(
self, table_name, schema=None, limit=100, show_cols=False,
indent=True):
"""Generates a ``select *`` statement in the proper dialect"""
return self.db_engine_spec.select_star(
self, table_name, schema=schema, limit=limit, show_cols=show_cols,
indent=indent)
def wrap_sql_limit(self, sql, limit=1000):
qry = (
select('*')
.select_from(
TextAsFrom(text(sql), ['*'])
.alias('inner_qry')
).limit(limit)
)
return self.compile_sqla_query(qry)
def safe_sqlalchemy_uri(self):
return self.sqlalchemy_uri
@property
def inspector(self):
engine = self.get_sqla_engine()
return sqla.inspect(engine)
def all_table_names(self, schema=None, force=False):
if not schema:
tables_dict = self.db_engine_spec.fetch_result_sets(
self, 'table', force=force)
return tables_dict.get("", [])
return sorted(self.inspector.get_table_names(schema))
def all_view_names(self, schema=None, force=False):
if not schema:
views_dict = self.db_engine_spec.fetch_result_sets(
self, 'view', force=force)
return views_dict.get("", [])
views = []
try:
views = self.inspector.get_view_names(schema)
except Exception:
pass
return views
def all_schema_names(self):
return sorted(self.inspector.get_schema_names())
@property
def db_engine_spec(self):
engine_name = self.get_sqla_engine().name or 'base'
return db_engine_specs.engines.get(
engine_name, db_engine_specs.BaseEngineSpec)
def grains(self):
"""Defines time granularity database-specific expressions.
The idea here is to make it easy for users to change the time grain
form a datetime (maybe the source grain is arbitrary timestamps, daily
or 5 minutes increments) to another, "truncated" datetime. Since
each database has slightly different but similar datetime functions,
this allows a mapping between database engines and actual functions.
"""
return self.db_engine_spec.time_grains
def grains_dict(self):
return {grain.name: grain for grain in self.grains()}
def get_extra(self):
extra = {}
if self.extra:
try:
extra = json.loads(self.extra)
except Exception as e:
logging.error(e)
return extra
def get_table(self, table_name, schema=None):
extra = self.get_extra()
meta = MetaData(**extra.get('metadata_params', {}))
return Table(
table_name, meta,
schema=schema or None,
autoload=True,
autoload_with=self.get_sqla_engine())
def get_columns(self, table_name, schema=None):
return self.inspector.get_columns(table_name, schema)
def get_indexes(self, table_name, schema=None):
return self.inspector.get_indexes(table_name, schema)
def get_pk_constraint(self, table_name, schema=None):
return self.inspector.get_pk_constraint(table_name, schema)
def get_foreign_keys(self, table_name, schema=None):
return self.inspector.get_foreign_keys(table_name, schema)
@property
def sqlalchemy_uri_decrypted(self):
conn = sqla.engine.url.make_url(self.sqlalchemy_uri)
conn.password = self.password
return str(conn)
@property
def sql_url(self):
return '/superset/sql/{}/'.format(self.id)
def get_perm(self):
return (
"[{obj.database_name}].(id:{obj.id})").format(obj=self)
sqla.event.listen(Database, 'after_insert', set_perm)
sqla.event.listen(Database, 'after_update', set_perm)
class Log(Model):
"""ORM object used to log Superset actions to the database"""
__tablename__ = 'logs'
id = Column(Integer, primary_key=True)
action = Column(String(512))
user_id = Column(Integer, ForeignKey('ab_user.id'))
dashboard_id = Column(Integer)
slice_id = Column(Integer)
json = Column(Text)
user = relationship('User', backref='logs', foreign_keys=[user_id])
dttm = Column(DateTime, default=datetime.utcnow)
dt = Column(Date, default=date.today())
duration_ms = Column(Integer)
referrer = Column(String(1024))
@classmethod
def log_this(cls, f):
"""Decorator to log user actions"""
@functools.wraps(f)
def wrapper(*args, **kwargs):
start_dttm = datetime.now()
user_id = None
if g.user:
user_id = g.user.get_id()
d = request.args.to_dict()
post_data = request.form or {}
d.update(post_data)
d.update(kwargs)
slice_id = d.get('slice_id', 0)
try:
slice_id = int(slice_id) if slice_id else 0
except ValueError:
slice_id = 0
params = ""
try:
params = json.dumps(d)
except:
pass
value = f(*args, **kwargs)
sesh = db.session()
log = cls(
action=f.__name__,
json=params,
dashboard_id=d.get('dashboard_id') or None,
slice_id=slice_id,
duration_ms=(
datetime.now() - start_dttm).total_seconds() * 1000,
referrer=request.referrer[:1000] if request.referrer else None,
user_id=user_id)
sesh.add(log)
sesh.commit()
return value
return wrapper
class FavStar(Model):
__tablename__ = 'favstar'
id = Column(Integer, primary_key=True)
user_id = Column(Integer, ForeignKey('ab_user.id'))
class_name = Column(String(50))
obj_id = Column(Integer)
dttm = Column(DateTime, default=datetime.utcnow)
class Query(Model):
"""ORM model for SQL query"""
__tablename__ = 'query'
id = Column(Integer, primary_key=True)
client_id = Column(String(11), unique=True, nullable=False)
database_id = Column(Integer, ForeignKey('dbs.id'), nullable=False)
# Store the tmp table into the DB only if the user asks for it.
tmp_table_name = Column(String(256))
user_id = Column(
Integer, ForeignKey('ab_user.id'), nullable=True)
status = Column(String(16), default=QueryStatus.PENDING)
tab_name = Column(String(256))
sql_editor_id = Column(String(256))
schema = Column(String(256))
sql = Column(Text)
# Query to retrieve the results,
# used only in case of select_as_cta_used is true.
select_sql = Column(Text)
executed_sql = Column(Text)
# Could be configured in the superset config.
limit = Column(Integer)
limit_used = Column(Boolean, default=False)
limit_reached = Column(Boolean, default=False)
select_as_cta = Column(Boolean)
select_as_cta_used = Column(Boolean, default=False)
progress = Column(Integer, default=0) # 1..100
# # of rows in the result set or rows modified.
rows = Column(Integer)
error_message = Column(Text)
# key used to store the results in the results backend
results_key = Column(String(64), index=True)
# Using Numeric in place of DateTime for sub-second precision
# stored as seconds since epoch, allowing for milliseconds
start_time = Column(Numeric(precision=3))
end_time = Column(Numeric(precision=3))
changed_on = Column(
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=True)
database = relationship(
'Database',
foreign_keys=[database_id],
backref=backref('queries', cascade='all, delete-orphan')
)
user = relationship(
'User',
backref=backref('queries', cascade='all, delete-orphan'),
foreign_keys=[user_id])
__table_args__ = (
sqla.Index('ti_user_id_changed_on', user_id, changed_on),
)
@property
def limit_reached(self):
return self.rows == self.limit if self.limit_used else False
def to_dict(self):
return {
'changedOn': self.changed_on,
'changed_on': self.changed_on.isoformat(),
'dbId': self.database_id,
'db': self.database.database_name,
'endDttm': self.end_time,
'errorMessage': self.error_message,
'executedSql': self.executed_sql,
'id': self.client_id,
'limit': self.limit,
'progress': self.progress,
'rows': self.rows,
'schema': self.schema,
'ctas': self.select_as_cta,
'serverId': self.id,
'sql': self.sql,
'sqlEditorId': self.sql_editor_id,
'startDttm': self.start_time,
'state': self.status.lower(),
'tab': self.tab_name,
'tempTable': self.tmp_table_name,
'userId': self.user_id,
'user': self.user.username,
'limit_reached': self.limit_reached,
'resultsKey': self.results_key,
}
@property
def name(self):
ts = datetime.now().isoformat()
ts = ts.replace('-', '').replace(':', '').split('.')[0]
tab = self.tab_name.replace(' ', '_').lower() if self.tab_name else 'notab'
tab = re.sub(r'\W+', '', tab)
return "sqllab_{tab}_{ts}".format(**locals())
class DatasourceAccessRequest(Model, AuditMixinNullable):
"""ORM model for the access requests for datasources and dbs."""
__tablename__ = 'access_request'
id = Column(Integer, primary_key=True)
datasource_id = Column(Integer)
datasource_type = Column(String(200))
ROLES_BLACKLIST = set(config.get('ROBOT_PERMISSION_ROLES', []))
@property
def cls_model(self):
return ConnectorRegistry.sources[self.datasource_type]
@property
def username(self):
return self.creator()
@property
def datasource(self):
return self.get_datasource
@datasource.getter
@utils.memoized
def get_datasource(self):
ds = db.session.query(self.cls_model).filter_by(
id=self.datasource_id).first()
return ds
@property
def datasource_link(self):
return self.datasource.link
@property
def roles_with_datasource(self):
action_list = ''
pv = sm.find_permission_view_menu(
'datasource_access', self.datasource.perm)
for r in pv.role:
if r.name in self.ROLES_BLACKLIST:
continue
url = (
'/superset/approve?datasource_type={self.datasource_type}&'
'datasource_id={self.datasource_id}&'
'created_by={self.created_by.username}&role_to_grant={r.name}'
.format(**locals())
)
href = '<a href="{}">Grant {} Role</a>'.format(url, r.name)
action_list = action_list + '<li>' + href + '</li>'
return '<ul>' + action_list + '</ul>'
@property
def user_roles(self):
action_list = ''
for r in self.created_by.roles:
url = (
'/superset/approve?datasource_type={self.datasource_type}&'
'datasource_id={self.datasource_id}&'
'created_by={self.created_by.username}&role_to_extend={r.name}'
.format(**locals())
)
href = '<a href="{}">Extend {} Role</a>'.format(url, r.name)
if r.name in self.ROLES_BLACKLIST:
href = "{} Role".format(r.name)
action_list = action_list + '<li>' + href + '</li>'
return '<ul>' + action_list + '</ul>'

127
superset/models/helpers.py Normal file
View File

@ -0,0 +1,127 @@
from datetime import datetime
import humanize
import json
import re
import sqlalchemy as sa
from sqlalchemy.ext.declarative import declared_attr
from flask import escape, Markup
from flask_appbuilder.models.mixins import AuditMixin
from flask_appbuilder.models.decorators import renders
from superset.utils import QueryStatus
class ImportMixin(object):
def override(self, obj):
"""Overrides the plain fields of the dashboard."""
for field in obj.__class__.export_fields:
setattr(self, field, getattr(obj, field))
def copy(self):
"""Creates a copy of the dashboard without relationships."""
new_obj = self.__class__()
new_obj.override(self)
return new_obj
def alter_params(self, **kwargs):
d = self.params_dict
d.update(kwargs)
self.params = json.dumps(d)
@property
def params_dict(self):
if self.params:
params = re.sub(",[ \t\r\n]+}", "}", self.params)
params = re.sub(",[ \t\r\n]+\]", "]", params)
return json.loads(params)
else:
return {}
class AuditMixinNullable(AuditMixin):
"""Altering the AuditMixin to use nullable fields
Allows creating objects programmatically outside of CRUD
"""
created_on = sa.Column(sa.DateTime, default=datetime.now, nullable=True)
changed_on = sa.Column(
sa.DateTime, default=datetime.now,
onupdate=datetime.now, nullable=True)
@declared_attr
def created_by_fk(cls): # noqa
return sa.Column(
sa.Integer, sa.ForeignKey('ab_user.id'),
default=cls.get_user_id, nullable=True)
@declared_attr
def changed_by_fk(cls): # noqa
return sa.Column(
sa.Integer, sa.ForeignKey('ab_user.id'),
default=cls.get_user_id, onupdate=cls.get_user_id, nullable=True)
def _user_link(self, user):
if not user:
return ''
url = '/superset/profile/{}/'.format(user.username)
return Markup('<a href="{}">{}</a>'.format(url, escape(user) or ''))
@renders('created_by')
def creator(self): # noqa
return self._user_link(self.created_by)
@property
def changed_by_(self):
return self._user_link(self.changed_by)
@renders('changed_on')
def changed_on_(self):
return Markup(
'<span class="no-wrap">{}</span>'.format(self.changed_on))
@renders('changed_on')
def modified(self):
s = humanize.naturaltime(datetime.now() - self.changed_on)
return Markup('<span class="no-wrap">{}</span>'.format(s))
@property
def icons(self):
return """
<a
href="{self.datasource_edit_url}"
data-toggle="tooltip"
title="{self.datasource}">
<i class="fa fa-database"></i>
</a>
""".format(**locals())
class QueryResult(object):
"""Object returned by the query interface"""
def __init__( # noqa
self,
df,
query,
duration,
status=QueryStatus.SUCCESS,
error_message=None):
self.df = df
self.query = query
self.duration = duration
self.status = status
self.error_message = error_message
def set_perm(mapper, connection, target): # noqa
if target.perm != target.get_perm():
link_table = target.__table__
connection.execute(
link_table.update()
.where(link_table.c.id == target.id)
.values(perm=target.get_perm())
)

View File

@ -6,7 +6,9 @@ from __future__ import unicode_literals
import logging
from flask_appbuilder.security.sqla import models as ab_models
from superset import conf, db, models, sm, source_registry
from superset import conf, db, sm
from superset.models import core as models
from superset.connectors.connector_registry import ConnectorRegistry
READ_ONLY_MODEL_VIEWS = {
@ -155,7 +157,7 @@ def create_custom_permissions():
def create_missing_datasource_perms(view_menu_set):
logging.info("Creating missing datasource permissions.")
datasources = source_registry.SourceRegistry.get_all_datasources(
datasources = ConnectorRegistry.get_all_datasources(
db.session)
for datasource in datasources:
if datasource and datasource.perm not in view_menu_set:
@ -181,8 +183,8 @@ def create_missing_metrics_perm(view_menu_set):
"""
logging.info("Creating missing metrics permissions")
metrics = []
for model in [models.SqlMetric, models.DruidMetric]:
metrics += list(db.session.query(model).all())
for datasource_class in ConnectorRegistry.sources.values():
metrics += list(db.session.query(datasource_class.metric_class).all())
for metric in metrics:
if (metric.is_restricted and metric.perm and
@ -216,7 +218,9 @@ def sync_role_definitions():
if conf.get('PUBLIC_ROLE_LIKE_GAMMA', False):
set_role('Public', pvms, is_gamma_pvm)
view_menu_set = db.session.query(models.SqlaTable).all()
view_menu_set = []
for datasource_class in ConnectorRegistry.sources.values():
view_menu_set += list(db.session.query(datasource_class).all())
create_missing_datasource_perms(view_menu_set)
create_missing_database_perms(view_menu_set)
create_missing_metrics_perm(view_menu_set)

View File

@ -12,11 +12,12 @@ from sqlalchemy.pool import NullPool
from sqlalchemy.orm import sessionmaker
from superset import (
app, db, models, utils, dataframe, results_backend)
app, db, utils, dataframe, results_backend)
from superset.models import core as models
from superset.sql_parse import SupersetQuery
from superset.db_engine_specs import LimitMethod
from superset.jinja_context import get_template_processor
QueryStatus = models.QueryStatus
from superset.utils import QueryStatus
celery_app = celery.Celery(config_source=app.config.get('CELERY_CONFIG'))

View File

@ -438,7 +438,7 @@ def pessimistic_connection_handling(target):
cursor.close()
class QueryStatus:
class QueryStatus(object):
"""Enum-type class for query statuses"""

View File

@ -0,0 +1,2 @@
from . import base # noqa
from . import core # noqa

201
superset/views/base.py Normal file
View File

@ -0,0 +1,201 @@
import logging
import json
from flask import g, redirect
from flask_babel import gettext as __
from flask_appbuilder import BaseView
from flask_appbuilder import ModelView
from flask_appbuilder.widgets import ListWidget
from flask_appbuilder.actions import action
from flask_appbuilder.models.sqla.filters import BaseFilter
from flask_appbuilder.security.sqla import models as ab_models
from superset import appbuilder, conf, db, utils, sm, sql_parse
from superset.connectors.connector_registry import ConnectorRegistry
def get_datasource_exist_error_mgs(full_name):
return __("Datasource %(name)s already exists", name=full_name)
def get_user_roles():
if g.user.is_anonymous():
public_role = conf.get('AUTH_ROLE_PUBLIC')
return [appbuilder.sm.find_role(public_role)] if public_role else []
return g.user.roles
class BaseSupersetView(BaseView):
def can_access(self, permission_name, view_name, user=None):
if not user:
user = g.user
return utils.can_access(
appbuilder.sm, permission_name, view_name, user)
def all_datasource_access(self, user=None):
return self.can_access(
"all_datasource_access", "all_datasource_access", user=user)
def database_access(self, database, user=None):
return (
self.can_access(
"all_database_access", "all_database_access", user=user) or
self.can_access("database_access", database.perm, user=user)
)
def schema_access(self, datasource, user=None):
return (
self.database_access(datasource.database, user=user) or
self.all_datasource_access(user=user) or
self.can_access("schema_access", datasource.schema_perm, user=user)
)
def datasource_access(self, datasource, user=None):
return (
self.schema_access(datasource, user=user) or
self.can_access("datasource_access", datasource.perm, user=user)
)
def datasource_access_by_name(
self, database, datasource_name, schema=None):
if self.database_access(database) or self.all_datasource_access():
return True
schema_perm = utils.get_schema_perm(database, schema)
if schema and utils.can_access(
sm, 'schema_access', schema_perm, g.user):
return True
datasources = ConnectorRegistry.query_datasources_by_name(
db.session, database, datasource_name, schema=schema)
for datasource in datasources:
if self.can_access("datasource_access", datasource.perm):
return True
return False
def datasource_access_by_fullname(
self, database, full_table_name, schema):
table_name_pieces = full_table_name.split(".")
if len(table_name_pieces) == 2:
table_schema = table_name_pieces[0]
table_name = table_name_pieces[1]
else:
table_schema = schema
table_name = table_name_pieces[0]
return self.datasource_access_by_name(
database, table_name, schema=table_schema)
def rejected_datasources(self, sql, database, schema):
superset_query = sql_parse.SupersetQuery(sql)
return [
t for t in superset_query.tables if not
self.datasource_access_by_fullname(database, t, schema)]
def accessible_by_user(self, database, datasource_names, schema=None):
if self.database_access(database) or self.all_datasource_access():
return datasource_names
schema_perm = utils.get_schema_perm(database, schema)
if schema and utils.can_access(
sm, 'schema_access', schema_perm, g.user):
return datasource_names
role_ids = set([role.id for role in g.user.roles])
# TODO: cache user_perms or user_datasources
user_pvms = (
db.session.query(ab_models.PermissionView)
.join(ab_models.Permission)
.filter(ab_models.Permission.name == 'datasource_access')
.filter(ab_models.PermissionView.role.any(
ab_models.Role.id.in_(role_ids)))
.all()
)
user_perms = set([pvm.view_menu.name for pvm in user_pvms])
user_datasources = ConnectorRegistry.query_datasources_by_permissions(
db.session, database, user_perms)
full_names = set([d.full_name for d in user_datasources])
return [d for d in datasource_names if d in full_names]
class SupersetModelView(ModelView):
page_size = 500
class ListWidgetWithCheckboxes(ListWidget):
"""An alternative to list view that renders Boolean fields as checkboxes
Works in conjunction with the `checkbox` view."""
template = 'superset/fab_overrides/list_with_checkboxes.html'
def validate_json(form, field): # noqa
try:
json.loads(field.data)
except Exception as e:
logging.exception(e)
raise Exception("json isn't valid")
class DeleteMixin(object):
@action(
"muldelete", "Delete", "Delete all Really?", "fa-trash", single=False)
def muldelete(self, items):
self.datamodel.delete_all(items)
self.update_redirect()
return redirect(self.get_redirect())
class SupersetFilter(BaseFilter):
"""Add utility function to make BaseFilter easy and fast
These utility function exist in the SecurityManager, but would do
a database round trip at every check. Here we cache the role objects
to be able to make multiple checks but query the db only once
"""
def get_user_roles(self):
return get_user_roles()
def get_all_permissions(self):
"""Returns a set of tuples with the perm name and view menu name"""
perms = set()
for role in self.get_user_roles():
for perm_view in role.permissions:
t = (perm_view.permission.name, perm_view.view_menu.name)
perms.add(t)
return perms
def has_role(self, role_name_or_list):
"""Whether the user has this role name"""
if not isinstance(role_name_or_list, list):
role_name_or_list = [role_name_or_list]
return any(
[r.name in role_name_or_list for r in self.get_user_roles()])
def has_perm(self, permission_name, view_menu_name):
"""Whether the user has this perm"""
return (permission_name, view_menu_name) in self.get_all_permissions()
def get_view_menus(self, permission_name):
"""Returns the details of view_menus for a perm name"""
vm = set()
for perm_name, vm_name in self.get_all_permissions():
if perm_name == permission_name:
vm.add(vm_name)
return vm
def has_all_datasource_access(self):
return (
self.has_role(['Admin', 'Alpha']) or
self.has_perm('all_datasource_access', 'all_datasource_access'))
class DatasourceFilter(SupersetFilter):
def apply(self, query, func): # noqa
if self.has_all_datasource_access():
return query
perms = self.get_view_menus('datasource_access')
# TODO(bogdan): add `schema_access` support here
return query.filter(self.model.perm.in_(perms))

View File

@ -19,12 +19,10 @@ import sqlalchemy as sqla
from flask import (
g, request, redirect, flash, Response, render_template, Markup)
from flask_appbuilder import ModelView, CompactCRUDMixin, BaseView, expose
from flask_appbuilder import expose
from flask_appbuilder.actions import action
from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask_appbuilder.security.decorators import has_access_api
from flask_appbuilder.widgets import ListWidget
from flask_appbuilder.models.sqla.filters import BaseFilter
from flask_appbuilder.security.sqla import models as ab_models
from flask_babel import gettext as __
@ -33,120 +31,26 @@ from flask_babel import lazy_gettext as _
from sqlalchemy import create_engine
from werkzeug.routing import BaseConverter
import superset
from superset import (
appbuilder, cache, db, models, viz, utils, app,
sm, sql_lab, sql_parse, results_backend, security,
appbuilder, cache, db, viz, utils, app,
sm, sql_lab, results_backend, security,
)
from superset.legacy import cast_form_data
from superset.utils import has_access
from superset.source_registry import SourceRegistry
from superset.models import DatasourceAccessRequest as DAR
from superset.connectors.connector_registry import ConnectorRegistry
import superset.models.core as models
from superset.sql_parse import SupersetQuery
from .base import (
SupersetModelView, BaseSupersetView, DeleteMixin,
SupersetFilter, get_user_roles
)
config = app.config
log_this = models.Log.log_this
can_access = utils.can_access
QueryStatus = models.QueryStatus
class BaseSupersetView(BaseView):
def can_access(self, permission_name, view_name, user=None):
if not user:
user = g.user
return utils.can_access(
appbuilder.sm, permission_name, view_name, user)
def all_datasource_access(self, user=None):
return self.can_access(
"all_datasource_access", "all_datasource_access", user=user)
def database_access(self, database, user=None):
return (
self.can_access(
"all_database_access", "all_database_access", user=user) or
self.can_access("database_access", database.perm, user=user)
)
def schema_access(self, datasource, user=None):
return (
self.database_access(datasource.database, user=user) or
self.all_datasource_access(user=user) or
self.can_access("schema_access", datasource.schema_perm, user=user)
)
def datasource_access(self, datasource, user=None):
return (
self.schema_access(datasource, user=user) or
self.can_access("datasource_access", datasource.perm, user=user)
)
def datasource_access_by_name(
self, database, datasource_name, schema=None):
if self.database_access(database) or self.all_datasource_access():
return True
schema_perm = utils.get_schema_perm(database, schema)
if schema and utils.can_access(
sm, 'schema_access', schema_perm, g.user):
return True
datasources = SourceRegistry.query_datasources_by_name(
db.session, database, datasource_name, schema=schema)
for datasource in datasources:
if self.can_access("datasource_access", datasource.perm):
return True
return False
def datasource_access_by_fullname(
self, database, full_table_name, schema):
table_name_pieces = full_table_name.split(".")
if len(table_name_pieces) == 2:
table_schema = table_name_pieces[0]
table_name = table_name_pieces[1]
else:
table_schema = schema
table_name = table_name_pieces[0]
return self.datasource_access_by_name(
database, table_name, schema=table_schema)
def rejected_datasources(self, sql, database, schema):
superset_query = sql_parse.SupersetQuery(sql)
return [
t for t in superset_query.tables if not
self.datasource_access_by_fullname(database, t, schema)]
def accessible_by_user(self, database, datasource_names, schema=None):
if self.database_access(database) or self.all_datasource_access():
return datasource_names
schema_perm = utils.get_schema_perm(database, schema)
if schema and utils.can_access(
sm, 'schema_access', schema_perm, g.user):
return datasource_names
role_ids = set([role.id for role in g.user.roles])
# TODO: cache user_perms or user_datasources
user_pvms = (
db.session.query(ab_models.PermissionView)
.join(ab_models.Permission)
.filter(ab_models.Permission.name == 'datasource_access')
.filter(ab_models.PermissionView.role.any(
ab_models.Role.id.in_(role_ids)))
.all()
)
user_perms = set([pvm.view_menu.name for pvm in user_pvms])
user_datasources = SourceRegistry.query_datasources_by_permissions(
db.session, database, user_perms)
full_names = set([d.full_name for d in user_datasources])
return [d for d in datasource_names if d in full_names]
class ListWidgetWithCheckboxes(ListWidget):
"""An alternative to list view that renders Boolean fields as checkboxes
Works in conjunction with the `checkbox` view."""
template = 'superset/fab_overrides/list_with_checkboxes.html'
DAR = models.DatasourceAccessRequest
ALL_DATASOURCE_ACCESS_ERR = __(
@ -168,10 +72,6 @@ def get_datasource_access_error_msg(datasource_name):
"`all_datasource_access` permission", name=datasource_name)
def get_datasource_exist_error_mgs(full_name):
return __("Datasource %(name)s already exists", name=full_name)
def get_error_msg():
if config.get("SHOW_STACKTRACE"):
error_msg = traceback.format_exc()
@ -211,10 +111,12 @@ def api(f):
return functools.update_wrapper(wraps, f)
def is_owner(obj, user):
""" Check if user is owner of the slice """
return obj and obj.owners and user in obj.owners
def check_ownership(obj, raise_if_false=True):
"""Meant to be used in `pre_update` hooks on models to enforce ownership
@ -257,68 +159,6 @@ def check_ownership(obj, raise_if_false=True):
return False
def get_user_roles():
if g.user.is_anonymous():
public_role = config.get('AUTH_ROLE_PUBLIC')
return [appbuilder.sm.find_role(public_role)] if public_role else []
return g.user.roles
class SupersetFilter(BaseFilter):
"""Add utility function to make BaseFilter easy and fast
These utility function exist in the SecurityManager, but would do
a database round trip at every check. Here we cache the role objects
to be able to make multiple checks but query the db only once
"""
def get_user_roles(self):
return get_user_roles()
def get_all_permissions(self):
"""Returns a set of tuples with the perm name and view menu name"""
perms = set()
for role in get_user_roles():
for perm_view in role.permissions:
t = (perm_view.permission.name, perm_view.view_menu.name)
perms.add(t)
return perms
def has_role(self, role_name_or_list):
"""Whether the user has this role name"""
if not isinstance(role_name_or_list, list):
role_name_or_list = [role_name_or_list]
return any(
[r.name in role_name_or_list for r in self.get_user_roles()])
def has_perm(self, permission_name, view_menu_name):
"""Whether the user has this perm"""
return (permission_name, view_menu_name) in self.get_all_permissions()
def get_view_menus(self, permission_name):
"""Returns the details of view_menus for a perm name"""
vm = set()
for perm_name, vm_name in self.get_all_permissions():
if perm_name == permission_name:
vm.add(vm_name)
return vm
def has_all_datasource_access(self):
return (
self.has_role(['Admin', 'Alpha']) or
self.has_perm('all_datasource_access', 'all_datasource_access'))
class DatasourceFilter(SupersetFilter):
def apply(self, query, func): # noqa
if self.has_all_datasource_access():
return query
perms = self.get_view_menus('datasource_access')
# TODO(bogdan): add `schema_access` support here
return query.filter(self.model.perm.in_(perms))
class SliceFilter(SupersetFilter):
def apply(self, query, func): # noqa
if self.has_all_datasource_access():
@ -355,14 +195,6 @@ class DashboardFilter(SupersetFilter):
return query
def validate_json(form, field): # noqa
try:
json.loads(field.data)
except Exception as e:
logging.exception(e)
raise Exception("json isn't valid")
def generate_download_headers(extension):
filename = datetime.now().strftime("%Y%m%d_%H%M%S")
content_disp = "attachment; filename={}.{}".format(filename, extension)
@ -372,206 +204,6 @@ def generate_download_headers(extension):
return headers
class DeleteMixin(object):
@action(
"muldelete", "Delete", "Delete all Really?", "fa-trash", single=False)
def muldelete(self, items):
self.datamodel.delete_all(items)
self.update_redirect()
return redirect(self.get_redirect())
class SupersetModelView(ModelView):
page_size = 500
class TableColumnInlineView(CompactCRUDMixin, SupersetModelView): # noqa
datamodel = SQLAInterface(models.TableColumn)
can_delete = False
list_widget = ListWidgetWithCheckboxes
edit_columns = [
'column_name', 'verbose_name', 'description', 'groupby', 'filterable',
'table', 'count_distinct', 'sum', 'min', 'max', 'expression',
'is_dttm', 'python_date_format', 'database_expression']
add_columns = edit_columns
list_columns = [
'column_name', 'type', 'groupby', 'filterable', 'count_distinct',
'sum', 'min', 'max', 'is_dttm']
page_size = 500
description_columns = {
'is_dttm': (_(
"Whether to make this column available as a "
"[Time Granularity] option, column has to be DATETIME or "
"DATETIME-like")),
'expression': utils.markdown(
"a valid SQL expression as supported by the underlying backend. "
"Example: `substr(name, 1, 1)`", True),
'python_date_format': utils.markdown(Markup(
"The pattern of timestamp format, use "
"<a href='https://docs.python.org/2/library/"
"datetime.html#strftime-strptime-behavior'>"
"python datetime string pattern</a> "
"expression. If time is stored in epoch "
"format, put `epoch_s` or `epoch_ms`. Leave `Database Expression` "
"below empty if timestamp is stored in "
"String or Integer(epoch) type"), True),
'database_expression': utils.markdown(
"The database expression to cast internal datetime "
"constants to database date/timestamp type according to the DBAPI. "
"The expression should follow the pattern of "
"%Y-%m-%d %H:%M:%S, based on different DBAPI. "
"The string should be a python string formatter \n"
"`Ex: TO_DATE('{}', 'YYYY-MM-DD HH24:MI:SS')` for Oracle"
"Superset uses default expression based on DB URI if this "
"field is blank.", True),
}
label_columns = {
'column_name': _("Column"),
'verbose_name': _("Verbose Name"),
'description': _("Description"),
'groupby': _("Groupable"),
'filterable': _("Filterable"),
'table': _("Table"),
'count_distinct': _("Count Distinct"),
'sum': _("Sum"),
'min': _("Min"),
'max': _("Max"),
'expression': _("Expression"),
'is_dttm': _("Is temporal"),
'python_date_format': _("Datetime Format"),
'database_expression': _("Database Expression")
}
appbuilder.add_view_no_menu(TableColumnInlineView)
class DruidColumnInlineView(CompactCRUDMixin, SupersetModelView): # noqa
datamodel = SQLAInterface(models.DruidColumn)
edit_columns = [
'column_name', 'description', 'dimension_spec_json', 'datasource',
'groupby', 'count_distinct', 'sum', 'min', 'max']
add_columns = edit_columns
list_columns = [
'column_name', 'type', 'groupby', 'filterable', 'count_distinct',
'sum', 'min', 'max']
can_delete = False
page_size = 500
label_columns = {
'column_name': _("Column"),
'type': _("Type"),
'datasource': _("Datasource"),
'groupby': _("Groupable"),
'filterable': _("Filterable"),
'count_distinct': _("Count Distinct"),
'sum': _("Sum"),
'min': _("Min"),
'max': _("Max"),
}
description_columns = {
'dimension_spec_json': utils.markdown(
"this field can be used to specify "
"a `dimensionSpec` as documented [here]"
"(http://druid.io/docs/latest/querying/dimensionspecs.html). "
"Make sure to input valid JSON and that the "
"`outputName` matches the `column_name` defined "
"above.",
True),
}
def post_update(self, col):
col.generate_metrics()
utils.validate_json(col.dimension_spec_json)
def post_add(self, col):
self.post_update(col)
appbuilder.add_view_no_menu(DruidColumnInlineView)
class SqlMetricInlineView(CompactCRUDMixin, SupersetModelView): # noqa
datamodel = SQLAInterface(models.SqlMetric)
list_columns = ['metric_name', 'verbose_name', 'metric_type']
edit_columns = [
'metric_name', 'description', 'verbose_name', 'metric_type',
'expression', 'table', 'd3format', 'is_restricted']
description_columns = {
'expression': utils.markdown(
"a valid SQL expression as supported by the underlying backend. "
"Example: `count(DISTINCT userid)`", True),
'is_restricted': _("Whether the access to this metric is restricted "
"to certain roles. Only roles with the permission "
"'metric access on XXX (the name of this metric)' "
"are allowed to access this metric"),
'd3format': utils.markdown(
"d3 formatting string as defined [here]"
"(https://github.com/d3/d3-format/blob/master/README.md#format). "
"For instance, this default formatting applies in the Table "
"visualization and allow for different metric to use different "
"formats", True
),
}
add_columns = edit_columns
page_size = 500
label_columns = {
'metric_name': _("Metric"),
'description': _("Description"),
'verbose_name': _("Verbose Name"),
'metric_type': _("Type"),
'expression': _("SQL Expression"),
'table': _("Table"),
}
def post_add(self, metric):
if metric.is_restricted:
security.merge_perm(sm, 'metric_access', metric.get_perm())
def post_update(self, metric):
if metric.is_restricted:
security.merge_perm(sm, 'metric_access', metric.get_perm())
appbuilder.add_view_no_menu(SqlMetricInlineView)
class DruidMetricInlineView(CompactCRUDMixin, SupersetModelView): # noqa
datamodel = SQLAInterface(models.DruidMetric)
list_columns = ['metric_name', 'verbose_name', 'metric_type']
edit_columns = [
'metric_name', 'description', 'verbose_name', 'metric_type', 'json',
'datasource', 'd3format', 'is_restricted']
add_columns = edit_columns
page_size = 500
validators_columns = {
'json': [validate_json],
}
description_columns = {
'metric_type': utils.markdown(
"use `postagg` as the metric type if you are defining a "
"[Druid Post Aggregation]"
"(http://druid.io/docs/latest/querying/post-aggregations.html)",
True),
'is_restricted': _("Whether the access to this metric is restricted "
"to certain roles. Only roles with the permission "
"'metric access on XXX (the name of this metric)' "
"are allowed to access this metric"),
}
label_columns = {
'metric_name': _("Metric"),
'description': _("Description"),
'verbose_name': _("Verbose Name"),
'metric_type': _("Type"),
'json': _("JSON"),
'datasource': _("Druid Datasource"),
}
def post_add(self, metric):
utils.init_metrics_perm(superset, [metric])
def post_update(self, metric):
utils.init_metrics_perm(superset, [metric])
appbuilder.add_view_no_menu(DruidMetricInlineView)
class DatabaseView(SupersetModelView, DeleteMixin): # noqa
datamodel = SQLAInterface(models.Database)
list_columns = [
@ -692,99 +324,6 @@ class DatabaseTablesAsync(DatabaseView):
appbuilder.add_view_no_menu(DatabaseTablesAsync)
class TableModelView(SupersetModelView, DeleteMixin): # noqa
datamodel = SQLAInterface(models.SqlaTable)
list_columns = [
'link', 'database', 'is_featured',
'changed_by_', 'changed_on_']
order_columns = [
'link', 'database', 'is_featured', 'changed_on_']
add_columns = ['database', 'schema', 'table_name']
edit_columns = [
'table_name', 'sql', 'is_featured', 'filter_select_enabled',
'database', 'schema',
'description', 'owner',
'main_dttm_col', 'default_endpoint', 'offset', 'cache_timeout']
show_columns = edit_columns + ['perm']
related_views = [TableColumnInlineView, SqlMetricInlineView]
base_order = ('changed_on', 'desc')
description_columns = {
'offset': _("Timezone offset (in hours) for this datasource"),
'table_name': _(
"Name of the table that exists in the source database"),
'schema': _(
"Schema, as used only in some databases like Postgres, Redshift "
"and DB2"),
'description': Markup(
"Supports <a href='https://daringfireball.net/projects/markdown/'>"
"markdown</a>"),
'sql': _(
"This fields acts a Superset view, meaning that Superset will "
"run a query against this string as a subquery."
),
}
base_filters = [['id', DatasourceFilter, lambda: []]]
label_columns = {
'link': _("Table"),
'changed_by_': _("Changed By"),
'database': _("Database"),
'changed_on_': _("Last Changed"),
'is_featured': _("Is Featured"),
'filter_select_enabled': _("Enable Filter Select"),
'schema': _("Schema"),
'default_endpoint': _("Default Endpoint"),
'offset': _("Offset"),
'cache_timeout': _("Cache Timeout"),
}
def pre_add(self, table):
number_of_existing_tables = db.session.query(
sqla.func.count('*')).filter(
models.SqlaTable.table_name == table.table_name,
models.SqlaTable.schema == table.schema,
models.SqlaTable.database_id == table.database.id
).scalar()
# table object is already added to the session
if number_of_existing_tables > 1:
raise Exception(get_datasource_exist_error_mgs(table.full_name))
# Fail before adding if the table can't be found
try:
table.get_sqla_table_object()
except Exception as e:
logging.exception(e)
raise Exception(
"Table [{}] could not be found, "
"please double check your "
"database connection, schema, and "
"table name".format(table.name))
def post_add(self, table):
table.fetch_metadata()
security.merge_perm(sm, 'datasource_access', table.get_perm())
if table.schema:
security.merge_perm(sm, 'schema_access', table.schema_perm)
flash(_(
"The table was created. As part of this two phase configuration "
"process, you should now click the edit button by "
"the new table to configure it."),
"info")
def post_update(self, table):
self.post_add(table)
appbuilder.add_view(
TableModelView,
"Tables",
label=__("Tables"),
category="Sources",
category_label=__("Sources"),
icon='fa-table',)
appbuilder.add_separator("Sources")
class AccessRequestsModelView(SupersetModelView, DeleteMixin):
datamodel = SQLAInterface(DAR)
list_columns = [
@ -810,43 +349,6 @@ appbuilder.add_view(
icon='fa-table',)
class DruidClusterModelView(SupersetModelView, DeleteMixin): # noqa
datamodel = SQLAInterface(models.DruidCluster)
add_columns = [
'cluster_name',
'coordinator_host', 'coordinator_port', 'coordinator_endpoint',
'broker_host', 'broker_port', 'broker_endpoint', 'cache_timeout',
]
edit_columns = add_columns
list_columns = ['cluster_name', 'metadata_last_refreshed']
label_columns = {
'cluster_name': _("Cluster"),
'coordinator_host': _("Coordinator Host"),
'coordinator_port': _("Coordinator Port"),
'coordinator_endpoint': _("Coordinator Endpoint"),
'broker_host': _("Broker Host"),
'broker_port': _("Broker Port"),
'broker_endpoint': _("Broker Endpoint"),
}
def pre_add(self, cluster):
security.merge_perm(sm, 'database_access', cluster.perm)
def pre_update(self, cluster):
self.pre_add(cluster)
if config['DRUID_IS_ACTIVE']:
appbuilder.add_view(
DruidClusterModelView,
name="Druid Clusters",
label=__("Druid Clusters"),
icon="fa-cubes",
category="Sources",
category_label=__("Sources"),
category_icon='fa-database',)
class SliceModelView(SupersetModelView, DeleteMixin): # noqa
datamodel = SQLAInterface(models.Slice)
can_add = False
@ -903,9 +405,9 @@ class SliceModelView(SupersetModelView, DeleteMixin): # noqa
if not widget:
return redirect(self.get_redirect())
sources = SourceRegistry.sources
sources = ConnectorRegistry.sources
for source in sources:
ds = db.session.query(SourceRegistry.sources[source]).first()
ds = db.session.query(ConnectorRegistry.sources[source]).first()
if ds is not None:
url = "/{}/list/".format(ds.baselink)
msg = _("Click on a {} link to create a Slice".format(source))
@ -1078,74 +580,6 @@ appbuilder.add_view(
icon="fa-search")
class DruidDatasourceModelView(SupersetModelView, DeleteMixin): # noqa
datamodel = SQLAInterface(models.DruidDatasource)
list_widget = ListWidgetWithCheckboxes
list_columns = [
'datasource_link', 'cluster', 'changed_by_', 'changed_on_', 'offset']
order_columns = [
'datasource_link', 'changed_on_', 'offset']
related_views = [DruidColumnInlineView, DruidMetricInlineView]
edit_columns = [
'datasource_name', 'cluster', 'description', 'owner',
'is_featured', 'is_hidden', 'filter_select_enabled',
'default_endpoint', 'offset', 'cache_timeout']
add_columns = edit_columns
show_columns = add_columns + ['perm']
page_size = 500
base_order = ('datasource_name', 'asc')
description_columns = {
'offset': _("Timezone offset (in hours) for this datasource"),
'description': Markup(
"Supports <a href='"
"https://daringfireball.net/projects/markdown/'>markdown</a>"),
}
base_filters = [['id', DatasourceFilter, lambda: []]]
label_columns = {
'datasource_link': _("Data Source"),
'cluster': _("Cluster"),
'description': _("Description"),
'owner': _("Owner"),
'is_featured': _("Is Featured"),
'is_hidden': _("Is Hidden"),
'filter_select_enabled': _("Enable Filter Select"),
'default_endpoint': _("Default Endpoint"),
'offset': _("Time Offset"),
'cache_timeout': _("Cache Timeout"),
}
def pre_add(self, datasource):
number_of_existing_datasources = db.session.query(
sqla.func.count('*')).filter(
models.DruidDatasource.datasource_name ==
datasource.datasource_name,
models.DruidDatasource.cluster_name == datasource.cluster.id
).scalar()
# table object is already added to the session
if number_of_existing_datasources > 1:
raise Exception(get_datasource_exist_error_mgs(
datasource.full_name))
def post_add(self, datasource):
datasource.generate_metrics()
security.merge_perm(sm, 'datasource_access', datasource.get_perm())
if datasource.schema:
security.merge_perm(sm, 'schema_access', datasource.schema_perm)
def post_update(self, datasource):
self.post_add(datasource)
if config['DRUID_IS_ACTIVE']:
appbuilder.add_view(
DruidDatasourceModelView,
"Druid Datasources",
label=__("Druid Datasources"),
category="Sources",
category_label=__("Sources"),
icon="fa-cube")
@app.route('/health')
def health():
return "OK"
@ -1283,7 +717,7 @@ class Superset(BaseSupersetView):
@has_access_api
@expose("/datasources/")
def datasources(self):
datasources = SourceRegistry.get_all_datasources(db.session)
datasources = ConnectorRegistry.get_all_datasources(db.session)
datasources = [(str(o.id) + '__' + o.type, repr(o)) for o in datasources]
return self.json_response(datasources)
@ -1318,7 +752,7 @@ class Superset(BaseSupersetView):
dbs['name'], ds_name, schema=schema['name'])
db_ds_names.add(fullname)
existing_datasources = SourceRegistry.get_all_datasources(db.session)
existing_datasources = ConnectorRegistry.get_all_datasources(db.session)
datasources = [
d for d in existing_datasources if d.full_name in db_ds_names]
role = sm.find_role(role_name)
@ -1356,7 +790,7 @@ class Superset(BaseSupersetView):
datasource_id = request.args.get('datasource_id')
datasource_type = request.args.get('datasource_type')
if datasource_id:
ds_class = SourceRegistry.sources.get(datasource_type)
ds_class = ConnectorRegistry.sources.get(datasource_type)
datasource = (
db.session.query(ds_class)
.filter_by(id=int(datasource_id))
@ -1385,7 +819,7 @@ class Superset(BaseSupersetView):
def approve(self):
def clean_fulfilled_requests(session):
for r in session.query(DAR).all():
datasource = SourceRegistry.get_datasource(
datasource = ConnectorRegistry.get_datasource(
r.datasource_type, r.datasource_id, session)
user = sm.get_user_by_id(r.created_by_fk)
if not datasource or \
@ -1400,7 +834,7 @@ class Superset(BaseSupersetView):
role_to_extend = request.args.get('role_to_extend')
session = db.session
datasource = SourceRegistry.get_datasource(
datasource = ConnectorRegistry.get_datasource(
datasource_type, datasource_id, session)
if not datasource:
@ -1501,9 +935,9 @@ class Superset(BaseSupersetView):
)
return slc.get_viz()
else:
form_data=self.get_form_data()
form_data = self.get_form_data()
viz_type = form_data.get('viz_type', 'table')
datasource = SourceRegistry.get_datasource(
datasource = ConnectorRegistry.get_datasource(
datasource_type, datasource_id, db.session)
viz_obj = viz.viz_types[viz_type](
datasource,
@ -1542,7 +976,6 @@ class Superset(BaseSupersetView):
utils.error_msg_from_exception(e),
stacktrace=traceback.format_exc())
if not self.datasource_access(viz_obj.datasource):
return json_error_response(DATASOURCE_ACCESS_ERR, status=404)
@ -1598,11 +1031,8 @@ class Superset(BaseSupersetView):
data = pickle.load(f)
# TODO: import DRUID datasources
for table in data['datasources']:
if table.type == 'table':
models.SqlaTable.import_obj(table, import_time=current_tt)
else:
models.DruidDatasource.import_obj(
table, import_time=current_tt)
ds_class = ConnectorRegistry.sources.get(table.type)
ds_class.import_obj(table, import_time=current_tt)
db.session.commit()
for dashboard in data['dashboards']:
models.Dashboard.import_obj(
@ -1628,7 +1058,7 @@ class Superset(BaseSupersetView):
error_redirect = '/slicemodelview/list/'
datasource = (
db.session.query(SourceRegistry.sources[datasource_type])
db.session.query(ConnectorRegistry.sources[datasource_type])
.filter_by(id=datasource_id)
.one()
)
@ -1705,8 +1135,7 @@ class Superset(BaseSupersetView):
"""
# TODO: Cache endpoint by user, datasource and column
error_redirect = '/slicemodelview/list/'
datasource_class = models.SqlaTable \
if datasource_type == "table" else models.DruidDatasource
datasource_class = ConnectorRegistry.sources[datasource_type]
datasource = db.session.query(
datasource_class).filter_by(id=datasource_id).first()
@ -2194,12 +1623,13 @@ class Superset(BaseSupersetView):
return json_error_response(__(
"Slice %(id)s not found", id=slice_id), status=404)
elif table_name and db_name:
SqlaTable = ConnectorRegistry.sources['table']
table = (
session.query(models.SqlaTable)
session.query(SqlaTable)
.join(models.Database)
.filter(
models.Database.database_name == db_name or
models.SqlaTable.table_name == table_name)
SqlaTable.table_name == table_name)
).first()
if not table:
return json_error_response(__(
@ -2209,15 +1639,15 @@ class Superset(BaseSupersetView):
datasource_id=table.id,
datasource_type=table.type).all()
for slice in slices:
for slc in slices:
try:
obj = slice.get_viz()
obj = slc.get_viz()
obj.get_json(force=True)
except Exception as e:
return json_error_response(utils.error_msg_from_exception(e))
return json_success(json.dumps(
[{"slice_id": session.id, "slice_name": session.slice_name}
for session in slices]))
[{"slice_id": slc.id, "slice_name": slc.slice_name}
for slc in slices]))
@expose("/favstar/<class_name>/<obj_id>/<action>/")
def favstar(self, class_name, obj_id, action):
@ -2322,12 +1752,14 @@ class Superset(BaseSupersetView):
cluster_name = payload['cluster']
user = sm.find_user(username=user_name)
DruidDatasource = ConnectorRegistry.sources['druid']
DruidCluster = DruidDatasource.cluster_class
if not user:
err_msg = __("Can't find User '%(name)s', please ask your admin "
"to create one.", name=user_name)
logging.error(err_msg)
return json_error_response(err_msg)
cluster = db.session.query(models.DruidCluster).filter_by(
cluster = db.session.query(DruidCluster).filter_by(
cluster_name=cluster_name).first()
if not cluster:
err_msg = __("Can't find DruidCluster with cluster_name = "
@ -2335,7 +1767,7 @@ class Superset(BaseSupersetView):
logging.error(err_msg)
return json_error_response(err_msg)
try:
models.DruidDatasource.sync_to_db_from_config(
DruidDatasource.sync_to_db_from_config(
druid_config, user, cluster)
except Exception as e:
logging.exception(utils.error_msg_from_exception(e))
@ -2349,13 +1781,14 @@ class Superset(BaseSupersetView):
data = json.loads(request.form.get('data'))
table_name = data.get('datasourceName')
viz_type = data.get('chartType')
SqlaTable = ConnectorRegistry.sources['table']
table = (
db.session.query(models.SqlaTable)
db.session.query(SqlaTable)
.filter_by(table_name=table_name)
.first()
)
if not table:
table = models.SqlaTable(table_name=table_name)
table = SqlaTable(table_name=table_name)
table.database_id = data.get('dbId')
q = SupersetQuery(data.get('sql'))
table.sql = q.stripped()
@ -2642,7 +2075,7 @@ class Superset(BaseSupersetView):
def fetch_datasource_metadata(self):
datasource_id, datasource_type = (
request.args.get('datasourceKey').split('__'))
datasource_class = SourceRegistry.sources[datasource_type]
datasource_class = ConnectorRegistry.sources[datasource_type]
datasource = (
db.session.query(datasource_class)
.filter_by(id=int(datasource_id))
@ -2740,7 +2173,8 @@ class Superset(BaseSupersetView):
def refresh_datasources(self):
"""endpoint that refreshes druid datasources metadata"""
session = db.session()
for cluster in session.query(models.DruidCluster).all():
DruidCluster = ConnectorRegistry.sources['druid']
for cluster in session.query(DruidCluster).all():
cluster_name = cluster.cluster_name
try:
cluster.refresh_datasources()

View File

@ -9,7 +9,11 @@ import mock
import unittest
from superset import db, models, sm, security
from superset.source_registry import SourceRegistry
from superset.models import core as models
from superset.connectors.connector_registry import ConnectorRegistry
from superset.connectors.sqla.models import SqlaTable
from superset.connectors.druid.models import DruidDatasource
from .base_tests import SupersetTestCase
@ -58,7 +62,7 @@ SCHEMA_ACCESS_ROLE = 'schema_access_role'
def create_access_request(session, ds_type, ds_name, role_name, user_name):
ds_class = SourceRegistry.sources[ds_type]
ds_class = ConnectorRegistry.sources[ds_type]
# TODO: generalize datasource names
if ds_type == 'table':
ds = session.query(ds_class).filter(
@ -293,7 +297,7 @@ class RequestAccessTests(SupersetTestCase):
access_request2 = create_access_request(
session, 'table', 'wb_health_population', TEST_ROLE_2, 'gamma2')
ds_1_id = access_request1.datasource_id
ds = session.query(models.SqlaTable).filter_by(
ds = session.query(SqlaTable).filter_by(
table_name='wb_health_population').first()
@ -314,7 +318,7 @@ class RequestAccessTests(SupersetTestCase):
gamma_user = sm.find_user(username='gamma')
gamma_user.roles.remove(sm.find_role(SCHEMA_ACCESS_ROLE))
ds = session.query(models.SqlaTable).filter_by(
ds = session.query(SqlaTable).filter_by(
table_name='wb_health_population').first()
ds.schema = None
@ -441,7 +445,7 @@ class RequestAccessTests(SupersetTestCase):
# Request table access, there are no roles have this table.
table1 = session.query(models.SqlaTable).filter_by(
table1 = session.query(SqlaTable).filter_by(
table_name='random_time_series').first()
table_1_id = table1.id
@ -454,7 +458,7 @@ class RequestAccessTests(SupersetTestCase):
# Request access, roles exist that contains the table.
# add table to the existing roles
table3 = session.query(models.SqlaTable).filter_by(
table3 = session.query(SqlaTable).filter_by(
table_name='energy_usage').first()
table_3_id = table3.id
table3_perm = table3.perm
@ -479,7 +483,7 @@ class RequestAccessTests(SupersetTestCase):
'<ul><li>{}</li></ul>'.format(approve_link_3))
# Request druid access, there are no roles have this table.
druid_ds_4 = session.query(models.DruidDatasource).filter_by(
druid_ds_4 = session.query(DruidDatasource).filter_by(
datasource_name='druid_ds_1').first()
druid_ds_4_id = druid_ds_4.id
@ -493,7 +497,7 @@ class RequestAccessTests(SupersetTestCase):
# Case 5. Roles exist that contains the druid datasource.
# add druid ds to the existing roles
druid_ds_5 = session.query(models.DruidDatasource).filter_by(
druid_ds_5 = session.query(DruidDatasource).filter_by(
datasource_name='druid_ds_2').first()
druid_ds_5_id = druid_ds_5.id
druid_ds_5_perm = druid_ds_5.perm

View File

@ -11,8 +11,11 @@ import unittest
from flask_appbuilder.security.sqla import models as ab_models
from superset import app, cli, db, models, appbuilder, security, sm
from superset import app, cli, db, appbuilder, security, sm
from superset.models import core as models
from superset.security import sync_role_definitions
from superset.connectors.sqla.models import SqlaTable
from superset.connectors.druid.models import DruidCluster, DruidDatasource
os.environ['SUPERSET_CONFIG'] = 'tests.superset_test_config'
@ -85,30 +88,34 @@ class SupersetTestCase(unittest.TestCase):
appbuilder.sm.find_role('Alpha'),
password='general')
sm.get_session.commit()
# create druid cluster and druid datasources
session = db.session
cluster = session.query(models.DruidCluster).filter_by(
cluster_name="druid_test").first()
cluster = (
session.query(DruidCluster)
.filter_by(cluster_name="druid_test")
.first()
)
if not cluster:
cluster = models.DruidCluster(cluster_name="druid_test")
cluster = DruidCluster(cluster_name="druid_test")
session.add(cluster)
session.commit()
druid_datasource1 = models.DruidDatasource(
druid_datasource1 = DruidDatasource(
datasource_name='druid_ds_1',
cluster_name='druid_test'
)
session.add(druid_datasource1)
druid_datasource2 = models.DruidDatasource(
druid_datasource2 = DruidDatasource(
datasource_name='druid_ds_2',
cluster_name='druid_test'
)
session.add(druid_datasource2)
session.commit()
def get_table(self, table_id):
return db.session.query(models.SqlaTable).filter_by(
return db.session.query(SqlaTable).filter_by(
id=table_id).first()
def get_or_create(self, cls, criteria, session):
@ -149,11 +156,11 @@ class SupersetTestCase(unittest.TestCase):
return slc
def get_table_by_name(self, name):
return db.session.query(models.SqlaTable).filter_by(
return db.session.query(SqlaTable).filter_by(
table_name=name).first()
def get_druid_ds_by_name(self, name):
return db.session.query(models.DruidDatasource).filter_by(
return db.session.query(DruidDatasource).filter_by(
datasource_name=name).first()
def get_resp(

View File

@ -12,14 +12,14 @@ import unittest
import pandas as pd
from superset import app, appbuilder, cli, db, models, dataframe
from superset import app, appbuilder, cli, db, dataframe
from superset.models import core as models
from superset.models.helpers import QueryStatus
from superset.security import sync_role_definitions
from superset.sql_parse import SupersetQuery
from .base_tests import SupersetTestCase
QueryStatus = models.QueryStatus
BASE_DIR = app.config.get('BASE_DIR')

View File

@ -7,17 +7,19 @@ from __future__ import unicode_literals
import csv
import doctest
import json
import logging
import io
import random
import unittest
from flask import escape
from superset import db, models, utils, appbuilder, sm, jinja_context, sql_lab
from superset.views import DatabaseView
from superset import db, utils, appbuilder, sm, jinja_context, sql_lab
from superset.models import core as models
from superset.views.core import DatabaseView
from superset.connectors.sqla.models import SqlaTable
from .base_tests import SupersetTestCase
import logging
class CoreTests(SupersetTestCase):
@ -31,7 +33,7 @@ class CoreTests(SupersetTestCase):
def setUpClass(cls):
cls.table_ids = {tbl.table_name: tbl.id for tbl in (
db.session
.query(models.SqlaTable)
.query(SqlaTable)
.all()
)}
@ -186,7 +188,7 @@ class CoreTests(SupersetTestCase):
slice_id = self.get_slice(slice_name, db.session).id
db.session.commit()
tbl_id = self.table_ids.get('energy_usage')
table = db.session.query(models.SqlaTable).filter(models.SqlaTable.id == tbl_id)
table = db.session.query(SqlaTable).filter(SqlaTable.id == tbl_id)
table.filter_select_enabled = True
url = (
"/superset/filter/table/{}/target/?viz_type=sankey&groupby=source"
@ -220,7 +222,7 @@ class CoreTests(SupersetTestCase):
url = '/tablemodelview/list/'
resp = self.get_resp(url)
table = db.session.query(models.SqlaTable).first()
table = db.session.query(SqlaTable).first()
assert table.name in resp
assert '/superset/explore/table/{}'.format(table.id) in resp
@ -459,7 +461,7 @@ class CoreTests(SupersetTestCase):
def test_public_user_dashboard_access(self):
table = (
db.session
.query(models.SqlaTable)
.query(SqlaTable)
.filter_by(table_name='birth_names')
.one()
)
@ -494,7 +496,7 @@ class CoreTests(SupersetTestCase):
self.logout()
table = (
db.session
.query(models.SqlaTable)
.query(SqlaTable)
.filter_by(table_name='birth_names')
.one()
)

View File

@ -11,7 +11,8 @@ import unittest
from mock import Mock, patch
from superset import db, sm, security
from superset.models import DruidCluster, DruidDatasource
from superset.connectors.druid.models import DruidCluster, DruidDatasource
from superset.connectors.druid.models import PyDruid
from .base_tests import SupersetTestCase
@ -70,7 +71,7 @@ class DruidTests(SupersetTestCase):
def __init__(self, *args, **kwargs):
super(DruidTests, self).__init__(*args, **kwargs)
@patch('superset.models.PyDruid')
@patch('superset.connectors.druid.models.PyDruid')
def test_client(self, PyDruid):
self.login(username='admin')
instance = PyDruid.return_value
@ -197,8 +198,12 @@ class DruidTests(SupersetTestCase):
}
def check():
resp = self.client.post('/superset/sync_druid/', data=json.dumps(cfg))
druid_ds = db.session.query(DruidDatasource).filter_by(
datasource_name="test_click").first()
druid_ds = (
db.session
.query(DruidDatasource)
.filter_by(datasource_name="test_click")
.one()
)
col_names = set([c.column_name for c in druid_ds.columns])
assert {"affiliate_id", "campaign", "first_seen"} == col_names
metric_names = {m.metric_name for m in druid_ds.metrics}
@ -224,7 +229,7 @@ class DruidTests(SupersetTestCase):
}
resp = self.client.post('/superset/sync_druid/', data=json.dumps(cfg))
druid_ds = db.session.query(DruidDatasource).filter_by(
datasource_name="test_click").first()
datasource_name="test_click").one()
# columns and metrics are not deleted if config is changed as
# user could define his own dimensions / metrics and want to keep them
assert set([c.column_name for c in druid_ds.columns]) == set(

View File

@ -10,7 +10,11 @@ import json
import pickle
import unittest
from superset import db, models
from superset import db
from superset.models import core as models
from superset.connectors.druid.models import (
DruidDatasource, DruidColumn, DruidMetric)
from superset.connectors.sqla.models import SqlaTable, TableColumn, SqlMetric
from .base_tests import SupersetTestCase
@ -31,10 +35,10 @@ class ImportExportTests(SupersetTestCase):
for dash in session.query(models.Dashboard):
if 'remote_id' in dash.params_dict:
session.delete(dash)
for table in session.query(models.SqlaTable):
for table in session.query(SqlaTable):
if 'remote_id' in table.params_dict:
session.delete(table)
for datasource in session.query(models.DruidDatasource):
for datasource in session.query(DruidDatasource):
if 'remote_id' in datasource.params_dict:
session.delete(datasource)
session.commit()
@ -90,7 +94,7 @@ class ImportExportTests(SupersetTestCase):
def create_table(
self, name, schema='', id=0, cols_names=[], metric_names=[]):
params = {'remote_id': id, 'database_name': 'main'}
table = models.SqlaTable(
table = SqlaTable(
id=id,
schema=schema,
table_name=name,
@ -98,15 +102,15 @@ class ImportExportTests(SupersetTestCase):
)
for col_name in cols_names:
table.columns.append(
models.TableColumn(column_name=col_name))
TableColumn(column_name=col_name))
for metric_name in metric_names:
table.metrics.append(models.SqlMetric(metric_name=metric_name))
table.metrics.append(SqlMetric(metric_name=metric_name))
return table
def create_druid_datasource(
self, name, id=0, cols_names=[], metric_names=[]):
params = {'remote_id': id, 'database_name': 'druid_test'}
datasource = models.DruidDatasource(
datasource = DruidDatasource(
id=id,
datasource_name=name,
cluster_name='druid_test',
@ -114,9 +118,9 @@ class ImportExportTests(SupersetTestCase):
)
for col_name in cols_names:
datasource.columns.append(
models.DruidColumn(column_name=col_name))
DruidColumn(column_name=col_name))
for metric_name in metric_names:
datasource.metrics.append(models.DruidMetric(
datasource.metrics.append(DruidMetric(
metric_name=metric_name))
return datasource
@ -136,11 +140,11 @@ class ImportExportTests(SupersetTestCase):
slug=dash_slug).first()
def get_datasource(self, datasource_id):
return db.session.query(models.DruidDatasource).filter_by(
return db.session.query(DruidDatasource).filter_by(
id=datasource_id).first()
def get_table_by_name(self, name):
return db.session.query(models.SqlaTable).filter_by(
return db.session.query(SqlaTable).filter_by(
table_name=name).first()
def assert_dash_equals(self, expected_dash, actual_dash,
@ -392,7 +396,7 @@ class ImportExportTests(SupersetTestCase):
def test_import_table_no_metadata(self):
table = self.create_table('pure_table', id=10001)
imported_id = models.SqlaTable.import_obj(table, import_time=1989)
imported_id = SqlaTable.import_obj(table, import_time=1989)
imported = self.get_table(imported_id)
self.assert_table_equals(table, imported)
@ -400,7 +404,7 @@ class ImportExportTests(SupersetTestCase):
table = self.create_table(
'table_1_col_1_met', id=10002,
cols_names=["col1"], metric_names=["metric1"])
imported_id = models.SqlaTable.import_obj(table, import_time=1990)
imported_id = SqlaTable.import_obj(table, import_time=1990)
imported = self.get_table(imported_id)
self.assert_table_equals(table, imported)
self.assertEquals(
@ -411,7 +415,7 @@ class ImportExportTests(SupersetTestCase):
table = self.create_table(
'table_2_col_2_met', id=10003, cols_names=['c1', 'c2'],
metric_names=['m1', 'm2'])
imported_id = models.SqlaTable.import_obj(table, import_time=1991)
imported_id = SqlaTable.import_obj(table, import_time=1991)
imported = self.get_table(imported_id)
self.assert_table_equals(table, imported)
@ -420,12 +424,12 @@ class ImportExportTests(SupersetTestCase):
table = self.create_table(
'table_override', id=10003, cols_names=['col1'],
metric_names=['m1'])
imported_id = models.SqlaTable.import_obj(table, import_time=1991)
imported_id = SqlaTable.import_obj(table, import_time=1991)
table_over = self.create_table(
'table_override', id=10003, cols_names=['new_col1', 'col2', 'col3'],
metric_names=['new_metric1'])
imported_over_id = models.SqlaTable.import_obj(
imported_over_id = SqlaTable.import_obj(
table_over, import_time=1992)
imported_over = self.get_table(imported_over_id)
@ -439,12 +443,12 @@ class ImportExportTests(SupersetTestCase):
table = self.create_table(
'copy_cat', id=10004, cols_names=['new_col1', 'col2', 'col3'],
metric_names=['new_metric1'])
imported_id = models.SqlaTable.import_obj(table, import_time=1993)
imported_id = SqlaTable.import_obj(table, import_time=1993)
copy_table = self.create_table(
'copy_cat', id=10004, cols_names=['new_col1', 'col2', 'col3'],
metric_names=['new_metric1'])
imported_id_copy = models.SqlaTable.import_obj(
imported_id_copy = SqlaTable.import_obj(
copy_table, import_time=1994)
self.assertEquals(imported_id, imported_id_copy)
@ -452,7 +456,7 @@ class ImportExportTests(SupersetTestCase):
def test_import_druid_no_metadata(self):
datasource = self.create_druid_datasource('pure_druid', id=10001)
imported_id = models.DruidDatasource.import_obj(
imported_id = DruidDatasource.import_obj(
datasource, import_time=1989)
imported = self.get_datasource(imported_id)
self.assert_datasource_equals(datasource, imported)
@ -461,7 +465,7 @@ class ImportExportTests(SupersetTestCase):
datasource = self.create_druid_datasource(
'druid_1_col_1_met', id=10002,
cols_names=["col1"], metric_names=["metric1"])
imported_id = models.DruidDatasource.import_obj(
imported_id = DruidDatasource.import_obj(
datasource, import_time=1990)
imported = self.get_datasource(imported_id)
self.assert_datasource_equals(datasource, imported)
@ -474,7 +478,7 @@ class ImportExportTests(SupersetTestCase):
datasource = self.create_druid_datasource(
'druid_2_col_2_met', id=10003, cols_names=['c1', 'c2'],
metric_names=['m1', 'm2'])
imported_id = models.DruidDatasource.import_obj(
imported_id = DruidDatasource.import_obj(
datasource, import_time=1991)
imported = self.get_datasource(imported_id)
self.assert_datasource_equals(datasource, imported)
@ -483,14 +487,14 @@ class ImportExportTests(SupersetTestCase):
datasource = self.create_druid_datasource(
'druid_override', id=10003, cols_names=['col1'],
metric_names=['m1'])
imported_id = models.DruidDatasource.import_obj(
imported_id = DruidDatasource.import_obj(
datasource, import_time=1991)
table_over = self.create_druid_datasource(
'druid_override', id=10003,
cols_names=['new_col1', 'col2', 'col3'],
metric_names=['new_metric1'])
imported_over_id = models.DruidDatasource.import_obj(
imported_over_id = DruidDatasource.import_obj(
table_over, import_time=1992)
imported_over = self.get_datasource(imported_over_id)
@ -504,13 +508,13 @@ class ImportExportTests(SupersetTestCase):
datasource = self.create_druid_datasource(
'copy_cat', id=10004, cols_names=['new_col1', 'col2', 'col3'],
metric_names=['new_metric1'])
imported_id = models.DruidDatasource.import_obj(
imported_id = DruidDatasource.import_obj(
datasource, import_time=1993)
copy_datasource = self.create_druid_datasource(
'copy_cat', id=10004, cols_names=['new_col1', 'col2', 'col3'],
metric_names=['new_metric1'])
imported_id_copy = models.DruidDatasource.import_obj(
imported_id_copy = DruidDatasource.import_obj(
copy_datasource, import_time=1994)
self.assertEquals(imported_id, imported_id_copy)

View File

@ -2,7 +2,7 @@ import unittest
from sqlalchemy.engine.url import make_url
from superset.models import Database
from superset.models.core import Database
class DatabaseModelTestCase(unittest.TestCase):

View File

@ -9,7 +9,9 @@ import json
import unittest
from flask_appbuilder.security.sqla import models as ab_models
from superset import db, models, utils, appbuilder, security, sm
from superset import db, utils, appbuilder, sm
from superset.models import core as models
from .base_tests import SupersetTestCase