6.4 序列标注
在上一节中,我们使用LSTM网络,并在最后的活性值之上堆叠一个softmax层,构建了一个序列分类模型。在此基础上,现在开始处理一个难度更大的问题——序列标注(sequence labelling)。该设置问题与序列分类不同,因为它需要对输入序列的每一帧都预测一个类别。
例如,考虑手写文字识别。每个单词都是一个字母序列,我们当然可以单独对每个字母进行分类,但人类的语言具有很强的结构性,这一点是可以善加利用的。如果查看一些手写体样本,会发现有些字符是很难单独识别的,如n、m和u。然而,依据其近邻的字母构成的上下文来识别,就会容易许多。在本节中,我们将通过RNN来利用字母之间的依赖性,并构建一个比较稳健的OCR(Optical Character Recognition,光学字符识别)系统。
6.4.1 OCR数据集
作为一个序列标注问题的例子,我们先了解一下由MIT的口语系统研究组的Rob Kassel收集的,并由斯坦福大学人工智能实验室的Ben Taskar预处理的OCR数据集。该数据集包含了大量单独的手写字母,每个样本对应一幅16×8像素的二值图像。这些字母被组合为一些序列,且每个序列都对应一个单词。整个数据集共包含约6800个、长度至多为14的单词。
下面给出三个该OCR数据集中的序列样本。这几个单词分别为cafeteria、puzzlement和unexpected。这些单词的首字母并未包含在数据集中,因为它们都是大写的。所有序列都被填充为最大长度14。为了简化工作量,该数据集中仅包含小写字母,这正是一些单词中不含首字母的原因。
该数据集可从http://ai.stanford.edu/~btaskar/ocr/上获取,它对应于一个用gzip压缩的、内容用Tab分隔的文本文件,该文件可利用Python的csv模块直接读取。该文件中每行都表示该数据集中一个字母的属性,如ID号、标签、像素值、单词中下一个字母的ID号等。
首先对那些下一个字母的ID值进行排序,以便能够按照正确的顺序读取每个单词中的字母。然后继续收集字母,直到下一个ID对应的字段未被设置为止。出现这种情况时,我们开始读取一个新的序列。读取完目标字母及其数据像素后,用零图像对序列进行填充,以使其能够纳入两个较大的包含目标字母和所有像素数据的NumPy数组中。
6.4.2 时间步之间共享的softmax层
现在,数据和目标数组中都包含了序列,每个目标字母对应于一个图像帧。为了每帧数据获取一个预测结果的最简单的方法是对RNN进行扩展,在每个字母的输出之上添加一个softmax分类器。这非常类似于上一节序列分类问题中所采用的模型,唯一的区别在于分类器是对每帧数据而非整个序列进行评估的。
下面具体实现序列标注方法。首先,需要计算序列的长度。在上一节中,该工作已经完成,因此这里不再赘述。
现在进入预测部分,这是与序列分类模型存在主要差别的地方。要将一个softmax层添加到所有帧上有两种方法:或者为所有帧添加几个不同的分类器,或者令所有帧共享同一个分类器。由于对第3个字母进行分类并不比对第8个字母分类难度更大,所以采取后一种方式是比较合理的。按照这种方式,分类器权值在训练中被调整的次数更多,因为需要对单词中的每个字母进行训练。
要在TensorFlow中实现一个共享层,我们需要运用一点小技巧。一个全连接层的权值矩阵的维数始终为batch_size*in_size*out_size,但现在有两个输入维batch_size和sequence_steps,我们希望在这两个维度上对权值矩阵进行更新。
要解决这个问题,可以令这一层的输入(本例中文RNN的输出活性值)扁平为形状batch_size*sequence_steps*in_size。按照这种方式,对于权值矩阵而言,它看起来就像是一个较大的批数据。当然,还必须对结果的形状进行调整,即反扁平化(unflatten)。
相比于序列分类,这里的代价和误差函数的变动都很小,即对序列中的每一帧,如今都有了一个预测-目标对,因此必须在相应的维度上进行平均。然而,tf.reduce_mean()在这里无法使用,因为它要依据张量的长度(即序列的最大长度)进行归一化,而我们希望按照之前计算的实际序列长度进行归一化。因此,可手工调用tf.reduce_sum()和一个除法运算来获得正确的均值。
与代价函数类似,我们也必须对误差函数进行调整。现在,tf.argmax()针对的是轴2而非轴1。然后,对各帧进行填充,并依据序列的实际长度计算均值。最后的tf.reduce_mean()对批数据中的所有单词取均值。
TensorFlow的自动导数计算的一大优点是可像对序列分类问题那样对该模型使用相同的优化运算,我们所要做的仅仅是将新的代价函数代入。从现在开始,我们将对所有的RNN运用梯度裁剪,因为这种措施能够防止训练发散,同时不会产生任何负面影响。
6.4.3 训练模型
现在可将到目前为止介绍的所有部分整合到一起,开始训练模型。通过上一节的学习,相信读者对导入和配置参数已经非常熟悉。我们利用get_dataset()下载手写体图像并进行预处理,这也正是将小写字母编码为独热编码向量的地方。经过编码之后,随机打乱数据的顺序,以便在划分训练集和测试集时得到一个无偏的结果。
当用1000个单词训练之后,我们的模型在测试集上的错误率已降至约9%。这个结果不算太差,但仍有提升的空间。
我们目前使用的模型与用于序列分类的模型非常相似。笔者是有意而为之的,目的是帮助读者了解为将已有模型用于解决新问题,应做何种修改。在另一个问题上的有效解决方案对于一个新问题也极有可能比预想的要有效。然而,我们完全可以做得更好!下一节将尝试利用一种更高级的循环神经网络架构改进现有结果。
6.4.4 双向RNN
如何对用RNN加softmax架构在OCR数据集上得到的结果进行改进?不妨重新审视一下使用RNN的动机。我们为OCR数据集选择这一架构的原因在于单词中的相邻字母之间存在依赖关系(或互信息)。RNN会将关于在同一单词之前全部输入的信息保存到隐含活性值中。
如果能够想到这一点,就会意识到在模型中循环连接对于前几个字母的分类是没有太大帮助的,因为网络尚无大量输入以从中推断出额外的信息。在序列分类任务中,这并不是一个问题,因为网络在决策之前能够看到所有的帧。在序列标注任务中,可利用双向RNN(bidirectional RNN)克服RNN的这个缺陷,这项技术在若干分类任务中都保持着最高的水平。
双向RNN的思想非常简单。它共有两个RNN观测输入序列,一个按照通常的顺序从左端读取单词,而另一个按照相反的顺序从右端读取单词。这样,在每个时间步,就可得到两个输出活性值。在将它们送入共享的softmax层之前,可将两者拼接在一起。利用这种架构,分类器便可从每个字母获取完整的单词信息。
双向循环神经网络
那么如何用TensorFlow实现双向RNN?实际上TensorFlow中已有了一个实现版本——tf.model.rnn.bidirectional_rnn。但是,我们希望学习如何自行构建复杂模型,因此下面来实现这种模型。笔者将引导你完成各个步骤。首先,将预测属性划分到两个函数中,以便眼下只关注较少的内容。
上面的_shared_softmax()函数的实现比较容易:在之前的预测属性中,我们已经有了相关代码。区别在于现在是从传入该函数的张量data推断输入尺寸。依照这种方式,可在必要时复用其他架构的函数,然后可以利用相同的扁平化技巧在所有的时间步中共享同一个softmax层。
下面进入真正有趣的环节——实现双向RNN。如你所见,我们利用rnn.dynamic_rnn创建了两个RNN。前向网络对我们而言非常熟悉,但后向网络是全新的。
我们并不将数据送入后向RNN,而是首先将序列反转。这样做要比实现一个新的用于反向传递的RNN运算更加容易。TensorFlow提供了tf.reverse_sequence()函数,它可帮助我们完成对所使用的帧数据中至多sequence_lengths帧的反转操作。请注意,在本书撰写之时,该函数要求sequence_lengths参数为int64类型的张量。在未来的版本中,极有可能也支持该参数为int32类型[1],且只需传入self.length即可。
这里也使用了scope参数,为什么需要它?第3章曾解释过,数据流图中的节点是拥有名称的。scope是rnn_dynamic_cell所使用的变量scope的名称,其默认值为RNN。现在由于我们有两个参数不同的RNN,所以它们需要有不同的域。
将反转的序列送入后向RNN后,我们再次将网络的输出反转,以与前向输出对齐。然后沿着RNN的神经元输出的维度将这两个张量拼接在一起,并将其返回。例如,当批数据尺寸为50,每个RNN有300个隐藏单元,所有单词至多包含14个字母时,所得到张量的形状为50×14×600。
非常酷,这样我们就亲手构建了自己的第一个由多个RNN组成的架构!下面来检查利用上一节的训练代码能够使这个模型达到何种性能。通过比较两个预测误差图,可以看出,双向模型具有更优的性能。在接收1000个单词之后,它在测试集上对字母的识别错误率已经低至4%。
总结一下,在本节中,我们学习了如何利用RNN完成序列标注任务,并了解了该任务与序列分类任务的差异,即我们希望得到一个能够接收RNN的输出并为所有时间步所共享的分类器。
通过增加第二个从后向前访问序列的RNN,并将每个时间步的输出进行整合,模型的性能能够得到显著提升,这是因为在对每个字母进行分类时,整个序列的信息都是可用的。
在下一节中,我们将介绍如何用非监督的方式训练RNN模型,以实现语言的学习。
[1] 从TensorFlow10.0起,seq_lengths的类型可为int32。——译者注
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。
绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论