[bugfix] 'DruidCluster' object has no attribute 'db_engine_spec' (#5765)

* [bugfix] 'DruidCluster' object has no attribute 'db_engine_spec'

* Fix tests
This commit is contained in:
Maxime Beauchemin 2018-08-28 21:04:06 -07:00 committed by GitHub
parent 2da5db9fcd
commit 135539c109
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 73 additions and 60 deletions

View File

@ -120,9 +120,14 @@ class BaseViz(object):
def get_metric_label(self, metric):
if isinstance(metric, string_types):
return metric
if isinstance(metric, dict):
return self.datasource.database.db_engine_spec.mutate_expression_label(
metric.get('label'))
metric = metric.get('label')
if self.datasource.type == 'table':
db_engine_spec = self.datasource.database.db_engine_spec
metric = db_engine_spec.mutate_expression_label(metric)
return metric
@staticmethod
def handle_js_int_overflow(data):

View File

@ -11,6 +11,8 @@ import os
import unittest
from flask_appbuilder.security.sqla import models as ab_models
from mock import Mock
import pandas as pd
from superset import app, cli, db, security_manager, utils
from superset.connectors.druid.models import DruidCluster, DruidDatasource
@ -147,6 +149,23 @@ class SupersetTestCase(unittest.TestCase):
return db.session.query(DruidDatasource).filter_by(
datasource_name=name).first()
def get_datasource_mock(self):
datasource = Mock()
results = Mock()
results.query = Mock()
results.status = Mock()
results.error_message = None
results.df = pd.DataFrame()
datasource.type = 'table'
datasource.query = Mock(return_value=results)
mock_dttm_col = Mock()
datasource.get_col = Mock(return_value=mock_dttm_col)
datasource.query = Mock(return_value=results)
datasource.database = Mock()
datasource.database.db_engine_spec = Mock()
datasource.database.db_engine_spec.mutate_expression_label = lambda x: x
return datasource
def get_resp(
self, url, data=None, follow_redirects=True, raise_on_error=True):
"""Shortcut to get the parsed results while following redirects"""

View File

@ -5,7 +5,6 @@ from __future__ import print_function
from __future__ import unicode_literals
from datetime import datetime
import unittest
import uuid
from mock import Mock, patch
@ -15,10 +14,11 @@ from superset import app
from superset.exceptions import SpatialException
from superset.utils import DTTM_ALIAS
import superset.viz as viz
from .base_tests import SupersetTestCase
from .utils import load_fixture
class BaseVizTestCase(unittest.TestCase):
class BaseVizTestCase(SupersetTestCase):
def test_constructor_exception_no_datasource(self):
form_data = {}
@ -31,7 +31,7 @@ class BaseVizTestCase(unittest.TestCase):
'viz_type': 'table',
'token': '12345',
}
datasource = {'type': 'table'}
datasource = self.get_datasource_mock()
test_viz = viz.BaseViz(datasource, form_data)
self.assertEqual(
test_viz.default_fillna,
@ -39,16 +39,9 @@ class BaseVizTestCase(unittest.TestCase):
)
def test_get_df_returns_empty_df(self):
datasource = Mock()
datasource.type = 'table'
form_data = {'dummy': 123}
query_obj = {'granularity': 'day'}
results = Mock()
results.query = Mock()
results.status = Mock()
results.error_message = None
results.df = pd.DataFrame()
datasource.query = Mock(return_value=results)
datasource = self.get_datasource_mock()
test_viz = viz.BaseViz(datasource, form_data)
result = test_viz.get_df(query_obj)
self.assertEqual(type(result), pd.DataFrame)
@ -66,14 +59,20 @@ class BaseVizTestCase(unittest.TestCase):
datasource.query = Mock(return_value=results)
mock_dttm_col = Mock()
datasource.get_col = Mock(return_value=mock_dttm_col)
test_viz = viz.BaseViz(datasource, form_data)
test_viz.df_metrics_to_num = Mock()
test_viz.get_fillna_for_columns = Mock(return_value=0)
results.df = pd.DataFrame(data={DTTM_ALIAS: ['1960-01-01 05:00:00']})
datasource.offset = 0
mock_dttm_col = Mock()
datasource.get_col = Mock(return_value=mock_dttm_col)
mock_dttm_col.python_date_format = 'epoch_ms'
result = test_viz.get_df(query_obj)
print(result)
import logging
logging.info(result)
pd.testing.assert_series_equal(
result[DTTM_ALIAS],
pd.Series([datetime(1960, 1, 1, 5, 0)], name=DTTM_ALIAS),
@ -103,38 +102,28 @@ class BaseVizTestCase(unittest.TestCase):
)
def test_cache_timeout(self):
datasource = Mock()
datasource = self.get_datasource_mock()
datasource.cache_timeout = 0
test_viz = viz.BaseViz(datasource, form_data={})
self.assertEqual(0, test_viz.cache_timeout)
datasource.cache_timeout = 156
test_viz = viz.BaseViz(datasource, form_data={})
self.assertEqual(156, test_viz.cache_timeout)
datasource.cache_timeout = None
datasource.database = Mock()
datasource.database.cache_timeout = 0
self.assertEqual(0, test_viz.cache_timeout)
datasource.database.cache_timeout = 1666
self.assertEqual(1666, test_viz.cache_timeout)
datasource.database.cache_timeout = None
test_viz = viz.BaseViz(datasource, form_data={})
self.assertEqual(app.config['CACHE_DEFAULT_TIMEOUT'], test_viz.cache_timeout)
class TableVizTestCase(unittest.TestCase):
class DBEngineSpecMock:
@staticmethod
def mutate_expression_label(label):
return label
class DatabaseMock:
def __init__(self):
self.db_engine_spec = TableVizTestCase.DBEngineSpecMock()
class DatasourceMock:
def __init__(self):
self.database = TableVizTestCase.DatabaseMock()
class TableVizTestCase(SupersetTestCase):
def test_get_data_applies_percentage(self):
form_data = {
@ -151,7 +140,7 @@ class TableVizTestCase(unittest.TestCase):
'column': {'column_name': 'value1', 'type': 'DOUBLE'},
}, 'count', 'avg__C'],
}
datasource = TableVizTestCase.DatasourceMock()
datasource = self.get_datasource_mock()
raw = {}
raw['SUM(value1)'] = [15, 20, 25, 40]
raw['avg__B'] = [10, 20, 5, 15]
@ -227,7 +216,7 @@ class TableVizTestCase(unittest.TestCase):
},
],
}
datasource = Mock()
datasource = self.get_datasource_mock()
test_viz = viz.TableViz(datasource, form_data)
query_obj = test_viz.query_obj()
self.assertEqual(
@ -265,7 +254,7 @@ class TableVizTestCase(unittest.TestCase):
],
'having': 'SUM(value1) > 5',
}
datasource = Mock()
datasource = self.get_datasource_mock()
test_viz = viz.TableViz(datasource, form_data)
query_obj = test_viz.query_obj()
self.assertEqual(
@ -281,7 +270,7 @@ class TableVizTestCase(unittest.TestCase):
@patch('superset.viz.BaseViz.query_obj')
def test_query_obj_merges_percent_metrics(self, super_query_obj):
datasource = Mock()
datasource = self.get_datasource_mock()
form_data = {
'percent_metrics': ['sum__A', 'avg__B', 'max__Y'],
'metrics': ['sum__A', 'count', 'avg__C'],
@ -299,7 +288,7 @@ class TableVizTestCase(unittest.TestCase):
@patch('superset.viz.BaseViz.query_obj')
def test_query_obj_throws_columns_and_metrics(self, super_query_obj):
datasource = Mock()
datasource = self.get_datasource_mock()
form_data = {
'all_columns': ['A', 'B'],
'metrics': ['x', 'y'],
@ -316,7 +305,7 @@ class TableVizTestCase(unittest.TestCase):
@patch('superset.viz.BaseViz.query_obj')
def test_query_obj_merges_all_columns(self, super_query_obj):
datasource = Mock()
datasource = self.get_datasource_mock()
form_data = {
'all_columns': ['colA', 'colB', 'colC'],
'order_by_cols': ['["colA", "colB"]', '["colC"]'],
@ -333,7 +322,7 @@ class TableVizTestCase(unittest.TestCase):
@patch('superset.viz.BaseViz.query_obj')
def test_query_obj_uses_sortby(self, super_query_obj):
datasource = Mock()
datasource = self.get_datasource_mock()
form_data = {
'timeseries_limit_metric': '__time__',
'order_desc': False,
@ -351,20 +340,20 @@ class TableVizTestCase(unittest.TestCase):
)], query_obj['orderby'])
def test_should_be_timeseries_raises_when_no_granularity(self):
datasource = Mock()
datasource = self.get_datasource_mock()
form_data = {'include_time': True}
test_viz = viz.TableViz(datasource, form_data)
with self.assertRaises(Exception):
test_viz.should_be_timeseries()
class PairedTTestTestCase(unittest.TestCase):
class PairedTTestTestCase(SupersetTestCase):
def test_get_data_transforms_dataframe(self):
form_data = {
'groupby': ['groupA', 'groupB', 'groupC'],
'metrics': ['metric1', 'metric2', 'metric3'],
}
datasource = {'type': 'table'}
datasource = self.get_datasource_mock()
# Test data
raw = {}
raw[DTTM_ALIAS] = [100, 200, 300, 100, 200, 300, 100, 200, 300]
@ -456,7 +445,7 @@ class PairedTTestTestCase(unittest.TestCase):
'groupby': [],
'metrics': ['', None],
}
datasource = {'type': 'table'}
datasource = self.get_datasource_mock()
# Test data
raw = {}
raw[DTTM_ALIAS] = [100, 200, 300]
@ -490,11 +479,11 @@ class PairedTTestTestCase(unittest.TestCase):
self.assertEqual(data, expected)
class PartitionVizTestCase(unittest.TestCase):
class PartitionVizTestCase(SupersetTestCase):
@patch('superset.viz.BaseViz.query_obj')
def test_query_obj_time_series_option(self, super_query_obj):
datasource = Mock()
datasource = self.get_datasource_mock()
form_data = {}
test_viz = viz.PartitionViz(datasource, form_data)
super_query_obj.return_value = {}
@ -715,7 +704,7 @@ class PartitionVizTestCase(unittest.TestCase):
self.assertEqual(7, len(test_viz.nest_values.mock_calls))
class RoseVisTestCase(unittest.TestCase):
class RoseVisTestCase(SupersetTestCase):
def test_rose_vis_get_data(self):
raw = {}
@ -755,14 +744,14 @@ class RoseVisTestCase(unittest.TestCase):
self.assertEqual(expected, res)
class TimeSeriesTableVizTestCase(unittest.TestCase):
class TimeSeriesTableVizTestCase(SupersetTestCase):
def test_get_data_metrics(self):
form_data = {
'metrics': ['sum__A', 'count'],
'groupby': [],
}
datasource = Mock()
datasource = self.get_datasource_mock()
raw = {}
t1 = pd.Timestamp('2000')
t2 = pd.Timestamp('2002')
@ -792,7 +781,7 @@ class TimeSeriesTableVizTestCase(unittest.TestCase):
'metrics': ['sum__A'],
'groupby': ['groupby1'],
}
datasource = Mock()
datasource = self.get_datasource_mock()
raw = {}
t1 = pd.Timestamp('2000')
t2 = pd.Timestamp('2002')
@ -821,7 +810,7 @@ class TimeSeriesTableVizTestCase(unittest.TestCase):
@patch('superset.viz.BaseViz.query_obj')
def test_query_obj_throws_metrics_and_groupby(self, super_query_obj):
datasource = Mock()
datasource = self.get_datasource_mock()
form_data = {
'groupby': ['a'],
}
@ -835,11 +824,11 @@ class TimeSeriesTableVizTestCase(unittest.TestCase):
test_viz.query_obj()
class BaseDeckGLVizTestCase(unittest.TestCase):
class BaseDeckGLVizTestCase(SupersetTestCase):
def test_get_metrics(self):
form_data = load_fixture('deck_path_form_data.json')
datasource = {'type': 'table'}
datasource = self.get_datasource_mock()
test_viz_deckgl = viz.BaseDeckGLViz(datasource, form_data)
result = test_viz_deckgl.get_metrics()
assert result == [form_data.get('size')]
@ -851,7 +840,7 @@ class BaseDeckGLVizTestCase(unittest.TestCase):
def test_scatterviz_get_metrics(self):
form_data = load_fixture('deck_path_form_data.json')
datasource = {'type': 'table'}
datasource = self.get_datasource_mock()
form_data = {}
test_viz_deckgl = viz.DeckScatterViz(datasource, form_data)
@ -867,7 +856,7 @@ class BaseDeckGLVizTestCase(unittest.TestCase):
def test_get_js_columns(self):
form_data = load_fixture('deck_path_form_data.json')
datasource = {'type': 'table'}
datasource = self.get_datasource_mock()
mock_d = {
'a': 'dummy1',
'b': 'dummy2',
@ -881,7 +870,7 @@ class BaseDeckGLVizTestCase(unittest.TestCase):
def test_get_properties(self):
mock_d = {}
form_data = load_fixture('deck_path_form_data.json')
datasource = {'type': 'table'}
datasource = self.get_datasource_mock()
test_viz_deckgl = viz.BaseDeckGLViz(datasource, form_data)
with self.assertRaises(NotImplementedError) as context:
@ -891,7 +880,7 @@ class BaseDeckGLVizTestCase(unittest.TestCase):
def test_process_spatial_query_obj(self):
form_data = load_fixture('deck_path_form_data.json')
datasource = {'type': 'table'}
datasource = self.get_datasource_mock()
mock_key = 'spatial_key'
mock_gb = []
test_viz_deckgl = viz.BaseDeckGLViz(datasource, form_data)
@ -917,7 +906,7 @@ class BaseDeckGLVizTestCase(unittest.TestCase):
},
}
datasource = {'type': 'table'}
datasource = self.get_datasource_mock()
expected_results = {
'latlong_key': ['lon', 'lat'],
'delimited_key': ['lonlat'],
@ -931,7 +920,7 @@ class BaseDeckGLVizTestCase(unittest.TestCase):
def test_geojson_query_obj(self):
form_data = load_fixture('deck_geojson_form_data.json')
datasource = {'type': 'table'}
datasource = self.get_datasource_mock()
test_viz_deckgl = viz.DeckGeoJson(datasource, form_data)
results = test_viz_deckgl.query_obj()
@ -941,7 +930,7 @@ class BaseDeckGLVizTestCase(unittest.TestCase):
def test_parse_coordinates(self):
form_data = load_fixture('deck_path_form_data.json')
datasource = {'type': 'table'}
datasource = self.get_datasource_mock()
viz_instance = viz.BaseDeckGLViz(datasource, form_data)
coord = viz_instance.parse_coordinates('1.23, 3.21')
@ -956,7 +945,7 @@ class BaseDeckGLVizTestCase(unittest.TestCase):
def test_parse_coordinates_raises(self):
form_data = load_fixture('deck_path_form_data.json')
datasource = {'type': 'table'}
datasource = self.get_datasource_mock()
test_viz_deckgl = viz.BaseDeckGLViz(datasource, form_data)
with self.assertRaises(SpatialException):
@ -984,7 +973,7 @@ class BaseDeckGLVizTestCase(unittest.TestCase):
},
}
datasource = {'type': 'table'}
datasource = self.get_datasource_mock()
expected_results = {
'latlong_key': [{
'clause': 'WHERE',
@ -1027,10 +1016,10 @@ class BaseDeckGLVizTestCase(unittest.TestCase):
assert expected_results.get(mock_key) == adhoc_filters
class TimeSeriesVizTestCase(unittest.TestCase):
class TimeSeriesVizTestCase(SupersetTestCase):
def test_timeseries_unicode_data(self):
datasource = Mock()
datasource = self.get_datasource_mock()
form_data = {
'groupby': ['name'],
'metrics': ['sum__payout'],