pipekit/pipekit/drivers/pg.py
Paul Trowbridge ff19ae9b81 Drivers: add list_schemas() to base, PG, DB2, MSSQL
Base provides a no-op default; drivers opt in by overriding. MSSQL
scopes the lookup to a linked server / database when those qualifiers
are supplied.

Co-Authored-By: Claude Sonnet 4.6 (1M context) <noreply@anthropic.com>
2026-05-02 20:56:10 -04:00

177 lines
6.9 KiB
Python

"""PostgreSQL driver (also used as a destination target)."""
from __future__ import annotations
from .base import (BrowseField, Driver, RemoteColumn, RemoteTable,
validate_identifier)
_TYPE_MAP = {
# Mostly identity — PG is the usual destination target, so mapping a PG
# source to PG dest is near-passthrough.
"smallint": "smallint", "integer": "integer", "bigint": "bigint",
"int": "integer", "int2": "smallint", "int4": "integer", "int8": "bigint",
"numeric": "numeric", "decimal": "numeric",
"real": "real", "double precision": "double precision",
"float4": "real", "float8": "double precision",
"text": "text", "varchar": "text", "char": "text", "bpchar": "text",
"character varying": "text", "character": "text",
"date": "date", "timestamp": "timestamp",
"timestamp without time zone": "timestamp",
"timestamp with time zone": "timestamptz", "timestamptz": "timestamptz",
"time": "time",
"boolean": "boolean", "bool": "boolean",
"bytea": "bytea",
"uuid": "uuid",
"json": "json", "jsonb": "jsonb",
}
def _base(type_raw: str) -> str:
return type_raw.lower().split("(", 1)[0].strip()
class PGDriver(Driver):
kind = "pg"
label = "PostgreSQL"
def browse_fields(self) -> list[BrowseField]:
return [
BrowseField(name="schema", label="Schema",
required=False, default="public"),
]
def list_schemas(self, conn, **_) -> list[str]:
result = self.query(
conn,
"SELECT schema_name FROM information_schema.schemata "
"WHERE schema_name NOT IN ('pg_catalog','information_schema') "
"AND schema_name NOT LIKE 'pg\\_%' ESCAPE '\\' "
"ORDER BY schema_name")
return [r[0].strip() for r in result.rows if r and r[0]]
def list_tables(self, conn, *, schema: str | None = None) -> list[RemoteTable]:
if schema:
validate_identifier(schema, "schema")
where = ["table_schema NOT IN ('pg_catalog','information_schema')"]
if schema:
where.append(f"table_schema = '{schema}'")
sql = (
"SELECT table_schema, table_name, table_type "
"FROM information_schema.tables "
f"WHERE {' AND '.join(where)} "
"ORDER BY table_schema, table_name"
)
result = self.query(conn, sql)
tables: list[RemoteTable] = []
for row in result.rows:
if len(row) < 3:
continue
sch, name, ttype = row[0].strip(), row[1].strip(), row[2].strip()
kind = "view" if ttype.upper() == "VIEW" else "table"
tables.append(RemoteTable(
schema=sch, name=name, kind=kind,
full_name=self.qualified_table_name(name, schema=sch),
))
return tables
def get_columns(
self, conn, table: str, *, schema: str | None = None,
) -> list[RemoteColumn]:
validate_identifier(table, "table")
if schema:
validate_identifier(schema, "schema")
sch = schema or "public"
where = [f"c.table_name = '{table}'", f"c.table_schema = '{sch}'"]
sql = (
"SELECT c.column_name, c.data_type, c.ordinal_position, c.is_nullable, "
" c.character_maximum_length, c.numeric_precision, c.numeric_scale, "
" COALESCE(pg_catalog.col_description("
" (quote_ident(c.table_schema) || '.' || quote_ident(c.table_name))::regclass, "
" c.ordinal_position::int), '') "
"FROM information_schema.columns c "
f"WHERE {' AND '.join(where)} "
"ORDER BY c.ordinal_position"
)
result = self.query(conn, sql)
cols: list[RemoteColumn] = []
for row in result.rows:
if len(row) < 4:
continue
name, dtype, pos, nullable = [c.strip() for c in row[:4]]
length = row[4].strip() if len(row) > 4 else ""
prec = row[5].strip() if len(row) > 5 else ""
scale = row[6].strip() if len(row) > 6 else ""
desc = row[7].strip() if len(row) > 7 else ""
type_raw = _format_type(dtype, length, prec, scale)
cols.append(RemoteColumn(
name=name, type_raw=type_raw,
position=int(pos), nullable=(nullable.upper() == "YES"),
description=desc or None,
))
return cols
def describe_table(
self, conn, table: str, *, schema: str | None = None,
) -> str | None:
validate_identifier(table, "table")
if schema:
validate_identifier(schema, "schema")
sch = schema or "public"
sql = (
"SELECT COALESCE(pg_catalog.obj_description("
f" (quote_ident('{sch}') || '.' || quote_ident('{table}'))::regclass, "
" 'pg_class'), '')"
)
result = self.query(conn, sql)
if not result.rows or not result.rows[0]:
return None
v = result.rows[0][0].strip()
return v or None
def qualified_table_name(
self, table: str, *, schema: str | None = None,
) -> str:
sch = schema or "public"
return f"{self.quote_identifier(sch)}.{self.quote_identifier(table)}"
def quote_identifier(self, name: str) -> str:
if name and name.islower() and name.replace("_", "").isalnum() and not name[0].isdigit():
return name
return '"' + name.replace('"', '""') + '"'
def default_expression(self, type_raw: str, column_name: str) -> str:
# PG doesn't pad char types and has honest NULLs — no shaping needed.
return self.quote_identifier(column_name)
def map_type(self, type_raw: str) -> str:
base = _base(type_raw)
mapped = _TYPE_MAP.get(base, "text")
if mapped == "numeric" and "(" in type_raw:
return "numeric" + type_raw[type_raw.index("("):]
return mapped
def build_create_table_sql(self, qualified_table: str,
columns: list[dict]) -> str:
if not columns:
raise ValueError("no columns provided for CREATE TABLE")
lines = []
for c in columns:
name = c["dest_name"]
validate_identifier(name, "dest column name")
dtype = (c.get("dest_type") or "text").strip()
if not dtype:
raise ValueError(f"column {name!r} has no dest_type")
lines.append(f" {self.quote_identifier(name)} {dtype}")
body = ",\n".join(lines)
return f"CREATE TABLE IF NOT EXISTS {qualified_table} (\n{body}\n);"
def _format_type(dtype: str, length: str, prec: str, scale: str) -> str:
base = dtype.lower()
if base in ("numeric", "decimal") and prec:
return f"{base}({prec},{scale or '0'})"
if base in ("character varying", "character") and length:
return f"{base}({length})"
return base