DataParallel
- multi-thread 방식
- 코드 한줄로 구현 가능
model = nn.DataParallel(model)
방식
1. 배치 데이터를 4개의 GPU에 나눔
2. 모델을 4개의 GPU에 복사
3. 4개의 GPU에서 forward 수행
4. 결과값을 GPU-1에 다시 모아 gradient 계산
5. 계산된 gradient를 4개의 GPU에 다시 나눔
6 .4개의 GPU에서 backward 수행
단점
- output을 GPU-1에 모아 계산하기 때문에 GPU-1의 사용량이 많아집니다.
- Python은 GLI(Global Interpreter Lock)때문에 multi-thread로는 성능 향상이 어렵습니다.
- GLI : 하나의 thread에만 자원을 허락하고 Lock을 걸어 다른 thread의 실행을 막는 것
DIstributedDataParallel
- multi-process 방식
- DataParallel과 다르게 각 프로세스의 output을 한 GPU에 모아 학습시키지 않고 backward 과정에서 내부적으로 통신하는 trigger를 작동시켜 gradient를 동기화
- dataloader 사용 시 DistributedSampler 필요
코드
0. 라이브러리 및 모델
import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
def setup(rank, world_size):
# 각 프로세스가 다른 프로세스들과 통신할 수 있도록 분산환경 설정
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(10, 10)
self.relu = nn.ReLU()
self.net2 = nn.Linear(10, 5)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
1. 메인함수
def run_demo(demo_fn, world_size):
mp.spawn(demo_fn,
args=(world_size,),
nprocs=world_size,
join=True)
- pytorch의 multiprocess모듈에서 각 프로세스에 train함수를 호출합니다.
- train 함수는 gpu index(world_size)를 인자로 받습니다.
2. train함수
def demo_basic(rank, world_size):
setup(rank, world_size)
model = ToyModel().to(rank)
ddp_model = DDP(model, device_ids=[rank])
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
optimizer.zero_grad()
outputs = ddp_model(torch.randn(20, 10))
labels = torch.randn(20, 5).to(rank)
loss_fn(outputs, labels).backward()
optimizer.step()
cleanup()
- setup을 통해 분산환경을 설정합니다.
- to.(rank)로 모델을 GPU로 이동하고
- 모델을 DDP 모듈로 감쌉니다.
- loss를 계산하고 backward를 통해 DDP 내부적으로 gradient 동기화 통신이 진행됩니다.
- GPU가 4개인 경우 world_size = 4, rank = [0,1,2,3]
단점
- 모델에서 사용하지 않는 parameter가 있으면 에러가 뜬다고 합니다.
- 이 경우에는 nvidia의 apex를 받아서 써야합니다. (pip install apex안됨)
'Pytorch' 카테고리의 다른 글
[Pytorch] 다양한 Learning Rate Scheduler (0) | 2022.06.26 |
---|---|
[Pytorch] Onnx로 모델 export하기 (0) | 2022.06.12 |
[Pytorch] 유용한 method (view,reshape,squeeze,permute,stack,repeat,gather...) (0) | 2022.01.16 |