"""PostgreSQL driver (also used as a destination target).""" from __future__ import annotations from .base import (BrowseField, Driver, RemoteColumn, RemoteTable, validate_identifier) _TYPE_MAP = { # Mostly identity — PG is the usual destination target, so mapping a PG # source to PG dest is near-passthrough. "smallint": "smallint", "integer": "integer", "bigint": "bigint", "int": "integer", "int2": "smallint", "int4": "integer", "int8": "bigint", "numeric": "numeric", "decimal": "numeric", "real": "real", "double precision": "double precision", "float4": "real", "float8": "double precision", "text": "text", "varchar": "text", "char": "text", "bpchar": "text", "character varying": "text", "character": "text", "date": "date", "timestamp": "timestamp", "timestamp without time zone": "timestamp", "timestamp with time zone": "timestamptz", "timestamptz": "timestamptz", "time": "time", "boolean": "boolean", "bool": "boolean", "bytea": "bytea", "uuid": "uuid", "json": "json", "jsonb": "jsonb", } def _base(type_raw: str) -> str: return type_raw.lower().split("(", 1)[0].strip() class PGDriver(Driver): kind = "pg" label = "PostgreSQL" def browse_fields(self) -> list[BrowseField]: return [ BrowseField(name="schema", label="Schema", required=False, default="public"), ] def list_schemas(self, conn, **_) -> list[str]: result = self.query( conn, "SELECT schema_name FROM information_schema.schemata " "WHERE schema_name NOT IN ('pg_catalog','information_schema') " "AND schema_name NOT LIKE 'pg\\_%' ESCAPE '\\' " "ORDER BY schema_name") return [r[0].strip() for r in result.rows if r and r[0]] def list_tables(self, conn, *, schema: str | None = None) -> list[RemoteTable]: if schema: validate_identifier(schema, "schema") where = ["table_schema NOT IN ('pg_catalog','information_schema')"] if schema: where.append(f"table_schema = '{schema}'") sql = ( "SELECT table_schema, table_name, table_type " "FROM information_schema.tables " f"WHERE {' AND '.join(where)} " "ORDER BY table_schema, table_name" ) result = self.query(conn, sql) tables: list[RemoteTable] = [] for row in result.rows: if len(row) < 3: continue sch, name, ttype = row[0].strip(), row[1].strip(), row[2].strip() kind = "view" if ttype.upper() == "VIEW" else "table" tables.append(RemoteTable( schema=sch, name=name, kind=kind, full_name=self.qualified_table_name(name, schema=sch), )) return tables def get_columns( self, conn, table: str, *, schema: str | None = None, ) -> list[RemoteColumn]: validate_identifier(table, "table") if schema: validate_identifier(schema, "schema") sch = schema or "public" where = [f"c.table_name = '{table}'", f"c.table_schema = '{sch}'"] sql = ( "SELECT c.column_name, c.data_type, c.ordinal_position, c.is_nullable, " " c.character_maximum_length, c.numeric_precision, c.numeric_scale, " " COALESCE(pg_catalog.col_description(" " (quote_ident(c.table_schema) || '.' || quote_ident(c.table_name))::regclass, " " c.ordinal_position::int), '') " "FROM information_schema.columns c " f"WHERE {' AND '.join(where)} " "ORDER BY c.ordinal_position" ) result = self.query(conn, sql) cols: list[RemoteColumn] = [] for row in result.rows: if len(row) < 4: continue name, dtype, pos, nullable = [c.strip() for c in row[:4]] length = row[4].strip() if len(row) > 4 else "" prec = row[5].strip() if len(row) > 5 else "" scale = row[6].strip() if len(row) > 6 else "" desc = row[7].strip() if len(row) > 7 else "" type_raw = _format_type(dtype, length, prec, scale) cols.append(RemoteColumn( name=name, type_raw=type_raw, position=int(pos), nullable=(nullable.upper() == "YES"), description=desc or None, )) return cols def describe_table( self, conn, table: str, *, schema: str | None = None, ) -> str | None: validate_identifier(table, "table") if schema: validate_identifier(schema, "schema") sch = schema or "public" sql = ( "SELECT COALESCE(pg_catalog.obj_description(" f" (quote_ident('{sch}') || '.' || quote_ident('{table}'))::regclass, " " 'pg_class'), '')" ) result = self.query(conn, sql) if not result.rows or not result.rows[0]: return None v = result.rows[0][0].strip() return v or None def qualified_table_name( self, table: str, *, schema: str | None = None, ) -> str: sch = schema or "public" return f"{self.quote_identifier(sch)}.{self.quote_identifier(table)}" def quote_identifier(self, name: str) -> str: if name and name.islower() and name.replace("_", "").isalnum() and not name[0].isdigit(): return name return '"' + name.replace('"', '""') + '"' def default_expression(self, type_raw: str, column_name: str) -> str: # PG doesn't pad char types and has honest NULLs — no shaping needed. return self.quote_identifier(column_name) def map_type(self, type_raw: str) -> str: base = _base(type_raw) mapped = _TYPE_MAP.get(base, "text") if mapped == "numeric" and "(" in type_raw: return "numeric" + type_raw[type_raw.index("("):] return mapped def build_create_table_sql(self, qualified_table: str, columns: list[dict]) -> str: if not columns: raise ValueError("no columns provided for CREATE TABLE") lines = [] for c in columns: name = c["dest_name"] validate_identifier(name, "dest column name") dtype = (c.get("dest_type") or "text").strip() if not dtype: raise ValueError(f"column {name!r} has no dest_type") lines.append(f" {self.quote_identifier(name)} {dtype}") body = ",\n".join(lines) return f"CREATE TABLE IF NOT EXISTS {qualified_table} (\n{body}\n);" def _format_type(dtype: str, length: str, prec: str, scale: str) -> str: base = dtype.lower() if base in ("numeric", "decimal") and prec: return f"{base}({prec},{scale or '0'})" if base in ("character varying", "character") and length: return f"{base}({length})" return base