9. 现代循环神经网络navigate_next 9.8. 束搜索
Quick search
code
Show Source
MXNet PyTorch Jupyter 记事本 课程 GitHub English
动手学深度学习
Table Of Contents
  • 前言
  • 安装
  • 符号
  • 1. 引言
  • 2. 预备知识
    • 2.1. 数据操作
    • 2.2. 数据预处理
    • 2.3. 线性代数
    • 2.4. 微积分
    • 2.5. 自动微分
    • 2.6. 概率
    • 2.7. 查阅文档
  • 3. 线性神经网络
    • 3.1. 线性回归
    • 3.2. 线性回归的从零开始实现
    • 3.3. 线性回归的简洁实现
    • 3.4. softmax回归
    • 3.5. 图像分类数据集
    • 3.6. softmax回归的从零开始实现
    • 3.7. softmax回归的简洁实现
  • 4. 多层感知机
    • 4.1. 多层感知机
    • 4.2. 多层感知机的从零开始实现
    • 4.3. 多层感知机的简洁实现
    • 4.4. 模型选择、欠拟合和过拟合
    • 4.5. 权重衰减
    • 4.6. 暂退法(Dropout)
    • 4.7. 前向传播、反向传播和计算图
    • 4.8. 数值稳定性和模型初始化
    • 4.9. 环境和分布偏移
    • 4.10. 实战Kaggle比赛:预测房价
  • 5. 深度学习计算
    • 5.1. 层和块
    • 5.2. 参数管理
    • 5.3. 延后初始化
    • 5.4. 自定义层
    • 5.5. 读写文件
    • 5.6. GPU
  • 6. 卷积神经网络
    • 6.1. 从全连接层到卷积
    • 6.2. 图像卷积
    • 6.3. 填充和步幅
    • 6.4. 多输入多输出通道
    • 6.5. 汇聚层
    • 6.6. 卷积神经网络(LeNet)
  • 7. 现代卷积神经网络
    • 7.1. 深度卷积神经网络(AlexNet)
    • 7.2. 使用块的网络(VGG)
    • 7.3. 网络中的网络(NiN)
    • 7.4. 含并行连结的网络(GoogLeNet)
    • 7.5. 批量规范化
    • 7.6. 残差网络(ResNet)
    • 7.7. 稠密连接网络(DenseNet)
  • 8. 循环神经网络
    • 8.1. 序列模型
    • 8.2. 文本预处理
    • 8.3. 语言模型和数据集
    • 8.4. 循环神经网络
    • 8.5. 循环神经网络的从零开始实现
    • 8.6. 循环神经网络的简洁实现
    • 8.7. 通过时间反向传播
  • 9. 现代循环神经网络
    • 9.1. 门控循环单元(GRU)
    • 9.2. 长短期记忆网络(LSTM)
    • 9.3. 深度循环神经网络
    • 9.4. 双向循环神经网络
    • 9.5. 机器翻译与数据集
    • 9.6. 编码器-解码器架构
    • 9.7. 序列到序列学习(seq2seq)
    • 9.8. 束搜索
  • 10. 注意力机制
    • 10.1. 注意力提示
    • 10.2. 注意力汇聚:Nadaraya-Watson 核回归
    • 10.3. 注意力评分函数
    • 10.4. Bahdanau 注意力
    • 10.5. 多头注意力
    • 10.6. 自注意力和位置编码
    • 10.7. Transformer
  • 11. 优化算法
    • 11.1. 优化和深度学习
    • 11.2. 凸性
    • 11.3. 梯度下降
    • 11.4. 随机梯度下降
    • 11.5. 小批量随机梯度下降
    • 11.6. 动量法
    • 11.7. AdaGrad算法
    • 11.8. RMSProp算法
    • 11.9. Adadelta
    • 11.10. Adam算法
    • 11.11. 学习率调度器
  • 12. 计算性能
    • 12.1. 编译器和解释器
    • 12.2. 异步计算
    • 12.3. 自动并行
    • 12.4. 硬件
    • 12.5. 多GPU训练
    • 12.6. 多GPU的简洁实现
    • 12.7. 参数服务器
  • 13. 计算机视觉
    • 13.1. 图像增广
    • 13.2. 微调
    • 13.3. 目标检测和边界框
    • 13.4. 锚框
    • 13.5. 多尺度目标检测
    • 13.6. 目标检测数据集
    • 13.7. 单发多框检测(SSD)
    • 13.8. 区域卷积神经网络(R-CNN)系列
    • 13.9. 语义分割和数据集
    • 13.10. 转置卷积
    • 13.11. 全卷积网络
    • 13.12. 风格迁移
    • 13.13. 实战 Kaggle 比赛:图像分类 (CIFAR-10)
    • 13.14. 实战Kaggle比赛:狗的品种识别(ImageNet Dogs)
  • 14. 自然语言处理:预训练
    • 14.1. 词嵌入(word2vec)
    • 14.2. 近似训练
    • 14.3. 用于预训练词嵌入的数据集
    • 14.4. 预训练word2vec
    • 14.5. 全局向量的词嵌入(GloVe)
    • 14.6. 子词嵌入
    • 14.7. 词的相似性和类比任务
    • 14.8. 来自Transformers的双向编码器表示(BERT)
    • 14.9. 用于预训练BERT的数据集
    • 14.10. 预训练BERT
  • 15. 自然语言处理:应用
    • 15.1. 情感分析及数据集
    • 15.2. 情感分析:使用循环神经网络
    • 15.3. 情感分析:使用卷积神经网络
    • 15.4. 自然语言推断与数据集
    • 15.5. 自然语言推断:使用注意力
    • 15.6. 针对序列级和词元级应用微调BERT
    • 15.7. 自然语言推断:微调BERT
  • 16. 附录:深度学习工具
    • 16.1. 使用Jupyter Notebook
    • 16.2. 使用Amazon SageMaker
    • 16.3. 使用Amazon EC2实例
    • 16.4. 选择服务器和GPU
    • 16.5. 为本书做贡献
    • 16.6. d2l API 文档
  • 参考文献
动手学深度学习
Table Of Contents
  • 前言
  • 安装
  • 符号
  • 1. 引言
  • 2. 预备知识
    • 2.1. 数据操作
    • 2.2. 数据预处理
    • 2.3. 线性代数
    • 2.4. 微积分
    • 2.5. 自动微分
    • 2.6. 概率
    • 2.7. 查阅文档
  • 3. 线性神经网络
    • 3.1. 线性回归
    • 3.2. 线性回归的从零开始实现
    • 3.3. 线性回归的简洁实现
    • 3.4. softmax回归
    • 3.5. 图像分类数据集
    • 3.6. softmax回归的从零开始实现
    • 3.7. softmax回归的简洁实现
  • 4. 多层感知机
    • 4.1. 多层感知机
    • 4.2. 多层感知机的从零开始实现
    • 4.3. 多层感知机的简洁实现
    • 4.4. 模型选择、欠拟合和过拟合
    • 4.5. 权重衰减
    • 4.6. 暂退法(Dropout)
    • 4.7. 前向传播、反向传播和计算图
    • 4.8. 数值稳定性和模型初始化
    • 4.9. 环境和分布偏移
    • 4.10. 实战Kaggle比赛:预测房价
  • 5. 深度学习计算
    • 5.1. 层和块
    • 5.2. 参数管理
    • 5.3. 延后初始化
    • 5.4. 自定义层
    • 5.5. 读写文件
    • 5.6. GPU
  • 6. 卷积神经网络
    • 6.1. 从全连接层到卷积
    • 6.2. 图像卷积
    • 6.3. 填充和步幅
    • 6.4. 多输入多输出通道
    • 6.5. 汇聚层
    • 6.6. 卷积神经网络(LeNet)
  • 7. 现代卷积神经网络
    • 7.1. 深度卷积神经网络(AlexNet)
    • 7.2. 使用块的网络(VGG)
    • 7.3. 网络中的网络(NiN)
    • 7.4. 含并行连结的网络(GoogLeNet)
    • 7.5. 批量规范化
    • 7.6. 残差网络(ResNet)
    • 7.7. 稠密连接网络(DenseNet)
  • 8. 循环神经网络
    • 8.1. 序列模型
    • 8.2. 文本预处理
    • 8.3. 语言模型和数据集
    • 8.4. 循环神经网络
    • 8.5. 循环神经网络的从零开始实现
    • 8.6. 循环神经网络的简洁实现
    • 8.7. 通过时间反向传播
  • 9. 现代循环神经网络
    • 9.1. 门控循环单元(GRU)
    • 9.2. 长短期记忆网络(LSTM)
    • 9.3. 深度循环神经网络
    • 9.4. 双向循环神经网络
    • 9.5. 机器翻译与数据集
    • 9.6. 编码器-解码器架构
    • 9.7. 序列到序列学习(seq2seq)
    • 9.8. 束搜索
  • 10. 注意力机制
    • 10.1. 注意力提示
    • 10.2. 注意力汇聚:Nadaraya-Watson 核回归
    • 10.3. 注意力评分函数
    • 10.4. Bahdanau 注意力
    • 10.5. 多头注意力
    • 10.6. 自注意力和位置编码
    • 10.7. Transformer
  • 11. 优化算法
    • 11.1. 优化和深度学习
    • 11.2. 凸性
    • 11.3. 梯度下降
    • 11.4. 随机梯度下降
    • 11.5. 小批量随机梯度下降
    • 11.6. 动量法
    • 11.7. AdaGrad算法
    • 11.8. RMSProp算法
    • 11.9. Adadelta
    • 11.10. Adam算法
    • 11.11. 学习率调度器
  • 12. 计算性能
    • 12.1. 编译器和解释器
    • 12.2. 异步计算
    • 12.3. 自动并行
    • 12.4. 硬件
    • 12.5. 多GPU训练
    • 12.6. 多GPU的简洁实现
    • 12.7. 参数服务器
  • 13. 计算机视觉
    • 13.1. 图像增广
    • 13.2. 微调
    • 13.3. 目标检测和边界框
    • 13.4. 锚框
    • 13.5. 多尺度目标检测
    • 13.6. 目标检测数据集
    • 13.7. 单发多框检测(SSD)
    • 13.8. 区域卷积神经网络(R-CNN)系列
    • 13.9. 语义分割和数据集
    • 13.10. 转置卷积
    • 13.11. 全卷积网络
    • 13.12. 风格迁移
    • 13.13. 实战 Kaggle 比赛:图像分类 (CIFAR-10)
    • 13.14. 实战Kaggle比赛:狗的品种识别(ImageNet Dogs)
  • 14. 自然语言处理:预训练
    • 14.1. 词嵌入(word2vec)
    • 14.2. 近似训练
    • 14.3. 用于预训练词嵌入的数据集
    • 14.4. 预训练word2vec
    • 14.5. 全局向量的词嵌入(GloVe)
    • 14.6. 子词嵌入
    • 14.7. 词的相似性和类比任务
    • 14.8. 来自Transformers的双向编码器表示(BERT)
    • 14.9. 用于预训练BERT的数据集
    • 14.10. 预训练BERT
  • 15. 自然语言处理:应用
    • 15.1. 情感分析及数据集
    • 15.2. 情感分析:使用循环神经网络
    • 15.3. 情感分析:使用卷积神经网络
    • 15.4. 自然语言推断与数据集
    • 15.5. 自然语言推断:使用注意力
    • 15.6. 针对序列级和词元级应用微调BERT
    • 15.7. 自然语言推断:微调BERT
  • 16. 附录:深度学习工具
    • 16.1. 使用Jupyter Notebook
    • 16.2. 使用Amazon SageMaker
    • 16.3. 使用Amazon EC2实例
    • 16.4. 选择服务器和GPU
    • 16.5. 为本书做贡献
    • 16.6. d2l API 文档
  • 参考文献

9.8. 束搜索¶
Open the notebook in Colab
Open the notebook in Colab
Open the notebook in Colab
Open the notebook in Colab
Open the notebook in SageMaker Studio Lab

在 9.7节中,我们逐个预测输出序列, 直到预测序列中出现特定的序列结束词元“<eos>”。 本节将首先介绍贪心搜索(greedy search)策略, 并探讨其存在的问题,然后对比其他替代策略: 穷举搜索(exhaustive search)和束搜索(beam search)。

在正式介绍贪心搜索之前,我们使用与 9.7节中 相同的数学符号定义搜索问题。 在任意时间步\(t'\),解码器输出\(y_{t'}\)的概率取决于 时间步\(t'\)之前的输出子序列\(y_1, \ldots, y_{t'-1}\) 和对输入序列的信息进行编码得到的上下文变量\(\mathbf{c}\)。 为了量化计算代价,用\(\mathcal{Y}\)表示输出词表, 其中包含“<eos>”, 所以这个词汇集合的基数\(\left|\mathcal{Y}\right|\)就是词表的大小。 我们还将输出序列的最大词元数指定为\(T'\)。 因此,我们的目标是从所有\(\mathcal{O}(\left|\mathcal{Y}\right|^{T'})\)个 可能的输出序列中寻找理想的输出。 当然,对于所有输出序列,在“<eos>”之后的部分(非本句) 将在实际输出中丢弃。

9.8.1. 贪心搜索¶

首先,让我们看看一个简单的策略:贪心搜索, 该策略已用于 9.7节的序列预测。 对于输出序列的每一时间步\(t'\), 我们都将基于贪心搜索从\(\mathcal{Y}\)中找到具有最高条件概率的词元,即:

(9.8.1)¶\[y_{t'} = \operatorname*{argmax}_{y \in \mathcal{Y}} P(y \mid y_1, \ldots, y_{t'-1}, \mathbf{c})\]

一旦输出序列包含了“<eos>”或者达到其最大长度\(T'\),则输出完成。

../_images/s2s-prob1.svg

图9.8.1 在每个时间步,贪心搜索选择具有最高条件概率的词元¶

如 图9.8.1中, 假设输出中有四个词元“A”“B”“C”和“<eos>”。 每个时间步下的四个数字分别表示在该时间步 生成“A”“B”“C”和“<eos>”的条件概率。 在每个时间步,贪心搜索选择具有最高条件概率的词元。 因此,将在 图9.8.1中 预测输出序列“A”“B”“C”和“<eos>”。 这个输出序列的条件概率是 \(0.5\times0.4\times0.4\times0.6 = 0.048\)。

那么贪心搜索存在的问题是什么呢? 现实中,最优序列(optimal sequence)应该是最大化 \(\prod_{t'=1}^{T'} P(y_{t'} \mid y_1, \ldots, y_{t'-1}, \mathbf{c})\) 值的输出序列,这是基于输入序列生成输出序列的条件概率。 然而,贪心搜索无法保证得到最优序列。

../_images/s2s-prob2.svg

图9.8.2 在时间步2,选择具有第二高条件概率的词元“C”(而非最高条件概率的词元)¶

图9.8.2中的另一个例子阐述了这个问题。 与 图9.8.1不同,在时间步\(2\)中, 我们选择 图9.8.2中的词元“C”, 它具有第二高的条件概率。 由于时间步\(3\)所基于的时间步\(1\)和\(2\)处的输出子序列已从 图9.8.1中的“A”和“B”改变为 图9.8.2中的“A”和“C”, 因此时间步\(3\)处的每个词元的条件概率也在 图9.8.2中改变。 假设我们在时间步\(3\)选择词元“B”, 于是当前的时间步\(4\)基于前三个时间步的输出子序列“A”“C”和“B”为条件, 这与 图9.8.1中的“A”“B”和“C”不同。 因此,在 图9.8.2中的时间步\(4\)生成 每个词元的条件概率也不同于 图9.8.1中的条件概率。 结果, 图9.8.2中的输出序列 “A”“C”“B”和“<eos>”的条件概率为 \(0.5\times0.3 \times0.6\times0.6=0.054\), 这大于 图9.8.1中的贪心搜索的条件概率。 这个例子说明:贪心搜索获得的输出序列 “A”“B”“C”和“<eos>” 不一定是最佳序列。

9.8.2. 穷举搜索¶

如果目标是获得最优序列, 我们可以考虑使用穷举搜索(exhaustive search): 穷举地列举所有可能的输出序列及其条件概率, 然后计算输出条件概率最高的一个。

虽然我们可以使用穷举搜索来获得最优序列, 但其计算量\(\mathcal{O}(\left|\mathcal{Y}\right|^{T'})\)可能高的惊人。 例如,当\(|\mathcal{Y}|=10000\)和\(T'=10\)时, 我们需要评估\(10000^{10} = 10^{40}\)序列, 这是一个极大的数,现有的计算机几乎不可能计算它。 然而,贪心搜索的计算量 \(\mathcal{O}(\left|\mathcal{Y}\right|T')\) 通它要显著地小于穷举搜索。 例如,当\(|\mathcal{Y}|=10000\)和\(T'=10\)时, 我们只需要评估\(10000\times10=10^5\)个序列。

9.8.3. 束搜索¶

那么该选取哪种序列搜索策略呢? 如果精度最重要,则显然是穷举搜索。 如果计算成本最重要,则显然是贪心搜索。 而束搜索的实际应用则介于这两个极端之间。

束搜索(beam search)是贪心搜索的一个改进版本。 它有一个超参数,名为束宽(beam size)\(k\)。 在时间步\(1\),我们选择具有最高条件概率的\(k\)个词元。 这\(k\)个词元将分别是\(k\)个候选输出序列的第一个词元。 在随后的每个时间步,基于上一时间步的\(k\)个候选输出序列, 我们将继续从\(k\left|\mathcal{Y}\right|\)个可能的选择中 挑出具有最高条件概率的\(k\)个候选输出序列。

../_images/beam-search.svg

图9.8.3 束搜索过程(束宽:2,输出序列的最大长度:3)。候选输出序列是\(A\)、\(C\)、\(AB\)、\(CE\)、\(ABD\)和\(CED\)¶

图9.8.3演示了束搜索的过程。 假设输出的词表只包含五个元素: \(\mathcal{Y} = \{A, B, C, D, E\}\), 其中有一个是“<eos>”。 设置束宽为\(2\),输出序列的最大长度为\(3\)。 在时间步\(1\),假设具有最高条件概率 \(P(y_1 \mid \mathbf{c})\)的词元是\(A\)和\(C\)。 在时间步\(2\),我们计算所有\(y_2 \in \mathcal{Y}\)为:

(9.8.2)¶\[\begin{split}\begin{aligned}P(A, y_2 \mid \mathbf{c}) = P(A \mid \mathbf{c})P(y_2 \mid A, \mathbf{c}),\\ P(C, y_2 \mid \mathbf{c}) = P(C \mid \mathbf{c})P(y_2 \mid C, \mathbf{c}),\end{aligned}\end{split}\]

从这十个值中选择最大的两个, 比如\(P(A, B \mid \mathbf{c})\)和\(P(C, E \mid \mathbf{c})\)。 然后在时间步\(3\),我们计算所有\(y_3 \in \mathcal{Y}\)为:

(9.8.3)¶\[\begin{split}\begin{aligned}P(A, B, y_3 \mid \mathbf{c}) = P(A, B \mid \mathbf{c})P(y_3 \mid A, B, \mathbf{c}),\\P(C, E, y_3 \mid \mathbf{c}) = P(C, E \mid \mathbf{c})P(y_3 \mid C, E, \mathbf{c}),\end{aligned}\end{split}\]

从这十个值中选择最大的两个, 即\(P(A, B, D \mid \mathbf{c})\)和\(P(C, E, D \mid \mathbf{c})\), 我们会得到六个候选输出序列: (1)\(A\);(2)\(C\);(3)\(A,B\);(4)\(C,E\);(5)\(A,B,D\);(6)\(C,E,D\)。

最后,基于这六个序列(例如,丢弃包括“<eos>”和之后的部分), 我们获得最终候选输出序列集合。 然后我们选择其中条件概率乘积最高的序列作为输出序列:

(9.8.4)¶\[\frac{1}{L^\alpha} \log P(y_1, \ldots, y_{L}\mid \mathbf{c}) = \frac{1}{L^\alpha} \sum_{t'=1}^L \log P(y_{t'} \mid y_1, \ldots, y_{t'-1}, \mathbf{c}),\]

其中\(L\)是最终候选序列的长度, \(\alpha\)通常设置为\(0.75\)。 因为一个较长的序列在 (9.8.4) 的求和中会有更多的对数项, 因此分母中的\(L^\alpha\)用于惩罚长序列。

束搜索的计算量为\(\mathcal{O}(k\left|\mathcal{Y}\right|T')\), 这个结果介于贪心搜索和穷举搜索之间。 实际上,贪心搜索可以看作一种束宽为\(1\)的特殊类型的束搜索。 通过灵活地选择束宽,束搜索可以在正确率和计算代价之间进行权衡。

9.8.4. 小结¶

  • 序列搜索策略包括贪心搜索、穷举搜索和束搜索。

  • 贪心搜索所选取序列的计算量最小,但精度相对较低。

  • 穷举搜索所选取序列的精度最高,但计算量最大。

  • 束搜索通过灵活选择束宽,在正确率和计算代价之间进行权衡。

9.8.5. 练习¶

  1. 我们可以把穷举搜索看作一种特殊的束搜索吗?为什么?

  2. 在 9.7节的机器翻译问题中应用束搜索。 束宽是如何影响预测的速度和结果的?

  3. 在 8.5节中,我们基于用户提供的前缀, 通过使用语言模型来生成文本。这个例子中使用了哪种搜索策略?可以改进吗?

Discussions

Table Of Contents

  • 9.8. 束搜索
    • 9.8.1. 贪心搜索
    • 9.8.2. 穷举搜索
    • 9.8.3. 束搜索
    • 9.8.4. 小结
    • 9.8.5. 练习
Previous
9.7. 序列到序列学习(seq2seq)
Next
10. 注意力机制