diff --git a/superset/dataframe.py b/superset/dataframe.py index 40e57fcea9..47683b0cc0 100644 --- a/superset/dataframe.py +++ b/superset/dataframe.py @@ -62,6 +62,12 @@ def dedup(l, suffix="__", case_sensitive=True): return new_l +def is_numeric(dtype): + if hasattr(dtype, "_is_numeric"): + return dtype._is_numeric + return np.issubdtype(dtype, np.number) + + class SupersetDataFrame(object): # Mapping numpy dtype.char to generic database types type_map = { @@ -80,21 +86,45 @@ class SupersetDataFrame(object): } def __init__(self, data, cursor_description, db_engine_spec): - column_names = [] - if cursor_description: - column_names = [col[0] for col in cursor_description] - - self.column_names = dedup(column_names) - data = data or [] - self.df = pd.DataFrame(list(data), columns=self.column_names).infer_objects() + + column_names = [] + dtype = None + if cursor_description: + # get deduped list of column names + column_names = dedup([col[0] for col in cursor_description]) + + # fix cursor descriptor with the deduped names + cursor_description = [ + tuple([column_name, *list(description)[1:]]) + for column_name, description in zip(column_names, cursor_description) + ] + + # get type for better type casting, if possible + dtype = db_engine_spec.get_pandas_dtype(cursor_description) + + self.column_names = column_names + + if dtype: + # convert each column in data into a Series of the proper dtype; we + # need to do this because we can not specify a mixed dtype when + # instantiating the DataFrame, and this allows us to have different + # dtypes for each column. + array = np.array(data) + data = { + column: pd.Series(array[:, i], dtype=dtype[column]) + for i, column in enumerate(column_names) + } + self.df = pd.DataFrame(data, columns=column_names) + else: + self.df = pd.DataFrame(list(data), columns=column_names).infer_objects() self._type_dict = {} try: # The driver may not be passing a cursor.description self._type_dict = { col: db_engine_spec.get_datatype(cursor_description[i][1]) - for i, col in enumerate(self.column_names) + for i, col in enumerate(column_names) if cursor_description } except Exception as e: @@ -183,7 +213,7 @@ class SupersetDataFrame(object): if ( hasattr(dtype, "type") and issubclass(dtype.type, np.generic) - and np.issubdtype(dtype, np.number) + and is_numeric(dtype) ): return "sum" return None diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index d89f07a020..94e9bc5f2e 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -275,6 +275,12 @@ class BaseEngineSpec: return type_code.upper() return None + @classmethod + def get_pandas_dtype( + cls, cursor_description: List[tuple] + ) -> Optional[Dict[str, str]]: + return None + @classmethod def extra_table_metadata( cls, database, table_name: str, schema_name: str diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index 79c7117c79..186557c978 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -39,6 +39,21 @@ from superset.utils import core as utils QueryStatus = utils.QueryStatus +# map between Presto types and Pandas +pandas_dtype_map = { + "boolean": "bool", + "tinyint": "Int64", # note: capital "I" means nullable int + "smallint": "Int64", + "integer": "Int64", + "bigint": "Int64", + "real": "float64", + "double": "float64", + "varchar": "object", + "timestamp": "datetime64", + "date": "datetime64", + "varbinary": "object", +} + class PrestoEngineSpec(BaseEngineSpec): engine = "presto" @@ -1052,3 +1067,9 @@ class PrestoEngineSpec(BaseEngineSpec): if df.empty: return "" return df.to_dict()[field_to_return][0] + + @classmethod + def get_pandas_dtype(cls, cursor_description: List[tuple]) -> Dict[str, str]: + return { + col[0]: pandas_dtype_map.get(col[1], "object") for col in cursor_description + } diff --git a/tests/dataframe_test.py b/tests/dataframe_test.py index 1fb9b807a1..6b421e919c 100644 --- a/tests/dataframe_test.py +++ b/tests/dataframe_test.py @@ -18,6 +18,7 @@ import numpy as np from superset.dataframe import dedup, SupersetDataFrame from superset.db_engine_specs import BaseEngineSpec +from superset.db_engine_specs.presto import PrestoEngineSpec from .base_tests import SupersetTestCase @@ -108,3 +109,23 @@ class SupersetDataFrameTestCase(SupersetTestCase): cursor_descr = (("a", "string"), ("a", "string")) cdf = SupersetDataFrame(data, cursor_descr, BaseEngineSpec) self.assertListEqual(cdf.column_names, ["a", "a__1"]) + + def test_int64_with_missing_data(self): + data = [(None,), (1239162456494753670,), (None,), (None,), (None,), (None,)] + cursor_descr = [("user_id", "bigint", None, None, None, None, True)] + + # the base engine spec does not provide a dtype based on the cursor + # description, so the column is inferred as float64 because of the + # missing data + cdf = SupersetDataFrame(data, cursor_descr, BaseEngineSpec) + np.testing.assert_array_equal( + cdf.raw_df.values.tolist(), + [[np.nan], [1.2391624564947538e18], [np.nan], [np.nan], [np.nan], [np.nan]], + ) + + # currently only Presto provides a dtype based on the cursor description + cdf = SupersetDataFrame(data, cursor_descr, PrestoEngineSpec) + np.testing.assert_array_equal( + cdf.raw_df.values.tolist(), + [[np.nan], [1239162456494753670], [np.nan], [np.nan], [np.nan], [np.nan]], + )