llm-gateway/packages/fine-tuner/scripts/generate_dpo_pairs.py
Rene Fichtmueller 2ca77d0aee feat: Phase 2F — Multi-Agent Integration (ADRs + Client Fallback + Tests)
- ADR-0001: Multi-Agent Coworking Architecture with LLM Gateway Orchestrator
- ADR-0002: Tier Assignment Strategy for Model Selection (cost-first escalation)
- ADR-0003: Confidence Gate Thresholds & Learning Cycle Intervals (6h/12h/24h cycles)
- ADR-0004: External Provider Fallback Chain Ordering (Cerebras → Groq → Mistral)
- Enhanced client SDK: Offline Ollama fallback, health checks, exponential backoff retry
- Integration tests: claude-code-integration.test.ts (14 test cases)
- PHASE_2F_DEPLOYMENT.md: Pre-deployment checklist, automated deploy, rollback plan
- Post-deployment verification procedures for health, client fallback, metrics
2026-04-19 21:39:44 +02:00

338 lines
13 KiB
Python

#!/usr/bin/env python3
"""
generate_dpo_pairs.py — Generate DPO (Direct Preference Optimization) training pairs
Creates "rejected" (bad) versions of existing good blog posts, forming
{prompt, chosen, rejected} triplets for DPO fine-tuning.
Bad patterns to inject (matching fo-blog-v6 failure modes):
1. Missing intro — jumps directly into a section without hook
2. Too long — 2500+ words with repetition
3. Topic drift — switches to a generic topic after 1 paragraph
4. Repeated sections — copy-pastes paragraphs verbatim
5. No structure — wall of text without ## headers
Input:
~/transceiver-training-data/v7-generated-sft.jsonl (good outputs from generate_v7_data.py)
Output:
~/transceiver-training-data/v7-dpo-pairs.jsonl (prompt/chosen/rejected triplets)
Usage:
python3 scripts/generate_dpo_pairs.py
python3 scripts/generate_dpo_pairs.py --input v7-generated-sft.jsonl --max 100
"""
from __future__ import annotations
import argparse
import json
import logging
import random
import re
import subprocess
import time
from pathlib import Path
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
datefmt="%H:%M:%S",
)
logger = logging.getLogger(__name__)
TRAINING_DATA_DIR = Path.home() / "transceiver-training-data"
DEFAULT_INPUT = TRAINING_DATA_DIR / "v7-generated-sft.jsonl"
OUTPUT_FILE = TRAINING_DATA_DIR / "v7-dpo-pairs.jsonl"
random.seed(42)
# ─── Rejection patterns ───────────────────────────────────────────────────────
REJECTION_STRATEGIES = [
"missing_intro",
"too_long",
"topic_drift",
"repeated_sections",
"no_structure",
]
# Generic "filler" content to inject for topic drift
DRIFT_FILLER = """
When considering optical transceivers in general, the most important factors are compatibility,
power consumption, and reach. Different vendors offer different solutions, and it's important
to evaluate your options carefully.
Generally speaking, transceivers come in many form factors: SFP, SFP+, QSFP, QSFP28, and others.
Each serves a different purpose in the network. The right choice depends on your specific requirements
including speed, distance, fiber type, and budget.
It's also worth noting that compatibility with your existing equipment is a key consideration.
Always check the vendor's compatibility matrix before purchasing any transceiver. Some vendors
are more open than others about supporting third-party optics.
In summary, careful evaluation of your requirements against available transceiver options will
lead to the best outcome for your organization's specific needs and constraints.
"""
def create_rejected_version(
good_output: str,
strategy: str,
topic: str,
) -> str:
"""
Create a "bad" version of a blog post using the specified rejection strategy.
Returns the rejected (bad) text.
"""
paragraphs = [p.strip() for p in good_output.split("\n\n") if p.strip()]
sections = re.split(r"^## .+", good_output, flags=re.MULTILINE)
if strategy == "missing_intro":
# Remove the first paragraph (the hook) and start abruptly
# Also remove any intro-style opening
lines = good_output.split("\n")
# Skip first non-empty block of lines
skipping = True
result_lines = []
for line in lines:
if skipping and line.strip() and not line.startswith("#"):
continue # Skip until we hit a header or second paragraph
elif skipping and (line.startswith("##") or line == ""):
skipping = False
result_lines.append(line)
else:
result_lines.append(line)
return "\n".join(result_lines) or good_output
elif strategy == "too_long":
# Repeat the middle sections 2-3 times + add verbose filler
if len(sections) < 2:
return good_output + "\n\n" + good_output + "\n\n" + good_output[:500]
intro = sections[0]
middle = "\n\n".join(sections[1:])
# Add repetition and verbose filler
return (
intro + "\n\n" + middle + "\n\n"
"## Additional Considerations\n\n"
+ middle + "\n\n"
"## Further Analysis\n\n"
"As we discussed above, " + middle[:300] + "\n\n"
"## Summary and Recap\n\n"
"To summarize what we have covered so far, it's important to reiterate the key points "
"that were mentioned in the previous sections. As noted above, these considerations are "
"critical to making the right decision for your network infrastructure.\n\n"
+ DRIFT_FILLER.strip()
)
elif strategy == "topic_drift":
# Start on-topic, then drift to generic content after first paragraph
if paragraphs:
first = paragraphs[0]
return first + "\n\n" + DRIFT_FILLER.strip()
return DRIFT_FILLER.strip()
elif strategy == "repeated_sections":
# Copy-paste one section twice verbatim
if len(sections) >= 3:
header_matches = list(re.finditer(r"^## .+", good_output, re.MULTILINE))
if header_matches:
# Duplicate the first section
repeat_start = header_matches[0].start()
repeat_end = header_matches[1].start() if len(header_matches) > 1 else len(good_output)
repeated_section = good_output[repeat_start:repeat_end]
# Insert the duplicate after the original
return (
good_output[:repeat_end]
+ "\n\n"
+ repeated_section
+ "\n\n"
+ good_output[repeat_end:]
)
# Fallback: repeat full text
return good_output + "\n\n" + good_output
elif strategy == "no_structure":
# Strip all ## headers and bullet points → wall of text
text = re.sub(r"^#{1,3} .+$", "", good_output, flags=re.MULTILINE)
text = re.sub(r"^\s*[-*•]\s+", "", text, flags=re.MULTILINE)
text = re.sub(r"\*\*(.+?)\*\*", r"\1", text) # Remove bold
text = re.sub(r"\n{3,}", "\n\n", text)
return text.strip()
return good_output # Fallback
def call_claude_for_bad_version(
system: str,
topic: str,
strategy: str,
timeout: int = 120,
) -> str | None:
"""
Alternative: Ask Claude to deliberately write a BAD version.
Use for strategies that are hard to create programmatically.
"""
bad_prompts = {
"missing_intro": (
f"Write a BAD blog post about '{topic}'. "
"DO NOT include an introduction or hook paragraph. "
"Start immediately with a technical section, skipping any context-setting. "
"The post should feel abrupt and confusing to readers who don't already know the topic."
),
"topic_drift": (
f"Write a blog post that STARTS about '{topic}' but after the first paragraph, "
"DRIFTS into generic optical transceiver advice unrelated to the specific topic. "
"The second half should be about general transceiver considerations, not the original topic."
),
"too_long_and_repetitive": (
f"Write a VERY LONG (2000+ words) blog post about '{topic}'. "
"Repeat the same information multiple times in different words. "
"Include obvious filler content. Repeat at least one entire section verbatim."
),
}
prompt = bad_prompts.get(strategy)
if not prompt:
return None
try:
result = subprocess.run(
["claude", "--print", "--system-prompt", system, "-p", prompt],
capture_output=True, text=True, timeout=timeout,
)
if result.returncode != 0 or not result.stdout.strip():
return None
return result.stdout.strip()
except Exception as exc:
logger.warning("claude error for bad version: %s", exc)
return None
def load_good_examples(input_file: Path) -> list[dict]:
"""Load good SFT examples from a JSONL file."""
if not input_file.exists():
logger.error("Input file not found: %s", input_file)
return []
examples = []
with open(input_file, encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
item = json.loads(line)
if item.get("output_text") and len(item["output_text"].split()) >= 400:
examples.append(item)
except json.JSONDecodeError:
pass
logger.info("Loaded %d valid good examples from %s", len(examples), input_file.name)
return examples
def generate_dpo_pairs(input_file: Path, max_pairs: int | None = None) -> None:
TRAINING_DATA_DIR.mkdir(parents=True, exist_ok=True)
good_examples = load_good_examples(input_file)
if not good_examples:
logger.error("No good examples found — run generate_v7_data.py first")
return
if max_pairs:
good_examples = good_examples[:max_pairs]
logger.info("Generating DPO pairs from %d good examples", len(good_examples))
stats = {"generated": 0, "failed": 0}
with open(OUTPUT_FILE, "a", encoding="utf-8") as out_f:
for i, item in enumerate(good_examples):
system_prompt = item.get("system_prompt", "")
input_text = item.get("input_text", "")
good_output = item.get("output_text", "")
topic = item.get("meta", {}).get("topic", input_text[:60])
# Pick a random rejection strategy for each pair
strategy = random.choice(REJECTION_STRATEGIES)
logger.info("[%03d/%03d] DPO pair (%s): %s", i + 1, len(good_examples), strategy, topic[:50])
# Create rejected (bad) version
if strategy in ("missing_intro", "repeated_sections", "no_structure"):
# Deterministic transformations — fast
rejected = create_rejected_version(good_output, strategy, topic)
elif strategy in ("too_long", "topic_drift"):
# Try programmatic first, fall back to Claude
rejected = create_rejected_version(good_output, strategy, topic)
else:
# Claude-generated bad version
rejected = call_claude_for_bad_version(system_prompt, topic, strategy)
if rejected is None:
rejected = create_rejected_version(good_output, "topic_drift", topic)
if not rejected or rejected == good_output:
logger.warning("[%03d] Could not create rejected version", i + 1)
stats["failed"] += 1
continue
# DPO format: prompt / chosen / rejected
# The prompt is the system + user message
dpo_record = {
"prompt": f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{input_text}<|im_end|>\n",
"chosen": good_output,
"rejected": rejected,
"meta": {
"topic": topic,
"rejection_strategy": strategy,
"chosen_words": len(good_output.split()),
"rejected_words": len(rejected.split()),
"category": item.get("meta", {}).get("category", ""),
"dataset_version": "v7-dpo",
},
}
out_f.write(json.dumps(dpo_record, ensure_ascii=False) + "\n")
out_f.flush()
stats["generated"] += 1
logger.info("DPO pairs done: generated=%d failed=%d", stats["generated"], stats["failed"])
logger.info("Output: %s", OUTPUT_FILE)
# Print strategy distribution
strategy_counts: dict[str, int] = {}
try:
with open(OUTPUT_FILE) as f:
for line in f:
if not line.strip():
continue
rec = json.loads(line)
s = rec.get("meta", {}).get("rejection_strategy", "unknown")
strategy_counts[s] = strategy_counts.get(s, 0) + 1
logger.info("Strategy distribution: %s", strategy_counts)
except Exception:
pass
def main() -> None:
parser = argparse.ArgumentParser(description="Generate DPO pairs from good SFT examples")
parser.add_argument(
"--input",
type=Path,
default=DEFAULT_INPUT,
help=f"Input JSONL with good examples (default: {DEFAULT_INPUT})",
)
parser.add_argument(
"--max",
type=int,
default=None,
help="Maximum number of pairs to generate",
)
args = parser.parse_args()
generate_dpo_pairs(input_file=args.input, max_pairs=args.max)
if __name__ == "__main__":
main()