????????在PyTorch中,.pt、.pth和.pth.tar都是保存训练好的模型的文件格式。主要区别在于:
使用torch.save进行保存,保存时传入保存的状态,名称
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, 'model_best.pth.tar')
1、通过torch.load()函数加载
checkpoint_path = "/home/user/msh/Project/SimCLR-master_old/runs/Jan03_19-04-59_user-X10DRi/checkpoint_0100.pth.tar"
checkpoint = torch.load(checkpoint_path)
print(checkpoint.keys())
运行结果如下:
dict_keys(['epoch', 'arch', 'state_dict', 'optimizer'])
2、epoch存放的是训练的轮次,arch存放的是模型的名称,optimizer存放是优化器具体的参数,
epoch = checkpoint['epoch']
print(epoch)
arch = checkpoint['arch']
print(arch)
optimizer = checkpoint['optimizer']
运行结果:
100
resnet18
3、state_dict.keys()存放的是模型每一层结构的名称
state_dict = checkpoint['state_dict']
print(state_dict.keys())
4、使用:先初始化模型,创建一个对象,然后使用load_state_dict()函数加载参数
model = ResNetSimCLR(arch,160)
model.load_state_dict(state_dict)