Python/PyTorch

FSDP 쉽게 설명하기

beeny-ds 2025. 5. 11. 13:17
작년부터 모델 학습 시 FSDP 를 많이 사용해왔다.
한정된 자원 내에서 Large 모델을 학습할 때 가장 효율적인 방법이 FSDP 라고 생각하기 때문이다.

본 포스팅에서는 Large 모델을 한정된 자원 내에서 효율적으로 학습할 수 있는 Accelerator 인 FSDP 에 대해 다룬다.
모델 학습을 해본 사람이라면 누구나 이해할 수 있도록 쉽게 풀어쓰고자 한다.

※ 시작하기에 앞서 필자의 이전 포스팅을 먼저 보는걸 추천한다.

목차

1. FSDP 효율

2. [NCCL] All-Gather & Reduce-Scatter 설명

3. FSDP 연산 방식


1. FSDP 효율

FSDP(Fully Sharded Data Parallel) 는 PyTorch 네이티브 분산 학습 기술로, 모델 파라미터/그래디언트/옵티마이저 상태를 완전 분할하여 적은 자원에서 큰 모델을 학습할 수 있는 기술이다.

핵심 메커니즘은 NCCL 중 All-Gather, Reduce-Scatter 을 사용한다. 이에 대한 설명은 아래서 확인하자.

 

FSDP 의 효과를 예시로 보여주면 다음과 같다.

 

a. 상황 가정

  • 모델 VRAM 요구량: 12GB (단일 GPU 기준)
  • 학습 환경: 4x GPU (각 GPU VRAM 용량 48GB 이상 가정)
  • 옵티마이저: Adam (파라미터의 2배 메모리 사용)

 

b. DDP 사용 시: GPU 당 사용량

구성 요소 GPU 당 VRAM 사용량 설명
모델 복제본 12 GB 전체 모델 복제
그래디언트 12 GB 전체 그래디언트 저장
옵티마이저 상태 24 GB Adam (파라미터 2배)
총 VRAM 사용량 48 GB 단일 GPU 기준
  • DDP 사용 시 GPU 당 모델 관련 VRAM 만 48 GB 필요.
  • 데이터셋을 VRAM 에 Load 할 여분의 리소스가 있어야 학습 가능.
  • 즉, GPU VRAM = 48GB 면 모델 학습 불가능.

 

c. FSDP 사용 시 : GPU 당 사용량

구성 요소 GPU 당 VRAM 사용량 설명
모델 샤드 3 GB 12 GB / 4 GPU 분할
그래디언트 3 GB 샤드별 부분 저장
옵티마이저 상태 6 GB 샤드별 Adam 상태 저장
임시 All-Gather a GB 순전파/역전파 시에만 연산을 위해 사용.
a: 가장 큰 레이어의 VRAM 크기
총 VRAM 사용량 12+a GB 단일 GPU 기준
  • 4등분 샤딩으로 GPU 당 모델 관련 VRAM 만 12 GB 필요.
    • 샤딩: 모델 파라미터를 분할 관리하는 핵심 메커니즘으로 뜻 그대로 분할함을 의미.
  • 단, 순전파/역전파 시 임시적으로 All-Gather 를 위해 모델 내 가장 큰 Layer 의 VRAM 인 a GB 만큼 여분 필요.
  • 12+a + b(데이터셋 리소스)GB 만큼의 GPU VRAM 이면 모델 학습 가능.
    • 넉넉잡아도 24 GB 정도면 충분히 학습 가능.
    • 평시 VRAM 사용량을 1/N으로 줄이되, 필요 시에만 일시적 메모리 증가를 허용.

실제 사례로 70B 모델을 FSDP 사용하여 4개의 GPU 로 학습했을 때 다음과 같다.

구성 GPU VRAM 값 설명
전체 모델 크기 140 GB (FP16) 단일 GPU 기준
최대 Layer 크기 7 GB 70B 모델의 Attention Layer
피크 메모리/GPU 35 GB(샤드) + 7 GB(Layer) = 42 GB 140GB 전체가 아닌 Layer 단위 관리

 

PyTorch 공식 문서에서 7B 이상 모델 학습 시 FSDP 기술을 권장한다고 하니 본 포스팅을 끝까지 읽고 FSDP 에 대해 이해하기 바란다.

 

그럼 GPU VRAM 을 효율적으로 사용할 수 있는 FSDP 에 대한 이론을 상세히 알아보기 전, 먼저 NCCL 의 All-Gather 와 Reduce-Scatter 에 대해 알아보자.


2. [NCCL] All-Gather & Reduce-Scatter 설명

먼저 NCCL(NVIDIA Collective Communications Library) 은 다중 GPU 환경에서 고성능 통신을 제공하는 라이브러리이다.

All-Gather, Reduce-Scatter, All-Reduce 연산이 분산 학습의 핵심을 이룬다.

그 중 All-Gather 와 Reduce-Scatter 는 대규모 분산 학습에서 핵심적인 통신 연산으로, 각각 데이터 수집과 분산 축소 기능을 수행한다.

 

a. All-Gather : 전체 데이터 수집

NCCL이 지원하는 기능으로, 아 그림과 같이 각 GPU에 분산되어 있는 데이터를 각각의 GPU가 전체 데이터를 합쳐주는 것이다.

All-Gather

연산 정의

  • 입력: 각 랭크(GPU)가 N개의 데이터 보유
  • 출력: 모든 랭크가 N × 랭크 수 크기의 버퍼에 전체 데이터 저장
  • 데이터 정렬: 랭크 인덱스 순으로 연속 저장

GPU-2개 예시

랭크 입력 데이터 출력 데이터
0 [A, B] [A, B, C, D]
1 [C, D] [A, B, C, D]

 

b. Reduce-Scatter : 분산 축소 연산

아래 그림과 같이 각각의 GPU에 있는 데이터를 합치고 나눠서 분산 시키는 기능이다.

Reduce-Scatter

연산 정의

  • Reduce: 모든 랭크의 데이터를 지정 연산(합, 최대 등)으로 축소
  • Scatter: 결과를 균등 분할하여 각 랭크에 할당

GPU-2 예시 (Sum 연산)

랭크 입력 데이터 출력 데이터
0 [A0, A1, A2, A3] [A0+B0, A1+B1]
1 [B0, B1, B2, B3] [A2+B2, A3+B3]

 

 

FSDP 는 All-Gather 와 Reduce-Scatter 두 가지 NCCL 연산을 모두 사용하여 모델을 학습한다.

그렇다면 FSDP 에서 All-Gather 와 Reduce-Scatter 연산을 통해 어떻게 GPU VRAM 을 효율적으로 사용하는지 알아보자.


3. FSDP 연산 방식

FSDP 에서는 아래 그림과 같이 모델 레이어의 파라미터를 GPU 수(N)로 분할(=샤딩)한다.

출처: FSDP Paper

위 그림과 같이 수평으로 분할한다.

직관적인 이해를 위해 단일 레이어의 수평 분할 예시를 참고해보자.

# 원본 레이어 구조
layer = nn.Linear(4096, 4096)  # (4096x4096) 크기의 가중치 행렬

단일 레이어가 4096 차원으로 구성되어 있다고 하면 GPU-4개 환경에서의 수평 분할했을 때 아래와 같이 된다.

GPU 소유 파라미터 범위 실제 저장 형태
0 가중치 [0:1024, :] (1024x4096) 텐서
1 가중치 [1024:2048, :] (1024x4096) 텐서
2 가중치 [2048:3072, :] (1024x4096) 텐서
3 가중치 [3072:4096, :] (1024x4096) 텐서

 

필자는 처음 FSDP 를 공부할 때 이런 의문이 들었다.

  • Q: 이렇게 레이어를 수평 분할하면 연산을 어떻게 해??

이에 대한 답이 위에서 설명한 All-Gather 와 Reduce-Scatter 연산이다.

모델 학습 시 각 단계별로 확인해보자.

 

a. 순전파 단계

# All-Gather로 전체 파라미터 일시적 복원
full_weight = all_gather(sharded_weights)  # (4096x4096) 완전체
output = input @ full_weight  # 완전한 행렬 연산 수행
del full_weight  # 즉시 메모리 해제
  • All-Gather 통해 필요한 레이어의 파라미터만 일시적 복원.
  • 복원된 레이어를 사용하여 연산 수행.
  • 메모리 효율성을 위해 복원된 파라미터를 즉시 해제.

 

b. 역전파 단계

# All-Gather로 파라미터 재복원 → 그래디언트 계산
# Reduce-Scatter로 분산 처리
local_grad = reduce_scatter(global_grad)  # 각 GPU당 (1024x4096) 그래디언트
  • All-Gather 통해 필요한 레이어의 파라미터 다시 일시적 복원.
  • 복원된 레이어를 사용하여 그래디언트 계산.
  • Reduce-Scatter 통해 각 GPU 는 자신의 샤드에 해당하는 그래디언트만 보유.

 

c. 옵티마이저 단계

optimizer.step()  # 해당 GPU의 샤드에 해당하는 weight 갱신
  • 각 GPU는 로컬 샤드에 대해서만 파라미터 업데이트 수행.

 

이러한 단계로 모델 학습이 이루어지기 때문에 GPU VRAM 을 절약할 수 있다.

이를 그림으로 확인하면 이해하기 더 쉬울거다.

 

ⅰ 순전파/역전파 시 All-Gather & 메모리 해제

출처: https://alnova2.tistory.com/1471

  • 4개의 GPU에 Weight가 각각 샤딩.
  • 순전파/역전파를 위해 각 Node 에 있는 Weight 를 각각의 GPU 에 모아서 계산.
  • 계산 후 원래 샤드 되어 있는 Weight 만 남기고 해제.
  • 하나의 GPU 에서 메모리 요구 = 각 GPU 에 저장되어야 하는 샤드 모델의 크기 + 각 Weight 가 합쳐졌을 때의 메모리 크기

 

순전파/역전파 시 All-Gather 연산 후 Reduce-Scatter

출처: https://alnova2.tistory.com/1471

  • Loss.backward후 Gradient가 각각 다름.
  • 이는 All-Gather 에서 합쳐진 Weight 는 동일하나 각 GPU 별로 서로 다른 Mini-Batch 가 동작하기 때문.

 

즉, FSDP 에서 VRAM 을 절약하는 방법을 요약하면 다음과 같다.

  • 지속적 샤딩: 업데이트 후에도 파라미터가 분산 상태 유지.
  • 동적 수집: 순전파/역전파 시에만 임시 복원.

마무리,,

지금까지 FSDP 를 사용하면 왜 적은 자원에서도 큰 모델을 학습할 수 있는지 그 방법에 대해 알아봤다.
필자의 개인적인 생각은 자신이 Deep Learning 전문가라면 적어도 자신이 사용하는 방법론에 대한 이론적 이해는 있어야 한다고 생각한다. 

필자는 요즘 DeepSeek 와 AI-Agent 를 위한 MCP 에 대한 공부 필요를 느끼고 있다.

그래서 다음 포스팅에서는 해당 주제를 공부한 뒤 포스팅하도록 하겠다.