‘DataParallel‘ object has no attribute ‘model1‘

发布时间:2023年12月25日
class BalanceNetAdapter(nn.Module):
    def __init__(self, fds, bucket_num, bucket_start, start_update,
                 start_smooth, kernel, ks, sigma, momentum # 模型1的参数
                 ):
        super(BalanceNetAdapter, self).__init__()

        # 模型1实例
        self.model1 = resnet50(fds=fds, bucket_num=bucket_num, bucket_start=bucket_start,
                     start_update=start_update, start_smooth=start_smooth,
                     kernel=kernel, ks=ks, sigma=sigma, momentum=momentum)
        
        # 模型2实例   后处理
        self.model2 = FLSS(512*4, 1)  

    def forward(self, input, target=None, epoch=None):

        # 如果是在训练状态
        if self.training:
            # 首先使用模型1进行特征提取
            features = self.model1(input, target, epoch) # features.shape=(256, 2048)  target=(256,1)
            # 然后使用模型2
            # 进行后处理
            pred, loss1 = self.model2.set_forward_loss(features, target)
            return pred, loss1
        else:
            # 如果是在验证或推理状态
            features = self.model1(input)
            pred = self.model2.set_forward(features)

            return pred

我的代码如下

model = BalanceNetAdapter(fds=args.fds, bucket_num=args.bucket_num, bucket_start=args.bucket_start,
                     start_update=args.start_update, start_smooth=args.start_smooth,
                     kernel=args.fds_kernel, ks=args.fds_ks, sigma=args.fds_sigma, momentum=args.fds_mmt)
model = torch.nn.DataParallel(model).cuda() 

问题1

在训练一个epoch结束后,这个时候的Model还是train的状态,我想更新一些自定义的参数,但是我只想用到model1, 那我要怎么写?
feature = model.model1(inputs, targets, epoch)如果这样写的话是会报错的,'DataParallel' object has no attribute 'model1

解决方法

在使用torch.nn.DataParallel包装模型之后,直接访问模型的子模块(例如model1)确实会导致报错,因为DataParallel对象并不直接暴露原始模型的内部属性或方法。为了解决这个问题,您需要通过DataParallel对象的module属性来访问原始模型的子模块。这是因为DataParallel将原始模型封装在它的module属性中。

因此,如果您想在DataParallel包装的模型上只使用model1子模块,您应该这样做:

feature = model.module.model1(inputs, targets, epoch)

这里的model.module实际上是指向您原始的BalanceNetAdapter模型,从而允许您访问其model1子模块。

需要注意的是,当使用model.module时,您只在一个GPU上运行model1,而不是在DataParallel模式下利用多个GPU。这通常适用于特定的场景,比如模型评估或特定参数的更新,其中并不需要多GPU并行处理。

总之,要在DataParallel模型中访问原始模型的子模块,请使用model.module属性。这使您可以像在未使用DataParallel时那样操作模型的子模块。

问题2

那我想更新model1里面的东西,要怎么写? model.model1.module.FDS.update_last_epoch_stats(epoch)这样写对吗?

解决方法

当您使用torch.nn.DataParallel来包装模型时,要更新模型内部子模块(如model1)的特定部分(如FDS),您应该通过model.module来访问原始模型,然后再访问它的子模块和组件。

如果您要更新model1中的FDS组件,正确的方式应该是:

model.module.model1.FDS.update_last_epoch_stats(epoch)

在这里,model.module访问的是DataParallel对象中封装的原始BalanceNetAdapter模型。然后,model1BalanceNetAdapter中的一个子模块,而FDSmodel1内的一个组件或属性。

请确保FDSupdate_last_epoch_stats确实是您model1中的组件和方法。如果它们不存在或命名不同,需要相应地调整代码。

文章来源:https://blog.csdn.net/weixin_43845922/article/details/135192396
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。