import os
os.environ['TORCH_HOME'] = r'D:\Pytorch\pythonProject\vgg16' # 下载位置
太大了(140多G)不提供直接下载
train_set = torchvision.datasets.ImageNet(root='./data_image_net', split='train', download=True
, transform=torchvision.transforms.ToTensor())
不预训练:采用随机参数
预训练:采用训练好的参数
vgg16_false = torchvision.models.vgg16(weights=None)
vgg16_true = torchvision.models.vgg16(weights='DEFAULT') # or weights='IMAGENET1K_V1'
完整代码
import torchvision
import os
os.environ['TORCH_HOME'] = r'D:\Pytorch\pythonProject\vgg16' # 下载位置
# train_set = torchvision.datasets.ImageNet(root='./data_image_net', split='train', download=True
# , transform=torchvision.transforms.ToTensor())
vgg16_false = torchvision.models.vgg16(weights=None)
vgg16_true = torchvision.models.vgg16(weights='DEFAULT') # or weights='IMAGENET1K_V1'
print(vgg16_true)
vgg16_true.add_module('add_linear', nn.Linear(1000, 10))
vgg16_true.classifier.add_module('add_linear', nn.Linear(1000, 10))
vgg16_false.classifier[6] = nn.Linear(4096, 10)
print(vgg16_false)