pipekit/pipekit/drivers/db2.py
Paul Trowbridge 574ada5258 Initial commit: Pipekit rewrite.
Orchestration layer around the jrunner Java JDBC CLI, replacing the
previous shell-based sync system in .archive/pre-rewrite. Includes
the FastAPI + Jinja web frontend, per-driver adapters (DB2, MSSQL,
PG), wizard-driven module creation with editable dest types and
source-sourced table/column descriptions, watermark/hook CRUD,
and the engine that runs modules end-to-end.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-04-22 00:38:26 -04:00

146 lines
5.5 KiB
Python

"""IBM i / DB2 for i driver (jt400)."""
from __future__ import annotations
from .base import (BrowseField, Driver, RemoteColumn, RemoteTable,
validate_identifier)
_TEXT_TYPES = {"char", "varchar", "nchar", "nvarchar", "graphic", "vargraphic",
"clob", "nclob"}
_DATE_TYPES = {"date"}
_TYPE_MAP = {
"smallint": "smallint", "integer": "integer", "int": "integer",
"bigint": "bigint",
"decimal": "numeric", "numeric": "numeric",
"real": "real", "float": "double precision", "double": "double precision",
"char": "text", "varchar": "text", "nchar": "text", "nvarchar": "text",
"graphic": "text", "vargraphic": "text", "clob": "text", "nclob": "text",
"date": "date", "time": "time", "timestamp": "timestamp",
"blob": "bytea", "binary": "bytea", "varbinary": "bytea",
"rowid": "text",
}
_SAFE_IDENT_CHARS = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_")
def _base(type_raw: str) -> str:
return type_raw.lower().split("(", 1)[0].strip()
def _needs_quoting(name: str) -> bool:
return bool(name) and (not name[0].isalpha() and name[0] != "_"
or any(c not in _SAFE_IDENT_CHARS for c in name))
class DB2Driver(Driver):
kind = "db2"
label = "IBM i / DB2 for i"
def browse_fields(self) -> list[BrowseField]:
return [
BrowseField(name="schema", label="Schema / library",
required=True,
help="e.g. RLDBF12"),
]
def list_tables(self, conn, *, schema: str) -> list[RemoteTable]:
validate_identifier(schema, "schema")
sql = (
"SELECT TABLE_SCHEMA, TABLE_NAME, TABLE_TYPE "
"FROM QSYS2.SYSTABLES "
f"WHERE TABLE_SCHEMA = '{schema}' "
"ORDER BY TABLE_NAME"
)
result = self.query(conn, sql)
tables: list[RemoteTable] = []
for row in result.rows:
if len(row) < 3:
continue
sch, name, ttype = row[0].strip(), row[1].strip(), row[2].strip()
kind = "view" if ttype in ("L", "V") else "table"
tables.append(RemoteTable(
schema=sch, name=name, kind=kind,
full_name=self.qualified_table_name(name, schema=sch),
))
return tables
def get_columns(self, conn, table: str, *, schema: str) -> list[RemoteColumn]:
validate_identifier(schema, "schema")
validate_identifier(table, "table")
sql = (
"SELECT COLUMN_NAME, DATA_TYPE, ORDINAL_POSITION, IS_NULLABLE, "
" LENGTH, NUMERIC_PRECISION, NUMERIC_SCALE, "
" COALESCE(COLUMN_TEXT, COLUMN_HEADING, '') "
"FROM QSYS2.SYSCOLUMNS "
f"WHERE TABLE_SCHEMA = '{schema}' AND TABLE_NAME = '{table}' "
"ORDER BY ORDINAL_POSITION"
)
result = self.query(conn, sql)
cols: list[RemoteColumn] = []
for row in result.rows:
if len(row) < 4:
continue
name, dtype, pos, nullable = [c.strip() for c in row[:4]]
length = row[4].strip() if len(row) > 4 else ""
prec = row[5].strip() if len(row) > 5 else ""
scale = row[6].strip() if len(row) > 6 else ""
desc = row[7].strip() if len(row) > 7 else ""
type_raw = _format_type(dtype, length, prec, scale)
cols.append(RemoteColumn(
name=name, type_raw=type_raw,
position=int(pos), nullable=(nullable.upper() == "Y"),
description=desc or None,
))
return cols
def describe_table(self, conn, table: str, *, schema: str) -> str | None:
validate_identifier(schema, "schema")
validate_identifier(table, "table")
sql = (
"SELECT COALESCE(TABLE_TEXT, LONG_COMMENT, '') "
"FROM QSYS2.SYSTABLES "
f"WHERE TABLE_SCHEMA = '{schema}' AND TABLE_NAME = '{table}' "
"FETCH FIRST 1 ROWS ONLY"
)
result = self.query(conn, sql)
if not result.rows or not result.rows[0]:
return None
v = result.rows[0][0].strip()
return v or None
def qualified_table_name(self, table: str, *, schema: str) -> str:
return f"{self.quote_identifier(schema)}.{self.quote_identifier(table)}"
def quote_identifier(self, name: str) -> str:
if _needs_quoting(name):
return '"' + name.replace('"', '""') + '"'
return name
def default_expression(self, type_raw: str, column_name: str) -> str:
col = self.quote_identifier(column_name)
base = _base(type_raw)
if base in _TEXT_TYPES:
return f"RTRIM({col})"
if base in _DATE_TYPES:
return (f"CASE WHEN {col} IN (DATE('0001-01-01'), DATE('9999-12-31')) "
f"THEN NULL ELSE {col} END")
return col
def map_type(self, type_raw: str) -> str:
base = _base(type_raw)
mapped = _TYPE_MAP.get(base, "text")
if mapped == "numeric" and "(" in type_raw:
return "numeric" + type_raw[type_raw.index("("):]
return mapped
def _format_type(dtype: str, length: str, prec: str, scale: str) -> str:
base = dtype.upper()
if base in ("DECIMAL", "NUMERIC") and prec:
return f"{base}({prec},{scale or '0'})"
if base in ("CHAR", "VARCHAR", "NCHAR", "NVARCHAR",
"GRAPHIC", "VARGRAPHIC") and length:
return f"{base}({length})"
return base