Source code for trinity.utils.lora_utils

import torch
from peft import LoraConfig, TaskType, get_peft_model
from transformers import AutoConfig, AutoModelForCausalLM


[docs] def create_dummy_lora( model_path: str, checkpoint_job_dir: str, lora_rank: int, lora_alpha: int, target_modules: str, ) -> str: config = AutoConfig.from_pretrained(model_path) model = AutoModelForCausalLM.from_config(config) lora_config = { "task_type": TaskType.CAUSAL_LM, "r": lora_rank, "lora_alpha": lora_alpha, "target_modules": target_modules, "bias": "none", } peft_model = get_peft_model(model, LoraConfig(**lora_config)) peft_model.save_pretrained(f"{checkpoint_job_dir}/dummy_lora") del model, peft_model torch.cuda.empty_cache() return f"{checkpoint_job_dir}/dummy_lora"