05.用PyTorch实现线性回归_哔哩哔哩_bilibili
最近在参考 B站 刘二大人 学习PyTorch,上传一些自己参考学习编写的代码,供交流使用。
import torch
import matplotlib.pyplot as plt
import numpy as np
#定义数据集
x_data=torch.Tensor([[1.0],[2.0],[3.0]])
y_data=torch.Tensor([[2.0],[4.0],[6.0]])
#定义模型
class LinearModel(torch.nn.Module):
def __init__(self):
super(LinearModel,self).__init__()#调用父类的init
self.linear=torch.nn.Linear(1,1)
#构造一个linear的对象,指定输入特征的维度和输出特征的维度
#Module实现了魔法函数__call__(),call()里面有一条语句是要调用forward()。
# 因此新写的类中需要重写forward()覆盖掉父类中的forward()
#call函数的另一个作用是可以直接在对象后面加(),例如实例化的model对象,和实例化的linear对象
def forward(self,x):
y_pred=self.linear(x)
return y_pred
model=LinearModel()#模型实例化
#构造损失函数和优化器
criterion=torch.nn.MSELoss(size_average=False)#损失函数,不求均值
#model.parameters()会扫描module中的所有成员,如果成员中有相应权重,那么都会将结果加到要训练的参数集合上
optimizer=torch.optim.SGD(model.parameters(),lr=0.01)#SGD优化器
#训练过程
epochs=[]
losses=[]
for epoch in range(100):
y_pred=model(x_data)
loss=criterion(y_pred,y_data)
print(epoch,loss.item())
epochs.append(epoch)
losses.append(loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
#打印权重值和偏置值
print('w=',model.linear.weight.item())
print('b=',model.linear.bias.item())
#测试数据
x_test=torch.Tensor([[4.0]])
y_test=model(x_test)
print('y_pred=',y_test.data.item())
#显示训练结果
plt.plot(epochs,losses)
plt.xlabel('epoch')
plt.ylabel('loss')
plt.show()
#保存损失函数的值
np.save('SGD.npy', losses)
训练使用SGD作为优化器。在训练过程中,损失函数变化如下:
将不同优化器训练过程中损失函数值保存,作图,代码如下:
import matplotlib.pyplot as plt
import numpy as np
losses1=np.load('Adagrad.npy')
losses2=np.load('Adam.npy')
losses3=np.load('Adamax.npy')
losses4=np.load('ASGD.npy')
losses5=np.load('RMSprop.npy')
losses6=np.load('Rprop.npy')
losses7=np.load('SGD.npy')
plt.plot(losses1, label='Adagrad')
plt.plot(losses2, label='Adam')
plt.plot(losses3, label='Adamax')
plt.plot(losses4, label='ASGD')
plt.plot(losses5, label='RMSprop')
plt.plot(losses6, label='Rprop')
plt.plot(losses7, label='SGD')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Loss Curves for Different Optimizers')
plt.legend()
plt.show()
训练效果如下:?参考:PyTorch 深度学习实践 第5讲_model call 深度学习-CSDN博客