From 7af8b2b3f822d74094f09609cd7e740415f90354 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt <33317356+villebro@users.noreply.github.com> Date: Mon, 20 Jul 2020 18:46:51 +0300 Subject: [PATCH] feat: add optional prophet forecasting functionality to chart data api (#10324) * feat: add prophet post processing operation * add tests * lint * whitespace * remove whitespace * address comments * add note to UPDATING.md --- UPDATING.md | 2 + setup.py | 1 + superset/charts/schemas.py | 98 ++++++++++++---- superset/utils/pandas_postprocessing.py | 142 ++++++++++++++++++++++++ tests/charts/api_tests.py | 39 ++++++- tests/fixtures/dataframes.py | 15 ++- tests/pandas_postprocessing_tests.py | 83 +++++++++++++- 7 files changed, 357 insertions(+), 23 deletions(-) diff --git a/UPDATING.md b/UPDATING.md index 5b4369114e..420cb03106 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -23,6 +23,8 @@ assists people when migrating to a new version. ## Next +* [10324](https://github.com/apache/incubator-superset/pull/10324): Facebook Prophet has been introduced as an optional dependency to add support for timeseries forecasting in the chart data API. To enable this feature, install Superset with the optional dependency `prophet` or directly `pip install fbprophet`. + * [10320](https://github.com/apache/incubator-superset/pull/10320): References to blacklst/whitelist language have been replaced with more appropriate alternatives. All configs refencing containing `WHITE`/`BLACK` have been replaced with `ALLOW`/`DENY`. Affected config variables that need to be updated: `TIME_GRAIN_BLACKLIST`, `VIZ_TYPE_BLACKLIST`, `DRUID_DATA_SOURCE_BLACKLIST`. * [9964](https://github.com/apache/incubator-superset/pull/9964): Breaking change on Flask-AppBuilder 3. If you're using OAuth, find out what needs to be changed [here](https://github.com/dpgaspar/Flask-AppBuilder/blob/master/README.rst#change-log). diff --git a/setup.py b/setup.py index 7a6c5d9f5c..58c22a1709 100644 --- a/setup.py +++ b/setup.py @@ -124,6 +124,7 @@ setup( "cockroachdb": ["cockroachdb==0.3.3"], "thumbnails": ["Pillow>=7.0.0, <8.0.0"], "excel": ["xlrd>=1.2.0, <1.3"], + "prophet": ["fbprophet>=0.6, <0.7"], }, python_requires="~=3.6", author="Apache Software Foundation", diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index 9ccd0f20b9..39bda90b06 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -101,6 +101,26 @@ openapi_spec_methods_override = { } +TIME_GRAINS = ( + "PT1S", + "PT1M", + "PT5M", + "PT10M", + "PT15M", + "PT0.5H", + "PT1H", + "P1D", + "P1W", + "P1M", + "P0.25Y", + "P1Y", + "1969-12-28T00:00:00Z/P1W", # Week starting Sunday + "1969-12-29T00:00:00Z/P1W", # Week starting Monday + "P1W/1970-01-03T00:00:00Z", # Week ending Saturday + "P1W/1970-01-04T00:00:00Z", # Week ending Sunday +) + + class ChartPostSchema(Schema): """ Schema to add a new chart. @@ -423,6 +443,62 @@ class ChartDataContributionOptionsSchema(ChartDataPostProcessingOperationOptions ) +class ChartDataProphetOptionsSchema(ChartDataPostProcessingOperationOptionsSchema): + """ + Prophet operation config. + """ + + time_grain = fields.String( + description="Time grain used to specify time period increments in prediction. " + "Supports [ISO 8601](https://en.wikipedia.org/wiki/ISO_8601#Durations) " + "durations.", + validate=validate.OneOf(choices=TIME_GRAINS), + example="P1D", + required=True, + ) + periods = fields.Integer( + descrption="Time periods (in units of `time_grain`) to predict into the future", + min=1, + example=7, + required=True, + ) + confidence_interval = fields.Float( + description="Width of predicted confidence interval", + validate=[ + Range( + min=0, + max=1, + min_inclusive=False, + max_inclusive=False, + error=_("`confidence_interval` must be between 0 and 1 (exclusive)"), + ) + ], + example=0.8, + required=True, + ) + yearly_seasonality = fields.Raw( + # TODO: add correct union type once supported by Marshmallow + description="Should yearly seasonality be applied. " + "An integer value will specify Fourier order of seasonality, `None` will " + "automatically detect seasonality.", + example=False, + ) + weekly_seasonality = fields.Raw( + # TODO: add correct union type once supported by Marshmallow + description="Should weekly seasonality be applied. " + "An integer value will specify Fourier order of seasonality, `None` will " + "automatically detect seasonality.", + example=False, + ) + monthly_seasonality = fields.Raw( + # TODO: add correct union type once supported by Marshmallow + description="Should monthly seasonality be applied. " + "An integer value will specify Fourier order of seasonality, `None` will " + "automatically detect seasonality.", + example=False, + ) + + class ChartDataPivotOptionsSchema(ChartDataPostProcessingOperationOptionsSchema): """ Pivot operation config. @@ -534,6 +610,7 @@ class ChartDataPostProcessingOperationSchema(Schema): "geohash_decode", "geohash_encode", "pivot", + "prophet", "rolling", "select", "sort", @@ -613,26 +690,7 @@ class ChartDataExtrasSchema(Schema): description="To what level of granularity should the temporal column be " "aggregated. Supports " "[ISO 8601](https://en.wikipedia.org/wiki/ISO_8601#Durations) durations.", - validate=validate.OneOf( - choices=( - "PT1S", - "PT1M", - "PT5M", - "PT10M", - "PT15M", - "PT0.5H", - "PT1H", - "P1D", - "P1W", - "P1M", - "P0.25Y", - "P1Y", - "1969-12-28T00:00:00Z/P1W", # Week starting Sunday - "1969-12-29T00:00:00Z/P1W", # Week starting Monday - "P1W/1970-01-03T00:00:00Z", # Week ending Saturday - "P1W/1970-01-04T00:00:00Z", # Week ending Sunday - ), - ), + validate=validate.OneOf(choices=TIME_GRAINS), example="P1D", allow_none=True, ) diff --git a/superset/utils/pandas_postprocessing.py b/superset/utils/pandas_postprocessing.py index e71df74717..73336ebdb7 100644 --- a/superset/utils/pandas_postprocessing.py +++ b/superset/utils/pandas_postprocessing.py @@ -72,6 +72,25 @@ ALLOWLIST_CUMULATIVE_FUNCTIONS = ( "cumsum", ) +PROPHET_TIME_GRAIN_MAP = { + "PT1S": "S", + "PT1M": "min", + "PT5M": "5min", + "PT10M": "10min", + "PT15M": "15min", + "PT0.5H": "30min", + "PT1H": "H", + "P1D": "D", + "P1W": "W", + "P1M": "M", + "P0.25Y": "Q", + "P1Y": "A", + "1969-12-28T00:00:00Z/P1W": "W", + "1969-12-29T00:00:00Z/P1W": "W", + "P1W/1970-01-03T00:00:00Z": "W", + "P1W/1970-01-04T00:00:00Z": "W", +} + def _flatten_column_after_pivot( column: Union[str, Tuple[str, ...]], aggregates: Dict[str, Dict[str, Any]] @@ -544,3 +563,126 @@ def contribution( if temporal_series is not None: contribution_df.insert(0, DTTM_ALIAS, temporal_series) return contribution_df + + +def _prophet_parse_seasonality( + input_value: Optional[Union[bool, int]] +) -> Union[bool, str, int]: + if input_value is None: + return "auto" + if isinstance(input_value, bool): + return input_value + try: + return int(input_value) + except ValueError: + return input_value + + +def _prophet_fit_and_predict( # pylint: disable=too-many-arguments + df: DataFrame, + confidence_interval: float, + yearly_seasonality: Union[bool, str, int], + weekly_seasonality: Union[bool, str, int], + daily_seasonality: Union[bool, str, int], + periods: int, + freq: str, +) -> DataFrame: + """ + Fit a prophet model and return a DataFrame with predicted results. + """ + try: + from fbprophet import Prophet # pylint: disable=import-error + except ModuleNotFoundError: + raise QueryObjectValidationError(_("`fbprophet` package not installed")) + model = Prophet( + interval_width=confidence_interval, + yearly_seasonality=yearly_seasonality, + weekly_seasonality=weekly_seasonality, + daily_seasonality=daily_seasonality, + ) + model.fit(df) + future = model.make_future_dataframe(periods=periods, freq=freq) + forecast = model.predict(future)[["ds", "yhat", "yhat_lower", "yhat_upper"]] + return forecast.join(df.set_index("ds"), on="ds").set_index(["ds"]) + + +def prophet( # pylint: disable=too-many-arguments + df: DataFrame, + time_grain: str, + periods: int, + confidence_interval: float, + yearly_seasonality: Optional[Union[bool, int]] = None, + weekly_seasonality: Optional[Union[bool, int]] = None, + daily_seasonality: Optional[Union[bool, int]] = None, +) -> DataFrame: + """ + Add forecasts to each series in a timeseries dataframe, along with confidence + intervals for the prediction. For each series, the operation creates three + new columns with the column name suffixed with the following values: + + - `__yhat`: the forecast for the given date + - `__yhat_lower`: the lower bound of the forecast for the given date + - `__yhat_upper`: the upper bound of the forecast for the given date + - `__yhat_upper`: the upper bound of the forecast for the given date + + + :param df: DataFrame containing all-numeric data (temporal column ignored) + :param time_grain: Time grain used to specify time period increments in prediction + :param periods: Time periods (in units of `time_grain`) to predict into the future + :param confidence_interval: Width of predicted confidence interval + :param yearly_seasonality: Should yearly seasonality be applied. + An integer value will specify Fourier order of seasonality. + :param weekly_seasonality: Should weekly seasonality be applied. + An integer value will specify Fourier order of seasonality, `None` will + automatically detect seasonality. + :param daily_seasonality: Should daily seasonality be applied. + An integer value will specify Fourier order of seasonality, `None` will + automatically detect seasonality. + :return: DataFrame with contributions, with temporal column at beginning if present + """ + # validate inputs + if not time_grain: + raise QueryObjectValidationError(_("Time grain missing")) + if time_grain not in PROPHET_TIME_GRAIN_MAP: + raise QueryObjectValidationError( + _("Unsupported time grain: %(time_grain)s", time_grain=time_grain,) + ) + freq = PROPHET_TIME_GRAIN_MAP[time_grain] + # check type at runtime due to marhsmallow schema not being able to handle + # union types + if not periods or periods < 0 or not isinstance(periods, int): + raise QueryObjectValidationError(_("Periods must be a positive integer value")) + if not confidence_interval or confidence_interval <= 0 or confidence_interval >= 1: + raise QueryObjectValidationError( + _("Confidence interval must be between 0 and 1 (exclusive)") + ) + if DTTM_ALIAS not in df.columns: + raise QueryObjectValidationError(_("DataFrame must include temporal column")) + if len(df.columns) < 2: + raise QueryObjectValidationError(_("DataFrame include at least one series")) + + target_df = DataFrame() + for column in [column for column in df.columns if column != DTTM_ALIAS]: + fit_df = _prophet_fit_and_predict( + df=df[[DTTM_ALIAS, column]].rename(columns={DTTM_ALIAS: "ds", column: "y"}), + confidence_interval=confidence_interval, + yearly_seasonality=_prophet_parse_seasonality(yearly_seasonality), + weekly_seasonality=_prophet_parse_seasonality(weekly_seasonality), + daily_seasonality=_prophet_parse_seasonality(daily_seasonality), + periods=periods, + freq=freq, + ) + new_columns = [ + f"{column}__yhat", + f"{column}__yhat_lower", + f"{column}__yhat_upper", + f"{column}", + ] + fit_df.columns = new_columns + if target_df.empty: + target_df = fit_df + else: + for new_column in new_columns: + target_df = target_df.assign(**{new_column: fit_df[new_column]}) + target_df.reset_index(level=0, inplace=True) + return target_df.rename(columns={"ds": DTTM_ALIAS}) diff --git a/tests/charts/api_tests.py b/tests/charts/api_tests.py index c115b1d9c7..9b42a9f697 100644 --- a/tests/charts/api_tests.py +++ b/tests/charts/api_tests.py @@ -21,8 +21,9 @@ from typing import List, Optional from datetime import datetime from unittest import mock -import prison import humanize +import prison +import pytest from sqlalchemy.sql import func from tests.test_app import app @@ -796,6 +797,42 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin): result = response_payload["result"][0] self.assertEqual(result["rowcount"], 10) + def test_chart_data_prophet(self): + """ + Chart data API: Ensure prophet post transformation works + """ + pytest.importorskip("fbprophet") + self.login(username="admin") + table = self.get_table_by_name("birth_names") + request_payload = get_query_context(table.name, table.id, table.type) + time_grain = "P1Y" + request_payload["queries"][0]["is_timeseries"] = True + request_payload["queries"][0]["groupby"] = [] + request_payload["queries"][0]["extras"] = {"time_grain_sqla": time_grain} + request_payload["queries"][0]["granularity"] = "ds" + request_payload["queries"][0]["post_processing"] = [ + { + "operation": "prophet", + "options": { + "time_grain": time_grain, + "periods": 3, + "confidence_interval": 0.9, + }, + } + ] + rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") + print(rv.data) + self.assertEqual(rv.status_code, 200) + response_payload = json.loads(rv.data.decode("utf-8")) + result = response_payload["result"][0] + row = result["data"][0] + self.assertIn("__timestamp", row) + self.assertIn("sum__num", row) + self.assertIn("sum__num__yhat", row) + self.assertIn("sum__num__yhat_upper", row) + self.assertIn("sum__num__yhat_lower", row) + self.assertEqual(result["rowcount"], 47) + def test_chart_data_no_data(self): """ Chart data API: Test chart data with empty result diff --git a/tests/fixtures/dataframes.py b/tests/fixtures/dataframes.py index dd01085a18..d932428799 100644 --- a/tests/fixtures/dataframes.py +++ b/tests/fixtures/dataframes.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from datetime import date +from datetime import date, datetime from pandas import DataFrame, to_datetime @@ -133,3 +133,16 @@ lonlat_df = DataFrame( ], } ) + +prophet_df = DataFrame( + { + "__timestamp": [ + datetime(2018, 12, 31), + datetime(2019, 12, 31), + datetime(2020, 12, 31), + datetime(2021, 12, 31), + ], + "a": [1.1, 1, 1.9, 3.15], + "b": [4, 3, 4.1, 3.95], + } +) diff --git a/tests/pandas_postprocessing_tests.py b/tests/pandas_postprocessing_tests.py index ea708349ea..479df423c6 100644 --- a/tests/pandas_postprocessing_tests.py +++ b/tests/pandas_postprocessing_tests.py @@ -20,13 +20,14 @@ import math from typing import Any, List, Optional from pandas import DataFrame, Series +import pytest from superset.exceptions import QueryObjectValidationError from superset.utils import pandas_postprocessing as proc from superset.utils.core import DTTM_ALIAS, PostProcessingContributionOrientation from .base_tests import SupersetTestCase -from .fixtures.dataframes import categories_df, lonlat_df, timeseries_df +from .fixtures.dataframes import categories_df, lonlat_df, timeseries_df, prophet_df AGGREGATES_SINGLE = {"idx_nulls": {"operator": "sum"}} AGGREGATES_MULTIPLE = { @@ -508,3 +509,83 @@ class TestPostProcessing(SupersetTestCase): self.assertListEqual(df.columns.tolist(), ["a", "b"]) self.assertListEqual(series_to_list(column_df["a"]), [0.25, 0.75]) self.assertListEqual(series_to_list(column_df["b"]), [0.1, 0.9]) + + def test_prophet_valid(self): + pytest.importorskip("fbprophet") + + df = proc.prophet( + df=prophet_df, time_grain="P1M", periods=3, confidence_interval=0.9 + ) + columns = {column for column in df.columns} + assert columns == { + DTTM_ALIAS, + "a__yhat", + "a__yhat_upper", + "a__yhat_lower", + "a", + "b__yhat", + "b__yhat_upper", + "b__yhat_lower", + "b", + } + assert df[DTTM_ALIAS].iloc[0].to_pydatetime() == datetime(2018, 12, 31) + assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2022, 3, 31) + assert len(df) == 7 + + df = proc.prophet( + df=prophet_df, time_grain="P1M", periods=5, confidence_interval=0.9 + ) + assert df[DTTM_ALIAS].iloc[0].to_pydatetime() == datetime(2018, 12, 31) + assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2022, 5, 31) + assert len(df) == 9 + + def test_prophet_missing_temporal_column(self): + df = prophet_df.drop(DTTM_ALIAS, axis=1) + + self.assertRaises( + QueryObjectValidationError, + proc.prophet, + df=df, + time_grain="P1M", + periods=3, + confidence_interval=0.9, + ) + + def test_prophet_incorrect_confidence_interval(self): + self.assertRaises( + QueryObjectValidationError, + proc.prophet, + df=prophet_df, + time_grain="P1M", + periods=3, + confidence_interval=0.0, + ) + + self.assertRaises( + QueryObjectValidationError, + proc.prophet, + df=prophet_df, + time_grain="P1M", + periods=3, + confidence_interval=1.0, + ) + + def test_prophet_incorrect_periods(self): + self.assertRaises( + QueryObjectValidationError, + proc.prophet, + df=prophet_df, + time_grain="P1M", + periods=0, + confidence_interval=0.8, + ) + + def test_prophet_incorrect_time_grain(self): + self.assertRaises( + QueryObjectValidationError, + proc.prophet, + df=prophet_df, + time_grain="yearly", + periods=10, + confidence_interval=0.8, + )