| | |
| | """ |
| | FunctionGemma SFT fine-tuning script. |
| | |
| | Runs TRL SFTTrainer for FunctionGemma with two modes: |
| | 1) LoRA (recommended): faster, lower memory, less overfit |
| | 2) Full-parameter: higher cost, maximal capacity |
| | |
| | Usage: |
| | # LoRA (default) |
| | python -m src.train \ |
| | --model_path /path/to/model \ |
| | --dataset_path ./data/training_data.json \ |
| | --bf16 |
| | |
| | # Full-parameter |
| | python -m src.train \ |
| | --model_path /path/to/model \ |
| | --dataset_path ./data/training_data.json \ |
| | --no-use-lora \ |
| | --bf16 |
| | """ |
| |
|
| | import os |
| | import json |
| | import argparse |
| | import logging |
| | from datetime import datetime |
| | from pathlib import Path |
| | from typing import Optional |
| |
|
| | import torch |
| | from datasets import Dataset, load_dataset |
| | from transformers import ( |
| | AutoModelForCausalLM, |
| | AutoTokenizer, |
| | TrainingArguments, |
| | BitsAndBytesConfig, |
| | ) |
| | from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training |
| | from trl import SFTTrainer, SFTConfig |
| |
|
| | |
| | PROJECT_ROOT = Path(__file__).resolve().parent.parent |
| | DEFAULT_DATA_PATH = PROJECT_ROOT / "data" / "training_data.json" |
| | DEFAULT_OUTPUT_DIR = PROJECT_ROOT / "runs" |
| |
|
| | logging.basicConfig( |
| | level=logging.INFO, |
| | format='%(asctime)s - %(levelname)s - %(message)s' |
| | ) |
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | def parse_args(): |
| | """Parse CLI arguments.""" |
| | parser = argparse.ArgumentParser(description="FunctionGemma SFT fine-tuning (LoRA / full)") |
| | |
| | |
| | parser.add_argument( |
| | "--model_path", |
| | type=str, |
| | default="google/functiongemma-270m-it", |
| | help="Model path or HF model id" |
| | ) |
| | parser.add_argument( |
| | "--tokenizer_path", |
| | type=str, |
| | default=None, |
| | help="Tokenizer path (defaults to model_path)" |
| | ) |
| | |
| | |
| | parser.add_argument( |
| | "--dataset_path", |
| | type=str, |
| | default=str(DEFAULT_DATA_PATH), |
| | help="Training dataset path" |
| | ) |
| | parser.add_argument( |
| | "--val_split", |
| | type=float, |
| | default=0.1, |
| | help="Validation split ratio" |
| | ) |
| | |
| | |
| | parser.add_argument( |
| | "--output_dir", |
| | type=str, |
| | default=str(DEFAULT_OUTPUT_DIR), |
| | help="Root output directory" |
| | ) |
| | parser.add_argument( |
| | "--run_name", |
| | type=str, |
| | default=None, |
| | help="Run name for logging and saving" |
| | ) |
| | |
| | |
| | parser.add_argument( |
| | "--use_lora", |
| | action="store_true", |
| | default=True, |
| | help="Enable LoRA (recommended). Add --no-use-lora for full-parameter finetune" |
| | ) |
| | parser.add_argument("--no-use-lora", dest="use_lora", action="store_false", help="Disable LoRA, run full-parameter finetune") |
| | |
| | |
| | parser.add_argument("--lora_r", type=int, default=16, help="LoRA rank") |
| | parser.add_argument("--lora_alpha", type=int, default=32, help="LoRA alpha") |
| | parser.add_argument("--lora_dropout", type=float, default=0.05, help="LoRA dropout") |
| | parser.add_argument( |
| | "--target_modules", |
| | type=str, |
| | nargs="+", |
| | default=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], |
| | help="Target modules for LoRA" |
| | ) |
| | |
| | |
| | parser.add_argument("--num_train_epochs", type=int, default=6, help="Training epochs (official rec: 8)") |
| | parser.add_argument("--max_steps", type=int, default=-1, help="Max training steps (-1 to use epochs)") |
| | parser.add_argument("--per_device_train_batch_size", type=int, default=4, help="Train batch size per device") |
| | parser.add_argument("--per_device_eval_batch_size", type=int, default=2, help="Eval batch size") |
| | parser.add_argument("--gradient_accumulation_steps", type=int, default=8, help="Grad accumulation steps") |
| | parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate") |
| | parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay") |
| | parser.add_argument("--warmup_ratio", type=float, default=0.0, help="Warmup ratio (constant scheduler usually skips warmup)") |
| | parser.add_argument("--max_seq_length", type=int, default=2048, help="Max sequence length (model supports up to 32768)") |
| | parser.add_argument("--lr_scheduler_type", type=str, default="constant", help="LR scheduler type (default constant)") |
| | |
| | |
| | parser.add_argument("--bf16", action="store_true", help="Use BF16") |
| | parser.add_argument("--fp16", action="store_true", help="Use FP16") |
| | parser.add_argument("--use_4bit", action="store_true", help="Enable 4-bit quant (QLoRA)") |
| | parser.add_argument("--use_8bit", action="store_true", help="Enable 8-bit quant") |
| | parser.add_argument("--use_flash_attention", action="store_true", help="Enable Flash Attention 2") |
| | parser.add_argument("--gradient_checkpointing", action="store_true", help="Enable gradient checkpointing") |
| | |
| | |
| | parser.add_argument("--logging_steps", type=int, default=10, help="Log every N steps") |
| | parser.add_argument("--save_steps", type=int, default=100, help="Save checkpoint every N steps") |
| | parser.add_argument("--eval_steps", type=int, default=100, help="Eval every N steps") |
| | parser.add_argument("--save_total_limit", type=int, default=3, help="Max checkpoints to keep") |
| | |
| | |
| | parser.add_argument("--seed", type=int, default=42, help="Random seed") |
| | parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="Resume from checkpoint") |
| | parser.add_argument("--push_to_hub", action="store_true", help="Push to Hugging Face Hub") |
| | parser.add_argument("--hub_model_id", type=str, default=None, help="Hub model id") |
| | |
| | return parser.parse_args() |
| |
|
| |
|
| | def load_and_prepare_dataset(dataset_path: str, val_split: float = 0.1): |
| | """Load and normalize dataset structure for SFT.""" |
| | logger.info(f"Loading dataset: {dataset_path}") |
| | |
| | |
| | with open(dataset_path, 'r', encoding='utf-8') as f: |
| | data = json.load(f) |
| | |
| | logger.info(f"Dataset size: {len(data)} samples") |
| | |
| | |
| | |
| | processed_data = [] |
| | for idx, item in enumerate(data): |
| | if 'input' in item and 'messages' in item['input']: |
| | |
| | messages = json.loads(json.dumps(item['input']['messages'])) |
| | |
| | |
| | for msg in messages: |
| | if 'tool_calls' in msg and msg['tool_calls']: |
| | for tc in msg['tool_calls']: |
| | if 'function' in tc and 'arguments' in tc['function']: |
| | args = tc['function']['arguments'] |
| | |
| | if not isinstance(args, str): |
| | tc['function']['arguments'] = json.dumps(args) |
| | |
| | |
| | if 'expected' in item and item['expected']: |
| | expected = item['expected'] |
| | |
| | if messages[-1]['role'] != 'assistant': |
| | |
| | function_name = expected.get('function_name') |
| | arguments = expected.get('arguments') |
| | response = expected.get('response', '') |
| | |
| | if function_name is not None and arguments is not None: |
| | |
| | arguments_str = json.dumps(arguments) if isinstance(arguments, dict) else str(arguments) |
| | |
| | assistant_msg = { |
| | "role": "assistant", |
| | "content": None, |
| | "tool_calls": [{ |
| | "id": f"call_{hash(function_name + arguments_str) % 1000000}", |
| | "type": "function", |
| | "function": { |
| | "name": function_name, |
| | "arguments": arguments_str |
| | } |
| | }] |
| | } |
| | messages.append(assistant_msg) |
| | logger.debug(f"Added assistant tool_calls: {function_name}") |
| | elif function_name is None and arguments is None and response: |
| | |
| | assistant_msg = { |
| | "role": "assistant", |
| | "content": response |
| | } |
| | messages.append(assistant_msg) |
| | logger.debug(f"Added assistant refusal response: {response[:50]}") |
| | else: |
| | logger.warning(f"Unknown expected format: {expected}") |
| | |
| | processed_item = { |
| | 'messages': messages |
| | } |
| | |
| | |
| | if 'tools' in item['input']: |
| | processed_item['tools'] = item['input']['tools'] |
| | |
| | |
| | if 'id' in item: |
| | processed_item['id'] = item['id'] |
| | |
| | |
| | for msg in processed_item['messages']: |
| | if 'tool_calls' in msg and msg['tool_calls']: |
| | for tc in msg['tool_calls']: |
| | if 'function' in tc and 'arguments' in tc['function']: |
| | if not isinstance(tc['function']['arguments'], str): |
| | logger.error(f"Sample {idx} arguments not string: {type(tc['function']['arguments'])}") |
| | tc['function']['arguments'] = json.dumps(tc['function']['arguments']) |
| | |
| | processed_data.append(processed_item) |
| | |
| | elif 'messages' in item: |
| | |
| | messages = json.loads(json.dumps(item['messages'])) |
| | for msg in messages: |
| | if 'tool_calls' in msg and msg['tool_calls']: |
| | for tc in msg['tool_calls']: |
| | if 'function' in tc and 'arguments' in tc['function']: |
| | if not isinstance(tc['function']['arguments'], str): |
| | tc['function']['arguments'] = json.dumps(tc['function']['arguments']) |
| | item_copy = dict(item) |
| | item_copy['messages'] = messages |
| | processed_data.append(item_copy) |
| | else: |
| | logger.warning(f"Skip malformed item: {item.get('id', 'unknown')}") |
| | |
| | logger.info(f"Processed dataset size: {len(processed_data)}") |
| | |
| | |
| | tool_calls_count = 0 |
| | for item in processed_data: |
| | for msg in item['messages']: |
| | if 'tool_calls' in msg and msg['tool_calls']: |
| | tool_calls_count += 1 |
| | for tc in msg['tool_calls']: |
| | if 'function' in tc and 'arguments' in tc['function']: |
| | if not isinstance(tc['function']['arguments'], str): |
| | logger.error(f"Found non-string arguments: {type(tc['function']['arguments'])}") |
| | logger.info(f"Messages containing tool_calls: {tool_calls_count}") |
| | |
| | |
| | dataset = Dataset.from_list(processed_data) |
| | |
| | |
| | if val_split > 0: |
| | dataset = dataset.train_test_split(test_size=val_split, seed=42) |
| | train_dataset = dataset['train'] |
| | eval_dataset = dataset['test'] |
| | logger.info(f"Train: {len(train_dataset)}, Eval: {len(eval_dataset)}") |
| | else: |
| | train_dataset = dataset |
| | eval_dataset = None |
| | logger.info(f"Train: {len(train_dataset)}, no eval split") |
| | |
| | return train_dataset, eval_dataset |
| |
|
| |
|
| | def get_quantization_config(use_4bit: bool, use_8bit: bool): |
| | """Build quantization config if requested.""" |
| | if use_4bit: |
| | logger.info("Using 4-bit quantization (QLoRA)") |
| | return BitsAndBytesConfig( |
| | load_in_4bit=True, |
| | bnb_4bit_quant_type="nf4", |
| | bnb_4bit_compute_dtype=torch.bfloat16, |
| | bnb_4bit_use_double_quant=True, |
| | ) |
| | elif use_8bit: |
| | logger.info("Using 8-bit quantization") |
| | return BitsAndBytesConfig( |
| | load_in_8bit=True, |
| | ) |
| | return None |
| |
|
| |
|
| | def load_model_and_tokenizer(args): |
| | """Load model and tokenizer.""" |
| | logger.info(f"Loading model: {args.model_path}") |
| | |
| | tokenizer_path = args.tokenizer_path or args.model_path |
| | |
| | |
| | tokenizer = AutoTokenizer.from_pretrained( |
| | tokenizer_path, |
| | trust_remote_code=True, |
| | padding_side="right", |
| | ) |
| | |
| | |
| | if tokenizer.pad_token is None: |
| | tokenizer.pad_token = tokenizer.eos_token |
| | tokenizer.pad_token_id = tokenizer.eos_token_id |
| | |
| | |
| | quantization_config = get_quantization_config(args.use_4bit, args.use_8bit) |
| | |
| | |
| | model_kwargs = { |
| | "trust_remote_code": True, |
| | "device_map": "auto", |
| | } |
| | |
| | if quantization_config: |
| | model_kwargs["quantization_config"] = quantization_config |
| | |
| | |
| | if args.bf16 and not (args.use_4bit or args.use_8bit): |
| | model_kwargs["torch_dtype"] = torch.bfloat16 |
| | elif args.fp16 and not (args.use_4bit or args.use_8bit): |
| | model_kwargs["torch_dtype"] = torch.float16 |
| | |
| | |
| | if args.use_flash_attention: |
| | model_kwargs["attn_implementation"] = "flash_attention_2" |
| | logger.info("Using Flash Attention 2") |
| | |
| | |
| | model = AutoModelForCausalLM.from_pretrained( |
| | args.model_path, |
| | **model_kwargs |
| | ) |
| | |
| | |
| | if args.use_4bit or args.use_8bit: |
| | model = prepare_model_for_kbit_training(model) |
| | |
| | |
| | if args.gradient_checkpointing: |
| | model.gradient_checkpointing_enable() |
| | logger.info("Enabled gradient checkpointing") |
| | |
| | logger.info(f"Model parameters: {model.num_parameters():,}") |
| | |
| | return model, tokenizer |
| |
|
| |
|
| | def get_lora_config(args): |
| | """Build LoRA config.""" |
| | logger.info(f"LoRA config: r={args.lora_r}, alpha={args.lora_alpha}, dropout={args.lora_dropout}") |
| | logger.info(f"Target modules: {args.target_modules}") |
| | |
| | return LoraConfig( |
| | r=args.lora_r, |
| | lora_alpha=args.lora_alpha, |
| | lora_dropout=args.lora_dropout, |
| | target_modules=args.target_modules, |
| | bias="none", |
| | task_type=TaskType.CAUSAL_LM, |
| | ) |
| |
|
| |
|
| | def formatting_func(example): |
| | """ |
| | Format function: pass data through for SFTTrainer. |
| | |
| | Dataset format: |
| | { |
| | "messages": [ |
| | {"role": "developer", "content": "..."}, |
| | {"role": "user", "content": "..."}, |
| | {"role": "assistant", "tool_calls": [...]} or {"role": "assistant", "content": "..."} |
| | ], |
| | "tools": [...] |
| | } |
| | """ |
| | |
| | return example |
| |
|
| |
|
| | def main(): |
| | args = parse_args() |
| | |
| | |
| | if args.run_name is None: |
| | args.run_name = f"functiongemma-lora-{datetime.now().strftime('%Y%m%d_%H%M%S')}" |
| | |
| | |
| | output_dir = os.path.join(args.output_dir, args.run_name) |
| | os.makedirs(output_dir, exist_ok=True) |
| | |
| | logger.info("=" * 60) |
| | logger.info("FunctionGemma SFT LoRA training") |
| | logger.info("=" * 60) |
| | logger.info(f"Output dir: {output_dir}") |
| | |
| | |
| | config_path = os.path.join(output_dir, "training_config.json") |
| | with open(config_path, 'w') as f: |
| | json.dump(vars(args), f, indent=2) |
| | logger.info(f"Config saved to: {config_path}") |
| | |
| | |
| | train_dataset, eval_dataset = load_and_prepare_dataset( |
| | args.dataset_path, |
| | args.val_split |
| | ) |
| | |
| | |
| | model, tokenizer = load_model_and_tokenizer(args) |
| | |
| | |
| | if args.use_lora: |
| | logger.info("=" * 60) |
| | logger.info("LoRA fine-tuning mode") |
| | logger.info("=" * 60) |
| | lora_config = get_lora_config(args) |
| | else: |
| | logger.info("=" * 60) |
| | logger.info("Full-parameter fine-tuning mode") |
| | logger.info("Warning: full fine-tuning needs more memory and time!") |
| | logger.info("=" * 60) |
| | lora_config = None |
| | |
| | |
| | training_args = SFTConfig( |
| | output_dir=output_dir, |
| | run_name=args.run_name, |
| | |
| | |
| | max_length=args.max_seq_length, |
| | packing=False, |
| | |
| | |
| | num_train_epochs=args.num_train_epochs, |
| | max_steps=args.max_steps, |
| | per_device_train_batch_size=args.per_device_train_batch_size, |
| | per_device_eval_batch_size=args.per_device_eval_batch_size, |
| | gradient_accumulation_steps=args.gradient_accumulation_steps, |
| | |
| | |
| | learning_rate=args.learning_rate, |
| | weight_decay=args.weight_decay, |
| | warmup_ratio=args.warmup_ratio, |
| | lr_scheduler_type=args.lr_scheduler_type, |
| | optim="adamw_torch_fused", |
| | |
| | |
| | bf16=args.bf16, |
| | fp16=args.fp16, |
| | |
| | |
| | logging_steps=args.logging_steps, |
| | save_steps=args.save_steps, |
| | eval_steps=args.eval_steps if eval_dataset else None, |
| | eval_strategy="steps" if eval_dataset else "no", |
| | save_total_limit=args.save_total_limit, |
| | load_best_model_at_end=True if eval_dataset else False, |
| | |
| | |
| | seed=args.seed, |
| | report_to=["tensorboard"], |
| | |
| | |
| | push_to_hub=args.push_to_hub, |
| | hub_model_id=args.hub_model_id, |
| | |
| | |
| | gradient_checkpointing=args.gradient_checkpointing, |
| | gradient_checkpointing_kwargs={"use_reentrant": False} if args.gradient_checkpointing else None, |
| | ) |
| | |
| | |
| | |
| | trainer = SFTTrainer( |
| | model=model, |
| | args=training_args, |
| | train_dataset=train_dataset, |
| | eval_dataset=eval_dataset, |
| | processing_class=tokenizer, |
| | peft_config=lora_config, |
| | ) |
| | |
| | |
| | trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| | total_params = sum(p.numel() for p in model.parameters()) |
| | trainable_percentage = 100 * trainable_params / total_params if total_params > 0 else 0 |
| | |
| | logger.info("=" * 60) |
| | logger.info("Model parameter stats:") |
| | logger.info(f" Total params: {total_params:,}") |
| | logger.info(f" Trainable params: {trainable_params:,}") |
| | logger.info(f" Trainable ratio: {trainable_percentage:.2f}%") |
| | logger.info(f" Mode: {'LoRA' if args.use_lora else 'Full fine-tune'}") |
| | logger.info("=" * 60) |
| | |
| | |
| | logger.info("Start training...") |
| | |
| | if args.resume_from_checkpoint: |
| | trainer.train(resume_from_checkpoint=args.resume_from_checkpoint) |
| | else: |
| | trainer.train() |
| | |
| | |
| | logger.info("Saving final model...") |
| | final_model_path = os.path.join(output_dir, "final_model") |
| | trainer.save_model(final_model_path) |
| | tokenizer.save_pretrained(final_model_path) |
| | |
| | logger.info("=" * 60) |
| | logger.info("Training done.") |
| | logger.info(f"Model saved at: {final_model_path}") |
| | |
| | if args.use_lora: |
| | |
| | lora_path = os.path.join(output_dir, "lora_adapter") |
| | model.save_pretrained(lora_path) |
| | tokenizer.save_pretrained(lora_path) |
| | logger.info(f"LoRA adapter saved to: {lora_path}") |
| | logger.info("") |
| | logger.info("Usage:") |
| | logger.info(f" 1. LoRA adapter: {lora_path}") |
| | logger.info(f" 2. Merge adapters with your base model before inference") |
| | else: |
| | |
| | logger.info("") |
| | logger.info("Usage:") |
| | logger.info(f" Use model directly from: {final_model_path}") |
| | |
| | logger.info("=" * 60) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|