feat: create dtype option for csv upload (#23716)

This commit is contained in:
Elizabeth Thompson 2023-04-24 12:53:53 -07:00 committed by GitHub
parent 4873c0990a
commit 71106cfd97
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 160 additions and 2 deletions

View File

@ -18,12 +18,16 @@ import logging
import re
from typing import Any, Dict, Optional, Pattern, Tuple
import pandas as pd
from flask_babel import gettext as __
from sqlalchemy.types import NVARCHAR
from superset.db_engine_specs.base import BasicParametersMixin
from superset.db_engine_specs.postgres import PostgresBaseEngineSpec
from superset.errors import SupersetErrorType
from superset.models.core import Database
from superset.models.sql_lab import Query
from superset.sql_parse import Table
logger = logging.getLogger()
@ -96,6 +100,42 @@ class RedshiftEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
),
}
@classmethod
def df_to_sql(
cls,
database: Database,
table: Table,
df: pd.DataFrame,
to_sql_kwargs: Dict[str, Any],
) -> None:
"""
Upload data from a Pandas DataFrame to a database.
For regular engines this calls the `pandas.DataFrame.to_sql` method.
Overrides the base class to allow for pandas string types to be
used as nvarchar(max) columns, as redshift does not support
text data types.
Note this method does not create metadata for the table.
:param database: The database to upload the data to
:param table: The table to upload the data to
:param df: The dataframe with data to be uploaded
:param to_sql_kwargs: The kwargs to be passed to pandas.DataFrame.to_sql` method
"""
to_sql_kwargs = to_sql_kwargs or {}
to_sql_kwargs["dtype"] = {
# uses the max size for redshift nvarchar(65335)
# the default object and string types create a varchar(256)
col_name: NVARCHAR(length=65535)
for col_name, type in zip(df.columns, df.dtypes)
if isinstance(type, pd.StringDtype)
}
super().df_to_sql(
df=df, database=database, table=table, to_sql_kwargs=to_sql_kwargs
)
@staticmethod
def _mutate_label(label: str) -> str:
"""

View File

@ -104,6 +104,10 @@
{{ lib.render_field(form.overwrite_duplicate, begin_sep_label, end_sep_label, begin_sep_field,
end_sep_field) }}
</tr>
<tr>
{{ lib.render_field(form.dtype, begin_sep_label, end_sep_label, begin_sep_field,
end_sep_field) }}
</tr>
{% endcall %}
{% call csv_macros.render_collapsable_form_group("accordion3", "Rows") %}
<tr>

View File

@ -140,6 +140,16 @@ class CsvToDatabaseForm(UploadToDatabaseForm):
get_pk=lambda a: a.id,
get_label=lambda a: a.database_name,
)
dtype = StringField(
_("Column Data Types"),
description=_(
"A dictionary with column names and their data types"
" if you need to change the defaults."
' Example: {"user_id":"integer"}'
),
validators=[Optional()],
widget=BS3TextFieldWidget(),
)
schema = StringField(
_("Schema"),
description=_("Select a schema if the database supports this"),

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import io
import json
import os
import tempfile
import zipfile
@ -189,6 +190,7 @@ class CsvToDatabaseView(CustomFormView):
delimiter_input = form.otherInput.data
try:
kwargs = {"dtype": json.loads(form.dtype.data)} if form.dtype.data else {}
df = pd.concat(
pd.read_csv(
chunksize=1000,
@ -208,6 +210,7 @@ class CsvToDatabaseView(CustomFormView):
skip_blank_lines=form.skip_blank_lines.data,
skipinitialspace=form.skip_initial_space.data,
skiprows=form.skiprows.data,
**kwargs,
)
)

View File

@ -20,7 +20,7 @@ import json
import logging
import os
import shutil
from typing import Dict, Optional
from typing import Dict, Optional, Union
from unittest import mock
@ -129,7 +129,12 @@ def get_upload_db():
return db.session.query(Database).filter_by(database_name=CSV_UPLOAD_DATABASE).one()
def upload_csv(filename: str, table_name: str, extra: Optional[Dict[str, str]] = None):
def upload_csv(
filename: str,
table_name: str,
extra: Optional[Dict[str, str]] = None,
dtype: Union[str, None] = None,
):
csv_upload_db_id = get_upload_db().id
schema = utils.get_example_default_schema()
form_data = {
@ -145,6 +150,8 @@ def upload_csv(filename: str, table_name: str, extra: Optional[Dict[str, str]] =
form_data["schema"] = schema
if extra:
form_data.update(extra)
if dtype:
form_data["dtype"] = dtype
return get_resp(test_client, "/csvtodatabaseview/form", data=form_data)
@ -386,6 +393,39 @@ def test_import_csv(mock_event_logger):
data = engine.execute(f"SELECT * from {CSV_UPLOAD_TABLE}").fetchall()
assert data == [("john", 1, "x"), ("paul", 2, None)]
# cleanup
with get_upload_db().get_sqla_engine_with_context() as engine:
engine.execute(f"DROP TABLE {full_table_name}")
# with dtype
upload_csv(
CSV_FILENAME1,
CSV_UPLOAD_TABLE,
dtype='{"a": "string", "b": "float64"}',
)
# you can change the type to something compatible, like an object to string
# or an int to a float
# file upload should work as normal
with test_db.get_sqla_engine_with_context() as engine:
data = engine.execute(f"SELECT * from {CSV_UPLOAD_TABLE}").fetchall()
assert data == [("john", 1), ("paul", 2)]
# cleanup
with get_upload_db().get_sqla_engine_with_context() as engine:
engine.execute(f"DROP TABLE {full_table_name}")
# with dtype - wrong type
resp = upload_csv(
CSV_FILENAME1,
CSV_UPLOAD_TABLE,
dtype='{"a": "int"}',
)
# you cannot pass an incompatible dtype
fail_msg = f"Unable to upload CSV file {escaped_double_quotes(CSV_FILENAME1)} to table {escaped_double_quotes(CSV_UPLOAD_TABLE)}"
assert fail_msg in resp
@pytest.mark.usefixtures("setup_csv_upload_with_context")
@pytest.mark.usefixtures("create_excel_files")

View File

@ -14,11 +14,18 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import unittest.mock as mock
from textwrap import dedent
import numpy as np
import pandas as pd
from sqlalchemy.types import NVARCHAR
from superset.db_engine_specs.redshift import RedshiftEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.sql_parse import Table
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
from tests.integration_tests.test_app import app
class TestRedshiftDbEngineSpec(TestDbEngineSpec):
@ -183,3 +190,57 @@ psql: error: could not connect to server: Operation timed out
},
)
]
def test_df_to_sql_no_dtype(self):
mock_database = mock.MagicMock()
mock_database.get_df.return_value.empty = False
table_name = "foobar"
data = [
("foo", "bar", pd.NA, None),
("foo", "bar", pd.NA, True),
("foo", "bar", pd.NA, None),
]
numpy_dtype = [
("id", "object"),
("value", "object"),
("num", "object"),
("bool", "object"),
]
column_names = ["id", "value", "num", "bool"]
test_array = np.array(data, dtype=numpy_dtype)
df = pd.DataFrame(test_array, columns=column_names)
df.to_sql = mock.MagicMock()
with app.app_context():
RedshiftEngineSpec.df_to_sql(
mock_database, Table(table=table_name), df, to_sql_kwargs={}
)
assert df.to_sql.call_args[1]["dtype"] == {}
def test_df_to_sql_with_string_dtype(self):
mock_database = mock.MagicMock()
mock_database.get_df.return_value.empty = False
table_name = "foobar"
data = [
("foo", "bar", pd.NA, None),
("foo", "bar", pd.NA, True),
("foo", "bar", pd.NA, None),
]
column_names = ["id", "value", "num", "bool"]
df = pd.DataFrame(data, columns=column_names)
df = df.astype(dtype={"value": "string"})
df.to_sql = mock.MagicMock()
with app.app_context():
RedshiftEngineSpec.df_to_sql(
mock_database, Table(table=table_name), df, to_sql_kwargs={}
)
# varchar string length should be 65535
dtype = df.to_sql.call_args[1]["dtype"]
assert isinstance(dtype["value"], NVARCHAR)
assert dtype["value"].length == 65535