解决torchvision.model下载预训练模型太慢的问题

发布时间:2024年01月16日

以下代码等价,但第一种慢,第二种快:

import torchvision.models as models

self.resnet = models.resnet18(pretrained=True)  #  联网下载,慢
self.model = models.resnet18(pretrained=False)
state_dict = torch.load('src/model/resnet18-5c106cde.pth')  # 自己从网上下载.pth,快
self.model.load_state_dict(state_dict)  # 再把读出来的参数放进没有参数的模型

当pretrained=True,才会联网下载模型,否则很快,仅得到一个没训练过的模型。

.pth文件或者state_dict变量:模型参数,里面是模型每一层具体的浮点数

model:模型,不含参数

model和.pth如果是对应的,就可以用model.load_state_dict加载。注意这条语句是在模型上直接修改,不应写成model = model.load_state_dict。

所以我们可以自己在浏览器下载模型,然后加载进去。那么去哪里下载呢?Ctrl+函数打开源码自己就可以找到。

文章来源:https://blog.csdn.net/major_in_data_/article/details/135627316
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。