pytorch集智-6手写数字加法机-迁移学习

发布时间:2024年01月20日

1 概述

迁移学习概念:将已经训练好的识别某些信息的网络拿去经过训练识别另外不同类别的信息

优越性:提高了训练模型利用率,解决了数据缺失的问题(对于新的预测场景,不需要大量的数据,只需要少量数据即可实现训练,可用于数据点很少的场景)

如何实现:将训练好的一个网络拿来和另一个网络连起来去训练即可实现迁移

训练方式:按是否改变源网络参数可分两类,分别是可改变和不可改变

2 案例 南非贫困预测

2.1 背景

南非存在贫困,1990-2021贫困人口从56%下降到43%,但下降的贫困人口数量和国际人道主义援助资源并不对应,而且大量资金援助一定程度加剧了贫富差距。可以看下具体哪些地区需要援助

2.2 方法

一个方法:夜光光亮遥感数据和人类gdp相关性经实验可达0.8-0.9,但夜光遥感和贫富没太大相关性:夜间光照月亮表示该地区越富有,但越安并不表示该地区越贫穷,也可能无人居住。

另一个方法:光亮遥感数据无法准确预测地区贫穷程度,但卫星遥感数据大体可以做到,判定依据有街道混乱程度等。如果要用深度网络训练,还需要对卫星遥感数据的图片标注贫困程度。非洲能获取到的贫困数据很少,但深度网络需要的数据量很大

最终方法:用迁移学习,将前两种方法合起来,见下图

3 案例2

3.1 背景

任务:区分图像里动物是蚂蚁还是蜜蜂,像素均为224x224

难点:只有244个图像,样本太少不足训练大型卷积网络,准确率只有50%左右

3.2 解决方案

解决方案:resnet与模型迁移,即用已训练好的物体分类的网络加全连接用来区分蚂蚁与蜜蜂

resnet:残差网络,对物体分类有较高精度

3.3 代码实现

3.3.1 准备数据

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as pyplot
import time
import copy
import os

data_path = 'pytorch/jizhi/figure_plus/data'
image_size = 224

class TranNet():
    def __init__(self):
        super(TranNet, self).__init__()
        
        self.train_dataset = datasets.ImageFolder(os.path.join(data_path, 'train'), transforms.Compose([
            transforms.RandomSizedCrop(image_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]))
        self.verify_dataset = datasets.ImageFolder(os.path.join(data_path, 'verify'), transforms.Compose([
            transforms.Scale(256),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]))
        self.train_loader = torch.utils.data.DataLoader(self.train_dataset, batch_size=4, shuffle=True, num_workers=4)
        self.verify_loader = torch.utils.data.DataLoader(self.verify_dataset, batch_size=4, shuffle=True, num_workers=4)
        self.num_classes = len(self.train_dataset.classes)
    
    def exec(self):
        ...

def main():
    TranNet().exec()

if __name__ == '__main__':
    main()

3.3.2 模型迁移

    def exec(self):
        self.model_prepare()
        
    def model_prepare(self):
        net = models.resnet18(pretrained=True)
        
        # float net values
        num_features = net.fc.in_features
        net.fc = nn.Linear(num_features, 2)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(net.parameters(), lr=0.0001, momentum=0.9)
        
        # fixed net values
        '''
        for param in net.parameters():
            param.requires_grad = False
        num_features = net.fc.in_features
        net.fc = nn.Linear(num_features, 2)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(net.fc.parameters(), lr = 0.001, momentum=0.9)
        '''

3.3.3 gpu加速

特点:gpu速度快,但内存低,所以尽量减少在gpu中存储的数据,只用来计算就好

    def model_prepare(self):
        # jusge whether GPU
        use_cuda = torch.cuda.is_available()
        dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
        itype = torch.cuda.LongTensor if use_cuda else torch.LongTensor
        
        net = models.resnet18(pretrained=True)
        net = net.cuda() if use_cuda else net

文章来源:https://blog.csdn.net/peter6768/article/details/135712687
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。