#!/usr/bin/env python3 """ train-fixes.py — Standalone LoRA training from fixes-chatml-sft.jsonl No database required. Reads directly from local JSONL files. Continues training on existing magatama-coder adapter (Qwen2.5-7B-Instruct LoRA). Usage: source .venv/bin/activate python3 scripts/train-fixes.py --dry-run # Just validate data python3 scripts/train-fixes.py # Continue training magatama-coder with fixes python3 scripts/train-fixes.py --from-scratch # Fresh LoRA (ignore existing adapter) python3 scripts/train-fixes.py --model 14b # SFT on Qwen2.5-14B (from scratch) """ from __future__ import annotations import argparse import json import logging import os import sys from pathlib import Path import torch import yaml from datasets import Dataset from peft import LoraConfig, TaskType, get_peft_model from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments from peft import PeftModel from trl import SFTConfig, SFTTrainer logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)s [%(name)s] %(message)s", ) logger = logging.getLogger("train-fixes") ROOT = Path(__file__).resolve().parent.parent DATA_DIR = ROOT / "data" CONFIG_PATH = ROOT / "config" / "fixes-training.yaml" # Existing magatama-coder adapter from previous training EXISTING_ADAPTER = ROOT / "adapters" / "e1fb3d70-f681-41c5-80d6-9185c549d5d7" / "adapter" def load_config() -> dict: with open(CONFIG_PATH) as f: return yaml.safe_load(f) def load_chatml_data(path: Path) -> list[dict]: samples = [] with open(path) as f: for line in f: line = line.strip() if line: samples.append(json.loads(line)) logger.info("Loaded %d samples from %s", len(samples), path.name) return samples def load_dpo_data(path: Path) -> list[dict]: samples = [] with open(path) as f: for line in f: line = line.strip() if line: samples.append(json.loads(line)) logger.info("Loaded %d DPO pairs from %s", len(samples), path.name) return samples def get_device() -> str: if torch.backends.mps.is_available(): logger.info("Using Apple Silicon MPS") return "mps" if torch.cuda.is_available(): logger.info("Using CUDA GPU") return "cuda" logger.info("Using CPU (slow!)") return "cpu" def run_sft(config: dict, model_size: str = "7b", from_scratch: bool = False) -> None: sft_cfg = config["sft"] models = config["models"] model_id = models["primary"] if model_size == "7b" else models["secondary"] logger.info("Base model: %s", model_id) # Load data data_path = DATA_DIR / "fixes-chatml-sft.jsonl" samples = load_chatml_data(data_path) # Use all samples for training — eval disabled on MPS due to OOM during evaluation_loop # (eval loads logits for all tokens which uses too much memory on 7.6B model) train_ds = Dataset.from_list(samples) eval_ds = None logger.info("Train: %d samples (eval disabled on MPS for OOM protection)", len(samples)) # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Load model device = get_device() logger.info("Loading model to %s...", device) model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.float32 if device == "mps" else torch.float16, trust_remote_code=True, ) # Continue from existing adapter or create new LoRA use_existing = ( not from_scratch and model_size == "7b" and EXISTING_ADAPTER.exists() ) if use_existing: logger.info("Loading existing magatama-coder adapter from %s", EXISTING_ADAPTER) model = PeftModel.from_pretrained( model, str(EXISTING_ADAPTER), is_trainable=True, ) logger.info("Continuing training on existing adapter (magatama-coder)") else: logger.info("Creating fresh LoRA adapter") lora_config = LoraConfig( task_type=TaskType.CAUSAL_LM, r=sft_cfg["lora_r"], lora_alpha=sft_cfg["lora_alpha"], lora_dropout=sft_cfg["lora_dropout"], target_modules=sft_cfg["target_modules"], ) model = get_peft_model(model, lora_config) model.print_trainable_parameters() if device == "mps": model = model.to("mps") # Output dir output_dir = str(ROOT / config["output"]["adapters_dir"]) os.makedirs(output_dir, exist_ok=True) # Training args training_args = SFTConfig( output_dir=output_dir, num_train_epochs=sft_cfg["num_epochs"], per_device_train_batch_size=sft_cfg["batch_size"], gradient_accumulation_steps=sft_cfg["gradient_accumulation"], learning_rate=sft_cfg["learning_rate"], warmup_ratio=sft_cfg["warmup_ratio"], weight_decay=sft_cfg.get("weight_decay", 0.01), lr_scheduler_type=sft_cfg.get("lr_scheduler", "cosine"), logging_steps=sft_cfg.get("logging_steps", 10), save_steps=20, # Save every 20 steps so we have checkpoints if it crashes eval_strategy="no", # Disabled on MPS — causes OOM during evaluation_loop max_length=sft_cfg["max_seq_length"], dataset_text_field="text", bf16=False, fp16=False, use_cpu=(device == "cpu"), gradient_checkpointing=False, save_total_limit=5, load_best_model_at_end=False, save_safetensors=True, dataloader_num_workers=0, # Avoid multiprocessing overhead on MPS report_to="none", ) # Trainer trainer = SFTTrainer( model=model, args=training_args, train_dataset=train_ds, eval_dataset=eval_ds, processing_class=tokenizer, ) logger.info("Starting SFT training...") logger.info(" Epochs: %d", sft_cfg["num_epochs"]) logger.info(" Batch: %d × %d accum = %d effective", sft_cfg["batch_size"], sft_cfg["gradient_accumulation"], sft_cfg["batch_size"] * sft_cfg["gradient_accumulation"]) logger.info(" LR: %s", sft_cfg["learning_rate"]) logger.info(" Max seq: %d", sft_cfg["max_seq_length"]) logger.info(" Output: %s", output_dir) result = trainer.train() logger.info("Training complete!") logger.info(" Loss: %.4f", result.training_loss) logger.info(" Steps: %d", result.global_step) logger.info(" Runtime: %.0fs", result.metrics.get("train_runtime", 0)) # Save adapter trainer.save_model(output_dir) tokenizer.save_pretrained(output_dir) logger.info("Adapter saved to %s", output_dir) # Save training metadata meta = { "model_id": model_id, "dataset": str(data_path), "samples": len(train_samples), "eval_samples": len(eval_samples), "training_loss": result.training_loss, "global_step": result.global_step, "runtime_seconds": result.metrics.get("train_runtime", 0), "config": sft_cfg, } with open(os.path.join(output_dir, "training-meta.json"), "w") as f: json.dump(meta, f, indent=2, ensure_ascii=False) logger.info("Metadata saved to training-meta.json") def run_dry(config: dict) -> None: """Validate data without training.""" data_path = DATA_DIR / "fixes-chatml-sft.jsonl" samples = load_chatml_data(data_path) dpo_path = DATA_DIR / "fixes-dpo-pairs.jsonl" dpo_samples = load_dpo_data(dpo_path) # Token length analysis lengths = [len(s["text"]) / 4 for s in samples] # rough token estimate avg = sum(lengths) / len(lengths) max_len = max(lengths) under_512 = sum(1 for l in lengths if l < 512) / len(lengths) * 100 under_1024 = sum(1 for l in lengths if l < 1024) / len(lengths) * 100 print(f"\n{'═' * 50}") print(f"DRY RUN — Data Validation") print(f"{'═' * 50}") print(f"SFT samples: {len(samples)}") print(f"DPO pairs: {len(dpo_samples)}") print(f"Avg tokens: ~{avg:.0f}") print(f"Max tokens: ~{max_len:.0f}") print(f"Under 512: {under_512:.0f}%") print(f"Under 1024: {under_1024:.0f}%") print(f"Device: {get_device()}") print(f"PyTorch: {torch.__version__}") print(f"MPS available: {torch.backends.mps.is_available()}") # Validate format errors = 0 for i, s in enumerate(samples): text = s.get("text", "") if "<|im_start|>system" not in text: logger.error("Sample %d missing system tag", i) errors += 1 if "<|im_start|>assistant" not in text: logger.error("Sample %d missing assistant tag", i) errors += 1 print(f"Format errors: {errors}") print(f"Status: {'READY' if errors == 0 else 'FIX ERRORS'}") print(f"{'═' * 50}\n") def main() -> None: parser = argparse.ArgumentParser(description="Train MAGATAMA Operations AI from fixes.json") parser.add_argument("--model", choices=["7b", "14b"], default="7b", help="Model size") parser.add_argument("--dry-run", action="store_true", help="Validate data only") parser.add_argument("--from-scratch", action="store_true", help="Fresh LoRA, ignore existing adapter") parser.add_argument("--dpo", action="store_true", help="Run DPO phase (after SFT)") args = parser.parse_args() config = load_config() if args.dry_run: run_dry(config) elif args.dpo: logger.info("DPO training not yet implemented in standalone script") logger.info("Use: python3 scripts/manual_trigger.py --dpo") sys.exit(1) else: run_sft(config, args.model, from_scratch=args.from_scratch) if __name__ == "__main__": main()