import torch.distributed as dist
# 在 torch.distributed.launch 中设置 local_rank 和 rank
local_rank = int(os.environ['LOCAL_RANK']) #这里的local_rank相当于告诉程序这是第几个进程
rank = int(os.environ['RANK'])
# 初始化分布式环境
dist.init_process_group(backend='nccl', init_method='tcp://localhost:23456', world_size=world_size, rank=rank)
# 使用 local_rank 指定当前进程在本地节点上的 GPU 设备
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
2.通过torch.nn.parallel.DistributedDataParallel定义模型。
model = MyModel()
model.to(device)
model = torch.nn.parallel.DistributedDataParallel(model)
train_dataset = MyDataset()
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler)
for epoch in range(num_epochs):
# 新增2:设置sampler的epoch,DistributedSampler需要这个来维持各个进程之间的相同随机数种子
trainloader.sampler.set_epoch(epoch)
for data, label in trainloader:
training...
torch.distributed.destroy_process_group()
如果只是在一台机子上运行,
python -m torch.distributed.launch --nproc_per_node=n --batch-size=n main.py
model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters = True)