diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index 8293065d9e..f68c4ff4b8 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -542,7 +542,7 @@ class ChartDataProphetOptionsSchema(ChartDataPostProcessingOperationOptionsSchem ) periods = fields.Integer( descrption="Time periods (in units of `time_grain`) to predict into the future", - min=1, + min=0, example=7, required=True, ) diff --git a/superset/utils/pandas_postprocessing.py b/superset/utils/pandas_postprocessing.py index 8ad5099552..ea4e40986b 100644 --- a/superset/utils/pandas_postprocessing.py +++ b/superset/utils/pandas_postprocessing.py @@ -820,8 +820,8 @@ def prophet( # pylint: disable=too-many-arguments freq = PROPHET_TIME_GRAIN_MAP[time_grain] # check type at runtime due to marhsmallow schema not being able to handle # union types - if not periods or periods < 0 or not isinstance(periods, int): - raise QueryObjectValidationError(_("Periods must be a positive integer value")) + if not isinstance(periods, int) or periods < 0: + raise QueryObjectValidationError(_("Periods must be a whole number")) if not confidence_interval or confidence_interval <= 0 or confidence_interval >= 1: raise QueryObjectValidationError( _("Confidence interval must be between 0 and 1 (exclusive)") diff --git a/tests/integration_tests/pandas_postprocessing_tests.py b/tests/integration_tests/pandas_postprocessing_tests.py index 7221130be8..feabef6e2f 100644 --- a/tests/integration_tests/pandas_postprocessing_tests.py +++ b/tests/integration_tests/pandas_postprocessing_tests.py @@ -830,6 +830,28 @@ class TestPostProcessing(SupersetTestCase): assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2022, 5, 31) assert len(df) == 9 + def test_prophet_valid_zero_periods(self): + pytest.importorskip("prophet") + + df = proc.prophet( + df=prophet_df, time_grain="P1M", periods=0, confidence_interval=0.9 + ) + columns = {column for column in df.columns} + assert columns == { + DTTM_ALIAS, + "a__yhat", + "a__yhat_upper", + "a__yhat_lower", + "a", + "b__yhat", + "b__yhat_upper", + "b__yhat_lower", + "b", + } + assert df[DTTM_ALIAS].iloc[0].to_pydatetime() == datetime(2018, 12, 31) + assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2021, 12, 31) + assert len(df) == 4 + def test_prophet_import(self): prophet = find_spec("prophet") if prophet is None: @@ -875,7 +897,7 @@ class TestPostProcessing(SupersetTestCase): proc.prophet, df=prophet_df, time_grain="P1M", - periods=0, + periods=-1, confidence_interval=0.8, )