mirror of https://github.com/apache/superset.git
feat: script to benchmark DB migrations (#13561)
This commit is contained in:
parent
3294f77ca5
commit
c1cb3619ab
|
@ -109,6 +109,8 @@ geographiclib==1.50
|
|||
# via geopy
|
||||
geopy==2.0.0
|
||||
# via apache-superset
|
||||
graphlib-backport==1.0.3
|
||||
# via apache-superset
|
||||
gunicorn==20.0.4
|
||||
# via apache-superset
|
||||
holidays==0.10.3
|
||||
|
|
|
@ -0,0 +1,215 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import importlib.util
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from inspect import getsource
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Dict, List, Set, Type
|
||||
|
||||
import click
|
||||
from flask_appbuilder import Model
|
||||
from flask_migrate import downgrade, upgrade
|
||||
from graphlib import TopologicalSorter # pylint: disable=wrong-import-order
|
||||
from sqlalchemy import inspect
|
||||
|
||||
from superset import db
|
||||
from superset.utils.mock_data import add_sample_rows
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def import_migration_script(filepath: Path) -> ModuleType:
|
||||
"""
|
||||
Import migration script as if it were a module.
|
||||
"""
|
||||
spec = importlib.util.spec_from_file_location(filepath.stem, filepath)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module) # type: ignore
|
||||
return module
|
||||
|
||||
|
||||
def extract_modified_tables(module: ModuleType) -> Set[str]:
|
||||
"""
|
||||
Extract the tables being modified by a migration script.
|
||||
|
||||
This function uses a simple approach of looking at the source code of
|
||||
the migration script looking for patterns. It could be improved by
|
||||
actually traversing the AST.
|
||||
"""
|
||||
|
||||
tables: Set[str] = set()
|
||||
for function in {"upgrade", "downgrade"}:
|
||||
source = getsource(getattr(module, function))
|
||||
tables.update(re.findall(r'alter_table\(\s*"(\w+?)"\s*\)', source, re.DOTALL))
|
||||
tables.update(re.findall(r'add_column\(\s*"(\w+?)"\s*,', source, re.DOTALL))
|
||||
tables.update(re.findall(r'drop_column\(\s*"(\w+?)"\s*,', source, re.DOTALL))
|
||||
|
||||
return tables
|
||||
|
||||
|
||||
def find_models(module: ModuleType) -> List[Type[Model]]:
|
||||
"""
|
||||
Find all models in a migration script.
|
||||
"""
|
||||
models: List[Type[Model]] = []
|
||||
tables = extract_modified_tables(module)
|
||||
|
||||
# add models defined explicitly in the migration script
|
||||
queue = list(module.__dict__.values())
|
||||
while queue:
|
||||
obj = queue.pop()
|
||||
if hasattr(obj, "__tablename__"):
|
||||
tables.add(obj.__tablename__)
|
||||
elif isinstance(obj, list):
|
||||
queue.extend(obj)
|
||||
elif isinstance(obj, dict):
|
||||
queue.extend(obj.values())
|
||||
|
||||
# add implicit models
|
||||
# pylint: disable=no-member, protected-access
|
||||
for obj in Model._decl_class_registry.values():
|
||||
if hasattr(obj, "__table__") and obj.__table__.fullname in tables:
|
||||
models.append(obj)
|
||||
|
||||
# sort topologically so we can create entities in order and
|
||||
# maintain relationships (eg, create a database before creating
|
||||
# a slice)
|
||||
sorter = TopologicalSorter()
|
||||
for model in models:
|
||||
inspector = inspect(model)
|
||||
dependent_tables: List[str] = []
|
||||
for column in inspector.columns.values():
|
||||
for foreign_key in column.foreign_keys:
|
||||
dependent_tables.append(foreign_key.target_fullname.split(".")[0])
|
||||
sorter.add(model.__tablename__, *dependent_tables)
|
||||
order = list(sorter.static_order())
|
||||
models.sort(key=lambda model: order.index(model.__tablename__))
|
||||
|
||||
return models
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("filepath")
|
||||
@click.option("--limit", default=1000, help="Maximum number of entities.")
|
||||
@click.option("--force", is_flag=True, help="Do not prompt for confirmation.")
|
||||
@click.option("--no-auto-cleanup", is_flag=True, help="Do not remove created models.")
|
||||
def main(
|
||||
filepath: str, limit: int = 1000, force: bool = False, no_auto_cleanup: bool = False
|
||||
) -> None:
|
||||
auto_cleanup = not no_auto_cleanup
|
||||
session = db.session()
|
||||
|
||||
print(f"Importing migration script: {filepath}")
|
||||
module = import_migration_script(Path(filepath))
|
||||
|
||||
revision: str = getattr(module, "revision", "")
|
||||
down_revision: str = getattr(module, "down_revision", "")
|
||||
if not revision or not down_revision:
|
||||
raise Exception(
|
||||
"Not a valid migration script, couldn't find down_revision/revision"
|
||||
)
|
||||
|
||||
print(f"Migration goes from {down_revision} to {revision}")
|
||||
current_revision = db.engine.execute(
|
||||
"SELECT version_num FROM alembic_version"
|
||||
).scalar()
|
||||
print(f"Current version of the DB is {current_revision}")
|
||||
|
||||
print("\nIdentifying models used in the migration:")
|
||||
models = find_models(module)
|
||||
model_rows: Dict[Type[Model], int] = {}
|
||||
for model in models:
|
||||
rows = session.query(model).count()
|
||||
print(f"- {model.__name__} ({rows} rows in table {model.__tablename__})")
|
||||
model_rows[model] = rows
|
||||
session.close()
|
||||
|
||||
if current_revision != down_revision:
|
||||
if not force:
|
||||
click.confirm(
|
||||
"\nRunning benchmark will downgrade the Superset DB to "
|
||||
f"{down_revision} and upgrade to {revision} again. There may "
|
||||
"be data loss in downgrades. Continue?",
|
||||
abort=True,
|
||||
)
|
||||
downgrade(revision=down_revision)
|
||||
|
||||
print("Benchmarking migration")
|
||||
results: Dict[str, float] = {}
|
||||
start = time.time()
|
||||
upgrade(revision=revision)
|
||||
duration = time.time() - start
|
||||
results["Current"] = duration
|
||||
print(f"Migration on current DB took: {duration:.2f} seconds")
|
||||
|
||||
min_entities = 10
|
||||
new_models: Dict[Type[Model], List[Model]] = defaultdict(list)
|
||||
while min_entities <= limit:
|
||||
downgrade(revision=down_revision)
|
||||
print(f"Running with at least {min_entities} entities of each model")
|
||||
for model in models:
|
||||
missing = min_entities - model_rows[model]
|
||||
if missing > 0:
|
||||
print(f"- Adding {missing} entities to the {model.__name__} model")
|
||||
try:
|
||||
added_models = add_sample_rows(session, model, missing)
|
||||
except Exception:
|
||||
session.rollback()
|
||||
raise
|
||||
model_rows[model] = min_entities
|
||||
session.commit()
|
||||
|
||||
if auto_cleanup:
|
||||
new_models[model].extend(added_models)
|
||||
|
||||
start = time.time()
|
||||
upgrade(revision=revision)
|
||||
duration = time.time() - start
|
||||
print(f"Migration for {min_entities}+ entities took: {duration:.2f} seconds")
|
||||
results[f"{min_entities}+"] = duration
|
||||
min_entities *= 10
|
||||
|
||||
if auto_cleanup:
|
||||
print("Cleaning up DB")
|
||||
# delete in reverse order of creation to handle relationships
|
||||
for model, entities in list(new_models.items())[::-1]:
|
||||
session.query(model).filter(
|
||||
model.id.in_(entity.id for entity in entities)
|
||||
).delete(synchronize_session=False)
|
||||
session.commit()
|
||||
|
||||
if current_revision != revision and not force:
|
||||
click.confirm(f"\nRevert DB to {revision}?", abort=True)
|
||||
upgrade(revision=revision)
|
||||
print("Reverted")
|
||||
|
||||
print("\nResults:\n")
|
||||
for label, duration in results.items():
|
||||
print(f"{label}: {duration:.2f} s")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from superset.app import create_app
|
||||
|
||||
app = create_app()
|
||||
with app.app_context():
|
||||
# pylint: disable=no-value-for-parameter
|
||||
main()
|
|
@ -30,7 +30,7 @@ combine_as_imports = true
|
|||
include_trailing_comma = true
|
||||
line_length = 88
|
||||
known_first_party = superset
|
||||
known_third_party =alembic,apispec,backoff,bleach,cachelib,celery,click,colorama,contextlib2,cron_descriptor,croniter,cryptography,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_jwt_extended,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,freezegun,geohash,geopy,holidays,humanize,isodate,jinja2,jwt,markdown,markupsafe,marshmallow,marshmallow_enum,msgpack,numpy,pandas,parameterized,parsedatetime,pathlib2,pgsanity,pkg_resources,polyline,prison,pyarrow,pyhive,pyparsing,pytest,pytz,redis,requests,retry,selenium,setuptools,simplejson,slack,sqlalchemy,sqlalchemy_utils,sqlparse,typing_extensions,werkzeug,wtforms,wtforms_json,yaml
|
||||
known_third_party =alembic,apispec,backoff,bleach,cachelib,celery,click,colorama,contextlib2,cron_descriptor,croniter,cryptography,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_jwt_extended,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,freezegun,geohash,geopy,graphlib,holidays,humanize,isodate,jinja2,jwt,markdown,markupsafe,marshmallow,marshmallow_enum,msgpack,numpy,pandas,parameterized,parsedatetime,pathlib2,pgsanity,pkg_resources,polyline,prison,pyarrow,pyhive,pyparsing,pytest,pytz,redis,requests,retry,selenium,setuptools,simplejson,slack,sqlalchemy,sqlalchemy_utils,sqlparse,typing_extensions,werkzeug,wtforms,wtforms_json,yaml
|
||||
multi_line_output = 3
|
||||
order_by_type = false
|
||||
|
||||
|
|
3
setup.py
3
setup.py
|
@ -82,6 +82,7 @@ setup(
|
|||
"flask-migrate",
|
||||
"flask-wtf",
|
||||
"geopy",
|
||||
"graphlib-backport",
|
||||
"gunicorn>=20.0.2, <20.1",
|
||||
"humanize",
|
||||
"isodate",
|
||||
|
@ -103,7 +104,7 @@ setup(
|
|||
"selenium>=3.141.0",
|
||||
"simplejson>=3.15.0",
|
||||
"slackclient==2.5.0", # PINNED! slack changes file upload api in the future versions
|
||||
"sqlalchemy>=1.3.16, <2.0, !=1.3.21",
|
||||
"sqlalchemy>=1.3.16, <1.4, !=1.3.21",
|
||||
"sqlalchemy-utils>=0.36.6,<0.37",
|
||||
"sqlparse==0.3.0", # PINNED! see https://github.com/andialbrecht/sqlparse/issues/562
|
||||
"typing-extensions>=3.7.4.3,<4", # needed to support typing.Literal on py37
|
||||
|
|
|
@ -20,7 +20,7 @@ from typing import List
|
|||
|
||||
import sqlalchemy.sql.sqltypes
|
||||
|
||||
from superset.utils.data import add_data, ColumnInfo
|
||||
from superset.utils.mock_data import add_data, ColumnInfo
|
||||
|
||||
COLUMN_TYPES = [
|
||||
sqlalchemy.sql.sqltypes.INTEGER(),
|
||||
|
|
|
@ -14,17 +14,30 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import decimal
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
import sys
|
||||
from datetime import date, datetime, time, timedelta
|
||||
from typing import Any, Callable, cast, Dict, List, Optional
|
||||
from typing import Any, Callable, cast, Dict, List, Optional, Type
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy.sql.sqltypes
|
||||
import sqlalchemy_utils
|
||||
from flask_appbuilder import Model
|
||||
from sqlalchemy import Column, inspect, MetaData, Table
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.sql.visitors import VisitableType
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from superset import db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ColumnInfo = TypedDict(
|
||||
"ColumnInfo",
|
||||
{
|
||||
|
@ -53,24 +66,35 @@ MAXIMUM_DATE = date.today()
|
|||
days_range = (MAXIMUM_DATE - MINIMUM_DATE).days
|
||||
|
||||
|
||||
# pylint: disable=too-many-return-statements, too-many-branches
|
||||
def get_type_generator(sqltype: sqlalchemy.sql.sqltypes) -> Callable[[], Any]:
|
||||
if isinstance(sqltype, sqlalchemy.sql.sqltypes.INTEGER):
|
||||
if isinstance(
|
||||
sqltype, (sqlalchemy.sql.sqltypes.INTEGER, sqlalchemy.sql.sqltypes.Integer)
|
||||
):
|
||||
return lambda: random.randrange(2147483647)
|
||||
|
||||
if isinstance(sqltype, sqlalchemy.sql.sqltypes.BIGINT):
|
||||
return lambda: random.randrange(sys.maxsize)
|
||||
|
||||
if isinstance(sqltype, sqlalchemy.sql.sqltypes.VARCHAR):
|
||||
if isinstance(
|
||||
sqltype, (sqlalchemy.sql.sqltypes.VARCHAR, sqlalchemy.sql.sqltypes.String)
|
||||
):
|
||||
length = random.randrange(sqltype.length or 255)
|
||||
length = max(8, length) # for unique values
|
||||
length = min(100, length) # for FAB perms
|
||||
return lambda: "".join(random.choices(string.printable, k=length))
|
||||
|
||||
if isinstance(sqltype, sqlalchemy.sql.sqltypes.TEXT):
|
||||
if isinstance(
|
||||
sqltype, (sqlalchemy.sql.sqltypes.TEXT, sqlalchemy.sql.sqltypes.Text)
|
||||
):
|
||||
length = random.randrange(65535)
|
||||
# "practicality beats purity"
|
||||
length = max(length, 2048)
|
||||
return lambda: "".join(random.choices(string.printable, k=length))
|
||||
|
||||
if isinstance(sqltype, sqlalchemy.sql.sqltypes.BOOLEAN):
|
||||
if isinstance(
|
||||
sqltype, (sqlalchemy.sql.sqltypes.BOOLEAN, sqlalchemy.sql.sqltypes.Boolean)
|
||||
):
|
||||
return lambda: random.choice([True, False])
|
||||
|
||||
if isinstance(
|
||||
|
@ -87,13 +111,49 @@ def get_type_generator(sqltype: sqlalchemy.sql.sqltypes) -> Callable[[], Any]:
|
|||
)
|
||||
|
||||
if isinstance(
|
||||
sqltype, (sqlalchemy.sql.sqltypes.TIMESTAMP, sqlalchemy.sql.sqltypes.DATETIME)
|
||||
sqltype,
|
||||
(
|
||||
sqlalchemy.sql.sqltypes.TIMESTAMP,
|
||||
sqlalchemy.sql.sqltypes.DATETIME,
|
||||
sqlalchemy.sql.sqltypes.DateTime,
|
||||
),
|
||||
):
|
||||
return lambda: datetime.fromordinal(MINIMUM_DATE.toordinal()) + timedelta(
|
||||
seconds=random.randrange(days_range * 86400)
|
||||
)
|
||||
|
||||
raise Exception(f"Unknown type {sqltype}. Please add it to `get_type_generator`.")
|
||||
if isinstance(sqltype, sqlalchemy.sql.sqltypes.Numeric):
|
||||
# since decimal is used in some models to store time, return a value that
|
||||
# is a reasonable timestamp
|
||||
return lambda: decimal.Decimal(datetime.now().timestamp() * 1000)
|
||||
|
||||
if isinstance(sqltype, sqlalchemy.sql.sqltypes.JSON):
|
||||
return lambda: {
|
||||
"".join(random.choices(string.printable, k=8)): random.randrange(65535)
|
||||
for _ in range(10)
|
||||
}
|
||||
|
||||
if isinstance(
|
||||
sqltype,
|
||||
(
|
||||
sqlalchemy.sql.sqltypes.BINARY,
|
||||
sqlalchemy_utils.types.encrypted.encrypted_type.EncryptedType,
|
||||
),
|
||||
):
|
||||
length = random.randrange(sqltype.length or 255)
|
||||
return lambda: os.urandom(length)
|
||||
|
||||
if isinstance(sqltype, sqlalchemy_utils.types.uuid.UUIDType):
|
||||
return uuid4
|
||||
|
||||
if isinstance(sqltype, sqlalchemy.sql.sqltypes.BLOB):
|
||||
length = random.randrange(sqltype.length or 255)
|
||||
return lambda: os.urandom(length)
|
||||
|
||||
logger.warning(
|
||||
"Unknown type %s. Please add it to `get_type_generator`.", type(sqltype)
|
||||
)
|
||||
return lambda: "UNKNOWN TYPE"
|
||||
|
||||
|
||||
def add_data(
|
||||
|
@ -161,5 +221,75 @@ def generate_data(columns: List[ColumnInfo], num_rows: int) -> List[Dict[str, An
|
|||
|
||||
|
||||
def generate_column_data(column: ColumnInfo, num_rows: int) -> List[Any]:
|
||||
func = get_type_generator(column["type"])
|
||||
return [func() for _ in range(num_rows)]
|
||||
gen = get_type_generator(column["type"])
|
||||
return [gen() for _ in range(num_rows)]
|
||||
|
||||
|
||||
def add_sample_rows(session: Session, model: Type[Model], count: int) -> List[Model]:
|
||||
"""
|
||||
Add entities of a given model.
|
||||
|
||||
:param Model model: a Superset/FAB model
|
||||
:param int count: how many entities to generate and insert
|
||||
"""
|
||||
inspector = inspect(model)
|
||||
|
||||
# select samples to copy relationship values
|
||||
relationships = inspector.relationships.items()
|
||||
samples = session.query(model).limit(count).all() if relationships else []
|
||||
|
||||
entities: List[Model] = []
|
||||
max_primary_key: Optional[int] = None
|
||||
for i in range(count):
|
||||
sample = samples[i % len(samples)] if samples else None
|
||||
kwargs = {}
|
||||
for column in inspector.columns.values():
|
||||
# for primary keys, keep incrementing
|
||||
if column.primary_key:
|
||||
if max_primary_key is None:
|
||||
max_primary_key = (
|
||||
session.query(func.max(getattr(model, column.name))).scalar()
|
||||
or 0
|
||||
)
|
||||
max_primary_key += 1
|
||||
kwargs[column.name] = max_primary_key
|
||||
|
||||
# if the column has a foreign key, copy the value from an existing entity
|
||||
elif column.foreign_keys:
|
||||
if sample:
|
||||
kwargs[column.name] = getattr(sample, column.name)
|
||||
else:
|
||||
kwargs[column.name] = get_valid_foreign_key(column)
|
||||
|
||||
# should be an enum but it's not
|
||||
elif column.name == "datasource_type":
|
||||
kwargs[column.name] = "table"
|
||||
|
||||
# otherwise, generate a random value based on the type
|
||||
else:
|
||||
kwargs[column.name] = generate_value(column)
|
||||
|
||||
entities.append(model(**kwargs))
|
||||
|
||||
session.add_all(entities)
|
||||
return entities
|
||||
|
||||
|
||||
def get_valid_foreign_key(column: Column) -> Any:
|
||||
foreign_key = list(column.foreign_keys)[0]
|
||||
table_name, column_name = foreign_key.target_fullname.split(".", 1)
|
||||
return db.engine.execute(f"SELECT {column_name} FROM {table_name} LIMIT 1").scalar()
|
||||
|
||||
|
||||
def generate_value(column: Column) -> Any:
|
||||
if hasattr(column.type, "enums"):
|
||||
return random.choice(column.type.enums)
|
||||
|
||||
json_as_string = "json" in column.name.lower() and isinstance(
|
||||
column.type, sqlalchemy.sql.sqltypes.Text
|
||||
)
|
||||
type_ = sqlalchemy.sql.sqltypes.JSON() if json_as_string else column.type
|
||||
value = get_type_generator(type_)()
|
||||
if json_as_string:
|
||||
value = json.dumps(value)
|
||||
return value
|
Loading…
Reference in New Issue