superset/tests/unit_tests/pandas_postprocessing/test_prophet.py

187 lines
5.4 KiB
Python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from importlib.util import find_spec
import pandas as pd
import pytest
from superset.exceptions import InvalidPostProcessingError
from superset.utils.core import DTTM_ALIAS
from superset.utils.pandas_postprocessing import prophet
from tests.unit_tests.fixtures.dataframes import prophet_df
def test_prophet_valid():
df = prophet(df=prophet_df, time_grain="P1M", periods=3, confidence_interval=0.9)
columns = {column for column in df.columns}
assert columns == {
DTTM_ALIAS,
"a__yhat",
"a__yhat_upper",
"a__yhat_lower",
"a",
"b__yhat",
"b__yhat_upper",
"b__yhat_lower",
"b",
}
assert df[DTTM_ALIAS].iloc[0].to_pydatetime() == datetime(2018, 12, 31)
assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2022, 3, 31)
assert len(df) == 7
df = prophet(df=prophet_df, time_grain="P1M", periods=5, confidence_interval=0.9)
assert df[DTTM_ALIAS].iloc[0].to_pydatetime() == datetime(2018, 12, 31)
assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2022, 5, 31)
assert len(df) == 9
df = prophet(
df=pd.DataFrame(
{
"__timestamp": [datetime(2022, 1, 2), datetime(2022, 1, 9)],
"x": [1, 1],
}
),
time_grain="P1W",
periods=1,
confidence_interval=0.9,
)
assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2022, 1, 16)
assert len(df) == 3
df = prophet(
df=pd.DataFrame(
{
"__timestamp": [datetime(2022, 1, 2), datetime(2022, 1, 9)],
"x": [1, 1],
}
),
time_grain="1969-12-28T00:00:00Z/P1W",
periods=1,
confidence_interval=0.9,
)
assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2022, 1, 16)
assert len(df) == 3
df = prophet(
df=pd.DataFrame(
{
"__timestamp": [datetime(2022, 1, 3), datetime(2022, 1, 10)],
"x": [1, 1],
}
),
time_grain="1969-12-29T00:00:00Z/P1W",
periods=1,
confidence_interval=0.9,
)
assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2022, 1, 17)
assert len(df) == 3
df = prophet(
df=pd.DataFrame(
{
"__timestamp": [datetime(2022, 1, 8), datetime(2022, 1, 15)],
"x": [1, 1],
}
),
time_grain="P1W/1970-01-03T00:00:00Z",
periods=1,
confidence_interval=0.9,
)
assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2022, 1, 22)
assert len(df) == 3
def test_prophet_valid_zero_periods():
df = prophet(df=prophet_df, time_grain="P1M", periods=0, confidence_interval=0.9)
columns = {column for column in df.columns}
assert columns == {
DTTM_ALIAS,
"a__yhat",
"a__yhat_upper",
"a__yhat_lower",
"a",
"b__yhat",
"b__yhat_upper",
"b__yhat_lower",
"b",
}
assert df[DTTM_ALIAS].iloc[0].to_pydatetime() == datetime(2018, 12, 31)
assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2021, 12, 31)
assert len(df) == 4
def test_prophet_import():
dynamic_module = find_spec("prophet")
if dynamic_module is None:
with pytest.raises(InvalidPostProcessingError):
prophet(df=prophet_df, time_grain="P1M", periods=3, confidence_interval=0.9)
def test_prophet_missing_temporal_column():
df = prophet_df.drop(DTTM_ALIAS, axis=1)
with pytest.raises(InvalidPostProcessingError):
prophet(
df=df,
time_grain="P1M",
periods=3,
confidence_interval=0.9,
)
def test_prophet_incorrect_confidence_interval():
with pytest.raises(InvalidPostProcessingError):
prophet(
df=prophet_df,
time_grain="P1M",
periods=3,
confidence_interval=0.0,
)
with pytest.raises(InvalidPostProcessingError):
prophet(
df=prophet_df,
time_grain="P1M",
periods=3,
confidence_interval=1.0,
)
def test_prophet_incorrect_periods():
with pytest.raises(InvalidPostProcessingError):
prophet(
df=prophet_df,
time_grain="P1M",
periods=-1,
confidence_interval=0.8,
)
def test_prophet_incorrect_time_grain():
with pytest.raises(InvalidPostProcessingError):
prophet(
df=prophet_df,
time_grain="yearly",
periods=10,
confidence_interval=0.8,
)