- ADR-0001: Multi-Agent Coworking Architecture with LLM Gateway Orchestrator - ADR-0002: Tier Assignment Strategy for Model Selection (cost-first escalation) - ADR-0003: Confidence Gate Thresholds & Learning Cycle Intervals (6h/12h/24h cycles) - ADR-0004: External Provider Fallback Chain Ordering (Cerebras → Groq → Mistral) - Enhanced client SDK: Offline Ollama fallback, health checks, exponential backoff retry - Integration tests: claude-code-integration.test.ts (14 test cases) - PHASE_2F_DEPLOYMENT.md: Pre-deployment checklist, automated deploy, rollback plan - Post-deployment verification procedures for health, client fallback, metrics
536 lines
18 KiB
Python
536 lines
18 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
train_blog_v7.py — fo-blog-v7 standalone training script
|
|
|
|
Two-phase training:
|
|
Phase 1 (SFT): LoRA fine-tuning on curated blog examples with anchored system prompt
|
|
Phase 2 (DPO): Preference learning from (chosen, rejected) blog pairs
|
|
|
|
Key improvements over v6:
|
|
- 350+ diverse training examples (250 generated + 100 RIPE/APNIC/NOG)
|
|
- DPO phase teaches the model to prefer structured, length-constrained output
|
|
- System prompt and user input both enforce 700-1000w and mandatory structure
|
|
- LoRA r=32 (doubled from v6's 16) for stronger signal with more diverse data
|
|
|
|
Usage:
|
|
cd packages/fine-tuner
|
|
python3 scripts/train_blog_v7.py --phase sft
|
|
python3 scripts/train_blog_v7.py --phase dpo
|
|
python3 scripts/train_blog_v7.py --phase sft --epochs 3
|
|
python3 scripts/train_blog_v7.py --phase sft --dry-run
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import logging
|
|
import os
|
|
import random
|
|
import re
|
|
import sys
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s [%(levelname)s] %(message)s",
|
|
datefmt="%H:%M:%S",
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# ─── Paths ────────────────────────────────────────────────────────────────────
|
|
|
|
FINE_TUNER_ROOT = Path(__file__).parent.parent
|
|
TRAINING_DATA_DIR = Path.home() / "transceiver-training-data"
|
|
|
|
SFT_DATA_FILES = [
|
|
TRAINING_DATA_DIR / "v7-generated-sft.jsonl",
|
|
TRAINING_DATA_DIR / "v7-ripe-apnic-sft.jsonl",
|
|
]
|
|
|
|
# Optional: include high-quality examples from v5/v6 dataset (700w+ only)
|
|
SUPPLEMENTAL_FILES = [
|
|
TRAINING_DATA_DIR / "blog-fichtmueller-posts.jsonl",
|
|
]
|
|
|
|
DPO_DATA_FILE = TRAINING_DATA_DIR / "v7-dpo-pairs.jsonl"
|
|
|
|
BASE_MODEL = "Qwen/Qwen2.5-7B-Instruct"
|
|
ADAPTER_OUTPUT_DIR = FINE_TUNER_ROOT / "adapters" / "fo-blog-v7"
|
|
DPO_OUTPUT_DIR = FINE_TUNER_ROOT / "adapters" / "fo-blog-v7-dpo"
|
|
|
|
CHATML_TEMPLATE = (
|
|
"<|im_start|>system\n{system}<|im_end|>\n"
|
|
"<|im_start|>user\n{user}<|im_end|>\n"
|
|
"<|im_start|>assistant\n{assistant}<|im_end|>"
|
|
)
|
|
|
|
|
|
# ─── Data loading ─────────────────────────────────────────────────────────────
|
|
|
|
def load_sft_examples(
|
|
min_output_words: int = 400,
|
|
max_output_words: int = 1500,
|
|
) -> tuple[list[dict], list[dict]]:
|
|
"""
|
|
Load and validate SFT training examples.
|
|
|
|
Returns (train_examples, val_examples) each as list of
|
|
{system_prompt, input_text, output_text} dicts.
|
|
|
|
Filters:
|
|
- min_output_words: Exclude too-short outputs
|
|
- max_output_words: Exclude runaway long outputs
|
|
"""
|
|
all_examples: list[dict] = []
|
|
|
|
for fpath in SFT_DATA_FILES:
|
|
if not fpath.exists():
|
|
logger.warning("SFT data file not found (skipping): %s", fpath)
|
|
continue
|
|
count_before = len(all_examples)
|
|
with open(fpath, encoding="utf-8") as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
try:
|
|
item = json.loads(line)
|
|
except json.JSONDecodeError:
|
|
continue
|
|
|
|
system = item.get("system_prompt", "").strip()
|
|
user = item.get("input_text", "").strip()
|
|
output = item.get("output_text", "").strip()
|
|
|
|
if not system or not user or not output:
|
|
continue
|
|
|
|
words = len(output.split())
|
|
if words < min_output_words:
|
|
continue
|
|
if words > max_output_words:
|
|
# Truncate at sentence boundary around 1200w
|
|
truncated = _truncate_to_words(output, 1100)
|
|
output = truncated
|
|
|
|
all_examples.append({
|
|
"system_prompt": system,
|
|
"input_text": user,
|
|
"output_text": output,
|
|
})
|
|
logger.info(" Loaded %d examples from %s", len(all_examples) - count_before, fpath.name)
|
|
|
|
# Load supplemental (real blog posts) — these are gold standard, include all
|
|
for fpath in SUPPLEMENTAL_FILES:
|
|
if not fpath.exists():
|
|
continue
|
|
count_before = len(all_examples)
|
|
with open(fpath, encoding="utf-8") as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
try:
|
|
item = json.loads(line)
|
|
except json.JSONDecodeError:
|
|
continue
|
|
output = item.get("output_text", "").strip()
|
|
if len(output.split()) < 200:
|
|
continue
|
|
all_examples.append({
|
|
"system_prompt": item.get("system_prompt", "You are an expert technical writer for optical networking."),
|
|
"input_text": item.get("input_text", "Write a technical blog post."),
|
|
"output_text": output,
|
|
})
|
|
logger.info(" Supplemental: %d examples from %s", len(all_examples) - count_before, fpath.name)
|
|
|
|
if not all_examples:
|
|
raise ValueError(
|
|
"No training examples found! Run generate_v7_data.py first:\n"
|
|
" python3 scripts/generate_v7_data.py"
|
|
)
|
|
|
|
logger.info("Total SFT examples: %d", len(all_examples))
|
|
|
|
# Shuffle and split 90/10
|
|
random.seed(42)
|
|
random.shuffle(all_examples)
|
|
split = max(1, int(len(all_examples) * 0.9))
|
|
train = all_examples[:split]
|
|
val = all_examples[split:]
|
|
|
|
logger.info("Train: %d | Val: %d", len(train), len(val))
|
|
return train, val
|
|
|
|
|
|
def _truncate_to_words(text: str, max_words: int) -> str:
|
|
"""Truncate text at a sentence boundary near max_words."""
|
|
words = text.split()
|
|
if len(words) <= max_words:
|
|
return text
|
|
# Find sentence boundary near max_words
|
|
partial = " ".join(words[:max_words])
|
|
# Find last sentence end
|
|
match = re.search(r"[.!?]\s*$", partial)
|
|
if match:
|
|
return partial[:match.end()].strip()
|
|
return partial + "."
|
|
|
|
|
|
def load_dpo_examples() -> list[dict]:
|
|
"""Load DPO (chosen/rejected) training pairs."""
|
|
if not DPO_DATA_FILE.exists():
|
|
raise FileNotFoundError(
|
|
f"DPO data not found: {DPO_DATA_FILE}\n"
|
|
"Run: python3 scripts/generate_dpo_pairs.py"
|
|
)
|
|
examples = []
|
|
with open(DPO_DATA_FILE, encoding="utf-8") as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
try:
|
|
item = json.loads(line)
|
|
if item.get("prompt") and item.get("chosen") and item.get("rejected"):
|
|
examples.append(item)
|
|
except json.JSONDecodeError:
|
|
pass
|
|
logger.info("Loaded %d DPO pairs from %s", len(examples), DPO_DATA_FILE.name)
|
|
return examples
|
|
|
|
|
|
# ─── SFT Training ─────────────────────────────────────────────────────────────
|
|
|
|
def run_sft(
|
|
num_epochs: int = 4,
|
|
lora_r: int = 32,
|
|
lora_alpha: int = 64,
|
|
lora_dropout: float = 0.05,
|
|
max_seq_length: int = 2048,
|
|
batch_size: int = 1,
|
|
gradient_accumulation: int = 8,
|
|
learning_rate: float = 1.5e-4,
|
|
dry_run: bool = False,
|
|
) -> dict:
|
|
"""Run Phase 1: LoRA SFT training."""
|
|
logger.info("=" * 60)
|
|
logger.info(" fo-blog-v7 Phase 1: SFT Training")
|
|
logger.info("=" * 60)
|
|
|
|
train_examples, val_examples = load_sft_examples()
|
|
|
|
if dry_run:
|
|
logger.info("DRY RUN: Would train on %d examples for %d epochs", len(train_examples), num_epochs)
|
|
logger.info(" base_model=%s, lora_r=%d, lr=%s", BASE_MODEL, lora_r, learning_rate)
|
|
return {"dry_run": True, "train_samples": len(train_examples)}
|
|
|
|
import torch
|
|
from datasets import Dataset
|
|
from peft import LoraConfig, TaskType, get_peft_model
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
from trl import SFTConfig, SFTTrainer
|
|
|
|
# Detect device
|
|
if torch.backends.mps.is_available():
|
|
device = "mps"
|
|
elif torch.cuda.is_available():
|
|
device = "cuda"
|
|
else:
|
|
device = "cpu"
|
|
logger.info("Device: %s", device)
|
|
|
|
# Prepare dataset
|
|
def format_example(ex: dict) -> dict:
|
|
text = CHATML_TEMPLATE.format(
|
|
system=ex["system_prompt"],
|
|
user=ex["input_text"],
|
|
assistant=ex["output_text"],
|
|
)
|
|
return {"text": text}
|
|
|
|
train_dataset = Dataset.from_list([format_example(ex) for ex in train_examples])
|
|
eval_dataset = Dataset.from_list([format_example(ex) for ex in val_examples]) if val_examples else None
|
|
|
|
logger.info("Dataset: %d train, %d eval", len(train_dataset), len(eval_dataset) if eval_dataset else 0)
|
|
|
|
# Load model
|
|
logger.info("Loading tokenizer and model: %s", BASE_MODEL)
|
|
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True, padding_side="right")
|
|
if tokenizer.pad_token is None:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
model_kwargs = {
|
|
"trust_remote_code": True,
|
|
}
|
|
if device == "mps":
|
|
model_kwargs["dtype"] = torch.float16
|
|
elif device == "cuda":
|
|
model_kwargs["torch_dtype"] = torch.bfloat16
|
|
model_kwargs["device_map"] = "auto"
|
|
else:
|
|
model_kwargs["dtype"] = torch.float32
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, **model_kwargs)
|
|
model.config.use_cache = False
|
|
|
|
# Apply LoRA
|
|
lora_config = LoraConfig(
|
|
r=lora_r,
|
|
lora_alpha=lora_alpha,
|
|
lora_dropout=lora_dropout,
|
|
bias="none",
|
|
task_type=TaskType.CAUSAL_LM,
|
|
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
|
|
inference_mode=False,
|
|
)
|
|
model = get_peft_model(model, lora_config)
|
|
model.print_trainable_parameters()
|
|
|
|
if device in ("mps", "cpu"):
|
|
model = model.to(device)
|
|
|
|
# Training args
|
|
ADAPTER_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
|
eval_strategy = "steps" if eval_dataset else "no"
|
|
|
|
training_args = SFTConfig(
|
|
output_dir=str(ADAPTER_OUTPUT_DIR),
|
|
num_train_epochs=num_epochs,
|
|
per_device_train_batch_size=batch_size,
|
|
per_device_eval_batch_size=batch_size,
|
|
gradient_accumulation_steps=gradient_accumulation,
|
|
learning_rate=learning_rate,
|
|
warmup_ratio=0.1,
|
|
weight_decay=0.01,
|
|
lr_scheduler_type="cosine",
|
|
eval_strategy=eval_strategy,
|
|
eval_steps=50 if eval_strategy == "steps" else None,
|
|
save_strategy="steps",
|
|
save_steps=50,
|
|
load_best_model_at_end=(eval_strategy == "steps"),
|
|
metric_for_best_model="eval_loss" if eval_strategy == "steps" else None,
|
|
greater_is_better=False,
|
|
logging_steps=10,
|
|
report_to="none",
|
|
dataloader_num_workers=0,
|
|
fp16=False,
|
|
bf16=False,
|
|
optim="adamw_torch",
|
|
gradient_checkpointing=True,
|
|
gradient_checkpointing_kwargs={"use_reentrant": False},
|
|
remove_unused_columns=False,
|
|
label_names=["labels"],
|
|
dataset_text_field="text",
|
|
max_length=max_seq_length,
|
|
packing=False,
|
|
)
|
|
|
|
trainer = SFTTrainer(
|
|
model=model,
|
|
args=training_args,
|
|
train_dataset=train_dataset,
|
|
eval_dataset=eval_dataset,
|
|
processing_class=tokenizer,
|
|
)
|
|
|
|
logger.info("Starting SFT training...")
|
|
result = trainer.train()
|
|
|
|
eval_metrics: dict = {}
|
|
if eval_dataset:
|
|
eval_metrics = trainer.evaluate()
|
|
|
|
# Save adapter
|
|
adapter_path = ADAPTER_OUTPUT_DIR / "adapter"
|
|
adapter_path.mkdir(parents=True, exist_ok=True)
|
|
model.save_pretrained(str(adapter_path))
|
|
tokenizer.save_pretrained(str(adapter_path))
|
|
logger.info("Saved SFT adapter to %s", adapter_path)
|
|
|
|
metrics = {
|
|
"phase": "sft",
|
|
"train_loss": round(result.training_loss, 4),
|
|
"eval_loss": round(eval_metrics.get("eval_loss", -1.0), 4),
|
|
"train_runtime_s": round(result.metrics.get("train_runtime", 0), 1),
|
|
"train_samples": len(train_dataset),
|
|
"val_samples": len(eval_dataset) if eval_dataset else 0,
|
|
"epochs": num_epochs,
|
|
"lora_r": lora_r,
|
|
"adapter_path": str(adapter_path),
|
|
"device": device,
|
|
}
|
|
logger.info("SFT metrics: %s", metrics)
|
|
return metrics
|
|
|
|
|
|
# ─── DPO Training ─────────────────────────────────────────────────────────────
|
|
|
|
def run_dpo(
|
|
num_epochs: int = 1,
|
|
beta: float = 0.1,
|
|
learning_rate: float = 5e-5,
|
|
max_seq_length: int = 2048,
|
|
batch_size: int = 1,
|
|
gradient_accumulation: int = 4,
|
|
dry_run: bool = False,
|
|
) -> dict:
|
|
"""Run Phase 2: DPO preference training."""
|
|
logger.info("=" * 60)
|
|
logger.info(" fo-blog-v7 Phase 2: DPO Training")
|
|
logger.info("=" * 60)
|
|
|
|
dpo_examples = load_dpo_examples()
|
|
|
|
if dry_run:
|
|
logger.info("DRY RUN: Would train DPO on %d pairs", len(dpo_examples))
|
|
return {"dry_run": True, "dpo_pairs": len(dpo_examples)}
|
|
|
|
import torch
|
|
from datasets import Dataset
|
|
from peft import PeftModel
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
from trl import DPOConfig, DPOTrainer
|
|
|
|
device = "mps" if torch.backends.mps.is_available() else \
|
|
"cuda" if torch.cuda.is_available() else "cpu"
|
|
logger.info("Device: %s", device)
|
|
|
|
sft_adapter_path = ADAPTER_OUTPUT_DIR / "adapter"
|
|
if not sft_adapter_path.exists():
|
|
raise FileNotFoundError(
|
|
f"SFT adapter not found: {sft_adapter_path}\n"
|
|
"Run Phase 1 first: python3 scripts/train_blog_v7.py --phase sft"
|
|
)
|
|
|
|
# Load base model + SFT adapter
|
|
logger.info("Loading base model with SFT adapter for DPO...")
|
|
tokenizer = AutoTokenizer.from_pretrained(str(sft_adapter_path), trust_remote_code=True)
|
|
if tokenizer.pad_token is None:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
if device == "mps":
|
|
model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, dtype=torch.float16, trust_remote_code=True)
|
|
else:
|
|
model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, torch_dtype=torch.bfloat16, trust_remote_code=True)
|
|
|
|
model = PeftModel.from_pretrained(model, str(sft_adapter_path))
|
|
model.config.use_cache = False
|
|
if device in ("mps", "cpu"):
|
|
model = model.to(device)
|
|
|
|
# DPO dataset
|
|
dataset = Dataset.from_list([{
|
|
"prompt": ex["prompt"],
|
|
"chosen": ex["chosen"],
|
|
"rejected": ex["rejected"],
|
|
} for ex in dpo_examples])
|
|
|
|
logger.info("DPO dataset: %d pairs", len(dataset))
|
|
|
|
DPO_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
dpo_config = DPOConfig(
|
|
output_dir=str(DPO_OUTPUT_DIR),
|
|
num_train_epochs=num_epochs,
|
|
per_device_train_batch_size=batch_size,
|
|
gradient_accumulation_steps=gradient_accumulation,
|
|
learning_rate=learning_rate,
|
|
beta=beta,
|
|
max_length=max_seq_length,
|
|
max_prompt_length=512,
|
|
fp16=False,
|
|
bf16=False,
|
|
optim="adamw_torch",
|
|
logging_steps=5,
|
|
save_steps=50,
|
|
report_to="none",
|
|
dataloader_num_workers=0,
|
|
remove_unused_columns=False,
|
|
label_names=[],
|
|
loss_type="sigmoid",
|
|
)
|
|
|
|
trainer = DPOTrainer(
|
|
model=model,
|
|
ref_model=None, # Use implicit reference (SFT adapter as starting point)
|
|
args=dpo_config,
|
|
train_dataset=dataset,
|
|
processing_class=tokenizer,
|
|
)
|
|
|
|
logger.info("Starting DPO training...")
|
|
result = trainer.train()
|
|
|
|
# Save DPO adapter
|
|
dpo_adapter_path = DPO_OUTPUT_DIR / "adapter"
|
|
model.save_pretrained(str(dpo_adapter_path))
|
|
tokenizer.save_pretrained(str(dpo_adapter_path))
|
|
logger.info("Saved DPO adapter to %s", dpo_adapter_path)
|
|
|
|
metrics = {
|
|
"phase": "dpo",
|
|
"train_loss": round(result.training_loss, 4),
|
|
"train_runtime_s": round(result.metrics.get("train_runtime", 0), 1),
|
|
"dpo_pairs": len(dataset),
|
|
"adapter_path": str(dpo_adapter_path),
|
|
"device": device,
|
|
}
|
|
logger.info("DPO metrics: %s", metrics)
|
|
return metrics
|
|
|
|
|
|
# ─── CLI ──────────────────────────────────────────────────────────────────────
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(description="fo-blog-v7 training script")
|
|
parser.add_argument("--phase", choices=["sft", "dpo"], required=True, help="Training phase")
|
|
parser.add_argument("--epochs", type=int, default=None, help="Override epoch count")
|
|
parser.add_argument("--lora-r", type=int, default=32, help="LoRA rank (SFT only)")
|
|
parser.add_argument("--lr", type=float, default=None, help="Override learning rate")
|
|
parser.add_argument("--dry-run", action="store_true", help="Validate data without training")
|
|
args = parser.parse_args()
|
|
|
|
if args.phase == "sft":
|
|
metrics = run_sft(
|
|
num_epochs=args.epochs or 4,
|
|
lora_r=args.lora_r,
|
|
learning_rate=args.lr or 1.5e-4,
|
|
dry_run=args.dry_run,
|
|
)
|
|
else:
|
|
metrics = run_dpo(
|
|
num_epochs=args.epochs or 1,
|
|
learning_rate=args.lr or 5e-5,
|
|
dry_run=args.dry_run,
|
|
)
|
|
|
|
print("\n" + "=" * 60)
|
|
print(f" fo-blog-v7 {args.phase.upper()} — DONE")
|
|
print("=" * 60)
|
|
for k, v in metrics.items():
|
|
print(f" {k}: {v}")
|
|
|
|
if args.phase == "sft" and not args.dry_run:
|
|
print()
|
|
print(" Next steps:")
|
|
print(" 1. Generate DPO pairs:")
|
|
print(" python3 scripts/generate_dpo_pairs.py")
|
|
print(" 2. Run DPO phase:")
|
|
print(" python3 scripts/train_blog_v7.py --phase dpo")
|
|
print(" 3. Convert + deploy:")
|
|
print(" python3 scripts/merge_and_convert.py --version v7")
|
|
elif args.phase == "dpo" and not args.dry_run:
|
|
print()
|
|
print(" Next steps:")
|
|
print(" 1. Merge + GGUF convert:")
|
|
print(" python3 scripts/merge_and_convert.py --version v7")
|
|
print(" 2. Test in Ollama:")
|
|
print(' ollama run fo-blog-v7 "Write a 700-1000w blog post about..."')
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|