diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index 13e574d7cf..345d5a4ec6 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=C,R,W -from collections import OrderedDict +from collections import defaultdict, deque, OrderedDict from datetime import datetime from distutils.version import StrictVersion import logging @@ -58,6 +58,51 @@ pandas_dtype_map = { } +def get_children(column: Dict[str, str]) -> List[Dict[str, str]]: + """ + Get the children of a complex Presto type (row or array). + + For arrays, we return a single list with the base type: + + >>> get_children(dict(name="a", type="ARRAY(BIGINT)")) + [{"name": "a", "type": "BIGINT"}] + + For rows, we return a list of the columns: + + >>> get_children(dict(name="a", type="ROW(BIGINT,FOO VARCHAR)")) + [{'name': 'a._col0', 'type': 'BIGINT'}, {'name': 'a.foo', 'type': 'VARCHAR'}] + + :param column: dictionary representing a Presto column + :return: list of dictionaries representing children columns + """ + pattern = re.compile("(?P\w+)\((?P.*)\)") + match = pattern.match(column["type"]) + if not match: + raise Exception(f"Unable to parse column type {column['type']}") + + group = match.groupdict() + type_ = group["type"].upper() + children_type = group["children"] + if type_ == "ARRAY": + return [{"name": column["name"], "type": children_type}] + elif type_ == "ROW": + nameless_columns = 0 + columns = [] + for child in utils.split(children_type, ","): + parts = list(utils.split(child.strip(), " ")) + if len(parts) == 2: + name, type_ = parts + name = name.strip('"') + else: + name = f"_col{nameless_columns}" + type_ = parts[0] + nameless_columns += 1 + columns.append({"name": f"{column['name']}.{name.lower()}", "type": type_}) + return columns + else: + raise Exception(f"Unknown type {type_}!") + + class PrestoEngineSpec(BaseEngineSpec): engine = "presto" @@ -846,43 +891,79 @@ class PrestoEngineSpec(BaseEngineSpec): if not is_feature_enabled("PRESTO_EXPAND_DATA"): return columns, data, [] + # insert a custom column that tracks the original row + columns.insert(0, {"name": "__row_id", "type": "BIGINT"}) + for i, row in enumerate(data): + row["__row_id"] = i + + # process each column, unnesting ARRAY types and expanding ROW types into new columns + to_process = deque((column, 0) for column in columns) all_columns: List[dict] = [] - # Get the list of all columns (selected fields and their nested fields) - for column in columns: - if column["type"].startswith("ARRAY") or column["type"].startswith("ROW"): - cls._parse_structural_column( - column["name"], column["type"].lower(), all_columns - ) - else: + expanded_columns = [] + current_array_level = None + while to_process: + column, level = to_process.popleft() + if column["name"] not in [column["name"] for column in all_columns]: all_columns.append(column) - # Build graphs where the root node is a row or array and its children are that - # column's nested fields - row_column_hierarchy, array_column_hierarchy, expanded_columns = cls._create_row_and_array_hierarchy( - columns - ) + # When unnesting arrays we need to keep track of how many extra rows + # were added, for each original row. This is necessary when we expand multiple + # arrays, so that the arrays after the first reuse the rows added by + # the first. every time we change a level in the nested arrays we + # reinitialize this. + if level != current_array_level: + unnested_rows: Dict[int, int] = defaultdict(int) + current_array_level = level - # Pull out a row's nested fields and their values into separate columns - ordered_row_columns = row_column_hierarchy.keys() - for datum in data: - for row_column in ordered_row_columns: - cls._expand_row_data(datum, row_column, row_column_hierarchy) + name = column["name"] - while array_column_hierarchy: - array_columns = list(array_column_hierarchy.keys()) - # Determine what columns are ready to be processed. - array_columns_to_process, unprocessed_array_columns = cls._split_array_columns_by_process_state( - array_columns, array_column_hierarchy, data[0] - ) - all_array_data = cls._process_array_data( - data, all_columns, array_column_hierarchy - ) - # Consolidate the original data set and the expanded array data - cls._consolidate_array_data_into_data(data, all_array_data) - # Remove processed array columns from the graph - cls._remove_processed_array_columns( - unprocessed_array_columns, array_column_hierarchy - ) + if column["type"].startswith("ARRAY("): + # keep processing array children; we append to the right so that + # multiple nested arrays are processed breadth-first + to_process.append((get_children(column)[0], level + 1)) + + # unnest array objects data into new rows + i = 0 + while i < len(data): + row = data[i] + values = row.get(name) + if values: + # how many extra rows we need to unnest the data? + extra_rows = len(values) - 1 + + # how many rows were already added for this row? + current_unnested_rows = unnested_rows[i] + + # add any necessary rows + missing = extra_rows - current_unnested_rows + for _ in range(missing): + data.insert(i + current_unnested_rows + 1, {}) + unnested_rows[i] += 1 + + # unnest array into rows + for j, value in enumerate(values): + data[i + j][name] = value + + # skip newly unnested rows + i += unnested_rows[i] + + i += 1 + + if column["type"].startswith("ROW("): + # expand columns; we append them to the left so they are added + # immediately after the parent + expanded = get_children(column) + to_process.extendleft((column, level) for column in expanded) + expanded_columns.extend(expanded) + + # expand row objects into new columns + for row in data: + for value, col in zip(row.get(name) or [], expanded): + row[col["name"]] = value + + data = [ + {k["name"]: row.get(k["name"], "") for k in all_columns} for row in data + ] return all_columns, data, expanded_columns diff --git a/superset/utils/core.py b/superset/utils/core.py index 55ada27879..7a74eb4333 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -32,7 +32,7 @@ import signal import smtplib from time import struct_time import traceback -from typing import List, NamedTuple, Optional, Tuple, Union +from typing import Iterator, List, NamedTuple, Optional, Tuple, Union from urllib.parse import unquote_plus import uuid import zlib @@ -1194,3 +1194,35 @@ class DatasourceName(NamedTuple): def get_stacktrace(): if current_app.config.get("SHOW_STACKTRACE"): return traceback.format_exc() + + +def split( + s: str, delimiter: str = " ", quote: str = '"', escaped_quote: str = r"\"" +) -> Iterator[str]: + """ + A split function that is aware of quotes and parentheses. + + :param s: string to split + :param delimiter: string defining where to split, usually a comma or space + :param quote: string, either a single or a double quote + :param escaped_quote: string representing an escaped quote + :return: list of strings + """ + parens = 0 + quotes = False + i = 0 + for j, c in enumerate(s): + complete = parens == 0 and not quotes + if complete and c == delimiter: + yield s[i:j] + i = j + len(delimiter) + elif c == "(": + parens += 1 + elif c == ")": + parens -= 1 + elif c == quote: + if quotes and s[j - len(escaped_quote) + 1 : j + 1] != escaped_quote: + quotes = False + elif not quotes: + quotes = True + yield s[i:] diff --git a/tests/db_engine_specs_test.py b/tests/db_engine_specs_test.py index a6db8a150b..2c6e3b305a 100644 --- a/tests/db_engine_specs_test.py +++ b/tests/db_engine_specs_test.py @@ -653,18 +653,51 @@ class DbEngineSpecsTestCase(SupersetTestCase): cols, data ) expected_cols = [ - {"name": "row_column", "type": "ROW"}, + {"name": "__row_id", "type": "BIGINT"}, + {"name": "row_column", "type": "ROW(NESTED_OBJ VARCHAR)"}, {"name": "row_column.nested_obj", "type": "VARCHAR"}, - {"name": "array_column", "type": "ARRAY"}, + {"name": "array_column", "type": "ARRAY(BIGINT)"}, ] + expected_data = [ - {"row_column": ["a"], "row_column.nested_obj": "a", "array_column": 1}, - {"row_column": "", "row_column.nested_obj": "", "array_column": 2}, - {"row_column": "", "row_column.nested_obj": "", "array_column": 3}, - {"row_column": ["b"], "row_column.nested_obj": "b", "array_column": 4}, - {"row_column": "", "row_column.nested_obj": "", "array_column": 5}, - {"row_column": "", "row_column.nested_obj": "", "array_column": 6}, + { + "__row_id": 0, + "array_column": 1, + "row_column": ["a"], + "row_column.nested_obj": "a", + }, + { + "__row_id": "", + "array_column": 2, + "row_column": "", + "row_column.nested_obj": "", + }, + { + "__row_id": "", + "array_column": 3, + "row_column": "", + "row_column.nested_obj": "", + }, + { + "__row_id": 1, + "array_column": 4, + "row_column": ["b"], + "row_column.nested_obj": "b", + }, + { + "__row_id": "", + "array_column": 5, + "row_column": "", + "row_column.nested_obj": "", + }, + { + "__row_id": "", + "array_column": 6, + "row_column": "", + "row_column.nested_obj": "", + }, ] + expected_expanded_cols = [{"name": "row_column.nested_obj", "type": "VARCHAR"}] self.assertEqual(actual_cols, expected_cols) self.assertEqual(actual_data, expected_data) @@ -677,7 +710,7 @@ class DbEngineSpecsTestCase(SupersetTestCase): cols = [ { "name": "row_column", - "type": "ROW(NESTED_OBJ1 VARCHAR, NESTED_ROW ROW(NESTED_OBJ2 VARCHAR)", + "type": "ROW(NESTED_OBJ1 VARCHAR, NESTED_ROW ROW(NESTED_OBJ2 VARCHAR))", } ] data = [{"row_column": ["a1", ["a2"]]}, {"row_column": ["b1", ["b2"]]}] @@ -685,28 +718,35 @@ class DbEngineSpecsTestCase(SupersetTestCase): cols, data ) expected_cols = [ - {"name": "row_column", "type": "ROW"}, - {"name": "row_column.nested_obj1", "type": "VARCHAR"}, - {"name": "row_column.nested_row", "type": "ROW"}, + {"name": "__row_id", "type": "BIGINT"}, + { + "name": "row_column", + "type": "ROW(NESTED_OBJ1 VARCHAR, NESTED_ROW ROW(NESTED_OBJ2 VARCHAR))", + }, + {"name": "row_column.nested_row", "type": "ROW(NESTED_OBJ2 VARCHAR)"}, {"name": "row_column.nested_row.nested_obj2", "type": "VARCHAR"}, + {"name": "row_column.nested_obj1", "type": "VARCHAR"}, ] expected_data = [ { + "__row_id": 0, "row_column": ["a1", ["a2"]], "row_column.nested_obj1": "a1", "row_column.nested_row": ["a2"], "row_column.nested_row.nested_obj2": "a2", }, { + "__row_id": 1, "row_column": ["b1", ["b2"]], "row_column.nested_obj1": "b1", "row_column.nested_row": ["b2"], "row_column.nested_row.nested_obj2": "b2", }, ] + expected_expanded_cols = [ {"name": "row_column.nested_obj1", "type": "VARCHAR"}, - {"name": "row_column.nested_row", "type": "ROW"}, + {"name": "row_column.nested_row", "type": "ROW(NESTED_OBJ2 VARCHAR)"}, {"name": "row_column.nested_row.nested_obj2", "type": "VARCHAR"}, ] self.assertEqual(actual_cols, expected_cols) @@ -732,63 +772,81 @@ class DbEngineSpecsTestCase(SupersetTestCase): cols, data ) expected_cols = [ + {"name": "__row_id", "type": "BIGINT"}, {"name": "int_column", "type": "BIGINT"}, - {"name": "array_column", "type": "ARRAY"}, - {"name": "array_column.nested_array", "type": "ARRAY"}, + { + "name": "array_column", + "type": "ARRAY(ROW(NESTED_ARRAY ARRAY(ROW(NESTED_OBJ VARCHAR))))", + }, + { + "name": "array_column.nested_array", + "type": "ARRAY(ROW(NESTED_OBJ VARCHAR))", + }, {"name": "array_column.nested_array.nested_obj", "type": "VARCHAR"}, ] expected_data = [ { - "int_column": 1, - "array_column": [[[["a"], ["b"]]], [[["c"], ["d"]]]], - "array_column.nested_array": [["a"], ["b"]], + "__row_id": 0, + "array_column": [[["a"], ["b"]]], + "array_column.nested_array": ["a"], "array_column.nested_array.nested_obj": "a", + "int_column": 1, }, { - "int_column": "", + "__row_id": "", "array_column": "", - "array_column.nested_array": "", + "array_column.nested_array": ["b"], "array_column.nested_array.nested_obj": "b", + "int_column": "", }, { - "int_column": "", - "array_column": "", - "array_column.nested_array": [["c"], ["d"]], + "__row_id": "", + "array_column": [[["c"], ["d"]]], + "array_column.nested_array": ["c"], "array_column.nested_array.nested_obj": "c", + "int_column": "", }, { - "int_column": "", + "__row_id": "", "array_column": "", - "array_column.nested_array": "", + "array_column.nested_array": ["d"], "array_column.nested_array.nested_obj": "d", + "int_column": "", }, { - "int_column": 2, - "array_column": [[[["e"], ["f"]]], [[["g"], ["h"]]]], - "array_column.nested_array": [["e"], ["f"]], + "__row_id": 1, + "array_column": [[["e"], ["f"]]], + "array_column.nested_array": ["e"], "array_column.nested_array.nested_obj": "e", + "int_column": 2, }, { - "int_column": "", + "__row_id": "", "array_column": "", - "array_column.nested_array": "", + "array_column.nested_array": ["f"], "array_column.nested_array.nested_obj": "f", + "int_column": "", }, { - "int_column": "", - "array_column": "", - "array_column.nested_array": [["g"], ["h"]], + "__row_id": "", + "array_column": [[["g"], ["h"]]], + "array_column.nested_array": ["g"], "array_column.nested_array.nested_obj": "g", + "int_column": "", }, { - "int_column": "", + "__row_id": "", "array_column": "", - "array_column.nested_array": "", + "array_column.nested_array": ["h"], "array_column.nested_array.nested_obj": "h", + "int_column": "", }, ] expected_expanded_cols = [ - {"name": "array_column.nested_array", "type": "ARRAY"}, + { + "name": "array_column.nested_array", + "type": "ARRAY(ROW(NESTED_OBJ VARCHAR))", + }, {"name": "array_column.nested_array.nested_obj", "type": "VARCHAR"}, ] self.assertEqual(actual_cols, expected_cols) diff --git a/tests/utils_tests.py b/tests/utils_tests.py index df11cf5cab..5efd44d993 100644 --- a/tests/utils_tests.py +++ b/tests/utils_tests.py @@ -45,6 +45,7 @@ from superset.utils.core import ( parse_js_uri_path_item, parse_past_timedelta, setup_cache, + split, validate_json, zlib_compress, zlib_decompress, @@ -832,6 +833,20 @@ class UtilsTestCase(unittest.TestCase): stacktrace = get_stacktrace() assert stacktrace is None + def test_split(self): + self.assertEqual(list(split("a b")), ["a", "b"]) + self.assertEqual(list(split("a,b", delimiter=",")), ["a", "b"]) + self.assertEqual(list(split("a,(b,a)", delimiter=",")), ["a", "(b,a)"]) + self.assertEqual( + list(split('a,(b,a),"foo , bar"', delimiter=",")), + ["a", "(b,a)", '"foo , bar"'], + ) + self.assertEqual( + list(split("a,'b,c'", delimiter=",", quote="'")), ["a", "'b,c'"] + ) + self.assertEqual(list(split('a "b c"')), ["a", '"b c"']) + self.assertEqual(list(split(r'a "b \" c"')), ["a", r'"b \" c"']) + def test_get_or_create_db(self): get_or_create_db("test_db", "sqlite:///superset.db") database = db.session.query(Database).filter_by(database_name="test_db").one()