fix: remove unneeded complexity in migration (#19022)

This commit is contained in:
Beto Dealmeida 2022-03-03 16:56:38 -08:00 committed by GitHub
parent 77063cc814
commit 50bb86d666
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -25,25 +25,22 @@ Create Date: 2021-11-11 16:41:53.266965
"""
import json
from typing import Any, Dict, List, Optional, Type
from typing import List
from uuid import uuid4
import sqlalchemy as sa
from alembic import op
from sqlalchemy import and_, inspect, or_
from sqlalchemy.engine import create_engine, Engine
from sqlalchemy.engine.url import make_url, URL
from sqlalchemy.exc import ArgumentError
from sqlalchemy.engine.url import make_url
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import backref, relationship, Session
from sqlalchemy.schema import UniqueConstraint
from sqlalchemy_utils import UUIDType
from superset import app, db, db_engine_specs
from superset import app, db
from superset.connectors.sqla.models import ADDITIVE_METRIC_TYPES
from superset.extensions import encrypted_field_factory, security_manager
from superset.extensions import encrypted_field_factory
from superset.sql_parse import ParsedQuery
from superset.utils.memoized import memoized
# revision identifiers, used by Alembic.
revision = "b8d3a24d9131"
@ -78,86 +75,6 @@ class Database(Base):
)
server_cert = sa.Column(encrypted_field_factory.create(sa.Text), nullable=True)
@property
def sqlalchemy_uri_decrypted(self) -> str:
try:
url = make_url(self.sqlalchemy_uri)
except (ArgumentError, ValueError):
return "dialect://invalid_uri"
if custom_password_store:
url.password = custom_password_store(url)
else:
url.password = self.password
return str(url)
@property
def backend(self) -> str:
sqlalchemy_url = make_url(self.sqlalchemy_uri_decrypted)
return sqlalchemy_url.get_backend_name() # pylint: disable=no-member
@classmethod
@memoized
def get_db_engine_spec_for_backend(
cls, backend: str
) -> Type[db_engine_specs.BaseEngineSpec]:
engines = db_engine_specs.get_engine_specs()
return engines.get(backend, db_engine_specs.BaseEngineSpec)
@property
def db_engine_spec(self) -> Type[db_engine_specs.BaseEngineSpec]:
return self.get_db_engine_spec_for_backend(self.backend)
def get_extra(self) -> Dict[str, Any]:
return self.db_engine_spec.get_extra_params(self)
def get_effective_user(
self, object_url: URL, user_name: Optional[str] = None,
) -> Optional[str]:
effective_username = None
if self.impersonate_user:
effective_username = object_url.username
if user_name:
effective_username = user_name
return effective_username
def get_encrypted_extra(self) -> Dict[str, Any]:
return json.loads(self.encrypted_extra) if self.encrypted_extra else {}
@memoized(watch=("impersonate_user", "sqlalchemy_uri_decrypted", "extra"))
def get_sqla_engine(self, schema: Optional[str] = None) -> Engine:
extra = self.get_extra()
sqlalchemy_url = make_url(self.sqlalchemy_uri_decrypted)
self.db_engine_spec.adjust_database_uri(sqlalchemy_url, schema)
effective_username = self.get_effective_user(sqlalchemy_url, "admin")
# If using MySQL or Presto for example, will set url.username
self.db_engine_spec.modify_url_for_impersonation(
sqlalchemy_url, self.impersonate_user, effective_username
)
params = extra.get("engine_params", {})
connect_args = params.get("connect_args", {})
if self.impersonate_user:
self.db_engine_spec.update_impersonation_config(
connect_args, str(sqlalchemy_url), effective_username
)
if connect_args:
params["connect_args"] = connect_args
params.update(self.get_encrypted_extra())
if DB_CONNECTION_MUTATOR:
sqlalchemy_url, params = DB_CONNECTION_MUTATOR(
sqlalchemy_url,
params,
effective_username,
security_manager,
"migration",
)
return create_engine(sqlalchemy_url, **params)
class TableColumn(Base):
@ -325,8 +242,9 @@ def after_insert(target: SqlaTable) -> None: # pylint: disable=too-many-locals
)
if not database:
return
engine = database.get_sqla_engine(schema=target.schema)
conditional_quote = engine.dialect.identifier_preparer.quote
url = make_url(database.sqlalchemy_uri)
dialect_class = url.get_dialect()
conditional_quote = dialect_class().identifier_preparer.quote
# create columns
columns = []