본문 바로가기
Natural Language Processing/Github 훑어보기

[Code review] Sentence-Transformers 학습 Process

by beeny-ds 2022. 7. 5.

Sentence-Transformers (UKPlab)

Sentence embedding 관련 패키지 리서치 중 Sentence-Transformers Github 코드를 자세히 살펴볼 기회가 생겼다. 
해당 패키지의 구조부터 자세한 상세 기능, 그리고 BERT 모델 사용 시 HuggingFace Transformers 와의 호환성까지 살펴보려 한다.

 

Sentence-Transformers > train process

Train process

  1. DataLoader
    - InputExample class 사용하여 생성
    - 생성한 객체를 list로 감싼 뒤, PyTorch DataLoader에 인자로 넣어 사용


  2. Model definition
    - models 폴더에서 model 구조 선택
    - BERT의 경우, Transformer.py & Pooling.py 사용
    - SentenceTransformer class에서 model 구조 생성 (__init__ 메서드에서 정의)

  3. Training Function definition
    - losses 폴더에 있는 loss 함수 정의
    - loss 함수에 2번에서 정의한 model 객체를 인자로 넣어줘서 loss 객체 생성

  4. Evaluator definition
    - Evaluation 폴더에 있는 evaluate 객체 정의

  5. Model Train (.fit)
    - 2번에서 정의한 model 객체의 fit 메서드 사용해서 학습 진행
    - fit 메서드의 인자로는 3번에서 정의한 loss 객체와 4번에서 정의한 evaluate 객체을 사용
위와 같은 Process를 Diagram으로 나타내면 아래와 같다.

 

Sentence-Transformers SentenceTransformer class 학습 process

  1. 데이터 인코딩
  2. 모델 정의
  3. loss 정의
    - 이때 (2.)에서 정의한 모델 객체를 인자로 사용하여 loss 객체 생성
  4. 검증 class 정의
  5. (2.)에서 정의한 모델 객체의 fit 메서드를 이용한 학습 진행
    - 이때 (3.)에서 정의한 loss 객체를 인자로 사용
  6. (1.~5.) 반복 & final model 저장
Loss class는 nn.Module을 받아 forward 메서드가 계산되는 형식.
즉, 객체와 객체 사이를 오가며 model weight 업데이트 되는 process…

 

Example code

from sentence_transformers import InputExample, SentenceTransformer, models,losses,evaluation
from torch.utils.data import DataLoader
from torch import nn

### 1. DataLoader
train_examples = [InputExample(texts=['My first sentence', 'My second sentence'], label=0.8),
				  InputExample(texts=['Another pair', 'Unrelated sentence'], label=0.3)]
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)

### 2. Model definition
word_embedding_model = models.Transformer('bert-base-uncased', max_seq_length=256)
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

### 3. Training Function definition
train_loss = losses.CosineSimilarityLoss(model)

### 4. Evaluator definition
evaluator = evaluation.EmbeddingSimilarityEvaluator.from_input_examples(eval_dataloader, batch_size=1, name='sts-dev')

### 5. Model Train (.fit)
model.fit(train_objectives=[(train_dataloader, train_loss)], epochs=epochs, warmup_steps=warmup_steps, evaluator=evaluator, evaluation_steps=evaluation_steps)

 


 

Comments

Sentence-Transformers 패키지 학습 Process는 처음 보는 형식이었다...
하지만 객체에 상관 없이 코드를 풀어서 보면 기존 학습 방식과 동일하다는 것을 알 수 있다.
"데이터를 모델에 넣어 계산 → 모델 output과 label 사이의 loss 계산 → 모델 업데이트"

다음 포스팅은 검증 및 예측 Process에 대해 살펴볼 예정이다.
  • 학습 Process
  • 검증 및 예측 Process
  • sentence-transformers 상세 기능
  • HuggingFace transformers와 어떻게 다른지

 

To Be Continued.....

반응형

댓글