본문 바로가기
Natural Language Processing/NLG 이모저모

LoRA 학습 코드 예시

by beeny-ds 2024. 9. 2.
LLM 에 대한 연구를 하는 사람이라면 누구나 LoRA 를 들어봤을거라 생각한다.
이번 포스팅은 LoRA 및 qLoRA 학습 코드 예시를 step by step 으로 설명하고자 한다.

※ sLLM Instruct tuning 에 관심이 깊은 사람에게 도움이 되는 글임을 유의하길 바란다.

 

목차

1. Model define

2. LoRA config define

3. Train datasets define

4. Arguments setting and Train

5. 마무리,,


 

1. Model define

LoRA 또는 qLoRA 학습을 위해서는 Model 과 Tokenizer 를 불러와야 한다.

본 예시에서는 beomi 님이 배포하신 Llama-3-Ko 모델을 foundation 으로 사용해보았다.

## load modules
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import torch
import bitsandbytes

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import BitsAndBytesConfig
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model

## 만약 qLoRA 를 사용하고 싶다면 아래 설정 적용
### 앞으로 if use_qlora: 가 있으면 qlora 설정으로 간주
use_qlora = True
if use_qlora:
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )
else:
	bnb_config = None

## load model & tokenizer
model_id = "beomi/Llama-3-Open-Ko-8B"

base_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map='auto',
    trust_remote_code=True,
    use_auth_token=True,
    torch_dtype=torch.bfloat16,
    attn_implementation='flash_attention_2',    # cuda 11.6 이상의 버전에서 사용 가능
)

tokenizer = AutoTokenizer.from_pretrained(model_id)

 

  • qlora 사용을 위해서는 transformers library 에 있는 객체인 BitsAndBytesConfig 을 설정해야 한다.
    본 예시는 nf4 로 모델의 weight 를 양자화하였다.
  • flash_attention 은 cuda 11.6 이상의 버전에서 활용 가능하다.
    만약 본인이 사용하는 cuda version 이 11.6 미만이면 attn_implementation="eager" 으로 해주자.
  • 모델의 weight 는 bf16 으로 불러와서 사용했다.
    fp32 는 gpu memory 과부화가 심하기 때문이다.

 

2. LoRA config define

peft library 에서 LoraConfig 객체를 정의해주는 것만으로 효율적인 학습이 가능하다.

이때 Model 의 어떤 Layer 에 Lora adapter 를 적용할지 정의해야 하는데 필자는 모든 layer 에 adapter 를 적용해줬다.

# LoRA adapter 추가할 Layer 정의 > 본 포스팅에서는 all layers 에 추가
def find_all_linear_names(model, load_in_4bit):
    '''
    lora target module 반환
    '''
    cls = bitsandbytes.nn.Linear4bit if load_in_4bit else (
        bitsandbytes.nn.Linear8bitLt if load_in_4bit else torch.nn.Linear)
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])

    if 'lm_head' in lora_module_names:
        lora_module_names.remove('lm_head')

    return list(lora_module_names)

# LoRA 사용을 위한 Config 정의
peft_config = LoraConfig(
    r=4,    # LoRA dim 정의
    lora_alpha=16,    # LoRA impact 설정 (rank 보다 클수록 학습 데이터를 더 많이 반영)
    lora_dropout=0.1,
    inference_mode=False,
    target_modules=find_all_linear_names(base_model,load_in_4bit)+["embed_tokens", "lm_head"],    # 적용 Layers 설정
    bias="none",
    task_type="CAUSAL_LM",
)

use_qlora = True
if use_qlora:    
    base_model = prepare_model_for_kbit_training(base_model, use_gradient_checkpointing=True)
    base_model = get_peft_model(base_model, peft_config)

 

  • find_all_linear_names 함수는 haotian-liu/LLaVA github 을 참고했다.
  • LoraConfig 설정 시 주의해야 할 arguments 는 r, lora_alpha, target_modules 이다.
    • r: LoRA adapter 의 rank 를 의미한다. (dim 이라 생각하면 됨)
    • lora_alpha: LoRA adapter 를 model weight 에 어느 정도로 많이 반영할지 이다.
      웬만하면 lora_alpha / r ≥ 2 가 되도록 설정하자.
    • target_modules: LoRA adapter 를 어떤 layer 에 적용할지이다.
      필자는 모든 layer 에 적용했다.
  • qlora 사용하여 모델을 학습하고자 한다면 꼭 prepare_model_for_kbit_training 으로 model 을 감싸주자.
    • gpu 에 load 되는 model 의 weight 는 양자화된 상태이지만 학습 중 연산할 때는 bf16 으로 변환된다.
      학습할 때에만 4bit 양자화된 weight 가 bf16 으로 변환되는걸 반복하기 때문에 모델에 의한 gpu memory 의 변화는 없다.

 

3. Train datasets define

모델마다 tokenizer.apply_chat_template 이 다르다. 심지어 어떤 모델은 chat_template 이 없는 것도 있다.

huggingface 에서 tokenizer_config.json 의 chat_template 을 참고하여 formatting_prompts_func 을 정의해주자.

모델을 학습할 text 를 어떻게 구성할지를 나타내주기 때문에 해당 함수는 Prompt 를 어떻게 구성할지를 의미하기도 한다.

## trl 사용하여 학습 데이터 생성을 위한 함수 정의
def formatting_prompts_func(example):
    '''
    text 정의
    '''
    output_texts = []
    for i in range(len(example['context'])):
    	user_prompt = ""    # 학습 데이터에 맞는 prompt 적용 (context | question)
        assistant_prompt = ""    # question 에 대한 answer 적용
        messages = [
            {"role": "system", "content": '당신은 인공지능 어시스턴트입니다. 묻는 말에 친절하고 정확하게 답변하세요.'},
            {"role": "user", "content": user_prompt},
            {"role": "assistant", "content": assistant_prompt}
        ]
        ## tokenizer chat_template 형식대로 text 추출
        text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=False
        )
        output_texts.append(text)
    return output_texts
    
## label 생성을 위한 instruction_template & response_template 정의
collator = DataCollatorForCompletionOnlyLM(instruction_template="<|begin_of_text|><|start_header_id|>system<|end_header_id|>", 
                                           response_template="<|start_header_id|>assistant<|end_header_id|>",
                                           tokenizer=tokenizer)
  • 본 예시에 사용한 foundation 모델의 chat_template 을 참고했을 때
    시작은 항상 "<|begin_of_text|><|start_header_id|>system<|end_header_id|>" 이고,
    답변 전까지는 항상 "<|start_header_id|>assistant<|end_header_id|>" 이다.
    • Label 은 response_template 이 나온 시점 이후부터 생성된다. 그 이전까지는 -100 으로 셋팅되어 학습하지 않는다.
    • 즉, 학습은 질문에 대한 답변만 학습한다.
  • 만약 foundation 모델에 chat_template 이 없으면 아래와 같이 chat_template 을 만들어주자.
    • chat_template 은 jinja template 을 사용하여 쉽게 만들어줄 수 있다.
jinja_template = ```{%- if messages[0]['role'] == 'system' %}
    {%- set loop_messages = messages[1:] %}
    {%- set system_message = messages[0]['content'] %}
{%- else %}
    {%- set loop_messages = messages %}
    {%- set system_message = '당신은 인공지능 어시스턴트입니다. 묻는 말에 친절하고 정확하게 답변하세요.' %}
{%- endif %}
{%- if not add_generation_prompt is defined %}
    {%- set add_generation_prompt = false %}
{%- endif %}
{%- for message in loop_messages %}
    {%- if loop.index0 == 0 %}
        {{- special_token_1 + 'system' + 'special_token_2\n' + system_message + '</s>\n\n'}}
    {%- endif %}
    {{- special_token_1 + message['role'] + 'special_token_2\n' + message['content'] + '</s>\n\n'}}
{%- endfor %}
{%- if add_generation_prompt %}
    {{- special_token_1 + 'assistant' + 'special_token_2\n' }}
{%- endif %}```

tokenizer.chat_template = jinja_template

Jinja_template 을 사용하여 tokenizer 에 chat_template 을 적용하면 apply_chat_template 으로 정의한 template 을 자동으로 생성해준다. foundation 모델에 chat_template 이 없다면 직접 구성해준 뒤 학습시켜주자.


 

4. Arguments setting and Train

SFTTrainer 를 사용하여 간편하게 LoRA 를 활용한 학습을 진행해준다.

## arguments 정의
training_args = TrainingArguments(
    output_dir="./outputs",
    num_train_epochs=5,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    warmup_ratio=0.01,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=30,
    do_train=True,
    do_eval=False,
    fp16=False,
    bf16=True,
    save_strategy = 'epoch',
    seed=42,
)

## 학습을 위한 SFTTrainer 객체 적의 정의
trainer = SFTTrainer(
    model=base_model,
    train_dataset=train_dataset,    # 학습 데이터는 Dataset type
    formatting_func=formatting_prompts_func,
    data_collator=collator,
    peft_config=peft_config,
    packing=False,
    max_seq_length=4096 ,
    tokenizer=tokenizer,
    args=training_args,
)

print(trainer.model.print_trainable_parameters())    # 전체, 학습 가능 파라미터 수 확인

trainer.train()    # 학습 시작

## 학습 완료된 모델 저장
trainer.model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
  • trainer.model.print_trainable_parameters() 를 사용하면 전체 파라미터 수와 학습할 파라미터 수(total LoRA parameters)를 확인할 수 있다.
  • SFTTrainer 는 transformers 의 Trainer 객체를 상속한다.
    상세한 코드 분석은 추후 포스팅으로 올리도록 하겠다.

LoRA 활용한 Llama-3-8B 학습 시 gpu vram

llama-3 를 lora 활용하여 학습하면 약 40~45GiB 의 gpu vram 이 필요하다.

하지만 qlora 를 활용한다면 4~6 GiB 의 gpu vram 만으로 학습할 수 있다.

하지만 웬만하면 qlora 를 활용한 학습은 추천하지 않는다. 그 이유는 lora 대비 모델의 tuning 영향이 적기 때문이다.

그렇기 때문에 성능향상이 잘 안 된다.

 

피치 못할 사정이라면 qlora 를 사용해야겠지만 웬만하면 lora 를 사용하길 권장한다.

 


 

마무리,,

오늘은 lora 학습 코드 예시에 대해 살펴보았다.

huggingface 에서 제공하는 많은 객체와 함수를 사용하면 쉽게 lora 학습을 할 수 있다.

또한 device 를 효율적으로 사용하는 것도 쉽게 가능하다.

 

다음 포스팅으로는 오늘 소개한 huggingface 코드를 상세히 훑어보도록 하겠다.

반응형

댓글