필자는 요즘 고전적인 Deep Learning 모델을 개발하기 위해 PyTorch-Lightning 을 활용한 모델 학습 및 평가 모듈을 개발하고 있다.
개발을 하며 가장 최신 버전으로 공부중인 PyTorch-Lightning 의 기능에 대해 포스팅하고자 한다.
PyTorch-Lightning 기능에 대한 포스팅은 대략 3개 정도 올릴 예정이다.
본 포스팅에서는 LightningDataModule 에 대해 다뤄보도록 하겠다.
※ 시작하기에 앞서 필자의 이전 포스팅인 [PyTorch-Lightning: v2.5.1] LightningModule class 파악 을 먼저 보기를 추천한다.
목차
1. LightningDataModule 이란
2. Method 역할 및 호출 시점
3. Train.fit() 사용 시 인자로 활용하는 방법
1. LightningDataModule 이란
PyTorch-Lightning 프레임워크의 주요 Class 로서 데이터 준비, 분할, 변환, 로딩 등 데이터 파이프라인의 모든 과정을 표준화된 방식으로 관리하는 역할을 한다.
LightningDataModule 은 코드의 재사용성, 실험의 일관성, 유지보수성, 데이터셋 교체의 유연성을 크게 높여주는 특징이 있다.
특징에 대해서 조금 더 상세하게 알아보자.
- 데이터 처리 캡슐화: 이터 다운로드, 전처리, 분할, 변환, DataLoader 생성까지 모든 과정을 하나의 클래스에 통합
- 코드 재사용성: 동일한 DataModule을 여러 프로젝트나 모델에서 재사용 가능
- 일관성 및 재현성: 데이터 분할, 변환, 준비 과정이 표준화되어 실험의 재현성이 높아짐
- 유지보수 및 확장성: 데이터 처리 로직이 한 곳에 모여 있어 유지보수가 쉽고, 새로운 데이터셋 추가도 간편
- 실험 및 배포의 용이성: DataModule만 교체하면 다양한 데이터셋으로 손쉽게 실험 가능, 코드 배포도 간단
- 하이퍼파라미터 관리: **save_hyperparameters()**로 하이퍼파라미터를 쉽게 저장/관리할 수 있음
아래 간단 예시를 참고하기 바란다.
from lightning.pytorch import LightningDataModule
from torch.utils.data import DataLoader, Dataset
class CustomDataModule(LightningDataModule):
def __init__(self, data_dir, batch_size=32):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
def prepare_data(self):
# 데이터 다운로드 또는 전처리
pass
def setup(self, stage=None):
# 데이터셋 분할 및 할당
self.train_dataset = ...
self.val_dataset = ...
self.test_dataset = ...
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size)
def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batch_size)
def test_dataloader(self):
return DataLoader(self.test_dataset, batch_size=self.batch_size)
Lightning 프레임워크의 가장 중요한 특징은 hook 이다. (궁금하면 여기 링크를 클릭해서 확인해보라.)
링크에서 보면 알겠지만 DataHooks class 의 설명을 보면 아래의 순서로 동작한다고 한다.
.. code-block:: python
model.prepare_data()
initialize_distributed()
model.setup(stage)
model.train_dataloader()
model.val_dataloader()
model.test_dataloader()
model.predict_dataloader()
때문에 각 Method 의 역할과 호출 시점에 대해서 알아야 한다.
2. Method 역할 및 호출 시점
위에서 언급한바와 같이 Lightning 프레임워크를 의도한대로 사용하기 위해서는 각 Method 의 역할과 언제 해당 Method 가 호출되는지 명확하게 이해해야 한다.
메서드명 | 역할 및 기능 | 호출 시점 |
__init__ | 데이터 경로, 배치 크기, 변환 등 데이터 처리에 필요한 하이퍼파라미터와 속성 초기화 | DataModule 객체를 생성할 때 호출 |
prepare_data | 데이터 다운로드, 압축 해제, 토크나이즈 등 디스크에 데이터를 준비하는 작업. 멀티 GPU/TPU 환경에서도 한 번만 실행됨 | Trainer가 처음 실행될 때, 단일 프로세스/노드에서 한 번만 호출 |
setup(stage=None) | 데이터셋 로드, 분할, 변환 등. 학습/검증/테스트/예측 등 단계별로 데이터셋 준비. 모든 프로세스에서 실행됨 | 각 단계(학습, 검증, 테스트, 예측) 시작 전에, 각 프로세스에서 호출 |
train_dataloader | 학습 데이터셋을 반환하는 DataLoader 생성 | Trainer가 학습 단계에 진입할 때 호출 |
val_dataloader | 검증 데이터셋을 반환하는 DataLoader 생성 | Trainer가 검증 단계에 진입할 때 호출 |
test_dataloader | 테스트 데이터셋을 반환하는 DataLoader 생성 | Trainer가 테스트 단계에 진입할 때 호출 |
predict_dataloader | 예측 데이터셋을 반환하는 DataLoader 생성 | Trainer가 예 단계에 진입할 때 호출 |
on_before_batch_transfer | 각 배치가 디바이스로 이동되기 전에 추가적인 변환이나 증강 작업을 할 수 있음 | 각 배치가 DataLoader에서 반환되어 디바이스로 옮겨지기 직전에 호출 |
save_hyperparameters | 모델 생성 시 전달된 하이퍼파라미터를 자동으로 저장하고, 체크포인트와 연동하여 모델 재생성 및 실험 재현성을 극대화 | 직접 호출하는 시점에 실행됨 |
prepare_data Method 에서 주의할 점이 있다.
절대 self 상태를 저장하면 안 된다. (ex. self.data = ~~~)
왜냐하면 분산 학습 시 메인 프로세스에서만 실행되기 때문에 다른 프로세스에서 상태를 공유할 수 없기 때문이다.
예들 들어보자.
def prepare_data(self):
# 올바른 사용: 디스크에 데이터 저장
download_dataset(self.data_url, self.data_dir)
# 잘못된 사용: 메모리에 데이터 로드
self.raw_data = load_csv(self.data_dir) # ❌ 분산 환경에서 문제 발생
위와 같이 self.raw_data 로 상태를 저장하면 안 된다.
올바른 예시는 다음과 같다.
def prepare_data(self):
# 디스크에 데이터 저장만 수행
download_data(self.data_dir) # ✅
preprocess_and_save(self.data_dir) # ✅ (결과를 디스크에 저장)
def setup(self):
# 디스크에서 데이터 로드 + 분할
raw_data = load_from_disk(self.data_dir) # ✅
self.train_data, self.val_data = split_data(raw_data) # ✅
호출 시점이 prepare_data 가 setup 보다 우선이기 때문에 prepare_data 에서 다운로드하여 디스크에 저장한 데이터를 setup 에서 불러올 수 있다.
이 점을 주의하며 메서드를 정의해야 한다.
3. Train.fit() 사용 시 인자로 활용하는 방법
Trainer.fit() 사용하여 모델을 학습할 때 데이터셋을 LightningDataModule 클래스로 전처리 및 정의한 경우, datamodule 인자에 넣어야 된다.
datamodule 인자에 LightningDataModule 로 정의한 데이터셋을 넣으면 데이터 전처리 및 로딩이 자동으로 처리되기 때문이다.
예를 들어보겠다.
from lightning.pytorch import Trainer
trainer = Trainer()
model = MyLightningModule()
datamodule = MyLightningDataModule()
trainer.fit(model, datamodule=datamodule)
위 예시와 같이 datamodule 인자에 LightningDataModule 인스턴스를 전달하면, Trainer는 내부적으로 해당 DataModule의 setup, train_dataloader, val_dataloader 등 메서드를 자동으로 호출하여 데이터를 처리하고 모델에 공급한다.
train_dataloaders나 val_dataloaders 인자에 직접 DataLoader를 넣는 방식과는 다르며, 둘을 동시에 사용할 수 없다.
둘 중 하나만 사용해야 하며, DataModule을 사용하는 것이 더 권장되는 방식이다.
이러한 특징과 장점은 다음과 같다.
- 데이터 준비, 전처리, 분할, 로딩을 DataModule에 모듈화하여 코드가 간결해지고 유지보수가 쉬워짐
- 실험 환경이 바뀌거나 데이터셋이 변경되어도 DataModule만 교체하면 동일한 모델 코드로 다양한 실험이 가능
- Trainer가 DataModule의 메서드를 자동으로 호출하므로 fit 메서드에서 별도의 DataLoader를 직접 넘길 필요가 없음
- DataModule을 사용하면 CPU, GPU, TPU, 16-bit precision 등 다양한 환경에 자동 대응할 수 있
때문에 Trainer class 의 fit 을 통해 모델을 학습 및 평가하고자 한다면 datamodule 인자에 LightningDataModule 인스턴스를 전달하는 방식으로 코드를 구성하는걸 추천한다.
마무리,,
지금까지 PyTorch-Lightning 프레임워크의 핵심 객체인 LightningModule 과 LightningDataModule 에 대해 알아봤다.
다음 포스팅으로는 Lightning 프레임워크를 사용해서 어떻게 코드 개발을 하면 좋은지, 팀원과 협업할 때의 팁은 무엇인지 소개하도록 하겠다.
다음 포스팅이 진짜 실무자에게 필요한 내용이니 꼭 다음 포스팅 내용을 참고하기 바란다.
'Python > 패키지 훓어보기' 카테고리의 다른 글
[PyTorch-Lightning: v2.5.1] 모델 학습, 검증, 추론 프레임워크 만드는 Tips (0) | 2025.04.23 |
---|---|
[PyTorch-Lightning: v2.5.1] LightningModule class 파악 (0) | 2025.04.17 |
[PyTorch] nn.Transformer 모델 구조 상세 확인 (0) | 2025.03.12 |
[LLaMA-Factory] LoRA Adapter 확인 (0) | 2025.02.27 |
[LLaMA-Factory] Tokenizer padding_side 확인 (0) | 2025.02.22 |
댓글