mirror of https://github.com/apache/superset.git
feat: Add geospatial post processing operations (#9661)
* feat: Add geospatial post processing operations * Linting * Refactor * Add tests * Improve docs * Address comments * fix latitude/longitude mixup * fix: bad refactor by pycharm
This commit is contained in:
parent
c474ea848a
commit
a52cfcd234
|
@ -34,7 +34,7 @@ flask-talisman==0.7.0 # via apache-superset (setup.py)
|
|||
flask-wtf==0.14.2 # via apache-superset (setup.py), flask-appbuilder
|
||||
flask==1.1.1 # via apache-superset (setup.py), flask-appbuilder, flask-babel, flask-caching, flask-compress, flask-jwt-extended, flask-login, flask-migrate, flask-openid, flask-sqlalchemy, flask-wtf
|
||||
geographiclib==1.50 # via geopy
|
||||
geopy==1.20.0 # via apache-superset (setup.py)
|
||||
geopy==1.21.0 # via apache-superset (setup.py)
|
||||
gunicorn==20.0.4 # via apache-superset (setup.py)
|
||||
humanize==0.5.1 # via apache-superset (setup.py)
|
||||
importlib-metadata==1.4.0 # via jsonschema, kombu
|
||||
|
|
|
@ -265,15 +265,23 @@ class ChartDataSelectOptionsSchema(ChartDataPostProcessingOperationOptionsSchema
|
|||
columns = fields.List(
|
||||
fields.String(),
|
||||
description="Columns which to select from the input data, in the desired "
|
||||
"order. If columns are renamed, the old column name should be "
|
||||
"order. If columns are renamed, the original column name should be "
|
||||
"referenced here.",
|
||||
example=["country", "gender", "age"],
|
||||
required=False,
|
||||
)
|
||||
exclude = fields.List(
|
||||
fields.String(),
|
||||
description="Columns to exclude from selection.",
|
||||
example=["my_temp_column"],
|
||||
required=False,
|
||||
)
|
||||
rename = fields.List(
|
||||
fields.Dict(),
|
||||
description="columns which to rename, mapping source column to target column. "
|
||||
"For instance, `{'y': 'y2'}` will rename the column `y` to `y2`.",
|
||||
example=[{"age": "average_age"}],
|
||||
required=False,
|
||||
)
|
||||
|
||||
|
||||
|
@ -335,12 +343,81 @@ class ChartDataPivotOptionsSchema(ChartDataPostProcessingOperationOptionsSchema)
|
|||
aggregates = ChartDataAggregateConfigField()
|
||||
|
||||
|
||||
class ChartDataGeohashDecodeOptionsSchema(
|
||||
ChartDataPostProcessingOperationOptionsSchema
|
||||
):
|
||||
"""
|
||||
Geohash decode operation config.
|
||||
"""
|
||||
|
||||
geohash = fields.String(
|
||||
description="Name of source column containing geohash string", required=True,
|
||||
)
|
||||
latitude = fields.String(
|
||||
description="Name of target column for decoded latitude", required=True,
|
||||
)
|
||||
longitude = fields.String(
|
||||
description="Name of target column for decoded longitude", required=True,
|
||||
)
|
||||
|
||||
|
||||
class ChartDataGeohashEncodeOptionsSchema(
|
||||
ChartDataPostProcessingOperationOptionsSchema
|
||||
):
|
||||
"""
|
||||
Geohash encode operation config.
|
||||
"""
|
||||
|
||||
latitude = fields.String(
|
||||
description="Name of source latitude column", required=True,
|
||||
)
|
||||
longitude = fields.String(
|
||||
description="Name of source longitude column", required=True,
|
||||
)
|
||||
geohash = fields.String(
|
||||
description="Name of target column for encoded geohash string", required=True,
|
||||
)
|
||||
|
||||
|
||||
class ChartDataGeodeticParseOptionsSchema(
|
||||
ChartDataPostProcessingOperationOptionsSchema
|
||||
):
|
||||
"""
|
||||
Geodetic point string parsing operation config.
|
||||
"""
|
||||
|
||||
geodetic = fields.String(
|
||||
description="Name of source column containing geodetic point strings",
|
||||
required=True,
|
||||
)
|
||||
latitude = fields.String(
|
||||
description="Name of target column for decoded latitude", required=True,
|
||||
)
|
||||
longitude = fields.String(
|
||||
description="Name of target column for decoded longitude", required=True,
|
||||
)
|
||||
altitude = fields.String(
|
||||
description="Name of target column for decoded altitude. If omitted, "
|
||||
"altitude information in geodetic string is ignored.",
|
||||
required=False,
|
||||
)
|
||||
|
||||
|
||||
class ChartDataPostProcessingOperationSchema(Schema):
|
||||
operation = fields.String(
|
||||
description="Post processing operation type",
|
||||
required=True,
|
||||
validate=validate.OneOf(
|
||||
choices=("aggregate", "pivot", "rolling", "select", "sort")
|
||||
choices=(
|
||||
"aggregate",
|
||||
"geodetic_parse",
|
||||
"geohash_decode",
|
||||
"geohash_encode",
|
||||
"pivot",
|
||||
"rolling",
|
||||
"select",
|
||||
"sort",
|
||||
)
|
||||
),
|
||||
example="aggregate",
|
||||
)
|
||||
|
@ -638,4 +715,7 @@ CHART_DATA_SCHEMAS = (
|
|||
ChartDataRollingOptionsSchema,
|
||||
ChartDataSelectOptionsSchema,
|
||||
ChartDataSortOptionsSchema,
|
||||
ChartDataGeohashDecodeOptionsSchema,
|
||||
ChartDataGeohashEncodeOptionsSchema,
|
||||
ChartDataGeodeticParseOptionsSchema,
|
||||
)
|
||||
|
|
|
@ -15,10 +15,12 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import geohash as geohash_lib
|
||||
import numpy as np
|
||||
from flask_babel import gettext as _
|
||||
from geopy.point import Point
|
||||
from pandas import DataFrame, NamedAgg
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
|
@ -144,10 +146,7 @@ def _append_columns(
|
|||
:return: new DataFrame with combined data from `base_df` and `append_df`
|
||||
"""
|
||||
return base_df.assign(
|
||||
**{
|
||||
target: append_df[append_df.columns[idx]]
|
||||
for idx, target in enumerate(columns.values())
|
||||
}
|
||||
**{target: append_df[source] for source, target in columns.items()}
|
||||
)
|
||||
|
||||
|
||||
|
@ -323,9 +322,12 @@ def rolling( # pylint: disable=too-many-arguments
|
|||
return df
|
||||
|
||||
|
||||
@validate_column_args("columns", "rename")
|
||||
@validate_column_args("columns", "drop", "rename")
|
||||
def select(
|
||||
df: DataFrame, columns: List[str], rename: Optional[Dict[str, str]] = None
|
||||
df: DataFrame,
|
||||
columns: Optional[List[str]] = None,
|
||||
exclude: Optional[List[str]] = None,
|
||||
rename: Optional[Dict[str, str]] = None,
|
||||
) -> DataFrame:
|
||||
"""
|
||||
Only select a subset of columns in the original dataset. Can be useful for
|
||||
|
@ -333,15 +335,21 @@ def select(
|
|||
|
||||
:param df: DataFrame on which the rolling period will be based.
|
||||
:param columns: Columns which to select from the DataFrame, in the desired order.
|
||||
If columns are renamed, the old column name should be referenced
|
||||
here.
|
||||
If left undefined, all columns will be selected. If columns are
|
||||
renamed, the original column name should be referenced here.
|
||||
:param exclude: columns to exclude from selection. If columns are renamed, the new
|
||||
column name should be referenced here.
|
||||
:param rename: columns which to rename, mapping source column to target column.
|
||||
For instance, `{'y': 'y2'}` will rename the column `y` to
|
||||
`y2`.
|
||||
:return: Subset of columns in original DataFrame
|
||||
:raises ChartDataValidationError: If the request in incorrect
|
||||
"""
|
||||
df_select = df[columns]
|
||||
df_select = df.copy(deep=False)
|
||||
if columns:
|
||||
df_select = df_select[columns]
|
||||
if exclude:
|
||||
df_select = df_select.drop(exclude, axis=1)
|
||||
if rename is not None:
|
||||
df_select = df_select.rename(columns=rename)
|
||||
return df_select
|
||||
|
@ -350,6 +358,7 @@ def select(
|
|||
@validate_column_args("columns")
|
||||
def diff(df: DataFrame, columns: Dict[str, str], periods: int = 1,) -> DataFrame:
|
||||
"""
|
||||
Calculate row-by-row difference for select columns.
|
||||
|
||||
:param df: DataFrame on which the diff will be based.
|
||||
:param columns: columns on which to perform diff, mapping source column to
|
||||
|
@ -369,6 +378,7 @@ def diff(df: DataFrame, columns: Dict[str, str], periods: int = 1,) -> DataFrame
|
|||
@validate_column_args("columns")
|
||||
def cum(df: DataFrame, columns: Dict[str, str], operator: str) -> DataFrame:
|
||||
"""
|
||||
Calculate cumulative sum/product/min/max for select columns.
|
||||
|
||||
:param df: DataFrame on which the cumulative operation will be based.
|
||||
:param columns: columns on which to perform a cumulative operation, mapping source
|
||||
|
@ -377,7 +387,7 @@ def cum(df: DataFrame, columns: Dict[str, str], operator: str) -> DataFrame:
|
|||
`y2` based on cumulative values calculated from `y`, leaving the original
|
||||
column `y` unchanged.
|
||||
:param operator: cumulative operator, e.g. `sum`, `prod`, `min`, `max`
|
||||
:return:
|
||||
:return: DataFrame with cumulated columns
|
||||
"""
|
||||
df_cum = df[columns.keys()]
|
||||
operation = "cum" + operator
|
||||
|
@ -388,3 +398,92 @@ def cum(df: DataFrame, columns: Dict[str, str], operator: str) -> DataFrame:
|
|||
_("Invalid cumulative operator: %(operator)s", operator=operator)
|
||||
)
|
||||
return _append_columns(df, getattr(df_cum, operation)(), columns)
|
||||
|
||||
|
||||
def geohash_decode(
|
||||
df: DataFrame, geohash: str, longitude: str, latitude: str
|
||||
) -> DataFrame:
|
||||
"""
|
||||
Decode a geohash column into longitude and latitude
|
||||
|
||||
:param df: DataFrame containing geohash data
|
||||
:param geohash: Name of source column containing geohash location.
|
||||
:param longitude: Name of new column to be created containing longitude.
|
||||
:param latitude: Name of new column to be created containing latitude.
|
||||
:return: DataFrame with decoded longitudes and latitudes
|
||||
"""
|
||||
try:
|
||||
lonlat_df = DataFrame()
|
||||
lonlat_df["latitude"], lonlat_df["longitude"] = zip(
|
||||
*df[geohash].apply(geohash_lib.decode)
|
||||
)
|
||||
return _append_columns(
|
||||
df, lonlat_df, {"latitude": latitude, "longitude": longitude}
|
||||
)
|
||||
except ValueError:
|
||||
raise QueryObjectValidationError(_("Invalid geohash string"))
|
||||
|
||||
|
||||
def geohash_encode(
|
||||
df: DataFrame, geohash: str, longitude: str, latitude: str,
|
||||
) -> DataFrame:
|
||||
"""
|
||||
Encode longitude and latitude into geohash
|
||||
|
||||
:param df: DataFrame containing longitude and latitude data
|
||||
:param geohash: Name of new column to be created containing geohash location.
|
||||
:param longitude: Name of source column containing longitude.
|
||||
:param latitude: Name of source column containing latitude.
|
||||
:return: DataFrame with decoded longitudes and latitudes
|
||||
"""
|
||||
try:
|
||||
encode_df = df[[latitude, longitude]]
|
||||
encode_df.columns = ["latitude", "longitude"]
|
||||
encode_df["geohash"] = encode_df.apply(
|
||||
lambda row: geohash_lib.encode(row["latitude"], row["longitude"]), axis=1,
|
||||
)
|
||||
return _append_columns(df, encode_df, {"geohash": geohash})
|
||||
except ValueError:
|
||||
QueryObjectValidationError(_("Invalid longitude/latitude"))
|
||||
|
||||
|
||||
def geodetic_parse(
|
||||
df: DataFrame,
|
||||
geodetic: str,
|
||||
longitude: str,
|
||||
latitude: str,
|
||||
altitude: Optional[str] = None,
|
||||
) -> DataFrame:
|
||||
"""
|
||||
Parse a column containing a geodetic point string
|
||||
[Geopy](https://geopy.readthedocs.io/en/stable/#geopy.point.Point).
|
||||
|
||||
:param df: DataFrame containing geodetic point data
|
||||
:param geodetic: Name of source column containing geodetic point string.
|
||||
:param longitude: Name of new column to be created containing longitude.
|
||||
:param latitude: Name of new column to be created containing latitude.
|
||||
:param altitude: Name of new column to be created containing altitude.
|
||||
:return: DataFrame with decoded longitudes and latitudes
|
||||
"""
|
||||
|
||||
def _parse_location(location: str) -> Tuple[float, float, float]:
|
||||
"""
|
||||
Parse a string containing a geodetic point and return latitude, longitude
|
||||
and altitude
|
||||
"""
|
||||
point = Point(location) # type: ignore
|
||||
return point[0], point[1], point[2]
|
||||
|
||||
try:
|
||||
geodetic_df = DataFrame()
|
||||
(
|
||||
geodetic_df["latitude"],
|
||||
geodetic_df["longitude"],
|
||||
geodetic_df["altitude"],
|
||||
) = zip(*df[geodetic].apply(_parse_location))
|
||||
columns = {"latitude": latitude, "longitude": longitude}
|
||||
if altitude:
|
||||
columns["altitude"] = altitude
|
||||
return _append_columns(df, geodetic_df, columns)
|
||||
except ValueError:
|
||||
raise QueryObjectValidationError(_("Invalid geodetic string"))
|
||||
|
|
|
@ -119,3 +119,17 @@ timeseries_df = DataFrame(
|
|||
index=to_datetime(["2019-01-01", "2019-01-02", "2019-01-05", "2019-01-07"]),
|
||||
data={"label": ["x", "y", "z", "q"], "y": [1.0, 2.0, 3.0, 4.0]},
|
||||
)
|
||||
|
||||
lonlat_df = DataFrame(
|
||||
{
|
||||
"city": ["New York City", "Sydney"],
|
||||
"geohash": ["dr5regw3pg6f", "r3gx2u9qdevk"],
|
||||
"latitude": [40.71277496, -33.85598011],
|
||||
"longitude": [-74.00597306, 151.20666526],
|
||||
"altitude": [5.5, 0.012],
|
||||
"geodetic": [
|
||||
"40.71277496, -74.00597306, 5.5km",
|
||||
"-33.85598011, 151.20666526, 12m",
|
||||
],
|
||||
}
|
||||
)
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
# under the License.
|
||||
# isort:skip_file
|
||||
import math
|
||||
from typing import Any, List
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from pandas import Series
|
||||
|
||||
|
@ -24,7 +24,7 @@ from superset.exceptions import QueryObjectValidationError
|
|||
from superset.utils import pandas_postprocessing as proc
|
||||
|
||||
from .base_tests import SupersetTestCase
|
||||
from .fixtures.dataframes import categories_df, timeseries_df
|
||||
from .fixtures.dataframes import categories_df, lonlat_df, timeseries_df
|
||||
|
||||
|
||||
def series_to_list(series: Series) -> List[Any]:
|
||||
|
@ -43,6 +43,19 @@ def series_to_list(series: Series) -> List[Any]:
|
|||
]
|
||||
|
||||
|
||||
def round_floats(
|
||||
floats: List[Optional[float]], precision: int
|
||||
) -> List[Optional[float]]:
|
||||
"""
|
||||
Round list of floats to certain precision
|
||||
|
||||
:param floats: floats to round
|
||||
:param precision: intended decimal precision
|
||||
:return: rounded floats
|
||||
"""
|
||||
return [round(val, precision) if val else None for val in floats]
|
||||
|
||||
|
||||
class PostProcessingTestCase(SupersetTestCase):
|
||||
def test_pivot(self):
|
||||
aggregates = {"idx_nulls": {"operator": "sum"}}
|
||||
|
@ -219,25 +232,40 @@ class PostProcessingTestCase(SupersetTestCase):
|
|||
post_df = proc.select(df=timeseries_df, columns=["label"])
|
||||
self.assertListEqual(post_df.columns.tolist(), ["label"])
|
||||
|
||||
# rename one column
|
||||
# rename and select one column
|
||||
post_df = proc.select(df=timeseries_df, columns=["y"], rename={"y": "y1"})
|
||||
self.assertListEqual(post_df.columns.tolist(), ["y1"])
|
||||
|
||||
# rename one and leave one unchanged
|
||||
post_df = proc.select(
|
||||
df=timeseries_df, columns=["label", "y"], rename={"y": "y1"}
|
||||
)
|
||||
post_df = proc.select(df=timeseries_df, rename={"y": "y1"})
|
||||
self.assertListEqual(post_df.columns.tolist(), ["label", "y1"])
|
||||
|
||||
# drop one column
|
||||
post_df = proc.select(df=timeseries_df, exclude=["label"])
|
||||
self.assertListEqual(post_df.columns.tolist(), ["y"])
|
||||
|
||||
# rename and drop one column
|
||||
post_df = proc.select(df=timeseries_df, rename={"y": "y1"}, exclude=["label"])
|
||||
self.assertListEqual(post_df.columns.tolist(), ["y1"])
|
||||
|
||||
# invalid columns
|
||||
self.assertRaises(
|
||||
QueryObjectValidationError,
|
||||
proc.select,
|
||||
df=timeseries_df,
|
||||
columns=["qwerty"],
|
||||
columns=["abc"],
|
||||
rename={"abc": "qwerty"},
|
||||
)
|
||||
|
||||
# select renamed column by new name
|
||||
self.assertRaises(
|
||||
QueryObjectValidationError,
|
||||
proc.select,
|
||||
df=timeseries_df,
|
||||
columns=["label_new"],
|
||||
rename={"label": "label_new"},
|
||||
)
|
||||
|
||||
def test_diff(self):
|
||||
# overwrite column
|
||||
post_df = proc.diff(df=timeseries_df, columns={"y": "y"})
|
||||
|
@ -288,3 +316,83 @@ class PostProcessingTestCase(SupersetTestCase):
|
|||
columns={"y": "y"},
|
||||
operator="abc",
|
||||
)
|
||||
|
||||
def test_geohash_decode(self):
|
||||
# decode lon/lat from geohash
|
||||
post_df = proc.geohash_decode(
|
||||
df=lonlat_df[["city", "geohash"]],
|
||||
geohash="geohash",
|
||||
latitude="latitude",
|
||||
longitude="longitude",
|
||||
)
|
||||
self.assertListEqual(
|
||||
sorted(post_df.columns.tolist()),
|
||||
sorted(["city", "geohash", "latitude", "longitude"]),
|
||||
)
|
||||
self.assertListEqual(
|
||||
round_floats(series_to_list(post_df["longitude"]), 6),
|
||||
round_floats(series_to_list(lonlat_df["longitude"]), 6),
|
||||
)
|
||||
self.assertListEqual(
|
||||
round_floats(series_to_list(post_df["latitude"]), 6),
|
||||
round_floats(series_to_list(lonlat_df["latitude"]), 6),
|
||||
)
|
||||
|
||||
def test_geohash_encode(self):
|
||||
# encode lon/lat into geohash
|
||||
post_df = proc.geohash_encode(
|
||||
df=lonlat_df[["city", "latitude", "longitude"]],
|
||||
latitude="latitude",
|
||||
longitude="longitude",
|
||||
geohash="geohash",
|
||||
)
|
||||
self.assertListEqual(
|
||||
sorted(post_df.columns.tolist()),
|
||||
sorted(["city", "geohash", "latitude", "longitude"]),
|
||||
)
|
||||
self.assertListEqual(
|
||||
series_to_list(post_df["geohash"]), series_to_list(lonlat_df["geohash"]),
|
||||
)
|
||||
|
||||
def test_geodetic_parse(self):
|
||||
# parse geodetic string with altitude into lon/lat/altitude
|
||||
post_df = proc.geodetic_parse(
|
||||
df=lonlat_df[["city", "geodetic"]],
|
||||
geodetic="geodetic",
|
||||
latitude="latitude",
|
||||
longitude="longitude",
|
||||
altitude="altitude",
|
||||
)
|
||||
self.assertListEqual(
|
||||
sorted(post_df.columns.tolist()),
|
||||
sorted(["city", "geodetic", "latitude", "longitude", "altitude"]),
|
||||
)
|
||||
self.assertListEqual(
|
||||
series_to_list(post_df["longitude"]),
|
||||
series_to_list(lonlat_df["longitude"]),
|
||||
)
|
||||
self.assertListEqual(
|
||||
series_to_list(post_df["latitude"]), series_to_list(lonlat_df["latitude"]),
|
||||
)
|
||||
self.assertListEqual(
|
||||
series_to_list(post_df["altitude"]), series_to_list(lonlat_df["altitude"]),
|
||||
)
|
||||
|
||||
# parse geodetic string into lon/lat
|
||||
post_df = proc.geodetic_parse(
|
||||
df=lonlat_df[["city", "geodetic"]],
|
||||
geodetic="geodetic",
|
||||
latitude="latitude",
|
||||
longitude="longitude",
|
||||
)
|
||||
self.assertListEqual(
|
||||
sorted(post_df.columns.tolist()),
|
||||
sorted(["city", "geodetic", "latitude", "longitude"]),
|
||||
)
|
||||
self.assertListEqual(
|
||||
series_to_list(post_df["longitude"]),
|
||||
series_to_list(lonlat_df["longitude"]),
|
||||
)
|
||||
self.assertListEqual(
|
||||
series_to_list(post_df["latitude"]), series_to_list(lonlat_df["latitude"]),
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue