Add csv upload support for BigQuery (#7756)

* Add extra_require for bigquery to setup.py

* Refactor df_to_db and add df upload capability for BigQuery

* Fix unit tests and clarify kwarg logic

* Fix flake8 errors

* Add minimum versions for bigquery dependencies

* wrap to_gbq in try-catch block and raise error if pandas-gbq is missing

* Fix linting error and make error more generic
This commit is contained in:
Ville Brofeldt 2019-06-24 00:20:09 +03:00 committed by Maxime Beauchemin
parent 90d156f186
commit 1c4092c61c
5 changed files with 74 additions and 26 deletions

View File

@ -108,6 +108,10 @@ setup(
'wtforms-json',
],
extras_require={
'bigquery': [
'pybigquery>=0.4.10',
'pandas_gbq>=0.10.0',
],
'cors': ['flask-cors>=2.0.0'],
'hive': [
'pyhive[hive]>=0.6.1',

View File

@ -230,36 +230,45 @@ class BaseEngineSpec(object):
return parsed_query.get_query_with_new_limit(limit)
@staticmethod
def csv_to_df(**kwargs):
def csv_to_df(**kwargs) -> pd.DataFrame:
""" Read csv into Pandas DataFrame
:param kwargs: params to be passed to DataFrame.read_csv
:return: Pandas DataFrame containing data from csv
"""
kwargs['filepath_or_buffer'] = \
config['UPLOAD_FOLDER'] + kwargs['filepath_or_buffer']
kwargs['encoding'] = 'utf-8'
kwargs['iterator'] = True
chunks = pd.read_csv(**kwargs)
df = pd.DataFrame()
df = pd.concat(chunk for chunk in chunks)
return df
@staticmethod
def df_to_db(df, table, **kwargs):
@classmethod
def df_to_sql(cls, df: pd.DataFrame, **kwargs):
""" Upload data from a Pandas DataFrame to a database. For
regular engines this calls the DataFrame.to_sql() method. Can be
overridden for engines that don't work well with to_sql(), e.g.
BigQuery.
:param df: Dataframe with data to be uploaded
:param kwargs: kwargs to be passed to to_sql() method
"""
df.to_sql(**kwargs)
table.user_id = g.user.id
table.schema = kwargs['schema']
table.fetch_metadata()
db.session.add(table)
db.session.commit()
@staticmethod
def create_table_from_csv(form, table):
def _allowed_file(filename):
@classmethod
def create_table_from_csv(cls, form, table):
""" Create table (including metadata in backend) from contents of a csv.
:param form: Parameters defining how to process data
:param table: Metadata of new table to be created
"""
def _allowed_file(filename: str) -> bool:
# Only allow specific file extensions as specified in the config
extension = os.path.splitext(filename)[1]
return extension and extension[1:] in config['ALLOWED_EXTENSIONS']
return extension is not None and extension[1:] in config['ALLOWED_EXTENSIONS']
filename = secure_filename(form.csv_file.data.filename)
if not _allowed_file(filename):
raise Exception('Invalid file type selected')
kwargs = {
csv_to_df_kwargs = {
'filepath_or_buffer': filename,
'sep': form.sep.data,
'header': form.header.data if form.header.data else 0,
@ -273,10 +282,9 @@ class BaseEngineSpec(object):
'infer_datetime_format': form.infer_datetime_format.data,
'chunksize': 10000,
}
df = BaseEngineSpec.csv_to_df(**kwargs)
df = cls.csv_to_df(**csv_to_df_kwargs)
df_to_db_kwargs = {
'table': table,
df_to_sql_kwargs = {
'df': df,
'name': form.name.data,
'con': create_engine(form.con.data.sqlalchemy_uri_decrypted, echo=False),
@ -286,8 +294,13 @@ class BaseEngineSpec(object):
'index_label': form.index_label.data,
'chunksize': 10000,
}
cls.df_to_sql(**df_to_sql_kwargs)
BaseEngineSpec.df_to_db(**df_to_db_kwargs)
table.user_id = g.user.id
table.schema = form.schema.data
table.fetch_metadata()
db.session.add(table)
db.session.commit()
@classmethod
def convert_dttm(cls, target_type, dttm):

View File

@ -14,10 +14,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=C,R,W
import hashlib
import re
import pandas as pd
from sqlalchemy import literal_column
from superset.db_engine_specs.base import BaseEngineSpec
@ -86,8 +86,8 @@ class BigQueryEngineSpec(BaseEngineSpec):
# replace non-alphanumeric characters with underscores
label_mutated = re.sub(r'[^\w]+', '_', label_mutated)
if label_mutated != label:
# add md5 hash to label to avoid possible collisions
label_mutated += label_hashed
# add first 5 chars from md5 hash to label to avoid possible collisions
label_mutated += label_hashed[:6]
return label_mutated
@ -141,3 +141,34 @@ class BigQueryEngineSpec(BaseEngineSpec):
@classmethod
def epoch_ms_to_dttm(cls):
return 'TIMESTAMP_MILLIS({col})'
@classmethod
def df_to_sql(cls, df: pd.DataFrame, **kwargs):
"""
Upload data from a Pandas DataFrame to BigQuery. Calls
`DataFrame.to_gbq()` which requires `pandas_gbq` to be installed.
:param df: Dataframe with data to be uploaded
:param kwargs: kwargs to be passed to to_gbq() method. Requires both `schema
and ``name` to be present in kwargs, which are combined and passed to
`to_gbq()` as `destination_table`.
"""
try:
import pandas_gbq
except ImportError:
raise Exception('Could not import the library `pandas_gbq`, which is '
'required to be installed in your environment in order '
'to upload data to BigQuery')
if not ('name' in kwargs and 'schema' in kwargs):
raise Exception('name and schema need to be defined in kwargs')
gbq_kwargs = {}
gbq_kwargs['project_id'] = kwargs['con'].engine.url.host
gbq_kwargs['destination_table'] = f"{kwargs.pop('schema')}.{kwargs.pop('name')}"
# Only pass through supported kwargs
supported_kwarg_keys = {'if_exists'}
for key in supported_kwarg_keys:
if key in kwargs:
gbq_kwargs[key] = kwargs[key]
pandas_gbq.to_gbq(df, **gbq_kwargs)

View File

@ -94,8 +94,8 @@ class HiveEngineSpec(PrestoEngineSpec):
except pyhive.exc.ProgrammingError:
return []
@staticmethod
def create_table_from_csv(form, table):
@classmethod
def create_table_from_csv(cls, form, table):
"""Uploads a csv file and creates a superset datasource in Hive."""
def convert_to_hive_type(col_type):
"""maps tableschema's types to hive types"""

View File

@ -702,15 +702,15 @@ class DbEngineSpecsTestCase(SupersetTestCase):
self.assertEqual(label, label_expected)
label = BigQueryEngineSpec.make_label_compatible(column('SUM(x)').name)
label_expected = 'SUM_x__5f110b965a993675bc4953bb3e03c4a5'
label_expected = 'SUM_x__5f110'
self.assertEqual(label, label_expected)
label = BigQueryEngineSpec.make_label_compatible(column('SUM[x]').name)
label_expected = 'SUM_x__7ebe14a3f9534aeee125449b0bc083a8'
label_expected = 'SUM_x__7ebe1'
self.assertEqual(label, label_expected)
label = BigQueryEngineSpec.make_label_compatible(column('12345_col').name)
label_expected = '_12345_col_8d3906e2ea99332eb185f7f8ecb2ffd6'
label_expected = '_12345_col_8d390'
self.assertEqual(label, label_expected)
def test_oracle_sqla_column_name_length_exceeded(self):