mirror of
https://github.com/apache/superset.git
synced 2024-09-20 04:29:47 -04:00
fd8461406d
* fix: rolling and cum operator on multiple series * add UT * updates
1010 lines
34 KiB
Python
1010 lines
34 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
# isort:skip_file
|
|
from datetime import datetime
|
|
from importlib.util import find_spec
|
|
import math
|
|
from typing import Any, List, Optional
|
|
|
|
import numpy as np
|
|
from pandas import DataFrame, Series, Timestamp, to_datetime
|
|
import pytest
|
|
|
|
from superset.exceptions import QueryObjectValidationError
|
|
from superset.utils import pandas_postprocessing as proc
|
|
from superset.utils.core import (
|
|
DTTM_ALIAS,
|
|
PostProcessingContributionOrientation,
|
|
PostProcessingBoxplotWhiskerType,
|
|
)
|
|
|
|
from .base_tests import SupersetTestCase
|
|
from .fixtures.dataframes import (
|
|
categories_df,
|
|
single_metric_df,
|
|
multiple_metrics_df,
|
|
lonlat_df,
|
|
names_df,
|
|
timeseries_df,
|
|
prophet_df,
|
|
timeseries_df2,
|
|
)
|
|
|
|
AGGREGATES_SINGLE = {"idx_nulls": {"operator": "sum"}}
|
|
AGGREGATES_MULTIPLE = {
|
|
"idx_nulls": {"operator": "sum"},
|
|
"asc_idx": {"operator": "mean"},
|
|
}
|
|
|
|
|
|
def series_to_list(series: Series) -> List[Any]:
|
|
"""
|
|
Converts a `Series` to a regular list, and replaces non-numeric values to
|
|
Nones.
|
|
|
|
:param series: Series to convert
|
|
:return: list without nan or inf
|
|
"""
|
|
return [
|
|
None
|
|
if not isinstance(val, str) and (math.isnan(val) or math.isinf(val))
|
|
else val
|
|
for val in series.tolist()
|
|
]
|
|
|
|
|
|
def round_floats(
|
|
floats: List[Optional[float]], precision: int
|
|
) -> List[Optional[float]]:
|
|
"""
|
|
Round list of floats to certain precision
|
|
|
|
:param floats: floats to round
|
|
:param precision: intended decimal precision
|
|
:return: rounded floats
|
|
"""
|
|
return [round(val, precision) if val else None for val in floats]
|
|
|
|
|
|
class TestPostProcessing(SupersetTestCase):
|
|
def test_flatten_column_after_pivot(self):
|
|
"""
|
|
Test pivot column flattening function
|
|
"""
|
|
# single aggregate cases
|
|
self.assertEqual(
|
|
proc._flatten_column_after_pivot(
|
|
aggregates=AGGREGATES_SINGLE, column="idx_nulls",
|
|
),
|
|
"idx_nulls",
|
|
)
|
|
self.assertEqual(
|
|
proc._flatten_column_after_pivot(
|
|
aggregates=AGGREGATES_SINGLE, column=1234,
|
|
),
|
|
"1234",
|
|
)
|
|
self.assertEqual(
|
|
proc._flatten_column_after_pivot(
|
|
aggregates=AGGREGATES_SINGLE, column=Timestamp("2020-09-29T00:00:00"),
|
|
),
|
|
"2020-09-29 00:00:00",
|
|
)
|
|
self.assertEqual(
|
|
proc._flatten_column_after_pivot(
|
|
aggregates=AGGREGATES_SINGLE, column="idx_nulls",
|
|
),
|
|
"idx_nulls",
|
|
)
|
|
self.assertEqual(
|
|
proc._flatten_column_after_pivot(
|
|
aggregates=AGGREGATES_SINGLE, column=("idx_nulls", "col1"),
|
|
),
|
|
"col1",
|
|
)
|
|
self.assertEqual(
|
|
proc._flatten_column_after_pivot(
|
|
aggregates=AGGREGATES_SINGLE, column=("idx_nulls", "col1", 1234),
|
|
),
|
|
"col1, 1234",
|
|
)
|
|
|
|
# Multiple aggregate cases
|
|
self.assertEqual(
|
|
proc._flatten_column_after_pivot(
|
|
aggregates=AGGREGATES_MULTIPLE, column=("idx_nulls", "asc_idx", "col1"),
|
|
),
|
|
"idx_nulls, asc_idx, col1",
|
|
)
|
|
self.assertEqual(
|
|
proc._flatten_column_after_pivot(
|
|
aggregates=AGGREGATES_MULTIPLE,
|
|
column=("idx_nulls", "asc_idx", "col1", 1234),
|
|
),
|
|
"idx_nulls, asc_idx, col1, 1234",
|
|
)
|
|
|
|
def test_pivot_without_columns(self):
|
|
"""
|
|
Make sure pivot without columns returns correct DataFrame
|
|
"""
|
|
df = proc.pivot(df=categories_df, index=["name"], aggregates=AGGREGATES_SINGLE,)
|
|
self.assertListEqual(
|
|
df.columns.tolist(), ["name", "idx_nulls"],
|
|
)
|
|
self.assertEqual(len(df), 101)
|
|
self.assertEqual(df.sum()[1], 1050)
|
|
|
|
def test_pivot_with_single_column(self):
|
|
"""
|
|
Make sure pivot with single column returns correct DataFrame
|
|
"""
|
|
df = proc.pivot(
|
|
df=categories_df,
|
|
index=["name"],
|
|
columns=["category"],
|
|
aggregates=AGGREGATES_SINGLE,
|
|
)
|
|
self.assertListEqual(
|
|
df.columns.tolist(), ["name", "cat0", "cat1", "cat2"],
|
|
)
|
|
self.assertEqual(len(df), 101)
|
|
self.assertEqual(df.sum()[1], 315)
|
|
|
|
df = proc.pivot(
|
|
df=categories_df,
|
|
index=["dept"],
|
|
columns=["category"],
|
|
aggregates=AGGREGATES_SINGLE,
|
|
)
|
|
self.assertListEqual(
|
|
df.columns.tolist(), ["dept", "cat0", "cat1", "cat2"],
|
|
)
|
|
self.assertEqual(len(df), 5)
|
|
|
|
def test_pivot_with_multiple_columns(self):
|
|
"""
|
|
Make sure pivot with multiple columns returns correct DataFrame
|
|
"""
|
|
df = proc.pivot(
|
|
df=categories_df,
|
|
index=["name"],
|
|
columns=["category", "dept"],
|
|
aggregates=AGGREGATES_SINGLE,
|
|
)
|
|
self.assertEqual(len(df.columns), 1 + 3 * 5) # index + possible permutations
|
|
|
|
def test_pivot_fill_values(self):
|
|
"""
|
|
Make sure pivot with fill values returns correct DataFrame
|
|
"""
|
|
df = proc.pivot(
|
|
df=categories_df,
|
|
index=["name"],
|
|
columns=["category"],
|
|
metric_fill_value=1,
|
|
aggregates={"idx_nulls": {"operator": "sum"}},
|
|
)
|
|
self.assertEqual(df.sum()[1], 382)
|
|
|
|
def test_pivot_fill_column_values(self):
|
|
"""
|
|
Make sure pivot witn null column names returns correct DataFrame
|
|
"""
|
|
df_copy = categories_df.copy()
|
|
df_copy["category"] = None
|
|
df = proc.pivot(
|
|
df=df_copy,
|
|
index=["name"],
|
|
columns=["category"],
|
|
aggregates={"idx_nulls": {"operator": "sum"}},
|
|
)
|
|
assert len(df) == 101
|
|
assert df.columns.tolist() == ["name", "<NULL>"]
|
|
|
|
def test_pivot_exceptions(self):
|
|
"""
|
|
Make sure pivot raises correct Exceptions
|
|
"""
|
|
# Missing index
|
|
self.assertRaises(
|
|
TypeError,
|
|
proc.pivot,
|
|
df=categories_df,
|
|
columns=["dept"],
|
|
aggregates=AGGREGATES_SINGLE,
|
|
)
|
|
|
|
# invalid index reference
|
|
self.assertRaises(
|
|
QueryObjectValidationError,
|
|
proc.pivot,
|
|
df=categories_df,
|
|
index=["abc"],
|
|
columns=["dept"],
|
|
aggregates=AGGREGATES_SINGLE,
|
|
)
|
|
|
|
# invalid column reference
|
|
self.assertRaises(
|
|
QueryObjectValidationError,
|
|
proc.pivot,
|
|
df=categories_df,
|
|
index=["dept"],
|
|
columns=["abc"],
|
|
aggregates=AGGREGATES_SINGLE,
|
|
)
|
|
|
|
# invalid aggregate options
|
|
self.assertRaises(
|
|
QueryObjectValidationError,
|
|
proc.pivot,
|
|
df=categories_df,
|
|
index=["name"],
|
|
columns=["category"],
|
|
aggregates={"idx_nulls": {}},
|
|
)
|
|
|
|
def test_pivot_eliminate_cartesian_product_columns(self):
|
|
# single metric
|
|
mock_df = DataFrame(
|
|
{
|
|
"dttm": to_datetime(["2019-01-01", "2019-01-01"]),
|
|
"a": [0, 1],
|
|
"b": [0, 1],
|
|
"metric": [9, np.NAN],
|
|
}
|
|
)
|
|
|
|
df = proc.pivot(
|
|
df=mock_df,
|
|
index=["dttm"],
|
|
columns=["a", "b"],
|
|
aggregates={"metric": {"operator": "mean"}},
|
|
drop_missing_columns=False,
|
|
)
|
|
self.assertEqual(list(df.columns), ["dttm", "0, 0", "1, 1"])
|
|
self.assertTrue(np.isnan(df["1, 1"][0]))
|
|
|
|
# multiple metrics
|
|
mock_df = DataFrame(
|
|
{
|
|
"dttm": to_datetime(["2019-01-01", "2019-01-01"]),
|
|
"a": [0, 1],
|
|
"b": [0, 1],
|
|
"metric": [9, np.NAN],
|
|
"metric2": [10, 11],
|
|
}
|
|
)
|
|
|
|
df = proc.pivot(
|
|
df=mock_df,
|
|
index=["dttm"],
|
|
columns=["a", "b"],
|
|
aggregates={
|
|
"metric": {"operator": "mean"},
|
|
"metric2": {"operator": "mean"},
|
|
},
|
|
drop_missing_columns=False,
|
|
)
|
|
self.assertEqual(
|
|
list(df.columns),
|
|
["dttm", "metric, 0, 0", "metric, 1, 1", "metric2, 0, 0", "metric2, 1, 1"],
|
|
)
|
|
self.assertTrue(np.isnan(df["metric, 1, 1"][0]))
|
|
|
|
def test_pivot_without_flatten_columns_and_reset_index(self):
|
|
df = proc.pivot(
|
|
df=single_metric_df,
|
|
index=["dttm"],
|
|
columns=["country"],
|
|
aggregates={"sum_metric": {"operator": "sum"}},
|
|
flatten_columns=False,
|
|
reset_index=False,
|
|
)
|
|
# metric
|
|
# country UK US
|
|
# dttm
|
|
# 2019-01-01 5 6
|
|
# 2019-01-02 7 8
|
|
assert df.columns.to_list() == [("sum_metric", "UK"), ("sum_metric", "US")]
|
|
assert df.index.to_list() == to_datetime(["2019-01-01", "2019-01-02"]).to_list()
|
|
|
|
def test_aggregate(self):
|
|
aggregates = {
|
|
"asc sum": {"column": "asc_idx", "operator": "sum"},
|
|
"asc q2": {
|
|
"column": "asc_idx",
|
|
"operator": "percentile",
|
|
"options": {"q": 75},
|
|
},
|
|
"desc q1": {
|
|
"column": "desc_idx",
|
|
"operator": "percentile",
|
|
"options": {"q": 25},
|
|
},
|
|
}
|
|
df = proc.aggregate(
|
|
df=categories_df, groupby=["constant"], aggregates=aggregates
|
|
)
|
|
self.assertListEqual(
|
|
df.columns.tolist(), ["constant", "asc sum", "asc q2", "desc q1"]
|
|
)
|
|
self.assertEqual(series_to_list(df["asc sum"])[0], 5050)
|
|
self.assertEqual(series_to_list(df["asc q2"])[0], 75)
|
|
self.assertEqual(series_to_list(df["desc q1"])[0], 25)
|
|
|
|
def test_sort(self):
|
|
df = proc.sort(df=categories_df, columns={"category": True, "asc_idx": False})
|
|
self.assertEqual(96, series_to_list(df["asc_idx"])[1])
|
|
|
|
self.assertRaises(
|
|
QueryObjectValidationError, proc.sort, df=df, columns={"abc": True}
|
|
)
|
|
|
|
def test_rolling(self):
|
|
# sum rolling type
|
|
post_df = proc.rolling(
|
|
df=timeseries_df,
|
|
columns={"y": "y"},
|
|
rolling_type="sum",
|
|
window=2,
|
|
min_periods=0,
|
|
)
|
|
|
|
self.assertListEqual(post_df.columns.tolist(), ["label", "y"])
|
|
self.assertListEqual(series_to_list(post_df["y"]), [1.0, 3.0, 5.0, 7.0])
|
|
|
|
# mean rolling type with alias
|
|
post_df = proc.rolling(
|
|
df=timeseries_df,
|
|
rolling_type="mean",
|
|
columns={"y": "y_mean"},
|
|
window=10,
|
|
min_periods=0,
|
|
)
|
|
self.assertListEqual(post_df.columns.tolist(), ["label", "y", "y_mean"])
|
|
self.assertListEqual(series_to_list(post_df["y_mean"]), [1.0, 1.5, 2.0, 2.5])
|
|
|
|
# count rolling type
|
|
post_df = proc.rolling(
|
|
df=timeseries_df,
|
|
rolling_type="count",
|
|
columns={"y": "y"},
|
|
window=10,
|
|
min_periods=0,
|
|
)
|
|
self.assertListEqual(post_df.columns.tolist(), ["label", "y"])
|
|
self.assertListEqual(series_to_list(post_df["y"]), [1.0, 2.0, 3.0, 4.0])
|
|
|
|
# quantile rolling type
|
|
post_df = proc.rolling(
|
|
df=timeseries_df,
|
|
columns={"y": "q1"},
|
|
rolling_type="quantile",
|
|
rolling_type_options={"quantile": 0.25},
|
|
window=10,
|
|
min_periods=0,
|
|
)
|
|
self.assertListEqual(post_df.columns.tolist(), ["label", "y", "q1"])
|
|
self.assertListEqual(series_to_list(post_df["q1"]), [1.0, 1.25, 1.5, 1.75])
|
|
|
|
# incorrect rolling type
|
|
self.assertRaises(
|
|
QueryObjectValidationError,
|
|
proc.rolling,
|
|
df=timeseries_df,
|
|
columns={"y": "y"},
|
|
rolling_type="abc",
|
|
window=2,
|
|
)
|
|
|
|
# incorrect rolling type options
|
|
self.assertRaises(
|
|
QueryObjectValidationError,
|
|
proc.rolling,
|
|
df=timeseries_df,
|
|
columns={"y": "y"},
|
|
rolling_type="quantile",
|
|
rolling_type_options={"abc": 123},
|
|
window=2,
|
|
)
|
|
|
|
def test_rolling_with_pivot_df_and_single_metric(self):
|
|
pivot_df = proc.pivot(
|
|
df=single_metric_df,
|
|
index=["dttm"],
|
|
columns=["country"],
|
|
aggregates={"sum_metric": {"operator": "sum"}},
|
|
flatten_columns=False,
|
|
reset_index=False,
|
|
)
|
|
rolling_df = proc.rolling(
|
|
df=pivot_df, rolling_type="sum", window=2, min_periods=0, is_pivot_df=True,
|
|
)
|
|
# dttm UK US
|
|
# 0 2019-01-01 5 6
|
|
# 1 2019-01-02 12 14
|
|
assert rolling_df["UK"].to_list() == [5.0, 12.0]
|
|
assert rolling_df["US"].to_list() == [6.0, 14.0]
|
|
assert (
|
|
rolling_df["dttm"].to_list()
|
|
== to_datetime(["2019-01-01", "2019-01-02",]).to_list()
|
|
)
|
|
|
|
rolling_df = proc.rolling(
|
|
df=pivot_df, rolling_type="sum", window=2, min_periods=2, is_pivot_df=True,
|
|
)
|
|
assert rolling_df.empty is True
|
|
|
|
def test_rolling_with_pivot_df_and_multiple_metrics(self):
|
|
pivot_df = proc.pivot(
|
|
df=multiple_metrics_df,
|
|
index=["dttm"],
|
|
columns=["country"],
|
|
aggregates={
|
|
"sum_metric": {"operator": "sum"},
|
|
"count_metric": {"operator": "sum"},
|
|
},
|
|
flatten_columns=False,
|
|
reset_index=False,
|
|
)
|
|
rolling_df = proc.rolling(
|
|
df=pivot_df, rolling_type="sum", window=2, min_periods=0, is_pivot_df=True,
|
|
)
|
|
# dttm count_metric, UK count_metric, US sum_metric, UK sum_metric, US
|
|
# 0 2019-01-01 1.0 2.0 5.0 6.0
|
|
# 1 2019-01-02 4.0 6.0 12.0 14.0
|
|
assert rolling_df["count_metric, UK"].to_list() == [1.0, 4.0]
|
|
assert rolling_df["count_metric, US"].to_list() == [2.0, 6.0]
|
|
assert rolling_df["sum_metric, UK"].to_list() == [5.0, 12.0]
|
|
assert rolling_df["sum_metric, US"].to_list() == [6.0, 14.0]
|
|
assert (
|
|
rolling_df["dttm"].to_list()
|
|
== to_datetime(["2019-01-01", "2019-01-02",]).to_list()
|
|
)
|
|
|
|
def test_select(self):
|
|
# reorder columns
|
|
post_df = proc.select(df=timeseries_df, columns=["y", "label"])
|
|
self.assertListEqual(post_df.columns.tolist(), ["y", "label"])
|
|
|
|
# one column
|
|
post_df = proc.select(df=timeseries_df, columns=["label"])
|
|
self.assertListEqual(post_df.columns.tolist(), ["label"])
|
|
|
|
# rename and select one column
|
|
post_df = proc.select(df=timeseries_df, columns=["y"], rename={"y": "y1"})
|
|
self.assertListEqual(post_df.columns.tolist(), ["y1"])
|
|
|
|
# rename one and leave one unchanged
|
|
post_df = proc.select(df=timeseries_df, rename={"y": "y1"})
|
|
self.assertListEqual(post_df.columns.tolist(), ["label", "y1"])
|
|
|
|
# drop one column
|
|
post_df = proc.select(df=timeseries_df, exclude=["label"])
|
|
self.assertListEqual(post_df.columns.tolist(), ["y"])
|
|
|
|
# rename and drop one column
|
|
post_df = proc.select(df=timeseries_df, rename={"y": "y1"}, exclude=["label"])
|
|
self.assertListEqual(post_df.columns.tolist(), ["y1"])
|
|
|
|
# invalid columns
|
|
self.assertRaises(
|
|
QueryObjectValidationError,
|
|
proc.select,
|
|
df=timeseries_df,
|
|
columns=["abc"],
|
|
rename={"abc": "qwerty"},
|
|
)
|
|
|
|
# select renamed column by new name
|
|
self.assertRaises(
|
|
QueryObjectValidationError,
|
|
proc.select,
|
|
df=timeseries_df,
|
|
columns=["label_new"],
|
|
rename={"label": "label_new"},
|
|
)
|
|
|
|
def test_diff(self):
|
|
# overwrite column
|
|
post_df = proc.diff(df=timeseries_df, columns={"y": "y"})
|
|
self.assertListEqual(post_df.columns.tolist(), ["label", "y"])
|
|
self.assertListEqual(series_to_list(post_df["y"]), [None, 1.0, 1.0, 1.0])
|
|
|
|
# add column
|
|
post_df = proc.diff(df=timeseries_df, columns={"y": "y1"})
|
|
self.assertListEqual(post_df.columns.tolist(), ["label", "y", "y1"])
|
|
self.assertListEqual(series_to_list(post_df["y"]), [1.0, 2.0, 3.0, 4.0])
|
|
self.assertListEqual(series_to_list(post_df["y1"]), [None, 1.0, 1.0, 1.0])
|
|
|
|
# look ahead
|
|
post_df = proc.diff(df=timeseries_df, columns={"y": "y1"}, periods=-1)
|
|
self.assertListEqual(series_to_list(post_df["y1"]), [-1.0, -1.0, -1.0, None])
|
|
|
|
# invalid column reference
|
|
self.assertRaises(
|
|
QueryObjectValidationError,
|
|
proc.diff,
|
|
df=timeseries_df,
|
|
columns={"abc": "abc"},
|
|
)
|
|
|
|
# diff by columns
|
|
post_df = proc.diff(df=timeseries_df2, columns={"y": "y", "z": "z"}, axis=1)
|
|
self.assertListEqual(post_df.columns.tolist(), ["label", "y", "z"])
|
|
self.assertListEqual(series_to_list(post_df["z"]), [0.0, 2.0, 8.0, 6.0])
|
|
|
|
def test_compare(self):
|
|
# `difference` comparison
|
|
post_df = proc.compare(
|
|
df=timeseries_df2,
|
|
source_columns=["y"],
|
|
compare_columns=["z"],
|
|
compare_type="difference",
|
|
)
|
|
self.assertListEqual(
|
|
post_df.columns.tolist(), ["label", "y", "z", "difference__y__z",]
|
|
)
|
|
self.assertListEqual(
|
|
series_to_list(post_df["difference__y__z"]), [0.0, -2.0, -8.0, -6.0],
|
|
)
|
|
|
|
# drop original columns
|
|
post_df = proc.compare(
|
|
df=timeseries_df2,
|
|
source_columns=["y"],
|
|
compare_columns=["z"],
|
|
compare_type="difference",
|
|
drop_original_columns=True,
|
|
)
|
|
self.assertListEqual(post_df.columns.tolist(), ["label", "difference__y__z",])
|
|
|
|
# `percentage` comparison
|
|
post_df = proc.compare(
|
|
df=timeseries_df2,
|
|
source_columns=["y"],
|
|
compare_columns=["z"],
|
|
compare_type="percentage",
|
|
)
|
|
self.assertListEqual(
|
|
post_df.columns.tolist(), ["label", "y", "z", "percentage__y__z",]
|
|
)
|
|
self.assertListEqual(
|
|
series_to_list(post_df["percentage__y__z"]), [0.0, -0.5, -0.8, -0.75],
|
|
)
|
|
|
|
# `ratio` comparison
|
|
post_df = proc.compare(
|
|
df=timeseries_df2,
|
|
source_columns=["y"],
|
|
compare_columns=["z"],
|
|
compare_type="ratio",
|
|
)
|
|
self.assertListEqual(
|
|
post_df.columns.tolist(), ["label", "y", "z", "ratio__y__z",]
|
|
)
|
|
self.assertListEqual(
|
|
series_to_list(post_df["ratio__y__z"]), [1.0, 0.5, 0.2, 0.25],
|
|
)
|
|
|
|
def test_cum(self):
|
|
# create new column (cumsum)
|
|
post_df = proc.cum(df=timeseries_df, columns={"y": "y2"}, operator="sum",)
|
|
self.assertListEqual(post_df.columns.tolist(), ["label", "y", "y2"])
|
|
self.assertListEqual(series_to_list(post_df["label"]), ["x", "y", "z", "q"])
|
|
self.assertListEqual(series_to_list(post_df["y"]), [1.0, 2.0, 3.0, 4.0])
|
|
self.assertListEqual(series_to_list(post_df["y2"]), [1.0, 3.0, 6.0, 10.0])
|
|
|
|
# overwrite column (cumprod)
|
|
post_df = proc.cum(df=timeseries_df, columns={"y": "y"}, operator="prod",)
|
|
self.assertListEqual(post_df.columns.tolist(), ["label", "y"])
|
|
self.assertListEqual(series_to_list(post_df["y"]), [1.0, 2.0, 6.0, 24.0])
|
|
|
|
# overwrite column (cummin)
|
|
post_df = proc.cum(df=timeseries_df, columns={"y": "y"}, operator="min",)
|
|
self.assertListEqual(post_df.columns.tolist(), ["label", "y"])
|
|
self.assertListEqual(series_to_list(post_df["y"]), [1.0, 1.0, 1.0, 1.0])
|
|
|
|
# invalid operator
|
|
self.assertRaises(
|
|
QueryObjectValidationError,
|
|
proc.cum,
|
|
df=timeseries_df,
|
|
columns={"y": "y"},
|
|
operator="abc",
|
|
)
|
|
|
|
def test_cum_with_pivot_df_and_single_metric(self):
|
|
pivot_df = proc.pivot(
|
|
df=single_metric_df,
|
|
index=["dttm"],
|
|
columns=["country"],
|
|
aggregates={"sum_metric": {"operator": "sum"}},
|
|
flatten_columns=False,
|
|
reset_index=False,
|
|
)
|
|
cum_df = proc.cum(df=pivot_df, operator="sum", is_pivot_df=True,)
|
|
# dttm UK US
|
|
# 0 2019-01-01 5 6
|
|
# 1 2019-01-02 12 14
|
|
assert cum_df["UK"].to_list() == [5.0, 12.0]
|
|
assert cum_df["US"].to_list() == [6.0, 14.0]
|
|
assert (
|
|
cum_df["dttm"].to_list()
|
|
== to_datetime(["2019-01-01", "2019-01-02",]).to_list()
|
|
)
|
|
|
|
def test_cum_with_pivot_df_and_multiple_metrics(self):
|
|
pivot_df = proc.pivot(
|
|
df=multiple_metrics_df,
|
|
index=["dttm"],
|
|
columns=["country"],
|
|
aggregates={
|
|
"sum_metric": {"operator": "sum"},
|
|
"count_metric": {"operator": "sum"},
|
|
},
|
|
flatten_columns=False,
|
|
reset_index=False,
|
|
)
|
|
cum_df = proc.cum(df=pivot_df, operator="sum", is_pivot_df=True,)
|
|
# dttm count_metric, UK count_metric, US sum_metric, UK sum_metric, US
|
|
# 0 2019-01-01 1 2 5 6
|
|
# 1 2019-01-02 4 6 12 14
|
|
assert cum_df["count_metric, UK"].to_list() == [1.0, 4.0]
|
|
assert cum_df["count_metric, US"].to_list() == [2.0, 6.0]
|
|
assert cum_df["sum_metric, UK"].to_list() == [5.0, 12.0]
|
|
assert cum_df["sum_metric, US"].to_list() == [6.0, 14.0]
|
|
assert (
|
|
cum_df["dttm"].to_list()
|
|
== to_datetime(["2019-01-01", "2019-01-02",]).to_list()
|
|
)
|
|
|
|
def test_geohash_decode(self):
|
|
# decode lon/lat from geohash
|
|
post_df = proc.geohash_decode(
|
|
df=lonlat_df[["city", "geohash"]],
|
|
geohash="geohash",
|
|
latitude="latitude",
|
|
longitude="longitude",
|
|
)
|
|
self.assertListEqual(
|
|
sorted(post_df.columns.tolist()),
|
|
sorted(["city", "geohash", "latitude", "longitude"]),
|
|
)
|
|
self.assertListEqual(
|
|
round_floats(series_to_list(post_df["longitude"]), 6),
|
|
round_floats(series_to_list(lonlat_df["longitude"]), 6),
|
|
)
|
|
self.assertListEqual(
|
|
round_floats(series_to_list(post_df["latitude"]), 6),
|
|
round_floats(series_to_list(lonlat_df["latitude"]), 6),
|
|
)
|
|
|
|
def test_geohash_encode(self):
|
|
# encode lon/lat into geohash
|
|
post_df = proc.geohash_encode(
|
|
df=lonlat_df[["city", "latitude", "longitude"]],
|
|
latitude="latitude",
|
|
longitude="longitude",
|
|
geohash="geohash",
|
|
)
|
|
self.assertListEqual(
|
|
sorted(post_df.columns.tolist()),
|
|
sorted(["city", "geohash", "latitude", "longitude"]),
|
|
)
|
|
self.assertListEqual(
|
|
series_to_list(post_df["geohash"]), series_to_list(lonlat_df["geohash"]),
|
|
)
|
|
|
|
def test_geodetic_parse(self):
|
|
# parse geodetic string with altitude into lon/lat/altitude
|
|
post_df = proc.geodetic_parse(
|
|
df=lonlat_df[["city", "geodetic"]],
|
|
geodetic="geodetic",
|
|
latitude="latitude",
|
|
longitude="longitude",
|
|
altitude="altitude",
|
|
)
|
|
self.assertListEqual(
|
|
sorted(post_df.columns.tolist()),
|
|
sorted(["city", "geodetic", "latitude", "longitude", "altitude"]),
|
|
)
|
|
self.assertListEqual(
|
|
series_to_list(post_df["longitude"]),
|
|
series_to_list(lonlat_df["longitude"]),
|
|
)
|
|
self.assertListEqual(
|
|
series_to_list(post_df["latitude"]), series_to_list(lonlat_df["latitude"]),
|
|
)
|
|
self.assertListEqual(
|
|
series_to_list(post_df["altitude"]), series_to_list(lonlat_df["altitude"]),
|
|
)
|
|
|
|
# parse geodetic string into lon/lat
|
|
post_df = proc.geodetic_parse(
|
|
df=lonlat_df[["city", "geodetic"]],
|
|
geodetic="geodetic",
|
|
latitude="latitude",
|
|
longitude="longitude",
|
|
)
|
|
self.assertListEqual(
|
|
sorted(post_df.columns.tolist()),
|
|
sorted(["city", "geodetic", "latitude", "longitude"]),
|
|
)
|
|
self.assertListEqual(
|
|
series_to_list(post_df["longitude"]),
|
|
series_to_list(lonlat_df["longitude"]),
|
|
)
|
|
self.assertListEqual(
|
|
series_to_list(post_df["latitude"]), series_to_list(lonlat_df["latitude"]),
|
|
)
|
|
|
|
def test_contribution(self):
|
|
df = DataFrame(
|
|
{
|
|
DTTM_ALIAS: [
|
|
datetime(2020, 7, 16, 14, 49),
|
|
datetime(2020, 7, 16, 14, 50),
|
|
],
|
|
"a": [1, 3],
|
|
"b": [1, 9],
|
|
}
|
|
)
|
|
with pytest.raises(QueryObjectValidationError, match="not numeric"):
|
|
proc.contribution(df, columns=[DTTM_ALIAS])
|
|
|
|
with pytest.raises(QueryObjectValidationError, match="same length"):
|
|
proc.contribution(df, columns=["a"], rename_columns=["aa", "bb"])
|
|
|
|
# cell contribution across row
|
|
processed_df = proc.contribution(
|
|
df, orientation=PostProcessingContributionOrientation.ROW,
|
|
)
|
|
self.assertListEqual(processed_df.columns.tolist(), [DTTM_ALIAS, "a", "b"])
|
|
self.assertListEqual(processed_df["a"].tolist(), [0.5, 0.25])
|
|
self.assertListEqual(processed_df["b"].tolist(), [0.5, 0.75])
|
|
|
|
# cell contribution across column without temporal column
|
|
df.pop(DTTM_ALIAS)
|
|
processed_df = proc.contribution(
|
|
df, orientation=PostProcessingContributionOrientation.COLUMN
|
|
)
|
|
self.assertListEqual(processed_df.columns.tolist(), ["a", "b"])
|
|
self.assertListEqual(processed_df["a"].tolist(), [0.25, 0.75])
|
|
self.assertListEqual(processed_df["b"].tolist(), [0.1, 0.9])
|
|
|
|
# contribution only on selected columns
|
|
processed_df = proc.contribution(
|
|
df,
|
|
orientation=PostProcessingContributionOrientation.COLUMN,
|
|
columns=["a"],
|
|
rename_columns=["pct_a"],
|
|
)
|
|
self.assertListEqual(processed_df.columns.tolist(), ["a", "b", "pct_a"])
|
|
self.assertListEqual(processed_df["a"].tolist(), [1, 3])
|
|
self.assertListEqual(processed_df["b"].tolist(), [1, 9])
|
|
self.assertListEqual(processed_df["pct_a"].tolist(), [0.25, 0.75])
|
|
|
|
def test_prophet_valid(self):
|
|
pytest.importorskip("prophet")
|
|
|
|
df = proc.prophet(
|
|
df=prophet_df, time_grain="P1M", periods=3, confidence_interval=0.9
|
|
)
|
|
columns = {column for column in df.columns}
|
|
assert columns == {
|
|
DTTM_ALIAS,
|
|
"a__yhat",
|
|
"a__yhat_upper",
|
|
"a__yhat_lower",
|
|
"a",
|
|
"b__yhat",
|
|
"b__yhat_upper",
|
|
"b__yhat_lower",
|
|
"b",
|
|
}
|
|
assert df[DTTM_ALIAS].iloc[0].to_pydatetime() == datetime(2018, 12, 31)
|
|
assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2022, 3, 31)
|
|
assert len(df) == 7
|
|
|
|
df = proc.prophet(
|
|
df=prophet_df, time_grain="P1M", periods=5, confidence_interval=0.9
|
|
)
|
|
assert df[DTTM_ALIAS].iloc[0].to_pydatetime() == datetime(2018, 12, 31)
|
|
assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2022, 5, 31)
|
|
assert len(df) == 9
|
|
|
|
def test_prophet_import(self):
|
|
prophet = find_spec("prophet")
|
|
if prophet is None:
|
|
with pytest.raises(QueryObjectValidationError):
|
|
proc.prophet(
|
|
df=prophet_df, time_grain="P1M", periods=3, confidence_interval=0.9
|
|
)
|
|
|
|
def test_prophet_missing_temporal_column(self):
|
|
df = prophet_df.drop(DTTM_ALIAS, axis=1)
|
|
|
|
self.assertRaises(
|
|
QueryObjectValidationError,
|
|
proc.prophet,
|
|
df=df,
|
|
time_grain="P1M",
|
|
periods=3,
|
|
confidence_interval=0.9,
|
|
)
|
|
|
|
def test_prophet_incorrect_confidence_interval(self):
|
|
self.assertRaises(
|
|
QueryObjectValidationError,
|
|
proc.prophet,
|
|
df=prophet_df,
|
|
time_grain="P1M",
|
|
periods=3,
|
|
confidence_interval=0.0,
|
|
)
|
|
|
|
self.assertRaises(
|
|
QueryObjectValidationError,
|
|
proc.prophet,
|
|
df=prophet_df,
|
|
time_grain="P1M",
|
|
periods=3,
|
|
confidence_interval=1.0,
|
|
)
|
|
|
|
def test_prophet_incorrect_periods(self):
|
|
self.assertRaises(
|
|
QueryObjectValidationError,
|
|
proc.prophet,
|
|
df=prophet_df,
|
|
time_grain="P1M",
|
|
periods=0,
|
|
confidence_interval=0.8,
|
|
)
|
|
|
|
def test_prophet_incorrect_time_grain(self):
|
|
self.assertRaises(
|
|
QueryObjectValidationError,
|
|
proc.prophet,
|
|
df=prophet_df,
|
|
time_grain="yearly",
|
|
periods=10,
|
|
confidence_interval=0.8,
|
|
)
|
|
|
|
def test_boxplot_tukey(self):
|
|
df = proc.boxplot(
|
|
df=names_df,
|
|
groupby=["region"],
|
|
whisker_type=PostProcessingBoxplotWhiskerType.TUKEY,
|
|
metrics=["cars"],
|
|
)
|
|
columns = {column for column in df.columns}
|
|
assert columns == {
|
|
"cars__mean",
|
|
"cars__median",
|
|
"cars__q1",
|
|
"cars__q3",
|
|
"cars__max",
|
|
"cars__min",
|
|
"cars__count",
|
|
"cars__outliers",
|
|
"region",
|
|
}
|
|
assert len(df) == 4
|
|
|
|
def test_boxplot_min_max(self):
|
|
df = proc.boxplot(
|
|
df=names_df,
|
|
groupby=["region"],
|
|
whisker_type=PostProcessingBoxplotWhiskerType.MINMAX,
|
|
metrics=["cars"],
|
|
)
|
|
columns = {column for column in df.columns}
|
|
assert columns == {
|
|
"cars__mean",
|
|
"cars__median",
|
|
"cars__q1",
|
|
"cars__q3",
|
|
"cars__max",
|
|
"cars__min",
|
|
"cars__count",
|
|
"cars__outliers",
|
|
"region",
|
|
}
|
|
assert len(df) == 4
|
|
|
|
def test_boxplot_percentile(self):
|
|
df = proc.boxplot(
|
|
df=names_df,
|
|
groupby=["region"],
|
|
whisker_type=PostProcessingBoxplotWhiskerType.PERCENTILE,
|
|
metrics=["cars"],
|
|
percentiles=[1, 99],
|
|
)
|
|
columns = {column for column in df.columns}
|
|
assert columns == {
|
|
"cars__mean",
|
|
"cars__median",
|
|
"cars__q1",
|
|
"cars__q3",
|
|
"cars__max",
|
|
"cars__min",
|
|
"cars__count",
|
|
"cars__outliers",
|
|
"region",
|
|
}
|
|
assert len(df) == 4
|
|
|
|
def test_boxplot_percentile_incorrect_params(self):
|
|
with pytest.raises(QueryObjectValidationError):
|
|
proc.boxplot(
|
|
df=names_df,
|
|
groupby=["region"],
|
|
whisker_type=PostProcessingBoxplotWhiskerType.PERCENTILE,
|
|
metrics=["cars"],
|
|
)
|
|
|
|
with pytest.raises(QueryObjectValidationError):
|
|
proc.boxplot(
|
|
df=names_df,
|
|
groupby=["region"],
|
|
whisker_type=PostProcessingBoxplotWhiskerType.PERCENTILE,
|
|
metrics=["cars"],
|
|
percentiles=[10],
|
|
)
|
|
|
|
with pytest.raises(QueryObjectValidationError):
|
|
proc.boxplot(
|
|
df=names_df,
|
|
groupby=["region"],
|
|
whisker_type=PostProcessingBoxplotWhiskerType.PERCENTILE,
|
|
metrics=["cars"],
|
|
percentiles=[90, 10],
|
|
)
|
|
|
|
with pytest.raises(QueryObjectValidationError):
|
|
proc.boxplot(
|
|
df=names_df,
|
|
groupby=["region"],
|
|
whisker_type=PostProcessingBoxplotWhiskerType.PERCENTILE,
|
|
metrics=["cars"],
|
|
percentiles=[10, 90, 10],
|
|
)
|
|
|
|
def test_resample(self):
|
|
df = timeseries_df.copy()
|
|
df.index.name = "time_column"
|
|
df.reset_index(inplace=True)
|
|
|
|
post_df = proc.resample(
|
|
df=df, rule="1D", method="ffill", time_column="time_column",
|
|
)
|
|
self.assertListEqual(
|
|
post_df["label"].tolist(), ["x", "y", "y", "y", "z", "z", "q"]
|
|
)
|
|
self.assertListEqual(post_df["y"].tolist(), [1.0, 2.0, 2.0, 2.0, 3.0, 3.0, 4.0])
|
|
|
|
post_df = proc.resample(
|
|
df=df, rule="1D", method="asfreq", time_column="time_column", fill_value=0,
|
|
)
|
|
self.assertListEqual(post_df["label"].tolist(), ["x", "y", 0, 0, "z", 0, "q"])
|
|
self.assertListEqual(post_df["y"].tolist(), [1.0, 2.0, 0, 0, 3.0, 0, 4.0])
|