Rene Fichtmueller 2ca77d0aee feat: Phase 2F — Multi-Agent Integration (ADRs + Client Fallback + Tests)
- ADR-0001: Multi-Agent Coworking Architecture with LLM Gateway Orchestrator
- ADR-0002: Tier Assignment Strategy for Model Selection (cost-first escalation)
- ADR-0003: Confidence Gate Thresholds & Learning Cycle Intervals (6h/12h/24h cycles)
- ADR-0004: External Provider Fallback Chain Ordering (Cerebras → Groq → Mistral)
- Enhanced client SDK: Offline Ollama fallback, health checks, exponential backoff retry
- Integration tests: claude-code-integration.test.ts (14 test cases)
- PHASE_2F_DEPLOYMENT.md: Pre-deployment checklist, automated deploy, rollback plan
- Post-deployment verification procedures for health, client fallback, metrics
2026-04-19 21:39:44 +02:00

291 lines
9.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
train-fixes.py — Standalone LoRA training from fixes-chatml-sft.jsonl
No database required. Reads directly from local JSONL files.
Continues training on existing magatama-coder adapter (Qwen2.5-7B-Instruct LoRA).
Usage:
source .venv/bin/activate
python3 scripts/train-fixes.py --dry-run # Just validate data
python3 scripts/train-fixes.py # Continue training magatama-coder with fixes
python3 scripts/train-fixes.py --from-scratch # Fresh LoRA (ignore existing adapter)
python3 scripts/train-fixes.py --model 14b # SFT on Qwen2.5-14B (from scratch)
"""
from __future__ import annotations
import argparse
import json
import logging
import os
import sys
from pathlib import Path
import torch
import yaml
from datasets import Dataset
from peft import LoraConfig, TaskType, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from peft import PeftModel
from trl import SFTConfig, SFTTrainer
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s [%(name)s] %(message)s",
)
logger = logging.getLogger("train-fixes")
ROOT = Path(__file__).resolve().parent.parent
DATA_DIR = ROOT / "data"
CONFIG_PATH = ROOT / "config" / "fixes-training.yaml"
# Existing magatama-coder adapter from previous training
EXISTING_ADAPTER = ROOT / "adapters" / "e1fb3d70-f681-41c5-80d6-9185c549d5d7" / "adapter"
def load_config() -> dict:
with open(CONFIG_PATH) as f:
return yaml.safe_load(f)
def load_chatml_data(path: Path) -> list[dict]:
samples = []
with open(path) as f:
for line in f:
line = line.strip()
if line:
samples.append(json.loads(line))
logger.info("Loaded %d samples from %s", len(samples), path.name)
return samples
def load_dpo_data(path: Path) -> list[dict]:
samples = []
with open(path) as f:
for line in f:
line = line.strip()
if line:
samples.append(json.loads(line))
logger.info("Loaded %d DPO pairs from %s", len(samples), path.name)
return samples
def get_device() -> str:
if torch.backends.mps.is_available():
logger.info("Using Apple Silicon MPS")
return "mps"
if torch.cuda.is_available():
logger.info("Using CUDA GPU")
return "cuda"
logger.info("Using CPU (slow!)")
return "cpu"
def run_sft(config: dict, model_size: str = "7b", from_scratch: bool = False) -> None:
sft_cfg = config["sft"]
models = config["models"]
model_id = models["primary"] if model_size == "7b" else models["secondary"]
logger.info("Base model: %s", model_id)
# Load data
data_path = DATA_DIR / "fixes-chatml-sft.jsonl"
samples = load_chatml_data(data_path)
# Use all samples for training — eval disabled on MPS due to OOM during evaluation_loop
# (eval loads logits for all tokens which uses too much memory on 7.6B model)
train_ds = Dataset.from_list(samples)
eval_ds = None
logger.info("Train: %d samples (eval disabled on MPS for OOM protection)", len(samples))
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Load model
device = get_device()
logger.info("Loading model to %s...", device)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float32 if device == "mps" else torch.float16,
trust_remote_code=True,
)
# Continue from existing adapter or create new LoRA
use_existing = (
not from_scratch
and model_size == "7b"
and EXISTING_ADAPTER.exists()
)
if use_existing:
logger.info("Loading existing magatama-coder adapter from %s", EXISTING_ADAPTER)
model = PeftModel.from_pretrained(
model,
str(EXISTING_ADAPTER),
is_trainable=True,
)
logger.info("Continuing training on existing adapter (magatama-coder)")
else:
logger.info("Creating fresh LoRA adapter")
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=sft_cfg["lora_r"],
lora_alpha=sft_cfg["lora_alpha"],
lora_dropout=sft_cfg["lora_dropout"],
target_modules=sft_cfg["target_modules"],
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
if device == "mps":
model = model.to("mps")
# Output dir
output_dir = str(ROOT / config["output"]["adapters_dir"])
os.makedirs(output_dir, exist_ok=True)
# Training args
training_args = SFTConfig(
output_dir=output_dir,
num_train_epochs=sft_cfg["num_epochs"],
per_device_train_batch_size=sft_cfg["batch_size"],
gradient_accumulation_steps=sft_cfg["gradient_accumulation"],
learning_rate=sft_cfg["learning_rate"],
warmup_ratio=sft_cfg["warmup_ratio"],
weight_decay=sft_cfg.get("weight_decay", 0.01),
lr_scheduler_type=sft_cfg.get("lr_scheduler", "cosine"),
logging_steps=sft_cfg.get("logging_steps", 10),
save_steps=20, # Save every 20 steps so we have checkpoints if it crashes
eval_strategy="no", # Disabled on MPS — causes OOM during evaluation_loop
max_length=sft_cfg["max_seq_length"],
dataset_text_field="text",
bf16=False,
fp16=False,
use_cpu=(device == "cpu"),
gradient_checkpointing=False,
save_total_limit=5,
load_best_model_at_end=False,
save_safetensors=True,
dataloader_num_workers=0, # Avoid multiprocessing overhead on MPS
report_to="none",
)
# Trainer
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=eval_ds,
processing_class=tokenizer,
)
logger.info("Starting SFT training...")
logger.info(" Epochs: %d", sft_cfg["num_epochs"])
logger.info(" Batch: %d × %d accum = %d effective",
sft_cfg["batch_size"], sft_cfg["gradient_accumulation"],
sft_cfg["batch_size"] * sft_cfg["gradient_accumulation"])
logger.info(" LR: %s", sft_cfg["learning_rate"])
logger.info(" Max seq: %d", sft_cfg["max_seq_length"])
logger.info(" Output: %s", output_dir)
result = trainer.train()
logger.info("Training complete!")
logger.info(" Loss: %.4f", result.training_loss)
logger.info(" Steps: %d", result.global_step)
logger.info(" Runtime: %.0fs", result.metrics.get("train_runtime", 0))
# Save adapter
trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)
logger.info("Adapter saved to %s", output_dir)
# Save training metadata
meta = {
"model_id": model_id,
"dataset": str(data_path),
"samples": len(train_samples),
"eval_samples": len(eval_samples),
"training_loss": result.training_loss,
"global_step": result.global_step,
"runtime_seconds": result.metrics.get("train_runtime", 0),
"config": sft_cfg,
}
with open(os.path.join(output_dir, "training-meta.json"), "w") as f:
json.dump(meta, f, indent=2, ensure_ascii=False)
logger.info("Metadata saved to training-meta.json")
def run_dry(config: dict) -> None:
"""Validate data without training."""
data_path = DATA_DIR / "fixes-chatml-sft.jsonl"
samples = load_chatml_data(data_path)
dpo_path = DATA_DIR / "fixes-dpo-pairs.jsonl"
dpo_samples = load_dpo_data(dpo_path)
# Token length analysis
lengths = [len(s["text"]) / 4 for s in samples] # rough token estimate
avg = sum(lengths) / len(lengths)
max_len = max(lengths)
under_512 = sum(1 for l in lengths if l < 512) / len(lengths) * 100
under_1024 = sum(1 for l in lengths if l < 1024) / len(lengths) * 100
print(f"\n{'' * 50}")
print(f"DRY RUN — Data Validation")
print(f"{'' * 50}")
print(f"SFT samples: {len(samples)}")
print(f"DPO pairs: {len(dpo_samples)}")
print(f"Avg tokens: ~{avg:.0f}")
print(f"Max tokens: ~{max_len:.0f}")
print(f"Under 512: {under_512:.0f}%")
print(f"Under 1024: {under_1024:.0f}%")
print(f"Device: {get_device()}")
print(f"PyTorch: {torch.__version__}")
print(f"MPS available: {torch.backends.mps.is_available()}")
# Validate format
errors = 0
for i, s in enumerate(samples):
text = s.get("text", "")
if "<|im_start|>system" not in text:
logger.error("Sample %d missing system tag", i)
errors += 1
if "<|im_start|>assistant" not in text:
logger.error("Sample %d missing assistant tag", i)
errors += 1
print(f"Format errors: {errors}")
print(f"Status: {'READY' if errors == 0 else 'FIX ERRORS'}")
print(f"{'' * 50}\n")
def main() -> None:
parser = argparse.ArgumentParser(description="Train MAGATAMA Operations AI from fixes.json")
parser.add_argument("--model", choices=["7b", "14b"], default="7b", help="Model size")
parser.add_argument("--dry-run", action="store_true", help="Validate data only")
parser.add_argument("--from-scratch", action="store_true", help="Fresh LoRA, ignore existing adapter")
parser.add_argument("--dpo", action="store_true", help="Run DPO phase (after SFT)")
args = parser.parse_args()
config = load_config()
if args.dry_run:
run_dry(config)
elif args.dpo:
logger.info("DPO training not yet implemented in standalone script")
logger.info("Use: python3 scripts/manual_trigger.py --dpo")
sys.exit(1)
else:
run_sft(config, args.model, from_scratch=args.from_scratch)
if __name__ == "__main__":
main()