Pytorch实战:LeNet手写数字识别

LeNet网络结构

网络参数:

使用的数据集是MNIST,其中包括60000个28*28的训练样本,10000个测试样本。

代码如下:

  • 导入所需的包
  • 定义网络结构(LeNet)
  • 定义前向传播过程
  • 超参数定义
  • 加载MNIST数据集
  • 定义损失函数和优化方法,本次使用的是交叉熵损失函数以及随机梯度下降优化方法
  • 开始训练,下图中用红框标记出来的区域,是使用pytorch训练一个网络必备的三件套,optimizer.zero_grad()方法的作用是在每次计算梯度之前首先将梯度置零,loss.backward()方法的作用就是计算loss函数对每个网络参数的梯度,optimizer.step() 方法的作用是利用计算好的梯度对网络参数进行更新。
  • 运行结果展示:

还是需要多练习代码编写能力,加油⛽️。