一个可以对基于Pytorch搭建的模型的训练过程进行全程追踪的模块

发布时间:2023年12月31日

本文所述的trace模块实现了对损失和准确率的全过程跟踪,并在生成损失与准确率统计图的同时实现损失、准确率和模型本身的同步存储,这使得即使训练间断,训练过程中的数据仍然可以被保留和呈现。模块源文件可在本项目Github仓库中获取。下文是对该模块的说明,源自本项目的README文件,由于笔者考试周临近,暂时没时间写个中文版的介绍,先开个贴占个坑吧。

A Pytorch Training Process Tracing Module

Xiangnan Zhang

School of Future Technology, Beijing Institute of Technology


This module called "trace" is used to trace whole traning process, as well as realize visualization of accuracy and loss.

All loss and accuracy can be traced via a Statistic object. During training and testing, these data will be appended to its list-format attributes. The model and traning data will
be saved and loaded at the same time, hence to guarantee the whole-traning-process tracing. While saving the model and traning data, line charts of loss and accuracy will be shown and saved as follows.

image

Importing

In terms of importing this module, you should code as follows:

import trace
from trace import Statistic

NOTICE: The class Statistic() should be imported separately, because it is the preriquisit of function load_statistic()and sys_load().

Details

Statistic(path)

Objects that belong to this class stores traning and testing loss and accuracy. So when you are initializing your model, you should create a Statistic like this:

statis=Statistic(statis_path)

its __init__()method will establish attributes self.train_loss, self.train_accuracy, self.test_lossand self.test_accuracy. Each of them are empty list, so you can
use .append()method to append values in your train_loop and test_loop functions, like this:

def test_loop(test_ds,model,loss_fn,statis):
    model.eval()
    ...
    statis.test_loss.append(test_loss)
    statis.test_accuracy.append(accuracy)
    ...

There are two methods for Statistic project called self.draw()and self.save, which can be used to draw statistical images and save Statistic object as pkl files. However, in most cases you should use Sys object’s .sys_conclude()method instead.

load_statistic(path)

This function is used to load a Statistic object. But in most cases, you should use sys_load()function instead.

sys_load(model_path,statis_path)

This function is used to load both model and Statistic object. It is highly recommended that you use this function to load these two items, because it can guarentee that model and Statistic object can be loaded at the same time.

Sys(model,model_path,statis)

This class aims to process model and Statistic object at the same time. You should create a Sys project after the model and Statistic object are loaded or iinitialized, like this:

syst=trace.Sys(model,model_path,statis)

Then you can use .sys_conclude()method to save both the model and Statistic object:

syst.sys_conclude("ConvM(4_categ)")

A string that represents the model should be given when using this method. When you need to save your model and Statistic object, this method is always highly recommended.


After saving, the model will be saved as a pth file, and the Statistic object will be saved as a pkl file. These suffixes should be included into file paths.

Deficiency

When using this module to trace data, train_loss will dramatically increase at the begining of a new training process, which can be seen in the front image. However, I’m not sure whether it is the module’s problem, or it is my model’s problem.

文章来源:https://blog.csdn.net/bgshuxuanzxn/article/details/135310883
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。