Main scripts
🗂️TextBrewer
|-🗂️src
|-🗂️textbrewer
|-📄configurations.py
# teacher model 기반으로 Student model initial weight 값 설정
# DistillationConfig 의 input 으로 들어감
-IntermediateMatch(type:class)
|-📄configurations.py
# knowledge distillation 학습을 위한 hyperparameters setting
-DistillationConfig(type:class)
# distiller define & train script 나열(?)된 script
|-📄distillers.py
|-📄distiller_general.py
-GeneralDistiller(type:class)
|-📄distiller_multiteacher.py
-MultiTeacherDistiller(type:class)
IntermediateMatch
학습 인자 함수가 있는 scripts
# scheduler, loss_type 등 설정 함수 위치
🗂️TextBrewer
|-🗂️src
|-🗂️textbrewer
# Hyperparameter setting 위한 함수를 정리한 script
|-📄presets.py
# 다양한 loss 함수 정의, classification의 경우, kd_ce_loss 함수 사용 / STS 데이터셋과 같이 mse loss 함수가 필요한 경우, mse_loss 관련 함수 사용
# DistilBERT 논문에 있는 cosine similarity loss 함수도 사용 가능
|-📄losses.py
# knowledge distillation 학습을 위한 scheduler 설정을 위한 script
|-📄schedulers.py
# Teacher model weight 값을 사용한 Student model weight 값 설정을 위한 script
|-📄projections.py
|-📄utils.py
Main class: GeneralDistiller
💡 Teacher model의 수가 한 개일 때(single-teacher), 경량화된 Student model을 학습하기 위한 class
# GeneralDistiller class 의 parent class & 사용된 script 정리
🗂️TextBrewer
|-🗂️src
|-🗂️textbrewer
# adapter 설정을 위한 class & 원하는 device 사용(apex 사용 가능) & tensorboard 사용 등 기능이 정리된 script
|-📄distiller_utils.py
# GeneralDistiller.train(**kwargs)했을 때, train 함수는 BasicDistiller class에서 작동
# GeneralDistiller class의 parent class = BasicDistiller class
|-📄distiller_basic.py
Main class: MultiTeacherDistiller
💡 Teacher model의 수가 두 개 이상일 때(multi-teacher), 경량화된 Student model을 학습하기 위한 class
- 대신 intermediate feature matching이 지원되지 않음
- Student model’s weight = initial weight value 사용 (정규 분포 or 균일 분포)
- Teacher model은 모두 같은 task 관련 model이어야 됨
# GeneralDistiller class 의 parent class & 사용된 script 정리
🗂️TextBrewer
|-🗂️src
|-🗂️textbrewer
# adapter 설정을 위한 class & 원하는 device 사용(apex 사용 가능) & tensorboard 사용 등 기능이 정리된 script
|-📄distiller_utils.py
# GeneralDistiller.train(**kwargs)했을 때, train 함수는 BasicDistiller class에서 작동
# GeneralDistiller class의 parent class = BasicDistiller class
|-📄distiller_basic.py
그 외의 script
🗂️TextBrewer
|-🗂️src
|-🗂️textbrewer
# adapter 설정을 위한 class & 원하는 device 사용(apex 사용 가능) & tensorboard 사용 등 기능이 정리된 script
|-📄distiller_utils.py
# GeneralDistiller.train(**kwargs)했을 때, train 함수는 BasicDistiller class에서 작동
# GeneralDistiller class의 parent class = BasicDistiller class
|-📄distiller_basic.py
TextBrewer 패키지 관련 글
이전 글
2022.06.17 - [Natural Language Processing/Model Compression] - [경량화 패키지] TextBrewer란?
이후 글
2022.06.17 - [Natural Language Processing/Model Compression] - [경량화 패키지] TextBrewer 학습 Process
2022.06.17 - [Natural Language Processing/Model Compression] - [경량화 패키지] TextBrewer 학습 Process
2022.06.17 - [Natural Language Processing/Model Compression] - [경량화 패키지] TextBrewer 사용 후기
반응형
'Natural Language Processing > Model Compression' 카테고리의 다른 글
[논문 리뷰] Distilling Linguistic Context for Language Model Compression (0) | 2022.06.23 |
---|---|
[경량화 패키지] TextBrewer 사용 후기 (0) | 2022.06.17 |
[경량화 패키지] TextBrewer 학습 Process (0) | 2022.06.17 |
[경량화 패키지] TextBrewer란? (0) | 2022.06.17 |
댓글