深度学习——训练过程实时可视化损失函数走势(附代码)

发布时间:2024年01月24日

深度学习训练过程实时可视化损失函数可以帮助我们更好地了解模型的训练情况,从而做出更合理的训练决策。

一、可视化损失函数好处

帮助我们判断模型是否在正确的方向上训练。如果损失函数一直在下降,说明模型正在逐渐拟合训练数据。如果损失函数一直在上升,说明模型可能存在过拟合或欠拟合的问题。

帮助我们确定训练是否已经收敛。如果损失函数的下降趋势已经平缓,说明模型已经收敛。如果损失函数仍然在下降,说明模型还可以继续训练。

帮助我们调整训练参数。如果损失函数的下降趋势不理想,我们可以调整训练参数,例如学习率、批处理尺寸等,以提高模型的训练效果。

二、可视化损失函数代码

2.1 边训练边可视化

2.1.1 代码

下面给出了一个完整的的训练过程代码,并在中间加了试试可视化代码,学者参考代码架构,在自己代码中将损失函数添加到列表loss_list中,再自定义调整多少个epoch绘画一个损失函数点。具体代码见下:

import torch
import torch.optim as optim  # 导入优化器模块
import matplotlib.pyplot as plt

# 定义损失函数
def loss_fn(y_true, y_pred):
    return torch.mean((y_true - y_pred)**2)

# 定义模型
model = torch.nn.Linear(10, 1)

# 定义训练数据
x_train = torch.randn(1000, 10)
y_train = torch.randn(1000, 1)

# 定义优化器
optimizer = optim.Adam(model.parameters(), lr=0.001)  # 使用 Adam 优化器,学习率为 0.001

# 定义损失函数
loss_list = []

# 开始训练
for epoch in range(10000):
    # 前向传播
    y_pred = model(x_train)

    # 计算损失
    loss = loss_fn(y_train, y_pred)
    loss_list.append(loss.item())

    # 反向传播
    loss.backward()

    # 更新参数
    optimizer.step()

    # 展示损失
    if epoch % 10 == 0:    # 10 个epoch绘画一个损失函数点,可以自定义
        print(f"epoch {epoch}: loss {loss.item()}")
        # 更新损失曲线
        plt.cla()
        plt.plot(loss_list)
        plt.xlabel("epoch")
        plt.ylabel("loss")
        plt.pause(0.01)

# 绘制损失曲线
plt.show()

2.1.2 实时可视化效果

运行上面代码后,会弹出一个窗口,实时绘制损失函数的走势,如下:

在这里插入图片描述

2.2 训练完后绘画损失函数

2.2.1 代码

具体代码见下:

import matplotlib.pyplot as plt
import torch
import torch.optim as optim  # 导入优化器模块

# 定义损失函数
def loss_fn(y_true, y_pred):
    return torch.mean((y_true - y_pred)**2)

# 定义模型
model = torch.nn.Linear(10, 1)

# 定义训练数据
x_train = torch.randn(1000, 10)
y_train = torch.randn(1000, 1)

# 定义优化器
optimizer = optim.Adam(model.parameters(), lr=0.001)  # 使用 Adam 优化器,学习率为 0.001

# 定义损失函数
loss_list = []

# 开始训练
for epoch in range(10000):
    # 前向传播
    y_pred = model(x_train)

    # 计算损失
    loss = loss_fn(y_train, y_pred)
    loss_list.append(loss.item())

    # 反向传播
    loss.backward()

    # 更新参数
    optimizer.step()

    # 展示损失
    if epoch % 10 == 0:
        print(f"epoch {epoch}: loss {loss.item()}")

# 绘制损失曲线
plt.plot(loss_list)
plt.xlabel("epoch")
plt.ylabel("loss")
plt.show()

2.2.2 训练完绘制效果

该模式下,要训练完窗口才会弹出来。

在这里插入图片描述

三、总结

以上就是深度学习训练过程实时可视化损失函数走势的方法,当然还可以通过其它方法实时可视化损失函数走势,比如:

(1)使用图表工具绘制损失函数的曲线。这种方法直观易懂,可以帮助我们快速了解损失函数的变化趋势。

(2)使用日志文件记录损失函数的值。这种方法可以让我们在训练结束后再进行分析,但不如图表工具实时性强。

(3)使用可视化工具直接在训练过程中显示损失函数的值。这种方法最直观,但需要使用特定的工具。

上面方法希望能帮到你,学者灵活使用代码。

总结不易,扫下方二维码关注 视觉研坊,学习更多最新开源资源,多多支持,谢谢!

文章来源:https://blog.csdn.net/qq_40280673/article/details/135775282
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。