mirror of https://github.com/apache/superset.git
feat: create dtype option for csv upload (#23716)
This commit is contained in:
parent
4873c0990a
commit
71106cfd97
|
@ -18,12 +18,16 @@ import logging
|
||||||
import re
|
import re
|
||||||
from typing import Any, Dict, Optional, Pattern, Tuple
|
from typing import Any, Dict, Optional, Pattern, Tuple
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
from flask_babel import gettext as __
|
from flask_babel import gettext as __
|
||||||
|
from sqlalchemy.types import NVARCHAR
|
||||||
|
|
||||||
from superset.db_engine_specs.base import BasicParametersMixin
|
from superset.db_engine_specs.base import BasicParametersMixin
|
||||||
from superset.db_engine_specs.postgres import PostgresBaseEngineSpec
|
from superset.db_engine_specs.postgres import PostgresBaseEngineSpec
|
||||||
from superset.errors import SupersetErrorType
|
from superset.errors import SupersetErrorType
|
||||||
|
from superset.models.core import Database
|
||||||
from superset.models.sql_lab import Query
|
from superset.models.sql_lab import Query
|
||||||
|
from superset.sql_parse import Table
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
@ -96,6 +100,42 @@ class RedshiftEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def df_to_sql(
|
||||||
|
cls,
|
||||||
|
database: Database,
|
||||||
|
table: Table,
|
||||||
|
df: pd.DataFrame,
|
||||||
|
to_sql_kwargs: Dict[str, Any],
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Upload data from a Pandas DataFrame to a database.
|
||||||
|
|
||||||
|
For regular engines this calls the `pandas.DataFrame.to_sql` method.
|
||||||
|
Overrides the base class to allow for pandas string types to be
|
||||||
|
used as nvarchar(max) columns, as redshift does not support
|
||||||
|
text data types.
|
||||||
|
|
||||||
|
Note this method does not create metadata for the table.
|
||||||
|
|
||||||
|
:param database: The database to upload the data to
|
||||||
|
:param table: The table to upload the data to
|
||||||
|
:param df: The dataframe with data to be uploaded
|
||||||
|
:param to_sql_kwargs: The kwargs to be passed to pandas.DataFrame.to_sql` method
|
||||||
|
"""
|
||||||
|
to_sql_kwargs = to_sql_kwargs or {}
|
||||||
|
to_sql_kwargs["dtype"] = {
|
||||||
|
# uses the max size for redshift nvarchar(65335)
|
||||||
|
# the default object and string types create a varchar(256)
|
||||||
|
col_name: NVARCHAR(length=65535)
|
||||||
|
for col_name, type in zip(df.columns, df.dtypes)
|
||||||
|
if isinstance(type, pd.StringDtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
super().df_to_sql(
|
||||||
|
df=df, database=database, table=table, to_sql_kwargs=to_sql_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _mutate_label(label: str) -> str:
|
def _mutate_label(label: str) -> str:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -104,6 +104,10 @@
|
||||||
{{ lib.render_field(form.overwrite_duplicate, begin_sep_label, end_sep_label, begin_sep_field,
|
{{ lib.render_field(form.overwrite_duplicate, begin_sep_label, end_sep_label, begin_sep_field,
|
||||||
end_sep_field) }}
|
end_sep_field) }}
|
||||||
</tr>
|
</tr>
|
||||||
|
<tr>
|
||||||
|
{{ lib.render_field(form.dtype, begin_sep_label, end_sep_label, begin_sep_field,
|
||||||
|
end_sep_field) }}
|
||||||
|
</tr>
|
||||||
{% endcall %}
|
{% endcall %}
|
||||||
{% call csv_macros.render_collapsable_form_group("accordion3", "Rows") %}
|
{% call csv_macros.render_collapsable_form_group("accordion3", "Rows") %}
|
||||||
<tr>
|
<tr>
|
||||||
|
|
|
@ -140,6 +140,16 @@ class CsvToDatabaseForm(UploadToDatabaseForm):
|
||||||
get_pk=lambda a: a.id,
|
get_pk=lambda a: a.id,
|
||||||
get_label=lambda a: a.database_name,
|
get_label=lambda a: a.database_name,
|
||||||
)
|
)
|
||||||
|
dtype = StringField(
|
||||||
|
_("Column Data Types"),
|
||||||
|
description=_(
|
||||||
|
"A dictionary with column names and their data types"
|
||||||
|
" if you need to change the defaults."
|
||||||
|
' Example: {"user_id":"integer"}'
|
||||||
|
),
|
||||||
|
validators=[Optional()],
|
||||||
|
widget=BS3TextFieldWidget(),
|
||||||
|
)
|
||||||
schema = StringField(
|
schema = StringField(
|
||||||
_("Schema"),
|
_("Schema"),
|
||||||
description=_("Select a schema if the database supports this"),
|
description=_("Select a schema if the database supports this"),
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
import io
|
import io
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import zipfile
|
import zipfile
|
||||||
|
@ -189,6 +190,7 @@ class CsvToDatabaseView(CustomFormView):
|
||||||
delimiter_input = form.otherInput.data
|
delimiter_input = form.otherInput.data
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
kwargs = {"dtype": json.loads(form.dtype.data)} if form.dtype.data else {}
|
||||||
df = pd.concat(
|
df = pd.concat(
|
||||||
pd.read_csv(
|
pd.read_csv(
|
||||||
chunksize=1000,
|
chunksize=1000,
|
||||||
|
@ -208,6 +210,7 @@ class CsvToDatabaseView(CustomFormView):
|
||||||
skip_blank_lines=form.skip_blank_lines.data,
|
skip_blank_lines=form.skip_blank_lines.data,
|
||||||
skipinitialspace=form.skip_initial_space.data,
|
skipinitialspace=form.skip_initial_space.data,
|
||||||
skiprows=form.skiprows.data,
|
skiprows=form.skiprows.data,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,7 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
|
@ -129,7 +129,12 @@ def get_upload_db():
|
||||||
return db.session.query(Database).filter_by(database_name=CSV_UPLOAD_DATABASE).one()
|
return db.session.query(Database).filter_by(database_name=CSV_UPLOAD_DATABASE).one()
|
||||||
|
|
||||||
|
|
||||||
def upload_csv(filename: str, table_name: str, extra: Optional[Dict[str, str]] = None):
|
def upload_csv(
|
||||||
|
filename: str,
|
||||||
|
table_name: str,
|
||||||
|
extra: Optional[Dict[str, str]] = None,
|
||||||
|
dtype: Union[str, None] = None,
|
||||||
|
):
|
||||||
csv_upload_db_id = get_upload_db().id
|
csv_upload_db_id = get_upload_db().id
|
||||||
schema = utils.get_example_default_schema()
|
schema = utils.get_example_default_schema()
|
||||||
form_data = {
|
form_data = {
|
||||||
|
@ -145,6 +150,8 @@ def upload_csv(filename: str, table_name: str, extra: Optional[Dict[str, str]] =
|
||||||
form_data["schema"] = schema
|
form_data["schema"] = schema
|
||||||
if extra:
|
if extra:
|
||||||
form_data.update(extra)
|
form_data.update(extra)
|
||||||
|
if dtype:
|
||||||
|
form_data["dtype"] = dtype
|
||||||
return get_resp(test_client, "/csvtodatabaseview/form", data=form_data)
|
return get_resp(test_client, "/csvtodatabaseview/form", data=form_data)
|
||||||
|
|
||||||
|
|
||||||
|
@ -386,6 +393,39 @@ def test_import_csv(mock_event_logger):
|
||||||
data = engine.execute(f"SELECT * from {CSV_UPLOAD_TABLE}").fetchall()
|
data = engine.execute(f"SELECT * from {CSV_UPLOAD_TABLE}").fetchall()
|
||||||
assert data == [("john", 1, "x"), ("paul", 2, None)]
|
assert data == [("john", 1, "x"), ("paul", 2, None)]
|
||||||
|
|
||||||
|
# cleanup
|
||||||
|
with get_upload_db().get_sqla_engine_with_context() as engine:
|
||||||
|
engine.execute(f"DROP TABLE {full_table_name}")
|
||||||
|
|
||||||
|
# with dtype
|
||||||
|
upload_csv(
|
||||||
|
CSV_FILENAME1,
|
||||||
|
CSV_UPLOAD_TABLE,
|
||||||
|
dtype='{"a": "string", "b": "float64"}',
|
||||||
|
)
|
||||||
|
|
||||||
|
# you can change the type to something compatible, like an object to string
|
||||||
|
# or an int to a float
|
||||||
|
# file upload should work as normal
|
||||||
|
with test_db.get_sqla_engine_with_context() as engine:
|
||||||
|
data = engine.execute(f"SELECT * from {CSV_UPLOAD_TABLE}").fetchall()
|
||||||
|
assert data == [("john", 1), ("paul", 2)]
|
||||||
|
|
||||||
|
# cleanup
|
||||||
|
with get_upload_db().get_sqla_engine_with_context() as engine:
|
||||||
|
engine.execute(f"DROP TABLE {full_table_name}")
|
||||||
|
|
||||||
|
# with dtype - wrong type
|
||||||
|
resp = upload_csv(
|
||||||
|
CSV_FILENAME1,
|
||||||
|
CSV_UPLOAD_TABLE,
|
||||||
|
dtype='{"a": "int"}',
|
||||||
|
)
|
||||||
|
|
||||||
|
# you cannot pass an incompatible dtype
|
||||||
|
fail_msg = f"Unable to upload CSV file {escaped_double_quotes(CSV_FILENAME1)} to table {escaped_double_quotes(CSV_UPLOAD_TABLE)}"
|
||||||
|
assert fail_msg in resp
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("setup_csv_upload_with_context")
|
@pytest.mark.usefixtures("setup_csv_upload_with_context")
|
||||||
@pytest.mark.usefixtures("create_excel_files")
|
@pytest.mark.usefixtures("create_excel_files")
|
||||||
|
|
|
@ -14,11 +14,18 @@
|
||||||
# KIND, either express or implied. See the License for the
|
# KIND, either express or implied. See the License for the
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
|
import unittest.mock as mock
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
from sqlalchemy.types import NVARCHAR
|
||||||
|
|
||||||
from superset.db_engine_specs.redshift import RedshiftEngineSpec
|
from superset.db_engine_specs.redshift import RedshiftEngineSpec
|
||||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||||
|
from superset.sql_parse import Table
|
||||||
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
|
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
|
||||||
|
from tests.integration_tests.test_app import app
|
||||||
|
|
||||||
|
|
||||||
class TestRedshiftDbEngineSpec(TestDbEngineSpec):
|
class TestRedshiftDbEngineSpec(TestDbEngineSpec):
|
||||||
|
@ -183,3 +190,57 @@ psql: error: could not connect to server: Operation timed out
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def test_df_to_sql_no_dtype(self):
|
||||||
|
mock_database = mock.MagicMock()
|
||||||
|
mock_database.get_df.return_value.empty = False
|
||||||
|
table_name = "foobar"
|
||||||
|
data = [
|
||||||
|
("foo", "bar", pd.NA, None),
|
||||||
|
("foo", "bar", pd.NA, True),
|
||||||
|
("foo", "bar", pd.NA, None),
|
||||||
|
]
|
||||||
|
numpy_dtype = [
|
||||||
|
("id", "object"),
|
||||||
|
("value", "object"),
|
||||||
|
("num", "object"),
|
||||||
|
("bool", "object"),
|
||||||
|
]
|
||||||
|
column_names = ["id", "value", "num", "bool"]
|
||||||
|
|
||||||
|
test_array = np.array(data, dtype=numpy_dtype)
|
||||||
|
|
||||||
|
df = pd.DataFrame(test_array, columns=column_names)
|
||||||
|
df.to_sql = mock.MagicMock()
|
||||||
|
|
||||||
|
with app.app_context():
|
||||||
|
RedshiftEngineSpec.df_to_sql(
|
||||||
|
mock_database, Table(table=table_name), df, to_sql_kwargs={}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert df.to_sql.call_args[1]["dtype"] == {}
|
||||||
|
|
||||||
|
def test_df_to_sql_with_string_dtype(self):
|
||||||
|
mock_database = mock.MagicMock()
|
||||||
|
mock_database.get_df.return_value.empty = False
|
||||||
|
table_name = "foobar"
|
||||||
|
data = [
|
||||||
|
("foo", "bar", pd.NA, None),
|
||||||
|
("foo", "bar", pd.NA, True),
|
||||||
|
("foo", "bar", pd.NA, None),
|
||||||
|
]
|
||||||
|
column_names = ["id", "value", "num", "bool"]
|
||||||
|
|
||||||
|
df = pd.DataFrame(data, columns=column_names)
|
||||||
|
df = df.astype(dtype={"value": "string"})
|
||||||
|
df.to_sql = mock.MagicMock()
|
||||||
|
|
||||||
|
with app.app_context():
|
||||||
|
RedshiftEngineSpec.df_to_sql(
|
||||||
|
mock_database, Table(table=table_name), df, to_sql_kwargs={}
|
||||||
|
)
|
||||||
|
|
||||||
|
# varchar string length should be 65535
|
||||||
|
dtype = df.to_sql.call_args[1]["dtype"]
|
||||||
|
assert isinstance(dtype["value"], NVARCHAR)
|
||||||
|
assert dtype["value"].length == 65535
|
||||||
|
|
Loading…
Reference in New Issue