# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. """ This module contains utilities to auto-generate an Entity-Relationship Diagram (ERD) from SQLAlchemy and onto a plantuml file. """ import json import os from collections import defaultdict from collections.abc import Iterable from typing import Any, Optional import click import jinja2 from superset import db GROUPINGS: dict[str, Iterable[str]] = { "Core": [ "css_templates", "dynamic_plugin", "favstar", "dashboards", "slices", "user_attribute", "embedded_dashboards", "annotation", "annotation_layer", "tag", "tagged_object", ], "System": ["ssh_tunnels", "keyvalue", "cache_keys", "key_value", "logs"], "Alerts & Reports": ["report_recipient", "report_execution_log", "report_schedule"], "Inherited from Flask App Builder (FAB)": [ "ab_user", "ab_permission", "ab_permission_view", "ab_view_menu", "ab_role", "ab_register_user", ], "SQL Lab": ["query", "saved_query", "tab_state", "table_schema"], "Data Assets": [ "dbs", "table_columns", "sql_metrics", "tables", "row_level_security_filters", "sl_tables", "sl_datasets", "sl_columns", "database_user_oauth2_tokens", ], } # Table name to group name mapping (reversing the above one for easy lookup) TABLE_TO_GROUP_MAP: dict[str, str] = {} for group, tables in GROUPINGS.items(): for table in tables: TABLE_TO_GROUP_MAP[table] = group def sort_data_structure(data): # type: ignore sorted_json = json.dumps(data, sort_keys=True) sorted_data = json.loads(sorted_json) return sorted_data def introspect_sqla_model(mapper: Any, seen: set[str]) -> dict[str, Any]: """ Introspects a SQLAlchemy model and returns a data structure that can be pass to a jinja2 template for instance Parameters: ----------- mapper: SQLAlchemy model mapper seen: set of model identifiers to avoid duplicates Returns: -------- Dict[str, Any]: data structure for jinja2 template """ table_name = mapper.persist_selectable.name model_info: dict[str, Any] = { "class_name": mapper.class_.__name__, "table_name": table_name, "fields": [], "relationships": [], } # Collect fields (columns) and their types for column in mapper.columns: field_info: dict[str, str] = { "field_name": column.key, "type": str(column.type), } model_info["fields"].append(field_info) # Collect relationships and identify types for attr, relationship in mapper.relationships.items(): related_table = relationship.mapper.persist_selectable.name # Create a unique identifier for the relationship to avoid duplicates relationship_id = "-".join(sorted([table_name, related_table])) if relationship_id not in seen: seen.add(relationship_id) squiggle = "||--|{" if relationship.direction.name == "MANYTOONE": squiggle = "}|--||" relationship_info: dict[str, str] = { "relationship_name": attr, "related_model": relationship.mapper.class_.__name__, "type": relationship.direction.name, "related_table": related_table, } # Identify many-to-many by checking for secondary table if relationship.secondary is not None: squiggle = "}|--|{" relationship_info["type"] = "many-to-many" relationship_info["secondary_table"] = relationship.secondary.name relationship_info["squiggle"] = squiggle model_info["relationships"].append(relationship_info) return sort_data_structure(model_info) # type: ignore def introspect_models() -> dict[str, list[dict[str, Any]]]: """ Introspects SQLAlchemy models and returns a data structure that can be pass to a jinja2 template for rendering an ERD. Returns: -------- Dict[str, List[Dict[str, Any]]]: data structure for jinja2 template """ data: dict[str, list[dict[str, Any]]] = defaultdict(list) seen_models: set[str] = set() for model in db.Model.registry.mappers: group_name = ( TABLE_TO_GROUP_MAP.get(model.mapper.persist_selectable.name) or "Uncategorized Models" ) model_data = introspect_sqla_model(model, seen_models) data[group_name].append(model_data) return data def generate_erd(file_path: str) -> None: """ Generates a PlantUML ERD of the models/database Parameters: ----------- file_path: str File path to write the ERD to """ data = introspect_models() templates_path = os.path.dirname(__file__) env = jinja2.Environment(loader=jinja2.FileSystemLoader(templates_path)) # Load the template template = env.get_template("erd.template.puml") rendered = template.render(data=data) with open(file_path, "w") as f: click.secho(f"Writing to {file_path}...", fg="green") f.write(rendered) @click.command() @click.option( "--output", "-o", type=click.Path(dir_okay=False, writable=True), help="File to write the ERD to", ) def erd(output: Optional[str] = None) -> None: """ Generates a PlantUML ERD of the models/database Parameters: ----------- output: str, optional File to write the ERD to, defaults to erd.plantuml if not provided """ path = os.path.dirname(__file__) output = output or os.path.join(path, "erd.puml") from superset.app import create_app app = create_app() with app.app_context(): generate_erd(output) if __name__ == "__main__": erd()