본문 바로가기
Natural Language Processing/Model Compression

[경량화 패키지] TextBrewer scripts info.

by beeny-ds 2022. 6. 17.

Main scripts

🗂️TextBrewer
  |-🗂️src
    |-🗂️textbrewer
      |-📄configurations.py
        # teacher model 기반으로 Student model initial weight 값 설정
        # DistillationConfig 의 input 으로 들어감
        -IntermediateMatch(type:class)

      |-📄configurations.py
        # knowledge distillation 학습을 위한 hyperparameters setting
        -DistillationConfig(type:class)

      # distiller define & train script 나열(?)된 script
      |-📄distillers.py
        |-📄distiller_general.py
            -GeneralDistiller(type:class)
        |-📄distiller_multiteacher.py
            -MultiTeacherDistiller(type:class)

IntermediateMatch

출처:https://github.com/airaria/TextBrewer/blob/master/src/textbrewer/configurations.py


학습 인자 함수가 있는 scripts

# scheduler, loss_type 등 설정 함수 위치
🗂️TextBrewer
  |-🗂️src
    |-🗂️textbrewer

      # Hyperparameter setting 위한 함수를 정리한 script
      |-📄presets.py

      # 다양한 loss 함수 정의, classification의 경우, kd_ce_loss 함수 사용 / STS 데이터셋과 같이 mse loss 함수가 필요한 경우, mse_loss 관련 함수 사용
      # DistilBERT 논문에 있는 cosine similarity loss 함수도 사용 가능
      |-📄losses.py

      # knowledge distillation 학습을 위한 scheduler 설정을 위한 script
      |-📄schedulers.py

      # Teacher model weight 값을 사용한 Student model weight 값 설정을 위한 script
      |-📄projections.py
      |-📄utils.py

Main class: GeneralDistiller

💡 Teacher model의 수가 한 개일 때(single-teacher), 경량화된 Student model을 학습하기 위한 class

# GeneralDistiller class 의 parent class & 사용된 script 정리
🗂️TextBrewer
  |-🗂️src
    |-🗂️textbrewer

      # adapter 설정을 위한 class & 원하는 device 사용(apex 사용 가능) & tensorboard 사용 등 기능이 정리된 script
      |-📄distiller_utils.py

      # GeneralDistiller.train(**kwargs)했을 때, train 함수는 BasicDistiller class에서 작동
      # GeneralDistiller class의 parent class = BasicDistiller class
      |-📄distiller_basic.py

Main class: MultiTeacherDistiller

💡 Teacher model의 수가 두 개 이상일 때(multi-teacher), 경량화된 Student model을 학습하기 위한 class

  • 대신 intermediate feature matching이 지원되지 않음
    • Student model’s weight = initial weight value 사용 (정규 분포 or 균일 분포)
  • Teacher model은 모두 같은 task 관련 model이어야 됨
# GeneralDistiller class 의 parent class & 사용된 script 정리
🗂️TextBrewer
  |-🗂️src
    |-🗂️textbrewer

      # adapter 설정을 위한 class & 원하는 device 사용(apex 사용 가능) & tensorboard 사용 등 기능이 정리된 script
      |-📄distiller_utils.py

      # GeneralDistiller.train(**kwargs)했을 때, train 함수는 BasicDistiller class에서 작동
      # GeneralDistiller class의 parent class = BasicDistiller class
      |-📄distiller_basic.py

그 외의 script

🗂️TextBrewer
  |-🗂️src
    |-🗂️textbrewer

      # adapter 설정을 위한 class & 원하는 device 사용(apex 사용 가능) & tensorboard 사용 등 기능이 정리된 script
      |-📄distiller_utils.py

      # GeneralDistiller.train(**kwargs)했을 때, train 함수는 BasicDistiller class에서 작동
      # GeneralDistiller class의 parent class = BasicDistiller class
      |-📄distiller_basic.py

 

TextBrewer 패키지 관련 글

이전 글

2022.06.17 - [Natural Language Processing/Model Compression] - [경량화 패키지] TextBrewer란?

이후 글

2022.06.17 - [Natural Language Processing/Model Compression] - [경량화 패키지] TextBrewer 학습 Process

2022.06.17 - [Natural Language Processing/Model Compression] - [경량화 패키지] TextBrewer 학습 Process

2022.06.17 - [Natural Language Processing/Model Compression] - [경량화 패키지] TextBrewer 사용 후기

 

반응형

댓글