This commit is contained in:
Balthazar Rouberol 2024-05-05 02:17:19 -03:00 committed by GitHub
commit 94abe3c1e5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 137 additions and 1 deletions

View File

@ -886,6 +886,12 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
Parse a row or array column
:param result: list tracking the results
"""
# We remove the (key_type, result_type) annotations for maps and arrays,
# as these types trip up the parser, and they are not reflected in SQLAlchemy
# column types
parent_data_type = re.sub(r"array\(\w+\)", "array", parent_data_type)
parent_data_type = re.sub(r"map\(\w+,\s*\w+\)", "map", parent_data_type)
formatted_parent_column_name = parent_column_name
# Quote the column name if there is a space
if " " in parent_column_name:
@ -942,7 +948,10 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
# the stack. We have run across a structural data type within the
# overall structural data type. Otherwise, we have completely parsed
# through the entire structural data type and can move on.
if not (inner_type.endswith("array") or inner_type.endswith("row")):
if (
not (inner_type.endswith("array") or inner_type.endswith("row"))
and stack
):
stack.pop()
# We have an array of row objects (i.e. array(row(...)))
elif inner_type in ("array", "row"):

View File

@ -155,3 +155,130 @@ def test_where_latest_partition(
)
assert str(actual) == expected
@pytest.mark.parametrize(
"parent_data_type, expected",
[
(
"map(varchar, varchar)",
[
{
"column_name": ".test_column",
"is_dttm": None,
"name": ".test_column",
"type": "MAP",
"type_generic": None,
}
],
),
(
"array(string)",
[
{
"column_name": "test_column",
"is_dttm": None,
"name": "test_column",
"type": "ARRAY",
"type_generic": None,
}
],
),
(
"map(string, array(string))",
[
{
"column_name": ".test_column",
"is_dttm": None,
"name": ".test_column",
"type": "MAP",
"type_generic": None,
}
],
),
(
'row("protocol" varchar, "status" integer)',
[
{
"column_name": "test_column",
"is_dttm": None,
"name": "test_column",
"type": "ROW",
"type_generic": None,
},
{
"column_name": 'test_column."protocol"',
"is_dttm": None,
"name": 'test_column."protocol"',
"type": "VARCHAR",
"type_generic": None,
},
{
"column_name": 'test_column."status"',
"is_dttm": None,
"name": 'test_column."status"',
"type": "INTEGER",
"type_generic": None,
},
],
),
(
'row("protocol" varchar, "request_headers" map(varchar, varchar))',
[
{
"column_name": "test_column",
"is_dttm": None,
"name": "test_column",
"type": "ROW",
"type_generic": None,
},
{
"column_name": 'test_column."protocol"',
"is_dttm": None,
"name": 'test_column."protocol"',
"type": "VARCHAR",
"type_generic": None,
},
{
"column_name": 'test_column."request_headers"',
"is_dttm": None,
"name": 'test_column."request_headers"',
"type": "MAP",
"type_generic": None,
},
],
),
(
'row("protocol" varchar, "request_headers" map(varchar, array(string)))',
[
{
"column_name": "test_column",
"is_dttm": None,
"name": "test_column",
"type": "ROW",
"type_generic": None,
},
{
"column_name": 'test_column."protocol"',
"is_dttm": None,
"name": 'test_column."protocol"',
"type": "VARCHAR",
"type_generic": None,
},
{
"column_name": 'test_column."request_headers"',
"is_dttm": None,
"name": 'test_column."request_headers"',
"type": "MAP",
"type_generic": None,
},
],
),
],
)
def test_parse_structural_column(parent_data_type, expected):
from superset.db_engine_specs.presto import PrestoEngineSpec as spec
accumulator = []
spec._parse_structural_column("test_column", parent_data_type, accumulator)
assert accumulator == expected