Base provides a no-op default; drivers opt in by overriding. MSSQL scopes the lookup to a linked server / database when those qualifiers are supplied. Co-Authored-By: Claude Sonnet 4.6 (1M context) <noreply@anthropic.com>
177 lines
6.9 KiB
Python
177 lines
6.9 KiB
Python
"""PostgreSQL driver (also used as a destination target)."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from .base import (BrowseField, Driver, RemoteColumn, RemoteTable,
|
|
validate_identifier)
|
|
|
|
|
|
_TYPE_MAP = {
|
|
# Mostly identity — PG is the usual destination target, so mapping a PG
|
|
# source to PG dest is near-passthrough.
|
|
"smallint": "smallint", "integer": "integer", "bigint": "bigint",
|
|
"int": "integer", "int2": "smallint", "int4": "integer", "int8": "bigint",
|
|
"numeric": "numeric", "decimal": "numeric",
|
|
"real": "real", "double precision": "double precision",
|
|
"float4": "real", "float8": "double precision",
|
|
"text": "text", "varchar": "text", "char": "text", "bpchar": "text",
|
|
"character varying": "text", "character": "text",
|
|
"date": "date", "timestamp": "timestamp",
|
|
"timestamp without time zone": "timestamp",
|
|
"timestamp with time zone": "timestamptz", "timestamptz": "timestamptz",
|
|
"time": "time",
|
|
"boolean": "boolean", "bool": "boolean",
|
|
"bytea": "bytea",
|
|
"uuid": "uuid",
|
|
"json": "json", "jsonb": "jsonb",
|
|
}
|
|
|
|
|
|
def _base(type_raw: str) -> str:
|
|
return type_raw.lower().split("(", 1)[0].strip()
|
|
|
|
|
|
class PGDriver(Driver):
|
|
kind = "pg"
|
|
label = "PostgreSQL"
|
|
|
|
def browse_fields(self) -> list[BrowseField]:
|
|
return [
|
|
BrowseField(name="schema", label="Schema",
|
|
required=False, default="public"),
|
|
]
|
|
|
|
def list_schemas(self, conn, **_) -> list[str]:
|
|
result = self.query(
|
|
conn,
|
|
"SELECT schema_name FROM information_schema.schemata "
|
|
"WHERE schema_name NOT IN ('pg_catalog','information_schema') "
|
|
"AND schema_name NOT LIKE 'pg\\_%' ESCAPE '\\' "
|
|
"ORDER BY schema_name")
|
|
return [r[0].strip() for r in result.rows if r and r[0]]
|
|
|
|
def list_tables(self, conn, *, schema: str | None = None) -> list[RemoteTable]:
|
|
if schema:
|
|
validate_identifier(schema, "schema")
|
|
where = ["table_schema NOT IN ('pg_catalog','information_schema')"]
|
|
if schema:
|
|
where.append(f"table_schema = '{schema}'")
|
|
sql = (
|
|
"SELECT table_schema, table_name, table_type "
|
|
"FROM information_schema.tables "
|
|
f"WHERE {' AND '.join(where)} "
|
|
"ORDER BY table_schema, table_name"
|
|
)
|
|
result = self.query(conn, sql)
|
|
tables: list[RemoteTable] = []
|
|
for row in result.rows:
|
|
if len(row) < 3:
|
|
continue
|
|
sch, name, ttype = row[0].strip(), row[1].strip(), row[2].strip()
|
|
kind = "view" if ttype.upper() == "VIEW" else "table"
|
|
tables.append(RemoteTable(
|
|
schema=sch, name=name, kind=kind,
|
|
full_name=self.qualified_table_name(name, schema=sch),
|
|
))
|
|
return tables
|
|
|
|
def get_columns(
|
|
self, conn, table: str, *, schema: str | None = None,
|
|
) -> list[RemoteColumn]:
|
|
validate_identifier(table, "table")
|
|
if schema:
|
|
validate_identifier(schema, "schema")
|
|
sch = schema or "public"
|
|
where = [f"c.table_name = '{table}'", f"c.table_schema = '{sch}'"]
|
|
sql = (
|
|
"SELECT c.column_name, c.data_type, c.ordinal_position, c.is_nullable, "
|
|
" c.character_maximum_length, c.numeric_precision, c.numeric_scale, "
|
|
" COALESCE(pg_catalog.col_description("
|
|
" (quote_ident(c.table_schema) || '.' || quote_ident(c.table_name))::regclass, "
|
|
" c.ordinal_position::int), '') "
|
|
"FROM information_schema.columns c "
|
|
f"WHERE {' AND '.join(where)} "
|
|
"ORDER BY c.ordinal_position"
|
|
)
|
|
result = self.query(conn, sql)
|
|
cols: list[RemoteColumn] = []
|
|
for row in result.rows:
|
|
if len(row) < 4:
|
|
continue
|
|
name, dtype, pos, nullable = [c.strip() for c in row[:4]]
|
|
length = row[4].strip() if len(row) > 4 else ""
|
|
prec = row[5].strip() if len(row) > 5 else ""
|
|
scale = row[6].strip() if len(row) > 6 else ""
|
|
desc = row[7].strip() if len(row) > 7 else ""
|
|
type_raw = _format_type(dtype, length, prec, scale)
|
|
cols.append(RemoteColumn(
|
|
name=name, type_raw=type_raw,
|
|
position=int(pos), nullable=(nullable.upper() == "YES"),
|
|
description=desc or None,
|
|
))
|
|
return cols
|
|
|
|
def describe_table(
|
|
self, conn, table: str, *, schema: str | None = None,
|
|
) -> str | None:
|
|
validate_identifier(table, "table")
|
|
if schema:
|
|
validate_identifier(schema, "schema")
|
|
sch = schema or "public"
|
|
sql = (
|
|
"SELECT COALESCE(pg_catalog.obj_description("
|
|
f" (quote_ident('{sch}') || '.' || quote_ident('{table}'))::regclass, "
|
|
" 'pg_class'), '')"
|
|
)
|
|
result = self.query(conn, sql)
|
|
if not result.rows or not result.rows[0]:
|
|
return None
|
|
v = result.rows[0][0].strip()
|
|
return v or None
|
|
|
|
def qualified_table_name(
|
|
self, table: str, *, schema: str | None = None,
|
|
) -> str:
|
|
sch = schema or "public"
|
|
return f"{self.quote_identifier(sch)}.{self.quote_identifier(table)}"
|
|
|
|
def quote_identifier(self, name: str) -> str:
|
|
if name and name.islower() and name.replace("_", "").isalnum() and not name[0].isdigit():
|
|
return name
|
|
return '"' + name.replace('"', '""') + '"'
|
|
|
|
def default_expression(self, type_raw: str, column_name: str) -> str:
|
|
# PG doesn't pad char types and has honest NULLs — no shaping needed.
|
|
return self.quote_identifier(column_name)
|
|
|
|
def map_type(self, type_raw: str) -> str:
|
|
base = _base(type_raw)
|
|
mapped = _TYPE_MAP.get(base, "text")
|
|
if mapped == "numeric" and "(" in type_raw:
|
|
return "numeric" + type_raw[type_raw.index("("):]
|
|
return mapped
|
|
|
|
def build_create_table_sql(self, qualified_table: str,
|
|
columns: list[dict]) -> str:
|
|
if not columns:
|
|
raise ValueError("no columns provided for CREATE TABLE")
|
|
lines = []
|
|
for c in columns:
|
|
name = c["dest_name"]
|
|
validate_identifier(name, "dest column name")
|
|
dtype = (c.get("dest_type") or "text").strip()
|
|
if not dtype:
|
|
raise ValueError(f"column {name!r} has no dest_type")
|
|
lines.append(f" {self.quote_identifier(name)} {dtype}")
|
|
body = ",\n".join(lines)
|
|
return f"CREATE TABLE IF NOT EXISTS {qualified_table} (\n{body}\n);"
|
|
|
|
|
|
def _format_type(dtype: str, length: str, prec: str, scale: str) -> str:
|
|
base = dtype.lower()
|
|
if base in ("numeric", "decimal") and prec:
|
|
return f"{base}({prec},{scale or '0'})"
|
|
if base in ("character varying", "character") and length:
|
|
return f"{base}({length})"
|
|
return base
|