From 131372740e79fa0d6ebd7484026cb9ac5918f631 Mon Sep 17 00:00:00 2001 From: Maxime Beauchemin Date: Thu, 23 Jun 2016 22:43:52 -0700 Subject: [PATCH] Adding orderby to Table 'not grouped by' and fixing metrics ordering (#669) --- caravel/assets/visualizations/table.js | 1 + caravel/forms.py | 10 ++++ caravel/models.py | 69 +++++++++++++++++--------- caravel/viz.py | 10 ++-- 4 files changed, 59 insertions(+), 31 deletions(-) diff --git a/caravel/assets/visualizations/table.js b/caravel/assets/visualizations/table.js index 5683819eef..813b403f5d 100644 --- a/caravel/assets/visualizations/table.js +++ b/caravel/assets/visualizations/table.js @@ -119,6 +119,7 @@ function tableVis(slice) { var height = slice.container.height(); var datatable = slice.container.find('.dataTable').DataTable({ paging: false, + aaSorting: [], searching: form_data.include_search, bInfo: false, scrollY: height + "px", diff --git a/caravel/forms.py b/caravel/forms.py index 86f7fef148..89b14c72c2 100644 --- a/caravel/forms.py +++ b/caravel/forms.py @@ -6,6 +6,7 @@ from __future__ import unicode_literals from collections import OrderedDict from copy import copy +import json import math from flask_babelpkg import lazy_gettext as _ @@ -129,6 +130,10 @@ class FormFactory(object): gb_cols = datasource.groupby_column_names default_groupby = gb_cols[0] if gb_cols else None group_by_choices = self.choicify(gb_cols) + order_by_choices = [] + for s in sorted(datasource.num_cols): + order_by_choices.append((json.dumps([s, True]), s + ' [asc]')) + order_by_choices.append((json.dumps([s, False]), s + ' [desc]')) # Pool of all the fields that can be used in Caravel field_data = { 'viz_type': (SelectField, { @@ -143,6 +148,11 @@ class FormFactory(object): "default": [default_metric], "description": _("One or many metrics to display") }), + 'order_by_cols': (SelectMultipleSortableField, { + "label": _("Ordering"), + "choices": order_by_choices, + "description": _("One or many metrics to display") + }), 'metric': (SelectField, { "label": _("Metric"), "choices": datasource.metrics_combo, diff --git a/caravel/models.py b/caravel/models.py index db04e39cd2..36d58255d1 100644 --- a/caravel/models.py +++ b/caravel/models.py @@ -32,7 +32,7 @@ from pydruid.utils.having import Having, Aggregation from six import string_types from sqlalchemy import ( Column, Integer, String, ForeignKey, Text, Boolean, DateTime, Date, - Table, create_engine, MetaData, desc, select, and_, func) + Table, create_engine, MetaData, desc, asc, select, and_, func) from sqlalchemy.engine import reflection from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.orm import relationship @@ -533,6 +533,7 @@ class SqlaTable(Model, Queryable, AuditMixinNullable): 'database_id', 'schema', 'table_name', name='_customer_location_uc'),) + def __repr__(self): return self.table_name @@ -561,6 +562,10 @@ class SqlaTable(Model, Queryable, AuditMixinNullable): l.append(self.main_dttm_col) return l + @property + def num_cols(self): + return [c.column_name for c in self.columns if c.isnum] + @property def any_dttm_col(self): cols = self.dttm_cols @@ -610,6 +615,7 @@ class SqlaTable(Model, Queryable, AuditMixinNullable): is_timeseries=True, timeseries_limit=15, row_limit=None, inner_from_dttm=None, inner_to_dttm=None, + orderby=None, extras=None, columns=None): """Querying any sqla table from this common interface""" @@ -618,6 +624,7 @@ class SqlaTable(Model, Queryable, AuditMixinNullable): granularity = self.main_dttm_col cols = {col.column_name: col for col in self.columns} + metrics_dict = {m.metric_name: m for m in self.metrics} qry_start_dttm = datetime.now() if not granularity and is_timeseries: @@ -625,14 +632,10 @@ class SqlaTable(Model, Queryable, AuditMixinNullable): "Datetime column not provided as part table configuration " "and is required by this type of chart")) - metrics_exprs = [ - m.sqla_col - for m in self.metrics if m.metric_name in metrics] + metrics_exprs = [metrics_dict.get(m).sqla_col for m in metrics] if metrics: - main_metric_expr = [ - m.sqla_col for m in self.metrics - if m.metric_name == metrics[0]][0] + main_metric_expr = metrics_exprs[0] else: main_metric_expr = literal_column("COUNT(*)").label("ccount") @@ -720,6 +723,11 @@ class SqlaTable(Model, Queryable, AuditMixinNullable): qry = qry.having(and_(*having_clause_and)) if groupby: qry = qry.order_by(desc(main_metric_expr)) + elif orderby: + for col, ascending in orderby: + direction = asc if ascending else desc + qry = qry.order_by(direction(col)) + qry = qry.limit(row_limit) if timeseries_limit and groupby: @@ -768,9 +776,12 @@ class SqlaTable(Model, Queryable, AuditMixinNullable): any_date_col = None for col in table.columns: try: - datatype = str(col.type) + datatype = "{}".format(col.type).upper() except Exception as e: datatype = "UNKNOWN" + logging.error( + "Unrecognized data type in {}.{}".format(table, col.name)) + logging.exception(e) dbcol = ( db.session .query(TC) @@ -780,22 +791,16 @@ class SqlaTable(Model, Queryable, AuditMixinNullable): ) db.session.flush() if not dbcol: - dbcol = TableColumn(column_name=col.name) - num_types = ('DOUBLE', 'FLOAT', 'INT', 'BIGINT', 'LONG') - date_types = ('DATE', 'TIME') - str_types = ('VARCHAR', 'STRING') - datatype = str(datatype).upper() - if any([t in datatype for t in str_types]): - dbcol.groupby = True - dbcol.filterable = True - elif any([t in datatype for t in num_types]): - dbcol.sum = True - elif any([t in datatype for t in date_types]): - dbcol.is_dttm = True + dbcol = TableColumn(column_name=col.name, type=datatype) + dbcol.groupby = dbcol.is_string + dbcol.filterable = dbcol.is_string + dbcol.sum = dbcol.isnum + dbcol.is_dttm = dbcol.is_time + db.session.merge(self) self.columns.append(dbcol) - if not any_date_col and 'date' in datatype.lower(): + if not any_date_col and dbcol.is_time: any_date_col = col.name quoted = "{}".format( @@ -905,13 +910,24 @@ class TableColumn(Model, AuditMixinNullable): expression = Column(Text, default='') description = Column(Text, default='') + num_types = ('DOUBLE', 'FLOAT', 'INT', 'BIGINT', 'LONG') + date_types = ('DATE', 'TIME') + str_types = ('VARCHAR', 'STRING', 'CHAR') + def __repr__(self): return self.column_name @property def isnum(self): - types = ('LONG', 'DOUBLE', 'FLOAT', 'BIGINT', 'INT') - return any([t in self.type.upper() for t in types]) + return any([t in self.type.upper() for t in self.num_types]) + + @property + def is_time(self): + return any([t in self.type.upper() for t in self.date_types]) + + @property + def is_string(self): + return any([t in self.type.upper() for t in self.str_types]) @property def sqla_col(self): @@ -999,6 +1015,10 @@ class DruidDatasource(Model, AuditMixinNullable, Queryable): [(m.metric_name, m.verbose_name) for m in self.metrics], key=lambda x: x[1]) + @property + def num_cols(self): + return [c.column_name for c in self.columns if c.isnum] + @property def name(self): return self.datasource_name @@ -1119,6 +1139,7 @@ class DruidDatasource(Model, AuditMixinNullable, Queryable): timeseries_limit=None, row_limit=None, inner_from_dttm=None, inner_to_dttm=None, + orderby=None, extras=None, # noqa select=None,): # noqa """Runs a query against Druid and returns a dataframe. @@ -1458,7 +1479,7 @@ class DruidColumn(Model, AuditMixinNullable): @property def isnum(self): - return self.type in ('LONG', 'DOUBLE', 'FLOAT') + return self.type in ('LONG', 'DOUBLE', 'FLOAT', 'INT') def generate_metrics(self): """Generate metrics based on the column metadata""" diff --git a/caravel/viz.py b/caravel/viz.py index a5e327ec94..cfef8b1452 100644 --- a/caravel/viz.py +++ b/caravel/viz.py @@ -350,16 +350,11 @@ class TableViz(BaseViz): fieldsets = ({ 'label': _("GROUP BY"), 'description': _('Use this section if you want a query that aggregates'), - 'fields': ( - 'groupby', - 'metrics', - ) + 'fields': ('groupby', 'metrics') }, { 'label': _("NOT GROUPED BY"), 'description': _('Use this section if you want to query atomic rows'), - 'fields': ( - 'all_columns', - ) + 'fields': ('all_columns', 'order_by_cols'), }, { 'label': _("Options"), 'fields': ( @@ -385,6 +380,7 @@ class TableViz(BaseViz): if fd.get('all_columns'): d['columns'] = fd.get('all_columns') d['groupby'] = [] + d['orderby'] = [json.loads(t) for t in fd.get('order_by_cols', [])] return d def get_df(self, query_obj=None):