""" trainer.py - LoRA / SFT fine-tuning using PEFT + TRL. Supports Apple Silicon MPS (primary) with automatic CPU fallback. Trains a LoRA adapter on top of Qwen2.5-Instruct using ChatML format, then returns training metrics for the orchestrator to evaluate and record. MPS notes (torch 2.x): - device_map is NOT supported with MPS; load the full model and call model.to("mps") explicitly after PEFT wrapping. - gradient_checkpointing is incompatible with MPS; leave disabled. - use_cache must be False during training to avoid shape conflicts. """ from __future__ import annotations import logging import os from pathlib import Path from typing import Optional import torch from datasets import Dataset from peft import LoraConfig, TaskType, get_peft_model from transformers import ( AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, ) from trl import SFTConfig, SFTTrainer logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- CHATML_TEMPLATE = ( "<|im_start|>system\n{system}<|im_end|>\n" "<|im_start|>user\n{user}<|im_end|>\n" "<|im_start|>assistant\n{assistant}<|im_end|>" ) QWEN_TARGET_MODULES = [ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ] # --------------------------------------------------------------------------- # Dataset preparation # --------------------------------------------------------------------------- def prepare_dataset(examples: list[dict]) -> Dataset: """ Convert learning_corpus rows to ChatML-formatted text examples. Each example dict must have: system_prompt, input_text, output_text. Rows with missing/empty fields are silently skipped. """ formatted: list[dict] = [] skipped = 0 for ex in examples: system = (ex.get("system_prompt") or "").strip() user = (ex.get("input_text") or "").strip() assistant = (ex.get("output_text") or "").strip() if not user or not assistant: skipped += 1 continue if not system: system = "You are a helpful assistant." text = CHATML_TEMPLATE.format(system=system, user=user, assistant=assistant) formatted.append({"text": text}) if skipped: logger.warning("prepare_dataset: skipped %d rows with missing fields", skipped) logger.info("prepare_dataset: %d examples formatted", len(formatted)) return Dataset.from_list(formatted) # --------------------------------------------------------------------------- # Device selection # --------------------------------------------------------------------------- def _select_device() -> str: """Return 'mps', 'cuda', or 'cpu' depending on availability.""" if torch.backends.mps.is_available() and torch.backends.mps.is_built(): return "mps" if torch.cuda.is_available(): return "cuda" return "cpu" def _load_model_and_tokenizer( base_model_path: str, device: str, ) -> tuple: """ Load tokenizer and base model for LoRA training. MPS: load in float32 (bfloat16/float16 not fully supported on MPS). CPU: float32. CUDA: bfloat16 with optional device_map="auto". """ logger.info("Loading tokenizer from %s", base_model_path) tokenizer = AutoTokenizer.from_pretrained( base_model_path, trust_remote_code=True, padding_side="right", # required for SFT with left-pad models ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token logger.info("Set pad_token = eos_token (%s)", tokenizer.eos_token) logger.info("Loading base model from %s on device=%s", base_model_path, device) if device == "cuda": model = AutoModelForCausalLM.from_pretrained( base_model_path, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True, ) elif device == "mps": # MPS: float16 halves memory vs float32; model moved to MPS after PEFT wrapping model = AutoModelForCausalLM.from_pretrained( base_model_path, dtype=torch.float16, trust_remote_code=True, ) else: # CPU: float32 model = AutoModelForCausalLM.from_pretrained( base_model_path, dtype=torch.float32, trust_remote_code=True, ) model.config.use_cache = False # required for training return model, tokenizer # --------------------------------------------------------------------------- # LoRA configuration # --------------------------------------------------------------------------- def _build_lora_config( r: int = 16, lora_alpha: int = 32, lora_dropout: float = 0.05, target_modules: Optional[list[str]] = None, ) -> LoraConfig: return LoraConfig( r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, bias="none", task_type=TaskType.CAUSAL_LM, target_modules=target_modules or QWEN_TARGET_MODULES, inference_mode=False, ) # --------------------------------------------------------------------------- # Main training entry point # --------------------------------------------------------------------------- def run_lora_training( base_model_path: str, train_examples: list[dict], val_examples: list[dict], output_dir: str, task_type: Optional[str] = None, lora_r: int = 16, lora_alpha: int = 32, lora_dropout: float = 0.05, max_seq_length: int = 2048, num_epochs: int = 3, batch_size: int = 1, gradient_accumulation_steps: int = 8, learning_rate: float = 2e-4, warmup_ratio: float = 0.1, ) -> dict: """ Full LoRA fine-tuning run using SFTTrainer. Returns a metrics dict: { "train_loss": float, "eval_loss": float, "train_runtime": float, "adapter_path": str, "device": str, } Raises on fatal errors so the orchestrator can record failure status. """ device = _select_device() logger.info("run_lora_training: device=%s task_type=%s output_dir=%s", device, task_type, output_dir) if len(train_examples) < 10: raise ValueError( f"Insufficient training data: need >= 10 examples, got {len(train_examples)}" ) # Prepare datasets train_dataset = prepare_dataset(train_examples) eval_dataset = prepare_dataset(val_examples) if val_examples else None if len(train_dataset) == 0: raise ValueError("All training examples were invalid — dataset is empty after formatting") # Load model model, tokenizer = _load_model_and_tokenizer(base_model_path, device) # Apply LoRA lora_config = _build_lora_config( r=lora_r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, ) model = get_peft_model(model, lora_config) model.print_trainable_parameters() # Move to device AFTER PEFT wrapping (MPS requirement) if device in ("mps", "cpu"): model = model.to(device) # Training arguments output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) # eval_strategy requires a validation set eval_strategy = "steps" if eval_dataset and len(eval_dataset) > 0 else "no" training_args = SFTConfig( output_dir=str(output_path), num_train_epochs=num_epochs, per_device_train_batch_size=batch_size, per_device_eval_batch_size=batch_size, gradient_accumulation_steps=gradient_accumulation_steps, learning_rate=learning_rate, warmup_ratio=warmup_ratio, eval_strategy=eval_strategy, eval_steps=50 if eval_strategy == "steps" else None, save_strategy="steps", save_steps=100, load_best_model_at_end=(eval_strategy == "steps"), metric_for_best_model="eval_loss" if eval_strategy == "steps" else None, greater_is_better=False, logging_steps=10, report_to="none", # no WandB / HF Hub logging dataloader_num_workers=0, # MPS requires 0 (no multiprocessing) fp16=False, # MPS does not support fp16 training bf16=False, # MPS does not support bf16 training optim="adamw_torch", # paged_adamw_8bit requires bitsandbytes (CUDA only) gradient_checkpointing=True, # torch 2.x+ supports MPS with use_reentrant=False gradient_checkpointing_kwargs={"use_reentrant": False}, remove_unused_columns=False, label_names=["labels"], dataset_text_field="text", max_length=max_seq_length, packing=False, # packing can cause issues with MPS ) # Trainer trainer = SFTTrainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, processing_class=tokenizer, ) logger.info( "Starting SFT training: %d train examples, %d val examples, %d epochs", len(train_dataset), len(eval_dataset) if eval_dataset else 0, num_epochs, ) train_result = trainer.train() # Evaluate if possible eval_metrics: dict = {} if eval_dataset and len(eval_dataset) > 0: eval_metrics = trainer.evaluate() logger.info("Eval metrics: %s", eval_metrics) # Save adapter (LoRA weights only — not the full model) adapter_path = str(output_path / "adapter") model.save_pretrained(adapter_path) tokenizer.save_pretrained(adapter_path) logger.info("Saved LoRA adapter to %s", adapter_path) return { "train_loss": round(train_result.training_loss, 4), "eval_loss": round(eval_metrics.get("eval_loss", -1.0), 4), "train_runtime": round(train_result.metrics.get("train_runtime", 0.0), 1), "train_samples": len(train_dataset), "val_samples": len(eval_dataset) if eval_dataset else 0, "adapter_path": adapter_path, "device": device, "task_type": task_type, "epochs": num_epochs, }