d2lzh_pytorch模块跳转连接
import torch
import numpy as np
import sys
sys.path.append("路径")
import d2lzh_pytorch as d2l
'''
-----------------------------生成人工数据集
样本数n=200
特征数=3
三阶多项式y=1.2x-3.4x^2+5.6x^3+5+ε
'''
n_train, n_test, true_w, true_b = 100, 100, [1.2, -3.4, 5.6], 5
sample_features = torch.randn(n_train + n_test, 1)
poly_sample_features = torch.cat((sample_features, torch.pow(sample_features, 2), torch.pow(sample_features, 3)),
dim=1)
labels = true_w[0] * poly_sample_features[:, 0] + true_w[1] * poly_sample_features[:, 1] + true_w[
2] * poly_sample_features[:, 2] + true_b
labels += torch.tensor(np.random.normal(0, 0.01, (labels.size())), dtype=torch.float)
'''
-----------------------------------------------------定义,训练和测试模型
'''
num_epochs, loss = 100, torch.nn.MSELoss()
'''
以下函数思路设计:
1. 参数传入训练数据集样本,测试数据集样本,训练标签,测试标签
2. 设计网络,网络输入特征计算格式
3. 设计数据读取,读取训练数据集
4. 循环更新迭代步骤,目的是为了优化w,b
5. 循环外使用全批量训练数据集和测试数据集,在已经更新好的w,b的基础上进行损失计算
6. 画图
'''
def fit_and_plot(train_features, test_features, train_labels, test_labels, label):
net = torch.nn.Linear(train_features.shape[-1], 1)
batch_size = min(10, train_labels.shape[0])
dataset = torch.utils.data.TensorDataset(train_features, train_labels)
train_iter = torch.utils.data.DataLoader(dataset, batch_size, shuffle=True)
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
train_ls, test_ls = [], []
for _ in range(num_epochs):
for X, y in train_iter:
y_hat = net(X)
l = loss(y_hat, y.view(y_hat.size()))
optimizer.zero_grad()
l.backward()
optimizer.step()
train_labels = train_labels.view(-1, 1)
test_labels = test_labels.view(-1, 1)
train_ls.append(loss(net(train_features), train_labels).item())
test_ls.append(loss(net(test_features), test_labels).item())
print(f'final epoch: train loss', train_ls[-1], 'test loss', test_ls[-1])
d2l.semilogy(range(1, num_epochs + 1), train_ls, 'epochs', 'loss', label,
range(1, num_epochs + 1), test_ls, ['train', 'test'])
print('weight:', net.weight.data,
'\nbias:', net.bias.data)
'''
-----------------------------------------------------------------三阶多项式函数拟合(正常)
'''
fit_and_plot(poly_sample_features[:n_train, :], poly_sample_features[n_train:, :], labels[:n_train], labels[n_train:],
'正常')
'''
----------------------------------------------------------------线性函数拟合(欠拟合)
'''
fit_and_plot(sample_features[:n_train, :], sample_features[n_train:, :], labels[:n_train], labels[n_train:], '欠拟合')
'''
----------------------------------------------------------------训练样本不足(过拟合)
'''
fit_and_plot(poly_sample_features[0:2, :], poly_sample_features[:n_train, :], labels[0:2], labels[:n_train], '过拟合')