utils: generalize utility to find find_constraint_name (#557)

See https://github.com/airbnb/caravel/pull/531
This commit is contained in:
Riccardo Magliocchetti 2016-06-03 18:47:51 +02:00 committed by Maxime Beauchemin
parent fe402465b1
commit 5bc50210ad
2 changed files with 21 additions and 16 deletions

View File

@ -12,27 +12,16 @@ down_revision = '956a063c52b3'
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
from caravel.utils import generic_find_constraint_name
naming_convention = { naming_convention = {
"fk": "fk":
"fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
} }
def find_constraint_name(upgrade = True): def find_constraint_name(upgrade=True):
__table = 'columns' cols = {'column_name'} if upgrade else {'datasource_name'}
__cols = {'column_name'} if upgrade else {'datasource_name'} return generic_find_constraint_name(table='columns', columns=cols, referenced='datasources')
__referenced = 'datasources'
__ref_cols = {'datasource_name'} if upgrade else {'column_name'}
engine = op.get_bind().engine
m = sa.MetaData({})
t=sa.Table(__table,m, autoload=True, autoload_with=engine)
for fk in t.foreign_key_constraints:
if fk.referred_table.name == __referenced and \
set(fk.column_keys) == __cols:
return fk.name
return None
def upgrade(): def upgrade():
constraint = find_constraint_name() or 'fk_columns_column_name_datasources' constraint = find_constraint_name() or 'fk_columns_column_name_datasources'
@ -47,4 +36,3 @@ def downgrade():
naming_convention=naming_convention) as batch_op: naming_convention=naming_convention) as batch_op:
batch_op.drop_constraint(constraint, type_="foreignkey") batch_op.drop_constraint(constraint, type_="foreignkey")
batch_op.create_foreign_key('fk_columns_column_name_datasources', 'datasources', ['column_name'], ['datasource_name']) batch_op.create_foreign_key('fk_columns_column_name_datasources', 'datasources', ['column_name'], ['datasource_name'])

View File

@ -11,7 +11,9 @@ import numpy
from datetime import datetime from datetime import datetime
import parsedatetime import parsedatetime
import sqlalchemy as sa
from dateutil.parser import parse from dateutil.parser import parse
from alembic import op
from flask import flash, Markup from flask import flash, Markup
from flask_appbuilder.security.sqla import models as ab_models from flask_appbuilder.security.sqla import models as ab_models
from markdown import markdown as md from markdown import markdown as md
@ -255,3 +257,18 @@ def readfile(filepath):
with open(filepath) as f: with open(filepath) as f:
content = f.read() content = f.read()
return content return content
def generic_find_constraint_name(table, columns, referenced):
"""
Utility to find a constraint name in alembic migrations
"""
engine = op.get_bind().engine
m = sa.MetaData({})
t = sa.Table(table, m, autoload=True, autoload_with=engine)
for fk in t.foreign_key_constraints:
if fk.referred_table.name == referenced and \
set(fk.column_keys) == columns:
return fk.name
return None