Replace pandas.DataFrame with PyArrow.Table for nullable int typing (#8733)

* Use PyArrow Table for query result serialization

* Cleanup dev comments

* Additional cleanup

* WIP: tests

* Remove explicit dtype logic from db_engine_specs

* Remove obsolete  column property

* SupersetTable column types

* Port SupersetDataFrame methods to SupersetTable

* Add test for nullable boolean columns

* Support datetime values with timezone offsets

* Black formatting

* Pylint

* More linting/formatting

* Resolve issue with timezones not appearing in results

* Types

* Enable running of tests in tests/db_engine_specs

* Resolve application context errors

* Refactor and add tests for pyodbc.Row conversion

* Appease isort, regardless of isort:skip

* Re-enable RESULTS_BACKEND_USE_MSGPACK default based on benchmarks

* Dataframe typing and nits

* Renames to reduce ambiguity
This commit is contained in:
Rob DiCiuccio 2020-01-03 16:55:39 +00:00 committed by Maxime Beauchemin
parent 4f8bf2b04d
commit 6537d5ed8c
16 changed files with 438 additions and 513 deletions

View File

@ -55,13 +55,11 @@ describe('ExploreResultsButton', () => {
const mockColumns = { const mockColumns = {
ds: { ds: {
is_date: true, is_date: true,
is_dim: false,
name: 'ds', name: 'ds',
type: 'STRING', type: 'STRING',
}, },
gender: { gender: {
is_date: false, is_date: false,
is_dim: true,
name: 'gender', name: 'gender',
type: 'STRING', type: 'STRING',
}, },

View File

@ -219,13 +219,11 @@ export const queries = [
columns: [ columns: [
{ {
is_date: true, is_date: true,
is_dim: false,
name: 'ds', name: 'ds',
type: 'STRING', type: 'STRING',
}, },
{ {
is_date: false, is_date: false,
is_dim: true,
name: 'gender', name: 'gender',
type: 'STRING', type: 'STRING',
}, },
@ -233,13 +231,11 @@ export const queries = [
selected_columns: [ selected_columns: [
{ {
is_date: true, is_date: true,
is_dim: false,
name: 'ds', name: 'ds',
type: 'STRING', type: 'STRING',
}, },
{ {
is_date: false, is_date: false,
is_dim: true,
name: 'gender', name: 'gender',
type: 'STRING', type: 'STRING',
}, },
@ -291,37 +287,31 @@ export const queryWithBadColumns = {
selected_columns: [ selected_columns: [
{ {
is_date: true, is_date: true,
is_dim: false,
name: 'COUNT(*)', name: 'COUNT(*)',
type: 'STRING', type: 'STRING',
}, },
{ {
is_date: false, is_date: false,
is_dim: true,
name: 'this_col_is_ok', name: 'this_col_is_ok',
type: 'STRING', type: 'STRING',
}, },
{ {
is_date: false, is_date: false,
is_dim: true,
name: 'a', name: 'a',
type: 'STRING', type: 'STRING',
}, },
{ {
is_date: false, is_date: false,
is_dim: true,
name: '1', name: '1',
type: 'STRING', type: 'STRING',
}, },
{ {
is_date: false, is_date: false,
is_dim: true,
name: '123', name: '123',
type: 'STRING', type: 'STRING',
}, },
{ {
is_date: false, is_date: false,
is_dim: true,
name: 'CASE WHEN 1=1 THEN 1 ELSE 0 END', name: 'CASE WHEN 1=1 THEN 1 ELSE 0 END',
type: 'STRING', type: 'STRING',
}, },

View File

@ -519,7 +519,7 @@ RESULTS_BACKEND = None
# rather than JSON. This feature requires additional testing from the # rather than JSON. This feature requires additional testing from the
# community before it is fully adopted, so this config option is provided # community before it is fully adopted, so this config option is provided
# in order to disable should breaking issues be discovered. # in order to disable should breaking issues be discovered.
RESULTS_BACKEND_USE_MSGPACK = False RESULTS_BACKEND_USE_MSGPACK = True
# The S3 bucket where you want to store your external hive tables created # The S3 bucket where you want to store your external hive tables created
# from CSV files. For example, 'companyname-superset' # from CSV files. For example, 'companyname-superset'

View File

@ -14,257 +14,22 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=C,R,W """ Superset utilities for pandas.DataFrame.
""" Superset wrapper around pandas.DataFrame.
TODO(bkyryliuk): add support for the conventions like: *_dim or dim_*
dimensions, *_ts, ts_*, ds_*, *_ds - datetime, etc.
TODO(bkyryliuk): recognize integer encoded enums.
""" """
import logging from typing import Any, Dict, List
from datetime import date, datetime
import numpy as np
import pandas as pd import pandas as pd
from pandas.core.common import maybe_box_datetimelike
from pandas.core.dtypes.dtypes import ExtensionDtype
from superset.utils.core import JS_MAX_INTEGER from superset.utils.core import JS_MAX_INTEGER
INFER_COL_TYPES_THRESHOLD = 95
INFER_COL_TYPES_SAMPLE_SIZE = 100
def df_to_records(dframe: pd.DataFrame) -> List[Dict[str, Any]]:
def dedup(l, suffix="__", case_sensitive=True): data: List[Dict[str, Any]] = dframe.to_dict(orient="records")
"""De-duplicates a list of string by suffixing a counter # TODO: refactor this
for d in data:
Always returns the same number of entries as provided, and always returns for k, v in list(d.items()):
unique values. Case sensitive comparison by default. # if an int is too big for JavaScript to handle
# convert it to a string
>>> print(','.join(dedup(['foo', 'bar', 'bar', 'bar', 'Bar']))) if isinstance(v, int) and abs(v) > JS_MAX_INTEGER:
foo,bar,bar__1,bar__2,Bar d[k] = str(v)
>>> print( return data
','.join(dedup(['foo', 'bar', 'bar', 'bar', 'Bar'], case_sensitive=False))
)
foo,bar,bar__1,bar__2,Bar__3
"""
new_l = []
seen = {}
for s in l:
s_fixed_case = s if case_sensitive else s.lower()
if s_fixed_case in seen:
seen[s_fixed_case] += 1
s += suffix + str(seen[s_fixed_case])
else:
seen[s_fixed_case] = 0
new_l.append(s)
return new_l
def is_numeric(dtype):
if hasattr(dtype, "_is_numeric"):
return dtype._is_numeric
return np.issubdtype(dtype, np.number)
class SupersetDataFrame:
# Mapping numpy dtype.char to generic database types
type_map = {
"b": "BOOL", # boolean
"i": "INT", # (signed) integer
"u": "INT", # unsigned integer
"l": "INT", # 64bit integer
"f": "FLOAT", # floating-point
"c": "FLOAT", # complex-floating point
"m": None, # timedelta
"M": "DATETIME", # datetime
"O": "OBJECT", # (Python) objects
"S": "BYTE", # (byte-)string
"U": "STRING", # Unicode
"V": None, # raw data (void)
}
def __init__(self, data, cursor_description, db_engine_spec):
data = data or []
column_names = []
dtype = None
if cursor_description:
# get deduped list of column names
column_names = dedup([col[0] for col in cursor_description])
# fix cursor descriptor with the deduped names
cursor_description = [
tuple([column_name, *list(description)[1:]])
for column_name, description in zip(column_names, cursor_description)
]
# get type for better type casting, if possible
dtype = db_engine_spec.get_pandas_dtype(cursor_description)
self.column_names = column_names
if dtype:
# put data in a 2D array so we can efficiently access each column;
# the reshape ensures the shape is 2D in case data is empty
array = np.array(data, dtype="object").reshape(-1, len(column_names))
# convert each column in data into a Series of the proper dtype; we
# need to do this because we can not specify a mixed dtype when
# instantiating the DataFrame, and this allows us to have different
# dtypes for each column.
data = {
column: pd.Series(array[:, i], dtype=dtype[column])
for i, column in enumerate(column_names)
}
self.df = pd.DataFrame(data, columns=column_names)
else:
self.df = pd.DataFrame(list(data), columns=column_names).infer_objects()
self._type_dict = {}
try:
# The driver may not be passing a cursor.description
self._type_dict = {
col: db_engine_spec.get_datatype(cursor_description[i][1])
for i, col in enumerate(column_names)
if cursor_description
}
except Exception as e:
logging.exception(e)
@property
def raw_df(self):
return self.df
@property
def size(self):
return len(self.df.index)
@property
def data(self):
return self.format_data(self.df)
@classmethod
def format_data(cls, df):
# work around for https://github.com/pandas-dev/pandas/issues/18372
data = [
dict(
(k, maybe_box_datetimelike(v))
for k, v in zip(df.columns, np.atleast_1d(row))
)
for row in df.values
]
for d in data:
for k, v in list(d.items()):
# if an int is too big for Java Script to handle
# convert it to a string
if isinstance(v, int):
if abs(v) > JS_MAX_INTEGER:
d[k] = str(v)
return data
@classmethod
def db_type(cls, dtype):
"""Given a numpy dtype, Returns a generic database type"""
if isinstance(dtype, ExtensionDtype):
return cls.type_map.get(dtype.kind)
elif hasattr(dtype, "char"):
return cls.type_map.get(dtype.char)
@classmethod
def datetime_conversion_rate(cls, data_series):
success = 0
total = 0
for value in data_series:
total += 1
try:
pd.to_datetime(value)
success += 1
except Exception:
continue
return 100 * success / total
@staticmethod
def is_date(np_dtype, db_type_str):
def looks_daty(s):
if isinstance(s, str):
return any([s.lower().startswith(ss) for ss in ("time", "date")])
return False
if looks_daty(db_type_str):
return True
if np_dtype and np_dtype.name and looks_daty(np_dtype.name):
return True
return False
@classmethod
def is_dimension(cls, dtype, column_name):
if cls.is_id(column_name):
return False
return dtype.name in ("object", "bool")
@classmethod
def is_id(cls, column_name):
return column_name.startswith("id") or column_name.endswith("id")
@classmethod
def agg_func(cls, dtype, column_name):
# consider checking for key substring too.
if cls.is_id(column_name):
return "count_distinct"
if (
hasattr(dtype, "type")
and issubclass(dtype.type, np.generic)
and is_numeric(dtype)
):
return "sum"
return None
@property
def columns(self):
"""Provides metadata about columns for data visualization.
:return: dict, with the fields name, type, is_date, is_dim and agg.
"""
if self.df.empty:
return None
columns = []
sample_size = min(INFER_COL_TYPES_SAMPLE_SIZE, len(self.df.index))
sample = self.df
if sample_size:
sample = self.df.sample(sample_size)
for col in self.df.dtypes.keys():
db_type_str = self._type_dict.get(col) or self.db_type(self.df.dtypes[col])
column = {
"name": col,
"agg": self.agg_func(self.df.dtypes[col], col),
"type": db_type_str,
"is_date": self.is_date(self.df.dtypes[col], db_type_str),
"is_dim": self.is_dimension(self.df.dtypes[col], col),
}
if not db_type_str or db_type_str.upper() == "OBJECT":
v = sample[col].iloc[0] if not sample[col].empty else None
if isinstance(v, str):
column["type"] = "STRING"
elif isinstance(v, int):
column["type"] = "INT"
elif isinstance(v, float):
column["type"] = "FLOAT"
elif isinstance(v, (datetime, date)):
column["type"] = "DATETIME"
column["is_date"] = True
column["is_dim"] = False
# check if encoded datetime
if (
column["type"] == "STRING"
and self.datetime_conversion_rate(sample[col])
> INFER_COL_TYPES_THRESHOLD
):
column.update({"is_date": True, "is_dim": False, "agg": None})
# 'agg' is optional attribute
if not column["agg"]:
column.pop("agg", None)
columns.append(column)
return columns

View File

@ -289,12 +289,6 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return type_code.upper() return type_code.upper()
return None return None
@classmethod
def get_pandas_dtype(
cls, cursor_description: List[tuple]
) -> Optional[Dict[str, str]]:
return None
@classmethod @classmethod
def extra_table_metadata( def extra_table_metadata(
cls, database, table_name: str, schema_name: str cls, database, table_name: str, schema_name: str

View File

@ -24,20 +24,6 @@ from sqlalchemy import literal_column
from superset.db_engine_specs.base import BaseEngineSpec from superset.db_engine_specs.base import BaseEngineSpec
pandas_dtype_map = {
"STRING": "object",
"BOOLEAN": "object", # to support nullable bool
"INTEGER": "Int64",
"FLOAT": "float64",
"TIMESTAMP": "datetime64[ns]",
"DATETIME": "datetime64[ns]",
"DATE": "object",
"BYTES": "object",
"TIME": "object",
"RECORD": "object",
"NUMERIC": "object",
}
class BigQueryEngineSpec(BaseEngineSpec): class BigQueryEngineSpec(BaseEngineSpec):
"""Engine spec for Google's BigQuery """Engine spec for Google's BigQuery
@ -209,9 +195,3 @@ class BigQueryEngineSpec(BaseEngineSpec):
if key in kwargs: if key in kwargs:
gbq_kwargs[key] = kwargs[key] gbq_kwargs[key] = kwargs[key]
pandas_gbq.to_gbq(df, **gbq_kwargs) pandas_gbq.to_gbq(df, **gbq_kwargs)
@classmethod
def get_pandas_dtype(cls, cursor_description: List[tuple]) -> Dict[str, str]:
return {
col[0]: pandas_dtype_map.get(col[1], "object") for col in cursor_description
}

View File

@ -17,6 +17,7 @@
from datetime import datetime from datetime import datetime
from typing import List, Optional, Tuple, TYPE_CHECKING from typing import List, Optional, Tuple, TYPE_CHECKING
from pytz import _FixedOffset # type: ignore
from sqlalchemy.dialects.postgresql.base import PGInspector from sqlalchemy.dialects.postgresql.base import PGInspector
from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod
@ -26,6 +27,12 @@ if TYPE_CHECKING:
from superset.models.core import Database # pylint: disable=unused-import from superset.models.core import Database # pylint: disable=unused-import
# Replace psycopg2.tz.FixedOffsetTimezone with pytz, which is serializable by PyArrow
# https://github.com/stub42/pytz/blob/b70911542755aeeea7b5a9e066df5e1c87e8f2c8/src/pytz/reference.py#L25
class FixedOffsetTimezone(_FixedOffset):
pass
class PostgresBaseEngineSpec(BaseEngineSpec): class PostgresBaseEngineSpec(BaseEngineSpec):
""" Abstract class for Postgres 'like' databases """ """ Abstract class for Postgres 'like' databases """
@ -45,6 +52,7 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
@classmethod @classmethod
def fetch_data(cls, cursor, limit: int) -> List[Tuple]: def fetch_data(cls, cursor, limit: int) -> List[Tuple]:
cursor.tzinfo_factory = FixedOffsetTimezone
if not cursor.description: if not cursor.description:
return [] return []
if cls.limit_method == LimitMethod.FETCH_MANY: if cls.limit_method == LimitMethod.FETCH_MANY:

View File

@ -46,21 +46,6 @@ if TYPE_CHECKING:
QueryStatus = utils.QueryStatus QueryStatus = utils.QueryStatus
config = app.config config = app.config
# map between Presto types and Pandas
pandas_dtype_map = {
"boolean": "object", # to support nullable bool
"tinyint": "Int64", # note: capital "I" means nullable int
"smallint": "Int64",
"integer": "Int64",
"bigint": "Int64",
"real": "float64",
"double": "float64",
"varchar": "object",
"timestamp": "datetime64[ns]",
"date": "datetime64[ns]",
"varbinary": "object",
}
def get_children(column: Dict[str, str]) -> List[Dict[str, str]]: def get_children(column: Dict[str, str]) -> List[Dict[str, str]]:
""" """
@ -962,9 +947,3 @@ class PrestoEngineSpec(BaseEngineSpec):
if df.empty: if df.empty:
return "" return ""
return df.to_dict()[field_to_return][0] return df.to_dict()[field_to_return][0]
@classmethod
def get_pandas_dtype(cls, cursor_description: List[tuple]) -> Dict[str, str]:
return {
col[0]: pandas_dtype_map.get(col[1], "object") for col in cursor_description
}

177
superset/result_set.py Normal file
View File

@ -0,0 +1,177 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=C,R,W
""" Superset wrapper around pyarrow.Table.
"""
import datetime
import logging
import re
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
import numpy as np
import pandas as pd
import pyarrow as pa
from superset import db_engine_specs
def dedup(l: List[str], suffix: str = "__", case_sensitive: bool = True) -> List[str]:
"""De-duplicates a list of string by suffixing a counter
Always returns the same number of entries as provided, and always returns
unique values. Case sensitive comparison by default.
>>> print(','.join(dedup(['foo', 'bar', 'bar', 'bar', 'Bar'])))
foo,bar,bar__1,bar__2,Bar
>>> print(
','.join(dedup(['foo', 'bar', 'bar', 'bar', 'Bar'], case_sensitive=False))
)
foo,bar,bar__1,bar__2,Bar__3
"""
new_l: List[str] = []
seen: Dict[str, int] = {}
for s in l:
s_fixed_case = s if case_sensitive else s.lower()
if s_fixed_case in seen:
seen[s_fixed_case] += 1
s += suffix + str(seen[s_fixed_case])
else:
seen[s_fixed_case] = 0
new_l.append(s)
return new_l
class SupersetResultSet:
def __init__(
self,
data: List[Tuple[Any, ...]],
cursor_description: Tuple[Any, ...],
db_engine_spec: Type[db_engine_specs.BaseEngineSpec],
):
data = data or []
column_names: List[str] = []
pa_data: List[pa.Array] = []
deduped_cursor_desc: List[Tuple[Any, ...]] = []
if cursor_description:
# get deduped list of column names
column_names = dedup([col[0] for col in cursor_description])
# fix cursor descriptor with the deduped names
deduped_cursor_desc = [
tuple([column_name, *list(description)[1:]])
for column_name, description in zip(column_names, cursor_description)
]
# put data in a 2D array so we can efficiently access each column;
array = np.array(data, dtype="object")
if array.size > 0:
pa_data = [pa.array(array[:, i]) for i, column in enumerate(column_names)]
# workaround for bug converting `psycopg2.tz.FixedOffsetTimezone` tzinfo values.
# related: https://issues.apache.org/jira/browse/ARROW-5248
if pa_data:
for i, column in enumerate(column_names):
if pa.types.is_temporal(pa_data[i].type):
sample = self.first_nonempty(array[:, i])
if sample and isinstance(sample, datetime.datetime):
try:
if sample.tzinfo:
series = pd.Series(array[:, i], dtype="datetime64[ns]")
pa_data[i] = pa.Array.from_pandas(
series, type=pa.timestamp("ns", tz=sample.tzinfo)
)
except Exception as e:
logging.exception(e)
self.table = pa.Table.from_arrays(pa_data, names=column_names)
self._type_dict: Dict[str, Any] = {}
try:
# The driver may not be passing a cursor.description
self._type_dict = {
col: db_engine_spec.get_datatype(deduped_cursor_desc[i][1])
for i, col in enumerate(column_names)
if deduped_cursor_desc
}
except Exception as e:
logging.exception(e)
@staticmethod
def convert_pa_dtype(pa_dtype: pa.DataType) -> Optional[str]:
if pa.types.is_boolean(pa_dtype):
return "BOOL"
if pa.types.is_integer(pa_dtype):
return "INT"
if pa.types.is_floating(pa_dtype):
return "FLOAT"
if pa.types.is_string(pa_dtype):
return "STRING"
if pa.types.is_temporal(pa_dtype):
return "DATETIME"
return None
@staticmethod
def convert_table_to_df(table: pa.Table) -> pd.DataFrame:
return table.to_pandas(integer_object_nulls=True)
@staticmethod
def first_nonempty(items: List) -> Any:
return next((i for i in items if i), None)
@staticmethod
def is_date(db_type_str: Optional[str]) -> bool:
return db_type_str in ("DATETIME", "TIMESTAMP")
def data_type(self, col_name: str, pa_dtype: pa.DataType) -> Optional[str]:
"""Given a pyarrow data type, Returns a generic database type"""
set_type = self._type_dict.get(col_name)
if set_type:
return set_type
mapped_type = self.convert_pa_dtype(pa_dtype)
if mapped_type:
return mapped_type
return None
def to_pandas_df(self) -> pd.DataFrame:
return self.convert_table_to_df(self.table)
@property
def pa_table(self) -> pa.Table:
return self.table
@property
def size(self) -> int:
return self.table.num_rows
@property
def columns(self) -> List[Dict[str, Any]]:
if not self.table.column_names:
return []
columns = []
for col in self.table.schema:
db_type_str = self.data_type(col.name, col.type)
column = {
"name": col.name,
"type": db_type_str,
"is_date": self.is_date(db_type_str),
}
columns.append(column)
return columns

View File

@ -19,7 +19,7 @@ import uuid
from contextlib import closing from contextlib import closing
from datetime import datetime from datetime import datetime
from sys import getsizeof from sys import getsizeof
from typing import Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import backoff import backoff
import msgpack import msgpack
@ -39,10 +39,11 @@ from superset import (
results_backend_use_msgpack, results_backend_use_msgpack,
security_manager, security_manager,
) )
from superset.dataframe import SupersetDataFrame from superset.dataframe import df_to_records
from superset.db_engine_specs import BaseEngineSpec from superset.db_engine_specs import BaseEngineSpec
from superset.extensions import celery_app from superset.extensions import celery_app
from superset.models.sql_lab import Query from superset.models.sql_lab import Query
from superset.result_set import SupersetResultSet
from superset.sql_parse import ParsedQuery from superset.sql_parse import ParsedQuery
from superset.utils.core import json_iso_dttm_ser, QueryStatus, sources, zlib_compress from superset.utils.core import json_iso_dttm_ser, QueryStatus, sources, zlib_compress
from superset.utils.dates import now_as_float from superset.utils.dates import now_as_float
@ -251,7 +252,7 @@ def execute_sql_statement(sql_statement, query, user_name, session, cursor, log_
logger.debug(f"Query {query.id}: Fetching cursor description") logger.debug(f"Query {query.id}: Fetching cursor description")
cursor_description = cursor.description cursor_description = cursor.description
return SupersetDataFrame(data, cursor_description, db_engine_spec) return SupersetResultSet(data, cursor_description, db_engine_spec)
def _serialize_payload( def _serialize_payload(
@ -265,13 +266,13 @@ def _serialize_payload(
def _serialize_and_expand_data( def _serialize_and_expand_data(
cdf: SupersetDataFrame, result_set: SupersetResultSet,
db_engine_spec: BaseEngineSpec, db_engine_spec: BaseEngineSpec,
use_msgpack: Optional[bool] = False, use_msgpack: Optional[bool] = False,
expand_data: bool = False, expand_data: bool = False,
) -> Tuple[Union[bytes, str], list, list, list]: ) -> Tuple[Union[bytes, str], list, list, list]:
selected_columns: list = cdf.columns or [] selected_columns: List[Dict] = result_set.columns
expanded_columns: list expanded_columns: List[Dict]
if use_msgpack: if use_msgpack:
with stats_timing( with stats_timing(
@ -279,14 +280,17 @@ def _serialize_and_expand_data(
): ):
data = ( data = (
pa.default_serialization_context() pa.default_serialization_context()
.serialize(cdf.raw_df) .serialize(result_set.pa_table)
.to_buffer() .to_buffer()
.to_pybytes() .to_pybytes()
) )
# expand when loading data from results backend # expand when loading data from results backend
all_columns, expanded_columns = (selected_columns, []) all_columns, expanded_columns = (selected_columns, [])
else: else:
data = cdf.data or [] df = result_set.to_pandas_df()
data = df_to_records(df) or []
if expand_data: if expand_data:
all_columns, data, expanded_columns = db_engine_spec.expand_data( all_columns, data, expanded_columns = db_engine_spec.expand_data(
selected_columns, data selected_columns, data
@ -356,7 +360,7 @@ def execute_sql_statements(
query.set_extra_json_key("progress", msg) query.set_extra_json_key("progress", msg)
session.commit() session.commit()
try: try:
cdf = execute_sql_statement( result_set = execute_sql_statement(
statement, query, user_name, session, cursor, log_params statement, query, user_name, session, cursor, log_params
) )
except Exception as e: # pylint: disable=broad-except except Exception as e: # pylint: disable=broad-except
@ -367,7 +371,7 @@ def execute_sql_statements(
return payload return payload
# Success, updating the query entry in database # Success, updating the query entry in database
query.rows = cdf.size query.rows = result_set.size
query.progress = 100 query.progress = 100
query.set_extra_json_key("progress", None) query.set_extra_json_key("progress", None)
if query.select_as_cta: if query.select_as_cta:
@ -381,9 +385,13 @@ def execute_sql_statements(
query.end_time = now_as_float() query.end_time = now_as_float()
data, selected_columns, all_columns, expanded_columns = _serialize_and_expand_data( data, selected_columns, all_columns, expanded_columns = _serialize_and_expand_data(
cdf, db_engine_spec, store_results and results_backend_use_msgpack, expand_data result_set,
db_engine_spec,
store_results and results_backend_use_msgpack,
expand_data,
) )
# TODO: data should be saved separately from metadata (likely in Parquet)
payload.update( payload.update(
{ {
"status": QueryStatus.SUCCESS, "status": QueryStatus.SUCCESS,

View File

@ -61,6 +61,7 @@ from superset import (
event_logger, event_logger,
get_feature_flags, get_feature_flags,
is_feature_enabled, is_feature_enabled,
result_set,
results_backend, results_backend,
results_backend_use_msgpack, results_backend_use_msgpack,
security_manager, security_manager,
@ -227,10 +228,10 @@ def _deserialize_results_payload(
ds_payload = msgpack.loads(payload, raw=False) ds_payload = msgpack.loads(payload, raw=False)
with stats_timing("sqllab.query.results_backend_pa_deserialize", stats_logger): with stats_timing("sqllab.query.results_backend_pa_deserialize", stats_logger):
df = pa.deserialize(ds_payload["data"]) pa_table = pa.deserialize(ds_payload["data"])
# TODO: optimize this, perhaps via df.to_dict, then traversing df = result_set.SupersetResultSet.convert_table_to_df(pa_table)
ds_payload["data"] = dataframe.SupersetDataFrame.format_data(df) or [] ds_payload["data"] = dataframe.df_to_records(df) or []
db_engine_spec = query.database.db_engine_spec db_engine_spec = query.database.db_engine_spec
all_columns, data, expanded_columns = db_engine_spec.expand_data( all_columns, data, expanded_columns = db_engine_spec.expand_data(

View File

@ -28,7 +28,7 @@ from flask import current_app
from tests.test_app import app from tests.test_app import app
from superset import db, sql_lab from superset import db, sql_lab
from superset.dataframe import SupersetDataFrame from superset.result_set import SupersetResultSet
from superset.db_engine_specs.base import BaseEngineSpec from superset.db_engine_specs.base import BaseEngineSpec
from superset.extensions import celery_app from superset.extensions import celery_app
from superset.models.helpers import QueryStatus from superset.models.helpers import QueryStatus
@ -275,13 +275,13 @@ class CeleryTestCase(SupersetTestCase):
("d", "datetime"), ("d", "datetime"),
) )
db_engine_spec = BaseEngineSpec() db_engine_spec = BaseEngineSpec()
cdf = SupersetDataFrame(data, cursor_descr, db_engine_spec) results = SupersetResultSet(data, cursor_descr, db_engine_spec)
with mock.patch.object( with mock.patch.object(
db_engine_spec, "expand_data", wraps=db_engine_spec.expand_data db_engine_spec, "expand_data", wraps=db_engine_spec.expand_data
) as expand_data: ) as expand_data:
data, selected_columns, all_columns, expanded_columns = sql_lab._serialize_and_expand_data( data, selected_columns, all_columns, expanded_columns = sql_lab._serialize_and_expand_data(
cdf, db_engine_spec, False, True results, db_engine_spec, False, True
) )
expand_data.assert_called_once() expand_data.assert_called_once()
@ -296,13 +296,13 @@ class CeleryTestCase(SupersetTestCase):
("d", "datetime"), ("d", "datetime"),
) )
db_engine_spec = BaseEngineSpec() db_engine_spec = BaseEngineSpec()
cdf = SupersetDataFrame(data, cursor_descr, db_engine_spec) results = SupersetResultSet(data, cursor_descr, db_engine_spec)
with mock.patch.object( with mock.patch.object(
db_engine_spec, "expand_data", wraps=db_engine_spec.expand_data db_engine_spec, "expand_data", wraps=db_engine_spec.expand_data
) as expand_data: ) as expand_data:
data, selected_columns, all_columns, expanded_columns = sql_lab._serialize_and_expand_data( data, selected_columns, all_columns, expanded_columns = sql_lab._serialize_and_expand_data(
cdf, db_engine_spec, True results, db_engine_spec, True
) )
expand_data.assert_not_called() expand_data.assert_not_called()
@ -318,14 +318,14 @@ class CeleryTestCase(SupersetTestCase):
("d", "datetime"), ("d", "datetime"),
) )
db_engine_spec = BaseEngineSpec() db_engine_spec = BaseEngineSpec()
cdf = SupersetDataFrame(data, cursor_descr, db_engine_spec) results = SupersetResultSet(data, cursor_descr, db_engine_spec)
query = { query = {
"database_id": 1, "database_id": 1,
"sql": "SELECT * FROM birth_names LIMIT 100", "sql": "SELECT * FROM birth_names LIMIT 100",
"status": QueryStatus.PENDING, "status": QueryStatus.PENDING,
} }
serialized_data, selected_columns, all_columns, expanded_columns = sql_lab._serialize_and_expand_data( serialized_data, selected_columns, all_columns, expanded_columns = sql_lab._serialize_and_expand_data(
cdf, db_engine_spec, use_new_deserialization results, db_engine_spec, use_new_deserialization
) )
payload = { payload = {
"query_id": 1, "query_id": 1,
@ -351,14 +351,14 @@ class CeleryTestCase(SupersetTestCase):
("d", "datetime"), ("d", "datetime"),
) )
db_engine_spec = BaseEngineSpec() db_engine_spec = BaseEngineSpec()
cdf = SupersetDataFrame(data, cursor_descr, db_engine_spec) results = SupersetResultSet(data, cursor_descr, db_engine_spec)
query = { query = {
"database_id": 1, "database_id": 1,
"sql": "SELECT * FROM birth_names LIMIT 100", "sql": "SELECT * FROM birth_names LIMIT 100",
"status": QueryStatus.PENDING, "status": QueryStatus.PENDING,
} }
serialized_data, selected_columns, all_columns, expanded_columns = sql_lab._serialize_and_expand_data( serialized_data, selected_columns, all_columns, expanded_columns = sql_lab._serialize_and_expand_data(
cdf, db_engine_spec, use_new_deserialization results, db_engine_spec, use_new_deserialization
) )
payload = { payload = {
"query_id": 1, "query_id": 1,

View File

@ -24,6 +24,7 @@ import io
import json import json
import logging import logging
import os import os
import pytz
import random import random
import re import re
import string import string
@ -44,12 +45,12 @@ from superset.models.dashboard import Dashboard
from superset.models.datasource_access_request import DatasourceAccessRequest from superset.models.datasource_access_request import DatasourceAccessRequest
from superset.models.slice import Slice from superset.models.slice import Slice
from superset.models.sql_lab import Query from superset.models.sql_lab import Query
from superset.result_set import SupersetResultSet
from superset.utils import core as utils from superset.utils import core as utils
from superset.views import core as views from superset.views import core as views
from superset.views.database.views import DatabaseView from superset.views.database.views import DatabaseView
from .base_tests import SupersetTestCase from .base_tests import SupersetTestCase
from .fixtures.pyodbcRow import Row
class CoreTests(SupersetTestCase): class CoreTests(SupersetTestCase):
@ -702,18 +703,24 @@ class CoreTests(SupersetTestCase):
os.remove(filename_2) os.remove(filename_2)
def test_dataframe_timezone(self): def test_dataframe_timezone(self):
tz = psycopg2.tz.FixedOffsetTimezone(offset=60, name=None) tz = pytz.FixedOffset(60)
data = [ data = [
(datetime.datetime(2017, 11, 18, 21, 53, 0, 219225, tzinfo=tz),), (datetime.datetime(2017, 11, 18, 21, 53, 0, 219225, tzinfo=tz),),
(datetime.datetime(2017, 11, 18, 22, 6, 30, 61810, tzinfo=tz),), (datetime.datetime(2017, 11, 18, 22, 6, 30, tzinfo=tz),),
] ]
df = dataframe.SupersetDataFrame(list(data), [["data"]], BaseEngineSpec) results = SupersetResultSet(list(data), [["data"]], BaseEngineSpec)
data = df.data df = results.to_pandas_df()
data = dataframe.df_to_records(df)
json_str = json.dumps(data, default=utils.pessimistic_json_iso_dttm_ser)
self.assertDictEqual( self.assertDictEqual(
data[0], {"data": pd.Timestamp("2017-11-18 21:53:00.219225+0100", tz=tz)} data[0], {"data": pd.Timestamp("2017-11-18 21:53:00.219225+0100", tz=tz)}
) )
self.assertDictEqual( self.assertDictEqual(
data[1], {"data": pd.Timestamp("2017-11-18 22:06:30.061810+0100", tz=tz)} data[1], {"data": pd.Timestamp("2017-11-18 22:06:30+0100", tz=tz)}
)
self.assertEqual(
json_str,
'[{"data": "2017-11-18T21:53:00.219225+01:00"}, {"data": "2017-11-18T22:06:30+01:00"}]',
) )
def test_mssql_engine_spec_pymssql(self): def test_mssql_engine_spec_pymssql(self):
@ -722,26 +729,11 @@ class CoreTests(SupersetTestCase):
(1, 1, datetime.datetime(2017, 10, 19, 23, 39, 16, 660000)), (1, 1, datetime.datetime(2017, 10, 19, 23, 39, 16, 660000)),
(2, 2, datetime.datetime(2018, 10, 19, 23, 39, 16, 660000)), (2, 2, datetime.datetime(2018, 10, 19, 23, 39, 16, 660000)),
] ]
df = dataframe.SupersetDataFrame( results = SupersetResultSet(
list(data), [["col1"], ["col2"], ["col3"]], MssqlEngineSpec list(data), [["col1"], ["col2"], ["col3"]], MssqlEngineSpec
) )
data = df.data df = results.to_pandas_df()
self.assertEqual(len(data), 2) data = dataframe.df_to_records(df)
self.assertEqual(
data[0],
{"col1": 1, "col2": 1, "col3": pd.Timestamp("2017-10-19 23:39:16.660000")},
)
def test_mssql_engine_spec_odbc(self):
# Test for case when pyodbc.Row is returned (msodbc driver)
data = [
Row((1, 1, datetime.datetime(2017, 10, 19, 23, 39, 16, 660000))),
Row((2, 2, datetime.datetime(2018, 10, 19, 23, 39, 16, 660000))),
]
df = dataframe.SupersetDataFrame(
list(data), [["col1"], ["col2"], ["col3"]], MssqlEngineSpec
)
data = df.data
self.assertEqual(len(data), 2) self.assertEqual(len(data), 2)
self.assertEqual( self.assertEqual(
data[0], data[0],
@ -876,14 +868,14 @@ class CoreTests(SupersetTestCase):
("d", "datetime"), ("d", "datetime"),
) )
db_engine_spec = BaseEngineSpec() db_engine_spec = BaseEngineSpec()
cdf = dataframe.SupersetDataFrame(data, cursor_descr, db_engine_spec) results = SupersetResultSet(data, cursor_descr, db_engine_spec)
query = { query = {
"database_id": 1, "database_id": 1,
"sql": "SELECT * FROM birth_names LIMIT 100", "sql": "SELECT * FROM birth_names LIMIT 100",
"status": utils.QueryStatus.PENDING, "status": utils.QueryStatus.PENDING,
} }
serialized_data, selected_columns, all_columns, expanded_columns = sql_lab._serialize_and_expand_data( serialized_data, selected_columns, all_columns, expanded_columns = sql_lab._serialize_and_expand_data(
cdf, db_engine_spec, use_new_deserialization results, db_engine_spec, use_new_deserialization
) )
payload = { payload = {
"query_id": 1, "query_id": 1,
@ -919,14 +911,14 @@ class CoreTests(SupersetTestCase):
("d", "datetime"), ("d", "datetime"),
) )
db_engine_spec = BaseEngineSpec() db_engine_spec = BaseEngineSpec()
cdf = dataframe.SupersetDataFrame(data, cursor_descr, db_engine_spec) results = SupersetResultSet(data, cursor_descr, db_engine_spec)
query = { query = {
"database_id": 1, "database_id": 1,
"sql": "SELECT * FROM birth_names LIMIT 100", "sql": "SELECT * FROM birth_names LIMIT 100",
"status": utils.QueryStatus.PENDING, "status": utils.QueryStatus.PENDING,
} }
serialized_data, selected_columns, all_columns, expanded_columns = sql_lab._serialize_and_expand_data( serialized_data, selected_columns, all_columns, expanded_columns = sql_lab._serialize_and_expand_data(
cdf, db_engine_spec, use_new_deserialization results, db_engine_spec, use_new_deserialization
) )
payload = { payload = {
"query_id": 1, "query_id": 1,
@ -953,7 +945,8 @@ class CoreTests(SupersetTestCase):
deserialized_payload = views._deserialize_results_payload( deserialized_payload = views._deserialize_results_payload(
serialized_payload, query_mock, use_new_deserialization serialized_payload, query_mock, use_new_deserialization
) )
payload["data"] = dataframe.SupersetDataFrame.format_data(cdf.raw_df) df = results.to_pandas_df()
payload["data"] = dataframe.df_to_records(df)
self.assertDictEqual(deserialized_payload, payload) self.assertDictEqual(deserialized_payload, payload)
expand_data.assert_called_once() expand_data.assert_called_once()

View File

@ -17,143 +17,35 @@
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from superset.dataframe import dedup, SupersetDataFrame from superset.dataframe import df_to_records
from superset.db_engine_specs import BaseEngineSpec from superset.db_engine_specs import BaseEngineSpec
from superset.db_engine_specs.presto import PrestoEngineSpec from superset.result_set import SupersetResultSet
from .base_tests import SupersetTestCase from .base_tests import SupersetTestCase
class SupersetDataFrameTestCase(SupersetTestCase): class SupersetDataFrameTestCase(SupersetTestCase):
def test_dedup(self): def test_df_to_records(self):
self.assertEqual(dedup(["foo", "bar"]), ["foo", "bar"])
self.assertEqual(
dedup(["foo", "bar", "foo", "bar", "Foo"]),
["foo", "bar", "foo__1", "bar__1", "Foo"],
)
self.assertEqual(
dedup(["foo", "bar", "bar", "bar", "Bar"]),
["foo", "bar", "bar__1", "bar__2", "Bar"],
)
self.assertEqual(
dedup(["foo", "bar", "bar", "bar", "Bar"], case_sensitive=False),
["foo", "bar", "bar__1", "bar__2", "Bar__3"],
)
def test_get_columns_basic(self):
data = [("a1", "b1", "c1"), ("a2", "b2", "c2")] data = [("a1", "b1", "c1"), ("a2", "b2", "c2")]
cursor_descr = (("a", "string"), ("b", "string"), ("c", "string")) cursor_descr = (("a", "string"), ("b", "string"), ("c", "string"))
cdf = SupersetDataFrame(data, cursor_descr, BaseEngineSpec) results = SupersetResultSet(data, cursor_descr, BaseEngineSpec)
df = results.to_pandas_df()
self.assertEqual( self.assertEqual(
cdf.columns, df_to_records(df),
[{"a": "a1", "b": "b1", "c": "c1"}, {"a": "a2", "b": "b2", "c": "c2"}],
)
def test_js_max_int(self):
data = [(1, 1239162456494753670, "c1"), (2, 100, "c2")]
cursor_descr = (("a", "int"), ("b", "int"), ("c", "string"))
results = SupersetResultSet(data, cursor_descr, BaseEngineSpec)
df = results.to_pandas_df()
self.assertEqual(
df_to_records(df),
[ [
{"is_date": False, "type": "STRING", "name": "a", "is_dim": True}, {"a": 1, "b": "1239162456494753670", "c": "c1"},
{"is_date": False, "type": "STRING", "name": "b", "is_dim": True}, {"a": 2, "b": 100, "c": "c2"},
{"is_date": False, "type": "STRING", "name": "c", "is_dim": True},
], ],
) )
def test_get_columns_with_int(self):
data = [("a1", 1), ("a2", 2)]
cursor_descr = (("a", "string"), ("b", "int"))
cdf = SupersetDataFrame(data, cursor_descr, BaseEngineSpec)
self.assertEqual(
cdf.columns,
[
{"is_date": False, "type": "STRING", "name": "a", "is_dim": True},
{
"is_date": False,
"type": "INT",
"name": "b",
"is_dim": False,
"agg": "sum",
},
],
)
def test_get_columns_type_inference(self):
data = [(1.2, 1), (3.14, 2)]
cursor_descr = (("a", None), ("b", None))
cdf = SupersetDataFrame(data, cursor_descr, BaseEngineSpec)
self.assertEqual(
cdf.columns,
[
{
"is_date": False,
"type": "FLOAT",
"name": "a",
"is_dim": False,
"agg": "sum",
},
{
"is_date": False,
"type": "INT",
"name": "b",
"is_dim": False,
"agg": "sum",
},
],
)
def test_is_date(self):
f = SupersetDataFrame.is_date
self.assertEqual(f(np.dtype("M"), ""), True)
self.assertEqual(f(np.dtype("f"), "DATETIME"), True)
self.assertEqual(f(np.dtype("i"), "TIMESTAMP"), True)
self.assertEqual(f(None, "DATETIME"), True)
self.assertEqual(f(None, "TIMESTAMP"), True)
self.assertEqual(f(None, ""), False)
self.assertEqual(f(np.dtype(np.int32), ""), False)
def test_dedup_with_data(self):
data = [("a", 1), ("a", 2)]
cursor_descr = (("a", "string"), ("a", "string"))
cdf = SupersetDataFrame(data, cursor_descr, BaseEngineSpec)
self.assertListEqual(cdf.column_names, ["a", "a__1"])
def test_int64_with_missing_data(self):
data = [(None,), (1239162456494753670,), (None,), (None,), (None,), (None,)]
cursor_descr = [("user_id", "bigint", None, None, None, None, True)]
# the base engine spec does not provide a dtype based on the cursor
# description, so the column is inferred as float64 because of the
# missing data
cdf = SupersetDataFrame(data, cursor_descr, BaseEngineSpec)
np.testing.assert_array_equal(
cdf.raw_df.values.tolist(),
[[np.nan], [1.2391624564947538e18], [np.nan], [np.nan], [np.nan], [np.nan]],
)
# currently only Presto provides a dtype based on the cursor description
cdf = SupersetDataFrame(data, cursor_descr, PrestoEngineSpec)
np.testing.assert_array_equal(
cdf.raw_df.values.tolist(),
[[np.nan], [1239162456494753670], [np.nan], [np.nan], [np.nan], [np.nan]],
)
def test_pandas_datetime64(self):
data = [(None,)]
cursor_descr = [("ds", "timestamp", None, None, None, None, True)]
cdf = SupersetDataFrame(data, cursor_descr, PrestoEngineSpec)
self.assertEqual(cdf.raw_df.dtypes[0], np.dtype("<M8[ns]"))
def test_no_type_coercion(self):
data = [("a", 1), ("b", 2)]
cursor_descr = [
("one", "varchar", None, None, None, None, True),
("two", "integer", None, None, None, None, True),
]
cdf = SupersetDataFrame(data, cursor_descr, PrestoEngineSpec)
self.assertEqual(cdf.raw_df.dtypes[0], np.dtype("O"))
self.assertEqual(cdf.raw_df.dtypes[1], pd.Int64Dtype())
def test_empty_data(self):
data = []
cursor_descr = [
("one", "varchar", None, None, None, None, True),
("two", "integer", None, None, None, None, True),
]
cdf = SupersetDataFrame(data, cursor_descr, PrestoEngineSpec)
self.assertEqual(cdf.raw_df.dtypes[0], np.dtype("O"))
self.assertEqual(cdf.raw_df.dtypes[1], pd.Int64Dtype())

150
tests/result_set_tests.py Normal file
View File

@ -0,0 +1,150 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
import numpy as np
import pandas as pd
from superset.dataframe import df_to_records
from superset.db_engine_specs import BaseEngineSpec
from superset.result_set import dedup, SupersetResultSet
from .base_tests import SupersetTestCase
class SupersetResultSetTestCase(SupersetTestCase):
def test_dedup(self):
self.assertEqual(dedup(["foo", "bar"]), ["foo", "bar"])
self.assertEqual(
dedup(["foo", "bar", "foo", "bar", "Foo"]),
["foo", "bar", "foo__1", "bar__1", "Foo"],
)
self.assertEqual(
dedup(["foo", "bar", "bar", "bar", "Bar"]),
["foo", "bar", "bar__1", "bar__2", "Bar"],
)
self.assertEqual(
dedup(["foo", "bar", "bar", "bar", "Bar"], case_sensitive=False),
["foo", "bar", "bar__1", "bar__2", "Bar__3"],
)
def test_get_columns_basic(self):
data = [("a1", "b1", "c1"), ("a2", "b2", "c2")]
cursor_descr = (("a", "string"), ("b", "string"), ("c", "string"))
results = SupersetResultSet(data, cursor_descr, BaseEngineSpec)
self.assertEqual(
results.columns,
[
{"is_date": False, "type": "STRING", "name": "a"},
{"is_date": False, "type": "STRING", "name": "b"},
{"is_date": False, "type": "STRING", "name": "c"},
],
)
def test_get_columns_with_int(self):
data = [("a1", 1), ("a2", 2)]
cursor_descr = (("a", "string"), ("b", "int"))
results = SupersetResultSet(data, cursor_descr, BaseEngineSpec)
self.assertEqual(
results.columns,
[
{"is_date": False, "type": "STRING", "name": "a"},
{"is_date": False, "type": "INT", "name": "b"},
],
)
def test_get_columns_type_inference(self):
data = [
(1.2, 1, "foo", datetime(2018, 10, 19, 23, 39, 16, 660000), True),
(3.14, 2, "bar", datetime(2019, 10, 19, 23, 39, 16, 660000), False),
]
cursor_descr = (("a", None), ("b", None), ("c", None), ("d", None), ("e", None))
results = SupersetResultSet(data, cursor_descr, BaseEngineSpec)
self.assertEqual(
results.columns,
[
{"is_date": False, "type": "FLOAT", "name": "a"},
{"is_date": False, "type": "INT", "name": "b"},
{"is_date": False, "type": "STRING", "name": "c"},
{"is_date": True, "type": "DATETIME", "name": "d"},
{"is_date": False, "type": "BOOL", "name": "e"},
],
)
def test_is_date(self):
is_date = SupersetResultSet.is_date
self.assertEqual(is_date("DATETIME"), True)
self.assertEqual(is_date("TIMESTAMP"), True)
self.assertEqual(is_date("STRING"), False)
self.assertEqual(is_date(""), False)
self.assertEqual(is_date(None), False)
def test_dedup_with_data(self):
data = [("a", 1), ("a", 2)]
cursor_descr = (("a", "string"), ("a", "string"))
results = SupersetResultSet(data, cursor_descr, BaseEngineSpec)
column_names = [col["name"] for col in results.columns]
self.assertListEqual(column_names, ["a", "a__1"])
def test_int64_with_missing_data(self):
data = [(None,), (1239162456494753670,), (None,), (None,), (None,), (None,)]
cursor_descr = [("user_id", "bigint", None, None, None, None, True)]
results = SupersetResultSet(data, cursor_descr, BaseEngineSpec)
self.assertEqual(results.columns[0]["type"], "BIGINT")
def test_nullable_bool(self):
data = [(None,), (True,), (None,), (None,), (None,), (None,)]
cursor_descr = [("is_test", "bool", None, None, None, None, True)]
results = SupersetResultSet(data, cursor_descr, BaseEngineSpec)
self.assertEqual(results.columns[0]["type"], "BOOL")
df = results.to_pandas_df()
self.assertEqual(
df_to_records(df),
[
{"is_test": None},
{"is_test": True},
{"is_test": None},
{"is_test": None},
{"is_test": None},
{"is_test": None},
],
)
def test_empty_datetime(self):
data = [(None,)]
cursor_descr = [("ds", "timestamp", None, None, None, None, True)]
results = SupersetResultSet(data, cursor_descr, BaseEngineSpec)
self.assertEqual(results.columns[0]["type"], "TIMESTAMP")
def test_no_type_coercion(self):
data = [("a", 1), ("b", 2)]
cursor_descr = [
("one", "varchar", None, None, None, None, True),
("two", "int", None, None, None, None, True),
]
results = SupersetResultSet(data, cursor_descr, BaseEngineSpec)
self.assertEqual(results.columns[0]["type"], "VARCHAR")
self.assertEqual(results.columns[1]["type"], "INT")
def test_empty_data(self):
data = []
cursor_descr = [
("emptyone", "varchar", None, None, None, None, True),
("emptytwo", "int", None, None, None, None, True),
]
results = SupersetResultSet(data, cursor_descr, BaseEngineSpec)
self.assertEqual(results.columns, [])

View File

@ -23,9 +23,10 @@ import prison
from superset import db, security_manager from superset import db, security_manager
from superset.connectors.sqla.models import SqlaTable from superset.connectors.sqla.models import SqlaTable
from superset.dataframe import SupersetDataFrame from superset.dataframe import df_to_records
from superset.db_engine_specs import BaseEngineSpec from superset.db_engine_specs import BaseEngineSpec
from superset.models.sql_lab import Query from superset.models.sql_lab import Query
from superset.result_set import SupersetResultSet
from superset.utils.core import datetime_to_epoch, get_example_database from superset.utils.core import datetime_to_epoch, get_example_database
from .base_tests import SupersetTestCase from .base_tests import SupersetTestCase
@ -266,29 +267,29 @@ class SqlLabTests(SupersetTestCase):
raise_on_error=True, raise_on_error=True,
) )
def test_df_conversion_no_dict(self): def test_ps_conversion_no_dict(self):
cols = [["string_col", "string"], ["int_col", "int"], ["float_col", "float"]] cols = [["string_col", "string"], ["int_col", "int"], ["float_col", "float"]]
data = [["a", 4, 4.0]] data = [["a", 4, 4.0]]
cdf = SupersetDataFrame(data, cols, BaseEngineSpec) results = SupersetResultSet(data, cols, BaseEngineSpec)
self.assertEqual(len(data), cdf.size) self.assertEqual(len(data), results.size)
self.assertEqual(len(cols), len(cdf.columns)) self.assertEqual(len(cols), len(results.columns))
def test_df_conversion_tuple(self): def test_pa_conversion_tuple(self):
cols = ["string_col", "int_col", "list_col", "float_col"] cols = ["string_col", "int_col", "list_col", "float_col"]
data = [("Text", 111, [123], 1.0)] data = [("Text", 111, [123], 1.0)]
cdf = SupersetDataFrame(data, cols, BaseEngineSpec) results = SupersetResultSet(data, cols, BaseEngineSpec)
self.assertEqual(len(data), cdf.size) self.assertEqual(len(data), results.size)
self.assertEqual(len(cols), len(cdf.columns)) self.assertEqual(len(cols), len(results.columns))
def test_df_conversion_dict(self): def test_pa_conversion_dict(self):
cols = ["string_col", "dict_col", "int_col"] cols = ["string_col", "dict_col", "int_col"]
data = [["a", {"c1": 1, "c2": 2, "c3": 3}, 4]] data = [["a", {"c1": 1, "c2": 2, "c3": 3}, 4]]
cdf = SupersetDataFrame(data, cols, BaseEngineSpec) results = SupersetResultSet(data, cols, BaseEngineSpec)
self.assertEqual(len(data), cdf.size) self.assertEqual(len(data), results.size)
self.assertEqual(len(cols), len(cdf.columns)) self.assertEqual(len(cols), len(results.columns))
def test_sqllab_viz(self): def test_sqllab_viz(self):
self.login("admin") self.login("admin")
@ -298,19 +299,8 @@ class SqlLabTests(SupersetTestCase):
"datasourceName": f"test_viz_flow_table_{random()}", "datasourceName": f"test_viz_flow_table_{random()}",
"schema": "superset", "schema": "superset",
"columns": [ "columns": [
{ {"is_date": False, "type": "STRING", "name": f"viz_type_{random()}"},
"is_date": False, {"is_date": False, "type": "OBJECT", "name": f"ccount_{random()}"},
"type": "STRING",
"name": f"viz_type_{random()}",
"is_dim": True,
},
{
"is_date": False,
"type": "OBJECT",
"name": f"ccount_{random()}",
"is_dim": True,
"agg": "sum",
},
], ],
"sql": """\ "sql": """\
SELECT * SELECT *