Sentence-Transformers (UKPlab)
Sentence embedding 관련 패키지 리서치 중 Sentence-Transformers Github 코드를 자세히 살펴볼 기회가 생겼다.
해당 패키지의 구조부터 자세한 상세 기능, 그리고 BERT 모델 사용 시 HuggingFace Transformers 와의 호환성까지 살펴보려 한다.
Sentence-Transformers > train process
- DataLoader
- InputExample class 사용하여 생성
- 생성한 객체를 list로 감싼 뒤, PyTorch DataLoader에 인자로 넣어 사용 - Model definition
- models 폴더에서 model 구조 선택
- BERT의 경우, Transformer.py & Pooling.py 사용
- SentenceTransformer class에서 model 구조 생성 (__init__ 메서드에서 정의) - Training Function definition
- losses 폴더에 있는 loss 함수 정의
- loss 함수에 2번에서 정의한 model 객체를 인자로 넣어줘서 loss 객체 생성 - Evaluator definition
- Evaluation 폴더에 있는 evaluate 객체 정의 - Model Train (.fit)
- 2번에서 정의한 model 객체의 fit 메서드 사용해서 학습 진행
- fit 메서드의 인자로는 3번에서 정의한 loss 객체와 4번에서 정의한 evaluate 객체을 사용
위와 같은 Process를 Diagram으로 나타내면 아래와 같다.
- 데이터 인코딩
- 모델 정의
- loss 정의
- 이때 (2.)에서 정의한 모델 객체를 인자로 사용하여 loss 객체 생성 - 검증 class 정의
- (2.)에서 정의한 모델 객체의 fit 메서드를 이용한 학습 진행
- 이때 (3.)에서 정의한 loss 객체를 인자로 사용 - (1.~5.) 반복 & final model 저장
Loss class는 nn.Module을 받아 forward 메서드가 계산되는 형식.
즉, 객체와 객체 사이를 오가며 model weight 업데이트 되는 process…
Example code
from sentence_transformers import InputExample, SentenceTransformer, models,losses,evaluation
from torch.utils.data import DataLoader
from torch import nn
### 1. DataLoader
train_examples = [InputExample(texts=['My first sentence', 'My second sentence'], label=0.8),
InputExample(texts=['Another pair', 'Unrelated sentence'], label=0.3)]
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)
### 2. Model definition
word_embedding_model = models.Transformer('bert-base-uncased', max_seq_length=256)
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
### 3. Training Function definition
train_loss = losses.CosineSimilarityLoss(model)
### 4. Evaluator definition
evaluator = evaluation.EmbeddingSimilarityEvaluator.from_input_examples(eval_dataloader, batch_size=1, name='sts-dev')
### 5. Model Train (.fit)
model.fit(train_objectives=[(train_dataloader, train_loss)], epochs=epochs, warmup_steps=warmup_steps, evaluator=evaluator, evaluation_steps=evaluation_steps)
Comments
Sentence-Transformers 패키지 학습 Process는 처음 보는 형식이었다...
하지만 객체에 상관 없이 코드를 풀어서 보면 기존 학습 방식과 동일하다는 것을 알 수 있다.
"데이터를 모델에 넣어 계산 → 모델 output과 label 사이의 loss 계산 → 모델 업데이트"
다음 포스팅은 검증 및 예측 Process에 대해 살펴볼 예정이다.
학습 Process- 검증 및 예측 Process
- sentence-transformers 상세 기능
- HuggingFace transformers와 어떻게 다른지
To Be Continued.....
반응형
'Natural Language Processing > Github 훑어보기' 카테고리의 다른 글
[Code review] Sentence-Transformers 비교 hug/trans (0) | 2022.07.08 |
---|---|
[Code review] Sentence-Transformers 상세 기능 (0) | 2022.07.08 |
[Code review] Sentence-Transformers 검증 및 예측 Process (0) | 2022.07.06 |
[Code review] Sentence-Transformers 훑어보기: 구조 (0) | 2022.07.04 |
댓글