[db migration] change datasources-clusters foreign key to cluster_id (#8576)

* [db migration] change datasources foreign key to cluster_id

* address pr comments

* address pr comment, fix ci
This commit is contained in:
serenajiang 2020-01-13 11:02:36 -08:00 committed by John Bodley
parent d9e7db69fe
commit 1f6f4ed879
9 changed files with 184 additions and 42 deletions

View File

@ -45,7 +45,7 @@ from sqlalchemy import (
Text,
UniqueConstraint,
)
from sqlalchemy.orm import backref, relationship, RelationshipProperty, Session
from sqlalchemy.orm import backref, relationship, Session
from sqlalchemy_utils import EncryptedType
from superset import conf, db, security_manager
@ -222,7 +222,7 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin):
session = db.session
ds_list = (
session.query(DruidDatasource)
.filter(DruidDatasource.cluster_name == self.cluster_name)
.filter(DruidDatasource.cluster_id == self.id)
.filter(DruidDatasource.datasource_name.in_(datasource_names))
)
ds_map = {ds.name: ds for ds in ds_list}
@ -468,7 +468,7 @@ class DruidDatasource(Model, BaseDatasource):
"""ORM object referencing Druid datasources (tables)"""
__tablename__ = "datasources"
__table_args__ = (UniqueConstraint("datasource_name", "cluster_name"),)
__table_args__ = (UniqueConstraint("datasource_name", "cluster_id"),)
type = "druid"
query_language = "json"
@ -484,11 +484,9 @@ class DruidDatasource(Model, BaseDatasource):
is_hidden = Column(Boolean, default=False)
filter_select_enabled = Column(Boolean, default=True) # override default
fetch_values_from = Column(String(100))
cluster_name = Column(
String(250), ForeignKey("clusters.cluster_name"), nullable=False
)
cluster_id = Column(Integer, ForeignKey("clusters.id"), nullable=False)
cluster = relationship(
"DruidCluster", backref="datasources", foreign_keys=[cluster_name]
"DruidCluster", backref="datasources", foreign_keys=[cluster_id]
)
owners = relationship(
owner_class, secondary=druiddatasource_user, backref="druiddatasources"
@ -499,7 +497,7 @@ class DruidDatasource(Model, BaseDatasource):
"is_hidden",
"description",
"default_endpoint",
"cluster_name",
"cluster_id",
"offset",
"cache_timeout",
"params",
@ -511,7 +509,15 @@ class DruidDatasource(Model, BaseDatasource):
export_children = ["columns", "metrics"]
@property
def database(self) -> RelationshipProperty:
def cluster_name(self) -> str:
cluster = (
self.cluster
or db.session.query(DruidCluster).filter_by(id=self.cluster_id).one()
)
return cluster.cluster_name
@property
def database(self) -> DruidCluster:
return self.cluster
@property
@ -608,17 +614,13 @@ class DruidDatasource(Model, BaseDatasource):
db.session.query(DruidDatasource)
.filter(
DruidDatasource.datasource_name == d.datasource_name,
DruidCluster.cluster_name == d.cluster_name,
DruidDatasource.cluster_id == d.cluster_id,
)
.first()
)
def lookup_cluster(d: DruidDatasource) -> Optional[DruidCluster]:
return (
db.session.query(DruidCluster)
.filter_by(cluster_name=d.cluster_name)
.one()
)
return db.session.query(DruidCluster).filter_by(id=d.cluster_id).first()
return import_datasource.import_datasource(
db.session, i_datasource, lookup_cluster, lookup_datasource, import_time
@ -1615,12 +1617,7 @@ class DruidDatasource(Model, BaseDatasource):
def query_datasources_by_name(
cls, session: Session, database: Database, datasource_name: str, schema=None
) -> List["DruidDatasource"]:
return (
session.query(cls)
.filter_by(cluster_name=database.id)
.filter_by(datasource_name=datasource_name)
.all()
)
return []
def external_metadata(self) -> List[Dict]:
self.merge_flag = True

View File

@ -341,7 +341,7 @@ class DruidDatasourceModelView(DatasourceModelView, DeleteMixin, YamlExportMixin
with db.session.no_autoflush:
query = db.session.query(models.DruidDatasource).filter(
models.DruidDatasource.datasource_name == datasource.datasource_name,
models.DruidDatasource.cluster_name == datasource.cluster.id,
models.DruidDatasource.cluster_id == datasource.cluster_id,
)
if db.session.query(query.exists()).scalar():
raise Exception(get_datasource_exist_error_msg(datasource.full_name))

View File

@ -0,0 +1,124 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""datasource_cluster_fk
Revision ID: e96dbf2cfef0
Revises: 817e1c9b09d0
Create Date: 2020-01-08 01:17:40.127610
"""
import sqlalchemy as sa
from alembic import op
from superset import db
from superset.utils.core import (
generic_find_fk_constraint_name,
generic_find_uq_constraint_name,
)
# revision identifiers, used by Alembic.
revision = "e96dbf2cfef0"
down_revision = "817e1c9b09d0"
def upgrade():
bind = op.get_bind()
insp = sa.engine.reflection.Inspector.from_engine(bind)
# Add cluster_id column
with op.batch_alter_table("datasources") as batch_op:
batch_op.add_column(sa.Column("cluster_id", sa.Integer()))
# Update cluster_id values
metadata = sa.MetaData(bind=bind)
datasources = sa.Table("datasources", metadata, autoload=True)
clusters = sa.Table("clusters", metadata, autoload=True)
statement = datasources.update().values(
cluster_id=sa.select([clusters.c.id])
.where(datasources.c.cluster_name == clusters.c.cluster_name)
.as_scalar()
)
bind.execute(statement)
with op.batch_alter_table("datasources") as batch_op:
# Drop cluster_name column
fk_constraint_name = generic_find_fk_constraint_name(
"datasources", {"cluster_name"}, "clusters", insp
)
uq_constraint_name = generic_find_uq_constraint_name(
"datasources", {"cluster_name", "datasource_name"}, insp
)
batch_op.drop_constraint(fk_constraint_name, type_="foreignkey")
batch_op.drop_constraint(uq_constraint_name, type_="unique")
batch_op.drop_column("cluster_name")
# Add constraints to cluster_id column
batch_op.alter_column("cluster_id", existing_type=sa.Integer, nullable=False)
batch_op.create_unique_constraint(
"uq_datasources_cluster_id", ["cluster_id", "datasource_name"]
)
batch_op.create_foreign_key(
"fk_datasources_cluster_id_clusters", "clusters", ["cluster_id"], ["id"]
)
def downgrade():
bind = op.get_bind()
insp = sa.engine.reflection.Inspector.from_engine(bind)
# Add cluster_name column
with op.batch_alter_table("datasources") as batch_op:
batch_op.add_column(sa.Column("cluster_name", sa.String(250)))
# Update cluster_name values
metadata = sa.MetaData(bind=bind)
datasources = sa.Table("datasources", metadata, autoload=True)
clusters = sa.Table("clusters", metadata, autoload=True)
statement = datasources.update().values(
cluster_name=sa.select([clusters.c.cluster_name])
.where(datasources.c.cluster_id == clusters.c.id)
.as_scalar()
)
bind.execute(statement)
with op.batch_alter_table("datasources") as batch_op:
# Drop cluster_id column
fk_constraint_name = generic_find_fk_constraint_name(
"datasources", {"id"}, "clusters", insp
)
uq_constraint_name = generic_find_uq_constraint_name(
"datasources", {"cluster_id", "datasource_name"}, insp
)
batch_op.drop_constraint(fk_constraint_name, type_="foreignkey")
batch_op.drop_constraint(uq_constraint_name, type_="unique")
batch_op.drop_column("cluster_id")
# Add constraints to cluster_name column
batch_op.alter_column(
"cluster_name", existing_type=sa.String(250), nullable=False
)
batch_op.create_unique_constraint(
"uq_datasources_cluster_name", ["cluster_name", "datasource_name"]
)
batch_op.create_foreign_key(
"fk_datasources_cluster_name_clusters",
"clusters",
["cluster_name"],
["cluster_name"],
)

View File

@ -35,7 +35,7 @@ from email.mime.text import MIMEText
from email.utils import formatdate
from enum import Enum
from time import struct_time
from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Tuple, Union
from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Set, Tuple, Union
from urllib.parse import unquote_plus
import bleach
@ -46,7 +46,8 @@ import parsedatetime
import sqlalchemy as sa
from dateutil.parser import parse
from dateutil.relativedelta import relativedelta
from flask import current_app, flash, g, Markup, render_template
from flask import current_app, flash, Flask, g, Markup, render_template
from flask_appbuilder import SQLA
from flask_appbuilder.security.sqla.models import User
from flask_babel import gettext as __, lazy_gettext as _
from sqlalchemy import event, exc, select, Text
@ -487,7 +488,9 @@ def readfile(file_path: str) -> Optional[str]:
return content
def generic_find_constraint_name(table, columns, referenced, db):
def generic_find_constraint_name(
table: str, columns: Set[str], referenced: str, db: SQLA
):
"""Utility to find a constraint name in alembic migrations"""
t = sa.Table(table, db.metadata, autoload=True, autoload_with=db.engine)
@ -496,7 +499,9 @@ def generic_find_constraint_name(table, columns, referenced, db):
return fk.name
def generic_find_fk_constraint_name(table, columns, referenced, insp):
def generic_find_fk_constraint_name(
table: str, columns: Set[str], referenced: str, insp
):
"""Utility to find a foreign-key constraint name in alembic migrations"""
for fk in insp.get_foreign_keys(table):
if (

View File

@ -64,6 +64,7 @@ class SupersetTestCase(TestCase):
@classmethod
def create_druid_test_objects(cls):
# create druid cluster and druid datasources
with app.app_context():
session = db.session
cluster = (
@ -75,11 +76,11 @@ class SupersetTestCase(TestCase):
session.commit()
druid_datasource1 = DruidDatasource(
datasource_name="druid_ds_1", cluster_name="druid_test"
datasource_name="druid_ds_1", cluster=cluster
)
session.add(druid_datasource1)
druid_datasource2 = DruidDatasource(
datasource_name="druid_ds_2", cluster_name="druid_test"
datasource_name="druid_ds_2", cluster=cluster
)
session.add(druid_datasource2)
session.commit()

View File

@ -23,7 +23,12 @@ import yaml
from tests.test_app import app
from superset import db
from superset.connectors.druid.models import DruidColumn, DruidDatasource, DruidMetric
from superset.connectors.druid.models import (
DruidColumn,
DruidDatasource,
DruidMetric,
DruidCluster,
)
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.utils.core import get_example_database
from superset.utils.dict_import_export import export_to_dict
@ -87,11 +92,15 @@ class DictImportExportTests(SupersetTestCase):
return table, dict_rep
def create_druid_datasource(self, name, id=0, cols_names=[], metric_names=[]):
name = "{0}{1}".format(NAME_PREFIX, name)
cluster_name = "druid_test"
cluster = self.get_or_create(
DruidCluster, {"cluster_name": cluster_name}, db.session
)
name = "{0}{1}".format(NAME_PREFIX, name)
params = {DBREF: id, "database_name": cluster_name}
dict_rep = {
"cluster_name": cluster_name,
"cluster_id": cluster.id,
"datasource_name": name,
"id": id,
"params": json.dumps(params),
@ -102,7 +111,7 @@ class DictImportExportTests(SupersetTestCase):
datasource = DruidDatasource(
id=id,
datasource_name=name,
cluster_name=cluster_name,
cluster_id=cluster.id,
params=json.dumps(params),
)
for col_name in cols_names:

View File

@ -131,9 +131,7 @@ class DruidTests(SupersetTestCase):
)
if cluster:
for datasource in (
db.session.query(DruidDatasource)
.filter_by(cluster_name=cluster.cluster_name)
.all()
db.session.query(DruidDatasource).filter_by(cluster_id=cluster.id).all()
):
db.session.delete(datasource)
@ -358,9 +356,7 @@ class DruidTests(SupersetTestCase):
)
if cluster:
for datasource in (
db.session.query(DruidDatasource)
.filter_by(cluster_name=cluster.cluster_name)
.all()
db.session.query(DruidDatasource).filter_by(cluster_id=cluster.id).all()
):
db.session.delete(datasource)

View File

@ -25,7 +25,12 @@ from sqlalchemy.orm.session import make_transient
from tests.test_app import app
from superset.utils.dashboard_import_export import decode_dashboards
from superset import db, security_manager
from superset.connectors.druid.models import DruidColumn, DruidDatasource, DruidMetric
from superset.connectors.druid.models import (
DruidColumn,
DruidDatasource,
DruidMetric,
DruidCluster,
)
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
@ -119,11 +124,16 @@ class ImportExportTests(SupersetTestCase):
return table
def create_druid_datasource(self, name, id=0, cols_names=[], metric_names=[]):
params = {"remote_id": id, "database_name": "druid_test"}
cluster_name = "druid_test"
cluster = self.get_or_create(
DruidCluster, {"cluster_name": cluster_name}, db.session
)
params = {"remote_id": id, "database_name": cluster_name}
datasource = DruidDatasource(
id=id,
datasource_name=name,
cluster_name="druid_test",
cluster_id=cluster.id,
params=json.dumps(params),
)
for col_name in cols_names:

View File

@ -238,7 +238,7 @@ class RolePermissionTests(SupersetTestCase):
datasource = DruidDatasource(
datasource_name="tmp_datasource",
cluster=druid_cluster,
cluster_name="druid_test",
cluster_id=druid_cluster.id,
)
session.add(datasource)
session.commit()