mirror of https://github.com/apache/superset.git
fix: add connection testing params for snowflake (#9272)
* fix: add connection testingt params for snowflake * Linting
This commit is contained in:
parent
724b8a3c31
commit
3682702e91
|
@ -529,7 +529,11 @@ The role and warehouse can be omitted if defaults are defined for the user, i.e.
|
||||||
|
|
||||||
Make sure the user has privileges to access and use all required
|
Make sure the user has privileges to access and use all required
|
||||||
databases/schemas/tables/views/warehouses, as the Snowflake SQLAlchemy engine does
|
databases/schemas/tables/views/warehouses, as the Snowflake SQLAlchemy engine does
|
||||||
not test for user rights during engine creation.
|
not test for user/role rights during engine creation by default. However, when
|
||||||
|
pressing the "Test Connection" button in the Create or Edit Database dialog,
|
||||||
|
user/role credentials are validated by passing `"validate_default_parameters": True`
|
||||||
|
to the `connect()` method during engine creation. If the user/role is not authorized
|
||||||
|
to access the database, an error is recorded in the Superset logs.
|
||||||
|
|
||||||
See `Snowflake SQLAlchemy <https://github.com/snowflakedb/snowflake-sqlalchemy>`_.
|
See `Snowflake SQLAlchemy <https://github.com/snowflakedb/snowflake-sqlalchemy>`_.
|
||||||
|
|
||||||
|
|
|
@ -942,3 +942,14 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||||
if data and type(data[0]).__name__ == "Row":
|
if data and type(data[0]).__name__ == "Row":
|
||||||
data = [tuple(row) for row in data]
|
data = [tuple(row) for row in data]
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def mutate_db_for_connection_test(database: "Database") -> None:
|
||||||
|
"""
|
||||||
|
Some databases require passing additional parameters for validating database
|
||||||
|
connections. This method makes it possible to mutate the database instance prior
|
||||||
|
to testing if a connection is ok.
|
||||||
|
|
||||||
|
:param database: instance to be mutated
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
|
@ -14,14 +14,18 @@
|
||||||
# KIND, either express or implied. See the License for the
|
# KIND, either express or implied. See the License for the
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
|
import json
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional
|
from typing import Optional, TYPE_CHECKING
|
||||||
from urllib import parse
|
from urllib import parse
|
||||||
|
|
||||||
from sqlalchemy.engine.url import URL
|
from sqlalchemy.engine.url import URL
|
||||||
|
|
||||||
from superset.db_engine_specs.postgres import PostgresBaseEngineSpec
|
from superset.db_engine_specs.postgres import PostgresBaseEngineSpec
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from superset.models.core import Database # pylint: disable=unused-import
|
||||||
|
|
||||||
|
|
||||||
class SnowflakeEngineSpec(PostgresBaseEngineSpec):
|
class SnowflakeEngineSpec(PostgresBaseEngineSpec):
|
||||||
engine = "snowflake"
|
engine = "snowflake"
|
||||||
|
@ -77,3 +81,19 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
|
||||||
if tt == "TIMESTAMP":
|
if tt == "TIMESTAMP":
|
||||||
return f"""TO_TIMESTAMP('{dttm.isoformat(timespec="microseconds")}')"""
|
return f"""TO_TIMESTAMP('{dttm.isoformat(timespec="microseconds")}')"""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def mutate_db_for_connection_test(database: "Database") -> None:
|
||||||
|
"""
|
||||||
|
By default, snowflake doesn't validate if the user/role has access to the chosen
|
||||||
|
database.
|
||||||
|
|
||||||
|
:param database: instance to be mutated
|
||||||
|
"""
|
||||||
|
extra = json.loads(database.extra or "{}")
|
||||||
|
engine_params = extra.get("engine_params", {})
|
||||||
|
connect_args = engine_params.get("connect_args", {})
|
||||||
|
connect_args["validate_default_parameters"] = True
|
||||||
|
engine_params["connect_args"] = connect_args
|
||||||
|
extra["engine_params"] = engine_params
|
||||||
|
database.extra = json.dumps(extra)
|
||||||
|
|
|
@ -1367,6 +1367,7 @@ class Superset(BaseSupersetView):
|
||||||
encrypted_extra=json.dumps(request.json.get("encrypted_extra", {})),
|
encrypted_extra=json.dumps(request.json.get("encrypted_extra", {})),
|
||||||
)
|
)
|
||||||
database.set_sqlalchemy_uri(uri)
|
database.set_sqlalchemy_uri(uri)
|
||||||
|
database.db_engine_spec.mutate_db_for_connection_test(database)
|
||||||
|
|
||||||
username = g.user.username if g.user is not None else None
|
username = g.user.username if g.user is not None else None
|
||||||
engine = database.get_sqla_engine(user_name=username)
|
engine = database.get_sqla_engine(user_name=username)
|
||||||
|
@ -1402,7 +1403,9 @@ class Superset(BaseSupersetView):
|
||||||
return json_error_response(_(str(e)), 400)
|
return json_error_response(_(str(e)), 400)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Unexpected error %s", e)
|
logger.error("Unexpected error %s", e)
|
||||||
return json_error_response(_("Unexpected error occurred."), 400)
|
return json_error_response(
|
||||||
|
_("Unexpected error occurred, please check your logs for details"), 400
|
||||||
|
)
|
||||||
|
|
||||||
@api
|
@api
|
||||||
@has_access_api
|
@has_access_api
|
||||||
|
|
|
@ -14,9 +14,10 @@
|
||||||
# KIND, either express or implied. See the License for the
|
# KIND, either express or implied. See the License for the
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
from sqlalchemy import column
|
import json
|
||||||
|
|
||||||
from superset.db_engine_specs.snowflake import SnowflakeEngineSpec
|
from superset.db_engine_specs.snowflake import SnowflakeEngineSpec
|
||||||
|
from superset.models.core import Database
|
||||||
from tests.db_engine_specs.base_tests import DbEngineSpecTestCase
|
from tests.db_engine_specs.base_tests import DbEngineSpecTestCase
|
||||||
|
|
||||||
|
|
||||||
|
@ -24,16 +25,21 @@ class SnowflakeTestCase(DbEngineSpecTestCase):
|
||||||
def test_convert_dttm(self):
|
def test_convert_dttm(self):
|
||||||
dttm = self.get_dttm()
|
dttm = self.get_dttm()
|
||||||
|
|
||||||
self.assertEqual(
|
test_cases = {
|
||||||
SnowflakeEngineSpec.convert_dttm("DATE", dttm), "TO_DATE('2019-01-02')"
|
"DATE": "TO_DATE('2019-01-02')",
|
||||||
)
|
"DATETIME": "CAST('2019-01-02T03:04:05.678900' AS DATETIME)",
|
||||||
|
"TIMESTAMP": "TO_TIMESTAMP('2019-01-02T03:04:05.678900')",
|
||||||
|
}
|
||||||
|
|
||||||
self.assertEqual(
|
for type_, expected in test_cases.items():
|
||||||
SnowflakeEngineSpec.convert_dttm("DATETIME", dttm),
|
self.assertEqual(SnowflakeEngineSpec.convert_dttm(type_, dttm), expected)
|
||||||
"CAST('2019-01-02T03:04:05.678900' AS DATETIME)",
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEqual(
|
def test_database_connection_test_mutator(self):
|
||||||
SnowflakeEngineSpec.convert_dttm("TIMESTAMP", dttm),
|
database = Database(sqlalchemy_uri="snowflake://abc")
|
||||||
"TO_TIMESTAMP('2019-01-02T03:04:05.678900')",
|
SnowflakeEngineSpec.mutate_db_for_connection_test(database)
|
||||||
|
engine_params = json.loads(database.extra or "{}")
|
||||||
|
|
||||||
|
self.assertDictEqual(
|
||||||
|
{"engine_params": {"connect_args": {"validate_default_parameters": True}}},
|
||||||
|
engine_params,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue