class BalanceNetAdapter(nn.Module):
def __init__(self, fds, bucket_num, bucket_start, start_update,
start_smooth, kernel, ks, sigma, momentum # 模型1的参数
):
super(BalanceNetAdapter, self).__init__()
# 模型1实例
self.model1 = resnet50(fds=fds, bucket_num=bucket_num, bucket_start=bucket_start,
start_update=start_update, start_smooth=start_smooth,
kernel=kernel, ks=ks, sigma=sigma, momentum=momentum)
# 模型2实例 后处理
self.model2 = FLSS(512*4, 1)
def forward(self, input, target=None, epoch=None):
# 如果是在训练状态
if self.training:
# 首先使用模型1进行特征提取
features = self.model1(input, target, epoch) # features.shape=(256, 2048) target=(256,1)
# 然后使用模型2
# 进行后处理
pred, loss1 = self.model2.set_forward_loss(features, target)
return pred, loss1
else:
# 如果是在验证或推理状态
features = self.model1(input)
pred = self.model2.set_forward(features)
return pred
我的代码如下
model = BalanceNetAdapter(fds=args.fds, bucket_num=args.bucket_num, bucket_start=args.bucket_start,
start_update=args.start_update, start_smooth=args.start_smooth,
kernel=args.fds_kernel, ks=args.fds_ks, sigma=args.fds_sigma, momentum=args.fds_mmt)
model = torch.nn.DataParallel(model).cuda()
在训练一个epoch结束后,这个时候的Model还是train的状态,我想更新一些自定义的参数,但是我只想用到model1
, 那我要怎么写?
feature = model.model1(inputs, targets, epoch)
如果这样写的话是会报错的,'DataParallel' object has no attribute 'model1
在使用torch.nn.DataParallel
包装模型之后,直接访问模型的子模块(例如model1
)确实会导致报错,因为DataParallel
对象并不直接暴露原始模型的内部属性或方法。为了解决这个问题,您需要通过DataParallel
对象的module
属性来访问原始模型的子模块。这是因为DataParallel
将原始模型封装在它的module
属性中。
因此,如果您想在DataParallel
包装的模型上只使用model1
子模块,您应该这样做:
feature = model.module.model1(inputs, targets, epoch)
这里的model.module
实际上是指向您原始的BalanceNetAdapter
模型,从而允许您访问其model1
子模块。
需要注意的是,当使用model.module
时,您只在一个GPU上运行model1
,而不是在DataParallel
模式下利用多个GPU。这通常适用于特定的场景,比如模型评估或特定参数的更新,其中并不需要多GPU并行处理。
总之,要在DataParallel
模型中访问原始模型的子模块,请使用model.module
属性。这使您可以像在未使用DataParallel
时那样操作模型的子模块。
那我想更新model1里面的东西,要怎么写? model.model1.module.FDS.update_last_epoch_stats(epoch)这样写对吗?
当您使用torch.nn.DataParallel
来包装模型时,要更新模型内部子模块(如model1
)的特定部分(如FDS
),您应该通过model.module
来访问原始模型,然后再访问它的子模块和组件。
如果您要更新model1
中的FDS
组件,正确的方式应该是:
model.module.model1.FDS.update_last_epoch_stats(epoch)
在这里,model.module
访问的是DataParallel
对象中封装的原始BalanceNetAdapter
模型。然后,model1
是BalanceNetAdapter
中的一个子模块,而FDS
是model1
内的一个组件或属性。
请确保FDS
和update_last_epoch_stats
确实是您model1
中的组件和方法。如果它们不存在或命名不同,需要相应地调整代码。