feat(prophet): enable confidence intervals and y_hat without forecast (#17658)

* enable confidence intervals and y_hat without forecast

* fix if statement

Co-authored-by: Ville Brofeldt <33317356+villebro@users.noreply.github.com>

Co-authored-by: Ville Brofeldt <33317356+villebro@users.noreply.github.com>
This commit is contained in:
Shiva Raisinghani 2021-12-07 23:56:18 -08:00 committed by GitHub
parent 418c0b4e48
commit cd88b8e81e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 26 additions and 4 deletions

View File

@ -542,7 +542,7 @@ class ChartDataProphetOptionsSchema(ChartDataPostProcessingOperationOptionsSchem
)
periods = fields.Integer(
descrption="Time periods (in units of `time_grain`) to predict into the future",
min=1,
min=0,
example=7,
required=True,
)

View File

@ -820,8 +820,8 @@ def prophet( # pylint: disable=too-many-arguments
freq = PROPHET_TIME_GRAIN_MAP[time_grain]
# check type at runtime due to marhsmallow schema not being able to handle
# union types
if not periods or periods < 0 or not isinstance(periods, int):
raise QueryObjectValidationError(_("Periods must be a positive integer value"))
if not isinstance(periods, int) or periods < 0:
raise QueryObjectValidationError(_("Periods must be a whole number"))
if not confidence_interval or confidence_interval <= 0 or confidence_interval >= 1:
raise QueryObjectValidationError(
_("Confidence interval must be between 0 and 1 (exclusive)")

View File

@ -830,6 +830,28 @@ class TestPostProcessing(SupersetTestCase):
assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2022, 5, 31)
assert len(df) == 9
def test_prophet_valid_zero_periods(self):
pytest.importorskip("prophet")
df = proc.prophet(
df=prophet_df, time_grain="P1M", periods=0, confidence_interval=0.9
)
columns = {column for column in df.columns}
assert columns == {
DTTM_ALIAS,
"a__yhat",
"a__yhat_upper",
"a__yhat_lower",
"a",
"b__yhat",
"b__yhat_upper",
"b__yhat_lower",
"b",
}
assert df[DTTM_ALIAS].iloc[0].to_pydatetime() == datetime(2018, 12, 31)
assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2021, 12, 31)
assert len(df) == 4
def test_prophet_import(self):
prophet = find_spec("prophet")
if prophet is None:
@ -875,7 +897,7 @@ class TestPostProcessing(SupersetTestCase):
proc.prophet,
df=prophet_df,
time_grain="P1M",
periods=0,
periods=-1,
confidence_interval=0.8,
)