mirror of https://github.com/apache/superset.git
feat(trino): Add functionality to upload data (#29164)
This commit is contained in:
parent
6b016da185
commit
53798c7904
|
@ -132,6 +132,7 @@ gevent = ["gevent>=23.9.1"]
|
|||
gsheets = ["shillelagh[gsheetsapi]>=1.2.18, <2"]
|
||||
hana = ["hdbcli==2.4.162", "sqlalchemy_hana==0.4.0"]
|
||||
hive = [
|
||||
"boto3",
|
||||
"pyhive[hive]>=0.6.5;python_version<'3.11'",
|
||||
"pyhive[hive_pure_sasl]>=0.7.0",
|
||||
"tableschema",
|
||||
|
@ -154,7 +155,7 @@ pinot = ["pinotdb>=0.3.3, <0.4"]
|
|||
playwright = ["playwright>=1.37.0, <2"]
|
||||
postgres = ["psycopg2-binary==2.9.6"]
|
||||
presto = ["pyhive[presto]>=0.6.5"]
|
||||
trino = ["trino>=0.328.0"]
|
||||
trino = ["boto3", "trino>=0.328.0"]
|
||||
prophet = ["prophet>=1.1.5, <2"]
|
||||
redshift = ["sqlalchemy-redshift>=0.8.1, <0.9"]
|
||||
rockset = ["rockset-sqlalchemy>=0.0.1, <1"]
|
||||
|
|
|
@ -79,6 +79,12 @@ def upload_to_s3(filename: str, upload_prefix: str, table: Table) -> str:
|
|||
)
|
||||
|
||||
s3 = boto3.client("s3")
|
||||
|
||||
# The location is merely an S3 prefix and thus we first need to ensure that there is
|
||||
# one and only one key associated with the table.
|
||||
bucket = s3.Bucket(bucket_path)
|
||||
bucket.objects.filter(Prefix=os.path.join(upload_prefix, table.table)).delete()
|
||||
|
||||
location = os.path.join("s3a://", bucket_path, upload_prefix, table.table)
|
||||
s3.upload_file(
|
||||
filename,
|
||||
|
|
|
@ -20,9 +20,14 @@ import contextlib
|
|||
import logging
|
||||
import threading
|
||||
import time
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
from flask import current_app, Flask
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
from flask import current_app, Flask, g
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.engine.url import URL
|
||||
from sqlalchemy.exc import NoSuchTableError
|
||||
|
@ -37,7 +42,9 @@ from superset.db_engine_specs.exceptions import (
|
|||
SupersetDBAPIOperationalError,
|
||||
SupersetDBAPIProgrammingError,
|
||||
)
|
||||
from superset.db_engine_specs.hive import upload_to_s3
|
||||
from superset.db_engine_specs.presto import PrestoBaseEngineSpec
|
||||
from superset.exceptions import SupersetException
|
||||
from superset.models.sql_lab import Query
|
||||
from superset.sql_parse import Table
|
||||
from superset.superset_typing import ResultSetColumnType
|
||||
|
@ -452,3 +459,83 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
|
|||
return super().get_indexes(database, inspector, table)
|
||||
except NoSuchTableError:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def df_to_sql(
|
||||
cls,
|
||||
database: Database,
|
||||
table: Table,
|
||||
df: pd.DataFrame,
|
||||
to_sql_kwargs: dict[str, Any],
|
||||
) -> None:
|
||||
"""
|
||||
Upload data from a Pandas DataFrame to a database.
|
||||
|
||||
The data is stored via the binary Parquet format which is both less problematic
|
||||
and more performant than a text file.
|
||||
|
||||
Note this method does not create metadata for the table.
|
||||
|
||||
:param database: The database to upload the data to
|
||||
:param table: The table to upload the data to
|
||||
:param df: The Pandas Dataframe with data to be uploaded
|
||||
:param to_sql_kwargs: The `pandas.DataFrame.to_sql` keyword arguments
|
||||
:see: superset.db_engine_specs.HiveEngineSpec.df_to_sql
|
||||
"""
|
||||
|
||||
# pylint: disable=import-outside-toplevel
|
||||
|
||||
if to_sql_kwargs["if_exists"] == "append":
|
||||
raise SupersetException("Append operation not currently supported")
|
||||
|
||||
if to_sql_kwargs["if_exists"] == "fail":
|
||||
if database.has_table_by_name(table.table, table.schema):
|
||||
raise SupersetException("Table already exists")
|
||||
elif to_sql_kwargs["if_exists"] == "replace":
|
||||
with cls.get_engine(database) as engine:
|
||||
engine.execute(f"DROP TABLE IF EXISTS {str(table)}")
|
||||
|
||||
def _get_trino_type(dtype: np.dtype[Any]) -> str:
|
||||
return {
|
||||
np.dtype("bool"): "BOOLEAN",
|
||||
np.dtype("float64"): "DOUBLE",
|
||||
np.dtype("int64"): "BIGINT",
|
||||
np.dtype("object"): "VARCHAR",
|
||||
}.get(dtype, "VARCHAR")
|
||||
|
||||
with NamedTemporaryFile(
|
||||
dir=current_app.config["UPLOAD_FOLDER"],
|
||||
suffix=".parquet",
|
||||
) as file:
|
||||
pa.parquet.write_table(pa.Table.from_pandas(df), where=file.name)
|
||||
|
||||
with cls.get_engine(database) as engine:
|
||||
engine.execute(
|
||||
# pylint: disable=consider-using-f-string
|
||||
text(
|
||||
"""
|
||||
CREATE TABLE {table} ({schema})
|
||||
WITH (
|
||||
format = 'PARQUET',
|
||||
external_location = '{location}'
|
||||
)
|
||||
""".format(
|
||||
location=upload_to_s3(
|
||||
filename=file.name,
|
||||
upload_prefix=current_app.config[
|
||||
"CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC"
|
||||
](
|
||||
database,
|
||||
g.user,
|
||||
table.schema,
|
||||
),
|
||||
table=table,
|
||||
),
|
||||
schema=", ".join(
|
||||
f'"{name}" {_get_trino_type(dtype)}'
|
||||
for name, dtype in df.dtypes.items()
|
||||
),
|
||||
table=str(table),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
|
|
@ -0,0 +1,115 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from unittest import mock
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from superset.db_engine_specs.trino import TrinoEngineSpec
|
||||
from superset.exceptions import SupersetException
|
||||
from superset.sql_parse import Table
|
||||
from tests.integration_tests.test_app import app
|
||||
|
||||
|
||||
def test_df_to_csv() -> None:
|
||||
with pytest.raises(SupersetException):
|
||||
TrinoEngineSpec.df_to_sql(
|
||||
mock.MagicMock(),
|
||||
Table("foobar"),
|
||||
pd.DataFrame(),
|
||||
{"if_exists": "append"},
|
||||
)
|
||||
|
||||
|
||||
@mock.patch("superset.db_engine_specs.trino.g", spec={})
|
||||
def test_df_to_sql_if_exists_fail(mock_g):
|
||||
mock_g.user = True
|
||||
mock_database = mock.MagicMock()
|
||||
mock_database.get_df.return_value.empty = False
|
||||
with pytest.raises(SupersetException, match="Table already exists"):
|
||||
TrinoEngineSpec.df_to_sql(
|
||||
mock_database, Table("foobar"), pd.DataFrame(), {"if_exists": "fail"}
|
||||
)
|
||||
|
||||
|
||||
@mock.patch("superset.db_engine_specs.trino.g", spec={})
|
||||
def test_df_to_sql_if_exists_fail_with_schema(mock_g):
|
||||
mock_g.user = True
|
||||
mock_database = mock.MagicMock()
|
||||
mock_database.get_df.return_value.empty = False
|
||||
with pytest.raises(SupersetException, match="Table already exists"):
|
||||
TrinoEngineSpec.df_to_sql(
|
||||
mock_database,
|
||||
Table(table="foobar", schema="schema"),
|
||||
pd.DataFrame(),
|
||||
{"if_exists": "fail"},
|
||||
)
|
||||
|
||||
|
||||
@mock.patch("superset.db_engine_specs.trino.g", spec={})
|
||||
@mock.patch("superset.db_engine_specs.trino.upload_to_s3")
|
||||
def test_df_to_sql_if_exists_replace(mock_upload_to_s3, mock_g):
|
||||
config = app.config.copy()
|
||||
app.config["CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC"]: lambda *args: "" # noqa: F722
|
||||
mock_upload_to_s3.return_value = "mock-location"
|
||||
mock_g.user = True
|
||||
mock_database = mock.MagicMock()
|
||||
mock_database.get_df.return_value.empty = False
|
||||
mock_execute = mock.MagicMock(return_value=True)
|
||||
mock_database.get_sqla_engine.return_value.__enter__.return_value.execute = (
|
||||
mock_execute
|
||||
)
|
||||
table_name = "foobar"
|
||||
|
||||
with app.app_context():
|
||||
TrinoEngineSpec.df_to_sql(
|
||||
mock_database,
|
||||
Table(table=table_name),
|
||||
pd.DataFrame(),
|
||||
{"if_exists": "replace", "header": 1, "na_values": "mock", "sep": "mock"},
|
||||
)
|
||||
|
||||
mock_execute.assert_any_call(f"DROP TABLE IF EXISTS {table_name}")
|
||||
app.config = config
|
||||
|
||||
|
||||
@mock.patch("superset.db_engine_specs.trino.g", spec={})
|
||||
@mock.patch("superset.db_engine_specs.trino.upload_to_s3")
|
||||
def test_df_to_sql_if_exists_replace_with_schema(mock_upload_to_s3, mock_g):
|
||||
config = app.config.copy()
|
||||
app.config["CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC"]: lambda *args: "" # noqa: F722
|
||||
mock_upload_to_s3.return_value = "mock-location"
|
||||
mock_g.user = True
|
||||
mock_database = mock.MagicMock()
|
||||
mock_database.get_df.return_value.empty = False
|
||||
mock_execute = mock.MagicMock(return_value=True)
|
||||
mock_database.get_sqla_engine.return_value.__enter__.return_value.execute = (
|
||||
mock_execute
|
||||
)
|
||||
table_name = "foobar"
|
||||
schema = "schema"
|
||||
|
||||
with app.app_context():
|
||||
TrinoEngineSpec.df_to_sql(
|
||||
mock_database,
|
||||
Table(table=table_name, schema=schema),
|
||||
pd.DataFrame(),
|
||||
{"if_exists": "replace", "header": 1, "na_values": "mock", "sep": "mock"},
|
||||
)
|
||||
|
||||
mock_execute.assert_any_call(f"DROP TABLE IF EXISTS {schema}.{table_name}")
|
||||
app.config = config
|
Loading…
Reference in New Issue