mirror of https://github.com/apache/superset.git
[BugFix] Allowing limit ordering by post-aggregation metrics (#4646)
* Allowing limit ordering by post-aggregation metrics * don't overwrite og dictionaries * update tests * python3 compat * code review comments, add tests, implement it in groupby as well * python 3 compat for unittest * more self * Throw exception when get aggregations is called with postaggs * Treat adhoc metrics as another aggregation
This commit is contained in:
parent
68bfcefb27
commit
8be0bde683
|
@ -35,7 +35,7 @@ from sqlalchemy.orm import backref, relationship
|
|||
|
||||
from superset import conf, db, import_util, security_manager, utils
|
||||
from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric
|
||||
from superset.exceptions import MetricPermException
|
||||
from superset.exceptions import MetricPermException, SupersetException
|
||||
from superset.models.helpers import (
|
||||
AuditMixinNullable, ImportMixin, QueryResult, set_perm,
|
||||
)
|
||||
|
@ -44,6 +44,7 @@ from superset.utils import (
|
|||
)
|
||||
|
||||
DRUID_TZ = conf.get('DRUID_TZ')
|
||||
POST_AGG_TYPE = 'postagg'
|
||||
|
||||
|
||||
# Function wrapper because bound methods cannot
|
||||
|
@ -843,7 +844,7 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
"""Return a list of metrics that are post aggregations"""
|
||||
postagg_metrics = [
|
||||
metrics_dict[name] for name in postagg_names
|
||||
if metrics_dict[name].metric_type == 'postagg'
|
||||
if metrics_dict[name].metric_type == POST_AGG_TYPE
|
||||
]
|
||||
# Remove post aggregations that were found
|
||||
for postagg in postagg_metrics:
|
||||
|
@ -893,8 +894,8 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
missing_postagg, post_aggs, agg_names, visited_postaggs, metrics_dict)
|
||||
post_aggs[postagg.metric_name] = DruidDatasource.get_post_agg(postagg.json_obj)
|
||||
|
||||
@staticmethod
|
||||
def metrics_and_post_aggs(metrics, metrics_dict):
|
||||
@classmethod
|
||||
def metrics_and_post_aggs(cls, metrics, metrics_dict):
|
||||
# Separate metrics into those that are aggregations
|
||||
# and those that are post aggregations
|
||||
saved_agg_names = set()
|
||||
|
@ -903,7 +904,7 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
for metric in metrics:
|
||||
if utils.is_adhoc_metric(metric):
|
||||
adhoc_agg_configs.append(metric)
|
||||
elif metrics_dict[metric].metric_type != 'postagg':
|
||||
elif metrics_dict[metric].metric_type != POST_AGG_TYPE:
|
||||
saved_agg_names.add(metric)
|
||||
else:
|
||||
postagg_names.append(metric)
|
||||
|
@ -914,9 +915,10 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
for postagg_name in postagg_names:
|
||||
postagg = metrics_dict[postagg_name]
|
||||
visited_postaggs.add(postagg_name)
|
||||
DruidDatasource.resolve_postagg(
|
||||
cls.resolve_postagg(
|
||||
postagg, post_aggs, saved_agg_names, visited_postaggs, metrics_dict)
|
||||
return list(saved_agg_names), adhoc_agg_configs, post_aggs
|
||||
aggs = cls.get_aggregations(metrics_dict, saved_agg_names, adhoc_agg_configs)
|
||||
return aggs, post_aggs
|
||||
|
||||
def values_for_column(self,
|
||||
column_name,
|
||||
|
@ -982,16 +984,35 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
else:
|
||||
return column_type + aggregate.capitalize()
|
||||
|
||||
def get_aggregations(self, saved_metrics, adhoc_metrics=[]):
|
||||
@staticmethod
|
||||
def get_aggregations(metrics_dict, saved_metrics, adhoc_metrics=[]):
|
||||
"""
|
||||
Returns a dictionary of aggregation metric names to aggregation json objects
|
||||
|
||||
:param metrics_dict: dictionary of all the metrics
|
||||
:param saved_metrics: list of saved metric names
|
||||
:param adhoc_metrics: list of adhoc metric names
|
||||
:raise SupersetException: if one or more metric names are not aggregations
|
||||
"""
|
||||
aggregations = OrderedDict()
|
||||
for m in self.metrics:
|
||||
if m.metric_name in saved_metrics:
|
||||
aggregations[m.metric_name] = m.json_obj
|
||||
invalid_metric_names = []
|
||||
for metric_name in saved_metrics:
|
||||
if metric_name in metrics_dict:
|
||||
metric = metrics_dict[metric_name]
|
||||
if metric.metric_type == POST_AGG_TYPE:
|
||||
invalid_metric_names.append(metric_name)
|
||||
else:
|
||||
aggregations[metric_name] = metric.json_obj
|
||||
else:
|
||||
invalid_metric_names.append(metric_name)
|
||||
if len(invalid_metric_names) > 0:
|
||||
raise SupersetException(
|
||||
_('Metric(s) {} must be aggregations.').format(invalid_metric_names))
|
||||
for adhoc_metric in adhoc_metrics:
|
||||
aggregations[adhoc_metric['label']] = {
|
||||
'fieldName': adhoc_metric['column']['column_name'],
|
||||
'fieldNames': [adhoc_metric['column']['column_name']],
|
||||
'type': self.druid_type_from_adhoc_metric(adhoc_metric),
|
||||
'type': DruidDatasource.druid_type_from_adhoc_metric(adhoc_metric),
|
||||
'name': adhoc_metric['label'],
|
||||
}
|
||||
return aggregations
|
||||
|
@ -1087,11 +1108,10 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
metrics_dict = {m.metric_name: m for m in self.metrics}
|
||||
columns_dict = {c.column_name: c for c in self.columns}
|
||||
|
||||
saved_metrics, adhoc_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
|
||||
aggregations, post_aggs = DruidDatasource.metrics_and_post_aggs(
|
||||
metrics,
|
||||
metrics_dict)
|
||||
|
||||
aggregations = self.get_aggregations(saved_metrics, adhoc_metrics)
|
||||
self.check_restricted_metrics(aggregations)
|
||||
|
||||
# the dimensions list with dimensionSpecs expanded
|
||||
|
@ -1143,7 +1163,15 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
pre_qry = deepcopy(qry)
|
||||
if timeseries_limit_metric:
|
||||
order_by = timeseries_limit_metric
|
||||
pre_qry['aggregations'] = self.get_aggregations([timeseries_limit_metric])
|
||||
aggs_dict, post_aggs_dict = self.metrics_and_post_aggs(
|
||||
[timeseries_limit_metric],
|
||||
metrics_dict)
|
||||
if phase == 1:
|
||||
pre_qry['aggregations'].update(aggs_dict)
|
||||
pre_qry['post_aggregations'].update(post_aggs_dict)
|
||||
else:
|
||||
pre_qry['aggregations'] = aggs_dict
|
||||
pre_qry['post_aggregations'] = post_aggs_dict
|
||||
else:
|
||||
order_by = list(qry['aggregations'].keys())[0]
|
||||
# Limit on the number of timeseries, doing a two-phases query
|
||||
|
@ -1193,6 +1221,15 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
|
||||
if timeseries_limit_metric:
|
||||
order_by = timeseries_limit_metric
|
||||
aggs_dict, post_aggs_dict = self.metrics_and_post_aggs(
|
||||
[timeseries_limit_metric],
|
||||
metrics_dict)
|
||||
if phase == 1:
|
||||
pre_qry['aggregations'].update(aggs_dict)
|
||||
pre_qry['post_aggregations'].update(post_aggs_dict)
|
||||
else:
|
||||
pre_qry['aggregations'] = aggs_dict
|
||||
pre_qry['post_aggregations'] = post_aggs_dict
|
||||
|
||||
# Limit on the number of timeseries, doing a two-phases query
|
||||
pre_qry['granularity'] = 'all'
|
||||
|
|
|
@ -14,6 +14,7 @@ import superset.connectors.druid.models as models
|
|||
from superset.connectors.druid.models import (
|
||||
DruidColumn, DruidDatasource, DruidMetric,
|
||||
)
|
||||
from superset.exceptions import SupersetException
|
||||
|
||||
|
||||
def mock_metric(metric_name, is_postagg=False):
|
||||
|
@ -157,9 +158,9 @@ class DruidFuncTestCase(unittest.TestCase):
|
|||
col1 = DruidColumn(column_name='col1')
|
||||
col2 = DruidColumn(column_name='col2')
|
||||
ds.columns = [col1, col2]
|
||||
all_metrics = []
|
||||
aggs = []
|
||||
post_aggs = ['some_agg']
|
||||
ds._metrics_and_post_aggs = Mock(return_value=(all_metrics, post_aggs))
|
||||
ds._metrics_and_post_aggs = Mock(return_value=(aggs, post_aggs))
|
||||
groupby = []
|
||||
metrics = ['metric1']
|
||||
ds.get_having_filters = Mock(return_value=[])
|
||||
|
@ -242,9 +243,9 @@ class DruidFuncTestCase(unittest.TestCase):
|
|||
col1 = DruidColumn(column_name='col1')
|
||||
col2 = DruidColumn(column_name='col2')
|
||||
ds.columns = [col1, col2]
|
||||
all_metrics = ['metric1']
|
||||
aggs = ['metric1']
|
||||
post_aggs = ['some_agg']
|
||||
ds._metrics_and_post_aggs = Mock(return_value=(all_metrics, post_aggs))
|
||||
ds._metrics_and_post_aggs = Mock(return_value=(aggs, post_aggs))
|
||||
groupby = ['col1']
|
||||
metrics = ['metric1']
|
||||
ds.get_having_filters = Mock(return_value=[])
|
||||
|
@ -316,9 +317,9 @@ class DruidFuncTestCase(unittest.TestCase):
|
|||
col1 = DruidColumn(column_name='col1')
|
||||
col2 = DruidColumn(column_name='col2')
|
||||
ds.columns = [col1, col2]
|
||||
all_metrics = []
|
||||
aggs = []
|
||||
post_aggs = ['some_agg']
|
||||
ds._metrics_and_post_aggs = Mock(return_value=(all_metrics, post_aggs))
|
||||
ds._metrics_and_post_aggs = Mock(return_value=(aggs, post_aggs))
|
||||
groupby = ['col1', 'col2']
|
||||
metrics = ['metric1']
|
||||
ds.get_having_filters = Mock(return_value=[])
|
||||
|
@ -512,10 +513,10 @@ class DruidFuncTestCase(unittest.TestCase):
|
|||
depends_on('I', ['H', 'K'])
|
||||
depends_on('J', 'K')
|
||||
depends_on('K', ['m8', 'm9'])
|
||||
all_metrics, saved_metrics, postaggs = DruidDatasource.metrics_and_post_aggs(
|
||||
aggs, postaggs = DruidDatasource.metrics_and_post_aggs(
|
||||
metrics, metrics_dict)
|
||||
expected_metrics = set(all_metrics)
|
||||
self.assertEqual(9, len(all_metrics))
|
||||
expected_metrics = set(aggs.keys())
|
||||
self.assertEqual(9, len(aggs))
|
||||
for i in range(1, 10):
|
||||
expected_metrics.remove('m' + str(i))
|
||||
self.assertEqual(0, len(expected_metrics))
|
||||
|
@ -593,45 +594,40 @@ class DruidFuncTestCase(unittest.TestCase):
|
|||
}
|
||||
|
||||
metrics = ['some_sum']
|
||||
saved_metrics, adhoc_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
|
||||
saved_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
|
||||
metrics, metrics_dict)
|
||||
|
||||
assert saved_metrics == ['some_sum']
|
||||
assert adhoc_metrics == []
|
||||
assert set(saved_metrics.keys()) == {'some_sum'}
|
||||
assert post_aggs == {}
|
||||
|
||||
metrics = [adhoc_metric]
|
||||
saved_metrics, adhoc_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
|
||||
saved_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
|
||||
metrics, metrics_dict)
|
||||
|
||||
assert saved_metrics == []
|
||||
assert adhoc_metrics == [adhoc_metric]
|
||||
assert set(saved_metrics.keys()) == set([adhoc_metric['label']])
|
||||
assert post_aggs == {}
|
||||
|
||||
metrics = ['some_sum', adhoc_metric]
|
||||
saved_metrics, adhoc_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
|
||||
saved_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
|
||||
metrics, metrics_dict)
|
||||
|
||||
assert saved_metrics == ['some_sum']
|
||||
assert adhoc_metrics == [adhoc_metric]
|
||||
assert set(saved_metrics.keys()) == {'some_sum', adhoc_metric['label']}
|
||||
assert post_aggs == {}
|
||||
|
||||
metrics = ['quantile_p95']
|
||||
saved_metrics, adhoc_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
|
||||
saved_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
|
||||
metrics, metrics_dict)
|
||||
|
||||
result_postaggs = set(['quantile_p95'])
|
||||
assert saved_metrics == ['a_histogram']
|
||||
assert adhoc_metrics == []
|
||||
assert set(saved_metrics.keys()) == {'a_histogram'}
|
||||
assert set(post_aggs.keys()) == result_postaggs
|
||||
|
||||
metrics = ['aCustomPostAgg']
|
||||
saved_metrics, adhoc_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
|
||||
saved_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
|
||||
metrics, metrics_dict)
|
||||
|
||||
result_postaggs = set(['aCustomPostAgg'])
|
||||
assert saved_metrics == ['aCustomMetric']
|
||||
assert adhoc_metrics == []
|
||||
assert set(saved_metrics.keys()) == {'aCustomMetric'}
|
||||
assert set(post_aggs.keys()) == result_postaggs
|
||||
|
||||
def test_druid_type_from_adhoc_metric(self):
|
||||
|
@ -663,3 +659,157 @@ class DruidFuncTestCase(unittest.TestCase):
|
|||
'label': 'My Adhoc Metric',
|
||||
})
|
||||
assert(druid_type == 'cardinality')
|
||||
|
||||
def test_run_query_order_by_metrics(self):
|
||||
client = Mock()
|
||||
client.query_builder.last_query.query_dict = {'mock': 0}
|
||||
from_dttm = Mock()
|
||||
to_dttm = Mock()
|
||||
ds = DruidDatasource(datasource_name='datasource')
|
||||
ds.get_having_filters = Mock(return_value=[])
|
||||
dim1 = DruidColumn(column_name='dim1')
|
||||
dim2 = DruidColumn(column_name='dim2')
|
||||
metrics_dict = {
|
||||
'count1': DruidMetric(
|
||||
metric_name='count1',
|
||||
metric_type='count',
|
||||
json=json.dumps({'type': 'count', 'name': 'count1'}),
|
||||
),
|
||||
'sum1': DruidMetric(
|
||||
metric_name='sum1',
|
||||
metric_type='doubleSum',
|
||||
json=json.dumps({'type': 'doubleSum', 'name': 'sum1'}),
|
||||
),
|
||||
'sum2': DruidMetric(
|
||||
metric_name='sum2',
|
||||
metric_type='doubleSum',
|
||||
json=json.dumps({'type': 'doubleSum', 'name': 'sum2'}),
|
||||
),
|
||||
'div1': DruidMetric(
|
||||
metric_name='div1',
|
||||
metric_type='postagg',
|
||||
json=json.dumps({
|
||||
'fn': '/',
|
||||
'type': 'arithmetic',
|
||||
'name': 'div1',
|
||||
'fields': [
|
||||
{
|
||||
'fieldName': 'sum1',
|
||||
'type': 'fieldAccess',
|
||||
},
|
||||
{
|
||||
'fieldName': 'sum2',
|
||||
'type': 'fieldAccess',
|
||||
},
|
||||
],
|
||||
}),
|
||||
),
|
||||
}
|
||||
ds.columns = [dim1, dim2]
|
||||
ds.metrics = list(metrics_dict.values())
|
||||
|
||||
groupby = ['dim1']
|
||||
metrics = ['count1']
|
||||
granularity = 'all'
|
||||
# get the counts of the top 5 'dim1's, order by 'sum1'
|
||||
ds.run_query(
|
||||
groupby, metrics, granularity, from_dttm, to_dttm,
|
||||
timeseries_limit=5, timeseries_limit_metric='sum1',
|
||||
client=client, order_desc=True, filter=[],
|
||||
)
|
||||
qry_obj = client.topn.call_args_list[0][1]
|
||||
self.assertEqual('dim1', qry_obj['dimension'])
|
||||
self.assertEqual('sum1', qry_obj['metric'])
|
||||
aggregations = qry_obj['aggregations']
|
||||
post_aggregations = qry_obj['post_aggregations']
|
||||
self.assertEqual({'count1', 'sum1'}, set(aggregations.keys()))
|
||||
self.assertEqual(set(), set(post_aggregations.keys()))
|
||||
|
||||
# get the counts of the top 5 'dim1's, order by 'div1'
|
||||
ds.run_query(
|
||||
groupby, metrics, granularity, from_dttm, to_dttm,
|
||||
timeseries_limit=5, timeseries_limit_metric='div1',
|
||||
client=client, order_desc=True, filter=[],
|
||||
)
|
||||
qry_obj = client.topn.call_args_list[1][1]
|
||||
self.assertEqual('dim1', qry_obj['dimension'])
|
||||
self.assertEqual('div1', qry_obj['metric'])
|
||||
aggregations = qry_obj['aggregations']
|
||||
post_aggregations = qry_obj['post_aggregations']
|
||||
self.assertEqual({'count1', 'sum1', 'sum2'}, set(aggregations.keys()))
|
||||
self.assertEqual({'div1'}, set(post_aggregations.keys()))
|
||||
|
||||
groupby = ['dim1', 'dim2']
|
||||
# get the counts of the top 5 ['dim1', 'dim2']s, order by 'sum1'
|
||||
ds.run_query(
|
||||
groupby, metrics, granularity, from_dttm, to_dttm,
|
||||
timeseries_limit=5, timeseries_limit_metric='sum1',
|
||||
client=client, order_desc=True, filter=[],
|
||||
)
|
||||
qry_obj = client.groupby.call_args_list[0][1]
|
||||
self.assertEqual({'dim1', 'dim2'}, set(qry_obj['dimensions']))
|
||||
self.assertEqual('sum1', qry_obj['limit_spec']['columns'][0]['dimension'])
|
||||
aggregations = qry_obj['aggregations']
|
||||
post_aggregations = qry_obj['post_aggregations']
|
||||
self.assertEqual({'count1', 'sum1'}, set(aggregations.keys()))
|
||||
self.assertEqual(set(), set(post_aggregations.keys()))
|
||||
|
||||
# get the counts of the top 5 ['dim1', 'dim2']s, order by 'div1'
|
||||
ds.run_query(
|
||||
groupby, metrics, granularity, from_dttm, to_dttm,
|
||||
timeseries_limit=5, timeseries_limit_metric='div1',
|
||||
client=client, order_desc=True, filter=[],
|
||||
)
|
||||
qry_obj = client.groupby.call_args_list[1][1]
|
||||
self.assertEqual({'dim1', 'dim2'}, set(qry_obj['dimensions']))
|
||||
self.assertEqual('div1', qry_obj['limit_spec']['columns'][0]['dimension'])
|
||||
aggregations = qry_obj['aggregations']
|
||||
post_aggregations = qry_obj['post_aggregations']
|
||||
self.assertEqual({'count1', 'sum1', 'sum2'}, set(aggregations.keys()))
|
||||
self.assertEqual({'div1'}, set(post_aggregations.keys()))
|
||||
|
||||
def test_get_aggregations(self):
|
||||
ds = DruidDatasource(datasource_name='datasource')
|
||||
metrics_dict = {
|
||||
'sum1': DruidMetric(
|
||||
metric_name='sum1',
|
||||
metric_type='doubleSum',
|
||||
json=json.dumps({'type': 'doubleSum', 'name': 'sum1'}),
|
||||
),
|
||||
'sum2': DruidMetric(
|
||||
metric_name='sum2',
|
||||
metric_type='doubleSum',
|
||||
json=json.dumps({'type': 'doubleSum', 'name': 'sum2'}),
|
||||
),
|
||||
'div1': DruidMetric(
|
||||
metric_name='div1',
|
||||
metric_type='postagg',
|
||||
json=json.dumps({
|
||||
'fn': '/',
|
||||
'type': 'arithmetic',
|
||||
'name': 'div1',
|
||||
'fields': [
|
||||
{
|
||||
'fieldName': 'sum1',
|
||||
'type': 'fieldAccess',
|
||||
},
|
||||
{
|
||||
'fieldName': 'sum2',
|
||||
'type': 'fieldAccess',
|
||||
},
|
||||
],
|
||||
}),
|
||||
),
|
||||
}
|
||||
metric_names = ['sum1', 'sum2']
|
||||
aggs = ds.get_aggregations(metrics_dict, metric_names)
|
||||
expected_agg = {name: metrics_dict[name].json_obj for name in metric_names}
|
||||
self.assertEqual(expected_agg, aggs)
|
||||
|
||||
metric_names = ['sum1', 'col1']
|
||||
self.assertRaises(
|
||||
SupersetException, ds.get_aggregations, metrics_dict, metric_names)
|
||||
|
||||
metric_names = ['sum1', 'div1']
|
||||
self.assertRaises(
|
||||
SupersetException, ds.get_aggregations, metrics_dict, metric_names)
|
||||
|
|
Loading…
Reference in New Issue