- ADR-0001: Multi-Agent Coworking Architecture with LLM Gateway Orchestrator - ADR-0002: Tier Assignment Strategy for Model Selection (cost-first escalation) - ADR-0003: Confidence Gate Thresholds & Learning Cycle Intervals (6h/12h/24h cycles) - ADR-0004: External Provider Fallback Chain Ordering (Cerebras → Groq → Mistral) - Enhanced client SDK: Offline Ollama fallback, health checks, exponential backoff retry - Integration tests: claude-code-integration.test.ts (14 test cases) - PHASE_2F_DEPLOYMENT.md: Pre-deployment checklist, automated deploy, rollback plan - Post-deployment verification procedures for health, client fallback, metrics
319 lines
10 KiB
Python
319 lines
10 KiB
Python
"""
|
|
trainer.py - LoRA / SFT fine-tuning using PEFT + TRL.
|
|
|
|
Supports Apple Silicon MPS (primary) with automatic CPU fallback.
|
|
Trains a LoRA adapter on top of Qwen2.5-Instruct using ChatML format,
|
|
then returns training metrics for the orchestrator to evaluate and record.
|
|
|
|
MPS notes (torch 2.x):
|
|
- device_map is NOT supported with MPS; load the full model and call
|
|
model.to("mps") explicitly after PEFT wrapping.
|
|
- gradient_checkpointing is incompatible with MPS; leave disabled.
|
|
- use_cache must be False during training to avoid shape conflicts.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
import torch
|
|
from datasets import Dataset
|
|
from peft import LoraConfig, TaskType, get_peft_model
|
|
from transformers import (
|
|
AutoModelForCausalLM,
|
|
AutoTokenizer,
|
|
BitsAndBytesConfig,
|
|
TrainingArguments,
|
|
)
|
|
from trl import SFTConfig, SFTTrainer
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Constants
|
|
# ---------------------------------------------------------------------------
|
|
|
|
CHATML_TEMPLATE = (
|
|
"<|im_start|>system\n{system}<|im_end|>\n"
|
|
"<|im_start|>user\n{user}<|im_end|>\n"
|
|
"<|im_start|>assistant\n{assistant}<|im_end|>"
|
|
)
|
|
|
|
QWEN_TARGET_MODULES = [
|
|
"q_proj",
|
|
"k_proj",
|
|
"v_proj",
|
|
"o_proj",
|
|
"gate_proj",
|
|
"up_proj",
|
|
"down_proj",
|
|
]
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Dataset preparation
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def prepare_dataset(examples: list[dict]) -> Dataset:
|
|
"""
|
|
Convert learning_corpus rows to ChatML-formatted text examples.
|
|
|
|
Each example dict must have: system_prompt, input_text, output_text.
|
|
Rows with missing/empty fields are silently skipped.
|
|
"""
|
|
formatted: list[dict] = []
|
|
skipped = 0
|
|
|
|
for ex in examples:
|
|
system = (ex.get("system_prompt") or "").strip()
|
|
user = (ex.get("input_text") or "").strip()
|
|
assistant = (ex.get("output_text") or "").strip()
|
|
|
|
if not user or not assistant:
|
|
skipped += 1
|
|
continue
|
|
|
|
if not system:
|
|
system = "You are a helpful assistant."
|
|
|
|
text = CHATML_TEMPLATE.format(system=system, user=user, assistant=assistant)
|
|
formatted.append({"text": text})
|
|
|
|
if skipped:
|
|
logger.warning("prepare_dataset: skipped %d rows with missing fields", skipped)
|
|
|
|
logger.info("prepare_dataset: %d examples formatted", len(formatted))
|
|
return Dataset.from_list(formatted)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Device selection
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _select_device() -> str:
|
|
"""Return 'mps', 'cuda', or 'cpu' depending on availability."""
|
|
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
|
|
return "mps"
|
|
if torch.cuda.is_available():
|
|
return "cuda"
|
|
return "cpu"
|
|
|
|
|
|
def _load_model_and_tokenizer(
|
|
base_model_path: str,
|
|
device: str,
|
|
) -> tuple:
|
|
"""
|
|
Load tokenizer and base model for LoRA training.
|
|
|
|
MPS: load in float32 (bfloat16/float16 not fully supported on MPS).
|
|
CPU: float32.
|
|
CUDA: bfloat16 with optional device_map="auto".
|
|
"""
|
|
logger.info("Loading tokenizer from %s", base_model_path)
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
base_model_path,
|
|
trust_remote_code=True,
|
|
padding_side="right", # required for SFT with left-pad models
|
|
)
|
|
if tokenizer.pad_token is None:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
logger.info("Set pad_token = eos_token (%s)", tokenizer.eos_token)
|
|
|
|
logger.info("Loading base model from %s on device=%s", base_model_path, device)
|
|
|
|
if device == "cuda":
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
base_model_path,
|
|
torch_dtype=torch.bfloat16,
|
|
device_map="auto",
|
|
trust_remote_code=True,
|
|
)
|
|
elif device == "mps":
|
|
# MPS: float16 halves memory vs float32; model moved to MPS after PEFT wrapping
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
base_model_path,
|
|
dtype=torch.float16,
|
|
trust_remote_code=True,
|
|
)
|
|
else:
|
|
# CPU: float32
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
base_model_path,
|
|
dtype=torch.float32,
|
|
trust_remote_code=True,
|
|
)
|
|
|
|
model.config.use_cache = False # required for training
|
|
return model, tokenizer
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# LoRA configuration
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _build_lora_config(
|
|
r: int = 16,
|
|
lora_alpha: int = 32,
|
|
lora_dropout: float = 0.05,
|
|
target_modules: Optional[list[str]] = None,
|
|
) -> LoraConfig:
|
|
return LoraConfig(
|
|
r=r,
|
|
lora_alpha=lora_alpha,
|
|
lora_dropout=lora_dropout,
|
|
bias="none",
|
|
task_type=TaskType.CAUSAL_LM,
|
|
target_modules=target_modules or QWEN_TARGET_MODULES,
|
|
inference_mode=False,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Main training entry point
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def run_lora_training(
|
|
base_model_path: str,
|
|
train_examples: list[dict],
|
|
val_examples: list[dict],
|
|
output_dir: str,
|
|
task_type: Optional[str] = None,
|
|
lora_r: int = 16,
|
|
lora_alpha: int = 32,
|
|
lora_dropout: float = 0.05,
|
|
max_seq_length: int = 2048,
|
|
num_epochs: int = 3,
|
|
batch_size: int = 1,
|
|
gradient_accumulation_steps: int = 8,
|
|
learning_rate: float = 2e-4,
|
|
warmup_ratio: float = 0.1,
|
|
) -> dict:
|
|
"""
|
|
Full LoRA fine-tuning run using SFTTrainer.
|
|
|
|
Returns a metrics dict:
|
|
{
|
|
"train_loss": float,
|
|
"eval_loss": float,
|
|
"train_runtime": float,
|
|
"adapter_path": str,
|
|
"device": str,
|
|
}
|
|
|
|
Raises on fatal errors so the orchestrator can record failure status.
|
|
"""
|
|
device = _select_device()
|
|
logger.info("run_lora_training: device=%s task_type=%s output_dir=%s", device, task_type, output_dir)
|
|
|
|
if len(train_examples) < 10:
|
|
raise ValueError(
|
|
f"Insufficient training data: need >= 10 examples, got {len(train_examples)}"
|
|
)
|
|
|
|
# Prepare datasets
|
|
train_dataset = prepare_dataset(train_examples)
|
|
eval_dataset = prepare_dataset(val_examples) if val_examples else None
|
|
|
|
if len(train_dataset) == 0:
|
|
raise ValueError("All training examples were invalid — dataset is empty after formatting")
|
|
|
|
# Load model
|
|
model, tokenizer = _load_model_and_tokenizer(base_model_path, device)
|
|
|
|
# Apply LoRA
|
|
lora_config = _build_lora_config(
|
|
r=lora_r,
|
|
lora_alpha=lora_alpha,
|
|
lora_dropout=lora_dropout,
|
|
)
|
|
model = get_peft_model(model, lora_config)
|
|
model.print_trainable_parameters()
|
|
|
|
# Move to device AFTER PEFT wrapping (MPS requirement)
|
|
if device in ("mps", "cpu"):
|
|
model = model.to(device)
|
|
|
|
# Training arguments
|
|
output_path = Path(output_dir)
|
|
output_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
# eval_strategy requires a validation set
|
|
eval_strategy = "steps" if eval_dataset and len(eval_dataset) > 0 else "no"
|
|
|
|
training_args = SFTConfig(
|
|
output_dir=str(output_path),
|
|
num_train_epochs=num_epochs,
|
|
per_device_train_batch_size=batch_size,
|
|
per_device_eval_batch_size=batch_size,
|
|
gradient_accumulation_steps=gradient_accumulation_steps,
|
|
learning_rate=learning_rate,
|
|
warmup_ratio=warmup_ratio,
|
|
eval_strategy=eval_strategy,
|
|
eval_steps=50 if eval_strategy == "steps" else None,
|
|
save_strategy="steps",
|
|
save_steps=100,
|
|
load_best_model_at_end=(eval_strategy == "steps"),
|
|
metric_for_best_model="eval_loss" if eval_strategy == "steps" else None,
|
|
greater_is_better=False,
|
|
logging_steps=10,
|
|
report_to="none", # no WandB / HF Hub logging
|
|
dataloader_num_workers=0, # MPS requires 0 (no multiprocessing)
|
|
fp16=False, # MPS does not support fp16 training
|
|
bf16=False, # MPS does not support bf16 training
|
|
optim="adamw_torch", # paged_adamw_8bit requires bitsandbytes (CUDA only)
|
|
gradient_checkpointing=True, # torch 2.x+ supports MPS with use_reentrant=False
|
|
gradient_checkpointing_kwargs={"use_reentrant": False},
|
|
remove_unused_columns=False,
|
|
label_names=["labels"],
|
|
dataset_text_field="text",
|
|
max_length=max_seq_length,
|
|
packing=False, # packing can cause issues with MPS
|
|
)
|
|
|
|
# Trainer
|
|
trainer = SFTTrainer(
|
|
model=model,
|
|
args=training_args,
|
|
train_dataset=train_dataset,
|
|
eval_dataset=eval_dataset,
|
|
processing_class=tokenizer,
|
|
)
|
|
|
|
logger.info(
|
|
"Starting SFT training: %d train examples, %d val examples, %d epochs",
|
|
len(train_dataset),
|
|
len(eval_dataset) if eval_dataset else 0,
|
|
num_epochs,
|
|
)
|
|
|
|
train_result = trainer.train()
|
|
|
|
# Evaluate if possible
|
|
eval_metrics: dict = {}
|
|
if eval_dataset and len(eval_dataset) > 0:
|
|
eval_metrics = trainer.evaluate()
|
|
logger.info("Eval metrics: %s", eval_metrics)
|
|
|
|
# Save adapter (LoRA weights only — not the full model)
|
|
adapter_path = str(output_path / "adapter")
|
|
model.save_pretrained(adapter_path)
|
|
tokenizer.save_pretrained(adapter_path)
|
|
logger.info("Saved LoRA adapter to %s", adapter_path)
|
|
|
|
return {
|
|
"train_loss": round(train_result.training_loss, 4),
|
|
"eval_loss": round(eval_metrics.get("eval_loss", -1.0), 4),
|
|
"train_runtime": round(train_result.metrics.get("train_runtime", 0.0), 1),
|
|
"train_samples": len(train_dataset),
|
|
"val_samples": len(eval_dataset) if eval_dataset else 0,
|
|
"adapter_path": adapter_path,
|
|
"device": device,
|
|
"task_type": task_type,
|
|
"epochs": num_epochs,
|
|
}
|