feat: implement csv upload configuration func for the schema enforcement (#9734)

* Implement csv upload func for schema enforcement

Implement function controlled csv upload schema

Refactor + fix tests

Fixing hive as well

* Add explore_db to the extras

Co-authored-by: bogdan kyryliuk <bogdankyryliuk@dropbox.com>
This commit is contained in:
Bogdan 2020-05-21 13:49:53 -07:00 committed by GitHub
parent 333dc8529e
commit 3e8e441bfc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 275 additions and 183 deletions

View File

@ -586,11 +586,27 @@ CSV_TO_HIVE_UPLOAD_S3_BUCKET = None
# The directory within the bucket specified above that will # The directory within the bucket specified above that will
# contain all the external tables # contain all the external tables
CSV_TO_HIVE_UPLOAD_DIRECTORY = "EXTERNAL_HIVE_TABLES/" CSV_TO_HIVE_UPLOAD_DIRECTORY = "EXTERNAL_HIVE_TABLES/"
# Function that creates upload directory dynamically based on the
# database used, user and schema provided.
CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC: Callable[
["Database", "models.User", str], Optional[str]
] = lambda database, user, schema: CSV_TO_HIVE_UPLOAD_DIRECTORY
# The namespace within hive where the tables created from # The namespace within hive where the tables created from
# uploading CSVs will be stored. # uploading CSVs will be stored.
UPLOADED_CSV_HIVE_NAMESPACE = None UPLOADED_CSV_HIVE_NAMESPACE = None
# Function that computes the allowed schemas for the CSV uploads.
# Allowed schemas will be a union of schemas_allowed_for_csv_upload
# db configuration and a result of this function.
# mypy doesn't catch that if case ensures list content being always str
ALLOWED_USER_CSV_SCHEMA_FUNC: Callable[
["Database", "models.User"], List[str]
] = lambda database, user: [
UPLOADED_CSV_HIVE_NAMESPACE # type: ignore
] if UPLOADED_CSV_HIVE_NAMESPACE else []
# A dictionary of items that gets merged into the Jinja context for # A dictionary of items that gets merged into the Jinja context for
# SQL Lab. The existing context gets updated with this dictionary, # SQL Lab. The existing context gets updated with this dictionary,
# meaning values for existing keys get overwritten by the content of this # meaning values for existing keys get overwritten by the content of this

View File

@ -18,7 +18,6 @@
import hashlib import hashlib
import json import json
import logging import logging
import os
import re import re
from contextlib import closing from contextlib import closing
from datetime import datetime from datetime import datetime
@ -49,11 +48,11 @@ from sqlalchemy.orm import Session
from sqlalchemy.sql import quoted_name, text from sqlalchemy.sql import quoted_name, text
from sqlalchemy.sql.expression import ColumnClause, ColumnElement, Select, TextAsFrom from sqlalchemy.sql.expression import ColumnClause, ColumnElement, Select, TextAsFrom
from sqlalchemy.types import TypeEngine from sqlalchemy.types import TypeEngine
from wtforms.form import Form
from superset import app, sql_parse from superset import app, sql_parse
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.models.sql_lab import Query from superset.models.sql_lab import Query
from superset.sql_parse import Table
from superset.utils import core as utils from superset.utils import core as utils
if TYPE_CHECKING: if TYPE_CHECKING:
@ -454,55 +453,26 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
df.to_sql(**kwargs) df.to_sql(**kwargs)
@classmethod @classmethod
def create_table_from_csv(cls, form: Form, database: "Database") -> None: def create_table_from_csv( # pylint: disable=too-many-arguments
cls,
filename: str,
table: Table,
database: "Database",
csv_to_df_kwargs: Dict[str, Any],
df_to_sql_kwargs: Dict[str, Any],
) -> None:
""" """
Create table from contents of a csv. Note: this method does not create Create table from contents of a csv. Note: this method does not create
metadata for the table. metadata for the table.
:param form: Parameters defining how to process data
:param database: Database model object for the target database
""" """
df = cls.csv_to_df(filepath_or_buffer=filename, **csv_to_df_kwargs,)
def _allowed_file(filename: str) -> bool:
# Only allow specific file extensions as specified in the config
extension = os.path.splitext(filename)[1].lower()
return (
extension is not None and extension[1:] in config["ALLOWED_EXTENSIONS"]
)
filename = form.csv_file.data.filename
if not _allowed_file(filename):
raise Exception("Invalid file type selected")
csv_to_df_kwargs = {
"filepath_or_buffer": filename,
"sep": form.sep.data,
"header": form.header.data if form.header.data else 0,
"index_col": form.index_col.data,
"mangle_dupe_cols": form.mangle_dupe_cols.data,
"skipinitialspace": form.skipinitialspace.data,
"skiprows": form.skiprows.data,
"nrows": form.nrows.data,
"skip_blank_lines": form.skip_blank_lines.data,
"parse_dates": form.parse_dates.data,
"infer_datetime_format": form.infer_datetime_format.data,
"chunksize": 10000,
}
df = cls.csv_to_df(**csv_to_df_kwargs)
engine = cls.get_engine(database) engine = cls.get_engine(database)
if table.schema:
df_to_sql_kwargs = { # only add schema when it is preset and non empty
"df": df, df_to_sql_kwargs["schema"] = table.schema
"name": form.name.data, if engine.dialect.supports_multivalues_insert:
"con": engine, df_to_sql_kwargs["method"] = "multi"
"schema": form.schema.data, cls.df_to_sql(df=df, con=engine, **df_to_sql_kwargs)
"if_exists": form.if_exists.data,
"index": form.index.data,
"index_label": form.index_label.data,
"chunksize": 10000,
}
cls.df_to_sql(**df_to_sql_kwargs)
@classmethod @classmethod
def convert_dttm(cls, target_type: str, dttm: datetime) -> Optional[str]: def convert_dttm(cls, target_type: str, dttm: datetime) -> Optional[str]:

View File

@ -23,18 +23,19 @@ from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
from urllib import parse from urllib import parse
import pandas as pd import pandas as pd
from flask import g
from sqlalchemy import Column from sqlalchemy import Column
from sqlalchemy.engine.base import Engine from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import make_url, URL from sqlalchemy.engine.url import make_url, URL
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import ColumnClause, Select from sqlalchemy.sql.expression import ColumnClause, Select
from wtforms.form import Form
from superset import app, cache, conf from superset import app, cache, conf
from superset.db_engine_specs.base import BaseEngineSpec from superset.db_engine_specs.base import BaseEngineSpec
from superset.db_engine_specs.presto import PrestoEngineSpec from superset.db_engine_specs.presto import PrestoEngineSpec
from superset.models.sql_lab import Query from superset.models.sql_lab import Query
from superset.sql_parse import Table
from superset.utils import core as utils from superset.utils import core as utils
if TYPE_CHECKING: if TYPE_CHECKING:
@ -105,8 +106,13 @@ class HiveEngineSpec(PrestoEngineSpec):
return [] return []
@classmethod @classmethod
def create_table_from_csv( # pylint: disable=too-many-locals def create_table_from_csv( # pylint: disable=too-many-arguments, too-many-locals
cls, form: Form, database: "Database" cls,
filename: str,
table: Table,
database: "Database",
csv_to_df_kwargs: Dict[str, Any],
df_to_sql_kwargs: Dict[str, Any],
) -> None: ) -> None:
"""Uploads a csv file and creates a superset datasource in Hive.""" """Uploads a csv file and creates a superset datasource in Hive."""
@ -128,38 +134,16 @@ class HiveEngineSpec(PrestoEngineSpec):
"No upload bucket specified. You can specify one in the config file." "No upload bucket specified. You can specify one in the config file."
) )
table_name = form.name.data upload_prefix = config["CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC"](
schema_name = form.schema.data database, g.user, table.schema
)
if config["UPLOADED_CSV_HIVE_NAMESPACE"]:
if "." in table_name or schema_name:
raise Exception(
"You can't specify a namespace. "
"All tables will be uploaded to the `{}` namespace".format(
config["HIVE_NAMESPACE"]
)
)
full_table_name = "{}.{}".format(
config["UPLOADED_CSV_HIVE_NAMESPACE"], table_name
)
else:
if "." in table_name and schema_name:
raise Exception(
"You can't specify a namespace both in the name of the table "
"and in the schema field. Please remove one"
)
full_table_name = (
"{}.{}".format(schema_name, table_name) if schema_name else table_name
)
filename = form.csv_file.data.filename
upload_prefix = config["CSV_TO_HIVE_UPLOAD_DIRECTORY"]
# Optional dependency # Optional dependency
from tableschema import Table # pylint: disable=import-error from tableschema import ( # pylint: disable=import-error
Table as TableSchemaTable,
)
hive_table_schema = Table(filename).infer() hive_table_schema = TableSchemaTable(filename).infer()
column_name_and_type = [] column_name_and_type = []
for column_info in hive_table_schema["fields"]: for column_info in hive_table_schema["fields"]:
column_name_and_type.append( column_name_and_type.append(
@ -173,13 +157,14 @@ class HiveEngineSpec(PrestoEngineSpec):
import boto3 # pylint: disable=import-error import boto3 # pylint: disable=import-error
s3 = boto3.client("s3") s3 = boto3.client("s3")
location = os.path.join("s3a://", bucket_path, upload_prefix, table_name) location = os.path.join("s3a://", bucket_path, upload_prefix, table.table)
s3.upload_file( s3.upload_file(
filename, filename,
bucket_path, bucket_path,
os.path.join(upload_prefix, table_name, os.path.basename(filename)), os.path.join(upload_prefix, table.table, os.path.basename(filename)),
) )
sql = f"""CREATE TABLE {full_table_name} ( {schema_definition} ) # TODO(bkyryliuk): support other delimiters
sql = f"""CREATE TABLE {str(table)} ( {schema_definition} )
ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' STORED AS ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' STORED AS
TEXTFILE LOCATION '{location}' TEXTFILE LOCATION '{location}'
tblproperties ('skip.header.line.count'='1')""" tblproperties ('skip.header.line.count'='1')"""

View File

@ -609,7 +609,13 @@ class Database(
def get_schema_access_for_csv_upload( # pylint: disable=invalid-name def get_schema_access_for_csv_upload( # pylint: disable=invalid-name
self, self,
) -> List[str]: ) -> List[str]:
return self.get_extra().get("schemas_allowed_for_csv_upload", []) allowed_databases = self.get_extra().get("schemas_allowed_for_csv_upload", [])
if hasattr(g, "user"):
extra_allowed_databases = config["ALLOWED_USER_CSV_SCHEMA_FUNC"](
self, g.user
)
allowed_databases += extra_allowed_databases
return sorted(set(allowed_databases))
@property @property
def sqlalchemy_uri_decrypted(self) -> str: def sqlalchemy_uri_decrypted(self) -> str:

View File

@ -30,6 +30,7 @@ from superset import app, db
from superset.connectors.sqla.models import SqlaTable from superset.connectors.sqla.models import SqlaTable
from superset.constants import RouteMethod from superset.constants import RouteMethod
from superset.exceptions import CertificateException from superset.exceptions import CertificateException
from superset.sql_parse import Table
from superset.utils import core as utils from superset.utils import core as utils
from superset.views.base import DeleteMixin, SupersetModelView, YamlExportMixin from superset.views.base import DeleteMixin, SupersetModelView, YamlExportMixin
@ -109,66 +110,116 @@ class CsvToDatabaseView(SimpleFormView):
def form_post(self, form): def form_post(self, form):
database = form.con.data database = form.con.data
schema_name = form.schema.data or "" csv_table = Table(table=form.name.data, schema=form.schema.data)
if not schema_allows_csv_upload(database, schema_name): if not schema_allows_csv_upload(database, csv_table.schema):
message = _( message = _(
'Database "%(database_name)s" schema "%(schema_name)s" ' 'Database "%(database_name)s" schema "%(schema_name)s" '
"is not allowed for csv uploads. Please contact your Superset Admin.", "is not allowed for csv uploads. Please contact your Superset Admin.",
database_name=database.database_name, database_name=database.database_name,
schema_name=schema_name, schema_name=csv_table.schema,
) )
flash(message, "danger") flash(message, "danger")
return redirect("/csvtodatabaseview/form") return redirect("/csvtodatabaseview/form")
csv_filename = form.csv_file.data.filename if "." in csv_table.table and csv_table.schema:
extension = os.path.splitext(csv_filename)[1].lower() message = _(
path = tempfile.NamedTemporaryFile( "You cannot specify a namespace both in the name of the table: "
dir=app.config["UPLOAD_FOLDER"], suffix=extension, delete=False '"%(csv_table.table)s" and in the schema field: '
'"%(csv_table.schema)s". Please remove one',
table=csv_table.table,
schema=csv_table.schema,
)
flash(message, "danger")
return redirect("/csvtodatabaseview/form")
uploaded_tmp_file_path = tempfile.NamedTemporaryFile(
dir=app.config["UPLOAD_FOLDER"],
suffix=os.path.splitext(form.csv_file.data.filename)[1].lower(),
delete=False,
).name ).name
form.csv_file.data.filename = path
try: try:
utils.ensure_path_exists(config["UPLOAD_FOLDER"]) utils.ensure_path_exists(config["UPLOAD_FOLDER"])
upload_stream_write(form.csv_file.data, path) upload_stream_write(form.csv_file.data, uploaded_tmp_file_path)
table_name = form.name.data
con = form.data.get("con") con = form.data.get("con")
database = ( database = (
db.session.query(models.Database).filter_by(id=con.data.get("id")).one() db.session.query(models.Database).filter_by(id=con.data.get("id")).one()
) )
database.db_engine_spec.create_table_from_csv(form, database) csv_to_df_kwargs = {
table = ( "sep": form.sep.data,
"header": form.header.data if form.header.data else 0,
"index_col": form.index_col.data,
"mangle_dupe_cols": form.mangle_dupe_cols.data,
"skipinitialspace": form.skipinitialspace.data,
"skiprows": form.skiprows.data,
"nrows": form.nrows.data,
"skip_blank_lines": form.skip_blank_lines.data,
"parse_dates": form.parse_dates.data,
"infer_datetime_format": form.infer_datetime_format.data,
"chunksize": 1000,
}
df_to_sql_kwargs = {
"name": csv_table.table,
"if_exists": form.if_exists.data,
"index": form.index.data,
"index_label": form.index_label.data,
"chunksize": 1000,
}
database.db_engine_spec.create_table_from_csv(
uploaded_tmp_file_path,
csv_table,
database,
csv_to_df_kwargs,
df_to_sql_kwargs,
)
# Connect table to the database that should be used for exploration.
# E.g. if hive was used to upload a csv, presto will be a better option
# to explore the table.
expore_database = database
explore_database_id = database.get_extra().get("explore_database_id", None)
if explore_database_id:
expore_database = (
db.session.query(models.Database)
.filter_by(id=explore_database_id)
.one_or_none()
or database
)
sqla_table = (
db.session.query(SqlaTable) db.session.query(SqlaTable)
.filter_by( .filter_by(
table_name=table_name, table_name=csv_table.table,
schema=form.schema.data, schema=csv_table.schema,
database_id=database.id, database_id=expore_database.id,
) )
.one_or_none() .one_or_none()
) )
if table:
table.fetch_metadata() if sqla_table:
if not table: sqla_table.fetch_metadata()
table = SqlaTable(table_name=table_name) if not sqla_table:
table.database = database sqla_table = SqlaTable(table_name=csv_table.table)
table.database_id = database.id sqla_table.database = expore_database
table.user_id = g.user.id sqla_table.database_id = database.id
table.schema = form.schema.data sqla_table.user_id = g.user.id
table.fetch_metadata() sqla_table.schema = csv_table.schema
db.session.add(table) sqla_table.fetch_metadata()
db.session.add(sqla_table)
db.session.commit() db.session.commit()
except Exception as ex: # pylint: disable=broad-except except Exception as ex: # pylint: disable=broad-except
db.session.rollback() db.session.rollback()
try: try:
os.remove(path) os.remove(uploaded_tmp_file_path)
except OSError: except OSError:
pass pass
message = _( message = _(
'Unable to upload CSV file "%(filename)s" to table ' 'Unable to upload CSV file "%(filename)s" to table '
'"%(table_name)s" in database "%(db_name)s". ' '"%(table_name)s" in database "%(db_name)s". '
"Error message: %(error_msg)s", "Error message: %(error_msg)s",
filename=csv_filename, filename=form.csv_file.data.filename,
table_name=form.name.data, table_name=form.name.data,
db_name=database.database_name, db_name=database.database_name,
error_msg=str(ex), error_msg=str(ex),
@ -178,14 +229,14 @@ class CsvToDatabaseView(SimpleFormView):
stats_logger.incr("failed_csv_upload") stats_logger.incr("failed_csv_upload")
return redirect("/csvtodatabaseview/form") return redirect("/csvtodatabaseview/form")
os.remove(path) os.remove(uploaded_tmp_file_path)
# Go back to welcome page / splash screen # Go back to welcome page / splash screen
message = _( message = _(
'CSV file "%(csv_filename)s" uploaded to table "%(table_name)s" in ' 'CSV file "%(csv_filename)s" uploaded to table "%(table_name)s" in '
'database "%(db_name)s"', 'database "%(db_name)s"',
csv_filename=csv_filename, csv_filename=form.csv_file.data.filename,
table_name=form.name.data, table_name=str(csv_table),
db_name=table.database.database_name, db_name=sqla_table.database.database_name,
) )
flash(message, "info") flash(message, "info")
stats_logger.incr("successful_csv_upload") stats_logger.incr("successful_csv_upload")

View File

@ -24,6 +24,8 @@ import io
import json import json
import logging import logging
import os import os
from typing import Dict, List, Optional
import pytz import pytz
import random import random
import re import re
@ -44,6 +46,7 @@ from superset import (
is_feature_enabled, is_feature_enabled,
) )
from superset.connectors.sqla.models import SqlaTable from superset.connectors.sqla.models import SqlaTable
from superset.datasets.dao import DatasetDAO
from superset.db_engine_specs.base import BaseEngineSpec from superset.db_engine_specs.base import BaseEngineSpec
from superset.db_engine_specs.mssql import MssqlEngineSpec from superset.db_engine_specs.mssql import MssqlEngineSpec
from superset.models import core as models from superset.models import core as models
@ -769,102 +772,163 @@ class CoreTests(SupersetTestCase):
self.get_json_resp(slc_url, {"form_data": json.dumps(slc.form_data)}) self.get_json_resp(slc_url, {"form_data": json.dumps(slc.form_data)})
self.assertEqual(1, qry.count()) self.assertEqual(1, qry.count())
def test_import_csv(self): def create_sample_csvfile(self, filename: str, content: List[str]) -> None:
self.login(username="admin") with open(filename, "w+") as test_file:
table_name = "".join(random.choice(string.ascii_uppercase) for _ in range(5)) for l in content:
test_file.write(f"{l}\n")
filename_1 = "testCSV.csv" def enable_csv_upload(self, database: models.Database) -> None:
test_file_1 = open(filename_1, "w+") """Enables csv upload in the given database."""
test_file_1.write("a,b\n") database.allow_csv_upload = True
test_file_1.write("john,1\n")
test_file_1.write("paul,2\n")
test_file_1.close()
filename_2 = "testCSV2.csv"
test_file_2 = open(filename_2, "w+")
test_file_2.write("b,c,d\n")
test_file_2.write("john,1,x\n")
test_file_2.write("paul,2,y\n")
test_file_2.close()
example_db = utils.get_example_database()
example_db.allow_csv_upload = True
db_id = example_db.id
db.session.commit() db.session.commit()
add_datasource_page = self.get_resp("/databaseview/list/")
self.assertIn("Upload a CSV", add_datasource_page)
form_get = self.get_resp("/csvtodatabaseview/form")
self.assertIn("CSV to Database configuration", form_get)
def upload_csv(
self, filename: str, table_name: str, extra: Optional[Dict[str, str]] = None
):
form_data = { form_data = {
"csv_file": open(filename_1, "rb"), "csv_file": open(filename, "rb"),
"sep": ",", "sep": ",",
"name": table_name, "name": table_name,
"con": db_id, "con": utils.get_example_database().id,
"if_exists": "fail", "if_exists": "fail",
"index_label": "test_label", "index_label": "test_label",
"mangle_dupe_cols": False, "mangle_dupe_cols": False,
} }
url = "/databaseview/list/" if extra:
add_datasource_page = self.get_resp(url) form_data.update(extra)
self.assertIn("Upload a CSV", add_datasource_page) return self.get_resp("/csvtodatabaseview/form", data=form_data)
url = "/csvtodatabaseview/form" @mock.patch(
form_get = self.get_resp(url) "superset.models.core.config",
self.assertIn("CSV to Database configuration", form_get) {**app.config, "ALLOWED_USER_CSV_SCHEMA_FUNC": lambda d, u: ["admin_database"]},
)
def test_import_csv_enforced_schema(self):
if utils.get_example_database().backend == "sqlite":
# sqlite doesn't support schema / database creation
return
self.login(username="admin")
table_name = "".join(random.choice(string.ascii_lowercase) for _ in range(5))
full_table_name = f"admin_database.{table_name}"
filename = "testCSV.csv"
self.create_sample_csvfile(filename, ["a,b", "john,1", "paul,2"])
try:
self.enable_csv_upload(utils.get_example_database())
# no schema specified, fail upload
resp = self.upload_csv(filename, table_name)
self.assertIn(
'Database "examples" schema "None" is not allowed for csv uploads', resp
)
# user specified schema matches the expected schema, append
success_msg = f'CSV file "{filename}" uploaded to table "{full_table_name}"'
resp = self.upload_csv(
filename,
table_name,
extra={"schema": "admin_database", "if_exists": "append"},
)
self.assertIn(success_msg, resp)
resp = self.upload_csv(
filename,
table_name,
extra={"schema": "admin_database", "if_exists": "replace"},
)
self.assertIn(success_msg, resp)
# user specified schema doesn't match, fail
resp = self.upload_csv(filename, table_name, extra={"schema": "gold"})
self.assertIn(
'Database "examples" schema "gold" is not allowed for csv uploads',
resp,
)
finally:
os.remove(filename)
def test_import_csv_explore_database(self):
if utils.get_example_database().backend == "sqlite":
# sqlite doesn't support schema / database creation
return
explore_db_id = utils.get_example_database().id
upload_db = utils.get_or_create_db(
"csv_explore_db", app.config["SQLALCHEMY_DATABASE_URI"]
)
upload_db_id = upload_db.id
extra = upload_db.get_extra()
extra["explore_database_id"] = explore_db_id
upload_db.extra = json.dumps(extra)
db.session.commit()
self.login(username="admin")
self.enable_csv_upload(DatasetDAO.get_database_by_id(upload_db_id))
table_name = "".join(random.choice(string.ascii_uppercase) for _ in range(5))
f = "testCSV.csv"
self.create_sample_csvfile(f, ["a,b", "john,1", "paul,2"])
# initial upload with fail mode
resp = self.upload_csv(f, table_name)
self.assertIn(f'CSV file "{f}" uploaded to table "{table_name}"', resp)
table = self.get_table_by_name(table_name)
self.assertEqual(table.database_id, explore_db_id)
# cleanup
db.session.delete(table)
db.session.delete(DatasetDAO.get_database_by_id(upload_db_id))
db.session.commit()
os.remove(f)
def test_import_csv(self):
self.login(username="admin")
table_name = "".join(random.choice(string.ascii_uppercase) for _ in range(5))
f1 = "testCSV.csv"
self.create_sample_csvfile(f1, ["a,b", "john,1", "paul,2"])
f2 = "testCSV2.csv"
self.create_sample_csvfile(f2, ["b,c,d", "john,1,x", "paul,2,y"])
self.enable_csv_upload(utils.get_example_database())
try: try:
success_msg_f1 = f'CSV file "{f1}" uploaded to table "{table_name}"'
# initial upload with fail mode # initial upload with fail mode
resp = self.get_resp(url, data=form_data) resp = self.upload_csv(f1, table_name)
self.assertIn( self.assertIn(success_msg_f1, resp)
f'CSV file "{filename_1}" uploaded to table "{table_name}"', resp
)
# upload again with fail mode; should fail # upload again with fail mode; should fail
form_data["csv_file"] = open(filename_1, "rb") fail_msg = f'Unable to upload CSV file "{f1}" to table "{table_name}"'
resp = self.get_resp(url, data=form_data) resp = self.upload_csv(f1, table_name)
self.assertIn( self.assertIn(fail_msg, resp)
f'Unable to upload CSV file "{filename_1}" to table "{table_name}"',
resp,
)
# upload again with append mode # upload again with append mode
form_data["csv_file"] = open(filename_1, "rb") resp = self.upload_csv(f1, table_name, extra={"if_exists": "append"})
form_data["if_exists"] = "append" self.assertIn(success_msg_f1, resp)
resp = self.get_resp(url, data=form_data)
self.assertIn(
f'CSV file "{filename_1}" uploaded to table "{table_name}"', resp
)
# upload again with replace mode # upload again with replace mode
form_data["csv_file"] = open(filename_1, "rb") resp = self.upload_csv(f1, table_name, extra={"if_exists": "replace"})
form_data["if_exists"] = "replace" self.assertIn(success_msg_f1, resp)
resp = self.get_resp(url, data=form_data)
self.assertIn(
f'CSV file "{filename_1}" uploaded to table "{table_name}"', resp
)
# try to append to table from file with different schema # try to append to table from file with different schema
form_data["csv_file"] = open(filename_2, "rb") resp = self.upload_csv(f2, table_name, extra={"if_exists": "append"})
form_data["if_exists"] = "append" fail_msg_f2 = f'Unable to upload CSV file "{f2}" to table "{table_name}"'
resp = self.get_resp(url, data=form_data) self.assertIn(fail_msg_f2, resp)
self.assertIn(
f'Unable to upload CSV file "{filename_2}" to table "{table_name}"',
resp,
)
# replace table from file with different schema # replace table from file with different schema
form_data["csv_file"] = open(filename_2, "rb") resp = self.upload_csv(f2, table_name, extra={"if_exists": "replace"})
form_data["if_exists"] = "replace" success_msg_f2 = f'CSV file "{f2}" uploaded to table "{table_name}"'
resp = self.get_resp(url, data=form_data) self.assertIn(success_msg_f2, resp)
self.assertIn(
f'CSV file "{filename_2}" uploaded to table "{table_name}"', resp table = self.get_table_by_name(table_name)
)
table = (
db.session.query(SqlaTable)
.filter_by(table_name=table_name, database_id=db_id)
.first()
)
# make sure the new column name is reflected in the table metadata # make sure the new column name is reflected in the table metadata
self.assertIn("d", table.column_names) self.assertIn("d", table.column_names)
finally: finally:
os.remove(filename_1) os.remove(f1)
os.remove(filename_2) os.remove(f2)
def test_dataframe_timezone(self): def test_dataframe_timezone(self):
tz = pytz.FixedOffset(60) tz = pytz.FixedOffset(60)