pipekit/pipekit/drivers/mssql.py
Paul Trowbridge 574ada5258 Initial commit: Pipekit rewrite.
Orchestration layer around the jrunner Java JDBC CLI, replacing the
previous shell-based sync system in .archive/pre-rewrite. Includes
the FastAPI + Jinja web frontend, per-driver adapters (DB2, MSSQL,
PG), wizard-driven module creation with editable dest types and
source-sourced table/column descriptions, watermark/hook CRUD,
and the engine that runs modules end-to-end.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-04-22 00:38:26 -04:00

229 lines
9.0 KiB
Python

"""Microsoft SQL Server driver (mssql-jdbc).
Structured qualifiers instead of the pre-rewrite dotted-string hack: each
field — linked server, database, schema — is a separate form input, and
only the ones the user fills in show up in the generated FROM clause.
"""
from __future__ import annotations
from .base import (BrowseField, Driver, RemoteColumn, RemoteTable,
validate_identifier)
_TEXT_TYPES = {"char", "varchar", "nchar", "nvarchar", "text", "ntext"}
_TYPE_MAP = {
"tinyint": "smallint", "smallint": "smallint",
"int": "integer", "integer": "integer", "bigint": "bigint",
"decimal": "numeric", "numeric": "numeric",
"money": "numeric(19,4)", "smallmoney": "numeric(10,4)",
"real": "real", "float": "double precision",
"char": "text", "varchar": "text", "nchar": "text", "nvarchar": "text",
"text": "text", "ntext": "text",
"date": "date", "datetime": "timestamp", "datetime2": "timestamp",
"smalldatetime": "timestamp", "datetimeoffset": "timestamptz",
"time": "time",
"bit": "boolean",
"binary": "bytea", "varbinary": "bytea", "image": "bytea",
"uniqueidentifier": "uuid",
}
def _base(type_raw: str) -> str:
return type_raw.lower().split("(", 1)[0].strip()
class MSSQLDriver(Driver):
kind = "mssql"
label = "Microsoft SQL Server"
def browse_fields(self) -> list[BrowseField]:
return [
BrowseField(name="linked_server", label="Linked server",
required=False,
help="only for cross-server lookups; usually blank"),
BrowseField(name="database", label="Database",
required=False,
help="leave blank to use the connection's current DB"),
BrowseField(name="schema", label="Schema",
required=False, default="dbo"),
]
def list_tables(
self, conn, *, linked_server: str | None = None,
database: str | None = None, schema: str | None = None,
) -> list[RemoteTable]:
self._validate(linked_server, database, schema)
prefix = self._info_schema_prefix(linked_server, database)
where = ["TABLE_TYPE IN ('BASE TABLE','VIEW')"]
if schema:
where.append(f"TABLE_SCHEMA = '{schema}'")
sql = (
f"SELECT TABLE_SCHEMA, TABLE_NAME, TABLE_TYPE "
f"FROM {prefix}INFORMATION_SCHEMA.TABLES "
f"WHERE {' AND '.join(where)} "
f"ORDER BY TABLE_SCHEMA, TABLE_NAME"
)
result = self.query(conn, sql)
tables: list[RemoteTable] = []
for row in result.rows:
if len(row) < 3:
continue
sch, name, ttype = row[0].strip(), row[1].strip(), row[2].strip()
kind = "view" if ttype.upper() == "VIEW" else "table"
tables.append(RemoteTable(
schema=sch, name=name, kind=kind,
full_name=self.qualified_table_name(
name, schema=sch, database=database,
linked_server=linked_server),
))
return tables
def get_columns(
self, conn, table: str, *, linked_server: str | None = None,
database: str | None = None, schema: str | None = None,
) -> list[RemoteColumn]:
validate_identifier(table, "table")
self._validate(linked_server, database, schema)
prefix = self._info_schema_prefix(linked_server, database)
where = [f"TABLE_NAME = '{table}'"]
if schema:
where.append(f"TABLE_SCHEMA = '{schema}'")
sql = (
f"SELECT COLUMN_NAME, DATA_TYPE, ORDINAL_POSITION, IS_NULLABLE, "
f" CHARACTER_MAXIMUM_LENGTH, NUMERIC_PRECISION, NUMERIC_SCALE "
f"FROM {prefix}INFORMATION_SCHEMA.COLUMNS "
f"WHERE {' AND '.join(where)} "
f"ORDER BY ORDINAL_POSITION"
)
result = self.query(conn, sql)
cols: list[RemoteColumn] = []
for row in result.rows:
if len(row) < 4:
continue
name, dtype, pos, nullable = [c.strip() for c in row[:4]]
length = row[4].strip() if len(row) > 4 else ""
prec = row[5].strip() if len(row) > 5 else ""
scale = row[6].strip() if len(row) > 6 else ""
type_raw = _format_type(dtype, length, prec, scale)
cols.append(RemoteColumn(
name=name, type_raw=type_raw,
position=int(pos), nullable=(nullable.upper() == "YES"),
))
# Extended-property descriptions live in sys.extended_properties,
# which isn't available over a linked-server call from this side.
if not linked_server:
descs = self._column_descriptions(conn, table, database=database,
schema=schema or "dbo")
for c in cols:
c.description = descs.get(c.name) or None
return cols
def describe_table(
self, conn, table: str, *, linked_server: str | None = None,
database: str | None = None, schema: str | None = None,
) -> str | None:
validate_identifier(table, "table")
self._validate(linked_server, database, schema)
if linked_server:
return None
sch = schema or "dbo"
db_prefix = f"[{database}]." if database else ""
sql = (
f"SELECT CAST(ep.value AS NVARCHAR(MAX)) "
f"FROM {db_prefix}sys.extended_properties ep "
f"JOIN {db_prefix}sys.tables t ON t.object_id = ep.major_id "
f"JOIN {db_prefix}sys.schemas s ON s.schema_id = t.schema_id "
f"WHERE ep.class = 1 AND ep.minor_id = 0 "
f"AND ep.name = 'MS_Description' "
f"AND s.name = '{sch}' AND t.name = '{table}'"
)
result = self.query(conn, sql)
if not result.rows or not result.rows[0]:
return None
v = result.rows[0][0].strip()
return v or None
def _column_descriptions(
self, conn, table: str, *, database: str | None, schema: str,
) -> dict[str, str]:
db_prefix = f"[{database}]." if database else ""
sql = (
f"SELECT c.name, CAST(ep.value AS NVARCHAR(MAX)) "
f"FROM {db_prefix}sys.extended_properties ep "
f"JOIN {db_prefix}sys.columns c "
f" ON c.object_id = ep.major_id AND c.column_id = ep.minor_id "
f"JOIN {db_prefix}sys.tables t ON t.object_id = c.object_id "
f"JOIN {db_prefix}sys.schemas s ON s.schema_id = t.schema_id "
f"WHERE ep.class = 1 AND ep.name = 'MS_Description' "
f"AND s.name = '{schema}' AND t.name = '{table}'"
)
result = self.query(conn, sql)
out: dict[str, str] = {}
for row in result.rows:
if len(row) < 2:
continue
name = row[0].strip()
desc = row[1].strip()
if name and desc:
out[name] = desc
return out
def qualified_table_name(
self, table: str, *, linked_server: str | None = None,
database: str | None = None, schema: str | None = None,
) -> str:
parts = []
if linked_server:
parts.append(self.quote_identifier(linked_server))
parts.append(self.quote_identifier(database or ""))
elif database:
parts.append(self.quote_identifier(database))
parts.append(self.quote_identifier(schema or "dbo"))
parts.append(self.quote_identifier(table))
return ".".join(parts)
def quote_identifier(self, name: str) -> str:
if not name:
return ""
return "[" + name.replace("]", "]]") + "]"
def default_expression(self, type_raw: str, column_name: str) -> str:
col = self.quote_identifier(column_name)
if _base(type_raw) in _TEXT_TYPES:
return f"RTRIM({col})"
return col
def map_type(self, type_raw: str) -> str:
base = _base(type_raw)
mapped = _TYPE_MAP.get(base, "text")
if mapped == "numeric" and "(" in type_raw:
return "numeric" + type_raw[type_raw.index("("):]
return mapped
# ---- helpers ----
def _validate(self, linked_server, database, schema):
if linked_server:
validate_identifier(linked_server, "linked_server")
if database:
validate_identifier(database, "database")
if schema:
validate_identifier(schema, "schema")
def _info_schema_prefix(self, linked_server, database) -> str:
if linked_server:
return f"[{linked_server}].[{database or ''}]."
if database:
return f"[{database}]."
return ""
def _format_type(dtype: str, length: str, prec: str, scale: str) -> str:
base = dtype.upper()
if base in ("DECIMAL", "NUMERIC") and prec:
return f"{base}({prec},{scale or '0'})"
if base in ("CHAR", "VARCHAR", "NCHAR", "NVARCHAR") and length and length != "-1":
return f"{base}({length})"
return base