From 32d2aa0c404f47abaa270bf1c572fdd26feb16ed Mon Sep 17 00:00:00 2001 From: Yongjie Zhao Date: Wed, 28 Jul 2021 15:34:39 +0100 Subject: [PATCH] feat: run extra query on QueryObject and add compare operator for post_processing (#15279) * rebase master and resolve conflicts * pylint to makefile * fix crash when pivot operator * fix comments * add precision argument * query test * wip * fix ut * rename * set time_offsets to cache key wip * refactor get_df_payload wip * extra query cache * cache ut * normalize df * fix timeoffset * fix ut * make cache key logging sense * resolve conflicts * backend follow up iteration 1 wip * rolling window type * rebase master * py lint and minor follow ups * pylintrc --- Makefile | 3 + superset/charts/commands/exceptions.py | 23 +- superset/charts/schemas.py | 3 + superset/common/query_context.py | 281 ++++++++++++------ superset/common/query_object.py | 10 + superset/common/utils.py | 179 +++++++++++ superset/constants.py | 19 ++ superset/examples/birth_names.py | 5 +- superset/utils/core.py | 2 + superset/utils/date_parser.py | 58 ++-- superset/utils/pandas_postprocessing.py | 76 ++++- superset/views/api.py | 4 +- .../integration_tests/fixtures/dataframes.py | 9 + .../fixtures/query_context.py | 16 +- .../pandas_postprocessing_tests.py | 59 ++++ .../integration_tests/query_context_tests.py | 103 +++++++ .../utils/date_parser_tests.py | 43 ++- 17 files changed, 744 insertions(+), 149 deletions(-) create mode 100644 superset/common/utils.py diff --git a/Makefile b/Makefile index 5de06f057b..85027a175f 100644 --- a/Makefile +++ b/Makefile @@ -76,5 +76,8 @@ format: py-format js-format py-format: pre-commit pre-commit run black --all-files +py-lint: pre-commit + pylint -j 0 superset + js-format: cd superset-frontend; npm run prettier diff --git a/superset/charts/commands/exceptions.py b/superset/charts/commands/exceptions.py index ee369a544c..60a62e1987 100644 --- a/superset/charts/commands/exceptions.py +++ b/superset/charts/commands/exceptions.py @@ -28,15 +28,15 @@ from superset.commands.exceptions import ( ) -class TimeRangeUnclearError(ValidationError): +class TimeRangeAmbiguousError(ValidationError): """ - Time range is in valid error. + Time range is ambiguous error. """ def __init__(self, human_readable: str) -> None: super().__init__( _( - "Time string is unclear." + "Time string is ambiguous." " Please specify [%(human_readable)s ago]" " or [%(human_readable)s later].", human_readable=human_readable, @@ -56,6 +56,23 @@ class TimeRangeParseFailError(ValidationError): ) +class TimeDeltaAmbiguousError(ValidationError): + """ + Time delta is ambiguous error. + """ + + def __init__(self, human_readable: str) -> None: + super().__init__( + _( + "Time delta is ambiguous." + " Please specify [%(human_readable)s ago]" + " or [%(human_readable)s later].", + human_readable=human_readable, + ), + field_name="time_range", + ) + + class DatabaseNotFoundValidationError(ValidationError): """ Marshmallow validation error for database does not exist diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index 7ae22e78ca..795bb63fe3 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -730,6 +730,8 @@ class ChartDataPostProcessingOperationSchema(Schema): "rolling", "select", "sort", + "diff", + "compare", ) ), example="aggregate", @@ -1074,6 +1076,7 @@ class ChartDataQueryObjectSchema(Schema): description="Should the rowcount of the actual query be returned", allow_none=True, ) + time_offsets = fields.List(fields.String(), allow_none=True,) class ChartDataQueryContextSchema(Schema): diff --git a/superset/common/query_context.py b/superset/common/query_context.py index 6185562751..8735f4836b 100644 --- a/superset/common/query_context.py +++ b/superset/common/query_context.py @@ -16,26 +16,28 @@ # under the License. from __future__ import annotations +import copy import logging from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Union import numpy as np import pandas as pd from flask_babel import _ +from pandas import DateOffset +from typing_extensions import TypedDict from superset import app, db, is_feature_enabled from superset.annotation_layers.dao import AnnotationLayerDAO from superset.charts.dao import ChartDAO from superset.common.query_actions import get_query_results from superset.common.query_object import QueryObject +from superset.common.utils import QueryCacheManager from superset.connectors.base.models import BaseDatasource from superset.connectors.connector_registry import ConnectorRegistry -from superset.exceptions import ( - CacheLoadError, - QueryObjectValidationError, - SupersetException, -) +from superset.constants import CacheRegion +from superset.exceptions import QueryObjectValidationError, SupersetException from superset.extensions import cache_manager, security_manager +from superset.models.helpers import QueryResult from superset.utils import csv from superset.utils.cache import generate_cache_key, set_and_log_cache from superset.utils.core import ( @@ -45,10 +47,12 @@ from superset.utils.core import ( DTTM_ALIAS, error_msg_from_exception, get_column_names_from_metrics, - get_stacktrace, + get_metric_names, normalize_dttm_col, QueryStatus, + TIME_COMPARISION, ) +from superset.utils.date_parser import get_past_or_future, normalize_time_delta from superset.views.utils import get_viz if TYPE_CHECKING: @@ -59,6 +63,12 @@ stats_logger: BaseStatsLogger = config["STATS_LOGGER"] logger = logging.getLogger(__name__) +class CachedTimeOffset(TypedDict): + df: pd.DataFrame + queries: List[str] + cache_keys: List[Optional[str]] + + class QueryContext: """ The query context contains the query object and additional fields necessary @@ -77,7 +87,8 @@ class QueryContext: # TODO: Type datasource and query_object dictionary with TypedDict when it becomes # a vanilla python type https://github.com/python/mypy/issues/5288 - def __init__( # pylint: disable=too-many-arguments + # pylint: disable=too-many-arguments + def __init__( self, datasource: DatasourceDict, queries: List[Dict[str, Any]], @@ -101,21 +112,143 @@ class QueryContext: "result_format": self.result_format, } - def get_query_result(self, query_object: QueryObject) -> Dict[str, Any]: - """Returns a pandas dataframe based on the query object""" + @staticmethod + def left_join_on_dttm( + left_df: pd.DataFrame, right_df: pd.DataFrame + ) -> pd.DataFrame: + df = left_df.set_index(DTTM_ALIAS).join(right_df.set_index(DTTM_ALIAS)) + df.reset_index(level=0, inplace=True) + return df - # Here, we assume that all the queries will use the same datasource, which is - # a valid assumption for current setting. In the long term, we may - # support multiple queries from different data sources. + def processing_time_offsets( + self, df: pd.DataFrame, query_object: QueryObject, + ) -> CachedTimeOffset: + # ensure query_object is immutable + query_object_clone = copy.copy(query_object) + queries = [] + cache_keys = [] + time_offsets = query_object.time_offsets + outer_from_dttm = query_object.from_dttm + outer_to_dttm = query_object.to_dttm + for offset in time_offsets: + try: + query_object_clone.from_dttm = get_past_or_future( + offset, outer_from_dttm, + ) + query_object_clone.to_dttm = get_past_or_future(offset, outer_to_dttm) + except ValueError as ex: + raise QueryObjectValidationError(str(ex)) + # make sure subquery use main query where clause + query_object_clone.inner_from_dttm = outer_from_dttm + query_object_clone.inner_to_dttm = outer_to_dttm + query_object_clone.time_offsets = [] + query_object_clone.post_processing = [] + + if not query_object.from_dttm or not query_object.to_dttm: + raise QueryObjectValidationError( + _( + "An enclosed time range (both start and end) must be specified " + "when using a Time Comparison." + ) + ) + # `offset` is added to the hash function + cache_key = self.query_cache_key(query_object_clone, time_offset=offset) + cache = QueryCacheManager.get(cache_key, CacheRegion.DATA, self.force) + # whether hit in the cache + if cache.is_loaded: + df = self.left_join_on_dttm(df, cache.df) + queries.append(cache.query) + cache_keys.append(cache_key) + continue + + query_object_clone_dct = query_object_clone.to_dict() + result = self.datasource.query(query_object_clone_dct) + queries.append(result.query) + cache_keys.append(None) + + # rename metrics: SUM(value) => SUM(value) 1 year ago + columns_name_mapping = { + metric: TIME_COMPARISION.join([metric, offset]) + for metric in get_metric_names( + query_object_clone_dct.get("metrics", []) + ) + } + columns_name_mapping[DTTM_ALIAS] = DTTM_ALIAS + + offset_metrics_df = result.df + if offset_metrics_df.empty: + offset_metrics_df = pd.DataFrame( + {col: [np.NaN] for col in columns_name_mapping.values()} + ) + else: + # 1. normalize df, set dttm column + offset_metrics_df = self.normalize_df( + offset_metrics_df, query_object_clone + ) + + # 2. extract `metrics` columns and `dttm` column from extra query + offset_metrics_df = offset_metrics_df[columns_name_mapping.keys()] + + # 3. rename extra query columns + offset_metrics_df = offset_metrics_df.rename( + columns=columns_name_mapping + ) + + # 4. set offset for dttm column + offset_metrics_df[DTTM_ALIAS] = offset_metrics_df[ + DTTM_ALIAS + ] - DateOffset(**normalize_time_delta(offset)) + + # df left join `offset_metrics_df` on `DTTM` + df = self.left_join_on_dttm(df, offset_metrics_df) + + # set offset df to cache. + value = { + "df": offset_metrics_df, + "query": result.query, + } + cache.set( + key=cache_key, + value=value, + timeout=self.cache_timeout, + datasource_uid=self.datasource.uid, + region=CacheRegion.DATA, + ) + + return CachedTimeOffset(df=df, queries=queries, cache_keys=cache_keys) + + def normalize_df(self, df: pd.DataFrame, query_object: QueryObject) -> pd.DataFrame: timestamp_format = None if self.datasource.type == "table": dttm_col = self.datasource.get_column(query_object.granularity) if dttm_col: timestamp_format = dttm_col.python_date_format + normalize_dttm_col( + df=df, + timestamp_format=timestamp_format, + offset=self.datasource.offset, + time_shift=query_object.time_shift, + ) + + if self.enforce_numerical_metrics: + self.df_metrics_to_num(df, query_object) + + df.replace([np.inf, -np.inf], np.nan, inplace=True) + + return df + + def get_query_result(self, query_object: QueryObject) -> QueryResult: + """Returns a pandas dataframe based on the query object""" + + # Here, we assume that all the queries will use the same datasource, which is + # a valid assumption for current setting. In the long term, we may + # support multiple queries from different data sources. + # The datasource here can be different backend but the interface is common result = self.datasource.query(query_object.to_dict()) + query = result.query + ";\n\n" df = result.df # Transform the timestamp we received from database to pandas supported @@ -124,25 +257,21 @@ class QueryContext: # If the datetime format is unix, the parse will use the corresponding # parsing logic if not df.empty: - normalize_dttm_col( - df=df, - timestamp_format=timestamp_format, - offset=self.datasource.offset, - time_shift=query_object.time_shift, - ) + df = self.normalize_df(df, query_object) - if self.enforce_numerical_metrics: - self.df_metrics_to_num(df, query_object) + if query_object.time_offsets: + time_offsets = self.processing_time_offsets(df, query_object) + df = time_offsets["df"] + queries = time_offsets["queries"] + + query += ";\n\n".join(queries) + query += ";\n\n" - df.replace([np.inf, -np.inf], np.nan, inplace=True) df = query_object.exec_post_processing(df) - return { - "query": result.query, - "status": result.status, - "error_message": result.error_message, - "df": df, - } + result.df = df + result.query = query + return result @staticmethod def df_metrics_to_num(df: pd.DataFrame, query_object: QueryObject) -> None: @@ -308,47 +437,16 @@ class QueryContext: ) return annotation_data - def get_df_payload( # pylint: disable=too-many-statements,too-many-locals + def get_df_payload( self, query_obj: QueryObject, force_cached: Optional[bool] = False, ) -> Dict[str, Any]: """Handles caching around the df payload retrieval""" cache_key = self.query_cache_key(query_obj) - logger.info("Cache key: %s", cache_key) - is_loaded = False - stacktrace = None - df = pd.DataFrame() - cache_value = None - status = None - query = "" - annotation_data = {} - error_message = None - if cache_key and cache_manager.data_cache and not self.force: - cache_value = cache_manager.data_cache.get(cache_key) - if cache_value: - stats_logger.incr("loading_from_cache") - try: - df = cache_value["df"] - query = cache_value["query"] - annotation_data = cache_value.get("annotation_data", {}) - status = QueryStatus.SUCCESS - is_loaded = True - stats_logger.incr("loaded_from_cache") - except KeyError as ex: - logger.exception(ex) - logger.error( - "Error reading cache: %s", - error_msg_from_exception(ex), - exc_info=True, - ) - logger.info("Serving from cache") + cache = QueryCacheManager.get( + cache_key, CacheRegion.DATA, self.force, force_cached, + ) - if force_cached and not is_loaded: - logger.warning( - "force_cached (QueryContext): value not found for key %s", cache_key - ) - raise CacheLoadError("Error loading data from cache") - - if query_obj and not is_loaded: + if query_obj and cache_key and not cache.is_loaded: try: invalid_columns = [ col @@ -365,47 +463,32 @@ class QueryContext: ) ) query_result = self.get_query_result(query_obj) - status = query_result["status"] - query = query_result["query"] - error_message = query_result["error_message"] - df = query_result["df"] annotation_data = self.get_annotation_data(query_obj) - - if status != QueryStatus.FAILED: - stats_logger.incr("loaded_from_source") - if not self.force: - stats_logger.incr("loaded_from_source_without_force") - is_loaded = True - except QueryObjectValidationError as ex: - error_message = str(ex) - status = QueryStatus.FAILED - except Exception as ex: # pylint: disable=broad-except - logger.exception(ex) - if not error_message: - error_message = str(ex) - status = QueryStatus.FAILED - stacktrace = get_stacktrace() - - if is_loaded and cache_key and status != QueryStatus.FAILED: - set_and_log_cache( - cache_manager.data_cache, - cache_key, - {"df": df, "query": query, "annotation_data": annotation_data}, - self.cache_timeout, - self.datasource.uid, + cache.set_query_result( + key=cache_key, + query_result=query_result, + annotation_data=annotation_data, + force_query=self.force, + timeout=self.cache_timeout, + datasource_uid=self.datasource.uid, + region=CacheRegion.DATA, ) + except QueryObjectValidationError as ex: + cache.error_message = str(ex) + cache.status = QueryStatus.FAILED + return { "cache_key": cache_key, - "cached_dttm": cache_value["dttm"] if cache_value is not None else None, + "cached_dttm": cache.cache_dttm, "cache_timeout": self.cache_timeout, - "df": df, - "annotation_data": annotation_data, - "error": error_message, - "is_cached": cache_value is not None, - "query": query, - "status": status, - "stacktrace": stacktrace, - "rowcount": len(df.index), + "df": cache.df, + "annotation_data": cache.annotation_data, + "error": cache.error_message, + "is_cached": cache.is_cached, + "query": cache.query, + "status": cache.status, + "stacktrace": cache.stacktrace, + "rowcount": len(cache.df.index), } def raise_for_access(self) -> None: diff --git a/superset/common/query_object.py b/superset/common/query_object.py index c16c4e7487..8ab4c620f4 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -77,6 +77,8 @@ class QueryObject: granularity: Optional[str] from_dttm: Optional[datetime] to_dttm: Optional[datetime] + inner_from_dttm: Optional[datetime] + inner_to_dttm: Optional[datetime] is_timeseries: bool time_shift: Optional[timedelta] groupby: List[str] @@ -94,6 +96,7 @@ class QueryObject: datasource: Optional[BaseDatasource] result_type: Optional[ChartDataResultType] is_rowcount: bool + time_offsets: List[str] def __init__( self, @@ -125,6 +128,9 @@ class QueryObject: groupby = groupby or [] extras = extras or {} annotation_layers = annotation_layers or [] + self.time_offsets = kwargs.get("time_offsets", []) + self.inner_from_dttm = kwargs.get("inner_from_dttm") + self.inner_to_dttm = kwargs.get("inner_to_dttm") self.is_rowcount = is_rowcount self.datasource = None @@ -268,6 +274,8 @@ class QueryObject: "groupby": self.groupby, "from_dttm": self.from_dttm, "to_dttm": self.to_dttm, + "inner_from_dttm": self.inner_from_dttm, + "inner_to_dttm": self.inner_to_dttm, "is_rowcount": self.is_rowcount, "is_timeseries": self.is_timeseries, "metrics": self.metrics, @@ -307,6 +315,8 @@ class QueryObject: cache_dict["time_range"] = self.time_range if self.post_processing: cache_dict["post_processing"] = self.post_processing + if self.time_offsets: + cache_dict["time_offsets"] = self.time_offsets for k in ["from_dttm", "to_dttm"]: del cache_dict[k] diff --git a/superset/common/utils.py b/superset/common/utils.py new file mode 100644 index 0000000000..ab83b84922 --- /dev/null +++ b/superset/common/utils.py @@ -0,0 +1,179 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging +from typing import Any, Dict, Optional + +from flask_caching import Cache +from pandas import DataFrame + +from superset import app +from superset.constants import CacheRegion +from superset.exceptions import CacheLoadError +from superset.extensions import cache_manager +from superset.models.helpers import QueryResult +from superset.stats_logger import BaseStatsLogger +from superset.utils.cache import set_and_log_cache +from superset.utils.core import error_msg_from_exception, get_stacktrace, QueryStatus + +config = app.config +stats_logger: BaseStatsLogger = config["STATS_LOGGER"] +logger = logging.getLogger(__name__) + +_cache: Dict[CacheRegion, Cache] = { + CacheRegion.DEFAULT: cache_manager.cache, + CacheRegion.DATA: cache_manager.data_cache, +} + + +class QueryCacheManager: + """ + Class for manage query-cache getting and setting + """ + + # pylint: disable=too-many-instance-attributes,too-many-arguments + def __init__( + self, + df: DataFrame = DataFrame(), + query: str = "", + annotation_data: Optional[Dict[str, Any]] = None, + status: Optional[str] = None, + error_message: Optional[str] = None, + is_loaded: bool = False, + stacktrace: Optional[str] = None, + is_cached: Optional[bool] = None, + cache_dttm: Optional[str] = None, + cache_value: Optional[Dict[str, Any]] = None, + ) -> None: + self.df = df + self.query = query + self.annotation_data = {} if annotation_data is None else annotation_data + self.status = status + self.error_message = error_message + + self.is_loaded = is_loaded + self.stacktrace = stacktrace + self.is_cached = is_cached + self.cache_dttm = cache_dttm + self.cache_value = cache_value + + # pylint: disable=too-many-arguments + def set_query_result( + self, + key: str, + query_result: QueryResult, + annotation_data: Optional[Dict[str, Any]] = None, + force_query: Optional[bool] = False, + timeout: Optional[int] = None, + datasource_uid: Optional[str] = None, + region: CacheRegion = CacheRegion.DEFAULT, + ) -> None: + """ + Set dataframe of query-result to specific cache region + """ + try: + self.status = query_result.status + self.query = query_result.query + self.error_message = query_result.error_message + self.df = query_result.df + self.annotation_data = {} if annotation_data is None else annotation_data + + if self.status != QueryStatus.FAILED: + stats_logger.incr("loaded_from_source") + if not force_query: + stats_logger.incr("loaded_from_source_without_force") + self.is_loaded = True + + value = { + "df": self.df, + "query": self.query, + "annotation_data": self.annotation_data, + } + if self.is_loaded and key and self.status != QueryStatus.FAILED: + self.set( + key=key, + value=value, + timeout=timeout, + datasource_uid=datasource_uid, + region=region, + ) + except Exception as ex: # pylint: disable=broad-except + logger.exception(ex) + if not self.error_message: + self.error_message = str(ex) + self.status = QueryStatus.FAILED + self.stacktrace = get_stacktrace() + + @classmethod + def get( + cls, + key: Optional[str], + region: CacheRegion = CacheRegion.DEFAULT, + force_query: Optional[bool] = False, + force_cached: Optional[bool] = False, + ) -> "QueryCacheManager": + """ + Initialize QueryCacheManager by query-cache key + """ + query_cache = cls() + if not key or not _cache[region] or force_query: + return query_cache + + cache_value = _cache[region].get(key) + if cache_value: + logger.info("Cache key: %s", key) + stats_logger.incr("loading_from_cache") + try: + query_cache.df = cache_value["df"] + query_cache.query = cache_value["query"] + query_cache.annotation_data = cache_value.get("annotation_data", {}) + query_cache.status = QueryStatus.SUCCESS + query_cache.is_loaded = True + query_cache.is_cached = cache_value is not None + query_cache.cache_dttm = ( + cache_value["dttm"] if cache_value is not None else None + ) + query_cache.cache_value = cache_value + stats_logger.incr("loaded_from_cache") + except KeyError as ex: + logger.exception(ex) + logger.error( + "Error reading cache: %s", + error_msg_from_exception(ex), + exc_info=True, + ) + logger.info("Serving from cache") + + if force_cached and not query_cache.is_loaded: + logger.warning( + "force_cached (QueryContext): value not found for key %s", key + ) + raise CacheLoadError("Error loading data from cache") + return query_cache + + @staticmethod + def set( + key: Optional[str], + value: Dict[str, Any], + timeout: Optional[int] = None, + datasource_uid: Optional[str] = None, + region: CacheRegion = CacheRegion.DEFAULT, + ) -> None: + """ + set value to specify cache region, proxy for `set_and_log_cache` + """ + if key: + set_and_log_cache(_cache[region], key, value, timeout, datasource_uid) diff --git a/superset/constants.py b/superset/constants.py index 7defc34d09..e6398666e9 100644 --- a/superset/constants.py +++ b/superset/constants.py @@ -18,6 +18,8 @@ # ATTENTION: If you change any constants, make sure to also change utils/common.js # string to use when None values *need* to be converted to/from strings +from enum import Enum + NULL_STRING = "" @@ -154,3 +156,20 @@ EXTRA_FORM_DATA_OVERRIDE_KEYS = ( set(EXTRA_FORM_DATA_OVERRIDE_REGULAR_MAPPINGS.values()) | EXTRA_FORM_DATA_OVERRIDE_EXTRA_KEYS ) + + +class PandasAxis(int, Enum): + ROW = 0 + COLUMN = 1 + + +class PandasPostprocessingCompare(str, Enum): + ABS = "absolute" + PCT = "percentage" + RAT = "ratio" + + +class CacheRegion(str, Enum): + DEFAULT = "default" + DATA = "data" + THUMBNAIL = "thumbnail" diff --git a/superset/examples/birth_names.py b/superset/examples/birth_names.py index 40b7103c1b..ea5884a01b 100644 --- a/superset/examples/birth_names.py +++ b/superset/examples/birth_names.py @@ -528,9 +528,9 @@ def create_dashboard(slices: List[Slice]) -> Dashboard: } }""" ) + # pylint: disable=line-too-long pos = json.loads( textwrap.dedent( - # pylint: disable=line-too-long """\ { "CHART-6GdlekVise": { @@ -800,9 +800,10 @@ def create_dashboard(slices: List[Slice]) -> Dashboard: "type": "ROW" } } - """ # pylint: enable=line-too-long + """ ) ) + # pylint: enable=line-too-long # dashboard v2 doesn't allow add markup slice dash.slices = [slc for slc in slices if slc.viz_type != "markup"] update_slice_ids(pos, dash.slices) diff --git a/superset/utils/core.py b/superset/utils/core.py index 0c4c18fe02..9e37b7caeb 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -115,6 +115,8 @@ logger = logging.getLogger(__name__) DTTM_ALIAS = "__timestamp" +TIME_COMPARISION = "__" + JS_MAX_INTEGER = 9007199254740991 # Largest int Java Script can handle 2^53-1 InputType = TypeVar("InputType") diff --git a/superset/utils/date_parser.py b/superset/utils/date_parser.py index 9bdf1d3026..51e5b8a8d8 100644 --- a/superset/utils/date_parser.py +++ b/superset/utils/date_parser.py @@ -19,7 +19,7 @@ import logging import re from datetime import datetime, timedelta from time import struct_time -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple import parsedatetime from dateutil.parser import parse @@ -40,8 +40,9 @@ from pyparsing import ( ) from superset.charts.commands.exceptions import ( + TimeDeltaAmbiguousError, + TimeRangeAmbiguousError, TimeRangeParseFailError, - TimeRangeUnclearError, ) from superset.utils.memoized import memoized @@ -51,33 +52,10 @@ logger = logging.getLogger(__name__) def parse_human_datetime(human_readable: str) -> datetime: - """ - Returns ``datetime.datetime`` from human readable strings - - >>> from datetime import date, timedelta - >>> from dateutil.relativedelta import relativedelta - >>> parse_human_datetime('2015-04-03') - datetime.datetime(2015, 4, 3, 0, 0) - >>> parse_human_datetime('2/3/1969') - datetime.datetime(1969, 2, 3, 0, 0) - >>> parse_human_datetime('now') <= datetime.now() - True - >>> parse_human_datetime('yesterday') <= datetime.now() - True - >>> date.today() - timedelta(1) == parse_human_datetime('yesterday').date() - True - >>> year_ago_1 = parse_human_datetime('one year ago').date() - >>> year_ago_2 = (datetime.now() - relativedelta(years=1)).date() - >>> year_ago_1 == year_ago_2 - True - >>> year_after_1 = parse_human_datetime('2 years after').date() - >>> year_after_2 = (datetime.now() + relativedelta(years=2)).date() - >>> year_after_1 == year_after_2 - True - """ + """ Returns ``datetime.datetime`` from human readable strings """ x_periods = r"^\s*([0-9]+)\s+(second|minute|hour|day|week|month|quarter|year)s?\s*$" if re.search(x_periods, human_readable, re.IGNORECASE): - raise TimeRangeUnclearError(human_readable) + raise TimeRangeAmbiguousError(human_readable) try: default = datetime(year=datetime.now().year, month=1, day=1) dttm = parse(human_readable, default=default) @@ -95,6 +73,18 @@ def parse_human_datetime(human_readable: str) -> datetime: return dttm +def normalize_time_delta(human_readable: str) -> Dict[str, int]: + x_unit = r"^\s*([0-9]+)\s+(second|minute|hour|day|week|month|quarter|year)s?\s+(ago|later)*$" # pylint: disable=line-too-long + matched = re.match(x_unit, human_readable, re.IGNORECASE) + if not matched: + raise TimeDeltaAmbiguousError(human_readable) + + key = matched[2] + "s" + value = int(matched[1]) + value = -value if matched[3] == "ago" else value + return {key: value} + + def dttm_from_timetuple(date_: struct_time) -> datetime: return datetime( date_.tm_year, @@ -106,6 +96,16 @@ def dttm_from_timetuple(date_: struct_time) -> datetime: ) +def get_past_or_future( + human_readable: Optional[str], source_time: Optional[datetime] = None, +) -> datetime: + cal = parsedatetime.Calendar() + source_dttm = dttm_from_timetuple( + source_time.timetuple() if source_time else datetime.now().timetuple() + ) + return dttm_from_timetuple(cal.parse(human_readable or "", source_dttm)[0]) + + def parse_human_timedelta( human_readable: Optional[str], source_time: Optional[datetime] = None, ) -> timedelta: @@ -115,12 +115,10 @@ def parse_human_timedelta( >>> parse_human_timedelta('1 day') == timedelta(days=1) True """ - cal = parsedatetime.Calendar() source_dttm = dttm_from_timetuple( source_time.timetuple() if source_time else datetime.now().timetuple() ) - modified_dttm = dttm_from_timetuple(cal.parse(human_readable or "", source_dttm)[0]) - return modified_dttm - source_dttm + return get_past_or_future(human_readable, source_time) - source_dttm def parse_past_timedelta( diff --git a/superset/utils/pandas_postprocessing.py b/superset/utils/pandas_postprocessing.py index e22a9c744e..75daba5881 100644 --- a/superset/utils/pandas_postprocessing.py +++ b/superset/utils/pandas_postprocessing.py @@ -21,16 +21,18 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union import geohash as geohash_lib import numpy as np +import pandas as pd from flask_babel import gettext as _ from geopy.point import Point from pandas import DataFrame, NamedAgg, Series, Timestamp -from superset.constants import NULL_STRING +from superset.constants import NULL_STRING, PandasAxis, PandasPostprocessingCompare from superset.exceptions import QueryObjectValidationError from superset.utils.core import ( DTTM_ALIAS, PostProcessingBoxplotWhiskerType, PostProcessingContributionOrientation, + TIME_COMPARISION, ) NUMPY_FUNCTIONS = { @@ -327,7 +329,7 @@ def rolling( # pylint: disable=too-many-arguments df: DataFrame, columns: Dict[str, str], rolling_type: str, - window: int, + window: Optional[int] = None, rolling_type_options: Optional[Dict[str, Any]] = None, center: bool = False, win_type: Optional[str] = None, @@ -357,8 +359,10 @@ def rolling( # pylint: disable=too-many-arguments rolling_type_options = rolling_type_options or {} df_rolling = df[columns.keys()] kwargs: Dict[str, Union[str, int]] = {} - if not window: + if window is None: raise QueryObjectValidationError(_("Undefined window for rolling operation")) + if window == 0: + raise QueryObjectValidationError(_("Window must be > 0")) kwargs["window"] = window if min_periods is not None: @@ -425,9 +429,14 @@ def select( @validate_column_args("columns") -def diff(df: DataFrame, columns: Dict[str, str], periods: int = 1,) -> DataFrame: +def diff( + df: DataFrame, + columns: Dict[str, str], + periods: int = 1, + axis: PandasAxis = PandasAxis.ROW, +) -> DataFrame: """ - Calculate row-by-row difference for select columns. + Calculate row-by-row or column-by-column difference for select columns. :param df: DataFrame on which the diff will be based. :param columns: columns on which to perform diff, mapping source column to @@ -436,14 +445,69 @@ def diff(df: DataFrame, columns: Dict[str, str], periods: int = 1,) -> DataFrame on diff values calculated from `y`, leaving the original column `y` unchanged. :param periods: periods to shift for calculating difference. + :param axis: 0 for row, 1 for column. default 0. :return: DataFrame with diffed columns :raises QueryObjectValidationError: If the request in incorrect """ df_diff = df[columns.keys()] - df_diff = df_diff.diff(periods=periods) + df_diff = df_diff.diff(periods=periods, axis=axis) return _append_columns(df, df_diff, columns) +# pylint: disable=too-many-arguments +@validate_column_args("source_columns", "compare_columns") +def compare( + df: DataFrame, + source_columns: List[str], + compare_columns: List[str], + compare_type: Optional[PandasPostprocessingCompare], + drop_original_columns: Optional[bool] = False, + precision: Optional[int] = 4, +) -> DataFrame: + """ + Calculate column-by-column changing for select columns. + + :param df: DataFrame on which the compare will be based. + :param source_columns: Main query columns + :param compare_columns: Columns being compared + :param compare_type: Type of compare. Choice of `absolute`, `percentage` or `ratio` + :param drop_original_columns: Whether to remove the source columns and + compare columns. + :param precision: Round a change rate to a variable number of decimal places. + :return: DataFrame with compared columns. + :raises QueryObjectValidationError: If the request in incorrect. + """ + if len(source_columns) != len(compare_columns): + raise QueryObjectValidationError( + _("`compare_columns` must have the same length as `source_columns`.") + ) + if compare_type not in tuple(PandasPostprocessingCompare): + raise QueryObjectValidationError( + _("`compare_type` must be `absolute`, `percentage` or `ratio`") + ) + if len(source_columns) == 0: + return df + + for s_col, c_col in zip(source_columns, compare_columns): + if compare_type == PandasPostprocessingCompare.ABS: + diff_series = df[s_col] - df[c_col] + elif compare_type == PandasPostprocessingCompare.PCT: + diff_series = ( + ((df[s_col] - df[c_col]) / df[s_col]).astype(float).round(precision) + ) + else: + # compare_type == "ratio" + diff_series = (df[s_col] / df[c_col]).astype(float).round(precision) + diff_df = diff_series.to_frame( + name=TIME_COMPARISION.join([compare_type, s_col, c_col]) + ) + df = pd.concat([df, diff_df], axis=1) + + if drop_original_columns: + df = df.drop(source_columns + compare_columns, axis=1) + return df + + @validate_column_args("columns") def cum(df: DataFrame, columns: Dict[str, str], operator: str) -> DataFrame: """ diff --git a/superset/views/api.py b/superset/views/api.py index 80cb5b6664..0ef973cdbe 100644 --- a/superset/views/api.py +++ b/superset/views/api.py @@ -25,8 +25,8 @@ from flask_appbuilder.security.decorators import has_access_api from superset import db, event_logger from superset.charts.commands.exceptions import ( + TimeRangeAmbiguousError, TimeRangeParseFailError, - TimeRangeUnclearError, ) from superset.common.query_context import QueryContext from superset.legacy import update_time_range @@ -97,6 +97,6 @@ class Api(BaseSupersetView): "timeRange": time_range, } return self.json_response({"result": result}) - except (ValueError, TimeRangeParseFailError, TimeRangeUnclearError) as error: + except (ValueError, TimeRangeParseFailError, TimeRangeAmbiguousError) as error: error_msg = {"message": f"Unexpected time range: {error}"} return self.json_response(error_msg, 400) diff --git a/tests/integration_tests/fixtures/dataframes.py b/tests/integration_tests/fixtures/dataframes.py index ab50425d73..28bc32fade 100644 --- a/tests/integration_tests/fixtures/dataframes.py +++ b/tests/integration_tests/fixtures/dataframes.py @@ -130,6 +130,15 @@ timeseries_df = DataFrame( data={"label": ["x", "y", "z", "q"], "y": [1.0, 2.0, 3.0, 4.0]}, ) +timeseries_df2 = DataFrame( + index=to_datetime(["2019-01-01", "2019-01-02", "2019-01-05", "2019-01-07"]), + data={ + "label": ["x", "y", "z", "q"], + "y": [2.0, 2.0, 2.0, 2.0], + "z": [2.0, 4.0, 10.0, 8.0], + }, +) + lonlat_df = DataFrame( { "city": ["New York City", "Sydney"], diff --git a/tests/integration_tests/fixtures/query_context.py b/tests/integration_tests/fixtures/query_context.py index 268e403875..ff79289748 100644 --- a/tests/integration_tests/fixtures/query_context.py +++ b/tests/integration_tests/fixtures/query_context.py @@ -195,7 +195,7 @@ POSTPROCESSING_OPERATIONS = { def get_query_object( - query_name: str, add_postprocessing_operations: bool + query_name: str, add_postprocessing_operations: bool, add_time_offsets: bool, ) -> Dict[str, Any]: if query_name not in QUERY_OBJECTS: raise Exception(f"QueryObject fixture not defined for datasource: {query_name}") @@ -212,6 +212,9 @@ def get_query_object( query_object = copy.deepcopy(obj) if add_postprocessing_operations: query_object["post_processing"] = _get_postprocessing_operation(query_name) + if add_time_offsets: + query_object["time_offsets"] = ["1 year ago"] + return query_object @@ -224,7 +227,9 @@ def _get_postprocessing_operation(query_name: str) -> List[Dict[str, Any]]: def get_query_context( - query_name: str, add_postprocessing_operations: bool = False, + query_name: str, + add_postprocessing_operations: bool = False, + add_time_offsets: bool = False, ) -> Dict[str, Any]: """ Create a request payload for retrieving a QueryContext object via the @@ -236,11 +241,16 @@ def get_query_context( :param datasource_id: id of datasource to query. :param datasource_type: type of datasource to query. :param add_postprocessing_operations: Add post-processing operations to QueryObject + :param add_time_offsets: Add time offsets to QueryObject(advanced analytics) :return: Request payload """ table_name = query_name.split(":")[0] table = get_table_by_name(table_name) return { "datasource": {"id": table.id, "type": table.type}, - "queries": [get_query_object(query_name, add_postprocessing_operations)], + "queries": [ + get_query_object( + query_name, add_postprocessing_operations, add_time_offsets, + ) + ], } diff --git a/tests/integration_tests/pandas_postprocessing_tests.py b/tests/integration_tests/pandas_postprocessing_tests.py index 3f54f7e79d..5cb7d55113 100644 --- a/tests/integration_tests/pandas_postprocessing_tests.py +++ b/tests/integration_tests/pandas_postprocessing_tests.py @@ -38,6 +38,7 @@ from .fixtures.dataframes import ( names_df, timeseries_df, prophet_df, + timeseries_df2, ) AGGREGATES_SINGLE = {"idx_nulls": {"operator": "sum"}} @@ -422,6 +423,64 @@ class TestPostProcessing(SupersetTestCase): columns={"abc": "abc"}, ) + # diff by columns + post_df = proc.diff(df=timeseries_df2, columns={"y": "y", "z": "z"}, axis=1) + self.assertListEqual(post_df.columns.tolist(), ["label", "y", "z"]) + self.assertListEqual(series_to_list(post_df["z"]), [0.0, 2.0, 8.0, 6.0]) + + def test_compare(self): + # `absolute` comparison + post_df = proc.compare( + df=timeseries_df2, + source_columns=["y"], + compare_columns=["z"], + compare_type="absolute", + ) + self.assertListEqual( + post_df.columns.tolist(), ["label", "y", "z", "absolute__y__z",] + ) + self.assertListEqual( + series_to_list(post_df["absolute__y__z"]), [0.0, -2.0, -8.0, -6.0], + ) + + # drop original columns + post_df = proc.compare( + df=timeseries_df2, + source_columns=["y"], + compare_columns=["z"], + compare_type="absolute", + drop_original_columns=True, + ) + self.assertListEqual(post_df.columns.tolist(), ["label", "absolute__y__z",]) + + # `percentage` comparison + post_df = proc.compare( + df=timeseries_df2, + source_columns=["y"], + compare_columns=["z"], + compare_type="percentage", + ) + self.assertListEqual( + post_df.columns.tolist(), ["label", "y", "z", "percentage__y__z",] + ) + self.assertListEqual( + series_to_list(post_df["percentage__y__z"]), [0.0, -1.0, -4.0, -3], + ) + + # `ratio` comparison + post_df = proc.compare( + df=timeseries_df2, + source_columns=["y"], + compare_columns=["z"], + compare_type="ratio", + ) + self.assertListEqual( + post_df.columns.tolist(), ["label", "y", "z", "ratio__y__z",] + ) + self.assertListEqual( + series_to_list(post_df["ratio__y__z"]), [1.0, 0.5, 0.2, 0.25], + ) + def test_cum(self): # create new column (cumsum) post_df = proc.cum(df=timeseries_df, columns={"y": "y2"}, operator="sum",) diff --git a/tests/integration_tests/query_context_tests.py b/tests/integration_tests/query_context_tests.py index 9c04c62307..3d7821808b 100644 --- a/tests/integration_tests/query_context_tests.py +++ b/tests/integration_tests/query_context_tests.py @@ -222,6 +222,20 @@ class TestQueryContext(SupersetTestCase): cache_key = query_context.query_cache_key(query_object) self.assertNotEqual(cache_key_original, cache_key) + def test_query_cache_key_changes_when_time_offsets_is_updated(self): + self.login(username="admin") + payload = get_query_context("birth_names", add_time_offsets=True) + + query_context = ChartDataQueryContextSchema().load(payload) + query_object = query_context.queries[0] + cache_key_original = query_context.query_cache_key(query_object) + + payload["queries"][0]["time_offsets"].pop() + query_context = ChartDataQueryContextSchema().load(payload) + query_object = query_context.queries[0] + cache_key = query_context.query_cache_key(query_object) + self.assertNotEqual(cache_key_original, cache_key) + def test_query_context_time_range_endpoints(self): """ Ensure that time_range_endpoints are populated automatically when missing @@ -476,3 +490,92 @@ class TestQueryContext(SupersetTestCase): responses = query_context.get_payload() new_cache_key = responses["queries"][0]["cache_key"] self.assertEqual(orig_cache_key, new_cache_key) + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_time_offsets_in_query_object(self): + """ + Ensure that time_offsets can generate the correct query + """ + self.login(username="admin") + payload = get_query_context("birth_names") + payload["queries"][0]["metrics"] = ["sum__num"] + payload["queries"][0]["groupby"] = ["name"] + payload["queries"][0]["is_timeseries"] = True + payload["queries"][0]["timeseries_limit"] = 5 + payload["queries"][0]["time_offsets"] = ["1 year ago", "1 year later"] + payload["queries"][0]["time_range"] = "1990 : 1991" + query_context = ChartDataQueryContextSchema().load(payload) + responses = query_context.get_payload() + self.assertEqual( + responses["queries"][0]["colnames"], + [ + "__timestamp", + "name", + "sum__num", + "sum__num__1 year ago", + "sum__num__1 year later", + ], + ) + + sqls = [ + sql for sql in responses["queries"][0]["query"].split(";") if sql.strip() + ] + self.assertEqual(len(sqls), 3) + # 1 year ago + assert re.search(r"1989-01-01.+1990-01-01", sqls[1], re.S) + assert re.search(r"1990-01-01.+1991-01-01", sqls[1], re.S) + + # # 1 year later + assert re.search(r"1991-01-01.+1992-01-01", sqls[2], re.S) + assert re.search(r"1990-01-01.+1991-01-01", sqls[2], re.S) + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_processing_time_offsets_cache(self): + """ + Ensure that time_offsets can generate the correct query + """ + self.login(username="admin") + payload = get_query_context("birth_names") + payload["queries"][0]["metrics"] = ["sum__num"] + payload["queries"][0]["groupby"] = ["name"] + payload["queries"][0]["is_timeseries"] = True + payload["queries"][0]["timeseries_limit"] = 5 + payload["queries"][0]["time_offsets"] = [] + payload["queries"][0]["time_range"] = "1990 : 1991" + query_context = ChartDataQueryContextSchema().load(payload) + query_object = query_context.queries[0] + query_result = query_context.get_query_result(query_object) + # get main query dataframe + df = query_result.df + + payload["queries"][0]["time_offsets"] = ["1 year ago", "1 year later"] + query_context = ChartDataQueryContextSchema().load(payload) + query_object = query_context.queries[0] + # query without cache + query_context.processing_time_offsets(df, query_object) + # query with cache + rv = query_context.processing_time_offsets(df, query_object) + cache_keys = rv["cache_keys"] + cache_keys__1_year_ago = cache_keys[0] + cache_keys__1_year_later = cache_keys[1] + self.assertIsNotNone(cache_keys__1_year_ago) + self.assertIsNotNone(cache_keys__1_year_later) + self.assertNotEqual(cache_keys__1_year_ago, cache_keys__1_year_later) + + # swap offsets + payload["queries"][0]["time_offsets"] = ["1 year later", "1 year ago"] + query_context = ChartDataQueryContextSchema().load(payload) + query_object = query_context.queries[0] + rv = query_context.processing_time_offsets(df, query_object) + cache_keys = rv["cache_keys"] + self.assertEqual(cache_keys__1_year_ago, cache_keys[1]) + self.assertEqual(cache_keys__1_year_later, cache_keys[0]) + + # remove all offsets + payload["queries"][0]["time_offsets"] = [] + query_context = ChartDataQueryContextSchema().load(payload) + query_object = query_context.queries[0] + rv = query_context.processing_time_offsets(df, query_object,) + self.assertIs(rv["df"], df) + self.assertEqual(rv["queries"], []) + self.assertEqual(rv["cache_keys"], []) diff --git a/tests/integration_tests/utils/date_parser_tests.py b/tests/integration_tests/utils/date_parser_tests.py index f04f0da45f..4cf979e83c 100644 --- a/tests/integration_tests/utils/date_parser_tests.py +++ b/tests/integration_tests/utils/date_parser_tests.py @@ -14,16 +14,19 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from datetime import datetime, timedelta +from datetime import date, datetime, timedelta from unittest.mock import patch +from dateutil.relativedelta import relativedelta + from superset.charts.commands.exceptions import ( + TimeRangeAmbiguousError, TimeRangeParseFailError, - TimeRangeUnclearError, ) from superset.utils.date_parser import ( DateRangeMigration, datetime_eval, + get_past_or_future, get_since_until, parse_human_datetime, parse_human_timedelta, @@ -288,16 +291,48 @@ class TestDateParser(SupersetTestCase): self.assertEqual(parse_past_timedelta("52 weeks"), timedelta(364)) self.assertEqual(parse_past_timedelta("1 month"), timedelta(31)) + def test_get_past_or_future(self): + # 2020 is a leap year + dttm = datetime(2020, 2, 29) + self.assertEqual(get_past_or_future("1 year", dttm), datetime(2021, 2, 28)) + self.assertEqual(get_past_or_future("-1 year", dttm), datetime(2019, 2, 28)) + self.assertEqual(get_past_or_future("1 month", dttm), datetime(2020, 3, 29)) + self.assertEqual(get_past_or_future("3 month", dttm), datetime(2020, 5, 29)) + def test_parse_human_datetime(self): - with self.assertRaises(TimeRangeUnclearError): + with self.assertRaises(TimeRangeAmbiguousError): parse_human_datetime(" 2 days ") - with self.assertRaises(TimeRangeUnclearError): + with self.assertRaises(TimeRangeAmbiguousError): parse_human_datetime("2 day") with self.assertRaises(TimeRangeParseFailError): parse_human_datetime("xxxxxxx") + self.assertEqual(parse_human_datetime("2015-04-03"), datetime(2015, 4, 3, 0, 0)) + + self.assertEqual( + parse_human_datetime("2/3/1969"), datetime(1969, 2, 3, 0, 0), + ) + + self.assertLessEqual(parse_human_datetime("now"), datetime.now()) + + self.assertLess(parse_human_datetime("yesterday"), datetime.now()) + + self.assertEqual( + date.today() - timedelta(1), parse_human_datetime("yesterday").date() + ) + + self.assertEqual( + parse_human_datetime("one year ago").date(), + (datetime.now() - relativedelta(years=1)).date(), + ) + + self.assertEqual( + parse_human_datetime("2 years after").date(), + (datetime.now() + relativedelta(years=2)).date(), + ) + def test_DateRangeMigration(self): params = '{"time_range": " 8 days : 2020-03-10T00:00:00"}' self.assertRegex(params, DateRangeMigration.x_dateunit_in_since)