diff --git a/superset/db_engine_specs/redshift.py b/superset/db_engine_specs/redshift.py index 7e2717d776..27b749e418 100644 --- a/superset/db_engine_specs/redshift.py +++ b/superset/db_engine_specs/redshift.py @@ -18,12 +18,16 @@ import logging import re from typing import Any, Dict, Optional, Pattern, Tuple +import pandas as pd from flask_babel import gettext as __ +from sqlalchemy.types import NVARCHAR from superset.db_engine_specs.base import BasicParametersMixin from superset.db_engine_specs.postgres import PostgresBaseEngineSpec from superset.errors import SupersetErrorType +from superset.models.core import Database from superset.models.sql_lab import Query +from superset.sql_parse import Table logger = logging.getLogger() @@ -96,6 +100,42 @@ class RedshiftEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin): ), } + @classmethod + def df_to_sql( + cls, + database: Database, + table: Table, + df: pd.DataFrame, + to_sql_kwargs: Dict[str, Any], + ) -> None: + """ + Upload data from a Pandas DataFrame to a database. + + For regular engines this calls the `pandas.DataFrame.to_sql` method. + Overrides the base class to allow for pandas string types to be + used as nvarchar(max) columns, as redshift does not support + text data types. + + Note this method does not create metadata for the table. + + :param database: The database to upload the data to + :param table: The table to upload the data to + :param df: The dataframe with data to be uploaded + :param to_sql_kwargs: The kwargs to be passed to pandas.DataFrame.to_sql` method + """ + to_sql_kwargs = to_sql_kwargs or {} + to_sql_kwargs["dtype"] = { + # uses the max size for redshift nvarchar(65335) + # the default object and string types create a varchar(256) + col_name: NVARCHAR(length=65535) + for col_name, type in zip(df.columns, df.dtypes) + if isinstance(type, pd.StringDtype) + } + + super().df_to_sql( + df=df, database=database, table=table, to_sql_kwargs=to_sql_kwargs + ) + @staticmethod def _mutate_label(label: str) -> str: """ diff --git a/superset/templates/superset/form_view/csv_to_database_view/edit.html b/superset/templates/superset/form_view/csv_to_database_view/edit.html index b09f9bd383..a0ae43792e 100644 --- a/superset/templates/superset/form_view/csv_to_database_view/edit.html +++ b/superset/templates/superset/form_view/csv_to_database_view/edit.html @@ -104,6 +104,10 @@ {{ lib.render_field(form.overwrite_duplicate, begin_sep_label, end_sep_label, begin_sep_field, end_sep_field) }} + + {{ lib.render_field(form.dtype, begin_sep_label, end_sep_label, begin_sep_field, + end_sep_field) }} + {% endcall %} {% call csv_macros.render_collapsable_form_group("accordion3", "Rows") %} diff --git a/superset/views/database/forms.py b/superset/views/database/forms.py index 91ab38dc2f..99b64e38ab 100644 --- a/superset/views/database/forms.py +++ b/superset/views/database/forms.py @@ -140,6 +140,16 @@ class CsvToDatabaseForm(UploadToDatabaseForm): get_pk=lambda a: a.id, get_label=lambda a: a.database_name, ) + dtype = StringField( + _("Column Data Types"), + description=_( + "A dictionary with column names and their data types" + " if you need to change the defaults." + ' Example: {"user_id":"integer"}' + ), + validators=[Optional()], + widget=BS3TextFieldWidget(), + ) schema = StringField( _("Schema"), description=_("Select a schema if the database supports this"), diff --git a/superset/views/database/views.py b/superset/views/database/views.py index 037128ee16..a9137e59ed 100644 --- a/superset/views/database/views.py +++ b/superset/views/database/views.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import io +import json import os import tempfile import zipfile @@ -189,6 +190,7 @@ class CsvToDatabaseView(CustomFormView): delimiter_input = form.otherInput.data try: + kwargs = {"dtype": json.loads(form.dtype.data)} if form.dtype.data else {} df = pd.concat( pd.read_csv( chunksize=1000, @@ -208,6 +210,7 @@ class CsvToDatabaseView(CustomFormView): skip_blank_lines=form.skip_blank_lines.data, skipinitialspace=form.skip_initial_space.data, skiprows=form.skiprows.data, + **kwargs, ) ) diff --git a/tests/integration_tests/csv_upload_tests.py b/tests/integration_tests/csv_upload_tests.py index d3b55f7bfe..97b83bb8fc 100644 --- a/tests/integration_tests/csv_upload_tests.py +++ b/tests/integration_tests/csv_upload_tests.py @@ -20,7 +20,7 @@ import json import logging import os import shutil -from typing import Dict, Optional +from typing import Dict, Optional, Union from unittest import mock @@ -129,7 +129,12 @@ def get_upload_db(): return db.session.query(Database).filter_by(database_name=CSV_UPLOAD_DATABASE).one() -def upload_csv(filename: str, table_name: str, extra: Optional[Dict[str, str]] = None): +def upload_csv( + filename: str, + table_name: str, + extra: Optional[Dict[str, str]] = None, + dtype: Union[str, None] = None, +): csv_upload_db_id = get_upload_db().id schema = utils.get_example_default_schema() form_data = { @@ -145,6 +150,8 @@ def upload_csv(filename: str, table_name: str, extra: Optional[Dict[str, str]] = form_data["schema"] = schema if extra: form_data.update(extra) + if dtype: + form_data["dtype"] = dtype return get_resp(test_client, "/csvtodatabaseview/form", data=form_data) @@ -386,6 +393,39 @@ def test_import_csv(mock_event_logger): data = engine.execute(f"SELECT * from {CSV_UPLOAD_TABLE}").fetchall() assert data == [("john", 1, "x"), ("paul", 2, None)] + # cleanup + with get_upload_db().get_sqla_engine_with_context() as engine: + engine.execute(f"DROP TABLE {full_table_name}") + + # with dtype + upload_csv( + CSV_FILENAME1, + CSV_UPLOAD_TABLE, + dtype='{"a": "string", "b": "float64"}', + ) + + # you can change the type to something compatible, like an object to string + # or an int to a float + # file upload should work as normal + with test_db.get_sqla_engine_with_context() as engine: + data = engine.execute(f"SELECT * from {CSV_UPLOAD_TABLE}").fetchall() + assert data == [("john", 1), ("paul", 2)] + + # cleanup + with get_upload_db().get_sqla_engine_with_context() as engine: + engine.execute(f"DROP TABLE {full_table_name}") + + # with dtype - wrong type + resp = upload_csv( + CSV_FILENAME1, + CSV_UPLOAD_TABLE, + dtype='{"a": "int"}', + ) + + # you cannot pass an incompatible dtype + fail_msg = f"Unable to upload CSV file {escaped_double_quotes(CSV_FILENAME1)} to table {escaped_double_quotes(CSV_UPLOAD_TABLE)}" + assert fail_msg in resp + @pytest.mark.usefixtures("setup_csv_upload_with_context") @pytest.mark.usefixtures("create_excel_files") diff --git a/tests/integration_tests/db_engine_specs/redshift_tests.py b/tests/integration_tests/db_engine_specs/redshift_tests.py index cdfe8d16cb..2d46c73fca 100644 --- a/tests/integration_tests/db_engine_specs/redshift_tests.py +++ b/tests/integration_tests/db_engine_specs/redshift_tests.py @@ -14,11 +14,18 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import unittest.mock as mock from textwrap import dedent +import numpy as np +import pandas as pd +from sqlalchemy.types import NVARCHAR + from superset.db_engine_specs.redshift import RedshiftEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType +from superset.sql_parse import Table from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec +from tests.integration_tests.test_app import app class TestRedshiftDbEngineSpec(TestDbEngineSpec): @@ -183,3 +190,57 @@ psql: error: could not connect to server: Operation timed out }, ) ] + + def test_df_to_sql_no_dtype(self): + mock_database = mock.MagicMock() + mock_database.get_df.return_value.empty = False + table_name = "foobar" + data = [ + ("foo", "bar", pd.NA, None), + ("foo", "bar", pd.NA, True), + ("foo", "bar", pd.NA, None), + ] + numpy_dtype = [ + ("id", "object"), + ("value", "object"), + ("num", "object"), + ("bool", "object"), + ] + column_names = ["id", "value", "num", "bool"] + + test_array = np.array(data, dtype=numpy_dtype) + + df = pd.DataFrame(test_array, columns=column_names) + df.to_sql = mock.MagicMock() + + with app.app_context(): + RedshiftEngineSpec.df_to_sql( + mock_database, Table(table=table_name), df, to_sql_kwargs={} + ) + + assert df.to_sql.call_args[1]["dtype"] == {} + + def test_df_to_sql_with_string_dtype(self): + mock_database = mock.MagicMock() + mock_database.get_df.return_value.empty = False + table_name = "foobar" + data = [ + ("foo", "bar", pd.NA, None), + ("foo", "bar", pd.NA, True), + ("foo", "bar", pd.NA, None), + ] + column_names = ["id", "value", "num", "bool"] + + df = pd.DataFrame(data, columns=column_names) + df = df.astype(dtype={"value": "string"}) + df.to_sql = mock.MagicMock() + + with app.app_context(): + RedshiftEngineSpec.df_to_sql( + mock_database, Table(table=table_name), df, to_sql_kwargs={} + ) + + # varchar string length should be 65535 + dtype = df.to_sql.call_args[1]["dtype"] + assert isinstance(dtype["value"], NVARCHAR) + assert dtype["value"].length == 65535