""" data_collector.py - Training data pipeline from PostgreSQL. Pulls high-confidence approved outputs, human-edited preference pairs, and low-confidence negatives from the llm_gateway database. All queries are parameterised; no external data is trusted. """ from __future__ import annotations import logging import uuid from typing import Optional import psycopg2 import psycopg2.extras logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Positive examples (SFT) # --------------------------------------------------------------------------- def collect_positive_examples( conn: psycopg2.extensions.connection, task_type: Optional[str], min_confidence: float = 7.5, limit: int = 500, ) -> list[dict]: """ Pull high-confidence, approved outputs from the learning_corpus table. Filters: - confidence_score >= min_confidence - status = 'approved' - used_in_training IS NULL (not yet consumed) - Optionally scoped to a single task_type """ sql_base = """ SELECT id, task_type, input_text, output_text, system_prompt, confidence_score, created_at FROM learning_corpus WHERE status = 'approved' AND confidence_score >= %(min_confidence)s AND used_in_training IS NULL AND system_prompt IS NOT NULL AND input_text IS NOT NULL AND output_text IS NOT NULL """ params: dict = {"min_confidence": min_confidence, "limit": limit} if task_type is not None: sql_base += " AND task_type = %(task_type)s" params["task_type"] = task_type sql_base += " ORDER BY confidence_score DESC LIMIT %(limit)s" with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: cur.execute(sql_base, params) rows = cur.fetchall() result = [dict(r) for r in rows] logger.info( "collect_positive_examples: task_type=%s, min_confidence=%.1f → %d rows", task_type, min_confidence, len(result), ) return result # --------------------------------------------------------------------------- # Preference pairs (DPO) # --------------------------------------------------------------------------- def collect_preference_pairs( conn: psycopg2.extensions.connection, task_type: Optional[str], limit: int = 200, ) -> list[dict]: """ Pull human-edited output pairs for DPO training. A valid preference pair requires: - human_edited = TRUE - edited_output IS NOT NULL and differs from output_text - used_in_dpo_training IS NULL """ sql_base = """ SELECT id, task_type, input_text, output_text, edited_output, system_prompt, created_at FROM learning_corpus WHERE human_edited = TRUE AND edited_output IS NOT NULL AND edited_output <> output_text AND used_in_dpo_training IS NULL AND input_text IS NOT NULL AND output_text IS NOT NULL """ params: dict = {"limit": limit} if task_type is not None: sql_base += " AND task_type = %(task_type)s" params["task_type"] = task_type sql_base += " ORDER BY created_at DESC LIMIT %(limit)s" with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: cur.execute(sql_base, params) rows = cur.fetchall() result = [dict(r) for r in rows] logger.info( "collect_preference_pairs: task_type=%s → %d pairs", task_type, len(result), ) return result # --------------------------------------------------------------------------- # Negative examples (optional, for debugging / contrastive studies) # --------------------------------------------------------------------------- def collect_negative_examples( conn: psycopg2.extensions.connection, task_type: Optional[str], max_confidence: float = 4.0, limit: int = 200, ) -> list[dict]: """ Pull low-confidence outputs — useful for contrastive analysis and understanding failure modes, but NOT included in SFT datasets directly. """ sql_base = """ SELECT id, task_type, input_text, output_text, system_prompt, confidence_score, created_at FROM learning_corpus WHERE confidence_score <= %(max_confidence)s AND status IN ('rejected', 'reviewed') AND input_text IS NOT NULL AND output_text IS NOT NULL """ params: dict = {"max_confidence": max_confidence, "limit": limit} if task_type is not None: sql_base += " AND task_type = %(task_type)s" params["task_type"] = task_type sql_base += " ORDER BY confidence_score ASC LIMIT %(limit)s" with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: cur.execute(sql_base, params) rows = cur.fetchall() result = [dict(r) for r in rows] logger.info( "collect_negative_examples: task_type=%s, max_confidence=%.1f → %d rows", task_type, max_confidence, len(result), ) return result # --------------------------------------------------------------------------- # Mark consumed examples # --------------------------------------------------------------------------- def mark_as_used( conn: psycopg2.extensions.connection, example_ids: list[str], run_id: str, ) -> None: """ Stamp consumed SFT examples with the run_id so they are not selected again. Uses a single parameterised UPDATE; never formats IDs into SQL strings. """ if not example_ids: return sql = """ UPDATE learning_corpus SET used_in_training = %(run_id)s WHERE id = ANY(%(ids)s::uuid[]) """ with conn.cursor() as cur: cur.execute(sql, {"run_id": run_id, "ids": example_ids}) conn.commit() logger.info("mark_as_used: stamped %d examples with run_id=%s", len(example_ids), run_id) def mark_as_used_dpo( conn: psycopg2.extensions.connection, example_ids: list[str], run_id: str, ) -> None: """Stamp consumed DPO preference pairs with the run_id.""" if not example_ids: return sql = """ UPDATE learning_corpus SET used_in_dpo_training = %(run_id)s WHERE id = ANY(%(ids)s::uuid[]) """ with conn.cursor() as cur: cur.execute(sql, {"run_id": run_id, "ids": example_ids}) conn.commit() logger.info( "mark_as_used_dpo: stamped %d preference pairs with run_id=%s", len(example_ids), run_id, ) # --------------------------------------------------------------------------- # Corpus statistics # --------------------------------------------------------------------------- def get_corpus_stats(conn: psycopg2.extensions.connection) -> dict: """ Return a snapshot of the learning corpus useful for trigger decisions. Returns counts per task_type plus global DPO pair count. """ task_sql = """ SELECT task_type, COUNT(*) AS total, COUNT(*) FILTER ( WHERE status = 'approved' AND confidence_score >= 7.5 AND used_in_training IS NULL ) AS available_positive FROM learning_corpus WHERE input_text IS NOT NULL AND output_text IS NOT NULL GROUP BY task_type """ dpo_sql = """ SELECT COUNT(*) AS dpo_count FROM learning_corpus WHERE human_edited = TRUE AND edited_output IS NOT NULL AND edited_output <> output_text AND used_in_dpo_training IS NULL """ stats: dict = {"by_task_type": {}, "dpo_pairs_available": 0} with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: cur.execute(task_sql) for row in cur.fetchall(): stats["by_task_type"][row["task_type"]] = { "total": row["total"], "available_positive": row["available_positive"], } cur.execute(dpo_sql) row = cur.fetchone() stats["dpo_pairs_available"] = int(row["dpo_count"]) if row else 0 return stats