From 686023c8ddd58d4d562128c794d48524bc635217 Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Fri, 5 Jan 2018 13:52:58 -0800 Subject: [PATCH] Druid support via SQLAlchemy (#4163) * Use druiddb * Remove auto formatting * Show prequeries * Fix subtle bug with lists * Move arguments to query object * Fix druid run_query --- superset/connectors/druid/models.py | 5 +- superset/connectors/sqla/models.py | 99 +++++++++++++++++++++-------- superset/db_engine_specs.py | 2 + superset/views/core.py | 6 ++ superset/viz.py | 2 + 5 files changed, 87 insertions(+), 27 deletions(-) diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py index 45d5d25913..4689ef2adb 100644 --- a/superset/connectors/druid/models.py +++ b/superset/connectors/druid/models.py @@ -1022,7 +1022,10 @@ class DruidDatasource(Model, BaseDatasource): orderby=None, extras=None, # noqa columns=None, phase=2, client=None, form_data=None, - order_desc=True): + order_desc=True, + prequeries=None, + is_prequery=False, + ): """Runs a query against Druid and returns a dataframe. """ # TODO refactor into using a TBD Query object diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index e13a8dfd4a..889aea3ba9 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -370,6 +370,8 @@ class SqlaTable(Model, BaseDatasource): ) logging.info(sql) sql = sqlparse.format(sql, reindent=True) + if query_obj['is_prequery']: + query_obj['prequeries'].append(sql) return sql def get_sqla_table(self): @@ -405,7 +407,10 @@ class SqlaTable(Model, BaseDatasource): extras=None, columns=None, form_data=None, - order_desc=True): + order_desc=True, + prequeries=None, + is_prequery=False, + ): """Querying any sqla table from this common interface""" template_kwargs = { 'from_dttm': from_dttm, @@ -564,37 +569,73 @@ class SqlaTable(Model, BaseDatasource): if is_timeseries and \ timeseries_limit and groupby and not time_groupby_inline: - # some sql dialects require for order by expressions - # to also be in the select clause -- others, e.g. vertica, - # require a unique inner alias - inner_main_metric_expr = main_metric_expr.label('mme_inner__') - inner_select_exprs += [inner_main_metric_expr] - subq = select(inner_select_exprs) - subq = subq.select_from(tbl) - inner_time_filter = dttm_col.get_time_filter( - inner_from_dttm or from_dttm, - inner_to_dttm or to_dttm, - ) - subq = subq.where(and_(*(where_clause_and + [inner_time_filter]))) - subq = subq.group_by(*inner_groupby_exprs) + if self.database.db_engine_spec.inner_joins: + # some sql dialects require for order by expressions + # to also be in the select clause -- others, e.g. vertica, + # require a unique inner alias + inner_main_metric_expr = main_metric_expr.label('mme_inner__') + inner_select_exprs += [inner_main_metric_expr] + subq = select(inner_select_exprs) + subq = subq.select_from(tbl) + inner_time_filter = dttm_col.get_time_filter( + inner_from_dttm or from_dttm, + inner_to_dttm or to_dttm, + ) + subq = subq.where(and_(*(where_clause_and + [inner_time_filter]))) + subq = subq.group_by(*inner_groupby_exprs) - ob = inner_main_metric_expr - if timeseries_limit_metric: - timeseries_limit_metric = metrics_dict.get(timeseries_limit_metric) - ob = timeseries_limit_metric.sqla_col - direction = desc if order_desc else asc - subq = subq.order_by(direction(ob)) - subq = subq.limit(timeseries_limit) + ob = inner_main_metric_expr + if timeseries_limit_metric: + timeseries_limit_metric = metrics_dict.get(timeseries_limit_metric) + ob = timeseries_limit_metric.sqla_col + direction = desc if order_desc else asc + subq = subq.order_by(direction(ob)) + subq = subq.limit(timeseries_limit) - on_clause = [] - for i, gb in enumerate(groupby): - on_clause.append( - groupby_exprs[i] == column(gb + '__')) + on_clause = [] + for i, gb in enumerate(groupby): + on_clause.append( + groupby_exprs[i] == column(gb + '__')) - tbl = tbl.join(subq.alias(), and_(*on_clause)) + tbl = tbl.join(subq.alias(), and_(*on_clause)) + else: + # run subquery to get top groups + subquery_obj = { + 'prequeries': prequeries, + 'is_prequery': True, + 'is_timeseries': False, + 'row_limit': timeseries_limit, + 'groupby': groupby, + 'metrics': metrics, + 'granularity': granularity, + 'from_dttm': inner_from_dttm or from_dttm, + 'to_dttm': inner_to_dttm or to_dttm, + 'filter': filter, + 'orderby': orderby, + 'extras': extras, + 'columns': columns, + 'form_data': form_data, + 'order_desc': True, + } + result = self.query(subquery_obj) + dimensions = [c for c in result.df.columns if c not in metrics] + top_groups = self._get_top_groups(result.df, dimensions) + qry = qry.where(top_groups) return qry.select_from(tbl) + def _get_top_groups(self, df, dimensions): + cols = {col.column_name: col for col in self.columns} + groups = [] + for unused, row in df.iterrows(): + group = [] + for dimension in dimensions: + col_obj = cols.get(dimension) + group.append(col_obj.sqla_col == row[dimension]) + groups.append(and_(*group)) + + return or_(*groups) + def query(self, query_obj): qry_start_dttm = datetime.now() sql = self.get_query_str(query_obj) @@ -609,6 +650,12 @@ class SqlaTable(Model, BaseDatasource): error_message = ( self.database.db_engine_spec.extract_error_message(e)) + # if this is a main query with prequeries, combine them together + if not query_obj['is_prequery']: + query_obj['prequeries'].append(sql) + sql = ';\n\n'.join(query_obj['prequeries']) + sql += ';' + return QueryResult( status=status, df=df, diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py index e02d477f43..19f541d80e 100644 --- a/superset/db_engine_specs.py +++ b/superset/db_engine_specs.py @@ -62,6 +62,7 @@ class BaseEngineSpec(object): time_groupby_inline = False limit_method = LimitMethod.FETCH_MANY time_secondary_columns = False + inner_joins = True @classmethod def fetch_data(cls, cursor, limit): @@ -1229,6 +1230,7 @@ class DruidEngineSpec(BaseEngineSpec): """Engine spec for Druid.io""" engine = 'druid' limit_method = LimitMethod.FETCH_MANY + inner_joins = False engines = { diff --git a/superset/views/core.py b/superset/views/core.py index 2f0c0e5c24..1f5cfd022f 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -991,6 +991,12 @@ class Superset(BaseSupersetView): query = viz_obj.datasource.get_query_str(query_obj) except Exception as e: return json_error_response(e) + + if query_obj['prequeries']: + query_obj['prequeries'].append(query) + query = ';\n\n'.join(query_obj['prequeries']) + query += ';' + return Response( json.dumps({ 'query': query, diff --git a/superset/viz.py b/superset/viz.py index e649456de1..a97b04e985 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -205,6 +205,8 @@ class BaseViz(object): 'timeseries_limit_metric': timeseries_limit_metric, 'form_data': form_data, 'order_desc': order_desc, + 'prequeries': [], + 'is_prequery': False, } return d