fix: add connection testing params for snowflake (#9272)

* fix: add connection testingt params for snowflake

* Linting
This commit is contained in:
Ville Brofeldt 2020-03-11 06:51:57 +02:00 committed by GitHub
parent 724b8a3c31
commit 3682702e91
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 58 additions and 14 deletions

View File

@ -529,7 +529,11 @@ The role and warehouse can be omitted if defaults are defined for the user, i.e.
Make sure the user has privileges to access and use all required Make sure the user has privileges to access and use all required
databases/schemas/tables/views/warehouses, as the Snowflake SQLAlchemy engine does databases/schemas/tables/views/warehouses, as the Snowflake SQLAlchemy engine does
not test for user rights during engine creation. not test for user/role rights during engine creation by default. However, when
pressing the "Test Connection" button in the Create or Edit Database dialog,
user/role credentials are validated by passing `"validate_default_parameters": True`
to the `connect()` method during engine creation. If the user/role is not authorized
to access the database, an error is recorded in the Superset logs.
See `Snowflake SQLAlchemy <https://github.com/snowflakedb/snowflake-sqlalchemy>`_. See `Snowflake SQLAlchemy <https://github.com/snowflakedb/snowflake-sqlalchemy>`_.

View File

@ -942,3 +942,14 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
if data and type(data[0]).__name__ == "Row": if data and type(data[0]).__name__ == "Row":
data = [tuple(row) for row in data] data = [tuple(row) for row in data]
return data return data
@staticmethod
def mutate_db_for_connection_test(database: "Database") -> None:
"""
Some databases require passing additional parameters for validating database
connections. This method makes it possible to mutate the database instance prior
to testing if a connection is ok.
:param database: instance to be mutated
"""
return None

View File

@ -14,14 +14,18 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import json
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional, TYPE_CHECKING
from urllib import parse from urllib import parse
from sqlalchemy.engine.url import URL from sqlalchemy.engine.url import URL
from superset.db_engine_specs.postgres import PostgresBaseEngineSpec from superset.db_engine_specs.postgres import PostgresBaseEngineSpec
if TYPE_CHECKING:
from superset.models.core import Database # pylint: disable=unused-import
class SnowflakeEngineSpec(PostgresBaseEngineSpec): class SnowflakeEngineSpec(PostgresBaseEngineSpec):
engine = "snowflake" engine = "snowflake"
@ -77,3 +81,19 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
if tt == "TIMESTAMP": if tt == "TIMESTAMP":
return f"""TO_TIMESTAMP('{dttm.isoformat(timespec="microseconds")}')""" return f"""TO_TIMESTAMP('{dttm.isoformat(timespec="microseconds")}')"""
return None return None
@staticmethod
def mutate_db_for_connection_test(database: "Database") -> None:
"""
By default, snowflake doesn't validate if the user/role has access to the chosen
database.
:param database: instance to be mutated
"""
extra = json.loads(database.extra or "{}")
engine_params = extra.get("engine_params", {})
connect_args = engine_params.get("connect_args", {})
connect_args["validate_default_parameters"] = True
engine_params["connect_args"] = connect_args
extra["engine_params"] = engine_params
database.extra = json.dumps(extra)

View File

@ -1367,6 +1367,7 @@ class Superset(BaseSupersetView):
encrypted_extra=json.dumps(request.json.get("encrypted_extra", {})), encrypted_extra=json.dumps(request.json.get("encrypted_extra", {})),
) )
database.set_sqlalchemy_uri(uri) database.set_sqlalchemy_uri(uri)
database.db_engine_spec.mutate_db_for_connection_test(database)
username = g.user.username if g.user is not None else None username = g.user.username if g.user is not None else None
engine = database.get_sqla_engine(user_name=username) engine = database.get_sqla_engine(user_name=username)
@ -1402,7 +1403,9 @@ class Superset(BaseSupersetView):
return json_error_response(_(str(e)), 400) return json_error_response(_(str(e)), 400)
except Exception as e: except Exception as e:
logger.error("Unexpected error %s", e) logger.error("Unexpected error %s", e)
return json_error_response(_("Unexpected error occurred."), 400) return json_error_response(
_("Unexpected error occurred, please check your logs for details"), 400
)
@api @api
@has_access_api @has_access_api

View File

@ -14,9 +14,10 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
from sqlalchemy import column import json
from superset.db_engine_specs.snowflake import SnowflakeEngineSpec from superset.db_engine_specs.snowflake import SnowflakeEngineSpec
from superset.models.core import Database
from tests.db_engine_specs.base_tests import DbEngineSpecTestCase from tests.db_engine_specs.base_tests import DbEngineSpecTestCase
@ -24,16 +25,21 @@ class SnowflakeTestCase(DbEngineSpecTestCase):
def test_convert_dttm(self): def test_convert_dttm(self):
dttm = self.get_dttm() dttm = self.get_dttm()
self.assertEqual( test_cases = {
SnowflakeEngineSpec.convert_dttm("DATE", dttm), "TO_DATE('2019-01-02')" "DATE": "TO_DATE('2019-01-02')",
) "DATETIME": "CAST('2019-01-02T03:04:05.678900' AS DATETIME)",
"TIMESTAMP": "TO_TIMESTAMP('2019-01-02T03:04:05.678900')",
}
self.assertEqual( for type_, expected in test_cases.items():
SnowflakeEngineSpec.convert_dttm("DATETIME", dttm), self.assertEqual(SnowflakeEngineSpec.convert_dttm(type_, dttm), expected)
"CAST('2019-01-02T03:04:05.678900' AS DATETIME)",
)
self.assertEqual( def test_database_connection_test_mutator(self):
SnowflakeEngineSpec.convert_dttm("TIMESTAMP", dttm), database = Database(sqlalchemy_uri="snowflake://abc")
"TO_TIMESTAMP('2019-01-02T03:04:05.678900')", SnowflakeEngineSpec.mutate_db_for_connection_test(database)
engine_params = json.loads(database.extra or "{}")
self.assertDictEqual(
{"engine_params": {"connect_args": {"validate_default_parameters": True}}},
engine_params,
) )