曹子晗

个人博客

欢迎来到我的个人博客


一些数据分布式计算需要知道的

目录

参考链接:pytorch(分布式)数据并行个人实践总结——DataParallel/DistributedDataParallel - fnangle - 博客园 (cnblogs.com)

其他教程

参考代码:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.utils.data as D
import torch.distributed as dist
import argparse

# 构造模型,并将其放在一张卡上
resnet = torchvision.models.resnet18(num_classes=10).cuda(0)
# 将其中的bn转化为sync_bn
resnet = nn.SyncBatchNorm.convert_sync_batchnorm(resnet)

# 构造数据集
class Dataset(D.Dataset):
    def __init__(self, data, y):
        self.data = data
        self.y = y

    def __getitem__(self, index):
        # Transform Images Here...
        # transforms = torchvision.transform.Compose([
        # 		...
        # ])
        # data = transforms(self.data[index]) 
        data = self.data[index]
        return self.data[index], self.y[index]

    def __len__(self):
        return len(self.data)


data = torch.randn(size=(200, 3, 224, 224))
y = torch.randint(0, 10, (200,))
dataset = Dataset(data, y)


def get_argparse():
    parser = argparse.ArgumentParser()
    parser.add_argument('--local_rank', type=int, default=0)  # 每个进程都会被分配一个gpu
    parser.add_argument('--batch_size', type=int, default=8)
    parser.add_argument('--world_size', type=int, default=2)  # 一共有nnode*nproc_per_node个gpu
    parser.add_argument('--epochs', type=int, default=100)
    return parser.parse_args()


args = get_argparse()
dist.init_process_group(backend='nccl', init_method='env://', world_size=2, rank=args.local_rank)

sampler = D.DistributedSampler(dataset, rank=args.local_rank, shuffle=True)  # 一定要记得用sampler,gpu之间拿到的数据不是重复的
dataloader = D.DataLoader(dataset, batch_size=int(args.batch_size / 2), sampler=sampler)  # 记得batch_size / word_size

opt = torch.optim.SGD(resnet.parameters(), lr=1e-3)

for ep in range(epochs)
    for x, y in dataloader:
        dataloader.sampler.set_epoch(ep + 1) # 记得调用,防止gpu拿到的每一轮的数据都是一样的

        x = x.cuda()
        y = y.cuda()

        pred = resnet(x)
        loss = F.cross_entropy(pred, y)
        print(f'rank: {args.local_rank}, loss: {loss.item()}')

        opt.zero_grad()
        loss.backward()
        opt.step()

sh文件

CUDA_VISIBLE_DEVICES=0,1 python3 -m torch.distributed.launch --nproc_per_node=2 --nnode=1 --master_port=23456 distributed_tester.py [--args=...]

或者是

CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node=2 --nnode=1 --master_port=23456 eval_linear.py [--args=...]

注意,使用torchrun时,不再传入args.local_rank,要访问local_rank需要使用local_rank = int(os.environ['LOCAL_RANK']),将其代替原来的args.local_rank