fix: Time Column on Generic X-axis (#23021)

This commit is contained in:
Michael S. Molina 2023-02-10 13:33:07 -05:00 committed by GitHub
parent 85f07798bf
commit 464ddee4b4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 97 additions and 56 deletions

View File

@ -20,16 +20,11 @@
import buildQueryObject from './buildQueryObject';
import DatasourceKey from './DatasourceKey';
import { QueryFieldAliases, QueryFormData } from './types/QueryFormData';
import {
BinaryQueryObjectFilterClause,
QueryContext,
QueryObject,
} from './types/Query';
import { QueryContext, QueryObject } from './types/Query';
import { SetDataMaskHook } from '../chart';
import { JsonObject } from '../connection';
import { normalizeTimeColumn } from './normalizeTimeColumn';
import { hasGenericChartAxes, isXAxisSet } from './getXAxis';
import { ensureIsArray } from '../utils';
import { isXAxisSet } from './getXAxis';
const WRAP_IN_ARRAY = (baseQueryObject: QueryObject) => [baseQueryObject];
@ -60,14 +55,6 @@ export default function buildQueryContext(
// eslint-disable-next-line no-param-reassign
query.post_processing = query.post_processing.filter(Boolean);
}
if (hasGenericChartAxes && query.time_range) {
// eslint-disable-next-line no-param-reassign
query.filters = ensureIsArray(query.filters).map(flt =>
flt?.op === 'TEMPORAL_RANGE'
? ({ ...flt, val: query.time_range } as BinaryQueryObjectFilterClause)
: flt,
);
}
});
if (isXAxisSet(formData)) {
queries = queries.map(query => normalizeTimeColumn(formData, query));

View File

@ -164,41 +164,4 @@ describe('buildQueryContext', () => {
expect(spyNormalizeTimeColumn).not.toBeCalled();
spyNormalizeTimeColumn.mockRestore();
});
it('should orverride time filter if GENERIC_CHART_AXES is enabled', () => {
Object.defineProperty(getXAxisModule, 'hasGenericChartAxes', {
value: true,
});
const queryContext = buildQueryContext(
{
datasource: '5__table',
viz_type: 'table',
},
() => [
{
filters: [
{
col: 'col1',
op: 'TEMPORAL_RANGE',
val: '2001 : 2002',
},
{
col: 'col2',
op: 'IN',
val: ['a', 'b'],
},
],
time_range: '1990 : 1991',
},
],
);
expect(queryContext.queries[0].filters).toEqual([
{ col: 'col1', op: 'TEMPORAL_RANGE', val: '1990 : 1991' },
{
col: 'col2',
op: 'IN',
val: ['a', 'b'],
},
]);
});
});

View File

@ -198,7 +198,8 @@ class SaveModal extends React.Component<SaveModalProps, SaveModalState> {
);
}
const { url_params, ...formData } = this.props.form_data || {};
const formData = this.props.form_data || {};
delete formData.url_params;
let dashboard: DashboardGetResponse | null = null;
if (this.state.newDashboardName || this.state.saveToDashboardId) {

View File

@ -22,6 +22,7 @@ from superset import app, db
from superset.charts.dao import ChartDAO
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
from superset.common.query_context import QueryContext
from superset.common.query_object import QueryObject
from superset.common.query_object_factory import QueryObjectFactory
from superset.datasource.dao import DatasourceDAO
from superset.models.slice import Slice
@ -65,8 +66,12 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods
result_type = result_type or ChartDataResultType.FULL
result_format = result_format or ChartDataResultFormat.JSON
queries_ = [
self._query_object_factory.create(
result_type, datasource=datasource, **query_obj
self._process_query_object(
datasource_model_instance,
form_data,
self._query_object_factory.create(
result_type, datasource=datasource, **query_obj
),
)
for query_obj in queries
]
@ -90,7 +95,6 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods
# pylint: disable=no-self-use
def _convert_to_model(self, datasource: DatasourceDict) -> BaseDatasource:
return DatasourceDAO.get_datasource(
session=db.session,
datasource_type=DatasourceType(datasource["type"]),
@ -99,3 +103,89 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods
def _get_slice(self, slice_id: Any) -> Optional[Slice]:
return ChartDAO.find_by_id(slice_id)
def _process_query_object(
self,
datasource: BaseDatasource,
form_data: Optional[Dict[str, Any]],
query_object: QueryObject,
) -> QueryObject:
self._apply_granularity(query_object, form_data, datasource)
self._apply_filters(query_object)
return query_object
def _apply_granularity(
self,
query_object: QueryObject,
form_data: Optional[Dict[str, Any]],
datasource: BaseDatasource,
) -> None:
temporal_columns = {
column.column_name
for column in datasource.columns
if (column["is_dttm"] if isinstance(column, dict) else column.is_dttm)
}
granularity = query_object.granularity
x_axis = form_data and form_data.get("x_axis")
if granularity:
filter_to_remove = None
if x_axis and x_axis in temporal_columns:
filter_to_remove = x_axis
x_axis_column = next(
(
column
for column in query_object.columns
if column == x_axis
or (
isinstance(column, dict)
and column["sqlExpression"] == x_axis
)
),
None,
)
# Replaces x-axis column values with granularity
if x_axis_column:
if isinstance(x_axis_column, dict):
x_axis_column["sqlExpression"] = granularity
x_axis_column["label"] = granularity
else:
query_object.columns = [
granularity if column == x_axis_column else column
for column in query_object.columns
]
for post_processing in query_object.post_processing:
if post_processing.get("operation") == "pivot":
post_processing["options"]["index"] = [granularity]
# If no temporal x-axis, then get the default temporal filter
if not filter_to_remove:
temporal_filters = [
filter["col"]
for filter in query_object.filter
if filter["op"] == "TEMPORAL_RANGE"
]
if len(temporal_filters) > 0:
# Use granularity if it's already in the filters
if granularity in temporal_filters:
filter_to_remove = granularity
else:
# Use the first temporal filter
filter_to_remove = temporal_filters[0]
# Removes the temporal filter which may be an x-axis or
# another temporal filter. A new filter based on the value of
# the granularity will be added later in the code.
# In practice, this is replacing the previous default temporal filter.
if filter_to_remove:
query_object.filter = [
filter
for filter in query_object.filter
if filter["col"] != filter_to_remove
]
def _apply_filters(self, query_object: QueryObject) -> None:
if query_object.time_range:
for filter_object in query_object.filter:
if filter_object["op"] == "TEMPORAL_RANGE":
filter_object["val"] = query_object.time_range