要在PyTorch中實現分布式訓練,可以使用torch.distributed包提供的工具和函數。下面是一個簡單的示例代碼,演示如何在PyTorch中設置并運行分布式訓練:
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# 初始化進程組
dist.init_process_group("gloo", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
def train(rank, world_size):
setup(rank, world_size)
# 創建模型和優化器
model = MyModel()
model = DDP(model)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# 加載數據
train_dataset = MyDataset()
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, sampler=train_sampler)
# 訓練
for epoch in range(10):
for data, target in train_loader:
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
cleanup()
if __name__ == '__main__':
world_size = 4
mp.spawn(train, args=(world_size,), nprocs=world_size)
在這個示例中,我們首先設置了進程組,然后創建了模型、優化器和數據加載器。然后在train
函數中,我們使用torch.multiprocessing.spawn
函數來啟動多個進程,每個進程運行train
函數。在train
函數中,我們將模型包裝成DistributedDataParallel
對象來實現分布式訓練,同時使用torch.utils.data.distributed.DistributedSampler
來分配數據。最后,我們在訓練循環中進行模型訓練,并在訓練結束后清理進程組。