diff --git a/requirements/development.txt b/requirements/development.txt index bba99045cc..7155b2e1a5 100644 --- a/requirements/development.txt +++ b/requirements/development.txt @@ -106,7 +106,7 @@ pylint==2.17.4 # via -r requirements/development.in python-ldap==3.4.3 # via -r requirements/development.in -requests==2.30.0 +requests==2.31.0 # via # pydruid # tableschema diff --git a/requirements/testing.in b/requirements/testing.in index 856c5272dc..b991be1040 100644 --- a/requirements/testing.in +++ b/requirements/testing.in @@ -16,7 +16,7 @@ # -r development.in -r integration.in --e file:.[bigquery,hive,presto,prophet,trino] +-e file:.[bigquery,hive,presto,prophet,trino,gsheets] docker flask-testing freezegun diff --git a/requirements/testing.txt b/requirements/testing.txt index 283b4c9fcd..c8a3221b45 100644 --- a/requirements/testing.txt +++ b/requirements/testing.txt @@ -1,4 +1,4 @@ -# SHA1:78fe89f88adf34ac75513d363d7d9d0b5cc8cd1c +# SHA1:78d0270a4f583095e0587aa21f57fc2ff7fe8b84 # # This file is autogenerated by pip-compile-multi # To update, run: @@ -12,6 +12,8 @@ # -r requirements/base.in # -r requirements/development.in # -r requirements/testing.in +apsw==3.42.0.1 + # via shillelagh cmdstanpy==1.1.0 # via prophet contourpy==1.0.7 @@ -50,6 +52,7 @@ google-auth==2.17.3 # google-cloud-core # pandas-gbq # pydata-google-auth + # shillelagh # sqlalchemy-bigquery google-auth-oauthlib==1.0.0 # via @@ -142,6 +145,8 @@ rfc3339-validator==0.1.4 # via openapi-schema-validator rsa==4.9 # via google-auth +shillelagh[gsheetsapi]==1.2.6 + # via apache-superset sqlalchemy-bigquery==1.6.1 # via apache-superset statsd==4.0.1 diff --git a/superset/db_engine_specs/gsheets.py b/superset/db_engine_specs/gsheets.py index 777499a8f9..a9ec921188 100644 --- a/superset/db_engine_specs/gsheets.py +++ b/superset/db_engine_specs/gsheets.py @@ -14,30 +14,44 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +from __future__ import annotations + import json +import logging import re from re import Pattern -from typing import Any, Optional, TYPE_CHECKING +from typing import Any, TYPE_CHECKING +import pandas as pd from apispec import APISpec from apispec.ext.marshmallow import MarshmallowPlugin from flask import g from flask_babel import gettext as __ from marshmallow import fields, Schema from marshmallow.exceptions import ValidationError +from requests import Session from sqlalchemy.engine import create_engine from sqlalchemy.engine.url import URL from typing_extensions import TypedDict -from superset import security_manager +from superset import db, security_manager from superset.constants import PASSWORD_MASK from superset.databases.schemas import encrypted_field_properties, EncryptedString from superset.db_engine_specs.shillelagh import ShillelaghEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType +from superset.exceptions import SupersetException if TYPE_CHECKING: from superset.models.core import Database + from superset.sql_parse import Table +_logger = logging.getLogger() + +EXAMPLE_GSHEETS_URL = ( + "https://docs.google.com/spreadsheets/d/" + "1LcWZMsdCl92g7nA-D6qGRqg1T5TiHyuKJUY1u9XAnsk/edit#gid=0" +) SYNTAX_ERROR_REGEX = re.compile('SQLError: near "(?P.*?)": syntax error') @@ -57,7 +71,7 @@ class GSheetsParametersSchema(Schema): class GSheetsParametersType(TypedDict): service_account_info: str - catalog: Optional[dict[str, str]] + catalog: dict[str, str] | None class GSheetsPropertiesType(TypedDict): @@ -88,14 +102,14 @@ class GSheetsEngineSpec(ShillelaghEngineSpec): ), } - supports_file_upload = False + supports_file_upload = True @classmethod def get_url_for_impersonation( cls, url: URL, impersonate_user: bool, - username: Optional[str], + username: str | None, ) -> URL: if impersonate_user and username is not None: user = security_manager.find_user(username=username) @@ -107,9 +121,9 @@ class GSheetsEngineSpec(ShillelaghEngineSpec): @classmethod def extra_table_metadata( cls, - database: "Database", + database: Database, table_name: str, - schema_name: Optional[str], + schema_name: str | None, ) -> dict[str, Any]: with database.get_raw_connection(schema=schema_name) as conn: cursor = conn.cursor() @@ -126,9 +140,8 @@ class GSheetsEngineSpec(ShillelaghEngineSpec): def build_sqlalchemy_uri( cls, _: GSheetsParametersType, - encrypted_extra: Optional[ # pylint: disable=unused-argument - dict[str, Any] - ] = None, + encrypted_extra: None # pylint: disable=unused-argument + | (dict[str, Any]) = None, ) -> str: return "gsheets://" @@ -136,7 +149,7 @@ class GSheetsEngineSpec(ShillelaghEngineSpec): def get_parameters_from_uri( cls, uri: str, # pylint: disable=unused-argument - encrypted_extra: Optional[dict[str, Any]] = None, + encrypted_extra: dict[str, Any] | None = None, ) -> Any: # Building parameters from encrypted_extra and uri if encrypted_extra: @@ -145,7 +158,7 @@ class GSheetsEngineSpec(ShillelaghEngineSpec): raise ValidationError("Invalid service credentials") @classmethod - def mask_encrypted_extra(cls, encrypted_extra: Optional[str]) -> Optional[str]: + def mask_encrypted_extra(cls, encrypted_extra: str | None) -> str | None: if encrypted_extra is None: return encrypted_extra @@ -162,9 +175,7 @@ class GSheetsEngineSpec(ShillelaghEngineSpec): return json.dumps(config) @classmethod - def unmask_encrypted_extra( - cls, old: Optional[str], new: Optional[str] - ) -> Optional[str]: + def unmask_encrypted_extra(cls, old: str | None, new: str | None) -> str | None: """ Reuse ``private_key`` if available and unchanged. """ @@ -299,3 +310,124 @@ class GSheetsEngineSpec(ShillelaghEngineSpec): ) idx += 1 return errors + + @staticmethod + def _do_post( + session: Session, + url: str, + body: dict[str, Any], + **kwargs: Any, + ) -> dict[str, Any]: + """ + POST to the Google API. + + Helper function that handles logging and error handling. + """ + _logger.info("POST %s", url) + _logger.debug(body) + response = session.post( + url, + json=body, + **kwargs, + ) + + payload = response.json() + _logger.debug(payload) + + if "error" in payload: + raise SupersetException(payload["error"]["message"]) + + return payload + + @classmethod + def df_to_sql( # pylint: disable=too-many-locals + cls, + database: Database, + table: Table, + df: pd.DataFrame, + to_sql_kwargs: dict[str, Any], + ) -> None: + """ + Create a new sheet and update the DB catalog. + + Since Google Sheets is not a database, uploading a file is slightly different + from other traditional databases. To create a table with a given name we first + create a spreadsheet with the contents of the dataframe, and we later update the + database catalog to add a mapping between the desired table name and the URL of + the new sheet. + + If the table already exists and the user wants it replaced we clear all the + cells in the existing sheet before uploading the new data. Appending to an + existing table is not supported because we can't ensure that the schemas match. + """ + # pylint: disable=import-outside-toplevel + from shillelagh.backends.apsw.dialects.base import get_adapter_for_table_name + + # grab the existing catalog, if any + extra = database.get_extra() + engine_params = extra.setdefault("engine_params", {}) + catalog = engine_params.setdefault("catalog", {}) + + # sanity checks + spreadsheet_url = catalog.get(table.table) + if spreadsheet_url and "if_exists" in to_sql_kwargs: + if to_sql_kwargs["if_exists"] == "append": + # no way we're going to append a dataframe to a spreadsheet, that's + # never going to work + raise SupersetException("Append operation not currently supported") + if to_sql_kwargs["if_exists"] == "fail": + raise SupersetException("Table already exists") + if to_sql_kwargs["if_exists"] == "replace": + pass + + # get the Google session from the Shillelagh adapter + with cls.get_engine(database) as engine: + with engine.connect() as conn: + # any GSheets URL will work to get a working session + adapter = get_adapter_for_table_name( + conn, + spreadsheet_url or EXAMPLE_GSHEETS_URL, + ) + session = adapter._get_session() # pylint: disable=protected-access + + # clear existing sheet, or create a new one + if spreadsheet_url: + spreadsheet_id = adapter._spreadsheet_id # pylint: disable=protected-access + range_ = adapter._sheet_name # pylint: disable=protected-access + url = ( + "https://sheets.googleapis.com/v4/spreadsheets/" + f"{spreadsheet_id}/values/{range_}:clear" + ) + cls._do_post(session, url, {}) + else: + payload = cls._do_post( + session, + "https://sheets.googleapis.com/v4/spreadsheets", + {"properties": {"title": table.table}}, + ) + spreadsheet_id = payload["spreadsheetId"] + range_ = payload["sheets"][0]["properties"]["title"] + spreadsheet_url = payload["spreadsheetUrl"] + + # insert data + body = { + "range": range_, + "majorDimension": "ROWS", + "values": df.fillna("").values.tolist(), + } + url = ( + "https://sheets.googleapis.com/v4/spreadsheets/" + f"{spreadsheet_id}/values/{range_}:append" + ) + cls._do_post( + session, + url, + body, + params={"valueInputOption": "USER_ENTERED"}, + ) + + # update catalog + catalog[table.table] = spreadsheet_url + database.extra = json.dumps(extra) + db.session.add(database) + db.session.commit() diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py index 4709a11377..cbdacc8f34 100644 --- a/tests/integration_tests/databases/api_tests.py +++ b/tests/integration_tests/databases/api_tests.py @@ -3153,7 +3153,7 @@ class TestDatabaseApi(SupersetTestCase): "preferred": False, "sqlalchemy_uri_placeholder": "gsheets://", "engine_information": { - "supports_file_upload": False, + "supports_file_upload": True, "disable_ssh_tunneling": True, }, }, diff --git a/tests/unit_tests/databases/api_test.py b/tests/unit_tests/databases/api_test.py index 899e2b0234..aa15645ddb 100644 --- a/tests/unit_tests/databases/api_test.py +++ b/tests/unit_tests/databases/api_test.py @@ -178,7 +178,7 @@ def test_database_connection( "driver": "gsheets", "engine_information": { "disable_ssh_tunneling": True, - "supports_file_upload": False, + "supports_file_upload": True, }, "expose_in_sqllab": True, "extra": '{\n "metadata_params": {},\n "engine_params": {},\n "metadata_cache_timeout": {},\n "schemas_allowed_for_file_upload": []\n}\n', @@ -249,7 +249,7 @@ def test_database_connection( "driver": "gsheets", "engine_information": { "disable_ssh_tunneling": True, - "supports_file_upload": False, + "supports_file_upload": True, }, "expose_in_sqllab": True, "force_ctas_schema": None, diff --git a/tests/unit_tests/db_engine_specs/test_gsheets.py b/tests/unit_tests/db_engine_specs/test_gsheets.py index 042e486642..7d7348c1a3 100644 --- a/tests/unit_tests/db_engine_specs/test_gsheets.py +++ b/tests/unit_tests/db_engine_specs/test_gsheets.py @@ -19,9 +19,13 @@ import json +import pandas as pd +import pytest from pytest_mock import MockFixture from superset.errors import ErrorLevel, SupersetError, SupersetErrorType +from superset.exceptions import SupersetException +from superset.sql_parse import Table class ProgrammingError(Exception): @@ -307,3 +311,91 @@ def test_unmask_encrypted_extra_when_new_is_none() -> None: new = None assert GSheetsEngineSpec.unmask_encrypted_extra(old, new) is None + + +def test_upload_new(mocker: MockFixture) -> None: + """ + Test file upload when the table does not exist. + """ + from superset.db_engine_specs.gsheets import GSheetsEngineSpec + + mocker.patch("superset.db_engine_specs.gsheets.db") + get_adapter_for_table_name = mocker.patch( + "shillelagh.backends.apsw.dialects.base.get_adapter_for_table_name" + ) + session = get_adapter_for_table_name()._get_session() + session.post().json.return_value = { + "spreadsheetId": 1, + "spreadsheetUrl": "https://docs.example.org", + "sheets": [{"properties": {"title": "sample_data"}}], + } + + database = mocker.MagicMock() + database.get_extra.return_value = {} + + df = pd.DataFrame([1, "foo", 3.0]) + table = Table("sample_data") + + GSheetsEngineSpec.df_to_sql(database, table, df, {}) + assert database.extra == json.dumps( + {"engine_params": {"catalog": {"sample_data": "https://docs.example.org"}}} + ) + + +def test_upload_existing(mocker: MockFixture) -> None: + """ + Test file upload when the table does exist. + """ + from superset.db_engine_specs.gsheets import GSheetsEngineSpec + + mocker.patch("superset.db_engine_specs.gsheets.db") + get_adapter_for_table_name = mocker.patch( + "shillelagh.backends.apsw.dialects.base.get_adapter_for_table_name" + ) + adapter = get_adapter_for_table_name() + adapter._spreadsheet_id = 1 + adapter._sheet_name = "sheet0" + session = adapter._get_session() + session.post().json.return_value = { + "spreadsheetId": 1, + "spreadsheetUrl": "https://docs.example.org", + "sheets": [{"properties": {"title": "sample_data"}}], + } + + database = mocker.MagicMock() + database.get_extra.return_value = { + "engine_params": {"catalog": {"sample_data": "https://docs.example.org"}} + } + + df = pd.DataFrame([1, "foo", 3.0]) + table = Table("sample_data") + + with pytest.raises(SupersetException) as excinfo: + GSheetsEngineSpec.df_to_sql(database, table, df, {"if_exists": "append"}) + assert str(excinfo.value) == "Append operation not currently supported" + + with pytest.raises(SupersetException) as excinfo: + GSheetsEngineSpec.df_to_sql(database, table, df, {"if_exists": "fail"}) + assert str(excinfo.value) == "Table already exists" + + GSheetsEngineSpec.df_to_sql(database, table, df, {"if_exists": "replace"}) + session.post.assert_has_calls( + [ + mocker.call(), + mocker.call( + "https://sheets.googleapis.com/v4/spreadsheets/1/values/sheet0:clear", + json={}, + ), + mocker.call().json(), + mocker.call( + "https://sheets.googleapis.com/v4/spreadsheets/1/values/sheet0:append", + json={ + "range": "sheet0", + "majorDimension": "ROWS", + "values": [[1], ["foo"], [3.0]], + }, + params={"valueInputOption": "USER_ENTERED"}, + ), + mocker.call().json(), + ] + )