refactor: pass all properties to validate_parameters (#21487)

This commit is contained in:
Elizabeth Thompson 2022-10-03 17:48:54 -07:00 committed by GitHub
parent 4417c6e3e2
commit e98943e580
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 109 additions and 74 deletions

View File

@ -120,6 +120,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
"allow_cvas",
"allow_dml",
"backend",
"driver",
"force_ctas_schema",
"impersonate_user",
"masked_encrypted_extra",
@ -269,6 +270,9 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
if new_model.parameters:
item["parameters"] = new_model.parameters
if new_model.driver:
item["driver"] = new_model.driver
return self.response(201, id=new_model.id, result=item)
except DatabaseInvalidError as ex:
return self.response_422(message=ex.normalized_messages())

View File

@ -38,8 +38,8 @@ BYPASS_VALIDATION_ENGINES = {"bigquery"}
class ValidateDatabaseParametersCommand(BaseCommand):
def __init__(self, parameters: Dict[str, Any]):
self._properties = parameters.copy()
def __init__(self, properties: Dict[str, Any]):
self._properties = properties.copy()
self._model: Optional[Database] = None
def run(self) -> None:
@ -66,9 +66,7 @@ class ValidateDatabaseParametersCommand(BaseCommand):
)
# perform initial validation
errors = engine_spec.validate_parameters( # type: ignore
self._properties.get("parameters", {})
)
errors = engine_spec.validate_parameters(self._properties) # type: ignore
if errors:
event_logger.log_with_context(action="validation_error", engine=engine)
raise InvalidParametersError(errors)

View File

@ -1685,6 +1685,10 @@ class BasicParametersType(TypedDict, total=False):
encryption: bool
class BasicPropertiesType(TypedDict):
parameters: BasicParametersType
class BasicParametersMixin:
"""
Mixin for configuring DB engine specs via a dictionary.
@ -1762,7 +1766,7 @@ class BasicParametersMixin:
@classmethod
def validate_parameters(
cls, parameters: BasicParametersType
cls, properties: BasicPropertiesType
) -> List[SupersetError]:
"""
Validates any number of parameters, for progressive validation.
@ -1773,6 +1777,7 @@ class BasicParametersMixin:
errors: List[SupersetError] = []
required = {"host", "port", "username", "database"}
parameters = properties.get("parameters", {})
present = {key for key in parameters if parameters.get(key, ())}
missing = sorted(required - present)

View File

@ -34,7 +34,7 @@ from typing_extensions import TypedDict
from superset.constants import PASSWORD_MASK
from superset.databases.schemas import encrypted_field_properties, EncryptedString
from superset.databases.utils import make_url_safe
from superset.db_engine_specs.base import BaseEngineSpec
from superset.db_engine_specs.base import BaseEngineSpec, BasicPropertiesType
from superset.db_engine_specs.exceptions import SupersetDBAPIDisconnectionError
from superset.errors import SupersetError, SupersetErrorType
from superset.sql_parse import Table
@ -450,7 +450,8 @@ class BigQueryEngineSpec(BaseEngineSpec):
@classmethod
def validate_parameters(
cls, parameters: BigQueryParametersType # pylint: disable=unused-argument
cls,
properties: BasicPropertiesType, # pylint: disable=unused-argument
) -> List[SupersetError]:
return []

View File

@ -58,6 +58,10 @@ class GSheetsParametersType(TypedDict):
catalog: Dict[str, str]
class GSheetsPropertiesType(TypedDict):
parameters: GSheetsParametersType
class GSheetsEngineSpec(SqliteEngineSpec):
"""Engine for Google spreadsheets"""
@ -208,9 +212,10 @@ class GSheetsEngineSpec(SqliteEngineSpec):
@classmethod
def validate_parameters(
cls,
parameters: GSheetsParametersType,
properties: GSheetsPropertiesType,
) -> List[SupersetError]:
errors: List[SupersetError] = []
parameters = properties.get("parameters", {})
encrypted_credentials = parameters.get("service_account_info") or "{}"
# On create the encrypted credentials are a string,

View File

@ -32,6 +32,7 @@ from sqlalchemy.engine.url import URL
from typing_extensions import TypedDict
from superset.databases.utils import make_url_safe
from superset.db_engine_specs.base import BasicPropertiesType
from superset.db_engine_specs.postgres import PostgresBaseEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.models.sql_lab import Query
@ -242,7 +243,7 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
@classmethod
def validate_parameters(
cls, parameters: SnowflakeParametersType
cls, properties: BasicPropertiesType
) -> List[SupersetError]:
errors: List[SupersetError] = []
required = {
@ -253,6 +254,7 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
"role",
"password",
}
parameters = properties.get("parameters", {})
present = {key for key in parameters if parameters.get(key, ())}
missing = sorted(required - present)

View File

@ -244,8 +244,11 @@ class Database(
@property
def backend(self) -> str:
sqlalchemy_url = make_url_safe(self.sqlalchemy_uri_decrypted)
return sqlalchemy_url.get_backend_name()
return self.url_object.get_backend_name()
@property
def driver(self) -> str:
return self.url_object.get_driver_name()
@property
def masked_encrypted_extra(self) -> Optional[str]:
@ -253,14 +256,12 @@ class Database(
@property
def parameters(self) -> Dict[str, Any]:
db_engine_spec = self.db_engine_spec
# Database parameters are a dictionary of values that are used to make up
# the sqlalchemy_uri
# When returning the parameters we should use the masked SQLAlchemy URI and the
# masked ``encrypted_extra`` to prevent exposing sensitive credentials.
masked_uri = make_url_safe(self.sqlalchemy_uri)
masked_encrypted_extra = db_engine_spec.mask_encrypted_extra(
self.encrypted_extra
)
masked_encrypted_extra = self.masked_encrypted_extra
encrypted_config = {}
if masked_encrypted_extra is not None:
try:
@ -270,7 +271,7 @@ class Database(
try:
# pylint: disable=useless-suppression
parameters = db_engine_spec.get_parameters_from_uri( # type: ignore
parameters = self.db_engine_spec.get_parameters_from_uri( # type: ignore
masked_uri,
encrypted_extra=encrypted_config,
)

View File

@ -421,28 +421,32 @@ def test_validate(is_port_open, is_hostname_valid):
is_hostname_valid.return_value = True
is_port_open.return_value = True
parameters = {
"host": "localhost",
"port": 5432,
"username": "username",
"password": "password",
"database": "dbname",
"query": {"sslmode": "verify-full"},
properties = {
"parameters": {
"host": "localhost",
"port": 5432,
"username": "username",
"password": "password",
"database": "dbname",
"query": {"sslmode": "verify-full"},
}
}
errors = BasicParametersMixin.validate_parameters(parameters)
errors = BasicParametersMixin.validate_parameters(properties)
assert errors == []
def test_validate_parameters_missing():
parameters = {
"host": "",
"port": None,
"username": "",
"password": "",
"database": "",
"query": {},
properties = {
"parameters": {
"host": "",
"port": None,
"username": "",
"password": "",
"database": "",
"query": {},
}
}
errors = BasicParametersMixin.validate_parameters(parameters)
errors = BasicParametersMixin.validate_parameters(properties)
assert errors == [
SupersetError(
message=(
@ -459,15 +463,17 @@ def test_validate_parameters_missing():
def test_validate_parameters_invalid_host(is_hostname_valid):
is_hostname_valid.return_value = False
parameters = {
"host": "localhost",
"port": None,
"username": "username",
"password": "password",
"database": "dbname",
"query": {"sslmode": "verify-full"},
properties = {
"parameters": {
"host": "localhost",
"port": None,
"username": "username",
"password": "password",
"database": "dbname",
"query": {"sslmode": "verify-full"},
}
}
errors = BasicParametersMixin.validate_parameters(parameters)
errors = BasicParametersMixin.validate_parameters(properties)
assert errors == [
SupersetError(
message="One or more parameters are missing: port",
@ -490,15 +496,17 @@ def test_validate_parameters_port_closed(is_port_open, is_hostname_valid):
is_hostname_valid.return_value = True
is_port_open.return_value = False
parameters = {
"host": "localhost",
"port": 5432,
"username": "username",
"password": "password",
"database": "dbname",
"query": {"sslmode": "verify-full"},
properties = {
"parameters": {
"host": "localhost",
"port": 5432,
"username": "username",
"password": "password",
"database": "dbname",
"query": {"sslmode": "verify-full"},
}
}
errors = BasicParametersMixin.validate_parameters(parameters)
errors = BasicParametersMixin.validate_parameters(properties)
assert errors == [
SupersetError(
message="The port is closed.",

View File

@ -22,6 +22,7 @@ from typing import Any
from uuid import UUID
import pytest
from pytest_mock import MockFixture
from sqlalchemy.orm.session import Session
@ -53,6 +54,7 @@ def test_post_with_uuid(
def test_password_mask(
mocker: MockFixture,
app: Any,
session: Session,
client: Any,
@ -92,6 +94,10 @@ def test_password_mask(
session.add(database)
session.commit()
# mock the lookup so that we don't need to include the driver
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
mocker.patch("superset.utils.log.DBEventLogger.log")
response = client.get("/api/v1/database/1")
assert (
response.json["result"]["parameters"]["service_account_info"]["private_key"]

View File

@ -134,7 +134,6 @@ def test_database_parameters_schema_mixin_invalid_engine(
try:
dummy_schema.load(payload)
except ValidationError as err:
print(err.messages)
assert err.messages == {
"_schema": ['Engine "dummy_engine" is not a valid engine.']
}

View File

@ -33,14 +33,16 @@ class ProgrammingError(Exception):
def test_validate_parameters_simple() -> None:
from superset.db_engine_specs.gsheets import (
GSheetsEngineSpec,
GSheetsParametersType,
GSheetsPropertiesType,
)
parameters: GSheetsParametersType = {
"service_account_info": "",
"catalog": {},
properties: GSheetsPropertiesType = {
"parameters": {
"service_account_info": "",
"catalog": {},
}
}
errors = GSheetsEngineSpec.validate_parameters(parameters)
errors = GSheetsEngineSpec.validate_parameters(properties)
assert errors == [
SupersetError(
message="Sheet name is required",
@ -56,7 +58,7 @@ def test_validate_parameters_catalog(
) -> None:
from superset.db_engine_specs.gsheets import (
GSheetsEngineSpec,
GSheetsParametersType,
GSheetsPropertiesType,
)
g = mocker.patch("superset.db_engine_specs.gsheets.g")
@ -71,15 +73,17 @@ def test_validate_parameters_catalog(
ProgrammingError("Unsupported table: https://www.google.com/"),
]
parameters: GSheetsParametersType = {
"service_account_info": "",
"catalog": {
"private_sheet": "https://docs.google.com/spreadsheets/d/1/edit",
"public_sheet": "https://docs.google.com/spreadsheets/d/1/edit#gid=1",
"not_a_sheet": "https://www.google.com/",
},
properties: GSheetsPropertiesType = {
"parameters": {
"service_account_info": "",
"catalog": {
"private_sheet": "https://docs.google.com/spreadsheets/d/1/edit",
"public_sheet": "https://docs.google.com/spreadsheets/d/1/edit#gid=1",
"not_a_sheet": "https://www.google.com/",
},
}
}
errors = GSheetsEngineSpec.validate_parameters(parameters) # ignore: type
errors = GSheetsEngineSpec.validate_parameters(properties) # ignore: type
assert errors == [
SupersetError(
@ -146,7 +150,7 @@ def test_validate_parameters_catalog_and_credentials(
) -> None:
from superset.db_engine_specs.gsheets import (
GSheetsEngineSpec,
GSheetsParametersType,
GSheetsPropertiesType,
)
g = mocker.patch("superset.db_engine_specs.gsheets.g")
@ -161,15 +165,17 @@ def test_validate_parameters_catalog_and_credentials(
ProgrammingError("Unsupported table: https://www.google.com/"),
]
parameters: GSheetsParametersType = {
"service_account_info": "",
"catalog": {
"private_sheet": "https://docs.google.com/spreadsheets/d/1/edit",
"public_sheet": "https://docs.google.com/spreadsheets/d/1/edit#gid=1",
"not_a_sheet": "https://www.google.com/",
},
properties: GSheetsPropertiesType = {
"parameters": {
"service_account_info": "",
"catalog": {
"private_sheet": "https://docs.google.com/spreadsheets/d/1/edit",
"public_sheet": "https://docs.google.com/spreadsheets/d/1/edit#gid=1",
"not_a_sheet": "https://www.google.com/",
},
}
}
errors = GSheetsEngineSpec.validate_parameters(parameters) # ignore: type
errors = GSheetsEngineSpec.validate_parameters(properties) # ignore: type
assert errors == [
SupersetError(
message=(