mirror of https://github.com/apache/superset.git
feat: create dtype option for csv upload (#23716)
This commit is contained in:
parent
4873c0990a
commit
71106cfd97
|
@ -18,12 +18,16 @@ import logging
|
|||
import re
|
||||
from typing import Any, Dict, Optional, Pattern, Tuple
|
||||
|
||||
import pandas as pd
|
||||
from flask_babel import gettext as __
|
||||
from sqlalchemy.types import NVARCHAR
|
||||
|
||||
from superset.db_engine_specs.base import BasicParametersMixin
|
||||
from superset.db_engine_specs.postgres import PostgresBaseEngineSpec
|
||||
from superset.errors import SupersetErrorType
|
||||
from superset.models.core import Database
|
||||
from superset.models.sql_lab import Query
|
||||
from superset.sql_parse import Table
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
@ -96,6 +100,42 @@ class RedshiftEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
|
|||
),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def df_to_sql(
|
||||
cls,
|
||||
database: Database,
|
||||
table: Table,
|
||||
df: pd.DataFrame,
|
||||
to_sql_kwargs: Dict[str, Any],
|
||||
) -> None:
|
||||
"""
|
||||
Upload data from a Pandas DataFrame to a database.
|
||||
|
||||
For regular engines this calls the `pandas.DataFrame.to_sql` method.
|
||||
Overrides the base class to allow for pandas string types to be
|
||||
used as nvarchar(max) columns, as redshift does not support
|
||||
text data types.
|
||||
|
||||
Note this method does not create metadata for the table.
|
||||
|
||||
:param database: The database to upload the data to
|
||||
:param table: The table to upload the data to
|
||||
:param df: The dataframe with data to be uploaded
|
||||
:param to_sql_kwargs: The kwargs to be passed to pandas.DataFrame.to_sql` method
|
||||
"""
|
||||
to_sql_kwargs = to_sql_kwargs or {}
|
||||
to_sql_kwargs["dtype"] = {
|
||||
# uses the max size for redshift nvarchar(65335)
|
||||
# the default object and string types create a varchar(256)
|
||||
col_name: NVARCHAR(length=65535)
|
||||
for col_name, type in zip(df.columns, df.dtypes)
|
||||
if isinstance(type, pd.StringDtype)
|
||||
}
|
||||
|
||||
super().df_to_sql(
|
||||
df=df, database=database, table=table, to_sql_kwargs=to_sql_kwargs
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _mutate_label(label: str) -> str:
|
||||
"""
|
||||
|
|
|
@ -104,6 +104,10 @@
|
|||
{{ lib.render_field(form.overwrite_duplicate, begin_sep_label, end_sep_label, begin_sep_field,
|
||||
end_sep_field) }}
|
||||
</tr>
|
||||
<tr>
|
||||
{{ lib.render_field(form.dtype, begin_sep_label, end_sep_label, begin_sep_field,
|
||||
end_sep_field) }}
|
||||
</tr>
|
||||
{% endcall %}
|
||||
{% call csv_macros.render_collapsable_form_group("accordion3", "Rows") %}
|
||||
<tr>
|
||||
|
|
|
@ -140,6 +140,16 @@ class CsvToDatabaseForm(UploadToDatabaseForm):
|
|||
get_pk=lambda a: a.id,
|
||||
get_label=lambda a: a.database_name,
|
||||
)
|
||||
dtype = StringField(
|
||||
_("Column Data Types"),
|
||||
description=_(
|
||||
"A dictionary with column names and their data types"
|
||||
" if you need to change the defaults."
|
||||
' Example: {"user_id":"integer"}'
|
||||
),
|
||||
validators=[Optional()],
|
||||
widget=BS3TextFieldWidget(),
|
||||
)
|
||||
schema = StringField(
|
||||
_("Schema"),
|
||||
description=_("Select a schema if the database supports this"),
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import zipfile
|
||||
|
@ -189,6 +190,7 @@ class CsvToDatabaseView(CustomFormView):
|
|||
delimiter_input = form.otherInput.data
|
||||
|
||||
try:
|
||||
kwargs = {"dtype": json.loads(form.dtype.data)} if form.dtype.data else {}
|
||||
df = pd.concat(
|
||||
pd.read_csv(
|
||||
chunksize=1000,
|
||||
|
@ -208,6 +210,7 @@ class CsvToDatabaseView(CustomFormView):
|
|||
skip_blank_lines=form.skip_blank_lines.data,
|
||||
skipinitialspace=form.skip_initial_space.data,
|
||||
skiprows=form.skiprows.data,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ import json
|
|||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from unittest import mock
|
||||
|
||||
|
@ -129,7 +129,12 @@ def get_upload_db():
|
|||
return db.session.query(Database).filter_by(database_name=CSV_UPLOAD_DATABASE).one()
|
||||
|
||||
|
||||
def upload_csv(filename: str, table_name: str, extra: Optional[Dict[str, str]] = None):
|
||||
def upload_csv(
|
||||
filename: str,
|
||||
table_name: str,
|
||||
extra: Optional[Dict[str, str]] = None,
|
||||
dtype: Union[str, None] = None,
|
||||
):
|
||||
csv_upload_db_id = get_upload_db().id
|
||||
schema = utils.get_example_default_schema()
|
||||
form_data = {
|
||||
|
@ -145,6 +150,8 @@ def upload_csv(filename: str, table_name: str, extra: Optional[Dict[str, str]] =
|
|||
form_data["schema"] = schema
|
||||
if extra:
|
||||
form_data.update(extra)
|
||||
if dtype:
|
||||
form_data["dtype"] = dtype
|
||||
return get_resp(test_client, "/csvtodatabaseview/form", data=form_data)
|
||||
|
||||
|
||||
|
@ -386,6 +393,39 @@ def test_import_csv(mock_event_logger):
|
|||
data = engine.execute(f"SELECT * from {CSV_UPLOAD_TABLE}").fetchall()
|
||||
assert data == [("john", 1, "x"), ("paul", 2, None)]
|
||||
|
||||
# cleanup
|
||||
with get_upload_db().get_sqla_engine_with_context() as engine:
|
||||
engine.execute(f"DROP TABLE {full_table_name}")
|
||||
|
||||
# with dtype
|
||||
upload_csv(
|
||||
CSV_FILENAME1,
|
||||
CSV_UPLOAD_TABLE,
|
||||
dtype='{"a": "string", "b": "float64"}',
|
||||
)
|
||||
|
||||
# you can change the type to something compatible, like an object to string
|
||||
# or an int to a float
|
||||
# file upload should work as normal
|
||||
with test_db.get_sqla_engine_with_context() as engine:
|
||||
data = engine.execute(f"SELECT * from {CSV_UPLOAD_TABLE}").fetchall()
|
||||
assert data == [("john", 1), ("paul", 2)]
|
||||
|
||||
# cleanup
|
||||
with get_upload_db().get_sqla_engine_with_context() as engine:
|
||||
engine.execute(f"DROP TABLE {full_table_name}")
|
||||
|
||||
# with dtype - wrong type
|
||||
resp = upload_csv(
|
||||
CSV_FILENAME1,
|
||||
CSV_UPLOAD_TABLE,
|
||||
dtype='{"a": "int"}',
|
||||
)
|
||||
|
||||
# you cannot pass an incompatible dtype
|
||||
fail_msg = f"Unable to upload CSV file {escaped_double_quotes(CSV_FILENAME1)} to table {escaped_double_quotes(CSV_UPLOAD_TABLE)}"
|
||||
assert fail_msg in resp
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("setup_csv_upload_with_context")
|
||||
@pytest.mark.usefixtures("create_excel_files")
|
||||
|
|
|
@ -14,11 +14,18 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import unittest.mock as mock
|
||||
from textwrap import dedent
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sqlalchemy.types import NVARCHAR
|
||||
|
||||
from superset.db_engine_specs.redshift import RedshiftEngineSpec
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.sql_parse import Table
|
||||
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
|
||||
from tests.integration_tests.test_app import app
|
||||
|
||||
|
||||
class TestRedshiftDbEngineSpec(TestDbEngineSpec):
|
||||
|
@ -183,3 +190,57 @@ psql: error: could not connect to server: Operation timed out
|
|||
},
|
||||
)
|
||||
]
|
||||
|
||||
def test_df_to_sql_no_dtype(self):
|
||||
mock_database = mock.MagicMock()
|
||||
mock_database.get_df.return_value.empty = False
|
||||
table_name = "foobar"
|
||||
data = [
|
||||
("foo", "bar", pd.NA, None),
|
||||
("foo", "bar", pd.NA, True),
|
||||
("foo", "bar", pd.NA, None),
|
||||
]
|
||||
numpy_dtype = [
|
||||
("id", "object"),
|
||||
("value", "object"),
|
||||
("num", "object"),
|
||||
("bool", "object"),
|
||||
]
|
||||
column_names = ["id", "value", "num", "bool"]
|
||||
|
||||
test_array = np.array(data, dtype=numpy_dtype)
|
||||
|
||||
df = pd.DataFrame(test_array, columns=column_names)
|
||||
df.to_sql = mock.MagicMock()
|
||||
|
||||
with app.app_context():
|
||||
RedshiftEngineSpec.df_to_sql(
|
||||
mock_database, Table(table=table_name), df, to_sql_kwargs={}
|
||||
)
|
||||
|
||||
assert df.to_sql.call_args[1]["dtype"] == {}
|
||||
|
||||
def test_df_to_sql_with_string_dtype(self):
|
||||
mock_database = mock.MagicMock()
|
||||
mock_database.get_df.return_value.empty = False
|
||||
table_name = "foobar"
|
||||
data = [
|
||||
("foo", "bar", pd.NA, None),
|
||||
("foo", "bar", pd.NA, True),
|
||||
("foo", "bar", pd.NA, None),
|
||||
]
|
||||
column_names = ["id", "value", "num", "bool"]
|
||||
|
||||
df = pd.DataFrame(data, columns=column_names)
|
||||
df = df.astype(dtype={"value": "string"})
|
||||
df.to_sql = mock.MagicMock()
|
||||
|
||||
with app.app_context():
|
||||
RedshiftEngineSpec.df_to_sql(
|
||||
mock_database, Table(table=table_name), df, to_sql_kwargs={}
|
||||
)
|
||||
|
||||
# varchar string length should be 65535
|
||||
dtype = df.to_sql.call_args[1]["dtype"]
|
||||
assert isinstance(dtype["value"], NVARCHAR)
|
||||
assert dtype["value"].length == 65535
|
||||
|
|
Loading…
Reference in New Issue