본문 바로가기
Python/Study

[소개] ONNX 란?

by beeny-ds 2022. 7. 27.

들어가며..

 Deep learning 모델을 서빙해 본 사람이라면 ONNX를 들어봤으리라 생각한다. ONNX는 다른 DNN 프레임워크 보다 추론 속도가 빠르고 성능도 거의 똑같이 유지된다. 그 외에도 여러 장점들이 존재하기 때문에 많은 데이터 사이언티스트 또는 데이터 분석가들은 ONNX에 대해 알아두면 큰 도움이 될 것이다.

 


 

ONNX 란?

ONNX는 Open Neural Network Exchange의 줄인 말로서 이름과 같이 다른 DNN 프레임워크 환경(ex Tensorflow, PyTorch, etc..)에서 만들어진 모델들을 서로 호환되게 사용할 수 있도록 만들어진 공유 플랫폼이다.
ps. ONNX 또한 DNN 프레임워크라고 부른다.

ONNX는 다음과 같은 장점을 갖는다.

  • 장점 ① : Framework Interoperability
    • 특정 환경에서 생성된 모델을 다른 환경으로 import하여 자유롭게 사용
      ex) Tensorflow에서 모델을 학습 시킨 뒤, 모바일로 옮겨서 사용
  • 장점 ② : Shared Optimization
    • 하드웨어 설계시 ONNX representation을 기준으로 최적화를 하면 되기 때문에 효율적
      ex) JSON 포맷이 정보 표현을 위해 여러 개발자들 사이에서 합의되어 사용하듯 ONNX라는 합의된 DNN 모델 포맷이 존재한다고 생각하면 됨

장점 ②에 의해 TVM, TensorRT와 같은 Deep learning compiler로 변환할 때 DNN 프레임워크를 ONNX 형태로 변환한 뒤 DL_Compiler 형태로 변환한다.
ex) PyTorch model → ONNX model → TensorRT engine


 

Simple view of ONNX

PyTorch 모델 → ONNX로 export 하기

PyTorch 모델을 ONNX 그래프로 export하는 전체 과정을 도식화

출처: https://yunmorning.tistory.com/17

  1. PyTorch 모델 & Input을 인자로 torch.onnx.export 함수 호출
  2. PyTorch의 JIT 컴파일러인 TorchScript를 통해서 trace or script 생성
  3. Trace or script는 PyTorch의 nn.Module을 상속하는 모델의 forward 함수에서 실행되는 코드들에 대한 IR(Intermediate Representation)을 담고 있다.
    • forward propagation 시에 호출되는 함수 및 연산들에 대한 최적화된 그래프 생성
  4. 생성된 trace/script는 ONNX exporter를 통해서 ONNX IR로 변환되고 여기에서 한번 더 graph optimization이 이루어진다.
  5. 최종적으로 생성된 ONNX 그래프는 .onnx 포맷으로 저장된다.

 

Tracing vs Scripting

출처: https://yunmorning.tistory.com/17


 

Limitations

  • PyTorch의 JIT compiler가 완벽하지 않다보니 파이썬으로 구현한 모델에 대해서 완벽하게 support 하지 못한다.
    • 현재까지는 tuple, list, Variable만이 input / output으로 지원되는 상황
    • dictionary, string은 일부만 지원 (dynamic loop up 불가능)
  • PyTorch와 ONNX의 backend 구현에 차이가 있다보니 모델 구조에 따라서 학습 성능에 문제가 있을 수 있다.

 

필자의 의견

PyTorch는 데이터 전문가들이 다양한 실험을 쉽게 할 수 있도록 각종 API를 개발하고 있다. torch를 이용한 ONNX 변환도 그 중 일부다. 시간이 지남에 따라 Limitations는 줄어들 것이다.

필자가 실제로 ONNX 변환하여 실험한 결과 BERT 모델(SKT KoBERT) 기준, CPU 환경에서는 대략 6배가 빨라졌다. 실험한 Task는 Classification으로 단순한 감성 분석(NSMC)이었다. GPU 환경에서도 테스트 해봤는데 기억이 잘 안나 몇 배가 빨라진지는 pass 하도록 하겠다.

Text를 넣으면 간단한 분류가 되는 API를 만들고자 한다면, API가 동작하는 환경이 열악하다면(ex. can use only cpu) ONNX 변환을 process에 넣어 개발하는 것을 적극 추천한다. 더 나아가면 Deep learning Compiler 사용도 추천한다.

 

To Be Continued

반응형

댓글