mirror of https://github.com/apache/superset.git
Refactoring Druid & SQLa into a proper "Connector" interface (#2362)
* Formalizing the Connector interface * Checkpoint * Fixing views * Fixing tests * Adding migrtion * Tests * Final * Addressing comments
This commit is contained in:
parent
9a8c3a0447
commit
2969cc9993
|
@ -13,9 +13,9 @@ from flask import Flask, redirect
|
|||
from flask_appbuilder import SQLA, AppBuilder, IndexView
|
||||
from flask_appbuilder.baseviews import expose
|
||||
from flask_migrate import Migrate
|
||||
from superset.source_registry import SourceRegistry
|
||||
from superset.connectors.connector_registry import ConnectorRegistry
|
||||
from werkzeug.contrib.fixers import ProxyFix
|
||||
from superset import utils
|
||||
from superset import utils, config # noqa
|
||||
|
||||
|
||||
APP_DIR = os.path.dirname(__file__)
|
||||
|
@ -104,6 +104,6 @@ results_backend = app.config.get("RESULTS_BACKEND")
|
|||
# Registering sources
|
||||
module_datasource_map = app.config.get("DEFAULT_MODULE_DS_MAP")
|
||||
module_datasource_map.update(app.config.get("ADDITIONAL_MODULE_DS_MAP"))
|
||||
SourceRegistry.register_sources(module_datasource_map)
|
||||
ConnectorRegistry.register_sources(module_datasource_map)
|
||||
|
||||
from superset import views, config # noqa
|
||||
from superset import views # noqa
|
||||
|
|
|
@ -13,7 +13,7 @@ from subprocess import Popen
|
|||
from flask_migrate import MigrateCommand
|
||||
from flask_script import Manager
|
||||
|
||||
from superset import app, db, data, security
|
||||
from superset import app, db, security
|
||||
|
||||
config = app.config
|
||||
|
||||
|
@ -89,6 +89,7 @@ def version(verbose):
|
|||
help="Load additional test data")
|
||||
def load_examples(load_test_data):
|
||||
"""Loads a set of Slices and Dashboards and a supporting dataset """
|
||||
from superset import data
|
||||
print("Loading examples into {}".format(db))
|
||||
|
||||
data.load_css_templates()
|
||||
|
|
|
@ -178,7 +178,10 @@ DRUID_DATA_SOURCE_BLACKLIST = []
|
|||
# --------------------------------------------------
|
||||
# Modules, datasources and middleware to be registered
|
||||
# --------------------------------------------------
|
||||
DEFAULT_MODULE_DS_MAP = {'superset.models': ['DruidDatasource', 'SqlaTable']}
|
||||
DEFAULT_MODULE_DS_MAP = {
|
||||
'superset.connectors.druid.models': ['DruidDatasource'],
|
||||
'superset.connectors.sqla.models': ['SqlaTable'],
|
||||
}
|
||||
ADDITIONAL_MODULE_DS_MAP = {}
|
||||
ADDITIONAL_MIDDLEWARE = []
|
||||
|
||||
|
@ -292,14 +295,17 @@ SILENCE_FAB = True
|
|||
BLUEPRINTS = []
|
||||
|
||||
try:
|
||||
|
||||
if CONFIG_PATH_ENV_VAR in os.environ:
|
||||
# Explicitly import config module that is not in pythonpath; useful
|
||||
# for case where app is being executed via pex.
|
||||
print('Loaded your LOCAL configuration at [{}]'.format(
|
||||
os.environ[CONFIG_PATH_ENV_VAR]))
|
||||
imp.load_source('superset_config', os.environ[CONFIG_PATH_ENV_VAR])
|
||||
|
||||
from superset_config import * # noqa
|
||||
import superset_config
|
||||
print('Loaded your LOCAL configuration at [{}]'.format(
|
||||
superset_config.__file__))
|
||||
else:
|
||||
from superset_config import * # noqa
|
||||
import superset_config
|
||||
print('Loaded your LOCAL configuration at [{}]'.format(
|
||||
superset_config.__file__))
|
||||
except ImportError:
|
||||
pass
|
||||
|
|
|
@ -0,0 +1,160 @@
|
|||
import json
|
||||
|
||||
from sqlalchemy import Column, Integer, String, Text, Boolean
|
||||
|
||||
from superset import utils
|
||||
from superset.models.helpers import AuditMixinNullable, ImportMixin
|
||||
|
||||
|
||||
class BaseDatasource(AuditMixinNullable, ImportMixin):
|
||||
|
||||
"""A common interface to objects that are queryable (tables and datasources)"""
|
||||
|
||||
__tablename__ = None # {connector_name}_datasource
|
||||
|
||||
# Used to do code highlighting when displaying the query in the UI
|
||||
query_language = None
|
||||
|
||||
@property
|
||||
def column_names(self):
|
||||
return sorted([c.column_name for c in self.columns])
|
||||
|
||||
@property
|
||||
def main_dttm_col(self):
|
||||
return "timestamp"
|
||||
|
||||
@property
|
||||
def groupby_column_names(self):
|
||||
return sorted([c.column_name for c in self.columns if c.groupby])
|
||||
|
||||
@property
|
||||
def filterable_column_names(self):
|
||||
return sorted([c.column_name for c in self.columns if c.filterable])
|
||||
|
||||
@property
|
||||
def dttm_cols(self):
|
||||
return []
|
||||
|
||||
@property
|
||||
def url(self):
|
||||
return '/{}/edit/{}'.format(self.baselink, self.id)
|
||||
|
||||
@property
|
||||
def explore_url(self):
|
||||
if self.default_endpoint:
|
||||
return self.default_endpoint
|
||||
else:
|
||||
return "/superset/explore/{obj.type}/{obj.id}/".format(obj=self)
|
||||
|
||||
@property
|
||||
def column_formats(self):
|
||||
return {
|
||||
m.metric_name: m.d3format
|
||||
for m in self.metrics
|
||||
if m.d3format
|
||||
}
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
"""Data representation of the datasource sent to the frontend"""
|
||||
order_by_choices = []
|
||||
for s in sorted(self.column_names):
|
||||
order_by_choices.append((json.dumps([s, True]), s + ' [asc]'))
|
||||
order_by_choices.append((json.dumps([s, False]), s + ' [desc]'))
|
||||
|
||||
d = {
|
||||
'all_cols': utils.choicify(self.column_names),
|
||||
'column_formats': self.column_formats,
|
||||
'edit_url': self.url,
|
||||
'filter_select': self.filter_select_enabled,
|
||||
'filterable_cols': utils.choicify(self.filterable_column_names),
|
||||
'gb_cols': utils.choicify(self.groupby_column_names),
|
||||
'id': self.id,
|
||||
'metrics_combo': self.metrics_combo,
|
||||
'name': self.name,
|
||||
'order_by_choices': order_by_choices,
|
||||
'type': self.type,
|
||||
}
|
||||
|
||||
# TODO move this block to SqlaTable.data
|
||||
if self.type == 'table':
|
||||
grains = self.database.grains() or []
|
||||
if grains:
|
||||
grains = [(g.name, g.name) for g in grains]
|
||||
d['granularity_sqla'] = utils.choicify(self.dttm_cols)
|
||||
d['time_grain_sqla'] = grains
|
||||
return d
|
||||
|
||||
|
||||
class BaseColumn(AuditMixinNullable, ImportMixin):
|
||||
"""Interface for column"""
|
||||
|
||||
__tablename__ = None # {connector_name}_column
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
column_name = Column(String(255))
|
||||
verbose_name = Column(String(1024))
|
||||
is_active = Column(Boolean, default=True)
|
||||
type = Column(String(32))
|
||||
groupby = Column(Boolean, default=False)
|
||||
count_distinct = Column(Boolean, default=False)
|
||||
sum = Column(Boolean, default=False)
|
||||
avg = Column(Boolean, default=False)
|
||||
max = Column(Boolean, default=False)
|
||||
min = Column(Boolean, default=False)
|
||||
filterable = Column(Boolean, default=False)
|
||||
description = Column(Text)
|
||||
|
||||
# [optional] Set this to support import/export functionality
|
||||
export_fields = []
|
||||
|
||||
def __repr__(self):
|
||||
return self.column_name
|
||||
|
||||
num_types = ('DOUBLE', 'FLOAT', 'INT', 'BIGINT', 'LONG', 'REAL', 'NUMERIC')
|
||||
date_types = ('DATE', 'TIME', 'DATETIME')
|
||||
str_types = ('VARCHAR', 'STRING', 'CHAR')
|
||||
|
||||
@property
|
||||
def is_num(self):
|
||||
return any([t in self.type.upper() for t in self.num_types])
|
||||
|
||||
@property
|
||||
def is_time(self):
|
||||
return any([t in self.type.upper() for t in self.date_types])
|
||||
|
||||
@property
|
||||
def is_string(self):
|
||||
return any([t in self.type.upper() for t in self.str_types])
|
||||
|
||||
|
||||
class BaseMetric(AuditMixinNullable, ImportMixin):
|
||||
|
||||
"""Interface for Metrics"""
|
||||
|
||||
__tablename__ = None # {connector_name}_metric
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
metric_name = Column(String(512))
|
||||
verbose_name = Column(String(1024))
|
||||
metric_type = Column(String(32))
|
||||
description = Column(Text)
|
||||
is_restricted = Column(Boolean, default=False, nullable=True)
|
||||
d3format = Column(String(128))
|
||||
|
||||
"""
|
||||
The interface should also declare a datasource relationship pointing
|
||||
to a derivative of BaseDatasource, along with a FK
|
||||
|
||||
datasource_name = Column(
|
||||
String(255),
|
||||
ForeignKey('datasources.datasource_name'))
|
||||
datasource = relationship(
|
||||
# needs to be altered to point to {Connector}Datasource
|
||||
'BaseDatasource',
|
||||
backref=backref('metrics', cascade='all, delete-orphan'),
|
||||
enable_typechecks=False)
|
||||
"""
|
||||
@property
|
||||
def perm(self):
|
||||
raise NotImplementedError()
|
|
@ -1,7 +1,7 @@
|
|||
from sqlalchemy.orm import subqueryload
|
||||
|
||||
|
||||
class SourceRegistry(object):
|
||||
class ConnectorRegistry(object):
|
||||
""" Central Registry for all available datasource engines"""
|
||||
|
||||
sources = {}
|
||||
|
@ -26,15 +26,15 @@ class SourceRegistry(object):
|
|||
@classmethod
|
||||
def get_all_datasources(cls, session):
|
||||
datasources = []
|
||||
for source_type in SourceRegistry.sources:
|
||||
for source_type in ConnectorRegistry.sources:
|
||||
datasources.extend(
|
||||
session.query(SourceRegistry.sources[source_type]).all())
|
||||
session.query(ConnectorRegistry.sources[source_type]).all())
|
||||
return datasources
|
||||
|
||||
@classmethod
|
||||
def get_datasource_by_name(cls, session, datasource_type, datasource_name,
|
||||
schema, database_name):
|
||||
datasource_class = SourceRegistry.sources[datasource_type]
|
||||
datasource_class = ConnectorRegistry.sources[datasource_type]
|
||||
datasources = session.query(datasource_class).all()
|
||||
|
||||
# Filter datasoures that don't have database.
|
||||
|
@ -45,7 +45,7 @@ class SourceRegistry(object):
|
|||
|
||||
@classmethod
|
||||
def query_datasources_by_permissions(cls, session, database, permissions):
|
||||
datasource_class = SourceRegistry.sources[database.type]
|
||||
datasource_class = ConnectorRegistry.sources[database.type]
|
||||
return (
|
||||
session.query(datasource_class)
|
||||
.filter_by(database_id=database.id)
|
||||
|
@ -56,7 +56,7 @@ class SourceRegistry(object):
|
|||
@classmethod
|
||||
def get_eager_datasource(cls, session, datasource_type, datasource_id):
|
||||
"""Returns datasource with columns and metrics."""
|
||||
datasource_class = SourceRegistry.sources[datasource_type]
|
||||
datasource_class = ConnectorRegistry.sources[datasource_type]
|
||||
return (
|
||||
session.query(datasource_class)
|
||||
.options(
|
|
@ -0,0 +1,2 @@
|
|||
from . import models # noqa
|
||||
from . import views # noqa
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,203 @@
|
|||
import sqlalchemy as sqla
|
||||
|
||||
from flask import Markup
|
||||
from flask_appbuilder import CompactCRUDMixin
|
||||
from flask_appbuilder.models.sqla.interface import SQLAInterface
|
||||
|
||||
from flask_babel import lazy_gettext as _
|
||||
from flask_babel import gettext as __
|
||||
|
||||
import superset
|
||||
from superset import db, utils, appbuilder, sm, security
|
||||
from superset.views.base import (
|
||||
SupersetModelView, validate_json, DeleteMixin, ListWidgetWithCheckboxes,
|
||||
DatasourceFilter, get_datasource_exist_error_mgs)
|
||||
|
||||
from . import models
|
||||
|
||||
|
||||
class DruidColumnInlineView(CompactCRUDMixin, SupersetModelView): # noqa
|
||||
datamodel = SQLAInterface(models.DruidColumn)
|
||||
edit_columns = [
|
||||
'column_name', 'description', 'dimension_spec_json', 'datasource',
|
||||
'groupby', 'count_distinct', 'sum', 'min', 'max']
|
||||
add_columns = edit_columns
|
||||
list_columns = [
|
||||
'column_name', 'type', 'groupby', 'filterable', 'count_distinct',
|
||||
'sum', 'min', 'max']
|
||||
can_delete = False
|
||||
page_size = 500
|
||||
label_columns = {
|
||||
'column_name': _("Column"),
|
||||
'type': _("Type"),
|
||||
'datasource': _("Datasource"),
|
||||
'groupby': _("Groupable"),
|
||||
'filterable': _("Filterable"),
|
||||
'count_distinct': _("Count Distinct"),
|
||||
'sum': _("Sum"),
|
||||
'min': _("Min"),
|
||||
'max': _("Max"),
|
||||
}
|
||||
description_columns = {
|
||||
'dimension_spec_json': utils.markdown(
|
||||
"this field can be used to specify "
|
||||
"a `dimensionSpec` as documented [here]"
|
||||
"(http://druid.io/docs/latest/querying/dimensionspecs.html). "
|
||||
"Make sure to input valid JSON and that the "
|
||||
"`outputName` matches the `column_name` defined "
|
||||
"above.",
|
||||
True),
|
||||
}
|
||||
|
||||
def post_update(self, col):
|
||||
col.generate_metrics()
|
||||
utils.validate_json(col.dimension_spec_json)
|
||||
|
||||
def post_add(self, col):
|
||||
self.post_update(col)
|
||||
|
||||
appbuilder.add_view_no_menu(DruidColumnInlineView)
|
||||
|
||||
|
||||
class DruidMetricInlineView(CompactCRUDMixin, SupersetModelView): # noqa
|
||||
datamodel = SQLAInterface(models.DruidMetric)
|
||||
list_columns = ['metric_name', 'verbose_name', 'metric_type']
|
||||
edit_columns = [
|
||||
'metric_name', 'description', 'verbose_name', 'metric_type', 'json',
|
||||
'datasource', 'd3format', 'is_restricted']
|
||||
add_columns = edit_columns
|
||||
page_size = 500
|
||||
validators_columns = {
|
||||
'json': [validate_json],
|
||||
}
|
||||
description_columns = {
|
||||
'metric_type': utils.markdown(
|
||||
"use `postagg` as the metric type if you are defining a "
|
||||
"[Druid Post Aggregation]"
|
||||
"(http://druid.io/docs/latest/querying/post-aggregations.html)",
|
||||
True),
|
||||
'is_restricted': _("Whether the access to this metric is restricted "
|
||||
"to certain roles. Only roles with the permission "
|
||||
"'metric access on XXX (the name of this metric)' "
|
||||
"are allowed to access this metric"),
|
||||
}
|
||||
label_columns = {
|
||||
'metric_name': _("Metric"),
|
||||
'description': _("Description"),
|
||||
'verbose_name': _("Verbose Name"),
|
||||
'metric_type': _("Type"),
|
||||
'json': _("JSON"),
|
||||
'datasource': _("Druid Datasource"),
|
||||
}
|
||||
|
||||
def post_add(self, metric):
|
||||
utils.init_metrics_perm(superset, [metric])
|
||||
|
||||
def post_update(self, metric):
|
||||
utils.init_metrics_perm(superset, [metric])
|
||||
|
||||
|
||||
appbuilder.add_view_no_menu(DruidMetricInlineView)
|
||||
|
||||
|
||||
class DruidClusterModelView(SupersetModelView, DeleteMixin): # noqa
|
||||
datamodel = SQLAInterface(models.DruidCluster)
|
||||
add_columns = [
|
||||
'cluster_name',
|
||||
'coordinator_host', 'coordinator_port', 'coordinator_endpoint',
|
||||
'broker_host', 'broker_port', 'broker_endpoint', 'cache_timeout',
|
||||
]
|
||||
edit_columns = add_columns
|
||||
list_columns = ['cluster_name', 'metadata_last_refreshed']
|
||||
label_columns = {
|
||||
'cluster_name': _("Cluster"),
|
||||
'coordinator_host': _("Coordinator Host"),
|
||||
'coordinator_port': _("Coordinator Port"),
|
||||
'coordinator_endpoint': _("Coordinator Endpoint"),
|
||||
'broker_host': _("Broker Host"),
|
||||
'broker_port': _("Broker Port"),
|
||||
'broker_endpoint': _("Broker Endpoint"),
|
||||
}
|
||||
|
||||
def pre_add(self, cluster):
|
||||
security.merge_perm(sm, 'database_access', cluster.perm)
|
||||
|
||||
def pre_update(self, cluster):
|
||||
self.pre_add(cluster)
|
||||
|
||||
|
||||
appbuilder.add_view(
|
||||
DruidClusterModelView,
|
||||
name="Druid Clusters",
|
||||
label=__("Druid Clusters"),
|
||||
icon="fa-cubes",
|
||||
category="Sources",
|
||||
category_label=__("Sources"),
|
||||
category_icon='fa-database',)
|
||||
|
||||
|
||||
class DruidDatasourceModelView(SupersetModelView, DeleteMixin): # noqa
|
||||
datamodel = SQLAInterface(models.DruidDatasource)
|
||||
list_widget = ListWidgetWithCheckboxes
|
||||
list_columns = [
|
||||
'datasource_link', 'cluster', 'changed_by_', 'changed_on_', 'offset']
|
||||
order_columns = [
|
||||
'datasource_link', 'changed_on_', 'offset']
|
||||
related_views = [DruidColumnInlineView, DruidMetricInlineView]
|
||||
edit_columns = [
|
||||
'datasource_name', 'cluster', 'description', 'owner',
|
||||
'is_featured', 'is_hidden', 'filter_select_enabled',
|
||||
'default_endpoint', 'offset', 'cache_timeout']
|
||||
add_columns = edit_columns
|
||||
show_columns = add_columns + ['perm']
|
||||
page_size = 500
|
||||
base_order = ('datasource_name', 'asc')
|
||||
description_columns = {
|
||||
'offset': _("Timezone offset (in hours) for this datasource"),
|
||||
'description': Markup(
|
||||
"Supports <a href='"
|
||||
"https://daringfireball.net/projects/markdown/'>markdown</a>"),
|
||||
}
|
||||
base_filters = [['id', DatasourceFilter, lambda: []]]
|
||||
label_columns = {
|
||||
'datasource_link': _("Data Source"),
|
||||
'cluster': _("Cluster"),
|
||||
'description': _("Description"),
|
||||
'owner': _("Owner"),
|
||||
'is_featured': _("Is Featured"),
|
||||
'is_hidden': _("Is Hidden"),
|
||||
'filter_select_enabled': _("Enable Filter Select"),
|
||||
'default_endpoint': _("Default Endpoint"),
|
||||
'offset': _("Time Offset"),
|
||||
'cache_timeout': _("Cache Timeout"),
|
||||
}
|
||||
|
||||
def pre_add(self, datasource):
|
||||
number_of_existing_datasources = db.session.query(
|
||||
sqla.func.count('*')).filter(
|
||||
models.DruidDatasource.datasource_name ==
|
||||
datasource.datasource_name,
|
||||
models.DruidDatasource.cluster_name == datasource.cluster.id
|
||||
).scalar()
|
||||
|
||||
# table object is already added to the session
|
||||
if number_of_existing_datasources > 1:
|
||||
raise Exception(get_datasource_exist_error_mgs(
|
||||
datasource.full_name))
|
||||
|
||||
def post_add(self, datasource):
|
||||
datasource.generate_metrics()
|
||||
security.merge_perm(sm, 'datasource_access', datasource.get_perm())
|
||||
if datasource.schema:
|
||||
security.merge_perm(sm, 'schema_access', datasource.schema_perm)
|
||||
|
||||
def post_update(self, datasource):
|
||||
self.post_add(datasource)
|
||||
|
||||
appbuilder.add_view(
|
||||
DruidDatasourceModelView,
|
||||
"Druid Datasources",
|
||||
label=__("Druid Datasources"),
|
||||
category="Sources",
|
||||
category_label=__("Sources"),
|
||||
icon="fa-cube")
|
|
@ -0,0 +1,2 @@
|
|||
from . import models # noqa
|
||||
from . import views # noqa
|
|
@ -0,0 +1,664 @@
|
|||
from datetime import datetime
|
||||
import logging
|
||||
import sqlparse
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from sqlalchemy import (
|
||||
Column, Integer, String, ForeignKey, Text, Boolean,
|
||||
DateTime,
|
||||
)
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import asc, and_, desc, select
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.sql.expression import ColumnClause, TextAsFrom
|
||||
from sqlalchemy.orm import backref, relationship
|
||||
from sqlalchemy.sql import table, literal_column, text, column
|
||||
|
||||
from flask import escape, Markup
|
||||
from flask_appbuilder import Model
|
||||
from flask_babel import lazy_gettext as _
|
||||
|
||||
from superset import db, utils, import_util
|
||||
from superset.connectors.base import BaseDatasource, BaseColumn, BaseMetric
|
||||
from superset.utils import (
|
||||
wrap_clause_in_parens,
|
||||
DTTM_ALIAS, QueryStatus
|
||||
)
|
||||
from superset.models.helpers import QueryResult
|
||||
from superset.models.core import Database
|
||||
from superset.jinja_context import get_template_processor
|
||||
from superset.models.helpers import set_perm
|
||||
|
||||
|
||||
class TableColumn(Model, BaseColumn):
|
||||
|
||||
"""ORM object for table columns, each table can have multiple columns"""
|
||||
|
||||
__tablename__ = 'table_columns'
|
||||
table_id = Column(Integer, ForeignKey('tables.id'))
|
||||
table = relationship(
|
||||
'SqlaTable',
|
||||
backref=backref('columns', cascade='all, delete-orphan'),
|
||||
foreign_keys=[table_id])
|
||||
is_dttm = Column(Boolean, default=False)
|
||||
expression = Column(Text, default='')
|
||||
python_date_format = Column(String(255))
|
||||
database_expression = Column(String(255))
|
||||
|
||||
export_fields = (
|
||||
'table_id', 'column_name', 'verbose_name', 'is_dttm', 'is_active',
|
||||
'type', 'groupby', 'count_distinct', 'sum', 'avg', 'max', 'min',
|
||||
'filterable', 'expression', 'description', 'python_date_format',
|
||||
'database_expression'
|
||||
)
|
||||
|
||||
@property
|
||||
def sqla_col(self):
|
||||
name = self.column_name
|
||||
if not self.expression:
|
||||
col = column(self.column_name).label(name)
|
||||
else:
|
||||
col = literal_column(self.expression).label(name)
|
||||
return col
|
||||
|
||||
def get_time_filter(self, start_dttm, end_dttm):
|
||||
col = self.sqla_col.label('__time')
|
||||
return and_(
|
||||
col >= text(self.dttm_sql_literal(start_dttm)),
|
||||
col <= text(self.dttm_sql_literal(end_dttm)),
|
||||
)
|
||||
|
||||
def get_timestamp_expression(self, time_grain):
|
||||
"""Getting the time component of the query"""
|
||||
expr = self.expression or self.column_name
|
||||
if not self.expression and not time_grain:
|
||||
return column(expr, type_=DateTime).label(DTTM_ALIAS)
|
||||
if time_grain:
|
||||
pdf = self.python_date_format
|
||||
if pdf in ('epoch_s', 'epoch_ms'):
|
||||
# if epoch, translate to DATE using db specific conf
|
||||
db_spec = self.table.database.db_engine_spec
|
||||
if pdf == 'epoch_s':
|
||||
expr = db_spec.epoch_to_dttm().format(col=expr)
|
||||
elif pdf == 'epoch_ms':
|
||||
expr = db_spec.epoch_ms_to_dttm().format(col=expr)
|
||||
grain = self.table.database.grains_dict().get(time_grain, '{col}')
|
||||
expr = grain.function.format(col=expr)
|
||||
return literal_column(expr, type_=DateTime).label(DTTM_ALIAS)
|
||||
|
||||
@classmethod
|
||||
def import_obj(cls, i_column):
|
||||
def lookup_obj(lookup_column):
|
||||
return db.session.query(TableColumn).filter(
|
||||
TableColumn.table_id == lookup_column.table_id,
|
||||
TableColumn.column_name == lookup_column.column_name).first()
|
||||
return import_util.import_simple_obj(db.session, i_column, lookup_obj)
|
||||
|
||||
def dttm_sql_literal(self, dttm):
|
||||
"""Convert datetime object to a SQL expression string
|
||||
|
||||
If database_expression is empty, the internal dttm
|
||||
will be parsed as the string with the pattern that
|
||||
the user inputted (python_date_format)
|
||||
If database_expression is not empty, the internal dttm
|
||||
will be parsed as the sql sentence for the database to convert
|
||||
"""
|
||||
|
||||
tf = self.python_date_format or '%Y-%m-%d %H:%M:%S.%f'
|
||||
if self.database_expression:
|
||||
return self.database_expression.format(dttm.strftime('%Y-%m-%d %H:%M:%S'))
|
||||
elif tf == 'epoch_s':
|
||||
return str((dttm - datetime(1970, 1, 1)).total_seconds())
|
||||
elif tf == 'epoch_ms':
|
||||
return str((dttm - datetime(1970, 1, 1)).total_seconds() * 1000.0)
|
||||
else:
|
||||
s = self.table.database.db_engine_spec.convert_dttm(
|
||||
self.type, dttm)
|
||||
return s or "'{}'".format(dttm.strftime(tf))
|
||||
|
||||
|
||||
class SqlMetric(Model, BaseMetric):
|
||||
|
||||
"""ORM object for metrics, each table can have multiple metrics"""
|
||||
|
||||
__tablename__ = 'sql_metrics'
|
||||
table_id = Column(Integer, ForeignKey('tables.id'))
|
||||
table = relationship(
|
||||
'SqlaTable',
|
||||
backref=backref('metrics', cascade='all, delete-orphan'),
|
||||
foreign_keys=[table_id])
|
||||
expression = Column(Text)
|
||||
|
||||
export_fields = (
|
||||
'metric_name', 'verbose_name', 'metric_type', 'table_id', 'expression',
|
||||
'description', 'is_restricted', 'd3format')
|
||||
|
||||
@property
|
||||
def sqla_col(self):
|
||||
name = self.metric_name
|
||||
return literal_column(self.expression).label(name)
|
||||
|
||||
@property
|
||||
def perm(self):
|
||||
return (
|
||||
"{parent_name}.[{obj.metric_name}](id:{obj.id})"
|
||||
).format(obj=self,
|
||||
parent_name=self.table.full_name) if self.table else None
|
||||
|
||||
@classmethod
|
||||
def import_obj(cls, i_metric):
|
||||
def lookup_obj(lookup_metric):
|
||||
return db.session.query(SqlMetric).filter(
|
||||
SqlMetric.table_id == lookup_metric.table_id,
|
||||
SqlMetric.metric_name == lookup_metric.metric_name).first()
|
||||
return import_util.import_simple_obj(db.session, i_metric, lookup_obj)
|
||||
|
||||
|
||||
class SqlaTable(Model, BaseDatasource):
|
||||
|
||||
"""An ORM object for SqlAlchemy table references"""
|
||||
|
||||
type = "table"
|
||||
query_language = 'sql'
|
||||
metric_class = SqlMetric
|
||||
|
||||
__tablename__ = 'tables'
|
||||
id = Column(Integer, primary_key=True)
|
||||
table_name = Column(String(250))
|
||||
main_dttm_col = Column(String(250))
|
||||
description = Column(Text)
|
||||
default_endpoint = Column(Text)
|
||||
database_id = Column(Integer, ForeignKey('dbs.id'), nullable=False)
|
||||
is_featured = Column(Boolean, default=False)
|
||||
filter_select_enabled = Column(Boolean, default=False)
|
||||
user_id = Column(Integer, ForeignKey('ab_user.id'))
|
||||
owner = relationship('User', backref='tables', foreign_keys=[user_id])
|
||||
database = relationship(
|
||||
'Database',
|
||||
backref=backref('tables', cascade='all, delete-orphan'),
|
||||
foreign_keys=[database_id])
|
||||
offset = Column(Integer, default=0)
|
||||
cache_timeout = Column(Integer)
|
||||
schema = Column(String(255))
|
||||
sql = Column(Text)
|
||||
params = Column(Text)
|
||||
perm = Column(String(1000))
|
||||
|
||||
baselink = "tablemodelview"
|
||||
column_cls = TableColumn
|
||||
metric_cls = SqlMetric
|
||||
export_fields = (
|
||||
'table_name', 'main_dttm_col', 'description', 'default_endpoint',
|
||||
'database_id', 'is_featured', 'offset', 'cache_timeout', 'schema',
|
||||
'sql', 'params')
|
||||
|
||||
__table_args__ = (
|
||||
sa.UniqueConstraint(
|
||||
'database_id', 'schema', 'table_name',
|
||||
name='_customer_location_uc'),)
|
||||
|
||||
def __repr__(self):
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def description_markeddown(self):
|
||||
return utils.markdown(self.description)
|
||||
|
||||
@property
|
||||
def link(self):
|
||||
name = escape(self.name)
|
||||
return Markup(
|
||||
'<a href="{self.explore_url}">{name}</a>'.format(**locals()))
|
||||
|
||||
@property
|
||||
def schema_perm(self):
|
||||
"""Returns schema permission if present, database one otherwise."""
|
||||
return utils.get_schema_perm(self.database, self.schema)
|
||||
|
||||
def get_perm(self):
|
||||
return (
|
||||
"[{obj.database}].[{obj.table_name}]"
|
||||
"(id:{obj.id})").format(obj=self)
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
if not self.schema:
|
||||
return self.table_name
|
||||
return "{}.{}".format(self.schema, self.table_name)
|
||||
|
||||
@property
|
||||
def full_name(self):
|
||||
return utils.get_datasource_full_name(
|
||||
self.database, self.table_name, schema=self.schema)
|
||||
|
||||
@property
|
||||
def dttm_cols(self):
|
||||
l = [c.column_name for c in self.columns if c.is_dttm]
|
||||
if self.main_dttm_col and self.main_dttm_col not in l:
|
||||
l.append(self.main_dttm_col)
|
||||
return l
|
||||
|
||||
@property
|
||||
def num_cols(self):
|
||||
return [c.column_name for c in self.columns if c.is_num]
|
||||
|
||||
@property
|
||||
def any_dttm_col(self):
|
||||
cols = self.dttm_cols
|
||||
if cols:
|
||||
return cols[0]
|
||||
|
||||
@property
|
||||
def html(self):
|
||||
t = ((c.column_name, c.type) for c in self.columns)
|
||||
df = pd.DataFrame(t)
|
||||
df.columns = ['field', 'type']
|
||||
return df.to_html(
|
||||
index=False,
|
||||
classes=(
|
||||
"dataframe table table-striped table-bordered "
|
||||
"table-condensed"))
|
||||
|
||||
@property
|
||||
def metrics_combo(self):
|
||||
return sorted(
|
||||
[
|
||||
(m.metric_name, m.verbose_name or m.metric_name)
|
||||
for m in self.metrics],
|
||||
key=lambda x: x[1])
|
||||
|
||||
@property
|
||||
def sql_url(self):
|
||||
return self.database.sql_url + "?table_name=" + str(self.table_name)
|
||||
|
||||
@property
|
||||
def time_column_grains(self):
|
||||
return {
|
||||
"time_columns": self.dttm_cols,
|
||||
"time_grains": [grain.name for grain in self.database.grains()]
|
||||
}
|
||||
|
||||
def get_col(self, col_name):
|
||||
columns = self.columns
|
||||
for col in columns:
|
||||
if col_name == col.column_name:
|
||||
return col
|
||||
|
||||
def values_for_column(self,
|
||||
column_name,
|
||||
from_dttm,
|
||||
to_dttm,
|
||||
limit=500):
|
||||
"""Runs query against sqla to retrieve some
|
||||
sample values for the given column.
|
||||
"""
|
||||
granularity = self.main_dttm_col
|
||||
|
||||
cols = {col.column_name: col for col in self.columns}
|
||||
target_col = cols[column_name]
|
||||
|
||||
tbl = table(self.table_name)
|
||||
qry = sa.select([target_col.sqla_col])
|
||||
qry = qry.select_from(tbl)
|
||||
qry = qry.distinct(column_name)
|
||||
qry = qry.limit(limit)
|
||||
|
||||
if granularity:
|
||||
dttm_col = cols[granularity]
|
||||
timestamp = dttm_col.sqla_col.label('timestamp')
|
||||
time_filter = [
|
||||
timestamp >= text(dttm_col.dttm_sql_literal(from_dttm)),
|
||||
timestamp <= text(dttm_col.dttm_sql_literal(to_dttm)),
|
||||
]
|
||||
qry = qry.where(and_(*time_filter))
|
||||
|
||||
engine = self.database.get_sqla_engine()
|
||||
sql = "{}".format(
|
||||
qry.compile(
|
||||
engine, compile_kwargs={"literal_binds": True}, ),
|
||||
)
|
||||
|
||||
return pd.read_sql_query(
|
||||
sql=sql,
|
||||
con=engine
|
||||
)
|
||||
|
||||
def get_query_str( # sqla
|
||||
self, engine, qry_start_dttm,
|
||||
groupby, metrics,
|
||||
granularity,
|
||||
from_dttm, to_dttm,
|
||||
filter=None, # noqa
|
||||
is_timeseries=True,
|
||||
timeseries_limit=15,
|
||||
timeseries_limit_metric=None,
|
||||
row_limit=None,
|
||||
inner_from_dttm=None,
|
||||
inner_to_dttm=None,
|
||||
orderby=None,
|
||||
extras=None,
|
||||
columns=None):
|
||||
"""Querying any sqla table from this common interface"""
|
||||
template_processor = get_template_processor(
|
||||
table=self, database=self.database)
|
||||
|
||||
# For backward compatibility
|
||||
if granularity not in self.dttm_cols:
|
||||
granularity = self.main_dttm_col
|
||||
|
||||
cols = {col.column_name: col for col in self.columns}
|
||||
metrics_dict = {m.metric_name: m for m in self.metrics}
|
||||
|
||||
if not granularity and is_timeseries:
|
||||
raise Exception(_(
|
||||
"Datetime column not provided as part table configuration "
|
||||
"and is required by this type of chart"))
|
||||
for m in metrics:
|
||||
if m not in metrics_dict:
|
||||
raise Exception(_("Metric '{}' is not valid".format(m)))
|
||||
metrics_exprs = [metrics_dict.get(m).sqla_col for m in metrics]
|
||||
timeseries_limit_metric = metrics_dict.get(timeseries_limit_metric)
|
||||
timeseries_limit_metric_expr = None
|
||||
if timeseries_limit_metric:
|
||||
timeseries_limit_metric_expr = \
|
||||
timeseries_limit_metric.sqla_col
|
||||
if metrics:
|
||||
main_metric_expr = metrics_exprs[0]
|
||||
else:
|
||||
main_metric_expr = literal_column("COUNT(*)").label("ccount")
|
||||
|
||||
select_exprs = []
|
||||
groupby_exprs = []
|
||||
|
||||
if groupby:
|
||||
select_exprs = []
|
||||
inner_select_exprs = []
|
||||
inner_groupby_exprs = []
|
||||
for s in groupby:
|
||||
col = cols[s]
|
||||
outer = col.sqla_col
|
||||
inner = col.sqla_col.label(col.column_name + '__')
|
||||
|
||||
groupby_exprs.append(outer)
|
||||
select_exprs.append(outer)
|
||||
inner_groupby_exprs.append(inner)
|
||||
inner_select_exprs.append(inner)
|
||||
elif columns:
|
||||
for s in columns:
|
||||
select_exprs.append(cols[s].sqla_col)
|
||||
metrics_exprs = []
|
||||
|
||||
if granularity:
|
||||
@compiles(ColumnClause)
|
||||
def visit_column(element, compiler, **kw):
|
||||
"""Patch for sqlalchemy bug
|
||||
|
||||
TODO: sqlalchemy 1.2 release should be doing this on its own.
|
||||
Patch only if the column clause is specific for DateTime
|
||||
set and granularity is selected.
|
||||
"""
|
||||
text = compiler.visit_column(element, **kw)
|
||||
try:
|
||||
if (
|
||||
element.is_literal and
|
||||
hasattr(element.type, 'python_type') and
|
||||
type(element.type) is DateTime
|
||||
):
|
||||
text = text.replace('%%', '%')
|
||||
except NotImplementedError:
|
||||
# Some elements raise NotImplementedError for python_type
|
||||
pass
|
||||
return text
|
||||
|
||||
dttm_col = cols[granularity]
|
||||
time_grain = extras.get('time_grain_sqla')
|
||||
|
||||
if is_timeseries:
|
||||
timestamp = dttm_col.get_timestamp_expression(time_grain)
|
||||
select_exprs += [timestamp]
|
||||
groupby_exprs += [timestamp]
|
||||
|
||||
time_filter = dttm_col.get_time_filter(from_dttm, to_dttm)
|
||||
|
||||
select_exprs += metrics_exprs
|
||||
qry = sa.select(select_exprs)
|
||||
|
||||
tbl = table(self.table_name)
|
||||
if self.schema:
|
||||
tbl.schema = self.schema
|
||||
|
||||
# Supporting arbitrary SQL statements in place of tables
|
||||
if self.sql:
|
||||
tbl = TextAsFrom(sa.text(self.sql), []).alias('expr_qry')
|
||||
|
||||
if not columns:
|
||||
qry = qry.group_by(*groupby_exprs)
|
||||
|
||||
where_clause_and = []
|
||||
having_clause_and = []
|
||||
for flt in filter:
|
||||
if not all([flt.get(s) for s in ['col', 'op', 'val']]):
|
||||
continue
|
||||
col = flt['col']
|
||||
op = flt['op']
|
||||
eq = flt['val']
|
||||
col_obj = cols.get(col)
|
||||
if col_obj and op in ('in', 'not in'):
|
||||
values = [types.strip("'").strip('"') for types in eq]
|
||||
if col_obj.is_num:
|
||||
values = [utils.js_string_to_num(s) for s in values]
|
||||
cond = col_obj.sqla_col.in_(values)
|
||||
if op == 'not in':
|
||||
cond = ~cond
|
||||
where_clause_and.append(cond)
|
||||
if extras:
|
||||
where = extras.get('where')
|
||||
if where:
|
||||
where_clause_and += [wrap_clause_in_parens(
|
||||
template_processor.process_template(where))]
|
||||
having = extras.get('having')
|
||||
if having:
|
||||
having_clause_and += [wrap_clause_in_parens(
|
||||
template_processor.process_template(having))]
|
||||
if granularity:
|
||||
qry = qry.where(and_(*([time_filter] + where_clause_and)))
|
||||
else:
|
||||
qry = qry.where(and_(*where_clause_and))
|
||||
qry = qry.having(and_(*having_clause_and))
|
||||
if groupby:
|
||||
qry = qry.order_by(desc(main_metric_expr))
|
||||
elif orderby:
|
||||
for col, ascending in orderby:
|
||||
direction = asc if ascending else desc
|
||||
qry = qry.order_by(direction(col))
|
||||
|
||||
qry = qry.limit(row_limit)
|
||||
|
||||
if is_timeseries and timeseries_limit and groupby:
|
||||
# some sql dialects require for order by expressions
|
||||
# to also be in the select clause -- others, e.g. vertica,
|
||||
# require a unique inner alias
|
||||
inner_main_metric_expr = main_metric_expr.label('mme_inner__')
|
||||
inner_select_exprs += [inner_main_metric_expr]
|
||||
subq = select(inner_select_exprs)
|
||||
subq = subq.select_from(tbl)
|
||||
inner_time_filter = dttm_col.get_time_filter(
|
||||
inner_from_dttm or from_dttm,
|
||||
inner_to_dttm or to_dttm,
|
||||
)
|
||||
subq = subq.where(and_(*(where_clause_and + [inner_time_filter])))
|
||||
subq = subq.group_by(*inner_groupby_exprs)
|
||||
ob = inner_main_metric_expr
|
||||
if timeseries_limit_metric_expr is not None:
|
||||
ob = timeseries_limit_metric_expr
|
||||
subq = subq.order_by(desc(ob))
|
||||
subq = subq.limit(timeseries_limit)
|
||||
on_clause = []
|
||||
for i, gb in enumerate(groupby):
|
||||
on_clause.append(
|
||||
groupby_exprs[i] == column(gb + '__'))
|
||||
|
||||
tbl = tbl.join(subq.alias(), and_(*on_clause))
|
||||
|
||||
qry = qry.select_from(tbl)
|
||||
|
||||
sql = "{}".format(
|
||||
qry.compile(
|
||||
engine, compile_kwargs={"literal_binds": True},),
|
||||
)
|
||||
logging.info(sql)
|
||||
sql = sqlparse.format(sql, reindent=True)
|
||||
return sql
|
||||
|
||||
def query(self, query_obj):
|
||||
qry_start_dttm = datetime.now()
|
||||
engine = self.database.get_sqla_engine()
|
||||
sql = self.get_query_str(engine, qry_start_dttm, **query_obj)
|
||||
status = QueryStatus.SUCCESS
|
||||
error_message = None
|
||||
df = None
|
||||
try:
|
||||
df = pd.read_sql_query(sql, con=engine)
|
||||
except Exception as e:
|
||||
status = QueryStatus.FAILED
|
||||
error_message = str(e)
|
||||
|
||||
return QueryResult(
|
||||
status=status,
|
||||
df=df,
|
||||
duration=datetime.now() - qry_start_dttm,
|
||||
query=sql,
|
||||
error_message=error_message)
|
||||
|
||||
def get_sqla_table_object(self):
|
||||
return self.database.get_table(self.table_name, schema=self.schema)
|
||||
|
||||
def fetch_metadata(self):
|
||||
"""Fetches the metadata for the table and merges it in"""
|
||||
try:
|
||||
table = self.get_sqla_table_object()
|
||||
except Exception:
|
||||
raise Exception(
|
||||
"Table doesn't seem to exist in the specified database, "
|
||||
"couldn't fetch column information")
|
||||
|
||||
TC = TableColumn # noqa shortcut to class
|
||||
M = SqlMetric # noqa
|
||||
metrics = []
|
||||
any_date_col = None
|
||||
for col in table.columns:
|
||||
try:
|
||||
datatype = "{}".format(col.type).upper()
|
||||
except Exception as e:
|
||||
datatype = "UNKNOWN"
|
||||
logging.error(
|
||||
"Unrecognized data type in {}.{}".format(table, col.name))
|
||||
logging.exception(e)
|
||||
dbcol = (
|
||||
db.session
|
||||
.query(TC)
|
||||
.filter(TC.table == self)
|
||||
.filter(TC.column_name == col.name)
|
||||
.first()
|
||||
)
|
||||
db.session.flush()
|
||||
if not dbcol:
|
||||
dbcol = TableColumn(column_name=col.name, type=datatype)
|
||||
dbcol.groupby = dbcol.is_string
|
||||
dbcol.filterable = dbcol.is_string
|
||||
dbcol.sum = dbcol.is_num
|
||||
dbcol.avg = dbcol.is_num
|
||||
dbcol.is_dttm = dbcol.is_time
|
||||
|
||||
db.session.merge(self)
|
||||
self.columns.append(dbcol)
|
||||
|
||||
if not any_date_col and dbcol.is_time:
|
||||
any_date_col = col.name
|
||||
|
||||
quoted = "{}".format(
|
||||
column(dbcol.column_name).compile(dialect=db.engine.dialect))
|
||||
if dbcol.sum:
|
||||
metrics.append(M(
|
||||
metric_name='sum__' + dbcol.column_name,
|
||||
verbose_name='sum__' + dbcol.column_name,
|
||||
metric_type='sum',
|
||||
expression="SUM({})".format(quoted)
|
||||
))
|
||||
if dbcol.avg:
|
||||
metrics.append(M(
|
||||
metric_name='avg__' + dbcol.column_name,
|
||||
verbose_name='avg__' + dbcol.column_name,
|
||||
metric_type='avg',
|
||||
expression="AVG({})".format(quoted)
|
||||
))
|
||||
if dbcol.max:
|
||||
metrics.append(M(
|
||||
metric_name='max__' + dbcol.column_name,
|
||||
verbose_name='max__' + dbcol.column_name,
|
||||
metric_type='max',
|
||||
expression="MAX({})".format(quoted)
|
||||
))
|
||||
if dbcol.min:
|
||||
metrics.append(M(
|
||||
metric_name='min__' + dbcol.column_name,
|
||||
verbose_name='min__' + dbcol.column_name,
|
||||
metric_type='min',
|
||||
expression="MIN({})".format(quoted)
|
||||
))
|
||||
if dbcol.count_distinct:
|
||||
metrics.append(M(
|
||||
metric_name='count_distinct__' + dbcol.column_name,
|
||||
verbose_name='count_distinct__' + dbcol.column_name,
|
||||
metric_type='count_distinct',
|
||||
expression="COUNT(DISTINCT {})".format(quoted)
|
||||
))
|
||||
dbcol.type = datatype
|
||||
db.session.merge(self)
|
||||
db.session.commit()
|
||||
|
||||
metrics.append(M(
|
||||
metric_name='count',
|
||||
verbose_name='COUNT(*)',
|
||||
metric_type='count',
|
||||
expression="COUNT(*)"
|
||||
))
|
||||
for metric in metrics:
|
||||
m = (
|
||||
db.session.query(M)
|
||||
.filter(M.metric_name == metric.metric_name)
|
||||
.filter(M.table_id == self.id)
|
||||
.first()
|
||||
)
|
||||
metric.table_id = self.id
|
||||
if not m:
|
||||
db.session.add(metric)
|
||||
db.session.commit()
|
||||
if not self.main_dttm_col:
|
||||
self.main_dttm_col = any_date_col
|
||||
|
||||
@classmethod
|
||||
def import_obj(cls, i_datasource, import_time=None):
|
||||
"""Imports the datasource from the object to the database.
|
||||
|
||||
Metrics and columns and datasource will be overrided if exists.
|
||||
This function can be used to import/export dashboards between multiple
|
||||
superset instances. Audit metadata isn't copies over.
|
||||
"""
|
||||
def lookup_sqlatable(table):
|
||||
return db.session.query(SqlaTable).join(Database).filter(
|
||||
SqlaTable.table_name == table.table_name,
|
||||
SqlaTable.schema == table.schema,
|
||||
Database.id == table.database_id,
|
||||
).first()
|
||||
|
||||
def lookup_database(table):
|
||||
return db.session.query(Database).filter_by(
|
||||
database_name=table.params_dict['database_name']).one()
|
||||
return import_util.import_datasource(
|
||||
db.session, i_datasource, lookup_database, lookup_sqlatable,
|
||||
import_time)
|
||||
|
||||
sa.event.listen(SqlaTable, 'after_insert', set_perm)
|
||||
sa.event.listen(SqlaTable, 'after_update', set_perm)
|
|
@ -0,0 +1,213 @@
|
|||
import logging
|
||||
|
||||
from flask import Markup, flash
|
||||
from flask_appbuilder import CompactCRUDMixin
|
||||
from flask_appbuilder.models.sqla.interface import SQLAInterface
|
||||
import sqlalchemy as sa
|
||||
|
||||
from flask_babel import lazy_gettext as _
|
||||
from flask_babel import gettext as __
|
||||
|
||||
from superset import appbuilder, db, utils, security, sm
|
||||
from superset.views.base import (
|
||||
SupersetModelView, ListWidgetWithCheckboxes, DeleteMixin, DatasourceFilter,
|
||||
get_datasource_exist_error_mgs,
|
||||
)
|
||||
|
||||
from . import models
|
||||
|
||||
|
||||
class TableColumnInlineView(CompactCRUDMixin, SupersetModelView): # noqa
|
||||
datamodel = SQLAInterface(models.TableColumn)
|
||||
can_delete = False
|
||||
list_widget = ListWidgetWithCheckboxes
|
||||
edit_columns = [
|
||||
'column_name', 'verbose_name', 'description', 'groupby', 'filterable',
|
||||
'table', 'count_distinct', 'sum', 'min', 'max', 'expression',
|
||||
'is_dttm', 'python_date_format', 'database_expression']
|
||||
add_columns = edit_columns
|
||||
list_columns = [
|
||||
'column_name', 'type', 'groupby', 'filterable', 'count_distinct',
|
||||
'sum', 'min', 'max', 'is_dttm']
|
||||
page_size = 500
|
||||
description_columns = {
|
||||
'is_dttm': (_(
|
||||
"Whether to make this column available as a "
|
||||
"[Time Granularity] option, column has to be DATETIME or "
|
||||
"DATETIME-like")),
|
||||
'expression': utils.markdown(
|
||||
"a valid SQL expression as supported by the underlying backend. "
|
||||
"Example: `substr(name, 1, 1)`", True),
|
||||
'python_date_format': utils.markdown(Markup(
|
||||
"The pattern of timestamp format, use "
|
||||
"<a href='https://docs.python.org/2/library/"
|
||||
"datetime.html#strftime-strptime-behavior'>"
|
||||
"python datetime string pattern</a> "
|
||||
"expression. If time is stored in epoch "
|
||||
"format, put `epoch_s` or `epoch_ms`. Leave `Database Expression` "
|
||||
"below empty if timestamp is stored in "
|
||||
"String or Integer(epoch) type"), True),
|
||||
'database_expression': utils.markdown(
|
||||
"The database expression to cast internal datetime "
|
||||
"constants to database date/timestamp type according to the DBAPI. "
|
||||
"The expression should follow the pattern of "
|
||||
"%Y-%m-%d %H:%M:%S, based on different DBAPI. "
|
||||
"The string should be a python string formatter \n"
|
||||
"`Ex: TO_DATE('{}', 'YYYY-MM-DD HH24:MI:SS')` for Oracle"
|
||||
"Superset uses default expression based on DB URI if this "
|
||||
"field is blank.", True),
|
||||
}
|
||||
label_columns = {
|
||||
'column_name': _("Column"),
|
||||
'verbose_name': _("Verbose Name"),
|
||||
'description': _("Description"),
|
||||
'groupby': _("Groupable"),
|
||||
'filterable': _("Filterable"),
|
||||
'table': _("Table"),
|
||||
'count_distinct': _("Count Distinct"),
|
||||
'sum': _("Sum"),
|
||||
'min': _("Min"),
|
||||
'max': _("Max"),
|
||||
'expression': _("Expression"),
|
||||
'is_dttm': _("Is temporal"),
|
||||
'python_date_format': _("Datetime Format"),
|
||||
'database_expression': _("Database Expression")
|
||||
}
|
||||
appbuilder.add_view_no_menu(TableColumnInlineView)
|
||||
|
||||
|
||||
class SqlMetricInlineView(CompactCRUDMixin, SupersetModelView): # noqa
|
||||
datamodel = SQLAInterface(models.SqlMetric)
|
||||
list_columns = ['metric_name', 'verbose_name', 'metric_type']
|
||||
edit_columns = [
|
||||
'metric_name', 'description', 'verbose_name', 'metric_type',
|
||||
'expression', 'table', 'd3format', 'is_restricted']
|
||||
description_columns = {
|
||||
'expression': utils.markdown(
|
||||
"a valid SQL expression as supported by the underlying backend. "
|
||||
"Example: `count(DISTINCT userid)`", True),
|
||||
'is_restricted': _("Whether the access to this metric is restricted "
|
||||
"to certain roles. Only roles with the permission "
|
||||
"'metric access on XXX (the name of this metric)' "
|
||||
"are allowed to access this metric"),
|
||||
'd3format': utils.markdown(
|
||||
"d3 formatting string as defined [here]"
|
||||
"(https://github.com/d3/d3-format/blob/master/README.md#format). "
|
||||
"For instance, this default formatting applies in the Table "
|
||||
"visualization and allow for different metric to use different "
|
||||
"formats", True
|
||||
),
|
||||
}
|
||||
add_columns = edit_columns
|
||||
page_size = 500
|
||||
label_columns = {
|
||||
'metric_name': _("Metric"),
|
||||
'description': _("Description"),
|
||||
'verbose_name': _("Verbose Name"),
|
||||
'metric_type': _("Type"),
|
||||
'expression': _("SQL Expression"),
|
||||
'table': _("Table"),
|
||||
}
|
||||
|
||||
def post_add(self, metric):
|
||||
if metric.is_restricted:
|
||||
security.merge_perm(sm, 'metric_access', metric.get_perm())
|
||||
|
||||
def post_update(self, metric):
|
||||
if metric.is_restricted:
|
||||
security.merge_perm(sm, 'metric_access', metric.get_perm())
|
||||
|
||||
appbuilder.add_view_no_menu(SqlMetricInlineView)
|
||||
|
||||
|
||||
class TableModelView(SupersetModelView, DeleteMixin): # noqa
|
||||
datamodel = SQLAInterface(models.SqlaTable)
|
||||
list_columns = [
|
||||
'link', 'database', 'is_featured',
|
||||
'changed_by_', 'changed_on_']
|
||||
order_columns = [
|
||||
'link', 'database', 'is_featured', 'changed_on_']
|
||||
add_columns = ['database', 'schema', 'table_name']
|
||||
edit_columns = [
|
||||
'table_name', 'sql', 'is_featured', 'filter_select_enabled',
|
||||
'database', 'schema',
|
||||
'description', 'owner',
|
||||
'main_dttm_col', 'default_endpoint', 'offset', 'cache_timeout']
|
||||
show_columns = edit_columns + ['perm']
|
||||
related_views = [TableColumnInlineView, SqlMetricInlineView]
|
||||
base_order = ('changed_on', 'desc')
|
||||
description_columns = {
|
||||
'offset': _("Timezone offset (in hours) for this datasource"),
|
||||
'table_name': _(
|
||||
"Name of the table that exists in the source database"),
|
||||
'schema': _(
|
||||
"Schema, as used only in some databases like Postgres, Redshift "
|
||||
"and DB2"),
|
||||
'description': Markup(
|
||||
"Supports <a href='https://daringfireball.net/projects/markdown/'>"
|
||||
"markdown</a>"),
|
||||
'sql': _(
|
||||
"This fields acts a Superset view, meaning that Superset will "
|
||||
"run a query against this string as a subquery."
|
||||
),
|
||||
}
|
||||
base_filters = [['id', DatasourceFilter, lambda: []]]
|
||||
label_columns = {
|
||||
'link': _("Table"),
|
||||
'changed_by_': _("Changed By"),
|
||||
'database': _("Database"),
|
||||
'changed_on_': _("Last Changed"),
|
||||
'is_featured': _("Is Featured"),
|
||||
'filter_select_enabled': _("Enable Filter Select"),
|
||||
'schema': _("Schema"),
|
||||
'default_endpoint': _("Default Endpoint"),
|
||||
'offset': _("Offset"),
|
||||
'cache_timeout': _("Cache Timeout"),
|
||||
}
|
||||
|
||||
def pre_add(self, table):
|
||||
number_of_existing_tables = db.session.query(
|
||||
sa.func.count('*')).filter(
|
||||
models.SqlaTable.table_name == table.table_name,
|
||||
models.SqlaTable.schema == table.schema,
|
||||
models.SqlaTable.database_id == table.database.id
|
||||
).scalar()
|
||||
# table object is already added to the session
|
||||
if number_of_existing_tables > 1:
|
||||
raise Exception(get_datasource_exist_error_mgs(table.full_name))
|
||||
|
||||
# Fail before adding if the table can't be found
|
||||
try:
|
||||
table.get_sqla_table_object()
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
raise Exception(
|
||||
"Table [{}] could not be found, "
|
||||
"please double check your "
|
||||
"database connection, schema, and "
|
||||
"table name".format(table.name))
|
||||
|
||||
def post_add(self, table):
|
||||
table.fetch_metadata()
|
||||
security.merge_perm(sm, 'datasource_access', table.get_perm())
|
||||
if table.schema:
|
||||
security.merge_perm(sm, 'schema_access', table.schema_perm)
|
||||
|
||||
flash(_(
|
||||
"The table was created. As part of this two phase configuration "
|
||||
"process, you should now click the edit button by "
|
||||
"the new table to configure it."),
|
||||
"info")
|
||||
|
||||
def post_update(self, table):
|
||||
self.post_add(table)
|
||||
|
||||
appbuilder.add_view(
|
||||
TableModelView,
|
||||
"Tables",
|
||||
label=__("Tables"),
|
||||
category="Sources",
|
||||
category_label=__("Sources"),
|
||||
icon='fa-table',)
|
||||
|
||||
appbuilder.add_separator("Sources")
|
|
@ -14,15 +14,19 @@ import random
|
|||
import pandas as pd
|
||||
from sqlalchemy import String, DateTime, Date, Float, BigInteger
|
||||
|
||||
from superset import app, db, models, utils
|
||||
from superset import app, db, utils
|
||||
from superset.models import core as models
|
||||
from superset.security import get_or_create_main_db
|
||||
|
||||
from superset.connectors.connector_registry import ConnectorRegistry
|
||||
|
||||
# Shortcuts
|
||||
DB = models.Database
|
||||
Slice = models.Slice
|
||||
TBL = models.SqlaTable
|
||||
Dash = models.Dashboard
|
||||
|
||||
TBL = ConnectorRegistry.sources['table']
|
||||
|
||||
config = app.config
|
||||
|
||||
DATA_FOLDER = os.path.join(config.get("BASE_DIR"), 'data')
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
"""adding verbose_name to druid column
|
||||
|
||||
Revision ID: b318dfe5fb6c
|
||||
Revises: d6db5a5cdb5d
|
||||
Create Date: 2017-03-08 11:48:10.835741
|
||||
|
||||
"""
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'b318dfe5fb6c'
|
||||
down_revision = 'd6db5a5cdb5d'
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.add_column('columns', sa.Column('verbose_name', sa.String(length=1024), nullable=True))
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_column('columns', 'verbose_name')
|
2876
superset/models.py
2876
superset/models.py
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1 @@
|
|||
from . import core # noqa
|
|
@ -0,0 +1,951 @@
|
|||
"""A collection of ORM sqlalchemy models for Superset"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import functools
|
||||
import json
|
||||
import logging
|
||||
import numpy
|
||||
import pickle
|
||||
import re
|
||||
import textwrap
|
||||
from future.standard_library import install_aliases
|
||||
from copy import copy
|
||||
from datetime import datetime, date
|
||||
|
||||
import pandas as pd
|
||||
import sqlalchemy as sqla
|
||||
from sqlalchemy.engine.url import make_url
|
||||
from sqlalchemy.orm import subqueryload
|
||||
|
||||
from flask import escape, g, Markup, request
|
||||
from flask_appbuilder import Model
|
||||
from flask_appbuilder.models.decorators import renders
|
||||
|
||||
from sqlalchemy import (
|
||||
Column, Integer, String, ForeignKey, Text, Boolean,
|
||||
DateTime, Date, Table, Numeric,
|
||||
create_engine, MetaData, select
|
||||
)
|
||||
from sqlalchemy.orm import backref, relationship
|
||||
from sqlalchemy.orm.session import make_transient
|
||||
from sqlalchemy.sql import text
|
||||
from sqlalchemy.sql.expression import TextAsFrom
|
||||
from sqlalchemy_utils import EncryptedType
|
||||
|
||||
from superset import app, db, db_engine_specs, utils, sm
|
||||
from superset.connectors.connector_registry import ConnectorRegistry
|
||||
from superset.viz import viz_types
|
||||
from superset.utils import QueryStatus
|
||||
from superset.models.helpers import AuditMixinNullable, ImportMixin, set_perm
|
||||
install_aliases()
|
||||
from urllib import parse # noqa
|
||||
|
||||
config = app.config
|
||||
|
||||
|
||||
def set_related_perm(mapper, connection, target): # noqa
|
||||
src_class = target.cls_model
|
||||
id_ = target.datasource_id
|
||||
ds = db.session.query(src_class).filter_by(id=int(id_)).first()
|
||||
target.perm = ds.perm
|
||||
|
||||
|
||||
class Url(Model, AuditMixinNullable):
|
||||
"""Used for the short url feature"""
|
||||
|
||||
__tablename__ = 'url'
|
||||
id = Column(Integer, primary_key=True)
|
||||
url = Column(Text)
|
||||
|
||||
|
||||
class KeyValue(Model):
|
||||
|
||||
"""Used for any type of key-value store"""
|
||||
|
||||
__tablename__ = 'keyvalue'
|
||||
id = Column(Integer, primary_key=True)
|
||||
value = Column(Text, nullable=False)
|
||||
|
||||
|
||||
class CssTemplate(Model, AuditMixinNullable):
|
||||
|
||||
"""CSS templates for dashboards"""
|
||||
|
||||
__tablename__ = 'css_templates'
|
||||
id = Column(Integer, primary_key=True)
|
||||
template_name = Column(String(250))
|
||||
css = Column(Text, default='')
|
||||
|
||||
|
||||
slice_user = Table('slice_user', Model.metadata,
|
||||
Column('id', Integer, primary_key=True),
|
||||
Column('user_id', Integer, ForeignKey('ab_user.id')),
|
||||
Column('slice_id', Integer, ForeignKey('slices.id'))
|
||||
)
|
||||
|
||||
|
||||
class Slice(Model, AuditMixinNullable, ImportMixin):
|
||||
|
||||
"""A slice is essentially a report or a view on data"""
|
||||
|
||||
__tablename__ = 'slices'
|
||||
id = Column(Integer, primary_key=True)
|
||||
slice_name = Column(String(250))
|
||||
datasource_id = Column(Integer)
|
||||
datasource_type = Column(String(200))
|
||||
datasource_name = Column(String(2000))
|
||||
viz_type = Column(String(250))
|
||||
params = Column(Text)
|
||||
description = Column(Text)
|
||||
cache_timeout = Column(Integer)
|
||||
perm = Column(String(1000))
|
||||
owners = relationship("User", secondary=slice_user)
|
||||
|
||||
export_fields = ('slice_name', 'datasource_type', 'datasource_name',
|
||||
'viz_type', 'params', 'cache_timeout')
|
||||
|
||||
def __repr__(self):
|
||||
return self.slice_name
|
||||
|
||||
@property
|
||||
def cls_model(self):
|
||||
return ConnectorRegistry.sources[self.datasource_type]
|
||||
|
||||
@property
|
||||
def datasource(self):
|
||||
return self.get_datasource
|
||||
|
||||
@datasource.getter
|
||||
@utils.memoized
|
||||
def get_datasource(self):
|
||||
ds = db.session.query(
|
||||
self.cls_model).filter_by(
|
||||
id=self.datasource_id).first()
|
||||
return ds
|
||||
|
||||
@renders('datasource_name')
|
||||
def datasource_link(self):
|
||||
datasource = self.datasource
|
||||
if datasource:
|
||||
return self.datasource.link
|
||||
|
||||
@property
|
||||
def datasource_edit_url(self):
|
||||
self.datasource.url
|
||||
|
||||
@property
|
||||
@utils.memoized
|
||||
def viz(self):
|
||||
d = json.loads(self.params)
|
||||
viz_class = viz_types[self.viz_type]
|
||||
return viz_class(self.datasource, form_data=d)
|
||||
|
||||
@property
|
||||
def description_markeddown(self):
|
||||
return utils.markdown(self.description)
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
"""Data used to render slice in templates"""
|
||||
d = {}
|
||||
self.token = ''
|
||||
try:
|
||||
d = self.viz.data
|
||||
self.token = d.get('token')
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
d['error'] = str(e)
|
||||
return {
|
||||
'datasource': self.datasource_name,
|
||||
'description': self.description,
|
||||
'description_markeddown': self.description_markeddown,
|
||||
'edit_url': self.edit_url,
|
||||
'form_data': self.form_data,
|
||||
'slice_id': self.id,
|
||||
'slice_name': self.slice_name,
|
||||
'slice_url': self.slice_url,
|
||||
}
|
||||
|
||||
@property
|
||||
def json_data(self):
|
||||
return json.dumps(self.data)
|
||||
|
||||
@property
|
||||
def form_data(self):
|
||||
form_data = json.loads(self.params)
|
||||
form_data['slice_id'] = self.id
|
||||
form_data['viz_type'] = self.viz_type
|
||||
form_data['datasource'] = (
|
||||
str(self.datasource_id) + '__' + self.datasource_type)
|
||||
return form_data
|
||||
|
||||
@property
|
||||
def slice_url(self):
|
||||
"""Defines the url to access the slice"""
|
||||
return (
|
||||
"/superset/explore/{obj.datasource_type}/"
|
||||
"{obj.datasource_id}/?form_data={params}".format(
|
||||
obj=self, params=parse.quote(json.dumps(self.form_data))))
|
||||
|
||||
@property
|
||||
def slice_id_url(self):
|
||||
return (
|
||||
"/superset/{slc.datasource_type}/{slc.datasource_id}/{slc.id}/"
|
||||
).format(slc=self)
|
||||
|
||||
@property
|
||||
def edit_url(self):
|
||||
return "/slicemodelview/edit/{}".format(self.id)
|
||||
|
||||
@property
|
||||
def slice_link(self):
|
||||
url = self.slice_url
|
||||
name = escape(self.slice_name)
|
||||
return Markup('<a href="{url}">{name}</a>'.format(**locals()))
|
||||
|
||||
def get_viz(self, url_params_multidict=None):
|
||||
"""Creates :py:class:viz.BaseViz object from the url_params_multidict.
|
||||
|
||||
:param werkzeug.datastructures.MultiDict url_params_multidict:
|
||||
Contains the visualization params, they override the self.params
|
||||
stored in the database
|
||||
:return: object of the 'viz_type' type that is taken from the
|
||||
url_params_multidict or self.params.
|
||||
:rtype: :py:class:viz.BaseViz
|
||||
"""
|
||||
slice_params = json.loads(self.params)
|
||||
slice_params['slice_id'] = self.id
|
||||
slice_params['json'] = "false"
|
||||
slice_params['slice_name'] = self.slice_name
|
||||
slice_params['viz_type'] = self.viz_type if self.viz_type else "table"
|
||||
|
||||
return viz_types[slice_params.get('viz_type')](
|
||||
self.datasource,
|
||||
form_data=slice_params,
|
||||
slice_=self
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def import_obj(cls, slc_to_import, import_time=None):
|
||||
"""Inserts or overrides slc in the database.
|
||||
|
||||
remote_id and import_time fields in params_dict are set to track the
|
||||
slice origin and ensure correct overrides for multiple imports.
|
||||
Slice.perm is used to find the datasources and connect them.
|
||||
"""
|
||||
session = db.session
|
||||
make_transient(slc_to_import)
|
||||
slc_to_import.dashboards = []
|
||||
slc_to_import.alter_params(
|
||||
remote_id=slc_to_import.id, import_time=import_time)
|
||||
|
||||
# find if the slice was already imported
|
||||
slc_to_override = None
|
||||
for slc in session.query(Slice).all():
|
||||
if ('remote_id' in slc.params_dict and
|
||||
slc.params_dict['remote_id'] == slc_to_import.id):
|
||||
slc_to_override = slc
|
||||
|
||||
slc_to_import = slc_to_import.copy()
|
||||
params = slc_to_import.params_dict
|
||||
slc_to_import.datasource_id = ConnectorRegistry.get_datasource_by_name(
|
||||
session, slc_to_import.datasource_type, params['datasource_name'],
|
||||
params['schema'], params['database_name']).id
|
||||
if slc_to_override:
|
||||
slc_to_override.override(slc_to_import)
|
||||
session.flush()
|
||||
return slc_to_override.id
|
||||
session.add(slc_to_import)
|
||||
logging.info('Final slice: {}'.format(slc_to_import.to_json()))
|
||||
session.flush()
|
||||
return slc_to_import.id
|
||||
|
||||
|
||||
sqla.event.listen(Slice, 'before_insert', set_related_perm)
|
||||
sqla.event.listen(Slice, 'before_update', set_related_perm)
|
||||
|
||||
|
||||
dashboard_slices = Table(
|
||||
'dashboard_slices', Model.metadata,
|
||||
Column('id', Integer, primary_key=True),
|
||||
Column('dashboard_id', Integer, ForeignKey('dashboards.id')),
|
||||
Column('slice_id', Integer, ForeignKey('slices.id')),
|
||||
)
|
||||
|
||||
dashboard_user = Table(
|
||||
'dashboard_user', Model.metadata,
|
||||
Column('id', Integer, primary_key=True),
|
||||
Column('user_id', Integer, ForeignKey('ab_user.id')),
|
||||
Column('dashboard_id', Integer, ForeignKey('dashboards.id'))
|
||||
)
|
||||
|
||||
|
||||
class Dashboard(Model, AuditMixinNullable, ImportMixin):
|
||||
|
||||
"""The dashboard object!"""
|
||||
|
||||
__tablename__ = 'dashboards'
|
||||
id = Column(Integer, primary_key=True)
|
||||
dashboard_title = Column(String(500))
|
||||
position_json = Column(Text)
|
||||
description = Column(Text)
|
||||
css = Column(Text)
|
||||
json_metadata = Column(Text)
|
||||
slug = Column(String(255), unique=True)
|
||||
slices = relationship(
|
||||
'Slice', secondary=dashboard_slices, backref='dashboards')
|
||||
owners = relationship("User", secondary=dashboard_user)
|
||||
|
||||
export_fields = ('dashboard_title', 'position_json', 'json_metadata',
|
||||
'description', 'css', 'slug')
|
||||
|
||||
def __repr__(self):
|
||||
return self.dashboard_title
|
||||
|
||||
@property
|
||||
def table_names(self):
|
||||
return ", ".join(
|
||||
{"{}".format(s.datasource.name) for s in self.slices})
|
||||
|
||||
@property
|
||||
def url(self):
|
||||
return "/superset/dashboard/{}/".format(self.slug or self.id)
|
||||
|
||||
@property
|
||||
def datasources(self):
|
||||
return {slc.datasource for slc in self.slices}
|
||||
|
||||
@property
|
||||
def sqla_metadata(self):
|
||||
metadata = MetaData(bind=self.get_sqla_engine())
|
||||
return metadata.reflect()
|
||||
|
||||
def dashboard_link(self):
|
||||
title = escape(self.dashboard_title)
|
||||
return Markup(
|
||||
'<a href="{self.url}">{title}</a>'.format(**locals()))
|
||||
|
||||
@property
|
||||
def json_data(self):
|
||||
positions = self.position_json
|
||||
if positions:
|
||||
positions = json.loads(positions)
|
||||
d = {
|
||||
'id': self.id,
|
||||
'metadata': self.params_dict,
|
||||
'css': self.css,
|
||||
'dashboard_title': self.dashboard_title,
|
||||
'slug': self.slug,
|
||||
'slices': [slc.data for slc in self.slices],
|
||||
'position_json': positions,
|
||||
}
|
||||
return json.dumps(d)
|
||||
|
||||
@property
|
||||
def params(self):
|
||||
return self.json_metadata
|
||||
|
||||
@params.setter
|
||||
def params(self, value):
|
||||
self.json_metadata = value
|
||||
|
||||
@property
|
||||
def position_array(self):
|
||||
if self.position_json:
|
||||
return json.loads(self.position_json)
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def import_obj(cls, dashboard_to_import, import_time=None):
|
||||
"""Imports the dashboard from the object to the database.
|
||||
|
||||
Once dashboard is imported, json_metadata field is extended and stores
|
||||
remote_id and import_time. It helps to decide if the dashboard has to
|
||||
be overridden or just copies over. Slices that belong to this
|
||||
dashboard will be wired to existing tables. This function can be used
|
||||
to import/export dashboards between multiple superset instances.
|
||||
Audit metadata isn't copies over.
|
||||
"""
|
||||
def alter_positions(dashboard, old_to_new_slc_id_dict):
|
||||
""" Updates slice_ids in the position json.
|
||||
|
||||
Sample position json:
|
||||
[{
|
||||
"col": 5,
|
||||
"row": 10,
|
||||
"size_x": 4,
|
||||
"size_y": 2,
|
||||
"slice_id": "3610"
|
||||
}]
|
||||
"""
|
||||
position_array = dashboard.position_array
|
||||
for position in position_array:
|
||||
if 'slice_id' not in position:
|
||||
continue
|
||||
old_slice_id = int(position['slice_id'])
|
||||
if old_slice_id in old_to_new_slc_id_dict:
|
||||
position['slice_id'] = '{}'.format(
|
||||
old_to_new_slc_id_dict[old_slice_id])
|
||||
dashboard.position_json = json.dumps(position_array)
|
||||
|
||||
logging.info('Started import of the dashboard: {}'
|
||||
.format(dashboard_to_import.to_json()))
|
||||
session = db.session
|
||||
logging.info('Dashboard has {} slices'
|
||||
.format(len(dashboard_to_import.slices)))
|
||||
# copy slices object as Slice.import_slice will mutate the slice
|
||||
# and will remove the existing dashboard - slice association
|
||||
slices = copy(dashboard_to_import.slices)
|
||||
old_to_new_slc_id_dict = {}
|
||||
new_filter_immune_slices = []
|
||||
new_expanded_slices = {}
|
||||
i_params_dict = dashboard_to_import.params_dict
|
||||
for slc in slices:
|
||||
logging.info('Importing slice {} from the dashboard: {}'.format(
|
||||
slc.to_json(), dashboard_to_import.dashboard_title))
|
||||
new_slc_id = Slice.import_obj(slc, import_time=import_time)
|
||||
old_to_new_slc_id_dict[slc.id] = new_slc_id
|
||||
# update json metadata that deals with slice ids
|
||||
new_slc_id_str = '{}'.format(new_slc_id)
|
||||
old_slc_id_str = '{}'.format(slc.id)
|
||||
if ('filter_immune_slices' in i_params_dict and
|
||||
old_slc_id_str in i_params_dict['filter_immune_slices']):
|
||||
new_filter_immune_slices.append(new_slc_id_str)
|
||||
if ('expanded_slices' in i_params_dict and
|
||||
old_slc_id_str in i_params_dict['expanded_slices']):
|
||||
new_expanded_slices[new_slc_id_str] = (
|
||||
i_params_dict['expanded_slices'][old_slc_id_str])
|
||||
|
||||
# override the dashboard
|
||||
existing_dashboard = None
|
||||
for dash in session.query(Dashboard).all():
|
||||
if ('remote_id' in dash.params_dict and
|
||||
dash.params_dict['remote_id'] ==
|
||||
dashboard_to_import.id):
|
||||
existing_dashboard = dash
|
||||
|
||||
dashboard_to_import.id = None
|
||||
alter_positions(dashboard_to_import, old_to_new_slc_id_dict)
|
||||
dashboard_to_import.alter_params(import_time=import_time)
|
||||
if new_expanded_slices:
|
||||
dashboard_to_import.alter_params(
|
||||
expanded_slices=new_expanded_slices)
|
||||
if new_filter_immune_slices:
|
||||
dashboard_to_import.alter_params(
|
||||
filter_immune_slices=new_filter_immune_slices)
|
||||
|
||||
new_slices = session.query(Slice).filter(
|
||||
Slice.id.in_(old_to_new_slc_id_dict.values())).all()
|
||||
|
||||
if existing_dashboard:
|
||||
existing_dashboard.override(dashboard_to_import)
|
||||
existing_dashboard.slices = new_slices
|
||||
session.flush()
|
||||
return existing_dashboard.id
|
||||
else:
|
||||
# session.add(dashboard_to_import) causes sqlachemy failures
|
||||
# related to the attached users / slices. Creating new object
|
||||
# allows to avoid conflicts in the sql alchemy state.
|
||||
copied_dash = dashboard_to_import.copy()
|
||||
copied_dash.slices = new_slices
|
||||
session.add(copied_dash)
|
||||
session.flush()
|
||||
return copied_dash.id
|
||||
|
||||
@classmethod
|
||||
def export_dashboards(cls, dashboard_ids):
|
||||
copied_dashboards = []
|
||||
datasource_ids = set()
|
||||
for dashboard_id in dashboard_ids:
|
||||
# make sure that dashboard_id is an integer
|
||||
dashboard_id = int(dashboard_id)
|
||||
copied_dashboard = (
|
||||
db.session.query(Dashboard)
|
||||
.options(subqueryload(Dashboard.slices))
|
||||
.filter_by(id=dashboard_id).first()
|
||||
)
|
||||
make_transient(copied_dashboard)
|
||||
for slc in copied_dashboard.slices:
|
||||
datasource_ids.add((slc.datasource_id, slc.datasource_type))
|
||||
# add extra params for the import
|
||||
slc.alter_params(
|
||||
remote_id=slc.id,
|
||||
datasource_name=slc.datasource.name,
|
||||
schema=slc.datasource.name,
|
||||
database_name=slc.datasource.database.name,
|
||||
)
|
||||
copied_dashboard.alter_params(remote_id=dashboard_id)
|
||||
copied_dashboards.append(copied_dashboard)
|
||||
|
||||
eager_datasources = []
|
||||
for dashboard_id, dashboard_type in datasource_ids:
|
||||
eager_datasource = ConnectorRegistry.get_eager_datasource(
|
||||
db.session, dashboard_type, dashboard_id)
|
||||
eager_datasource.alter_params(
|
||||
remote_id=eager_datasource.id,
|
||||
database_name=eager_datasource.database.name,
|
||||
)
|
||||
make_transient(eager_datasource)
|
||||
eager_datasources.append(eager_datasource)
|
||||
|
||||
return pickle.dumps({
|
||||
'dashboards': copied_dashboards,
|
||||
'datasources': eager_datasources,
|
||||
})
|
||||
|
||||
|
||||
class Database(Model, AuditMixinNullable):
|
||||
|
||||
"""An ORM object that stores Database related information"""
|
||||
|
||||
__tablename__ = 'dbs'
|
||||
type = "table"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
database_name = Column(String(250), unique=True)
|
||||
sqlalchemy_uri = Column(String(1024))
|
||||
password = Column(EncryptedType(String(1024), config.get('SECRET_KEY')))
|
||||
cache_timeout = Column(Integer)
|
||||
select_as_create_table_as = Column(Boolean, default=False)
|
||||
expose_in_sqllab = Column(Boolean, default=False)
|
||||
allow_run_sync = Column(Boolean, default=True)
|
||||
allow_run_async = Column(Boolean, default=False)
|
||||
allow_ctas = Column(Boolean, default=False)
|
||||
allow_dml = Column(Boolean, default=False)
|
||||
force_ctas_schema = Column(String(250))
|
||||
extra = Column(Text, default=textwrap.dedent("""\
|
||||
{
|
||||
"metadata_params": {},
|
||||
"engine_params": {}
|
||||
}
|
||||
"""))
|
||||
perm = Column(String(1000))
|
||||
|
||||
def __repr__(self):
|
||||
return self.database_name
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.database_name
|
||||
|
||||
@property
|
||||
def backend(self):
|
||||
url = make_url(self.sqlalchemy_uri_decrypted)
|
||||
return url.get_backend_name()
|
||||
|
||||
def set_sqlalchemy_uri(self, uri):
|
||||
password_mask = "X" * 10
|
||||
conn = sqla.engine.url.make_url(uri)
|
||||
if conn.password != password_mask:
|
||||
# do not over-write the password with the password mask
|
||||
self.password = conn.password
|
||||
conn.password = password_mask if conn.password else None
|
||||
self.sqlalchemy_uri = str(conn) # hides the password
|
||||
|
||||
def get_sqla_engine(self, schema=None):
|
||||
extra = self.get_extra()
|
||||
url = make_url(self.sqlalchemy_uri_decrypted)
|
||||
params = extra.get('engine_params', {})
|
||||
url.database = self.get_database_for_various_backend(url, schema)
|
||||
return create_engine(url, **params)
|
||||
|
||||
def get_database_for_various_backend(self, uri, default_database=None):
|
||||
database = uri.database
|
||||
if self.backend == 'presto' and default_database:
|
||||
if '/' in database:
|
||||
database = database.split('/')[0] + '/' + default_database
|
||||
else:
|
||||
database += '/' + default_database
|
||||
# Postgres and Redshift use the concept of schema as a logical entity
|
||||
# on top of the database, so the database should not be changed
|
||||
# even if passed default_database
|
||||
elif self.backend == 'redshift' or self.backend == 'postgresql':
|
||||
pass
|
||||
elif default_database:
|
||||
database = default_database
|
||||
return database
|
||||
|
||||
def get_reserved_words(self):
|
||||
return self.get_sqla_engine().dialect.preparer.reserved_words
|
||||
|
||||
def get_quoter(self):
|
||||
return self.get_sqla_engine().dialect.identifier_preparer.quote
|
||||
|
||||
def get_df(self, sql, schema):
|
||||
sql = sql.strip().strip(';')
|
||||
eng = self.get_sqla_engine(schema=schema)
|
||||
cur = eng.execute(sql, schema=schema)
|
||||
cols = [col[0] for col in cur.cursor.description]
|
||||
df = pd.DataFrame(cur.fetchall(), columns=cols)
|
||||
|
||||
def needs_conversion(df_series):
|
||||
if df_series.empty:
|
||||
return False
|
||||
for df_type in [list, dict]:
|
||||
if isinstance(df_series[0], df_type):
|
||||
return True
|
||||
return False
|
||||
|
||||
for k, v in df.dtypes.iteritems():
|
||||
if v.type == numpy.object_ and needs_conversion(df[k]):
|
||||
df[k] = df[k].apply(utils.json_dumps_w_dates)
|
||||
return df
|
||||
|
||||
def compile_sqla_query(self, qry, schema=None):
|
||||
eng = self.get_sqla_engine(schema=schema)
|
||||
compiled = qry.compile(eng, compile_kwargs={"literal_binds": True})
|
||||
return '{}'.format(compiled)
|
||||
|
||||
def select_star(
|
||||
self, table_name, schema=None, limit=100, show_cols=False,
|
||||
indent=True):
|
||||
"""Generates a ``select *`` statement in the proper dialect"""
|
||||
return self.db_engine_spec.select_star(
|
||||
self, table_name, schema=schema, limit=limit, show_cols=show_cols,
|
||||
indent=indent)
|
||||
|
||||
def wrap_sql_limit(self, sql, limit=1000):
|
||||
qry = (
|
||||
select('*')
|
||||
.select_from(
|
||||
TextAsFrom(text(sql), ['*'])
|
||||
.alias('inner_qry')
|
||||
).limit(limit)
|
||||
)
|
||||
return self.compile_sqla_query(qry)
|
||||
|
||||
def safe_sqlalchemy_uri(self):
|
||||
return self.sqlalchemy_uri
|
||||
|
||||
@property
|
||||
def inspector(self):
|
||||
engine = self.get_sqla_engine()
|
||||
return sqla.inspect(engine)
|
||||
|
||||
def all_table_names(self, schema=None, force=False):
|
||||
if not schema:
|
||||
tables_dict = self.db_engine_spec.fetch_result_sets(
|
||||
self, 'table', force=force)
|
||||
return tables_dict.get("", [])
|
||||
return sorted(self.inspector.get_table_names(schema))
|
||||
|
||||
def all_view_names(self, schema=None, force=False):
|
||||
if not schema:
|
||||
views_dict = self.db_engine_spec.fetch_result_sets(
|
||||
self, 'view', force=force)
|
||||
return views_dict.get("", [])
|
||||
views = []
|
||||
try:
|
||||
views = self.inspector.get_view_names(schema)
|
||||
except Exception:
|
||||
pass
|
||||
return views
|
||||
|
||||
def all_schema_names(self):
|
||||
return sorted(self.inspector.get_schema_names())
|
||||
|
||||
@property
|
||||
def db_engine_spec(self):
|
||||
engine_name = self.get_sqla_engine().name or 'base'
|
||||
return db_engine_specs.engines.get(
|
||||
engine_name, db_engine_specs.BaseEngineSpec)
|
||||
|
||||
def grains(self):
|
||||
"""Defines time granularity database-specific expressions.
|
||||
|
||||
The idea here is to make it easy for users to change the time grain
|
||||
form a datetime (maybe the source grain is arbitrary timestamps, daily
|
||||
or 5 minutes increments) to another, "truncated" datetime. Since
|
||||
each database has slightly different but similar datetime functions,
|
||||
this allows a mapping between database engines and actual functions.
|
||||
"""
|
||||
return self.db_engine_spec.time_grains
|
||||
|
||||
def grains_dict(self):
|
||||
return {grain.name: grain for grain in self.grains()}
|
||||
|
||||
def get_extra(self):
|
||||
extra = {}
|
||||
if self.extra:
|
||||
try:
|
||||
extra = json.loads(self.extra)
|
||||
except Exception as e:
|
||||
logging.error(e)
|
||||
return extra
|
||||
|
||||
def get_table(self, table_name, schema=None):
|
||||
extra = self.get_extra()
|
||||
meta = MetaData(**extra.get('metadata_params', {}))
|
||||
return Table(
|
||||
table_name, meta,
|
||||
schema=schema or None,
|
||||
autoload=True,
|
||||
autoload_with=self.get_sqla_engine())
|
||||
|
||||
def get_columns(self, table_name, schema=None):
|
||||
return self.inspector.get_columns(table_name, schema)
|
||||
|
||||
def get_indexes(self, table_name, schema=None):
|
||||
return self.inspector.get_indexes(table_name, schema)
|
||||
|
||||
def get_pk_constraint(self, table_name, schema=None):
|
||||
return self.inspector.get_pk_constraint(table_name, schema)
|
||||
|
||||
def get_foreign_keys(self, table_name, schema=None):
|
||||
return self.inspector.get_foreign_keys(table_name, schema)
|
||||
|
||||
@property
|
||||
def sqlalchemy_uri_decrypted(self):
|
||||
conn = sqla.engine.url.make_url(self.sqlalchemy_uri)
|
||||
conn.password = self.password
|
||||
return str(conn)
|
||||
|
||||
@property
|
||||
def sql_url(self):
|
||||
return '/superset/sql/{}/'.format(self.id)
|
||||
|
||||
def get_perm(self):
|
||||
return (
|
||||
"[{obj.database_name}].(id:{obj.id})").format(obj=self)
|
||||
|
||||
sqla.event.listen(Database, 'after_insert', set_perm)
|
||||
sqla.event.listen(Database, 'after_update', set_perm)
|
||||
|
||||
|
||||
class Log(Model):
|
||||
|
||||
"""ORM object used to log Superset actions to the database"""
|
||||
|
||||
__tablename__ = 'logs'
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
action = Column(String(512))
|
||||
user_id = Column(Integer, ForeignKey('ab_user.id'))
|
||||
dashboard_id = Column(Integer)
|
||||
slice_id = Column(Integer)
|
||||
json = Column(Text)
|
||||
user = relationship('User', backref='logs', foreign_keys=[user_id])
|
||||
dttm = Column(DateTime, default=datetime.utcnow)
|
||||
dt = Column(Date, default=date.today())
|
||||
duration_ms = Column(Integer)
|
||||
referrer = Column(String(1024))
|
||||
|
||||
@classmethod
|
||||
def log_this(cls, f):
|
||||
"""Decorator to log user actions"""
|
||||
@functools.wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
start_dttm = datetime.now()
|
||||
user_id = None
|
||||
if g.user:
|
||||
user_id = g.user.get_id()
|
||||
d = request.args.to_dict()
|
||||
post_data = request.form or {}
|
||||
d.update(post_data)
|
||||
d.update(kwargs)
|
||||
slice_id = d.get('slice_id', 0)
|
||||
try:
|
||||
slice_id = int(slice_id) if slice_id else 0
|
||||
except ValueError:
|
||||
slice_id = 0
|
||||
params = ""
|
||||
try:
|
||||
params = json.dumps(d)
|
||||
except:
|
||||
pass
|
||||
value = f(*args, **kwargs)
|
||||
|
||||
sesh = db.session()
|
||||
log = cls(
|
||||
action=f.__name__,
|
||||
json=params,
|
||||
dashboard_id=d.get('dashboard_id') or None,
|
||||
slice_id=slice_id,
|
||||
duration_ms=(
|
||||
datetime.now() - start_dttm).total_seconds() * 1000,
|
||||
referrer=request.referrer[:1000] if request.referrer else None,
|
||||
user_id=user_id)
|
||||
sesh.add(log)
|
||||
sesh.commit()
|
||||
return value
|
||||
return wrapper
|
||||
|
||||
|
||||
class FavStar(Model):
|
||||
__tablename__ = 'favstar'
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
user_id = Column(Integer, ForeignKey('ab_user.id'))
|
||||
class_name = Column(String(50))
|
||||
obj_id = Column(Integer)
|
||||
dttm = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
|
||||
class Query(Model):
|
||||
|
||||
"""ORM model for SQL query"""
|
||||
|
||||
__tablename__ = 'query'
|
||||
id = Column(Integer, primary_key=True)
|
||||
client_id = Column(String(11), unique=True, nullable=False)
|
||||
|
||||
database_id = Column(Integer, ForeignKey('dbs.id'), nullable=False)
|
||||
|
||||
# Store the tmp table into the DB only if the user asks for it.
|
||||
tmp_table_name = Column(String(256))
|
||||
user_id = Column(
|
||||
Integer, ForeignKey('ab_user.id'), nullable=True)
|
||||
status = Column(String(16), default=QueryStatus.PENDING)
|
||||
tab_name = Column(String(256))
|
||||
sql_editor_id = Column(String(256))
|
||||
schema = Column(String(256))
|
||||
sql = Column(Text)
|
||||
# Query to retrieve the results,
|
||||
# used only in case of select_as_cta_used is true.
|
||||
select_sql = Column(Text)
|
||||
executed_sql = Column(Text)
|
||||
# Could be configured in the superset config.
|
||||
limit = Column(Integer)
|
||||
limit_used = Column(Boolean, default=False)
|
||||
limit_reached = Column(Boolean, default=False)
|
||||
select_as_cta = Column(Boolean)
|
||||
select_as_cta_used = Column(Boolean, default=False)
|
||||
|
||||
progress = Column(Integer, default=0) # 1..100
|
||||
# # of rows in the result set or rows modified.
|
||||
rows = Column(Integer)
|
||||
error_message = Column(Text)
|
||||
# key used to store the results in the results backend
|
||||
results_key = Column(String(64), index=True)
|
||||
|
||||
# Using Numeric in place of DateTime for sub-second precision
|
||||
# stored as seconds since epoch, allowing for milliseconds
|
||||
start_time = Column(Numeric(precision=3))
|
||||
end_time = Column(Numeric(precision=3))
|
||||
changed_on = Column(
|
||||
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=True)
|
||||
|
||||
database = relationship(
|
||||
'Database',
|
||||
foreign_keys=[database_id],
|
||||
backref=backref('queries', cascade='all, delete-orphan')
|
||||
)
|
||||
user = relationship(
|
||||
'User',
|
||||
backref=backref('queries', cascade='all, delete-orphan'),
|
||||
foreign_keys=[user_id])
|
||||
|
||||
__table_args__ = (
|
||||
sqla.Index('ti_user_id_changed_on', user_id, changed_on),
|
||||
)
|
||||
|
||||
@property
|
||||
def limit_reached(self):
|
||||
return self.rows == self.limit if self.limit_used else False
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'changedOn': self.changed_on,
|
||||
'changed_on': self.changed_on.isoformat(),
|
||||
'dbId': self.database_id,
|
||||
'db': self.database.database_name,
|
||||
'endDttm': self.end_time,
|
||||
'errorMessage': self.error_message,
|
||||
'executedSql': self.executed_sql,
|
||||
'id': self.client_id,
|
||||
'limit': self.limit,
|
||||
'progress': self.progress,
|
||||
'rows': self.rows,
|
||||
'schema': self.schema,
|
||||
'ctas': self.select_as_cta,
|
||||
'serverId': self.id,
|
||||
'sql': self.sql,
|
||||
'sqlEditorId': self.sql_editor_id,
|
||||
'startDttm': self.start_time,
|
||||
'state': self.status.lower(),
|
||||
'tab': self.tab_name,
|
||||
'tempTable': self.tmp_table_name,
|
||||
'userId': self.user_id,
|
||||
'user': self.user.username,
|
||||
'limit_reached': self.limit_reached,
|
||||
'resultsKey': self.results_key,
|
||||
}
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
ts = datetime.now().isoformat()
|
||||
ts = ts.replace('-', '').replace(':', '').split('.')[0]
|
||||
tab = self.tab_name.replace(' ', '_').lower() if self.tab_name else 'notab'
|
||||
tab = re.sub(r'\W+', '', tab)
|
||||
return "sqllab_{tab}_{ts}".format(**locals())
|
||||
|
||||
|
||||
class DatasourceAccessRequest(Model, AuditMixinNullable):
|
||||
"""ORM model for the access requests for datasources and dbs."""
|
||||
__tablename__ = 'access_request'
|
||||
id = Column(Integer, primary_key=True)
|
||||
|
||||
datasource_id = Column(Integer)
|
||||
datasource_type = Column(String(200))
|
||||
|
||||
ROLES_BLACKLIST = set(config.get('ROBOT_PERMISSION_ROLES', []))
|
||||
|
||||
@property
|
||||
def cls_model(self):
|
||||
return ConnectorRegistry.sources[self.datasource_type]
|
||||
|
||||
@property
|
||||
def username(self):
|
||||
return self.creator()
|
||||
|
||||
@property
|
||||
def datasource(self):
|
||||
return self.get_datasource
|
||||
|
||||
@datasource.getter
|
||||
@utils.memoized
|
||||
def get_datasource(self):
|
||||
ds = db.session.query(self.cls_model).filter_by(
|
||||
id=self.datasource_id).first()
|
||||
return ds
|
||||
|
||||
@property
|
||||
def datasource_link(self):
|
||||
return self.datasource.link
|
||||
|
||||
@property
|
||||
def roles_with_datasource(self):
|
||||
action_list = ''
|
||||
pv = sm.find_permission_view_menu(
|
||||
'datasource_access', self.datasource.perm)
|
||||
for r in pv.role:
|
||||
if r.name in self.ROLES_BLACKLIST:
|
||||
continue
|
||||
url = (
|
||||
'/superset/approve?datasource_type={self.datasource_type}&'
|
||||
'datasource_id={self.datasource_id}&'
|
||||
'created_by={self.created_by.username}&role_to_grant={r.name}'
|
||||
.format(**locals())
|
||||
)
|
||||
href = '<a href="{}">Grant {} Role</a>'.format(url, r.name)
|
||||
action_list = action_list + '<li>' + href + '</li>'
|
||||
return '<ul>' + action_list + '</ul>'
|
||||
|
||||
@property
|
||||
def user_roles(self):
|
||||
action_list = ''
|
||||
for r in self.created_by.roles:
|
||||
url = (
|
||||
'/superset/approve?datasource_type={self.datasource_type}&'
|
||||
'datasource_id={self.datasource_id}&'
|
||||
'created_by={self.created_by.username}&role_to_extend={r.name}'
|
||||
.format(**locals())
|
||||
)
|
||||
href = '<a href="{}">Extend {} Role</a>'.format(url, r.name)
|
||||
if r.name in self.ROLES_BLACKLIST:
|
||||
href = "{} Role".format(r.name)
|
||||
action_list = action_list + '<li>' + href + '</li>'
|
||||
return '<ul>' + action_list + '</ul>'
|
|
@ -0,0 +1,127 @@
|
|||
from datetime import datetime
|
||||
import humanize
|
||||
import json
|
||||
import re
|
||||
import sqlalchemy as sa
|
||||
|
||||
from sqlalchemy.ext.declarative import declared_attr
|
||||
|
||||
from flask import escape, Markup
|
||||
from flask_appbuilder.models.mixins import AuditMixin
|
||||
from flask_appbuilder.models.decorators import renders
|
||||
from superset.utils import QueryStatus
|
||||
|
||||
|
||||
class ImportMixin(object):
|
||||
def override(self, obj):
|
||||
"""Overrides the plain fields of the dashboard."""
|
||||
for field in obj.__class__.export_fields:
|
||||
setattr(self, field, getattr(obj, field))
|
||||
|
||||
def copy(self):
|
||||
"""Creates a copy of the dashboard without relationships."""
|
||||
new_obj = self.__class__()
|
||||
new_obj.override(self)
|
||||
return new_obj
|
||||
|
||||
def alter_params(self, **kwargs):
|
||||
d = self.params_dict
|
||||
d.update(kwargs)
|
||||
self.params = json.dumps(d)
|
||||
|
||||
@property
|
||||
def params_dict(self):
|
||||
if self.params:
|
||||
params = re.sub(",[ \t\r\n]+}", "}", self.params)
|
||||
params = re.sub(",[ \t\r\n]+\]", "]", params)
|
||||
return json.loads(params)
|
||||
else:
|
||||
return {}
|
||||
|
||||
|
||||
class AuditMixinNullable(AuditMixin):
|
||||
|
||||
"""Altering the AuditMixin to use nullable fields
|
||||
|
||||
Allows creating objects programmatically outside of CRUD
|
||||
"""
|
||||
|
||||
created_on = sa.Column(sa.DateTime, default=datetime.now, nullable=True)
|
||||
changed_on = sa.Column(
|
||||
sa.DateTime, default=datetime.now,
|
||||
onupdate=datetime.now, nullable=True)
|
||||
|
||||
@declared_attr
|
||||
def created_by_fk(cls): # noqa
|
||||
return sa.Column(
|
||||
sa.Integer, sa.ForeignKey('ab_user.id'),
|
||||
default=cls.get_user_id, nullable=True)
|
||||
|
||||
@declared_attr
|
||||
def changed_by_fk(cls): # noqa
|
||||
return sa.Column(
|
||||
sa.Integer, sa.ForeignKey('ab_user.id'),
|
||||
default=cls.get_user_id, onupdate=cls.get_user_id, nullable=True)
|
||||
|
||||
def _user_link(self, user):
|
||||
if not user:
|
||||
return ''
|
||||
url = '/superset/profile/{}/'.format(user.username)
|
||||
return Markup('<a href="{}">{}</a>'.format(url, escape(user) or ''))
|
||||
|
||||
@renders('created_by')
|
||||
def creator(self): # noqa
|
||||
return self._user_link(self.created_by)
|
||||
|
||||
@property
|
||||
def changed_by_(self):
|
||||
return self._user_link(self.changed_by)
|
||||
|
||||
@renders('changed_on')
|
||||
def changed_on_(self):
|
||||
return Markup(
|
||||
'<span class="no-wrap">{}</span>'.format(self.changed_on))
|
||||
|
||||
@renders('changed_on')
|
||||
def modified(self):
|
||||
s = humanize.naturaltime(datetime.now() - self.changed_on)
|
||||
return Markup('<span class="no-wrap">{}</span>'.format(s))
|
||||
|
||||
@property
|
||||
def icons(self):
|
||||
return """
|
||||
<a
|
||||
href="{self.datasource_edit_url}"
|
||||
data-toggle="tooltip"
|
||||
title="{self.datasource}">
|
||||
<i class="fa fa-database"></i>
|
||||
</a>
|
||||
""".format(**locals())
|
||||
|
||||
|
||||
class QueryResult(object):
|
||||
|
||||
"""Object returned by the query interface"""
|
||||
|
||||
def __init__( # noqa
|
||||
self,
|
||||
df,
|
||||
query,
|
||||
duration,
|
||||
status=QueryStatus.SUCCESS,
|
||||
error_message=None):
|
||||
self.df = df
|
||||
self.query = query
|
||||
self.duration = duration
|
||||
self.status = status
|
||||
self.error_message = error_message
|
||||
|
||||
|
||||
def set_perm(mapper, connection, target): # noqa
|
||||
if target.perm != target.get_perm():
|
||||
link_table = target.__table__
|
||||
connection.execute(
|
||||
link_table.update()
|
||||
.where(link_table.c.id == target.id)
|
||||
.values(perm=target.get_perm())
|
||||
)
|
|
@ -6,7 +6,9 @@ from __future__ import unicode_literals
|
|||
import logging
|
||||
from flask_appbuilder.security.sqla import models as ab_models
|
||||
|
||||
from superset import conf, db, models, sm, source_registry
|
||||
from superset import conf, db, sm
|
||||
from superset.models import core as models
|
||||
from superset.connectors.connector_registry import ConnectorRegistry
|
||||
|
||||
|
||||
READ_ONLY_MODEL_VIEWS = {
|
||||
|
@ -155,7 +157,7 @@ def create_custom_permissions():
|
|||
|
||||
def create_missing_datasource_perms(view_menu_set):
|
||||
logging.info("Creating missing datasource permissions.")
|
||||
datasources = source_registry.SourceRegistry.get_all_datasources(
|
||||
datasources = ConnectorRegistry.get_all_datasources(
|
||||
db.session)
|
||||
for datasource in datasources:
|
||||
if datasource and datasource.perm not in view_menu_set:
|
||||
|
@ -181,8 +183,8 @@ def create_missing_metrics_perm(view_menu_set):
|
|||
"""
|
||||
logging.info("Creating missing metrics permissions")
|
||||
metrics = []
|
||||
for model in [models.SqlMetric, models.DruidMetric]:
|
||||
metrics += list(db.session.query(model).all())
|
||||
for datasource_class in ConnectorRegistry.sources.values():
|
||||
metrics += list(db.session.query(datasource_class.metric_class).all())
|
||||
|
||||
for metric in metrics:
|
||||
if (metric.is_restricted and metric.perm and
|
||||
|
@ -216,7 +218,9 @@ def sync_role_definitions():
|
|||
if conf.get('PUBLIC_ROLE_LIKE_GAMMA', False):
|
||||
set_role('Public', pvms, is_gamma_pvm)
|
||||
|
||||
view_menu_set = db.session.query(models.SqlaTable).all()
|
||||
view_menu_set = []
|
||||
for datasource_class in ConnectorRegistry.sources.values():
|
||||
view_menu_set += list(db.session.query(datasource_class).all())
|
||||
create_missing_datasource_perms(view_menu_set)
|
||||
create_missing_database_perms(view_menu_set)
|
||||
create_missing_metrics_perm(view_menu_set)
|
||||
|
|
|
@ -12,11 +12,12 @@ from sqlalchemy.pool import NullPool
|
|||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from superset import (
|
||||
app, db, models, utils, dataframe, results_backend)
|
||||
app, db, utils, dataframe, results_backend)
|
||||
from superset.models import core as models
|
||||
from superset.sql_parse import SupersetQuery
|
||||
from superset.db_engine_specs import LimitMethod
|
||||
from superset.jinja_context import get_template_processor
|
||||
QueryStatus = models.QueryStatus
|
||||
from superset.utils import QueryStatus
|
||||
|
||||
celery_app = celery.Celery(config_source=app.config.get('CELERY_CONFIG'))
|
||||
|
||||
|
|
|
@ -438,7 +438,7 @@ def pessimistic_connection_handling(target):
|
|||
cursor.close()
|
||||
|
||||
|
||||
class QueryStatus:
|
||||
class QueryStatus(object):
|
||||
|
||||
"""Enum-type class for query statuses"""
|
||||
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
from . import base # noqa
|
||||
from . import core # noqa
|
|
@ -0,0 +1,201 @@
|
|||
import logging
|
||||
import json
|
||||
|
||||
from flask import g, redirect
|
||||
from flask_babel import gettext as __
|
||||
|
||||
from flask_appbuilder import BaseView
|
||||
from flask_appbuilder import ModelView
|
||||
from flask_appbuilder.widgets import ListWidget
|
||||
from flask_appbuilder.actions import action
|
||||
from flask_appbuilder.models.sqla.filters import BaseFilter
|
||||
from flask_appbuilder.security.sqla import models as ab_models
|
||||
|
||||
from superset import appbuilder, conf, db, utils, sm, sql_parse
|
||||
from superset.connectors.connector_registry import ConnectorRegistry
|
||||
|
||||
|
||||
def get_datasource_exist_error_mgs(full_name):
|
||||
return __("Datasource %(name)s already exists", name=full_name)
|
||||
|
||||
|
||||
def get_user_roles():
|
||||
if g.user.is_anonymous():
|
||||
public_role = conf.get('AUTH_ROLE_PUBLIC')
|
||||
return [appbuilder.sm.find_role(public_role)] if public_role else []
|
||||
return g.user.roles
|
||||
|
||||
|
||||
class BaseSupersetView(BaseView):
|
||||
def can_access(self, permission_name, view_name, user=None):
|
||||
if not user:
|
||||
user = g.user
|
||||
return utils.can_access(
|
||||
appbuilder.sm, permission_name, view_name, user)
|
||||
|
||||
def all_datasource_access(self, user=None):
|
||||
return self.can_access(
|
||||
"all_datasource_access", "all_datasource_access", user=user)
|
||||
|
||||
def database_access(self, database, user=None):
|
||||
return (
|
||||
self.can_access(
|
||||
"all_database_access", "all_database_access", user=user) or
|
||||
self.can_access("database_access", database.perm, user=user)
|
||||
)
|
||||
|
||||
def schema_access(self, datasource, user=None):
|
||||
return (
|
||||
self.database_access(datasource.database, user=user) or
|
||||
self.all_datasource_access(user=user) or
|
||||
self.can_access("schema_access", datasource.schema_perm, user=user)
|
||||
)
|
||||
|
||||
def datasource_access(self, datasource, user=None):
|
||||
return (
|
||||
self.schema_access(datasource, user=user) or
|
||||
self.can_access("datasource_access", datasource.perm, user=user)
|
||||
)
|
||||
|
||||
def datasource_access_by_name(
|
||||
self, database, datasource_name, schema=None):
|
||||
if self.database_access(database) or self.all_datasource_access():
|
||||
return True
|
||||
|
||||
schema_perm = utils.get_schema_perm(database, schema)
|
||||
if schema and utils.can_access(
|
||||
sm, 'schema_access', schema_perm, g.user):
|
||||
return True
|
||||
|
||||
datasources = ConnectorRegistry.query_datasources_by_name(
|
||||
db.session, database, datasource_name, schema=schema)
|
||||
for datasource in datasources:
|
||||
if self.can_access("datasource_access", datasource.perm):
|
||||
return True
|
||||
return False
|
||||
|
||||
def datasource_access_by_fullname(
|
||||
self, database, full_table_name, schema):
|
||||
table_name_pieces = full_table_name.split(".")
|
||||
if len(table_name_pieces) == 2:
|
||||
table_schema = table_name_pieces[0]
|
||||
table_name = table_name_pieces[1]
|
||||
else:
|
||||
table_schema = schema
|
||||
table_name = table_name_pieces[0]
|
||||
return self.datasource_access_by_name(
|
||||
database, table_name, schema=table_schema)
|
||||
|
||||
def rejected_datasources(self, sql, database, schema):
|
||||
superset_query = sql_parse.SupersetQuery(sql)
|
||||
return [
|
||||
t for t in superset_query.tables if not
|
||||
self.datasource_access_by_fullname(database, t, schema)]
|
||||
|
||||
def accessible_by_user(self, database, datasource_names, schema=None):
|
||||
if self.database_access(database) or self.all_datasource_access():
|
||||
return datasource_names
|
||||
|
||||
schema_perm = utils.get_schema_perm(database, schema)
|
||||
if schema and utils.can_access(
|
||||
sm, 'schema_access', schema_perm, g.user):
|
||||
return datasource_names
|
||||
|
||||
role_ids = set([role.id for role in g.user.roles])
|
||||
# TODO: cache user_perms or user_datasources
|
||||
user_pvms = (
|
||||
db.session.query(ab_models.PermissionView)
|
||||
.join(ab_models.Permission)
|
||||
.filter(ab_models.Permission.name == 'datasource_access')
|
||||
.filter(ab_models.PermissionView.role.any(
|
||||
ab_models.Role.id.in_(role_ids)))
|
||||
.all()
|
||||
)
|
||||
user_perms = set([pvm.view_menu.name for pvm in user_pvms])
|
||||
user_datasources = ConnectorRegistry.query_datasources_by_permissions(
|
||||
db.session, database, user_perms)
|
||||
full_names = set([d.full_name for d in user_datasources])
|
||||
return [d for d in datasource_names if d in full_names]
|
||||
|
||||
|
||||
class SupersetModelView(ModelView):
|
||||
page_size = 500
|
||||
|
||||
|
||||
class ListWidgetWithCheckboxes(ListWidget):
|
||||
"""An alternative to list view that renders Boolean fields as checkboxes
|
||||
|
||||
Works in conjunction with the `checkbox` view."""
|
||||
template = 'superset/fab_overrides/list_with_checkboxes.html'
|
||||
|
||||
|
||||
def validate_json(form, field): # noqa
|
||||
try:
|
||||
json.loads(field.data)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
raise Exception("json isn't valid")
|
||||
|
||||
|
||||
class DeleteMixin(object):
|
||||
@action(
|
||||
"muldelete", "Delete", "Delete all Really?", "fa-trash", single=False)
|
||||
def muldelete(self, items):
|
||||
self.datamodel.delete_all(items)
|
||||
self.update_redirect()
|
||||
return redirect(self.get_redirect())
|
||||
|
||||
|
||||
class SupersetFilter(BaseFilter):
|
||||
|
||||
"""Add utility function to make BaseFilter easy and fast
|
||||
|
||||
These utility function exist in the SecurityManager, but would do
|
||||
a database round trip at every check. Here we cache the role objects
|
||||
to be able to make multiple checks but query the db only once
|
||||
"""
|
||||
|
||||
def get_user_roles(self):
|
||||
return get_user_roles()
|
||||
|
||||
def get_all_permissions(self):
|
||||
"""Returns a set of tuples with the perm name and view menu name"""
|
||||
perms = set()
|
||||
for role in self.get_user_roles():
|
||||
for perm_view in role.permissions:
|
||||
t = (perm_view.permission.name, perm_view.view_menu.name)
|
||||
perms.add(t)
|
||||
return perms
|
||||
|
||||
def has_role(self, role_name_or_list):
|
||||
"""Whether the user has this role name"""
|
||||
if not isinstance(role_name_or_list, list):
|
||||
role_name_or_list = [role_name_or_list]
|
||||
return any(
|
||||
[r.name in role_name_or_list for r in self.get_user_roles()])
|
||||
|
||||
def has_perm(self, permission_name, view_menu_name):
|
||||
"""Whether the user has this perm"""
|
||||
return (permission_name, view_menu_name) in self.get_all_permissions()
|
||||
|
||||
def get_view_menus(self, permission_name):
|
||||
"""Returns the details of view_menus for a perm name"""
|
||||
vm = set()
|
||||
for perm_name, vm_name in self.get_all_permissions():
|
||||
if perm_name == permission_name:
|
||||
vm.add(vm_name)
|
||||
return vm
|
||||
|
||||
def has_all_datasource_access(self):
|
||||
return (
|
||||
self.has_role(['Admin', 'Alpha']) or
|
||||
self.has_perm('all_datasource_access', 'all_datasource_access'))
|
||||
|
||||
|
||||
class DatasourceFilter(SupersetFilter):
|
||||
def apply(self, query, func): # noqa
|
||||
if self.has_all_datasource_access():
|
||||
return query
|
||||
perms = self.get_view_menus('datasource_access')
|
||||
# TODO(bogdan): add `schema_access` support here
|
||||
return query.filter(self.model.perm.in_(perms))
|
|
@ -19,12 +19,10 @@ import sqlalchemy as sqla
|
|||
|
||||
from flask import (
|
||||
g, request, redirect, flash, Response, render_template, Markup)
|
||||
from flask_appbuilder import ModelView, CompactCRUDMixin, BaseView, expose
|
||||
from flask_appbuilder import expose
|
||||
from flask_appbuilder.actions import action
|
||||
from flask_appbuilder.models.sqla.interface import SQLAInterface
|
||||
from flask_appbuilder.security.decorators import has_access_api
|
||||
from flask_appbuilder.widgets import ListWidget
|
||||
from flask_appbuilder.models.sqla.filters import BaseFilter
|
||||
from flask_appbuilder.security.sqla import models as ab_models
|
||||
|
||||
from flask_babel import gettext as __
|
||||
|
@ -33,120 +31,26 @@ from flask_babel import lazy_gettext as _
|
|||
from sqlalchemy import create_engine
|
||||
from werkzeug.routing import BaseConverter
|
||||
|
||||
import superset
|
||||
from superset import (
|
||||
appbuilder, cache, db, models, viz, utils, app,
|
||||
sm, sql_lab, sql_parse, results_backend, security,
|
||||
appbuilder, cache, db, viz, utils, app,
|
||||
sm, sql_lab, results_backend, security,
|
||||
)
|
||||
from superset.legacy import cast_form_data
|
||||
from superset.utils import has_access
|
||||
from superset.source_registry import SourceRegistry
|
||||
from superset.models import DatasourceAccessRequest as DAR
|
||||
from superset.connectors.connector_registry import ConnectorRegistry
|
||||
import superset.models.core as models
|
||||
from superset.sql_parse import SupersetQuery
|
||||
|
||||
from .base import (
|
||||
SupersetModelView, BaseSupersetView, DeleteMixin,
|
||||
SupersetFilter, get_user_roles
|
||||
)
|
||||
|
||||
config = app.config
|
||||
log_this = models.Log.log_this
|
||||
can_access = utils.can_access
|
||||
QueryStatus = models.QueryStatus
|
||||
|
||||
|
||||
class BaseSupersetView(BaseView):
|
||||
def can_access(self, permission_name, view_name, user=None):
|
||||
if not user:
|
||||
user = g.user
|
||||
return utils.can_access(
|
||||
appbuilder.sm, permission_name, view_name, user)
|
||||
|
||||
def all_datasource_access(self, user=None):
|
||||
return self.can_access(
|
||||
"all_datasource_access", "all_datasource_access", user=user)
|
||||
|
||||
def database_access(self, database, user=None):
|
||||
return (
|
||||
self.can_access(
|
||||
"all_database_access", "all_database_access", user=user) or
|
||||
self.can_access("database_access", database.perm, user=user)
|
||||
)
|
||||
|
||||
def schema_access(self, datasource, user=None):
|
||||
return (
|
||||
self.database_access(datasource.database, user=user) or
|
||||
self.all_datasource_access(user=user) or
|
||||
self.can_access("schema_access", datasource.schema_perm, user=user)
|
||||
)
|
||||
|
||||
def datasource_access(self, datasource, user=None):
|
||||
return (
|
||||
self.schema_access(datasource, user=user) or
|
||||
self.can_access("datasource_access", datasource.perm, user=user)
|
||||
)
|
||||
|
||||
def datasource_access_by_name(
|
||||
self, database, datasource_name, schema=None):
|
||||
if self.database_access(database) or self.all_datasource_access():
|
||||
return True
|
||||
|
||||
schema_perm = utils.get_schema_perm(database, schema)
|
||||
if schema and utils.can_access(
|
||||
sm, 'schema_access', schema_perm, g.user):
|
||||
return True
|
||||
|
||||
datasources = SourceRegistry.query_datasources_by_name(
|
||||
db.session, database, datasource_name, schema=schema)
|
||||
for datasource in datasources:
|
||||
if self.can_access("datasource_access", datasource.perm):
|
||||
return True
|
||||
return False
|
||||
|
||||
def datasource_access_by_fullname(
|
||||
self, database, full_table_name, schema):
|
||||
table_name_pieces = full_table_name.split(".")
|
||||
if len(table_name_pieces) == 2:
|
||||
table_schema = table_name_pieces[0]
|
||||
table_name = table_name_pieces[1]
|
||||
else:
|
||||
table_schema = schema
|
||||
table_name = table_name_pieces[0]
|
||||
return self.datasource_access_by_name(
|
||||
database, table_name, schema=table_schema)
|
||||
|
||||
def rejected_datasources(self, sql, database, schema):
|
||||
superset_query = sql_parse.SupersetQuery(sql)
|
||||
return [
|
||||
t for t in superset_query.tables if not
|
||||
self.datasource_access_by_fullname(database, t, schema)]
|
||||
|
||||
def accessible_by_user(self, database, datasource_names, schema=None):
|
||||
if self.database_access(database) or self.all_datasource_access():
|
||||
return datasource_names
|
||||
|
||||
schema_perm = utils.get_schema_perm(database, schema)
|
||||
if schema and utils.can_access(
|
||||
sm, 'schema_access', schema_perm, g.user):
|
||||
return datasource_names
|
||||
|
||||
role_ids = set([role.id for role in g.user.roles])
|
||||
# TODO: cache user_perms or user_datasources
|
||||
user_pvms = (
|
||||
db.session.query(ab_models.PermissionView)
|
||||
.join(ab_models.Permission)
|
||||
.filter(ab_models.Permission.name == 'datasource_access')
|
||||
.filter(ab_models.PermissionView.role.any(
|
||||
ab_models.Role.id.in_(role_ids)))
|
||||
.all()
|
||||
)
|
||||
user_perms = set([pvm.view_menu.name for pvm in user_pvms])
|
||||
user_datasources = SourceRegistry.query_datasources_by_permissions(
|
||||
db.session, database, user_perms)
|
||||
full_names = set([d.full_name for d in user_datasources])
|
||||
return [d for d in datasource_names if d in full_names]
|
||||
|
||||
|
||||
class ListWidgetWithCheckboxes(ListWidget):
|
||||
"""An alternative to list view that renders Boolean fields as checkboxes
|
||||
|
||||
Works in conjunction with the `checkbox` view."""
|
||||
template = 'superset/fab_overrides/list_with_checkboxes.html'
|
||||
DAR = models.DatasourceAccessRequest
|
||||
|
||||
|
||||
ALL_DATASOURCE_ACCESS_ERR = __(
|
||||
|
@ -168,10 +72,6 @@ def get_datasource_access_error_msg(datasource_name):
|
|||
"`all_datasource_access` permission", name=datasource_name)
|
||||
|
||||
|
||||
def get_datasource_exist_error_mgs(full_name):
|
||||
return __("Datasource %(name)s already exists", name=full_name)
|
||||
|
||||
|
||||
def get_error_msg():
|
||||
if config.get("SHOW_STACKTRACE"):
|
||||
error_msg = traceback.format_exc()
|
||||
|
@ -211,10 +111,12 @@ def api(f):
|
|||
|
||||
return functools.update_wrapper(wraps, f)
|
||||
|
||||
|
||||
def is_owner(obj, user):
|
||||
""" Check if user is owner of the slice """
|
||||
return obj and obj.owners and user in obj.owners
|
||||
|
||||
|
||||
def check_ownership(obj, raise_if_false=True):
|
||||
"""Meant to be used in `pre_update` hooks on models to enforce ownership
|
||||
|
||||
|
@ -257,68 +159,6 @@ def check_ownership(obj, raise_if_false=True):
|
|||
return False
|
||||
|
||||
|
||||
def get_user_roles():
|
||||
if g.user.is_anonymous():
|
||||
public_role = config.get('AUTH_ROLE_PUBLIC')
|
||||
return [appbuilder.sm.find_role(public_role)] if public_role else []
|
||||
return g.user.roles
|
||||
|
||||
|
||||
class SupersetFilter(BaseFilter):
|
||||
|
||||
"""Add utility function to make BaseFilter easy and fast
|
||||
|
||||
These utility function exist in the SecurityManager, but would do
|
||||
a database round trip at every check. Here we cache the role objects
|
||||
to be able to make multiple checks but query the db only once
|
||||
"""
|
||||
|
||||
def get_user_roles(self):
|
||||
return get_user_roles()
|
||||
|
||||
def get_all_permissions(self):
|
||||
"""Returns a set of tuples with the perm name and view menu name"""
|
||||
perms = set()
|
||||
for role in get_user_roles():
|
||||
for perm_view in role.permissions:
|
||||
t = (perm_view.permission.name, perm_view.view_menu.name)
|
||||
perms.add(t)
|
||||
return perms
|
||||
|
||||
def has_role(self, role_name_or_list):
|
||||
"""Whether the user has this role name"""
|
||||
if not isinstance(role_name_or_list, list):
|
||||
role_name_or_list = [role_name_or_list]
|
||||
return any(
|
||||
[r.name in role_name_or_list for r in self.get_user_roles()])
|
||||
|
||||
def has_perm(self, permission_name, view_menu_name):
|
||||
"""Whether the user has this perm"""
|
||||
return (permission_name, view_menu_name) in self.get_all_permissions()
|
||||
|
||||
def get_view_menus(self, permission_name):
|
||||
"""Returns the details of view_menus for a perm name"""
|
||||
vm = set()
|
||||
for perm_name, vm_name in self.get_all_permissions():
|
||||
if perm_name == permission_name:
|
||||
vm.add(vm_name)
|
||||
return vm
|
||||
|
||||
def has_all_datasource_access(self):
|
||||
return (
|
||||
self.has_role(['Admin', 'Alpha']) or
|
||||
self.has_perm('all_datasource_access', 'all_datasource_access'))
|
||||
|
||||
|
||||
class DatasourceFilter(SupersetFilter):
|
||||
def apply(self, query, func): # noqa
|
||||
if self.has_all_datasource_access():
|
||||
return query
|
||||
perms = self.get_view_menus('datasource_access')
|
||||
# TODO(bogdan): add `schema_access` support here
|
||||
return query.filter(self.model.perm.in_(perms))
|
||||
|
||||
|
||||
class SliceFilter(SupersetFilter):
|
||||
def apply(self, query, func): # noqa
|
||||
if self.has_all_datasource_access():
|
||||
|
@ -355,14 +195,6 @@ class DashboardFilter(SupersetFilter):
|
|||
return query
|
||||
|
||||
|
||||
def validate_json(form, field): # noqa
|
||||
try:
|
||||
json.loads(field.data)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
raise Exception("json isn't valid")
|
||||
|
||||
|
||||
def generate_download_headers(extension):
|
||||
filename = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
content_disp = "attachment; filename={}.{}".format(filename, extension)
|
||||
|
@ -372,206 +204,6 @@ def generate_download_headers(extension):
|
|||
return headers
|
||||
|
||||
|
||||
class DeleteMixin(object):
|
||||
@action(
|
||||
"muldelete", "Delete", "Delete all Really?", "fa-trash", single=False)
|
||||
def muldelete(self, items):
|
||||
self.datamodel.delete_all(items)
|
||||
self.update_redirect()
|
||||
return redirect(self.get_redirect())
|
||||
|
||||
|
||||
class SupersetModelView(ModelView):
|
||||
page_size = 500
|
||||
|
||||
|
||||
class TableColumnInlineView(CompactCRUDMixin, SupersetModelView): # noqa
|
||||
datamodel = SQLAInterface(models.TableColumn)
|
||||
can_delete = False
|
||||
list_widget = ListWidgetWithCheckboxes
|
||||
edit_columns = [
|
||||
'column_name', 'verbose_name', 'description', 'groupby', 'filterable',
|
||||
'table', 'count_distinct', 'sum', 'min', 'max', 'expression',
|
||||
'is_dttm', 'python_date_format', 'database_expression']
|
||||
add_columns = edit_columns
|
||||
list_columns = [
|
||||
'column_name', 'type', 'groupby', 'filterable', 'count_distinct',
|
||||
'sum', 'min', 'max', 'is_dttm']
|
||||
page_size = 500
|
||||
description_columns = {
|
||||
'is_dttm': (_(
|
||||
"Whether to make this column available as a "
|
||||
"[Time Granularity] option, column has to be DATETIME or "
|
||||
"DATETIME-like")),
|
||||
'expression': utils.markdown(
|
||||
"a valid SQL expression as supported by the underlying backend. "
|
||||
"Example: `substr(name, 1, 1)`", True),
|
||||
'python_date_format': utils.markdown(Markup(
|
||||
"The pattern of timestamp format, use "
|
||||
"<a href='https://docs.python.org/2/library/"
|
||||
"datetime.html#strftime-strptime-behavior'>"
|
||||
"python datetime string pattern</a> "
|
||||
"expression. If time is stored in epoch "
|
||||
"format, put `epoch_s` or `epoch_ms`. Leave `Database Expression` "
|
||||
"below empty if timestamp is stored in "
|
||||
"String or Integer(epoch) type"), True),
|
||||
'database_expression': utils.markdown(
|
||||
"The database expression to cast internal datetime "
|
||||
"constants to database date/timestamp type according to the DBAPI. "
|
||||
"The expression should follow the pattern of "
|
||||
"%Y-%m-%d %H:%M:%S, based on different DBAPI. "
|
||||
"The string should be a python string formatter \n"
|
||||
"`Ex: TO_DATE('{}', 'YYYY-MM-DD HH24:MI:SS')` for Oracle"
|
||||
"Superset uses default expression based on DB URI if this "
|
||||
"field is blank.", True),
|
||||
}
|
||||
label_columns = {
|
||||
'column_name': _("Column"),
|
||||
'verbose_name': _("Verbose Name"),
|
||||
'description': _("Description"),
|
||||
'groupby': _("Groupable"),
|
||||
'filterable': _("Filterable"),
|
||||
'table': _("Table"),
|
||||
'count_distinct': _("Count Distinct"),
|
||||
'sum': _("Sum"),
|
||||
'min': _("Min"),
|
||||
'max': _("Max"),
|
||||
'expression': _("Expression"),
|
||||
'is_dttm': _("Is temporal"),
|
||||
'python_date_format': _("Datetime Format"),
|
||||
'database_expression': _("Database Expression")
|
||||
}
|
||||
appbuilder.add_view_no_menu(TableColumnInlineView)
|
||||
|
||||
|
||||
class DruidColumnInlineView(CompactCRUDMixin, SupersetModelView): # noqa
|
||||
datamodel = SQLAInterface(models.DruidColumn)
|
||||
edit_columns = [
|
||||
'column_name', 'description', 'dimension_spec_json', 'datasource',
|
||||
'groupby', 'count_distinct', 'sum', 'min', 'max']
|
||||
add_columns = edit_columns
|
||||
list_columns = [
|
||||
'column_name', 'type', 'groupby', 'filterable', 'count_distinct',
|
||||
'sum', 'min', 'max']
|
||||
can_delete = False
|
||||
page_size = 500
|
||||
label_columns = {
|
||||
'column_name': _("Column"),
|
||||
'type': _("Type"),
|
||||
'datasource': _("Datasource"),
|
||||
'groupby': _("Groupable"),
|
||||
'filterable': _("Filterable"),
|
||||
'count_distinct': _("Count Distinct"),
|
||||
'sum': _("Sum"),
|
||||
'min': _("Min"),
|
||||
'max': _("Max"),
|
||||
}
|
||||
description_columns = {
|
||||
'dimension_spec_json': utils.markdown(
|
||||
"this field can be used to specify "
|
||||
"a `dimensionSpec` as documented [here]"
|
||||
"(http://druid.io/docs/latest/querying/dimensionspecs.html). "
|
||||
"Make sure to input valid JSON and that the "
|
||||
"`outputName` matches the `column_name` defined "
|
||||
"above.",
|
||||
True),
|
||||
}
|
||||
|
||||
def post_update(self, col):
|
||||
col.generate_metrics()
|
||||
utils.validate_json(col.dimension_spec_json)
|
||||
|
||||
def post_add(self, col):
|
||||
self.post_update(col)
|
||||
|
||||
appbuilder.add_view_no_menu(DruidColumnInlineView)
|
||||
|
||||
|
||||
class SqlMetricInlineView(CompactCRUDMixin, SupersetModelView): # noqa
|
||||
datamodel = SQLAInterface(models.SqlMetric)
|
||||
list_columns = ['metric_name', 'verbose_name', 'metric_type']
|
||||
edit_columns = [
|
||||
'metric_name', 'description', 'verbose_name', 'metric_type',
|
||||
'expression', 'table', 'd3format', 'is_restricted']
|
||||
description_columns = {
|
||||
'expression': utils.markdown(
|
||||
"a valid SQL expression as supported by the underlying backend. "
|
||||
"Example: `count(DISTINCT userid)`", True),
|
||||
'is_restricted': _("Whether the access to this metric is restricted "
|
||||
"to certain roles. Only roles with the permission "
|
||||
"'metric access on XXX (the name of this metric)' "
|
||||
"are allowed to access this metric"),
|
||||
'd3format': utils.markdown(
|
||||
"d3 formatting string as defined [here]"
|
||||
"(https://github.com/d3/d3-format/blob/master/README.md#format). "
|
||||
"For instance, this default formatting applies in the Table "
|
||||
"visualization and allow for different metric to use different "
|
||||
"formats", True
|
||||
),
|
||||
}
|
||||
add_columns = edit_columns
|
||||
page_size = 500
|
||||
label_columns = {
|
||||
'metric_name': _("Metric"),
|
||||
'description': _("Description"),
|
||||
'verbose_name': _("Verbose Name"),
|
||||
'metric_type': _("Type"),
|
||||
'expression': _("SQL Expression"),
|
||||
'table': _("Table"),
|
||||
}
|
||||
|
||||
def post_add(self, metric):
|
||||
if metric.is_restricted:
|
||||
security.merge_perm(sm, 'metric_access', metric.get_perm())
|
||||
|
||||
def post_update(self, metric):
|
||||
if metric.is_restricted:
|
||||
security.merge_perm(sm, 'metric_access', metric.get_perm())
|
||||
|
||||
appbuilder.add_view_no_menu(SqlMetricInlineView)
|
||||
|
||||
|
||||
class DruidMetricInlineView(CompactCRUDMixin, SupersetModelView): # noqa
|
||||
datamodel = SQLAInterface(models.DruidMetric)
|
||||
list_columns = ['metric_name', 'verbose_name', 'metric_type']
|
||||
edit_columns = [
|
||||
'metric_name', 'description', 'verbose_name', 'metric_type', 'json',
|
||||
'datasource', 'd3format', 'is_restricted']
|
||||
add_columns = edit_columns
|
||||
page_size = 500
|
||||
validators_columns = {
|
||||
'json': [validate_json],
|
||||
}
|
||||
description_columns = {
|
||||
'metric_type': utils.markdown(
|
||||
"use `postagg` as the metric type if you are defining a "
|
||||
"[Druid Post Aggregation]"
|
||||
"(http://druid.io/docs/latest/querying/post-aggregations.html)",
|
||||
True),
|
||||
'is_restricted': _("Whether the access to this metric is restricted "
|
||||
"to certain roles. Only roles with the permission "
|
||||
"'metric access on XXX (the name of this metric)' "
|
||||
"are allowed to access this metric"),
|
||||
}
|
||||
label_columns = {
|
||||
'metric_name': _("Metric"),
|
||||
'description': _("Description"),
|
||||
'verbose_name': _("Verbose Name"),
|
||||
'metric_type': _("Type"),
|
||||
'json': _("JSON"),
|
||||
'datasource': _("Druid Datasource"),
|
||||
}
|
||||
|
||||
def post_add(self, metric):
|
||||
utils.init_metrics_perm(superset, [metric])
|
||||
|
||||
def post_update(self, metric):
|
||||
utils.init_metrics_perm(superset, [metric])
|
||||
|
||||
|
||||
appbuilder.add_view_no_menu(DruidMetricInlineView)
|
||||
|
||||
|
||||
class DatabaseView(SupersetModelView, DeleteMixin): # noqa
|
||||
datamodel = SQLAInterface(models.Database)
|
||||
list_columns = [
|
||||
|
@ -692,99 +324,6 @@ class DatabaseTablesAsync(DatabaseView):
|
|||
appbuilder.add_view_no_menu(DatabaseTablesAsync)
|
||||
|
||||
|
||||
class TableModelView(SupersetModelView, DeleteMixin): # noqa
|
||||
datamodel = SQLAInterface(models.SqlaTable)
|
||||
list_columns = [
|
||||
'link', 'database', 'is_featured',
|
||||
'changed_by_', 'changed_on_']
|
||||
order_columns = [
|
||||
'link', 'database', 'is_featured', 'changed_on_']
|
||||
add_columns = ['database', 'schema', 'table_name']
|
||||
edit_columns = [
|
||||
'table_name', 'sql', 'is_featured', 'filter_select_enabled',
|
||||
'database', 'schema',
|
||||
'description', 'owner',
|
||||
'main_dttm_col', 'default_endpoint', 'offset', 'cache_timeout']
|
||||
show_columns = edit_columns + ['perm']
|
||||
related_views = [TableColumnInlineView, SqlMetricInlineView]
|
||||
base_order = ('changed_on', 'desc')
|
||||
description_columns = {
|
||||
'offset': _("Timezone offset (in hours) for this datasource"),
|
||||
'table_name': _(
|
||||
"Name of the table that exists in the source database"),
|
||||
'schema': _(
|
||||
"Schema, as used only in some databases like Postgres, Redshift "
|
||||
"and DB2"),
|
||||
'description': Markup(
|
||||
"Supports <a href='https://daringfireball.net/projects/markdown/'>"
|
||||
"markdown</a>"),
|
||||
'sql': _(
|
||||
"This fields acts a Superset view, meaning that Superset will "
|
||||
"run a query against this string as a subquery."
|
||||
),
|
||||
}
|
||||
base_filters = [['id', DatasourceFilter, lambda: []]]
|
||||
label_columns = {
|
||||
'link': _("Table"),
|
||||
'changed_by_': _("Changed By"),
|
||||
'database': _("Database"),
|
||||
'changed_on_': _("Last Changed"),
|
||||
'is_featured': _("Is Featured"),
|
||||
'filter_select_enabled': _("Enable Filter Select"),
|
||||
'schema': _("Schema"),
|
||||
'default_endpoint': _("Default Endpoint"),
|
||||
'offset': _("Offset"),
|
||||
'cache_timeout': _("Cache Timeout"),
|
||||
}
|
||||
|
||||
def pre_add(self, table):
|
||||
number_of_existing_tables = db.session.query(
|
||||
sqla.func.count('*')).filter(
|
||||
models.SqlaTable.table_name == table.table_name,
|
||||
models.SqlaTable.schema == table.schema,
|
||||
models.SqlaTable.database_id == table.database.id
|
||||
).scalar()
|
||||
# table object is already added to the session
|
||||
if number_of_existing_tables > 1:
|
||||
raise Exception(get_datasource_exist_error_mgs(table.full_name))
|
||||
|
||||
# Fail before adding if the table can't be found
|
||||
try:
|
||||
table.get_sqla_table_object()
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
raise Exception(
|
||||
"Table [{}] could not be found, "
|
||||
"please double check your "
|
||||
"database connection, schema, and "
|
||||
"table name".format(table.name))
|
||||
|
||||
def post_add(self, table):
|
||||
table.fetch_metadata()
|
||||
security.merge_perm(sm, 'datasource_access', table.get_perm())
|
||||
if table.schema:
|
||||
security.merge_perm(sm, 'schema_access', table.schema_perm)
|
||||
|
||||
flash(_(
|
||||
"The table was created. As part of this two phase configuration "
|
||||
"process, you should now click the edit button by "
|
||||
"the new table to configure it."),
|
||||
"info")
|
||||
|
||||
def post_update(self, table):
|
||||
self.post_add(table)
|
||||
|
||||
appbuilder.add_view(
|
||||
TableModelView,
|
||||
"Tables",
|
||||
label=__("Tables"),
|
||||
category="Sources",
|
||||
category_label=__("Sources"),
|
||||
icon='fa-table',)
|
||||
|
||||
appbuilder.add_separator("Sources")
|
||||
|
||||
|
||||
class AccessRequestsModelView(SupersetModelView, DeleteMixin):
|
||||
datamodel = SQLAInterface(DAR)
|
||||
list_columns = [
|
||||
|
@ -810,43 +349,6 @@ appbuilder.add_view(
|
|||
icon='fa-table',)
|
||||
|
||||
|
||||
class DruidClusterModelView(SupersetModelView, DeleteMixin): # noqa
|
||||
datamodel = SQLAInterface(models.DruidCluster)
|
||||
add_columns = [
|
||||
'cluster_name',
|
||||
'coordinator_host', 'coordinator_port', 'coordinator_endpoint',
|
||||
'broker_host', 'broker_port', 'broker_endpoint', 'cache_timeout',
|
||||
]
|
||||
edit_columns = add_columns
|
||||
list_columns = ['cluster_name', 'metadata_last_refreshed']
|
||||
label_columns = {
|
||||
'cluster_name': _("Cluster"),
|
||||
'coordinator_host': _("Coordinator Host"),
|
||||
'coordinator_port': _("Coordinator Port"),
|
||||
'coordinator_endpoint': _("Coordinator Endpoint"),
|
||||
'broker_host': _("Broker Host"),
|
||||
'broker_port': _("Broker Port"),
|
||||
'broker_endpoint': _("Broker Endpoint"),
|
||||
}
|
||||
|
||||
def pre_add(self, cluster):
|
||||
security.merge_perm(sm, 'database_access', cluster.perm)
|
||||
|
||||
def pre_update(self, cluster):
|
||||
self.pre_add(cluster)
|
||||
|
||||
|
||||
if config['DRUID_IS_ACTIVE']:
|
||||
appbuilder.add_view(
|
||||
DruidClusterModelView,
|
||||
name="Druid Clusters",
|
||||
label=__("Druid Clusters"),
|
||||
icon="fa-cubes",
|
||||
category="Sources",
|
||||
category_label=__("Sources"),
|
||||
category_icon='fa-database',)
|
||||
|
||||
|
||||
class SliceModelView(SupersetModelView, DeleteMixin): # noqa
|
||||
datamodel = SQLAInterface(models.Slice)
|
||||
can_add = False
|
||||
|
@ -903,9 +405,9 @@ class SliceModelView(SupersetModelView, DeleteMixin): # noqa
|
|||
if not widget:
|
||||
return redirect(self.get_redirect())
|
||||
|
||||
sources = SourceRegistry.sources
|
||||
sources = ConnectorRegistry.sources
|
||||
for source in sources:
|
||||
ds = db.session.query(SourceRegistry.sources[source]).first()
|
||||
ds = db.session.query(ConnectorRegistry.sources[source]).first()
|
||||
if ds is not None:
|
||||
url = "/{}/list/".format(ds.baselink)
|
||||
msg = _("Click on a {} link to create a Slice".format(source))
|
||||
|
@ -1078,74 +580,6 @@ appbuilder.add_view(
|
|||
icon="fa-search")
|
||||
|
||||
|
||||
class DruidDatasourceModelView(SupersetModelView, DeleteMixin): # noqa
|
||||
datamodel = SQLAInterface(models.DruidDatasource)
|
||||
list_widget = ListWidgetWithCheckboxes
|
||||
list_columns = [
|
||||
'datasource_link', 'cluster', 'changed_by_', 'changed_on_', 'offset']
|
||||
order_columns = [
|
||||
'datasource_link', 'changed_on_', 'offset']
|
||||
related_views = [DruidColumnInlineView, DruidMetricInlineView]
|
||||
edit_columns = [
|
||||
'datasource_name', 'cluster', 'description', 'owner',
|
||||
'is_featured', 'is_hidden', 'filter_select_enabled',
|
||||
'default_endpoint', 'offset', 'cache_timeout']
|
||||
add_columns = edit_columns
|
||||
show_columns = add_columns + ['perm']
|
||||
page_size = 500
|
||||
base_order = ('datasource_name', 'asc')
|
||||
description_columns = {
|
||||
'offset': _("Timezone offset (in hours) for this datasource"),
|
||||
'description': Markup(
|
||||
"Supports <a href='"
|
||||
"https://daringfireball.net/projects/markdown/'>markdown</a>"),
|
||||
}
|
||||
base_filters = [['id', DatasourceFilter, lambda: []]]
|
||||
label_columns = {
|
||||
'datasource_link': _("Data Source"),
|
||||
'cluster': _("Cluster"),
|
||||
'description': _("Description"),
|
||||
'owner': _("Owner"),
|
||||
'is_featured': _("Is Featured"),
|
||||
'is_hidden': _("Is Hidden"),
|
||||
'filter_select_enabled': _("Enable Filter Select"),
|
||||
'default_endpoint': _("Default Endpoint"),
|
||||
'offset': _("Time Offset"),
|
||||
'cache_timeout': _("Cache Timeout"),
|
||||
}
|
||||
|
||||
def pre_add(self, datasource):
|
||||
number_of_existing_datasources = db.session.query(
|
||||
sqla.func.count('*')).filter(
|
||||
models.DruidDatasource.datasource_name ==
|
||||
datasource.datasource_name,
|
||||
models.DruidDatasource.cluster_name == datasource.cluster.id
|
||||
).scalar()
|
||||
|
||||
# table object is already added to the session
|
||||
if number_of_existing_datasources > 1:
|
||||
raise Exception(get_datasource_exist_error_mgs(
|
||||
datasource.full_name))
|
||||
|
||||
def post_add(self, datasource):
|
||||
datasource.generate_metrics()
|
||||
security.merge_perm(sm, 'datasource_access', datasource.get_perm())
|
||||
if datasource.schema:
|
||||
security.merge_perm(sm, 'schema_access', datasource.schema_perm)
|
||||
|
||||
def post_update(self, datasource):
|
||||
self.post_add(datasource)
|
||||
|
||||
if config['DRUID_IS_ACTIVE']:
|
||||
appbuilder.add_view(
|
||||
DruidDatasourceModelView,
|
||||
"Druid Datasources",
|
||||
label=__("Druid Datasources"),
|
||||
category="Sources",
|
||||
category_label=__("Sources"),
|
||||
icon="fa-cube")
|
||||
|
||||
|
||||
@app.route('/health')
|
||||
def health():
|
||||
return "OK"
|
||||
|
@ -1283,7 +717,7 @@ class Superset(BaseSupersetView):
|
|||
@has_access_api
|
||||
@expose("/datasources/")
|
||||
def datasources(self):
|
||||
datasources = SourceRegistry.get_all_datasources(db.session)
|
||||
datasources = ConnectorRegistry.get_all_datasources(db.session)
|
||||
datasources = [(str(o.id) + '__' + o.type, repr(o)) for o in datasources]
|
||||
return self.json_response(datasources)
|
||||
|
||||
|
@ -1318,7 +752,7 @@ class Superset(BaseSupersetView):
|
|||
dbs['name'], ds_name, schema=schema['name'])
|
||||
db_ds_names.add(fullname)
|
||||
|
||||
existing_datasources = SourceRegistry.get_all_datasources(db.session)
|
||||
existing_datasources = ConnectorRegistry.get_all_datasources(db.session)
|
||||
datasources = [
|
||||
d for d in existing_datasources if d.full_name in db_ds_names]
|
||||
role = sm.find_role(role_name)
|
||||
|
@ -1356,7 +790,7 @@ class Superset(BaseSupersetView):
|
|||
datasource_id = request.args.get('datasource_id')
|
||||
datasource_type = request.args.get('datasource_type')
|
||||
if datasource_id:
|
||||
ds_class = SourceRegistry.sources.get(datasource_type)
|
||||
ds_class = ConnectorRegistry.sources.get(datasource_type)
|
||||
datasource = (
|
||||
db.session.query(ds_class)
|
||||
.filter_by(id=int(datasource_id))
|
||||
|
@ -1385,7 +819,7 @@ class Superset(BaseSupersetView):
|
|||
def approve(self):
|
||||
def clean_fulfilled_requests(session):
|
||||
for r in session.query(DAR).all():
|
||||
datasource = SourceRegistry.get_datasource(
|
||||
datasource = ConnectorRegistry.get_datasource(
|
||||
r.datasource_type, r.datasource_id, session)
|
||||
user = sm.get_user_by_id(r.created_by_fk)
|
||||
if not datasource or \
|
||||
|
@ -1400,7 +834,7 @@ class Superset(BaseSupersetView):
|
|||
role_to_extend = request.args.get('role_to_extend')
|
||||
|
||||
session = db.session
|
||||
datasource = SourceRegistry.get_datasource(
|
||||
datasource = ConnectorRegistry.get_datasource(
|
||||
datasource_type, datasource_id, session)
|
||||
|
||||
if not datasource:
|
||||
|
@ -1501,9 +935,9 @@ class Superset(BaseSupersetView):
|
|||
)
|
||||
return slc.get_viz()
|
||||
else:
|
||||
form_data=self.get_form_data()
|
||||
form_data = self.get_form_data()
|
||||
viz_type = form_data.get('viz_type', 'table')
|
||||
datasource = SourceRegistry.get_datasource(
|
||||
datasource = ConnectorRegistry.get_datasource(
|
||||
datasource_type, datasource_id, db.session)
|
||||
viz_obj = viz.viz_types[viz_type](
|
||||
datasource,
|
||||
|
@ -1542,7 +976,6 @@ class Superset(BaseSupersetView):
|
|||
utils.error_msg_from_exception(e),
|
||||
stacktrace=traceback.format_exc())
|
||||
|
||||
|
||||
if not self.datasource_access(viz_obj.datasource):
|
||||
return json_error_response(DATASOURCE_ACCESS_ERR, status=404)
|
||||
|
||||
|
@ -1598,11 +1031,8 @@ class Superset(BaseSupersetView):
|
|||
data = pickle.load(f)
|
||||
# TODO: import DRUID datasources
|
||||
for table in data['datasources']:
|
||||
if table.type == 'table':
|
||||
models.SqlaTable.import_obj(table, import_time=current_tt)
|
||||
else:
|
||||
models.DruidDatasource.import_obj(
|
||||
table, import_time=current_tt)
|
||||
ds_class = ConnectorRegistry.sources.get(table.type)
|
||||
ds_class.import_obj(table, import_time=current_tt)
|
||||
db.session.commit()
|
||||
for dashboard in data['dashboards']:
|
||||
models.Dashboard.import_obj(
|
||||
|
@ -1628,7 +1058,7 @@ class Superset(BaseSupersetView):
|
|||
|
||||
error_redirect = '/slicemodelview/list/'
|
||||
datasource = (
|
||||
db.session.query(SourceRegistry.sources[datasource_type])
|
||||
db.session.query(ConnectorRegistry.sources[datasource_type])
|
||||
.filter_by(id=datasource_id)
|
||||
.one()
|
||||
)
|
||||
|
@ -1705,8 +1135,7 @@ class Superset(BaseSupersetView):
|
|||
"""
|
||||
# TODO: Cache endpoint by user, datasource and column
|
||||
error_redirect = '/slicemodelview/list/'
|
||||
datasource_class = models.SqlaTable \
|
||||
if datasource_type == "table" else models.DruidDatasource
|
||||
datasource_class = ConnectorRegistry.sources[datasource_type]
|
||||
|
||||
datasource = db.session.query(
|
||||
datasource_class).filter_by(id=datasource_id).first()
|
||||
|
@ -2194,12 +1623,13 @@ class Superset(BaseSupersetView):
|
|||
return json_error_response(__(
|
||||
"Slice %(id)s not found", id=slice_id), status=404)
|
||||
elif table_name and db_name:
|
||||
SqlaTable = ConnectorRegistry.sources['table']
|
||||
table = (
|
||||
session.query(models.SqlaTable)
|
||||
session.query(SqlaTable)
|
||||
.join(models.Database)
|
||||
.filter(
|
||||
models.Database.database_name == db_name or
|
||||
models.SqlaTable.table_name == table_name)
|
||||
SqlaTable.table_name == table_name)
|
||||
).first()
|
||||
if not table:
|
||||
return json_error_response(__(
|
||||
|
@ -2209,15 +1639,15 @@ class Superset(BaseSupersetView):
|
|||
datasource_id=table.id,
|
||||
datasource_type=table.type).all()
|
||||
|
||||
for slice in slices:
|
||||
for slc in slices:
|
||||
try:
|
||||
obj = slice.get_viz()
|
||||
obj = slc.get_viz()
|
||||
obj.get_json(force=True)
|
||||
except Exception as e:
|
||||
return json_error_response(utils.error_msg_from_exception(e))
|
||||
return json_success(json.dumps(
|
||||
[{"slice_id": session.id, "slice_name": session.slice_name}
|
||||
for session in slices]))
|
||||
[{"slice_id": slc.id, "slice_name": slc.slice_name}
|
||||
for slc in slices]))
|
||||
|
||||
@expose("/favstar/<class_name>/<obj_id>/<action>/")
|
||||
def favstar(self, class_name, obj_id, action):
|
||||
|
@ -2322,12 +1752,14 @@ class Superset(BaseSupersetView):
|
|||
cluster_name = payload['cluster']
|
||||
|
||||
user = sm.find_user(username=user_name)
|
||||
DruidDatasource = ConnectorRegistry.sources['druid']
|
||||
DruidCluster = DruidDatasource.cluster_class
|
||||
if not user:
|
||||
err_msg = __("Can't find User '%(name)s', please ask your admin "
|
||||
"to create one.", name=user_name)
|
||||
logging.error(err_msg)
|
||||
return json_error_response(err_msg)
|
||||
cluster = db.session.query(models.DruidCluster).filter_by(
|
||||
cluster = db.session.query(DruidCluster).filter_by(
|
||||
cluster_name=cluster_name).first()
|
||||
if not cluster:
|
||||
err_msg = __("Can't find DruidCluster with cluster_name = "
|
||||
|
@ -2335,7 +1767,7 @@ class Superset(BaseSupersetView):
|
|||
logging.error(err_msg)
|
||||
return json_error_response(err_msg)
|
||||
try:
|
||||
models.DruidDatasource.sync_to_db_from_config(
|
||||
DruidDatasource.sync_to_db_from_config(
|
||||
druid_config, user, cluster)
|
||||
except Exception as e:
|
||||
logging.exception(utils.error_msg_from_exception(e))
|
||||
|
@ -2349,13 +1781,14 @@ class Superset(BaseSupersetView):
|
|||
data = json.loads(request.form.get('data'))
|
||||
table_name = data.get('datasourceName')
|
||||
viz_type = data.get('chartType')
|
||||
SqlaTable = ConnectorRegistry.sources['table']
|
||||
table = (
|
||||
db.session.query(models.SqlaTable)
|
||||
db.session.query(SqlaTable)
|
||||
.filter_by(table_name=table_name)
|
||||
.first()
|
||||
)
|
||||
if not table:
|
||||
table = models.SqlaTable(table_name=table_name)
|
||||
table = SqlaTable(table_name=table_name)
|
||||
table.database_id = data.get('dbId')
|
||||
q = SupersetQuery(data.get('sql'))
|
||||
table.sql = q.stripped()
|
||||
|
@ -2642,7 +2075,7 @@ class Superset(BaseSupersetView):
|
|||
def fetch_datasource_metadata(self):
|
||||
datasource_id, datasource_type = (
|
||||
request.args.get('datasourceKey').split('__'))
|
||||
datasource_class = SourceRegistry.sources[datasource_type]
|
||||
datasource_class = ConnectorRegistry.sources[datasource_type]
|
||||
datasource = (
|
||||
db.session.query(datasource_class)
|
||||
.filter_by(id=int(datasource_id))
|
||||
|
@ -2740,7 +2173,8 @@ class Superset(BaseSupersetView):
|
|||
def refresh_datasources(self):
|
||||
"""endpoint that refreshes druid datasources metadata"""
|
||||
session = db.session()
|
||||
for cluster in session.query(models.DruidCluster).all():
|
||||
DruidCluster = ConnectorRegistry.sources['druid']
|
||||
for cluster in session.query(DruidCluster).all():
|
||||
cluster_name = cluster.cluster_name
|
||||
try:
|
||||
cluster.refresh_datasources()
|
|
@ -9,7 +9,11 @@ import mock
|
|||
import unittest
|
||||
|
||||
from superset import db, models, sm, security
|
||||
from superset.source_registry import SourceRegistry
|
||||
|
||||
from superset.models import core as models
|
||||
from superset.connectors.connector_registry import ConnectorRegistry
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.connectors.druid.models import DruidDatasource
|
||||
|
||||
from .base_tests import SupersetTestCase
|
||||
|
||||
|
@ -58,7 +62,7 @@ SCHEMA_ACCESS_ROLE = 'schema_access_role'
|
|||
|
||||
|
||||
def create_access_request(session, ds_type, ds_name, role_name, user_name):
|
||||
ds_class = SourceRegistry.sources[ds_type]
|
||||
ds_class = ConnectorRegistry.sources[ds_type]
|
||||
# TODO: generalize datasource names
|
||||
if ds_type == 'table':
|
||||
ds = session.query(ds_class).filter(
|
||||
|
@ -293,7 +297,7 @@ class RequestAccessTests(SupersetTestCase):
|
|||
access_request2 = create_access_request(
|
||||
session, 'table', 'wb_health_population', TEST_ROLE_2, 'gamma2')
|
||||
ds_1_id = access_request1.datasource_id
|
||||
ds = session.query(models.SqlaTable).filter_by(
|
||||
ds = session.query(SqlaTable).filter_by(
|
||||
table_name='wb_health_population').first()
|
||||
|
||||
|
||||
|
@ -314,7 +318,7 @@ class RequestAccessTests(SupersetTestCase):
|
|||
gamma_user = sm.find_user(username='gamma')
|
||||
gamma_user.roles.remove(sm.find_role(SCHEMA_ACCESS_ROLE))
|
||||
|
||||
ds = session.query(models.SqlaTable).filter_by(
|
||||
ds = session.query(SqlaTable).filter_by(
|
||||
table_name='wb_health_population').first()
|
||||
ds.schema = None
|
||||
|
||||
|
@ -441,7 +445,7 @@ class RequestAccessTests(SupersetTestCase):
|
|||
|
||||
# Request table access, there are no roles have this table.
|
||||
|
||||
table1 = session.query(models.SqlaTable).filter_by(
|
||||
table1 = session.query(SqlaTable).filter_by(
|
||||
table_name='random_time_series').first()
|
||||
table_1_id = table1.id
|
||||
|
||||
|
@ -454,7 +458,7 @@ class RequestAccessTests(SupersetTestCase):
|
|||
|
||||
# Request access, roles exist that contains the table.
|
||||
# add table to the existing roles
|
||||
table3 = session.query(models.SqlaTable).filter_by(
|
||||
table3 = session.query(SqlaTable).filter_by(
|
||||
table_name='energy_usage').first()
|
||||
table_3_id = table3.id
|
||||
table3_perm = table3.perm
|
||||
|
@ -479,7 +483,7 @@ class RequestAccessTests(SupersetTestCase):
|
|||
'<ul><li>{}</li></ul>'.format(approve_link_3))
|
||||
|
||||
# Request druid access, there are no roles have this table.
|
||||
druid_ds_4 = session.query(models.DruidDatasource).filter_by(
|
||||
druid_ds_4 = session.query(DruidDatasource).filter_by(
|
||||
datasource_name='druid_ds_1').first()
|
||||
druid_ds_4_id = druid_ds_4.id
|
||||
|
||||
|
@ -493,7 +497,7 @@ class RequestAccessTests(SupersetTestCase):
|
|||
|
||||
# Case 5. Roles exist that contains the druid datasource.
|
||||
# add druid ds to the existing roles
|
||||
druid_ds_5 = session.query(models.DruidDatasource).filter_by(
|
||||
druid_ds_5 = session.query(DruidDatasource).filter_by(
|
||||
datasource_name='druid_ds_2').first()
|
||||
druid_ds_5_id = druid_ds_5.id
|
||||
druid_ds_5_perm = druid_ds_5.perm
|
||||
|
|
|
@ -11,8 +11,11 @@ import unittest
|
|||
|
||||
from flask_appbuilder.security.sqla import models as ab_models
|
||||
|
||||
from superset import app, cli, db, models, appbuilder, security, sm
|
||||
from superset import app, cli, db, appbuilder, security, sm
|
||||
from superset.models import core as models
|
||||
from superset.security import sync_role_definitions
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.connectors.druid.models import DruidCluster, DruidDatasource
|
||||
|
||||
os.environ['SUPERSET_CONFIG'] = 'tests.superset_test_config'
|
||||
|
||||
|
@ -85,30 +88,34 @@ class SupersetTestCase(unittest.TestCase):
|
|||
appbuilder.sm.find_role('Alpha'),
|
||||
password='general')
|
||||
sm.get_session.commit()
|
||||
|
||||
# create druid cluster and druid datasources
|
||||
session = db.session
|
||||
cluster = session.query(models.DruidCluster).filter_by(
|
||||
cluster_name="druid_test").first()
|
||||
cluster = (
|
||||
session.query(DruidCluster)
|
||||
.filter_by(cluster_name="druid_test")
|
||||
.first()
|
||||
)
|
||||
if not cluster:
|
||||
cluster = models.DruidCluster(cluster_name="druid_test")
|
||||
cluster = DruidCluster(cluster_name="druid_test")
|
||||
session.add(cluster)
|
||||
session.commit()
|
||||
|
||||
druid_datasource1 = models.DruidDatasource(
|
||||
druid_datasource1 = DruidDatasource(
|
||||
datasource_name='druid_ds_1',
|
||||
cluster_name='druid_test'
|
||||
)
|
||||
session.add(druid_datasource1)
|
||||
druid_datasource2 = models.DruidDatasource(
|
||||
druid_datasource2 = DruidDatasource(
|
||||
datasource_name='druid_ds_2',
|
||||
cluster_name='druid_test'
|
||||
)
|
||||
session.add(druid_datasource2)
|
||||
session.commit()
|
||||
|
||||
|
||||
|
||||
def get_table(self, table_id):
|
||||
return db.session.query(models.SqlaTable).filter_by(
|
||||
return db.session.query(SqlaTable).filter_by(
|
||||
id=table_id).first()
|
||||
|
||||
def get_or_create(self, cls, criteria, session):
|
||||
|
@ -149,11 +156,11 @@ class SupersetTestCase(unittest.TestCase):
|
|||
return slc
|
||||
|
||||
def get_table_by_name(self, name):
|
||||
return db.session.query(models.SqlaTable).filter_by(
|
||||
return db.session.query(SqlaTable).filter_by(
|
||||
table_name=name).first()
|
||||
|
||||
def get_druid_ds_by_name(self, name):
|
||||
return db.session.query(models.DruidDatasource).filter_by(
|
||||
return db.session.query(DruidDatasource).filter_by(
|
||||
datasource_name=name).first()
|
||||
|
||||
def get_resp(
|
||||
|
|
|
@ -12,14 +12,14 @@ import unittest
|
|||
|
||||
import pandas as pd
|
||||
|
||||
from superset import app, appbuilder, cli, db, models, dataframe
|
||||
from superset import app, appbuilder, cli, db, dataframe
|
||||
from superset.models import core as models
|
||||
from superset.models.helpers import QueryStatus
|
||||
from superset.security import sync_role_definitions
|
||||
from superset.sql_parse import SupersetQuery
|
||||
|
||||
from .base_tests import SupersetTestCase
|
||||
|
||||
QueryStatus = models.QueryStatus
|
||||
|
||||
BASE_DIR = app.config.get('BASE_DIR')
|
||||
|
||||
|
||||
|
|
|
@ -7,17 +7,19 @@ from __future__ import unicode_literals
|
|||
import csv
|
||||
import doctest
|
||||
import json
|
||||
import logging
|
||||
import io
|
||||
import random
|
||||
import unittest
|
||||
|
||||
from flask import escape
|
||||
|
||||
from superset import db, models, utils, appbuilder, sm, jinja_context, sql_lab
|
||||
from superset.views import DatabaseView
|
||||
from superset import db, utils, appbuilder, sm, jinja_context, sql_lab
|
||||
from superset.models import core as models
|
||||
from superset.views.core import DatabaseView
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
|
||||
from .base_tests import SupersetTestCase
|
||||
import logging
|
||||
|
||||
|
||||
class CoreTests(SupersetTestCase):
|
||||
|
@ -31,7 +33,7 @@ class CoreTests(SupersetTestCase):
|
|||
def setUpClass(cls):
|
||||
cls.table_ids = {tbl.table_name: tbl.id for tbl in (
|
||||
db.session
|
||||
.query(models.SqlaTable)
|
||||
.query(SqlaTable)
|
||||
.all()
|
||||
)}
|
||||
|
||||
|
@ -186,7 +188,7 @@ class CoreTests(SupersetTestCase):
|
|||
slice_id = self.get_slice(slice_name, db.session).id
|
||||
db.session.commit()
|
||||
tbl_id = self.table_ids.get('energy_usage')
|
||||
table = db.session.query(models.SqlaTable).filter(models.SqlaTable.id == tbl_id)
|
||||
table = db.session.query(SqlaTable).filter(SqlaTable.id == tbl_id)
|
||||
table.filter_select_enabled = True
|
||||
url = (
|
||||
"/superset/filter/table/{}/target/?viz_type=sankey&groupby=source"
|
||||
|
@ -220,7 +222,7 @@ class CoreTests(SupersetTestCase):
|
|||
url = '/tablemodelview/list/'
|
||||
resp = self.get_resp(url)
|
||||
|
||||
table = db.session.query(models.SqlaTable).first()
|
||||
table = db.session.query(SqlaTable).first()
|
||||
assert table.name in resp
|
||||
assert '/superset/explore/table/{}'.format(table.id) in resp
|
||||
|
||||
|
@ -459,7 +461,7 @@ class CoreTests(SupersetTestCase):
|
|||
def test_public_user_dashboard_access(self):
|
||||
table = (
|
||||
db.session
|
||||
.query(models.SqlaTable)
|
||||
.query(SqlaTable)
|
||||
.filter_by(table_name='birth_names')
|
||||
.one()
|
||||
)
|
||||
|
@ -494,7 +496,7 @@ class CoreTests(SupersetTestCase):
|
|||
self.logout()
|
||||
table = (
|
||||
db.session
|
||||
.query(models.SqlaTable)
|
||||
.query(SqlaTable)
|
||||
.filter_by(table_name='birth_names')
|
||||
.one()
|
||||
)
|
||||
|
|
|
@ -11,7 +11,8 @@ import unittest
|
|||
from mock import Mock, patch
|
||||
|
||||
from superset import db, sm, security
|
||||
from superset.models import DruidCluster, DruidDatasource
|
||||
from superset.connectors.druid.models import DruidCluster, DruidDatasource
|
||||
from superset.connectors.druid.models import PyDruid
|
||||
|
||||
from .base_tests import SupersetTestCase
|
||||
|
||||
|
@ -70,7 +71,7 @@ class DruidTests(SupersetTestCase):
|
|||
def __init__(self, *args, **kwargs):
|
||||
super(DruidTests, self).__init__(*args, **kwargs)
|
||||
|
||||
@patch('superset.models.PyDruid')
|
||||
@patch('superset.connectors.druid.models.PyDruid')
|
||||
def test_client(self, PyDruid):
|
||||
self.login(username='admin')
|
||||
instance = PyDruid.return_value
|
||||
|
@ -197,8 +198,12 @@ class DruidTests(SupersetTestCase):
|
|||
}
|
||||
def check():
|
||||
resp = self.client.post('/superset/sync_druid/', data=json.dumps(cfg))
|
||||
druid_ds = db.session.query(DruidDatasource).filter_by(
|
||||
datasource_name="test_click").first()
|
||||
druid_ds = (
|
||||
db.session
|
||||
.query(DruidDatasource)
|
||||
.filter_by(datasource_name="test_click")
|
||||
.one()
|
||||
)
|
||||
col_names = set([c.column_name for c in druid_ds.columns])
|
||||
assert {"affiliate_id", "campaign", "first_seen"} == col_names
|
||||
metric_names = {m.metric_name for m in druid_ds.metrics}
|
||||
|
@ -224,7 +229,7 @@ class DruidTests(SupersetTestCase):
|
|||
}
|
||||
resp = self.client.post('/superset/sync_druid/', data=json.dumps(cfg))
|
||||
druid_ds = db.session.query(DruidDatasource).filter_by(
|
||||
datasource_name="test_click").first()
|
||||
datasource_name="test_click").one()
|
||||
# columns and metrics are not deleted if config is changed as
|
||||
# user could define his own dimensions / metrics and want to keep them
|
||||
assert set([c.column_name for c in druid_ds.columns]) == set(
|
||||
|
|
|
@ -10,7 +10,11 @@ import json
|
|||
import pickle
|
||||
import unittest
|
||||
|
||||
from superset import db, models
|
||||
from superset import db
|
||||
from superset.models import core as models
|
||||
from superset.connectors.druid.models import (
|
||||
DruidDatasource, DruidColumn, DruidMetric)
|
||||
from superset.connectors.sqla.models import SqlaTable, TableColumn, SqlMetric
|
||||
|
||||
from .base_tests import SupersetTestCase
|
||||
|
||||
|
@ -31,10 +35,10 @@ class ImportExportTests(SupersetTestCase):
|
|||
for dash in session.query(models.Dashboard):
|
||||
if 'remote_id' in dash.params_dict:
|
||||
session.delete(dash)
|
||||
for table in session.query(models.SqlaTable):
|
||||
for table in session.query(SqlaTable):
|
||||
if 'remote_id' in table.params_dict:
|
||||
session.delete(table)
|
||||
for datasource in session.query(models.DruidDatasource):
|
||||
for datasource in session.query(DruidDatasource):
|
||||
if 'remote_id' in datasource.params_dict:
|
||||
session.delete(datasource)
|
||||
session.commit()
|
||||
|
@ -90,7 +94,7 @@ class ImportExportTests(SupersetTestCase):
|
|||
def create_table(
|
||||
self, name, schema='', id=0, cols_names=[], metric_names=[]):
|
||||
params = {'remote_id': id, 'database_name': 'main'}
|
||||
table = models.SqlaTable(
|
||||
table = SqlaTable(
|
||||
id=id,
|
||||
schema=schema,
|
||||
table_name=name,
|
||||
|
@ -98,15 +102,15 @@ class ImportExportTests(SupersetTestCase):
|
|||
)
|
||||
for col_name in cols_names:
|
||||
table.columns.append(
|
||||
models.TableColumn(column_name=col_name))
|
||||
TableColumn(column_name=col_name))
|
||||
for metric_name in metric_names:
|
||||
table.metrics.append(models.SqlMetric(metric_name=metric_name))
|
||||
table.metrics.append(SqlMetric(metric_name=metric_name))
|
||||
return table
|
||||
|
||||
def create_druid_datasource(
|
||||
self, name, id=0, cols_names=[], metric_names=[]):
|
||||
params = {'remote_id': id, 'database_name': 'druid_test'}
|
||||
datasource = models.DruidDatasource(
|
||||
datasource = DruidDatasource(
|
||||
id=id,
|
||||
datasource_name=name,
|
||||
cluster_name='druid_test',
|
||||
|
@ -114,9 +118,9 @@ class ImportExportTests(SupersetTestCase):
|
|||
)
|
||||
for col_name in cols_names:
|
||||
datasource.columns.append(
|
||||
models.DruidColumn(column_name=col_name))
|
||||
DruidColumn(column_name=col_name))
|
||||
for metric_name in metric_names:
|
||||
datasource.metrics.append(models.DruidMetric(
|
||||
datasource.metrics.append(DruidMetric(
|
||||
metric_name=metric_name))
|
||||
return datasource
|
||||
|
||||
|
@ -136,11 +140,11 @@ class ImportExportTests(SupersetTestCase):
|
|||
slug=dash_slug).first()
|
||||
|
||||
def get_datasource(self, datasource_id):
|
||||
return db.session.query(models.DruidDatasource).filter_by(
|
||||
return db.session.query(DruidDatasource).filter_by(
|
||||
id=datasource_id).first()
|
||||
|
||||
def get_table_by_name(self, name):
|
||||
return db.session.query(models.SqlaTable).filter_by(
|
||||
return db.session.query(SqlaTable).filter_by(
|
||||
table_name=name).first()
|
||||
|
||||
def assert_dash_equals(self, expected_dash, actual_dash,
|
||||
|
@ -392,7 +396,7 @@ class ImportExportTests(SupersetTestCase):
|
|||
|
||||
def test_import_table_no_metadata(self):
|
||||
table = self.create_table('pure_table', id=10001)
|
||||
imported_id = models.SqlaTable.import_obj(table, import_time=1989)
|
||||
imported_id = SqlaTable.import_obj(table, import_time=1989)
|
||||
imported = self.get_table(imported_id)
|
||||
self.assert_table_equals(table, imported)
|
||||
|
||||
|
@ -400,7 +404,7 @@ class ImportExportTests(SupersetTestCase):
|
|||
table = self.create_table(
|
||||
'table_1_col_1_met', id=10002,
|
||||
cols_names=["col1"], metric_names=["metric1"])
|
||||
imported_id = models.SqlaTable.import_obj(table, import_time=1990)
|
||||
imported_id = SqlaTable.import_obj(table, import_time=1990)
|
||||
imported = self.get_table(imported_id)
|
||||
self.assert_table_equals(table, imported)
|
||||
self.assertEquals(
|
||||
|
@ -411,7 +415,7 @@ class ImportExportTests(SupersetTestCase):
|
|||
table = self.create_table(
|
||||
'table_2_col_2_met', id=10003, cols_names=['c1', 'c2'],
|
||||
metric_names=['m1', 'm2'])
|
||||
imported_id = models.SqlaTable.import_obj(table, import_time=1991)
|
||||
imported_id = SqlaTable.import_obj(table, import_time=1991)
|
||||
|
||||
imported = self.get_table(imported_id)
|
||||
self.assert_table_equals(table, imported)
|
||||
|
@ -420,12 +424,12 @@ class ImportExportTests(SupersetTestCase):
|
|||
table = self.create_table(
|
||||
'table_override', id=10003, cols_names=['col1'],
|
||||
metric_names=['m1'])
|
||||
imported_id = models.SqlaTable.import_obj(table, import_time=1991)
|
||||
imported_id = SqlaTable.import_obj(table, import_time=1991)
|
||||
|
||||
table_over = self.create_table(
|
||||
'table_override', id=10003, cols_names=['new_col1', 'col2', 'col3'],
|
||||
metric_names=['new_metric1'])
|
||||
imported_over_id = models.SqlaTable.import_obj(
|
||||
imported_over_id = SqlaTable.import_obj(
|
||||
table_over, import_time=1992)
|
||||
|
||||
imported_over = self.get_table(imported_over_id)
|
||||
|
@ -439,12 +443,12 @@ class ImportExportTests(SupersetTestCase):
|
|||
table = self.create_table(
|
||||
'copy_cat', id=10004, cols_names=['new_col1', 'col2', 'col3'],
|
||||
metric_names=['new_metric1'])
|
||||
imported_id = models.SqlaTable.import_obj(table, import_time=1993)
|
||||
imported_id = SqlaTable.import_obj(table, import_time=1993)
|
||||
|
||||
copy_table = self.create_table(
|
||||
'copy_cat', id=10004, cols_names=['new_col1', 'col2', 'col3'],
|
||||
metric_names=['new_metric1'])
|
||||
imported_id_copy = models.SqlaTable.import_obj(
|
||||
imported_id_copy = SqlaTable.import_obj(
|
||||
copy_table, import_time=1994)
|
||||
|
||||
self.assertEquals(imported_id, imported_id_copy)
|
||||
|
@ -452,7 +456,7 @@ class ImportExportTests(SupersetTestCase):
|
|||
|
||||
def test_import_druid_no_metadata(self):
|
||||
datasource = self.create_druid_datasource('pure_druid', id=10001)
|
||||
imported_id = models.DruidDatasource.import_obj(
|
||||
imported_id = DruidDatasource.import_obj(
|
||||
datasource, import_time=1989)
|
||||
imported = self.get_datasource(imported_id)
|
||||
self.assert_datasource_equals(datasource, imported)
|
||||
|
@ -461,7 +465,7 @@ class ImportExportTests(SupersetTestCase):
|
|||
datasource = self.create_druid_datasource(
|
||||
'druid_1_col_1_met', id=10002,
|
||||
cols_names=["col1"], metric_names=["metric1"])
|
||||
imported_id = models.DruidDatasource.import_obj(
|
||||
imported_id = DruidDatasource.import_obj(
|
||||
datasource, import_time=1990)
|
||||
imported = self.get_datasource(imported_id)
|
||||
self.assert_datasource_equals(datasource, imported)
|
||||
|
@ -474,7 +478,7 @@ class ImportExportTests(SupersetTestCase):
|
|||
datasource = self.create_druid_datasource(
|
||||
'druid_2_col_2_met', id=10003, cols_names=['c1', 'c2'],
|
||||
metric_names=['m1', 'm2'])
|
||||
imported_id = models.DruidDatasource.import_obj(
|
||||
imported_id = DruidDatasource.import_obj(
|
||||
datasource, import_time=1991)
|
||||
imported = self.get_datasource(imported_id)
|
||||
self.assert_datasource_equals(datasource, imported)
|
||||
|
@ -483,14 +487,14 @@ class ImportExportTests(SupersetTestCase):
|
|||
datasource = self.create_druid_datasource(
|
||||
'druid_override', id=10003, cols_names=['col1'],
|
||||
metric_names=['m1'])
|
||||
imported_id = models.DruidDatasource.import_obj(
|
||||
imported_id = DruidDatasource.import_obj(
|
||||
datasource, import_time=1991)
|
||||
|
||||
table_over = self.create_druid_datasource(
|
||||
'druid_override', id=10003,
|
||||
cols_names=['new_col1', 'col2', 'col3'],
|
||||
metric_names=['new_metric1'])
|
||||
imported_over_id = models.DruidDatasource.import_obj(
|
||||
imported_over_id = DruidDatasource.import_obj(
|
||||
table_over, import_time=1992)
|
||||
|
||||
imported_over = self.get_datasource(imported_over_id)
|
||||
|
@ -504,13 +508,13 @@ class ImportExportTests(SupersetTestCase):
|
|||
datasource = self.create_druid_datasource(
|
||||
'copy_cat', id=10004, cols_names=['new_col1', 'col2', 'col3'],
|
||||
metric_names=['new_metric1'])
|
||||
imported_id = models.DruidDatasource.import_obj(
|
||||
imported_id = DruidDatasource.import_obj(
|
||||
datasource, import_time=1993)
|
||||
|
||||
copy_datasource = self.create_druid_datasource(
|
||||
'copy_cat', id=10004, cols_names=['new_col1', 'col2', 'col3'],
|
||||
metric_names=['new_metric1'])
|
||||
imported_id_copy = models.DruidDatasource.import_obj(
|
||||
imported_id_copy = DruidDatasource.import_obj(
|
||||
copy_datasource, import_time=1994)
|
||||
|
||||
self.assertEquals(imported_id, imported_id_copy)
|
||||
|
|
|
@ -2,7 +2,7 @@ import unittest
|
|||
|
||||
from sqlalchemy.engine.url import make_url
|
||||
|
||||
from superset.models import Database
|
||||
from superset.models.core import Database
|
||||
|
||||
|
||||
class DatabaseModelTestCase(unittest.TestCase):
|
||||
|
|
|
@ -9,7 +9,9 @@ import json
|
|||
import unittest
|
||||
|
||||
from flask_appbuilder.security.sqla import models as ab_models
|
||||
from superset import db, models, utils, appbuilder, security, sm
|
||||
from superset import db, utils, appbuilder, sm
|
||||
from superset.models import core as models
|
||||
|
||||
from .base_tests import SupersetTestCase
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue