feat: Add geospatial post processing operations (#9661)

* feat: Add geospatial post processing operations

* Linting

* Refactor

* Add tests

* Improve docs

* Address comments

* fix latitude/longitude mixup

* fix: bad refactor by pycharm
This commit is contained in:
Ville Brofeldt 2020-04-28 20:15:16 +03:00 committed by GitHub
parent c474ea848a
commit a52cfcd234
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 322 additions and 21 deletions

View File

@ -34,7 +34,7 @@ flask-talisman==0.7.0 # via apache-superset (setup.py)
flask-wtf==0.14.2 # via apache-superset (setup.py), flask-appbuilder
flask==1.1.1 # via apache-superset (setup.py), flask-appbuilder, flask-babel, flask-caching, flask-compress, flask-jwt-extended, flask-login, flask-migrate, flask-openid, flask-sqlalchemy, flask-wtf
geographiclib==1.50 # via geopy
geopy==1.20.0 # via apache-superset (setup.py)
geopy==1.21.0 # via apache-superset (setup.py)
gunicorn==20.0.4 # via apache-superset (setup.py)
humanize==0.5.1 # via apache-superset (setup.py)
importlib-metadata==1.4.0 # via jsonschema, kombu

View File

@ -265,15 +265,23 @@ class ChartDataSelectOptionsSchema(ChartDataPostProcessingOperationOptionsSchema
columns = fields.List(
fields.String(),
description="Columns which to select from the input data, in the desired "
"order. If columns are renamed, the old column name should be "
"order. If columns are renamed, the original column name should be "
"referenced here.",
example=["country", "gender", "age"],
required=False,
)
exclude = fields.List(
fields.String(),
description="Columns to exclude from selection.",
example=["my_temp_column"],
required=False,
)
rename = fields.List(
fields.Dict(),
description="columns which to rename, mapping source column to target column. "
"For instance, `{'y': 'y2'}` will rename the column `y` to `y2`.",
example=[{"age": "average_age"}],
required=False,
)
@ -335,12 +343,81 @@ class ChartDataPivotOptionsSchema(ChartDataPostProcessingOperationOptionsSchema)
aggregates = ChartDataAggregateConfigField()
class ChartDataGeohashDecodeOptionsSchema(
ChartDataPostProcessingOperationOptionsSchema
):
"""
Geohash decode operation config.
"""
geohash = fields.String(
description="Name of source column containing geohash string", required=True,
)
latitude = fields.String(
description="Name of target column for decoded latitude", required=True,
)
longitude = fields.String(
description="Name of target column for decoded longitude", required=True,
)
class ChartDataGeohashEncodeOptionsSchema(
ChartDataPostProcessingOperationOptionsSchema
):
"""
Geohash encode operation config.
"""
latitude = fields.String(
description="Name of source latitude column", required=True,
)
longitude = fields.String(
description="Name of source longitude column", required=True,
)
geohash = fields.String(
description="Name of target column for encoded geohash string", required=True,
)
class ChartDataGeodeticParseOptionsSchema(
ChartDataPostProcessingOperationOptionsSchema
):
"""
Geodetic point string parsing operation config.
"""
geodetic = fields.String(
description="Name of source column containing geodetic point strings",
required=True,
)
latitude = fields.String(
description="Name of target column for decoded latitude", required=True,
)
longitude = fields.String(
description="Name of target column for decoded longitude", required=True,
)
altitude = fields.String(
description="Name of target column for decoded altitude. If omitted, "
"altitude information in geodetic string is ignored.",
required=False,
)
class ChartDataPostProcessingOperationSchema(Schema):
operation = fields.String(
description="Post processing operation type",
required=True,
validate=validate.OneOf(
choices=("aggregate", "pivot", "rolling", "select", "sort")
choices=(
"aggregate",
"geodetic_parse",
"geohash_decode",
"geohash_encode",
"pivot",
"rolling",
"select",
"sort",
)
),
example="aggregate",
)
@ -638,4 +715,7 @@ CHART_DATA_SCHEMAS = (
ChartDataRollingOptionsSchema,
ChartDataSelectOptionsSchema,
ChartDataSortOptionsSchema,
ChartDataGeohashDecodeOptionsSchema,
ChartDataGeohashEncodeOptionsSchema,
ChartDataGeodeticParseOptionsSchema,
)

View File

@ -15,10 +15,12 @@
# specific language governing permissions and limitations
# under the License.
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import geohash as geohash_lib
import numpy as np
from flask_babel import gettext as _
from geopy.point import Point
from pandas import DataFrame, NamedAgg
from superset.exceptions import QueryObjectValidationError
@ -144,10 +146,7 @@ def _append_columns(
:return: new DataFrame with combined data from `base_df` and `append_df`
"""
return base_df.assign(
**{
target: append_df[append_df.columns[idx]]
for idx, target in enumerate(columns.values())
}
**{target: append_df[source] for source, target in columns.items()}
)
@ -323,9 +322,12 @@ def rolling( # pylint: disable=too-many-arguments
return df
@validate_column_args("columns", "rename")
@validate_column_args("columns", "drop", "rename")
def select(
df: DataFrame, columns: List[str], rename: Optional[Dict[str, str]] = None
df: DataFrame,
columns: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
rename: Optional[Dict[str, str]] = None,
) -> DataFrame:
"""
Only select a subset of columns in the original dataset. Can be useful for
@ -333,15 +335,21 @@ def select(
:param df: DataFrame on which the rolling period will be based.
:param columns: Columns which to select from the DataFrame, in the desired order.
If columns are renamed, the old column name should be referenced
here.
If left undefined, all columns will be selected. If columns are
renamed, the original column name should be referenced here.
:param exclude: columns to exclude from selection. If columns are renamed, the new
column name should be referenced here.
:param rename: columns which to rename, mapping source column to target column.
For instance, `{'y': 'y2'}` will rename the column `y` to
`y2`.
:return: Subset of columns in original DataFrame
:raises ChartDataValidationError: If the request in incorrect
"""
df_select = df[columns]
df_select = df.copy(deep=False)
if columns:
df_select = df_select[columns]
if exclude:
df_select = df_select.drop(exclude, axis=1)
if rename is not None:
df_select = df_select.rename(columns=rename)
return df_select
@ -350,6 +358,7 @@ def select(
@validate_column_args("columns")
def diff(df: DataFrame, columns: Dict[str, str], periods: int = 1,) -> DataFrame:
"""
Calculate row-by-row difference for select columns.
:param df: DataFrame on which the diff will be based.
:param columns: columns on which to perform diff, mapping source column to
@ -369,6 +378,7 @@ def diff(df: DataFrame, columns: Dict[str, str], periods: int = 1,) -> DataFrame
@validate_column_args("columns")
def cum(df: DataFrame, columns: Dict[str, str], operator: str) -> DataFrame:
"""
Calculate cumulative sum/product/min/max for select columns.
:param df: DataFrame on which the cumulative operation will be based.
:param columns: columns on which to perform a cumulative operation, mapping source
@ -377,7 +387,7 @@ def cum(df: DataFrame, columns: Dict[str, str], operator: str) -> DataFrame:
`y2` based on cumulative values calculated from `y`, leaving the original
column `y` unchanged.
:param operator: cumulative operator, e.g. `sum`, `prod`, `min`, `max`
:return:
:return: DataFrame with cumulated columns
"""
df_cum = df[columns.keys()]
operation = "cum" + operator
@ -388,3 +398,92 @@ def cum(df: DataFrame, columns: Dict[str, str], operator: str) -> DataFrame:
_("Invalid cumulative operator: %(operator)s", operator=operator)
)
return _append_columns(df, getattr(df_cum, operation)(), columns)
def geohash_decode(
df: DataFrame, geohash: str, longitude: str, latitude: str
) -> DataFrame:
"""
Decode a geohash column into longitude and latitude
:param df: DataFrame containing geohash data
:param geohash: Name of source column containing geohash location.
:param longitude: Name of new column to be created containing longitude.
:param latitude: Name of new column to be created containing latitude.
:return: DataFrame with decoded longitudes and latitudes
"""
try:
lonlat_df = DataFrame()
lonlat_df["latitude"], lonlat_df["longitude"] = zip(
*df[geohash].apply(geohash_lib.decode)
)
return _append_columns(
df, lonlat_df, {"latitude": latitude, "longitude": longitude}
)
except ValueError:
raise QueryObjectValidationError(_("Invalid geohash string"))
def geohash_encode(
df: DataFrame, geohash: str, longitude: str, latitude: str,
) -> DataFrame:
"""
Encode longitude and latitude into geohash
:param df: DataFrame containing longitude and latitude data
:param geohash: Name of new column to be created containing geohash location.
:param longitude: Name of source column containing longitude.
:param latitude: Name of source column containing latitude.
:return: DataFrame with decoded longitudes and latitudes
"""
try:
encode_df = df[[latitude, longitude]]
encode_df.columns = ["latitude", "longitude"]
encode_df["geohash"] = encode_df.apply(
lambda row: geohash_lib.encode(row["latitude"], row["longitude"]), axis=1,
)
return _append_columns(df, encode_df, {"geohash": geohash})
except ValueError:
QueryObjectValidationError(_("Invalid longitude/latitude"))
def geodetic_parse(
df: DataFrame,
geodetic: str,
longitude: str,
latitude: str,
altitude: Optional[str] = None,
) -> DataFrame:
"""
Parse a column containing a geodetic point string
[Geopy](https://geopy.readthedocs.io/en/stable/#geopy.point.Point).
:param df: DataFrame containing geodetic point data
:param geodetic: Name of source column containing geodetic point string.
:param longitude: Name of new column to be created containing longitude.
:param latitude: Name of new column to be created containing latitude.
:param altitude: Name of new column to be created containing altitude.
:return: DataFrame with decoded longitudes and latitudes
"""
def _parse_location(location: str) -> Tuple[float, float, float]:
"""
Parse a string containing a geodetic point and return latitude, longitude
and altitude
"""
point = Point(location) # type: ignore
return point[0], point[1], point[2]
try:
geodetic_df = DataFrame()
(
geodetic_df["latitude"],
geodetic_df["longitude"],
geodetic_df["altitude"],
) = zip(*df[geodetic].apply(_parse_location))
columns = {"latitude": latitude, "longitude": longitude}
if altitude:
columns["altitude"] = altitude
return _append_columns(df, geodetic_df, columns)
except ValueError:
raise QueryObjectValidationError(_("Invalid geodetic string"))

View File

@ -119,3 +119,17 @@ timeseries_df = DataFrame(
index=to_datetime(["2019-01-01", "2019-01-02", "2019-01-05", "2019-01-07"]),
data={"label": ["x", "y", "z", "q"], "y": [1.0, 2.0, 3.0, 4.0]},
)
lonlat_df = DataFrame(
{
"city": ["New York City", "Sydney"],
"geohash": ["dr5regw3pg6f", "r3gx2u9qdevk"],
"latitude": [40.71277496, -33.85598011],
"longitude": [-74.00597306, 151.20666526],
"altitude": [5.5, 0.012],
"geodetic": [
"40.71277496, -74.00597306, 5.5km",
"-33.85598011, 151.20666526, 12m",
],
}
)

View File

@ -16,7 +16,7 @@
# under the License.
# isort:skip_file
import math
from typing import Any, List
from typing import Any, List, Optional
from pandas import Series
@ -24,7 +24,7 @@ from superset.exceptions import QueryObjectValidationError
from superset.utils import pandas_postprocessing as proc
from .base_tests import SupersetTestCase
from .fixtures.dataframes import categories_df, timeseries_df
from .fixtures.dataframes import categories_df, lonlat_df, timeseries_df
def series_to_list(series: Series) -> List[Any]:
@ -43,6 +43,19 @@ def series_to_list(series: Series) -> List[Any]:
]
def round_floats(
floats: List[Optional[float]], precision: int
) -> List[Optional[float]]:
"""
Round list of floats to certain precision
:param floats: floats to round
:param precision: intended decimal precision
:return: rounded floats
"""
return [round(val, precision) if val else None for val in floats]
class PostProcessingTestCase(SupersetTestCase):
def test_pivot(self):
aggregates = {"idx_nulls": {"operator": "sum"}}
@ -219,25 +232,40 @@ class PostProcessingTestCase(SupersetTestCase):
post_df = proc.select(df=timeseries_df, columns=["label"])
self.assertListEqual(post_df.columns.tolist(), ["label"])
# rename one column
# rename and select one column
post_df = proc.select(df=timeseries_df, columns=["y"], rename={"y": "y1"})
self.assertListEqual(post_df.columns.tolist(), ["y1"])
# rename one and leave one unchanged
post_df = proc.select(
df=timeseries_df, columns=["label", "y"], rename={"y": "y1"}
)
post_df = proc.select(df=timeseries_df, rename={"y": "y1"})
self.assertListEqual(post_df.columns.tolist(), ["label", "y1"])
# drop one column
post_df = proc.select(df=timeseries_df, exclude=["label"])
self.assertListEqual(post_df.columns.tolist(), ["y"])
# rename and drop one column
post_df = proc.select(df=timeseries_df, rename={"y": "y1"}, exclude=["label"])
self.assertListEqual(post_df.columns.tolist(), ["y1"])
# invalid columns
self.assertRaises(
QueryObjectValidationError,
proc.select,
df=timeseries_df,
columns=["qwerty"],
columns=["abc"],
rename={"abc": "qwerty"},
)
# select renamed column by new name
self.assertRaises(
QueryObjectValidationError,
proc.select,
df=timeseries_df,
columns=["label_new"],
rename={"label": "label_new"},
)
def test_diff(self):
# overwrite column
post_df = proc.diff(df=timeseries_df, columns={"y": "y"})
@ -288,3 +316,83 @@ class PostProcessingTestCase(SupersetTestCase):
columns={"y": "y"},
operator="abc",
)
def test_geohash_decode(self):
# decode lon/lat from geohash
post_df = proc.geohash_decode(
df=lonlat_df[["city", "geohash"]],
geohash="geohash",
latitude="latitude",
longitude="longitude",
)
self.assertListEqual(
sorted(post_df.columns.tolist()),
sorted(["city", "geohash", "latitude", "longitude"]),
)
self.assertListEqual(
round_floats(series_to_list(post_df["longitude"]), 6),
round_floats(series_to_list(lonlat_df["longitude"]), 6),
)
self.assertListEqual(
round_floats(series_to_list(post_df["latitude"]), 6),
round_floats(series_to_list(lonlat_df["latitude"]), 6),
)
def test_geohash_encode(self):
# encode lon/lat into geohash
post_df = proc.geohash_encode(
df=lonlat_df[["city", "latitude", "longitude"]],
latitude="latitude",
longitude="longitude",
geohash="geohash",
)
self.assertListEqual(
sorted(post_df.columns.tolist()),
sorted(["city", "geohash", "latitude", "longitude"]),
)
self.assertListEqual(
series_to_list(post_df["geohash"]), series_to_list(lonlat_df["geohash"]),
)
def test_geodetic_parse(self):
# parse geodetic string with altitude into lon/lat/altitude
post_df = proc.geodetic_parse(
df=lonlat_df[["city", "geodetic"]],
geodetic="geodetic",
latitude="latitude",
longitude="longitude",
altitude="altitude",
)
self.assertListEqual(
sorted(post_df.columns.tolist()),
sorted(["city", "geodetic", "latitude", "longitude", "altitude"]),
)
self.assertListEqual(
series_to_list(post_df["longitude"]),
series_to_list(lonlat_df["longitude"]),
)
self.assertListEqual(
series_to_list(post_df["latitude"]), series_to_list(lonlat_df["latitude"]),
)
self.assertListEqual(
series_to_list(post_df["altitude"]), series_to_list(lonlat_df["altitude"]),
)
# parse geodetic string into lon/lat
post_df = proc.geodetic_parse(
df=lonlat_df[["city", "geodetic"]],
geodetic="geodetic",
latitude="latitude",
longitude="longitude",
)
self.assertListEqual(
sorted(post_df.columns.tolist()),
sorted(["city", "geodetic", "latitude", "longitude"]),
)
self.assertListEqual(
series_to_list(post_df["longitude"]),
series_to_list(lonlat_df["longitude"]),
)
self.assertListEqual(
series_to_list(post_df["latitude"]), series_to_list(lonlat_df["latitude"]),
)