feat: add optional prophet forecasting functionality to chart data api (#10324)

* feat: add prophet post processing operation

* add tests

* lint

* whitespace

* remove whitespace

* address comments

* add note to UPDATING.md
This commit is contained in:
Ville Brofeldt 2020-07-20 18:46:51 +03:00 committed by GitHub
parent 73797b8b64
commit 7af8b2b3f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 357 additions and 23 deletions

View File

@ -23,6 +23,8 @@ assists people when migrating to a new version.
## Next
* [10324](https://github.com/apache/incubator-superset/pull/10324): Facebook Prophet has been introduced as an optional dependency to add support for timeseries forecasting in the chart data API. To enable this feature, install Superset with the optional dependency `prophet` or directly `pip install fbprophet`.
* [10320](https://github.com/apache/incubator-superset/pull/10320): References to blacklst/whitelist language have been replaced with more appropriate alternatives. All configs refencing containing `WHITE`/`BLACK` have been replaced with `ALLOW`/`DENY`. Affected config variables that need to be updated: `TIME_GRAIN_BLACKLIST`, `VIZ_TYPE_BLACKLIST`, `DRUID_DATA_SOURCE_BLACKLIST`.
* [9964](https://github.com/apache/incubator-superset/pull/9964): Breaking change on Flask-AppBuilder 3. If you're using OAuth, find out what needs to be changed [here](https://github.com/dpgaspar/Flask-AppBuilder/blob/master/README.rst#change-log).

View File

@ -124,6 +124,7 @@ setup(
"cockroachdb": ["cockroachdb==0.3.3"],
"thumbnails": ["Pillow>=7.0.0, <8.0.0"],
"excel": ["xlrd>=1.2.0, <1.3"],
"prophet": ["fbprophet>=0.6, <0.7"],
},
python_requires="~=3.6",
author="Apache Software Foundation",

View File

@ -101,6 +101,26 @@ openapi_spec_methods_override = {
}
TIME_GRAINS = (
"PT1S",
"PT1M",
"PT5M",
"PT10M",
"PT15M",
"PT0.5H",
"PT1H",
"P1D",
"P1W",
"P1M",
"P0.25Y",
"P1Y",
"1969-12-28T00:00:00Z/P1W", # Week starting Sunday
"1969-12-29T00:00:00Z/P1W", # Week starting Monday
"P1W/1970-01-03T00:00:00Z", # Week ending Saturday
"P1W/1970-01-04T00:00:00Z", # Week ending Sunday
)
class ChartPostSchema(Schema):
"""
Schema to add a new chart.
@ -423,6 +443,62 @@ class ChartDataContributionOptionsSchema(ChartDataPostProcessingOperationOptions
)
class ChartDataProphetOptionsSchema(ChartDataPostProcessingOperationOptionsSchema):
"""
Prophet operation config.
"""
time_grain = fields.String(
description="Time grain used to specify time period increments in prediction. "
"Supports [ISO 8601](https://en.wikipedia.org/wiki/ISO_8601#Durations) "
"durations.",
validate=validate.OneOf(choices=TIME_GRAINS),
example="P1D",
required=True,
)
periods = fields.Integer(
descrption="Time periods (in units of `time_grain`) to predict into the future",
min=1,
example=7,
required=True,
)
confidence_interval = fields.Float(
description="Width of predicted confidence interval",
validate=[
Range(
min=0,
max=1,
min_inclusive=False,
max_inclusive=False,
error=_("`confidence_interval` must be between 0 and 1 (exclusive)"),
)
],
example=0.8,
required=True,
)
yearly_seasonality = fields.Raw(
# TODO: add correct union type once supported by Marshmallow
description="Should yearly seasonality be applied. "
"An integer value will specify Fourier order of seasonality, `None` will "
"automatically detect seasonality.",
example=False,
)
weekly_seasonality = fields.Raw(
# TODO: add correct union type once supported by Marshmallow
description="Should weekly seasonality be applied. "
"An integer value will specify Fourier order of seasonality, `None` will "
"automatically detect seasonality.",
example=False,
)
monthly_seasonality = fields.Raw(
# TODO: add correct union type once supported by Marshmallow
description="Should monthly seasonality be applied. "
"An integer value will specify Fourier order of seasonality, `None` will "
"automatically detect seasonality.",
example=False,
)
class ChartDataPivotOptionsSchema(ChartDataPostProcessingOperationOptionsSchema):
"""
Pivot operation config.
@ -534,6 +610,7 @@ class ChartDataPostProcessingOperationSchema(Schema):
"geohash_decode",
"geohash_encode",
"pivot",
"prophet",
"rolling",
"select",
"sort",
@ -613,26 +690,7 @@ class ChartDataExtrasSchema(Schema):
description="To what level of granularity should the temporal column be "
"aggregated. Supports "
"[ISO 8601](https://en.wikipedia.org/wiki/ISO_8601#Durations) durations.",
validate=validate.OneOf(
choices=(
"PT1S",
"PT1M",
"PT5M",
"PT10M",
"PT15M",
"PT0.5H",
"PT1H",
"P1D",
"P1W",
"P1M",
"P0.25Y",
"P1Y",
"1969-12-28T00:00:00Z/P1W", # Week starting Sunday
"1969-12-29T00:00:00Z/P1W", # Week starting Monday
"P1W/1970-01-03T00:00:00Z", # Week ending Saturday
"P1W/1970-01-04T00:00:00Z", # Week ending Sunday
),
),
validate=validate.OneOf(choices=TIME_GRAINS),
example="P1D",
allow_none=True,
)

View File

@ -72,6 +72,25 @@ ALLOWLIST_CUMULATIVE_FUNCTIONS = (
"cumsum",
)
PROPHET_TIME_GRAIN_MAP = {
"PT1S": "S",
"PT1M": "min",
"PT5M": "5min",
"PT10M": "10min",
"PT15M": "15min",
"PT0.5H": "30min",
"PT1H": "H",
"P1D": "D",
"P1W": "W",
"P1M": "M",
"P0.25Y": "Q",
"P1Y": "A",
"1969-12-28T00:00:00Z/P1W": "W",
"1969-12-29T00:00:00Z/P1W": "W",
"P1W/1970-01-03T00:00:00Z": "W",
"P1W/1970-01-04T00:00:00Z": "W",
}
def _flatten_column_after_pivot(
column: Union[str, Tuple[str, ...]], aggregates: Dict[str, Dict[str, Any]]
@ -544,3 +563,126 @@ def contribution(
if temporal_series is not None:
contribution_df.insert(0, DTTM_ALIAS, temporal_series)
return contribution_df
def _prophet_parse_seasonality(
input_value: Optional[Union[bool, int]]
) -> Union[bool, str, int]:
if input_value is None:
return "auto"
if isinstance(input_value, bool):
return input_value
try:
return int(input_value)
except ValueError:
return input_value
def _prophet_fit_and_predict( # pylint: disable=too-many-arguments
df: DataFrame,
confidence_interval: float,
yearly_seasonality: Union[bool, str, int],
weekly_seasonality: Union[bool, str, int],
daily_seasonality: Union[bool, str, int],
periods: int,
freq: str,
) -> DataFrame:
"""
Fit a prophet model and return a DataFrame with predicted results.
"""
try:
from fbprophet import Prophet # pylint: disable=import-error
except ModuleNotFoundError:
raise QueryObjectValidationError(_("`fbprophet` package not installed"))
model = Prophet(
interval_width=confidence_interval,
yearly_seasonality=yearly_seasonality,
weekly_seasonality=weekly_seasonality,
daily_seasonality=daily_seasonality,
)
model.fit(df)
future = model.make_future_dataframe(periods=periods, freq=freq)
forecast = model.predict(future)[["ds", "yhat", "yhat_lower", "yhat_upper"]]
return forecast.join(df.set_index("ds"), on="ds").set_index(["ds"])
def prophet( # pylint: disable=too-many-arguments
df: DataFrame,
time_grain: str,
periods: int,
confidence_interval: float,
yearly_seasonality: Optional[Union[bool, int]] = None,
weekly_seasonality: Optional[Union[bool, int]] = None,
daily_seasonality: Optional[Union[bool, int]] = None,
) -> DataFrame:
"""
Add forecasts to each series in a timeseries dataframe, along with confidence
intervals for the prediction. For each series, the operation creates three
new columns with the column name suffixed with the following values:
- `__yhat`: the forecast for the given date
- `__yhat_lower`: the lower bound of the forecast for the given date
- `__yhat_upper`: the upper bound of the forecast for the given date
- `__yhat_upper`: the upper bound of the forecast for the given date
:param df: DataFrame containing all-numeric data (temporal column ignored)
:param time_grain: Time grain used to specify time period increments in prediction
:param periods: Time periods (in units of `time_grain`) to predict into the future
:param confidence_interval: Width of predicted confidence interval
:param yearly_seasonality: Should yearly seasonality be applied.
An integer value will specify Fourier order of seasonality.
:param weekly_seasonality: Should weekly seasonality be applied.
An integer value will specify Fourier order of seasonality, `None` will
automatically detect seasonality.
:param daily_seasonality: Should daily seasonality be applied.
An integer value will specify Fourier order of seasonality, `None` will
automatically detect seasonality.
:return: DataFrame with contributions, with temporal column at beginning if present
"""
# validate inputs
if not time_grain:
raise QueryObjectValidationError(_("Time grain missing"))
if time_grain not in PROPHET_TIME_GRAIN_MAP:
raise QueryObjectValidationError(
_("Unsupported time grain: %(time_grain)s", time_grain=time_grain,)
)
freq = PROPHET_TIME_GRAIN_MAP[time_grain]
# check type at runtime due to marhsmallow schema not being able to handle
# union types
if not periods or periods < 0 or not isinstance(periods, int):
raise QueryObjectValidationError(_("Periods must be a positive integer value"))
if not confidence_interval or confidence_interval <= 0 or confidence_interval >= 1:
raise QueryObjectValidationError(
_("Confidence interval must be between 0 and 1 (exclusive)")
)
if DTTM_ALIAS not in df.columns:
raise QueryObjectValidationError(_("DataFrame must include temporal column"))
if len(df.columns) < 2:
raise QueryObjectValidationError(_("DataFrame include at least one series"))
target_df = DataFrame()
for column in [column for column in df.columns if column != DTTM_ALIAS]:
fit_df = _prophet_fit_and_predict(
df=df[[DTTM_ALIAS, column]].rename(columns={DTTM_ALIAS: "ds", column: "y"}),
confidence_interval=confidence_interval,
yearly_seasonality=_prophet_parse_seasonality(yearly_seasonality),
weekly_seasonality=_prophet_parse_seasonality(weekly_seasonality),
daily_seasonality=_prophet_parse_seasonality(daily_seasonality),
periods=periods,
freq=freq,
)
new_columns = [
f"{column}__yhat",
f"{column}__yhat_lower",
f"{column}__yhat_upper",
f"{column}",
]
fit_df.columns = new_columns
if target_df.empty:
target_df = fit_df
else:
for new_column in new_columns:
target_df = target_df.assign(**{new_column: fit_df[new_column]})
target_df.reset_index(level=0, inplace=True)
return target_df.rename(columns={"ds": DTTM_ALIAS})

View File

@ -21,8 +21,9 @@ from typing import List, Optional
from datetime import datetime
from unittest import mock
import prison
import humanize
import prison
import pytest
from sqlalchemy.sql import func
from tests.test_app import app
@ -796,6 +797,42 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin):
result = response_payload["result"][0]
self.assertEqual(result["rowcount"], 10)
def test_chart_data_prophet(self):
"""
Chart data API: Ensure prophet post transformation works
"""
pytest.importorskip("fbprophet")
self.login(username="admin")
table = self.get_table_by_name("birth_names")
request_payload = get_query_context(table.name, table.id, table.type)
time_grain = "P1Y"
request_payload["queries"][0]["is_timeseries"] = True
request_payload["queries"][0]["groupby"] = []
request_payload["queries"][0]["extras"] = {"time_grain_sqla": time_grain}
request_payload["queries"][0]["granularity"] = "ds"
request_payload["queries"][0]["post_processing"] = [
{
"operation": "prophet",
"options": {
"time_grain": time_grain,
"periods": 3,
"confidence_interval": 0.9,
},
}
]
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
print(rv.data)
self.assertEqual(rv.status_code, 200)
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]
row = result["data"][0]
self.assertIn("__timestamp", row)
self.assertIn("sum__num", row)
self.assertIn("sum__num__yhat", row)
self.assertIn("sum__num__yhat_upper", row)
self.assertIn("sum__num__yhat_lower", row)
self.assertEqual(result["rowcount"], 47)
def test_chart_data_no_data(self):
"""
Chart data API: Test chart data with empty result

View File

@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import date
from datetime import date, datetime
from pandas import DataFrame, to_datetime
@ -133,3 +133,16 @@ lonlat_df = DataFrame(
],
}
)
prophet_df = DataFrame(
{
"__timestamp": [
datetime(2018, 12, 31),
datetime(2019, 12, 31),
datetime(2020, 12, 31),
datetime(2021, 12, 31),
],
"a": [1.1, 1, 1.9, 3.15],
"b": [4, 3, 4.1, 3.95],
}
)

View File

@ -20,13 +20,14 @@ import math
from typing import Any, List, Optional
from pandas import DataFrame, Series
import pytest
from superset.exceptions import QueryObjectValidationError
from superset.utils import pandas_postprocessing as proc
from superset.utils.core import DTTM_ALIAS, PostProcessingContributionOrientation
from .base_tests import SupersetTestCase
from .fixtures.dataframes import categories_df, lonlat_df, timeseries_df
from .fixtures.dataframes import categories_df, lonlat_df, timeseries_df, prophet_df
AGGREGATES_SINGLE = {"idx_nulls": {"operator": "sum"}}
AGGREGATES_MULTIPLE = {
@ -508,3 +509,83 @@ class TestPostProcessing(SupersetTestCase):
self.assertListEqual(df.columns.tolist(), ["a", "b"])
self.assertListEqual(series_to_list(column_df["a"]), [0.25, 0.75])
self.assertListEqual(series_to_list(column_df["b"]), [0.1, 0.9])
def test_prophet_valid(self):
pytest.importorskip("fbprophet")
df = proc.prophet(
df=prophet_df, time_grain="P1M", periods=3, confidence_interval=0.9
)
columns = {column for column in df.columns}
assert columns == {
DTTM_ALIAS,
"a__yhat",
"a__yhat_upper",
"a__yhat_lower",
"a",
"b__yhat",
"b__yhat_upper",
"b__yhat_lower",
"b",
}
assert df[DTTM_ALIAS].iloc[0].to_pydatetime() == datetime(2018, 12, 31)
assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2022, 3, 31)
assert len(df) == 7
df = proc.prophet(
df=prophet_df, time_grain="P1M", periods=5, confidence_interval=0.9
)
assert df[DTTM_ALIAS].iloc[0].to_pydatetime() == datetime(2018, 12, 31)
assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2022, 5, 31)
assert len(df) == 9
def test_prophet_missing_temporal_column(self):
df = prophet_df.drop(DTTM_ALIAS, axis=1)
self.assertRaises(
QueryObjectValidationError,
proc.prophet,
df=df,
time_grain="P1M",
periods=3,
confidence_interval=0.9,
)
def test_prophet_incorrect_confidence_interval(self):
self.assertRaises(
QueryObjectValidationError,
proc.prophet,
df=prophet_df,
time_grain="P1M",
periods=3,
confidence_interval=0.0,
)
self.assertRaises(
QueryObjectValidationError,
proc.prophet,
df=prophet_df,
time_grain="P1M",
periods=3,
confidence_interval=1.0,
)
def test_prophet_incorrect_periods(self):
self.assertRaises(
QueryObjectValidationError,
proc.prophet,
df=prophet_df,
time_grain="P1M",
periods=0,
confidence_interval=0.8,
)
def test_prophet_incorrect_time_grain(self):
self.assertRaises(
QueryObjectValidationError,
proc.prophet,
df=prophet_df,
time_grain="yearly",
periods=10,
confidence_interval=0.8,
)