LeNet网络结构

网络参数:

使用的数据集是MNIST,其中包括60000个28*28的训练样本,10000个测试样本。
代码如下:
- 导入所需的包

- 定义网络结构(LeNet)

- 定义前向传播过程

- 超参数定义

- 加载MNIST数据集


- 定义损失函数和优化方法,本次使用的是交叉熵损失函数以及随机梯度下降优化方法

- 开始训练,下图中用红框标记出来的区域,是使用pytorch训练一个网络必备的三件套,optimizer.zero_grad()方法的作用是在每次计算梯度之前首先将梯度置零,loss.backward()方法的作用就是计算loss函数对每个网络参数的梯度,optimizer.step() 方法的作用是利用计算好的梯度对网络参数进行更新。

- 运行结果展示:

还是需要多练习代码编写能力,加油⛽️。