Full v8 training pipeline for the optical networking blog model: - train_blog_v8.py: SFT (LoRA r=64, 5 epochs) + DPO (2 epochs) on Qwen2.5-14B-Instruct Fixed for trl 1.2.x: SFTConfig instead of TrainingArguments, processing_class= instead of tokenizer=, eval_strategy= instead of deprecated evaluation_strategy= - consolidate_v8_dataset.py: weighted merge of all data sources (820 effective SFT / 235 DPO) - crawl_v8_sources.py: APNIC/RIPE Labs/potaroo/Cloudflare crawler with balanced div extraction - process_v6_blogs.py: converts 101 real v6 TIP blog outputs into SFT + DPO pairs - label_v7_quality.py: Claude-judged quality labels → v8 quality DPO pairs - parse_real_posts.py: parses blog.fichtmueller.org Ghost CMS HTML → gold SFT records - run_v8_pipeline.sh: autopilot (consolidate → SFT → DPO → GGUF → Ollama) - blog-v8-training.yaml: training config reference Dataset breakdown: 19 real posts ×3 + 196 v7-gen + 28 v6blogs ×2 + 135 external ×1.5
395 lines
14 KiB
Python
395 lines
14 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
consolidate_v8_dataset.py — Alle v8 Datenquellen zusammenführen
|
||
|
||
Merged alle JSONL-Quellen in ein einzelnes Training-Dataset:
|
||
Tier 1 (weight 3.0): Renes echte Blog-Posts (Gold Standard)
|
||
Tier 2 (weight 1.0): v7-generierte Blogs, RIPE/APNIC Ingest
|
||
Tier 3 (weight 1.5): Externe gecrawlte + umgeschriebene Posts
|
||
|
||
Für SFT-Training: Dupliziert High-Weight Beispiele entsprechend
|
||
Für DPO-Training: Merged alle DPO-Pair-Dateien
|
||
|
||
Output:
|
||
~/transceiver-training-data/v8-sft-merged.jsonl (SFT, gewichtet)
|
||
~/transceiver-training-data/v8-dpo-merged.jsonl (DPO, alle Pairs)
|
||
|
||
Usage:
|
||
python3 scripts/consolidate_v8_dataset.py
|
||
python3 scripts/consolidate_v8_dataset.py --no-weight # flat, no duplication
|
||
python3 scripts/consolidate_v8_dataset.py --stats-only # nur Statistiken
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import json
|
||
import math
|
||
import random
|
||
from pathlib import Path
|
||
from typing import Any
|
||
|
||
DATA_DIR = Path.home() / "transceiver-training-data"
|
||
OUTPUT_SFT = DATA_DIR / "v8-sft-merged.jsonl"
|
||
OUTPUT_DPO = DATA_DIR / "v8-dpo-merged.jsonl"
|
||
|
||
random.seed(42)
|
||
|
||
# ─── SFT Sources ──────────────────────────────────────────────────────────────
|
||
SFT_SOURCES: list[dict[str, Any]] = [
|
||
# Tier 1 — Gold: Rene's real posts (weight × 3)
|
||
{
|
||
"file": "v8-real-posts-sft.jsonl",
|
||
"weight": 3.0,
|
||
"tier": 1,
|
||
"description": "Rene's real blog posts (human written)",
|
||
},
|
||
# Tier 2 — Good: Claude-generated, validated topics
|
||
{
|
||
"file": "v7-generated-sft.jsonl",
|
||
"weight": 1.0,
|
||
"tier": 2,
|
||
"description": "v7 Claude-generated optical/networking blogs",
|
||
},
|
||
{
|
||
"file": "v7-ripe-apnic-sft.jsonl",
|
||
"weight": 1.0,
|
||
"tier": 2,
|
||
"description": "RIPE/APNIC BGP & routing content (v7 ingest)",
|
||
},
|
||
# Tier 2 — Real v6 TIP outputs (good length, real style)
|
||
{
|
||
"file": "v8-v6blogs-sft.jsonl",
|
||
"weight": 2.0,
|
||
"tier": 2,
|
||
"description": "Real fo-blog-v6 outputs (700-1100w, actual optical networking voice)",
|
||
"optional": True,
|
||
},
|
||
# Tier 3 — External: crawled + Claude rewritten
|
||
{
|
||
"file": "v8-external-sft.jsonl",
|
||
"weight": 1.5,
|
||
"tier": 3,
|
||
"description": "External: APNIC Blog / RIPE Labs / potaroo.net / Cloudflare",
|
||
"optional": True,
|
||
},
|
||
# Tier 2 — Legacy: pre-v7 curated datasets (lower priority, check quality)
|
||
{
|
||
"file": "nanog-ripe-labs-content.jsonl",
|
||
"weight": 0.5,
|
||
"tier": 2,
|
||
"description": "NANOG/RIPE Labs curated (pre-v7)",
|
||
"optional": True,
|
||
"needs_conversion": True, # may use different schema
|
||
},
|
||
{
|
||
"file": "rir-infrastructure-data.jsonl",
|
||
"weight": 0.5,
|
||
"tier": 2,
|
||
"description": "RIR infrastructure data (pre-v7)",
|
||
"optional": True,
|
||
"needs_conversion": True,
|
||
},
|
||
]
|
||
|
||
# ─── DPO Sources ──────────────────────────────────────────────────────────────
|
||
DPO_SOURCES: list[dict[str, Any]] = [
|
||
{
|
||
"file": "v7-dpo-pairs.jsonl",
|
||
"description": "v7 DPO pairs (5 rejection strategies, synthetic)",
|
||
},
|
||
{
|
||
"file": "v8-v6blogs-dpo.jsonl",
|
||
"description": "Real v6 too-long posts as rejected + Claude-rewritten chosen (real failures)",
|
||
"optional": True,
|
||
},
|
||
{
|
||
"file": "v8-quality-dpo.jsonl",
|
||
"description": "v8 real quality labels (human good/bad scoring)",
|
||
"optional": True,
|
||
},
|
||
]
|
||
|
||
# ─── Required SFT fields ──────────────────────────────────────────────────────
|
||
REQUIRED_SFT_FIELDS = {"system_prompt", "input_text", "output_text"}
|
||
REQUIRED_DPO_FIELDS = {"prompt", "chosen", "rejected"}
|
||
|
||
|
||
def load_sft_file(path: Path, needs_conversion: bool = False) -> list[dict]:
|
||
"""Load + validate SFT JSONL. Attempt field mapping for legacy formats."""
|
||
records = []
|
||
with open(path, encoding="utf-8") as f:
|
||
for i, line in enumerate(f, 1):
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
try:
|
||
item = json.loads(line)
|
||
except json.JSONDecodeError:
|
||
continue
|
||
|
||
# Check required fields
|
||
if REQUIRED_SFT_FIELDS.issubset(item.keys()):
|
||
records.append(item)
|
||
continue
|
||
|
||
if not needs_conversion:
|
||
continue
|
||
|
||
# Try to map legacy schemas → standard
|
||
converted = _try_convert_legacy_sft(item)
|
||
if converted:
|
||
records.append(converted)
|
||
|
||
return records
|
||
|
||
|
||
def _try_convert_legacy_sft(item: dict) -> dict | None:
|
||
"""Try to map legacy JSONL formats to standard SFT schema."""
|
||
# Schema: {instruction, input, output}
|
||
if "instruction" in item and "output" in item:
|
||
return {
|
||
"system_prompt": item.get("instruction", ""),
|
||
"input_text": item.get("input", ""),
|
||
"output_text": item["output"],
|
||
"meta": item.get("meta", {}),
|
||
}
|
||
# Schema: {prompt, completion}
|
||
if "prompt" in item and "completion" in item:
|
||
return {
|
||
"system_prompt": "",
|
||
"input_text": item["prompt"],
|
||
"output_text": item["completion"],
|
||
"meta": item.get("meta", {}),
|
||
}
|
||
# Schema: {messages: [{role, content}]}
|
||
if "messages" in item:
|
||
msgs = item["messages"]
|
||
sys_msg = next((m["content"] for m in msgs if m.get("role") == "system"), "")
|
||
user_msg = next((m["content"] for m in msgs if m.get("role") == "user"), "")
|
||
asst_msg = next((m["content"] for m in msgs if m.get("role") == "assistant"), "")
|
||
if user_msg and asst_msg:
|
||
return {
|
||
"system_prompt": sys_msg,
|
||
"input_text": user_msg,
|
||
"output_text": asst_msg,
|
||
"meta": item.get("meta", {}),
|
||
}
|
||
return None
|
||
|
||
|
||
def validate_sft_record(record: dict) -> bool:
|
||
"""Quality gate for SFT records."""
|
||
output = record.get("output_text", "")
|
||
words = len(output.split())
|
||
if words < 200:
|
||
return False
|
||
if not record.get("input_text"):
|
||
return False
|
||
return True
|
||
|
||
|
||
def duplicate_by_weight(records: list[dict], weight: float) -> list[dict]:
|
||
"""Duplicate records to approximate their training weight."""
|
||
if weight <= 1.0:
|
||
return records
|
||
# Integer duplications + probabilistic remainder
|
||
full_copies = int(math.floor(weight))
|
||
remainder = weight - full_copies
|
||
result = records * full_copies
|
||
# Add fractional copies
|
||
n_extra = int(len(records) * remainder)
|
||
if n_extra > 0:
|
||
extra = random.sample(records, min(n_extra, len(records)))
|
||
result.extend(extra)
|
||
return result
|
||
|
||
|
||
def merge_sft(apply_weights: bool = True) -> dict[str, int]:
|
||
"""Merge all SFT sources into OUTPUT_SFT."""
|
||
stats: dict[str, int] = {}
|
||
all_records: list[dict] = []
|
||
|
||
for source in SFT_SOURCES:
|
||
path = DATA_DIR / source["file"]
|
||
if not path.exists():
|
||
if source.get("optional"):
|
||
print(f" SKIP (optional, not found): {source['file']}")
|
||
else:
|
||
print(f" MISSING (required): {source['file']}")
|
||
continue
|
||
|
||
records = load_sft_file(path, source.get("needs_conversion", False))
|
||
valid = [r for r in records if validate_sft_record(r)]
|
||
invalid = len(records) - len(valid)
|
||
|
||
if invalid:
|
||
print(f" {source['file']}: {len(records)} loaded, {invalid} dropped (quality)")
|
||
|
||
weight = source["weight"] if apply_weights else 1.0
|
||
if apply_weights and weight != 1.0:
|
||
weighted = duplicate_by_weight(valid, weight)
|
||
print(f" Tier {source['tier']} | {source['file']}: {len(valid)} → {len(weighted)} (×{weight})")
|
||
else:
|
||
weighted = valid
|
||
print(f" Tier {source['tier']} | {source['file']}: {len(valid)}")
|
||
|
||
# Tag with source metadata
|
||
for r in weighted:
|
||
if "meta" not in r:
|
||
r["meta"] = {}
|
||
r["meta"].setdefault("source_file", source["file"])
|
||
r["meta"].setdefault("tier", source["tier"])
|
||
|
||
all_records.extend(weighted)
|
||
stats[source["file"]] = len(valid)
|
||
|
||
# Shuffle to mix tiers
|
||
random.shuffle(all_records)
|
||
|
||
OUTPUT_SFT.parent.mkdir(parents=True, exist_ok=True)
|
||
with open(OUTPUT_SFT, "w", encoding="utf-8") as f:
|
||
for r in all_records:
|
||
f.write(json.dumps(r, ensure_ascii=False) + "\n")
|
||
|
||
print(f"\nSFT merged: {len(all_records)} examples → {OUTPUT_SFT}")
|
||
return stats
|
||
|
||
|
||
def merge_dpo() -> dict[str, int]:
|
||
"""Merge all DPO sources into OUTPUT_DPO."""
|
||
stats: dict[str, int] = {}
|
||
all_pairs: list[dict] = []
|
||
|
||
for source in DPO_SOURCES:
|
||
path = DATA_DIR / source["file"]
|
||
if not path.exists():
|
||
if source.get("optional"):
|
||
print(f" SKIP (optional): {source['file']}")
|
||
else:
|
||
print(f" MISSING: {source['file']}")
|
||
continue
|
||
|
||
pairs = []
|
||
with open(path, encoding="utf-8") as f:
|
||
for line in f:
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
try:
|
||
item = json.loads(line)
|
||
if REQUIRED_DPO_FIELDS.issubset(item.keys()):
|
||
# Ensure chosen != rejected and both are non-empty
|
||
if (item["chosen"] != item["rejected"]
|
||
and len(item["chosen"].split()) > 50
|
||
and len(item["rejected"].split()) > 10):
|
||
pairs.append(item)
|
||
except json.JSONDecodeError:
|
||
pass
|
||
|
||
print(f" {source['file']}: {len(pairs)} valid pairs")
|
||
all_pairs.extend(pairs)
|
||
stats[source["file"]] = len(pairs)
|
||
|
||
random.shuffle(all_pairs)
|
||
|
||
with open(OUTPUT_DPO, "w", encoding="utf-8") as f:
|
||
for p in all_pairs:
|
||
f.write(json.dumps(p, ensure_ascii=False) + "\n")
|
||
|
||
print(f"\nDPO merged: {len(all_pairs)} pairs → {OUTPUT_DPO}")
|
||
return stats
|
||
|
||
|
||
def print_stats() -> None:
|
||
"""Print dataset statistics without merging."""
|
||
print("\n=== v8 Dataset Status ===\n")
|
||
|
||
print("SFT Sources:")
|
||
sft_total = 0
|
||
for source in SFT_SOURCES:
|
||
path = DATA_DIR / source["file"]
|
||
if path.exists():
|
||
with open(path) as f:
|
||
count = sum(1 for line in f if line.strip())
|
||
wc_list = []
|
||
with open(path) as f:
|
||
for line in f:
|
||
if not line.strip():
|
||
continue
|
||
try:
|
||
r = json.loads(line)
|
||
wc = r.get("meta", {}).get("word_count") or len(r.get("output_text", "").split())
|
||
if wc:
|
||
wc_list.append(wc)
|
||
except Exception:
|
||
pass
|
||
avg_wc = int(sum(wc_list) / len(wc_list)) if wc_list else 0
|
||
effective = int(count * source["weight"])
|
||
sft_total += effective
|
||
status = f"✓ {count:4d} examples (×{source['weight']} → {effective:4d} effective, avg {avg_wc}w)"
|
||
else:
|
||
status = "✗ NOT FOUND" + (" (optional)" if source.get("optional") else " ⚠ REQUIRED")
|
||
print(f" [{source['tier']}] {source['file']}: {status}")
|
||
|
||
print(f"\n Total effective SFT: {sft_total} examples")
|
||
|
||
print("\nDPO Sources:")
|
||
dpo_total = 0
|
||
for source in DPO_SOURCES:
|
||
path = DATA_DIR / source["file"]
|
||
if path.exists():
|
||
with open(path) as f:
|
||
count = sum(1 for line in f if line.strip())
|
||
dpo_total += count
|
||
print(f" ✓ {source['file']}: {count} pairs")
|
||
else:
|
||
optional = " (optional)" if source.get("optional") else " ⚠ REQUIRED"
|
||
print(f" ✗ {source['file']}: NOT FOUND{optional}")
|
||
|
||
print(f"\n Total DPO: {dpo_total} pairs")
|
||
|
||
# Merged files
|
||
print("\nMerged Outputs:")
|
||
for out in [OUTPUT_SFT, OUTPUT_DPO]:
|
||
if out.exists():
|
||
with open(out) as f:
|
||
count = sum(1 for line in f if line.strip())
|
||
size_mb = out.stat().st_size / 1_048_576
|
||
print(f" ✓ {out.name}: {count} lines ({size_mb:.1f} MB)")
|
||
else:
|
||
print(f" ✗ {out.name}: not yet generated")
|
||
|
||
|
||
def main() -> None:
|
||
parser = argparse.ArgumentParser(description="Merge v8 training datasets")
|
||
parser.add_argument("--no-weight", action="store_true",
|
||
help="Flat merge without duplication by weight")
|
||
parser.add_argument("--stats-only", action="store_true",
|
||
help="Print statistics only, do not merge")
|
||
parser.add_argument("--sft-only", action="store_true", help="Only merge SFT")
|
||
parser.add_argument("--dpo-only", action="store_true", help="Only merge DPO")
|
||
args = parser.parse_args()
|
||
|
||
if args.stats_only:
|
||
print_stats()
|
||
return
|
||
|
||
print_stats()
|
||
print()
|
||
|
||
if not args.dpo_only:
|
||
print("=== Merging SFT ===")
|
||
merge_sft(apply_weights=not args.no_weight)
|
||
|
||
if not args.sft_only:
|
||
print("\n=== Merging DPO ===")
|
||
merge_dpo()
|
||
|
||
print("\n=== Final Stats ===")
|
||
print_stats()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|