Rene Fichtmueller 2ca77d0aee feat: Phase 2F — Multi-Agent Integration (ADRs + Client Fallback + Tests)
- ADR-0001: Multi-Agent Coworking Architecture with LLM Gateway Orchestrator
- ADR-0002: Tier Assignment Strategy for Model Selection (cost-first escalation)
- ADR-0003: Confidence Gate Thresholds & Learning Cycle Intervals (6h/12h/24h cycles)
- ADR-0004: External Provider Fallback Chain Ordering (Cerebras → Groq → Mistral)
- Enhanced client SDK: Offline Ollama fallback, health checks, exponential backoff retry
- Integration tests: claude-code-integration.test.ts (14 test cases)
- PHASE_2F_DEPLOYMENT.md: Pre-deployment checklist, automated deploy, rollback plan
- Post-deployment verification procedures for health, client fallback, metrics
2026-04-19 21:39:44 +02:00

319 lines
10 KiB
Python

"""
trainer.py - LoRA / SFT fine-tuning using PEFT + TRL.
Supports Apple Silicon MPS (primary) with automatic CPU fallback.
Trains a LoRA adapter on top of Qwen2.5-Instruct using ChatML format,
then returns training metrics for the orchestrator to evaluate and record.
MPS notes (torch 2.x):
- device_map is NOT supported with MPS; load the full model and call
model.to("mps") explicitly after PEFT wrapping.
- gradient_checkpointing is incompatible with MPS; leave disabled.
- use_cache must be False during training to avoid shape conflicts.
"""
from __future__ import annotations
import logging
import os
from pathlib import Path
from typing import Optional
import torch
from datasets import Dataset
from peft import LoraConfig, TaskType, get_peft_model
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
TrainingArguments,
)
from trl import SFTConfig, SFTTrainer
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
CHATML_TEMPLATE = (
"<|im_start|>system\n{system}<|im_end|>\n"
"<|im_start|>user\n{user}<|im_end|>\n"
"<|im_start|>assistant\n{assistant}<|im_end|>"
)
QWEN_TARGET_MODULES = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]
# ---------------------------------------------------------------------------
# Dataset preparation
# ---------------------------------------------------------------------------
def prepare_dataset(examples: list[dict]) -> Dataset:
"""
Convert learning_corpus rows to ChatML-formatted text examples.
Each example dict must have: system_prompt, input_text, output_text.
Rows with missing/empty fields are silently skipped.
"""
formatted: list[dict] = []
skipped = 0
for ex in examples:
system = (ex.get("system_prompt") or "").strip()
user = (ex.get("input_text") or "").strip()
assistant = (ex.get("output_text") or "").strip()
if not user or not assistant:
skipped += 1
continue
if not system:
system = "You are a helpful assistant."
text = CHATML_TEMPLATE.format(system=system, user=user, assistant=assistant)
formatted.append({"text": text})
if skipped:
logger.warning("prepare_dataset: skipped %d rows with missing fields", skipped)
logger.info("prepare_dataset: %d examples formatted", len(formatted))
return Dataset.from_list(formatted)
# ---------------------------------------------------------------------------
# Device selection
# ---------------------------------------------------------------------------
def _select_device() -> str:
"""Return 'mps', 'cuda', or 'cpu' depending on availability."""
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
return "mps"
if torch.cuda.is_available():
return "cuda"
return "cpu"
def _load_model_and_tokenizer(
base_model_path: str,
device: str,
) -> tuple:
"""
Load tokenizer and base model for LoRA training.
MPS: load in float32 (bfloat16/float16 not fully supported on MPS).
CPU: float32.
CUDA: bfloat16 with optional device_map="auto".
"""
logger.info("Loading tokenizer from %s", base_model_path)
tokenizer = AutoTokenizer.from_pretrained(
base_model_path,
trust_remote_code=True,
padding_side="right", # required for SFT with left-pad models
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info("Set pad_token = eos_token (%s)", tokenizer.eos_token)
logger.info("Loading base model from %s on device=%s", base_model_path, device)
if device == "cuda":
model = AutoModelForCausalLM.from_pretrained(
base_model_path,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
elif device == "mps":
# MPS: float16 halves memory vs float32; model moved to MPS after PEFT wrapping
model = AutoModelForCausalLM.from_pretrained(
base_model_path,
dtype=torch.float16,
trust_remote_code=True,
)
else:
# CPU: float32
model = AutoModelForCausalLM.from_pretrained(
base_model_path,
dtype=torch.float32,
trust_remote_code=True,
)
model.config.use_cache = False # required for training
return model, tokenizer
# ---------------------------------------------------------------------------
# LoRA configuration
# ---------------------------------------------------------------------------
def _build_lora_config(
r: int = 16,
lora_alpha: int = 32,
lora_dropout: float = 0.05,
target_modules: Optional[list[str]] = None,
) -> LoraConfig:
return LoraConfig(
r=r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
bias="none",
task_type=TaskType.CAUSAL_LM,
target_modules=target_modules or QWEN_TARGET_MODULES,
inference_mode=False,
)
# ---------------------------------------------------------------------------
# Main training entry point
# ---------------------------------------------------------------------------
def run_lora_training(
base_model_path: str,
train_examples: list[dict],
val_examples: list[dict],
output_dir: str,
task_type: Optional[str] = None,
lora_r: int = 16,
lora_alpha: int = 32,
lora_dropout: float = 0.05,
max_seq_length: int = 2048,
num_epochs: int = 3,
batch_size: int = 1,
gradient_accumulation_steps: int = 8,
learning_rate: float = 2e-4,
warmup_ratio: float = 0.1,
) -> dict:
"""
Full LoRA fine-tuning run using SFTTrainer.
Returns a metrics dict:
{
"train_loss": float,
"eval_loss": float,
"train_runtime": float,
"adapter_path": str,
"device": str,
}
Raises on fatal errors so the orchestrator can record failure status.
"""
device = _select_device()
logger.info("run_lora_training: device=%s task_type=%s output_dir=%s", device, task_type, output_dir)
if len(train_examples) < 10:
raise ValueError(
f"Insufficient training data: need >= 10 examples, got {len(train_examples)}"
)
# Prepare datasets
train_dataset = prepare_dataset(train_examples)
eval_dataset = prepare_dataset(val_examples) if val_examples else None
if len(train_dataset) == 0:
raise ValueError("All training examples were invalid — dataset is empty after formatting")
# Load model
model, tokenizer = _load_model_and_tokenizer(base_model_path, device)
# Apply LoRA
lora_config = _build_lora_config(
r=lora_r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# Move to device AFTER PEFT wrapping (MPS requirement)
if device in ("mps", "cpu"):
model = model.to(device)
# Training arguments
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# eval_strategy requires a validation set
eval_strategy = "steps" if eval_dataset and len(eval_dataset) > 0 else "no"
training_args = SFTConfig(
output_dir=str(output_path),
num_train_epochs=num_epochs,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
learning_rate=learning_rate,
warmup_ratio=warmup_ratio,
eval_strategy=eval_strategy,
eval_steps=50 if eval_strategy == "steps" else None,
save_strategy="steps",
save_steps=100,
load_best_model_at_end=(eval_strategy == "steps"),
metric_for_best_model="eval_loss" if eval_strategy == "steps" else None,
greater_is_better=False,
logging_steps=10,
report_to="none", # no WandB / HF Hub logging
dataloader_num_workers=0, # MPS requires 0 (no multiprocessing)
fp16=False, # MPS does not support fp16 training
bf16=False, # MPS does not support bf16 training
optim="adamw_torch", # paged_adamw_8bit requires bitsandbytes (CUDA only)
gradient_checkpointing=True, # torch 2.x+ supports MPS with use_reentrant=False
gradient_checkpointing_kwargs={"use_reentrant": False},
remove_unused_columns=False,
label_names=["labels"],
dataset_text_field="text",
max_length=max_seq_length,
packing=False, # packing can cause issues with MPS
)
# Trainer
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
processing_class=tokenizer,
)
logger.info(
"Starting SFT training: %d train examples, %d val examples, %d epochs",
len(train_dataset),
len(eval_dataset) if eval_dataset else 0,
num_epochs,
)
train_result = trainer.train()
# Evaluate if possible
eval_metrics: dict = {}
if eval_dataset and len(eval_dataset) > 0:
eval_metrics = trainer.evaluate()
logger.info("Eval metrics: %s", eval_metrics)
# Save adapter (LoRA weights only — not the full model)
adapter_path = str(output_path / "adapter")
model.save_pretrained(adapter_path)
tokenizer.save_pretrained(adapter_path)
logger.info("Saved LoRA adapter to %s", adapter_path)
return {
"train_loss": round(train_result.training_loss, 4),
"eval_loss": round(eval_metrics.get("eval_loss", -1.0), 4),
"train_runtime": round(train_result.metrics.get("train_runtime", 0.0), 1),
"train_samples": len(train_dataset),
"val_samples": len(eval_dataset) if eval_dataset else 0,
"adapter_path": adapter_path,
"device": device,
"task_type": task_type,
"epochs": num_epochs,
}