语言选择: 简体中文简体中文 line EnglishEnglish

公司动态

PyTorch 17:优化器 (一)

前两节课中,我们学习了损失函数的概念以及 PyTorch 中的一系列损失函数方法,我们知道了损失函数的作用是衡量模型输出与真实标签之间的差异。在得到了 loss 函数之后,我们应该如何去更新模型参数,使得 loss 逐步降低呢?这正是优化器的工作。本节课我们开始学习优化器模块。

在学习优化器模块之前,我们先回顾一下机器学习模型训练的 5 个步骤:

我们看到,优化器是第 4 个模块,那么它的作用是什么呢?我们知道,在前一步的损失函数模块中,我们会得到一个 loss 值,即模型输出与真实标签之间的差异。有了 loss 值之后,我们一般会采用 PyTorch 中的 AutoGrid 自动梯度求导模块对模型中参数的梯度进行求导计算,之后优化器会拿到这些梯度值并采用一些优化策略去更新模型参数,使得 loss 值下降。因此,优化器的作用就是利用梯度来更新模型中的可学习参数,使得模型输出与真实标签之间的差异更小,即让 loss 值下降。

PyTorch 的优化器管理更新 模型中可学习参数 (权值或偏置) 的值,使得模型输出更接近真实标签。

  • 导数:函数在指定坐标轴上的变化率。
  • 方向导数:指定方向上的变化率。
  • 梯度:一个向量,方向为方向导数取得最大值的方向。

PyTorch 中的 Optimizer 类

基本属性

  • :优化器超参数。
  • :参数的缓存,如 的缓存。
  • :管理的参数组。
  • :记录更新次数,学习率调整中使用。

基本方法

  • :清空所管理参数的梯度 (PyTorch 特性:张量梯度不自动清零)。
  • :执行一步更新。
  • :添加参数组。
  • :获取优化器当前状态信息 字典
  • :加载状态信息字典。

下面我们来看一下优化器中的 5 种基本方法的具体使用方式:

为了方便计算,我们先设置学习率 :

输出结果:

可以看到,第一个梯度在更新之前的值为 $0.6614$,更新之后的值为 $0.6614 - 1=-0.3386$。现在,我们将学习率设置为 ,观察结果是否发生变化:

输出结果:

可以看到,第一个梯度更新后的值变为了 $0.6614 - 0.1=0.5614$。这就是 方法的一步更新。

输出结果:

可以看到,在执行 之前,我们的梯度为 $[[1., 1.],[1., 1.]]$,执行之后变为了 $[[0., 0.],[0., 0.]]$。另外,我们看到,optimizer 中管理的 的内存地址和真实的 地址是相同的,所以我们在优化器中保存的是参数的地址,而不是拷贝的参数的值,这样可以节省内存消耗。

我们同样采用上面的优化器,该优化器当前已经管理了一组参数,就是我们的 。现在我们希望再增加一组参数,并且我们将该组参数的学习率设置的更小一些 。首先,我们需要构建这样一组参数的字典,字典的 key 设置为 ,其值为新的一组参数 ;然后可以设置一些超参数,例如学习率 等。然后我们使用 将这组参数加入优化器中。

输出结果:

可以看到,在加入新参数之前,我们的优化器中只有一组参数,以一个列表形式呈现,里面只有一个字典元素。当我们使用 之后,列表中有了两个字典元素。可以看到,两组参数的学习率是不同的,所以通过这种方式我们可以为不同的参数组设置不同的学习率,这在模型拟合过程中是一种非常实用的方法。

这两个函数用于保存优化器的状态信息,通常用于断点的继续训练。

保存状态信息

输出结果:

可以看到,在更新之前, 里的值是一个空字典。在经过 10 步更新之后, 字典中有了一些值,它的 key 是 ,即参数地址,而它的值也是一个字典。其中 是动量中会使用的一些缓存信息。所以,在 中我们是通过地址去匹配参数的缓存的。然后,我们使用 对字典进行序列化,将其保存为一个 的形式,可以看到当前文件夹下多了一个 的文件。

读取状态信息

之前我们的模型已经训练了 10 次,假设我们总共需要训练 100 次,我们不希望再从头训练,而是希望能够接着之前第 10 次的状态继续训练,我们可以利用 加载前面保存的 文件,并将其读取加载到优化器中继续训练:

输出结果:

可以看到,在加载之前, 里的值是一个空字典。在使用 加载之后,我们得到了之前第 10 次的参数状态,然后继续在此基础上进行训练即可。

本节课中,我们学习了优化器 Optimizer 的概念和一些基本属性及方法。在下节课中,我们将继续学习 PyTorch 中的一些常用的优化方法 (优化器)。

下节内容:优化器 (二)

知识共享许可协议本作品采用知识共享署名-非商业性使用-相同方式共享 4.0 国际许可协议进行许可。 欢迎转载,并请注明来自:YEY 的博客 同时保持文章内容的完整和以上声明信息!


平台注册入口