"""Introspect source systems — browse tables, fetch columns, generate queries and DDL.""" import csv import io import os import re import subprocess import tempfile from dataclasses import dataclass from config import get_config from engine.db import get_connection @dataclass class RemoteTable: schema: str name: str table_type: str linked_server: str = None linked_db: str = None @property def full_name(self) -> str: if self.linked_server: return f"[{self.linked_server}].[{self.linked_db}].{self.schema}.{self.name}" return f"{self.schema}.{self.name}" @property def type_label(self) -> str: mapping = { "BASE TABLE": "Table", "VIEW": "View", "P": "Table", "L": "View", "T": "Table", "V": "View", } return mapping.get(self.table_type, self.table_type) def to_dict(self) -> dict: return {"schema": self.schema, "name": self.name, "table_type": self.table_type, "type_label": self.type_label, "full_name": self.full_name, "linked_server": self.linked_server, "linked_db": self.linked_db} @dataclass class RemoteColumn: name: str data_type: str position: int nullable: bool = True def to_dict(self) -> dict: return {"name": self.name, "data_type": self.data_type, "position": self.position, "nullable": self.nullable} # --------------------------------------------------------------------------- # JDBC type to PostgreSQL type mapping # --------------------------------------------------------------------------- TYPE_MAP_PG = { # integers "int": "integer", "integer": "integer", "smallint": "smallint", "bigint": "bigint", "tinyint": "smallint", # floats "float": "double precision", "real": "real", "double": "double precision", # decimal "decimal": "numeric", "numeric": "numeric", "money": "numeric(19,4)", "smallmoney": "numeric(10,4)", # strings "varchar": "text", "char": "text", "nvarchar": "text", "nchar": "text", "text": "text", "ntext": "text", "character": "text", # dates "date": "date", "datetime": "timestamp", "datetime2": "timestamp", "smalldatetime": "timestamp", "timestamp": "timestamp", "timestamptz": "timestamptz", # boolean "bit": "boolean", # binary "binary": "bytea", "varbinary": "bytea", "image": "bytea", # uuid "uniqueidentifier": "uuid", } def map_type_pg(source_type: str) -> str: """Map a source column type to a PostgreSQL type.""" base = source_type.lower().split("(")[0].strip() return TYPE_MAP_PG.get(base, "text") # --------------------------------------------------------------------------- # jrunner query helper # --------------------------------------------------------------------------- def _resolve_password(password: str) -> str: """Resolve a password — if it starts with $, look up the env var.""" if password and password.startswith("$"): return os.environ.get(password[1:], "") return password or "" def run_jrunner_query(connection_id: int, sql: str) -> str: """Run a query via jrunner in CSV mode and return raw output.""" conn = get_connection(connection_id) if not conn: raise ValueError(f"Connection {connection_id} not found") cfg = get_config() jrunner = cfg["jrunner_path"] password = _resolve_password(conn["password"]) with tempfile.NamedTemporaryFile(mode="w", suffix=".sql", delete=False) as f: f.write(sql) sql_path = f.name try: result = subprocess.run( [jrunner, "-scu", conn["jdbc_url"], "-scn", conn["username"] or "", "-scp", password, "-sq", sql_path, "-f", "csv"], capture_output=True, text=True, timeout=60, ) if result.returncode != 0: raise RuntimeError(f"jrunner error: {result.stderr or result.stdout}") return result.stdout finally: os.unlink(sql_path) def _parse_csv(output: str) -> list[list[str]]: """Parse CSV output from jrunner, skipping the header.""" reader = csv.reader(io.StringIO(output)) header = next(reader, None) if not header: return [] return [row for row in reader if row] # --------------------------------------------------------------------------- # Table browsing # --------------------------------------------------------------------------- def _detect_source_type(jdbc_url: str) -> str: """Detect source type from JDBC URL.""" url = jdbc_url.lower() if "as400" in url: return "as400" if "sqlserver" in url: return "sqlserver" if "postgresql" in url: return "postgresql" if "clickhouse" in url: return "clickhouse" if "mysql" in url: return "mysql" return "unknown" def fetch_tables(connection_id: int, schema_filter: str = None) -> list[RemoteTable]: """Fetch list of tables and views from a source connection.""" conn = get_connection(connection_id) if not conn: raise ValueError(f"Connection {connection_id} not found") source_type = _detect_source_type(conn["jdbc_url"]) linked_server = None linked_db = None if source_type == "as400": sql = ( "SELECT TABLE_SCHEMA, TABLE_NAME, TABLE_TYPE " "FROM QSYS2.SYSTABLES " "WHERE TABLE_SCHEMA NOT LIKE 'Q%' " ) if schema_filter: sql += f"AND TABLE_SCHEMA = '{schema_filter}' " sql += "ORDER BY TABLE_SCHEMA, TABLE_NAME" elif source_type == "sqlserver": # Parse schema_filter formats: # "LINKED.DB" -> linked server + database # "LINKED.DB.SCHEMA" -> linked server + database + schema # ".DB" -> database only (no linked server) # ".DB.SCHEMA" -> database + schema # "SCHEMA" -> schema only (current database) linked_schema = None local_db = None if schema_filter and "." in schema_filter: parts = schema_filter.split(".") if parts[0] == "": # Starts with dot: ".DB" or ".DB.SCHEMA" local_db = parts[1] if len(parts) > 1 else None linked_schema = parts[2] if len(parts) > 2 else None elif len(parts) == 2: linked_server, linked_db = parts elif len(parts) >= 3: linked_server, linked_db, linked_schema = parts[0], parts[1], parts[2] if linked_server: sql = ( f"SELECT TABLE_SCHEMA, TABLE_NAME, TABLE_TYPE " f"FROM [{linked_server}].[{linked_db}].INFORMATION_SCHEMA.TABLES " f"WHERE TABLE_TYPE IN ('BASE TABLE','VIEW') " ) if linked_schema: sql += f"AND TABLE_SCHEMA = '{linked_schema}' " sql += "ORDER BY TABLE_SCHEMA, TABLE_NAME" elif local_db: sql = ( f"SELECT TABLE_SCHEMA, TABLE_NAME, TABLE_TYPE " f"FROM [{local_db}].INFORMATION_SCHEMA.TABLES " f"WHERE TABLE_TYPE IN ('BASE TABLE','VIEW') " ) if linked_schema: sql += f"AND TABLE_SCHEMA = '{linked_schema}' " sql += "ORDER BY TABLE_SCHEMA, TABLE_NAME" else: sql = ( "SELECT TABLE_SCHEMA, TABLE_NAME, TABLE_TYPE " "FROM INFORMATION_SCHEMA.TABLES " "WHERE TABLE_TYPE IN ('BASE TABLE','VIEW') " ) if schema_filter: sql += f"AND TABLE_SCHEMA = '{schema_filter}' " sql += "ORDER BY TABLE_SCHEMA, TABLE_NAME" elif source_type == "postgresql": sql = ( "SELECT table_schema, table_name, table_type " "FROM information_schema.tables " "WHERE table_schema NOT IN ('pg_catalog','information_schema') " ) if schema_filter: sql += f"AND table_schema = '{schema_filter}' " sql += "ORDER BY table_schema, table_name" elif source_type == "clickhouse": sql = ( "SELECT database AS TABLE_SCHEMA, name AS TABLE_NAME, engine AS TABLE_TYPE " "FROM system.tables " "WHERE database NOT IN ('system','INFORMATION_SCHEMA','information_schema') " ) if schema_filter: sql += f"AND database = '{schema_filter}' " sql += "ORDER BY database, name" elif source_type == "mysql": sql = ( "SELECT TABLE_SCHEMA, TABLE_NAME, TABLE_TYPE " "FROM INFORMATION_SCHEMA.TABLES " "WHERE TABLE_SCHEMA NOT IN ('mysql','information_schema','performance_schema','sys') " ) if schema_filter: sql += f"AND TABLE_SCHEMA = '{schema_filter}' " sql += "ORDER BY TABLE_SCHEMA, TABLE_NAME" else: # Generic fallback — INFORMATION_SCHEMA is widely supported sql = ( "SELECT TABLE_SCHEMA, TABLE_NAME, TABLE_TYPE " "FROM INFORMATION_SCHEMA.TABLES " "ORDER BY TABLE_SCHEMA, TABLE_NAME" ) # For database-only queries, store the db in linked_db so downstream can reference it effective_db = linked_db if linked_server else (local_db if source_type == "sqlserver" else None) rows = _parse_csv(run_jrunner_query(connection_id, sql)) return [RemoteTable(schema=r[0].strip(), name=r[1].strip(), table_type=r[2].strip(), linked_server=linked_server if source_type == "sqlserver" else None, linked_db=effective_db) for r in rows if len(r) >= 3] def fetch_columns(connection_id: int, schema: str, table: str, linked_server: str = None, linked_db: str = None) -> list[RemoteColumn]: """Fetch column metadata for a specific table.""" conn = get_connection(connection_id) if not conn: raise ValueError(f"Connection {connection_id} not found") source_type = _detect_source_type(conn["jdbc_url"]) if source_type == "as400": sql = ( f"SELECT COLUMN_NAME, DATA_TYPE, ORDINAL_POSITION " f"FROM QSYS2.SYSCOLUMNS " f"WHERE TABLE_SCHEMA = '{schema}' AND TABLE_NAME = '{table}' " f"ORDER BY ORDINAL_POSITION" ) elif source_type == "clickhouse": sql = ( f"SELECT name, type, position() " f"FROM system.columns " f"WHERE database = '{schema}' AND table = '{table}' " f"ORDER BY position" ) elif source_type == "sqlserver" and linked_server and linked_db: sql = ( f"SELECT COLUMN_NAME, DATA_TYPE, ORDINAL_POSITION " f"FROM [{linked_server}].[{linked_db}].INFORMATION_SCHEMA.COLUMNS " f"WHERE TABLE_SCHEMA = '{schema}' AND TABLE_NAME = '{table}' " f"ORDER BY ORDINAL_POSITION" ) elif source_type == "sqlserver" and linked_db: sql = ( f"SELECT COLUMN_NAME, DATA_TYPE, ORDINAL_POSITION " f"FROM [{linked_db}].INFORMATION_SCHEMA.COLUMNS " f"WHERE TABLE_SCHEMA = '{schema}' AND TABLE_NAME = '{table}' " f"ORDER BY ORDINAL_POSITION" ) else: # Works for SQL Server, PostgreSQL, MySQL sql = ( f"SELECT COLUMN_NAME, DATA_TYPE, ORDINAL_POSITION " f"FROM INFORMATION_SCHEMA.COLUMNS " f"WHERE TABLE_SCHEMA = '{schema}' AND TABLE_NAME = '{table}' " f"ORDER BY ORDINAL_POSITION" ) rows = _parse_csv(run_jrunner_query(connection_id, sql)) return [RemoteColumn(name=r[0].strip(), data_type=r[1].strip(), position=int(r[2].strip())) for r in rows if len(r) >= 3] # --------------------------------------------------------------------------- # Query and DDL generation # --------------------------------------------------------------------------- _IDENTIFIER_RE = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$') def _needs_quoting(name: str) -> bool: """Check if a column name needs quoting (has spaces, special chars, etc.).""" return not _IDENTIFIER_RE.match(name) def _safe_alias(name: str) -> str: """Generate a safe lowercase alias for a column name. Replaces special characters with underscores and strips leading/trailing underscores. If the result still needs quoting, wraps in double quotes. """ alias = re.sub(r'[^a-z0-9_]', '_', name.lower()) alias = re.sub(r'_+', '_', alias).strip('_') if not alias or not _IDENTIFIER_RE.match(alias): alias = f'"{alias}"' return alias def generate_select(connection_id: int, schema: str, table: str, columns: list[RemoteColumn] = None, linked_server: str = None, linked_db: str = None) -> str: """Generate a SELECT query from column metadata.""" if columns is None: columns = fetch_columns(connection_id, schema, table, linked_server=linked_server, linked_db=linked_db) conn = get_connection(connection_id) source_type = _detect_source_type(conn["jdbc_url"]) text_types = {"varchar", "char", "nvarchar", "nchar", "character", "text", "ntext"} lines = ["SELECT"] for i, col in enumerate(columns): prefix = " ," if i > 0 else " " alias = _safe_alias(col.name) # Quote source column name if it contains special characters # SQL Server uses [brackets], others use "double quotes" if _needs_quoting(col.name): if source_type == "sqlserver": col_ref = f"[{col.name}]" else: col_ref = f'"{col.name}"' else: col_ref = col.name base_type = col.data_type.lower().split("(")[0].strip() # RTRIM text columns for SQL Server and AS/400 (padded char fields) if base_type in text_types and source_type in ("sqlserver", "as400"): expr = f"RTRIM({col_ref})" lines.append(f"{prefix}{expr:<35} AS {alias}") else: lines.append(f"{prefix}{col_ref:<35} AS {alias}") lines.append("FROM") if linked_server and linked_db: lines.append(f" [{linked_server}].[{linked_db}].{schema}.{table}") elif linked_db: lines.append(f" [{linked_db}].{schema}.{table}") else: lines.append(f" {schema}.{table}") return "\n".join(lines) def generate_dest_ddl(dest_table: str, columns: list[RemoteColumn]) -> str: """Generate CREATE TABLE DDL for the destination (PostgreSQL).""" schema_table = dest_table lines = [f"CREATE TABLE IF NOT EXISTS {schema_table} ("] col_lines = [] for col in columns: pg_type = map_type_pg(col.data_type) col_name = _safe_alias(col.name) col_lines.append(f" {col_name:<30} {pg_type}") lines.append(",\n".join(col_lines)) lines.append(");") return "\n".join(lines) def propose_module(connection_id: int, schema: str, table: str, dest_schema: str = None, linked_server: str = None, linked_db: str = None) -> dict: """ Given a source table, propose a complete module config: - source_query (auto-generated SELECT with RTRIM) - dest_table - dest_ddl (CREATE TABLE for destination) - suggested merge_strategy - suggested merge_key (first column) - suggested watermark_column (if DEX_ROW_TS or similar found) """ columns = fetch_columns(connection_id, schema, table, linked_server=linked_server, linked_db=linked_db) source_query = generate_select(connection_id, schema, table, columns, linked_server=linked_server, linked_db=linked_db) # Propose destination table name if dest_schema is None: dest_schema = "public" dest_table = f"{dest_schema}.{table.lower()}" # Generate DDL dest_ddl = generate_dest_ddl(dest_table, columns) # Suggest merge strategy based on columns present col_names_lower = [c.name.lower() for c in columns] timestamp_col = None for candidate in ["dex_row_ts", "modified_date", "updated_at", "last_modified", "modifieddate", "changedate"]: if candidate in col_names_lower: timestamp_col = candidate break merge_key = columns[0].name.lower() if columns else None if timestamp_col: strategy = "incremental" else: strategy = "full" return { "name": table.lower(), "source_query": source_query, "dest_table": dest_table, "dest_ddl": dest_ddl, "columns": [c.to_dict() for c in columns], "merge_strategy": strategy, "merge_key": merge_key, "watermark_column": timestamp_col, }