搭建网络的步骤大体为如下:python
1.准备数据算法
2. 定义网络结构model网络
3. 定义损失函数
4. 定义优化算法 optimizer
5. 训练
5.1 准备好tensor形式的输入数据和标签(可选)
5.2 前向传播计算网络输出output和计算损失函数loss
5.3 反向传播更新参数
如下三句话一句也不能少:
5.3.1 optimizer.zero_grad() 将上次迭代计算的梯度值清0
5.3.2 loss.backward() 反向传播,计算梯度值
5.3.3 optimizer.step() 更新权值参数
5.4 保存训练集上的loss和验证集上的loss以及准确率以及打印训练信息。(可选
6. 图示训练过程当中loss和accuracy的变化状况(可选)
7. 在测试集上测试app
代码注释都写的很详细 函数
1 import torch 2 import torch.nn.functional as F 3 import matplotlib.pyplot as plt 4 5 # 1.准备数据 generate data 6 x=torch.unsqueeze(torch.linspace(-1,1,100),dim=1) 7 print(x.shape) 8 y=x*x+0.2*torch.rand(x.size()) 9 #显示数据散点图 10 plt.scatter(x.data.numpy(),y.data.numpy()) 11 12 # 2.定义网络结构 build net 13 class Net(torch.nn.Module): 14 #n_feature:输入特征个数 n_hidden:隐藏层个数 n_output:输出层个数 15 def __init__(self,n_feature,n_hidden,n_output): 16 # super表示继承Net的父类,并同时初始化父类的参数 17 super(Net,self).__init__() 18 # nn.Linear表明线性层 表明y=w*x+b 其中w的shape为[n_hidden,n_feature] b的shape为[n_hidden] 19 # y=w^T*x+b 这里w的维度是转置前的维度 因此是反的 20 self.hidden =torch.nn.Linear(n_feature,n_hidden) 21 self.predict =torch.nn.Linear(n_hidden,n_output) 22 print(self.hidden.weight) 23 print(self.predict.weight) 24 #定义一个前向传播过程函数 25 def forward(self, x): 26 # n_feature n_hidden n_output 27 #举例(2,5,1) 2 5 1 28 # - ** - 29 # ** - - - ** - - 30 # - ** - - - ** 31 # ** - - - ** - - 32 # - ** - 33 # 输入层 隐藏层 输出层 34 x=F.relu(self.hidden(x)) 35 x=self.predict(x) 36 return x 37 # 实例化一个网络为net 38 net = Net(n_feature=1,n_hidden=10,n_output=1) 39 print(net) 40 # 3.定义损失函数 这里使用均方偏差(mean square error) 41 loss_func=torch.nn.MSELoss() 42 # 4.定义优化器 这里使用随机梯度降低 43 optimizer=torch.optim.SGD(net.parameters(),lr=0.2) 44 #定义300遍更新 每10遍显示一次 45 plt.ion() 46 # 5.训练 47 for t in range(100): 48 prediction = net(x) # input x and predict based on x 49 loss = loss_func(prediction, y) # must be (1. nn output, 2. target) 50 # 5.3反向传播三步不可少 51 optimizer.zero_grad() # clear gradients for next train 52 loss.backward() # backpropagation, compute gradients 53 optimizer.step() # apply gradients 54 55 if t % 10 == 0: 56 # plot and show learning process 57 plt.cla() 58 plt.scatter(x.data.numpy(), y.data.numpy()) 59 plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) 60 plt.text(0.5, 0, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 20, 'color': 'red'}) 61 plt.show() 62 plt.pause(0.1) 63 64 plt.ioff()
参考:莫烦python测试