- ADR-0001: Multi-Agent Coworking Architecture with LLM Gateway Orchestrator - ADR-0002: Tier Assignment Strategy for Model Selection (cost-first escalation) - ADR-0003: Confidence Gate Thresholds & Learning Cycle Intervals (6h/12h/24h cycles) - ADR-0004: External Provider Fallback Chain Ordering (Cerebras → Groq → Mistral) - Enhanced client SDK: Offline Ollama fallback, health checks, exponential backoff retry - Integration tests: claude-code-integration.test.ts (14 test cases) - PHASE_2F_DEPLOYMENT.md: Pre-deployment checklist, automated deploy, rollback plan - Post-deployment verification procedures for health, client fallback, metrics
291 lines
9.7 KiB
Python
291 lines
9.7 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
train-fixes.py — Standalone LoRA training from fixes-chatml-sft.jsonl
|
||
No database required. Reads directly from local JSONL files.
|
||
|
||
Continues training on existing magatama-coder adapter (Qwen2.5-7B-Instruct LoRA).
|
||
|
||
Usage:
|
||
source .venv/bin/activate
|
||
python3 scripts/train-fixes.py --dry-run # Just validate data
|
||
python3 scripts/train-fixes.py # Continue training magatama-coder with fixes
|
||
python3 scripts/train-fixes.py --from-scratch # Fresh LoRA (ignore existing adapter)
|
||
python3 scripts/train-fixes.py --model 14b # SFT on Qwen2.5-14B (from scratch)
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import json
|
||
import logging
|
||
import os
|
||
import sys
|
||
from pathlib import Path
|
||
|
||
import torch
|
||
import yaml
|
||
from datasets import Dataset
|
||
from peft import LoraConfig, TaskType, get_peft_model
|
||
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
|
||
from peft import PeftModel
|
||
from trl import SFTConfig, SFTTrainer
|
||
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format="%(asctime)s %(levelname)s [%(name)s] %(message)s",
|
||
)
|
||
logger = logging.getLogger("train-fixes")
|
||
|
||
ROOT = Path(__file__).resolve().parent.parent
|
||
DATA_DIR = ROOT / "data"
|
||
CONFIG_PATH = ROOT / "config" / "fixes-training.yaml"
|
||
|
||
# Existing magatama-coder adapter from previous training
|
||
EXISTING_ADAPTER = ROOT / "adapters" / "e1fb3d70-f681-41c5-80d6-9185c549d5d7" / "adapter"
|
||
|
||
|
||
def load_config() -> dict:
|
||
with open(CONFIG_PATH) as f:
|
||
return yaml.safe_load(f)
|
||
|
||
|
||
def load_chatml_data(path: Path) -> list[dict]:
|
||
samples = []
|
||
with open(path) as f:
|
||
for line in f:
|
||
line = line.strip()
|
||
if line:
|
||
samples.append(json.loads(line))
|
||
logger.info("Loaded %d samples from %s", len(samples), path.name)
|
||
return samples
|
||
|
||
|
||
def load_dpo_data(path: Path) -> list[dict]:
|
||
samples = []
|
||
with open(path) as f:
|
||
for line in f:
|
||
line = line.strip()
|
||
if line:
|
||
samples.append(json.loads(line))
|
||
logger.info("Loaded %d DPO pairs from %s", len(samples), path.name)
|
||
return samples
|
||
|
||
|
||
def get_device() -> str:
|
||
if torch.backends.mps.is_available():
|
||
logger.info("Using Apple Silicon MPS")
|
||
return "mps"
|
||
if torch.cuda.is_available():
|
||
logger.info("Using CUDA GPU")
|
||
return "cuda"
|
||
logger.info("Using CPU (slow!)")
|
||
return "cpu"
|
||
|
||
|
||
def run_sft(config: dict, model_size: str = "7b", from_scratch: bool = False) -> None:
|
||
sft_cfg = config["sft"]
|
||
models = config["models"]
|
||
|
||
model_id = models["primary"] if model_size == "7b" else models["secondary"]
|
||
logger.info("Base model: %s", model_id)
|
||
|
||
# Load data
|
||
data_path = DATA_DIR / "fixes-chatml-sft.jsonl"
|
||
samples = load_chatml_data(data_path)
|
||
|
||
# Use all samples for training — eval disabled on MPS due to OOM during evaluation_loop
|
||
# (eval loads logits for all tokens which uses too much memory on 7.6B model)
|
||
train_ds = Dataset.from_list(samples)
|
||
eval_ds = None
|
||
|
||
logger.info("Train: %d samples (eval disabled on MPS for OOM protection)", len(samples))
|
||
|
||
# Load tokenizer
|
||
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
||
if tokenizer.pad_token is None:
|
||
tokenizer.pad_token = tokenizer.eos_token
|
||
|
||
# Load model
|
||
device = get_device()
|
||
logger.info("Loading model to %s...", device)
|
||
|
||
model = AutoModelForCausalLM.from_pretrained(
|
||
model_id,
|
||
torch_dtype=torch.float32 if device == "mps" else torch.float16,
|
||
trust_remote_code=True,
|
||
)
|
||
|
||
# Continue from existing adapter or create new LoRA
|
||
use_existing = (
|
||
not from_scratch
|
||
and model_size == "7b"
|
||
and EXISTING_ADAPTER.exists()
|
||
)
|
||
|
||
if use_existing:
|
||
logger.info("Loading existing magatama-coder adapter from %s", EXISTING_ADAPTER)
|
||
model = PeftModel.from_pretrained(
|
||
model,
|
||
str(EXISTING_ADAPTER),
|
||
is_trainable=True,
|
||
)
|
||
logger.info("Continuing training on existing adapter (magatama-coder)")
|
||
else:
|
||
logger.info("Creating fresh LoRA adapter")
|
||
lora_config = LoraConfig(
|
||
task_type=TaskType.CAUSAL_LM,
|
||
r=sft_cfg["lora_r"],
|
||
lora_alpha=sft_cfg["lora_alpha"],
|
||
lora_dropout=sft_cfg["lora_dropout"],
|
||
target_modules=sft_cfg["target_modules"],
|
||
)
|
||
model = get_peft_model(model, lora_config)
|
||
|
||
model.print_trainable_parameters()
|
||
|
||
if device == "mps":
|
||
model = model.to("mps")
|
||
|
||
# Output dir
|
||
output_dir = str(ROOT / config["output"]["adapters_dir"])
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
# Training args
|
||
training_args = SFTConfig(
|
||
output_dir=output_dir,
|
||
num_train_epochs=sft_cfg["num_epochs"],
|
||
per_device_train_batch_size=sft_cfg["batch_size"],
|
||
gradient_accumulation_steps=sft_cfg["gradient_accumulation"],
|
||
learning_rate=sft_cfg["learning_rate"],
|
||
warmup_ratio=sft_cfg["warmup_ratio"],
|
||
weight_decay=sft_cfg.get("weight_decay", 0.01),
|
||
lr_scheduler_type=sft_cfg.get("lr_scheduler", "cosine"),
|
||
logging_steps=sft_cfg.get("logging_steps", 10),
|
||
save_steps=20, # Save every 20 steps so we have checkpoints if it crashes
|
||
eval_strategy="no", # Disabled on MPS — causes OOM during evaluation_loop
|
||
max_length=sft_cfg["max_seq_length"],
|
||
dataset_text_field="text",
|
||
bf16=False,
|
||
fp16=False,
|
||
use_cpu=(device == "cpu"),
|
||
gradient_checkpointing=False,
|
||
save_total_limit=5,
|
||
load_best_model_at_end=False,
|
||
save_safetensors=True,
|
||
dataloader_num_workers=0, # Avoid multiprocessing overhead on MPS
|
||
report_to="none",
|
||
)
|
||
|
||
# Trainer
|
||
trainer = SFTTrainer(
|
||
model=model,
|
||
args=training_args,
|
||
train_dataset=train_ds,
|
||
eval_dataset=eval_ds,
|
||
processing_class=tokenizer,
|
||
)
|
||
|
||
logger.info("Starting SFT training...")
|
||
logger.info(" Epochs: %d", sft_cfg["num_epochs"])
|
||
logger.info(" Batch: %d × %d accum = %d effective",
|
||
sft_cfg["batch_size"], sft_cfg["gradient_accumulation"],
|
||
sft_cfg["batch_size"] * sft_cfg["gradient_accumulation"])
|
||
logger.info(" LR: %s", sft_cfg["learning_rate"])
|
||
logger.info(" Max seq: %d", sft_cfg["max_seq_length"])
|
||
logger.info(" Output: %s", output_dir)
|
||
|
||
result = trainer.train()
|
||
|
||
logger.info("Training complete!")
|
||
logger.info(" Loss: %.4f", result.training_loss)
|
||
logger.info(" Steps: %d", result.global_step)
|
||
logger.info(" Runtime: %.0fs", result.metrics.get("train_runtime", 0))
|
||
|
||
# Save adapter
|
||
trainer.save_model(output_dir)
|
||
tokenizer.save_pretrained(output_dir)
|
||
logger.info("Adapter saved to %s", output_dir)
|
||
|
||
# Save training metadata
|
||
meta = {
|
||
"model_id": model_id,
|
||
"dataset": str(data_path),
|
||
"samples": len(train_samples),
|
||
"eval_samples": len(eval_samples),
|
||
"training_loss": result.training_loss,
|
||
"global_step": result.global_step,
|
||
"runtime_seconds": result.metrics.get("train_runtime", 0),
|
||
"config": sft_cfg,
|
||
}
|
||
with open(os.path.join(output_dir, "training-meta.json"), "w") as f:
|
||
json.dump(meta, f, indent=2, ensure_ascii=False)
|
||
logger.info("Metadata saved to training-meta.json")
|
||
|
||
|
||
def run_dry(config: dict) -> None:
|
||
"""Validate data without training."""
|
||
data_path = DATA_DIR / "fixes-chatml-sft.jsonl"
|
||
samples = load_chatml_data(data_path)
|
||
|
||
dpo_path = DATA_DIR / "fixes-dpo-pairs.jsonl"
|
||
dpo_samples = load_dpo_data(dpo_path)
|
||
|
||
# Token length analysis
|
||
lengths = [len(s["text"]) / 4 for s in samples] # rough token estimate
|
||
avg = sum(lengths) / len(lengths)
|
||
max_len = max(lengths)
|
||
under_512 = sum(1 for l in lengths if l < 512) / len(lengths) * 100
|
||
under_1024 = sum(1 for l in lengths if l < 1024) / len(lengths) * 100
|
||
|
||
print(f"\n{'═' * 50}")
|
||
print(f"DRY RUN — Data Validation")
|
||
print(f"{'═' * 50}")
|
||
print(f"SFT samples: {len(samples)}")
|
||
print(f"DPO pairs: {len(dpo_samples)}")
|
||
print(f"Avg tokens: ~{avg:.0f}")
|
||
print(f"Max tokens: ~{max_len:.0f}")
|
||
print(f"Under 512: {under_512:.0f}%")
|
||
print(f"Under 1024: {under_1024:.0f}%")
|
||
print(f"Device: {get_device()}")
|
||
print(f"PyTorch: {torch.__version__}")
|
||
print(f"MPS available: {torch.backends.mps.is_available()}")
|
||
|
||
# Validate format
|
||
errors = 0
|
||
for i, s in enumerate(samples):
|
||
text = s.get("text", "")
|
||
if "<|im_start|>system" not in text:
|
||
logger.error("Sample %d missing system tag", i)
|
||
errors += 1
|
||
if "<|im_start|>assistant" not in text:
|
||
logger.error("Sample %d missing assistant tag", i)
|
||
errors += 1
|
||
|
||
print(f"Format errors: {errors}")
|
||
print(f"Status: {'READY' if errors == 0 else 'FIX ERRORS'}")
|
||
print(f"{'═' * 50}\n")
|
||
|
||
|
||
def main() -> None:
|
||
parser = argparse.ArgumentParser(description="Train MAGATAMA Operations AI from fixes.json")
|
||
parser.add_argument("--model", choices=["7b", "14b"], default="7b", help="Model size")
|
||
parser.add_argument("--dry-run", action="store_true", help="Validate data only")
|
||
parser.add_argument("--from-scratch", action="store_true", help="Fresh LoRA, ignore existing adapter")
|
||
parser.add_argument("--dpo", action="store_true", help="Run DPO phase (after SFT)")
|
||
args = parser.parse_args()
|
||
|
||
config = load_config()
|
||
|
||
if args.dry_run:
|
||
run_dry(config)
|
||
elif args.dpo:
|
||
logger.info("DPO training not yet implemented in standalone script")
|
||
logger.info("Use: python3 scripts/manual_trigger.py --dpo")
|
||
sys.exit(1)
|
||
else:
|
||
run_sft(config, args.model, from_scratch=args.from_scratch)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|