mirror of
https://github.com/apache/superset.git
synced 2024-09-19 20:19:37 -04:00
247 lines
8.9 KiB
Python
247 lines
8.9 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
import importlib.util
|
|
import logging
|
|
import re
|
|
import time
|
|
from collections import defaultdict
|
|
from inspect import getsource
|
|
from pathlib import Path
|
|
from types import ModuleType
|
|
from typing import Dict, List, Set, Type
|
|
|
|
import click
|
|
from flask import current_app
|
|
from flask_appbuilder import Model
|
|
from flask_migrate import downgrade, upgrade
|
|
from graphlib import TopologicalSorter # pylint: disable=wrong-import-order
|
|
from progress.bar import ChargingBar
|
|
from sqlalchemy import create_engine, inspect
|
|
from sqlalchemy.ext.automap import automap_base
|
|
|
|
from superset import db
|
|
from superset.utils.mock_data import add_sample_rows
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def import_migration_script(filepath: Path) -> ModuleType:
|
|
"""
|
|
Import migration script as if it were a module.
|
|
"""
|
|
spec = importlib.util.spec_from_file_location(filepath.stem, filepath)
|
|
if spec:
|
|
module = importlib.util.module_from_spec(spec)
|
|
spec.loader.exec_module(module) # type: ignore
|
|
return module
|
|
raise Exception(
|
|
"No module spec found in location: `{path}`".format(path=str(filepath))
|
|
)
|
|
|
|
|
|
def extract_modified_tables(module: ModuleType) -> Set[str]:
|
|
"""
|
|
Extract the tables being modified by a migration script.
|
|
|
|
This function uses a simple approach of looking at the source code of
|
|
the migration script looking for patterns. It could be improved by
|
|
actually traversing the AST.
|
|
"""
|
|
|
|
tables: Set[str] = set()
|
|
for function in {"upgrade", "downgrade"}:
|
|
source = getsource(getattr(module, function))
|
|
tables.update(re.findall(r'alter_table\(\s*"(\w+?)"\s*\)', source, re.DOTALL))
|
|
tables.update(re.findall(r'add_column\(\s*"(\w+?)"\s*,', source, re.DOTALL))
|
|
tables.update(re.findall(r'drop_column\(\s*"(\w+?)"\s*,', source, re.DOTALL))
|
|
|
|
return tables
|
|
|
|
|
|
def find_models(module: ModuleType) -> List[Type[Model]]:
|
|
"""
|
|
Find all models in a migration script.
|
|
"""
|
|
models: List[Type[Model]] = []
|
|
tables = extract_modified_tables(module)
|
|
|
|
# add models defined explicitly in the migration script
|
|
queue = list(module.__dict__.values())
|
|
while queue:
|
|
obj = queue.pop()
|
|
if hasattr(obj, "__tablename__"):
|
|
tables.add(obj.__tablename__)
|
|
elif isinstance(obj, list):
|
|
queue.extend(obj)
|
|
elif isinstance(obj, dict):
|
|
queue.extend(obj.values())
|
|
|
|
# build models by automapping the existing tables, instead of using current
|
|
# code; this is needed for migrations that modify schemas (eg, add a column),
|
|
# where the current model is out-of-sync with the existing table after a
|
|
# downgrade
|
|
sqlalchemy_uri = current_app.config["SQLALCHEMY_DATABASE_URI"]
|
|
engine = create_engine(sqlalchemy_uri)
|
|
Base = automap_base()
|
|
Base.prepare(engine, reflect=True)
|
|
seen = set()
|
|
while tables:
|
|
table = tables.pop()
|
|
seen.add(table)
|
|
model = getattr(Base.classes, table)
|
|
model.__tablename__ = table
|
|
models.append(model)
|
|
|
|
# add other models referenced in foreign keys
|
|
inspector = inspect(model)
|
|
for column in inspector.columns.values():
|
|
for foreign_key in column.foreign_keys:
|
|
table = foreign_key.column.table.name
|
|
if table not in seen:
|
|
tables.add(table)
|
|
|
|
# sort topologically so we can create entities in order and
|
|
# maintain relationships (eg, create a database before creating
|
|
# a slice)
|
|
sorter = TopologicalSorter()
|
|
for model in models:
|
|
inspector = inspect(model)
|
|
dependent_tables: List[str] = []
|
|
for column in inspector.columns.values():
|
|
for foreign_key in column.foreign_keys:
|
|
if foreign_key.column.table.name != model.__tablename__:
|
|
dependent_tables.append(foreign_key.column.table.name)
|
|
sorter.add(model.__tablename__, *dependent_tables)
|
|
order = list(sorter.static_order())
|
|
models.sort(key=lambda model: order.index(model.__tablename__))
|
|
|
|
return models
|
|
|
|
|
|
@click.command()
|
|
@click.argument("filepath")
|
|
@click.option("--limit", default=1000, help="Maximum number of entities.")
|
|
@click.option("--force", is_flag=True, help="Do not prompt for confirmation.")
|
|
@click.option("--no-auto-cleanup", is_flag=True, help="Do not remove created models.")
|
|
def main(
|
|
filepath: str, limit: int = 1000, force: bool = False, no_auto_cleanup: bool = False
|
|
) -> None:
|
|
auto_cleanup = not no_auto_cleanup
|
|
session = db.session()
|
|
|
|
print(f"Importing migration script: {filepath}")
|
|
module = import_migration_script(Path(filepath))
|
|
|
|
revision: str = getattr(module, "revision", "")
|
|
down_revision: str = getattr(module, "down_revision", "")
|
|
if not revision or not down_revision:
|
|
raise Exception(
|
|
"Not a valid migration script, couldn't find down_revision/revision"
|
|
)
|
|
|
|
print(f"Migration goes from {down_revision} to {revision}")
|
|
current_revision = db.engine.execute(
|
|
"SELECT version_num FROM alembic_version"
|
|
).scalar()
|
|
print(f"Current version of the DB is {current_revision}")
|
|
|
|
if current_revision != down_revision:
|
|
if not force:
|
|
click.confirm(
|
|
"\nRunning benchmark will downgrade the Superset DB to "
|
|
f"{down_revision} and upgrade to {revision} again. There may "
|
|
"be data loss in downgrades. Continue?",
|
|
abort=True,
|
|
)
|
|
downgrade(revision=down_revision)
|
|
|
|
print("\nIdentifying models used in the migration:")
|
|
models = find_models(module)
|
|
model_rows: Dict[Type[Model], int] = {}
|
|
for model in models:
|
|
rows = session.query(model).count()
|
|
print(f"- {model.__name__} ({rows} rows in table {model.__tablename__})")
|
|
model_rows[model] = rows
|
|
session.close()
|
|
|
|
print("Benchmarking migration")
|
|
results: Dict[str, float] = {}
|
|
start = time.time()
|
|
upgrade(revision=revision)
|
|
duration = time.time() - start
|
|
results["Current"] = duration
|
|
print(f"Migration on current DB took: {duration:.2f} seconds")
|
|
|
|
min_entities = 10
|
|
new_models: Dict[Type[Model], List[Model]] = defaultdict(list)
|
|
while min_entities <= limit:
|
|
downgrade(revision=down_revision)
|
|
print(f"Running with at least {min_entities} entities of each model")
|
|
for model in models:
|
|
missing = min_entities - model_rows[model]
|
|
if missing > 0:
|
|
entities: List[Model] = []
|
|
print(f"- Adding {missing} entities to the {model.__name__} model")
|
|
bar = ChargingBar("Processing", max=missing)
|
|
try:
|
|
for entity in add_sample_rows(session, model, missing):
|
|
entities.append(entity)
|
|
bar.next()
|
|
except Exception:
|
|
session.rollback()
|
|
raise
|
|
bar.finish()
|
|
model_rows[model] = min_entities
|
|
session.add_all(entities)
|
|
session.commit()
|
|
|
|
if auto_cleanup:
|
|
new_models[model].extend(entities)
|
|
start = time.time()
|
|
upgrade(revision=revision)
|
|
duration = time.time() - start
|
|
print(f"Migration for {min_entities}+ entities took: {duration:.2f} seconds")
|
|
results[f"{min_entities}+"] = duration
|
|
min_entities *= 10
|
|
|
|
print("\nResults:\n")
|
|
for label, duration in results.items():
|
|
print(f"{label}: {duration:.2f} s")
|
|
|
|
if auto_cleanup:
|
|
print("Cleaning up DB")
|
|
# delete in reverse order of creation to handle relationships
|
|
for model, entities in list(new_models.items())[::-1]:
|
|
session.query(model).filter(
|
|
model.id.in_(entity.id for entity in entities)
|
|
).delete(synchronize_session=False)
|
|
session.commit()
|
|
|
|
if current_revision != revision and not force:
|
|
click.confirm(f"\nRevert DB to {revision}?", abort=True)
|
|
upgrade(revision=revision)
|
|
print("Reverted")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from superset.app import create_app
|
|
|
|
app = create_app()
|
|
with app.app_context():
|
|
# pylint: disable=no-value-for-parameter
|
|
main()
|