返回介绍

6.5 预测编码

发布于 2024-02-05 23:12:36 字数 10216 浏览 0 评论 0 收藏 0

我们已经学习了如何利用RNN对影评中的情绪进行分类,以及如何识别手写单词。这些应用都是有监督的,即需要一个带标注信息的数据集。另外一种有趣的学习设置是预测编码(predictive coding),目的是通过向RNN输入大量序列,训练它预测序列的下一帧的能力。

以文本为例,预测一个句子中下一个单词的似然被称为语言建模(language modelling)。为什么预测句子中的下一个单词是有用的?有一类应用被称为识别语言。例如,假设希望构建一个手写文字识别器,目标是将手写文字图像转换为键入的文字。虽然可以尝试从输入图像恢复所有的单词,但如果能够预知下一个单词的概率分布,无疑能够缩小候选单词的考虑范围。基本上,这便是盲目的形状识别与阅读的区别。

除了提升模型对涉及自然语言的任务的处理性能,为了生成文本,也可以依据网络所认为的下一个单词的分布进行抽样。训练结束后,可将一个种子单词(seed word)送入RNN,然后观察它所预测的下一个单词。之后,将最可能的单词送回RNN作为接下来的输入,以观察它认为接下来应是什么。重复这个步骤,便可生成与训练数据看上去非常类似的新内容。

从一个循环语言模型进行种子采样

这里的有趣之处在于预测编码将训练网络对任意序列的所有重要信息进行压缩。一个句子中的下一个单词通常与前一个单词及其顺序及关系有关。因此,能够精确预测自然语言中下一个字符的网络也需要捕捉语法和语言规则。

6.5.1 字符级语言建模

下面利用RNN构建一个预测编码语言模型。我们用稍多于26个独热编码字符表示字母、一些标点符号和空格,而不将词向量作为输入。

对于单词级的语言建模和字符级的语言建模,哪个方法更优尚不清楚。字符级的建模方法之美在于网络不仅能学会如何构词,还可以学会如何拼写。此外,采用这种方法时,与尺寸为300的词向量或独热编码的单词相比,网络的输入维数更低。此外,还有一个好处,即不必再去考虑那些未知的单词,因为它们是由网络已知的字母构成的。从理论上讲,这甚至允许网络发明一些新的单词。

Andrew Karpathy在2015年将RNN应用于字符级语言建模,并自动生成了一些令人惊叹的莎士比亚剧本、Linux内核和驱动代码以及包括正确的标记语法的维基百科文章。这个项目的源码可从Github获取https://github.com/karpathy/char-rnn。下面从机器学习文献的摘要上训练一个类似的模型,看看能否生成一些多少有一定合理性的新摘要。

6.5.2 ArXiv摘要API

ArXiv.org是一个托管了来自计算机科学、数学、物理学和生物学等领域的许多研究论文的在线库。如果一直在追踪机器学习相关研究,可能对该网站早有耳闻。幸运的是,这个平台提供了一个基于Web的可用于检索文献的API。下面来编写一个依据给定搜索查询,从ArXiv获取摘要的类。

在构造方法中,首先检查是否有之前的摘要转储文件可用。如果有,则直接使用,而无需再次调用ArXiv API。你可以想象更为复杂的检查已有文件与新类别、新关键词是否匹配的逻辑,但就目前而言,执行新的查询时,将旧的转储文件删除或转移已经足够用了。如果没有转储文件可用,则调用_fetch_all()方法,并将它所生成的行写入磁盘。

由于所感兴趣的是机器学习论文,所以只在Machine Learning、Neural and Evolutionary Computing和Optimization and Control三个类别内进行搜索。我们进一步限制只返回那些元数据中包含单词neural、network或deep的结果,这样可以获取到约7MB的文本,这样的数据量对于训练一个简单的RNN语言模型已经足够大了。尽管使用更多的数据通常会得到更好的结果,但我们并不希望在看到结果之前用数小时等待训练结束。你尽可以使用更多的搜索查询,并用更多的数据来训练模型。

_fetch_all()方法基本上完成的是分页功能。每次查询时,这个API仅返回一定数量的摘要,我们可指定一个偏移量,用于获取比如第2页、第3页的结果。可以看到,我们能为下一个函数_fetch_page()传入一个指定了页面尺寸的参数。理论上,可以将页面尺寸设为一个很大的数,并尝试一次性得到全部结果。然而,实际上这种做法会严重影响查询的效率。页面的获取容错性更强,而且更重要的是,不会为ArXiv API增加过大的负载。

这里完成了实际的抓取,结果为XML格式,利用流行而强大的BeautifulSoup库来提取摘要。如果尚未安装该库,可通过执行命令sudo-H pip3 install beautifulsoup4来安装它。BeautifulSoup会为我们解析XML结果,这样便可遍历那些感兴趣的标签。首先查看对应于文章的<entry>标签,并从其内部读取包含摘要文本的<summary>标签。

6.5.3 数据预处理

数据预处理的相关代码如下:

6.5.4 预测编码模型

现在已介绍了整个流程:定义了任务,编写了一个解析器用于获取数据集,下面利用TensorFlow实现神经网络模型。由于对于预测编码而言,需要尝试预测输入序列中的下一个字符,所以模型只有一个输入,即构造方法的sequence参数。

此外,构造方法接收一个参数对象,用于修改重要的选项,并使实验可复现。第3个参数initial=None是循环连接层的初始内部活性值。虽然希望TensorFlow将隐状态初始化为零张量,但今后在需要从所学习到的语言模型进行采样时定义它会更加方便。

在上面的示例代码中,可以看到我们的模型所要实现的大致功能。如果初看上去觉得难以理解,请不必担心,与上一章模型相比,我们只是希望更多地突出这个模型的某些价值。

从数据处理开始。前面提到过,这个模型只接收一个序列块作为输入。首先,我们利用它构造输入数据和目标序列,这是引入时域差的地方,因为在时间步t,模型应有st作为输入,st+1作为输出。获取数据或目标的一种简便方法是对所提供的序列进行切片处理,并将第一帧或最后一帧分别切除。

切片运算是通过tf.slice()实现的,该函数的参数包括要切片的序列、一个包含各维起始索引的元组以及一个包含各维大小的元组。sizes-1意味着保持那个维度上从起始索引到终止索引的所有元素不变。由于希望对帧数据进行切片,所以只需关心第2维。

我们还为目标序列定义了两个前面已经讨论过的属性:mask是一个尺寸为batch_size×max_length的张量,其分量非0即1,具体取哪个值取决于相关帧是否被使用。为得到每个序列的长度,属性length沿时间轴对mask求和。

请注意,mask和length属性对于数据序列也是合法的,因为从概念上讲,它们与目标序列的长度相同。然而,我们并不在数据序列上计算这两个属性,因为它仍然包含着并不需要的最后一帧,而对它是没有下一个字母可预测的。将数据张量的最后一帧切除,但除了主要包含填充的帧外,它并不包含大多数序列实际上的最后一帧。这也正是下面用mask对代价函数进行掩膜处理的原因。

下面定义由一个循环神经网络和一个共享的softmax层构成的实际网络,具体方法与上一节序列标注任务中使用的结构类似。这里不再展示用于共享的softmax层的代码(可从上一节找到相关代码)。

上述神经网络代码中新增的部分是我们希望同时获得的预测和最后的循环活性值。在此之前,仅返回预测值,但最后的活性值可使我们更有效地生成一些序列。由于仅希望为循环神经网络构建一次数据流图,因此有一个属性forward用于返回由那两个张量构成的元组,而prediction和state的目的仅仅是便于外部访问。

模型的下一部分是代价函数和评价函数。在每个时间步,模型都会从词汇表中预测下一个字母。这是一个分类问题,我们相应地采用交叉熵代价函数,也可以很容易地计算字符预测错误率。

logprob属性是新增的,它刻画了模型在对数空间为正确的下一个字母所分配的概率。基本上,可以认为这是变换到对数空间并取均值后的负交叉熵。将结果变换回线性空间,便会得到所谓的混淆度(perplexity),这是一种用于评价语言模型性能的常见度量。

混淆度的定义为,直观地表示了模型在每个时间步必须猜测的选项数目。

对于完美的模型而言,混淆度为1,而始终对每个类别都输出相同概率的模型的混淆度为n。只要模型为下一个字母分配一个零概率,混淆度甚至会变为无穷大。为防止这种极端情况出现,可将预测概率箝位在一个很小的正数和1之间。

上述三个属性都会在所有序列的各帧上取平均。对于固定长度序列,结果将为一个tf.reduce_mean(),但在处理变长序列时,必须格外小心。首先,通过与掩膜相乘,屏蔽掉填充的帧。然后,沿着帧尺寸进行聚合。由于上述这三个函数都与目标值做了乘法,每帧只有一个元素集,我们利用tf.reduce_sum()函数将各帧聚合为一个标量。

接下来,希望利用序列的实际长度对每个序列中的各帧取平均。为了避免在空序列时除数为0,我们使用每个序列长度的最大值和1。最后,利用tf.reduce_mean()对批数据中的样本取平均。

下面直接开始训练模型。请注意,我们并未定义optimize运算,它始终与之前本章在序列分类或序列标签任务中所使用的运算一致。

6.5.5 训练模型

在对语言模型采样之前,必须将已经构建好的模块进行整合,包括数据集、预处理步骤和网络模型。下面编写一个对这些步骤进行整合的类,将新引入的混淆度度量打印出来,并周期性地将训练进展保存下来。这个检查点不但对于以后继续训练非常有用,而且还便于加载模型以用于采样(稍后将进行)。

构造方法、__call__()、_optimization()和_evaluation()都比较容易理解。我们加载数据集,为数据流图定义输入,在经过预处理的数据集上训练模型,并追踪对数几率,在相邻两次训练epoch之间的评价时间上使用它们计算并打印混淆度。

在_init_or_load_session()中,引入了一个tf.train.Saver(),用于将数据流图中所有tf.Variable()的当前值保存到检查点文件中。实际的点检查(checkpointing)是在_evalution()内完成的,在这里我们创建这个类并寻找已有的检查点文件以便加载。tf.train.get_checkpoint_state()会从检查点文件所在目录中查找TensorFlow的元数据文件。在本书撰写之时,它只包含最新生成的检查点文件。

检查点文件是通过一个可指定的数字(在本例中为epoch数)预先准备。在加载检查点文件时,利用Python的正则表达式包re提取epoch数。点检查的逻辑实现后,便可开始训练。下面是具体的配置:

为了运行这段代码,可调用Training(get_params())()。在笔者的笔记本电脑上,完成20个epoch需要大约1小时的时间。在训练过程中,模型一共看到了20 epochs*200 batches*100 examples*50 characters=20M个字母。

从上图可以看出,模型在混淆度约为1.5/字母时收敛,这意味着利用这个模型时,每个字母只需1.5位,从而可实现文本的压缩。

如果使用单词级的语言模型,则需要依据单词数而非字符数取平均。作为一种粗略的估计,可以将它乘以每个单词中的平均字符数。

6.5.6 生成相似序列

完成上述所有工作后,便可利用训练好的模型生成新的序列。我们将编写一个功能与Training类相似的较小的类,实现从磁盘加载最新的模型检查点,并定义一些占位符,以将数据输入数据流图。当然,这次并不训练模型,只是用它生成新数据。

在构造方法中,我们创建了一个预处理类的实例,后面利用它将当前生成的序列转化为一个NumPy向量,以输入数据流图。这时的占位符sequence对每批数据只预留了一个序列的空间,因为不希望每次生成多个序列。

这里序列的长度被设为2,下面做一解释。前面介绍过,我们的模型将除最后的字符外的所有字符作为输入,而将除首字符外的所有字符作为目标。我们将当前文本最后的字符和作为序列的任意第二个字符输入到模型中,网络将为第一个字符预测一个结果,将第二个字符用作目标值,但由于并不是训练模型,因此它将被忽略。

你可能会对只将当前文本最后的字符传入网络感到疑惑。这里采用的技巧是准备获取循环神经网络最后的活性值,并用它对网络下一次运行时的状态进行初始化。为此,需要利用模型的初始状态参数。对于使用过的GRUCell,该状态是一个尺寸为rnn_layers*rnn_units的向量。

__call__()函数定义了用于采样文本序列的逻辑。我们从一个采样种子开始,每次预测一个字符,并总是将当前文本送入网络。使用相同的预处理类将当前文本转换为填充后的NumPy块,然后将它们送入网络。由于在批数据中只有一个序列和一个输出帧,因此只关心索引[0,0]处的预测结果。之后,利用后面将要介绍的_sample()函数对softmax输出进行采样。

那么如何对网络输出进行采样?前文曾提到过,可选取序列最优的预测,并将其作为下一帧传入网络来生成序列。实际上,并非只选择最可能的下一帧,而是也从RNN输出的概率分布中随机抽样。按照这种方式,那些具有高输出概率的单词更可能被选中,但输出概率低的单词也是有可能被选中的。这样就可得到更多动态生成的序列。否则,可能是一次又一次地生成相同的平均句子。

要手工控制这个生成过程的有效性有一种简单的机制。例如,如果总是随机选择下一个单词(并将网络输出完全忽略),将得到非常新且独一无二的句子,但它们可能会没有任何意义。如果总是选择将网络最可能的输出作为下一个单词,则将得到大量虽常见但无意义的单词,如the、a等。

对这种行为进行控制的方式是引入一个温度参数T。利用该参数使softmax层的输出分布预测更相似或更为不同。这样会分别导致生成更有趣但有随机性的序列,以及更多合理但乏味的序列。其工作方式是在线性空间对输出进行缩放,然后将它们变换至指数空间并再次归一化:

由于网络已经输出了一个softmax分布,则可通过运用自然对数将其撤销。我们不必将归一化操作撤销,因为会再次将结果归一化。之后,将每个值除以所选择的温度值,并重新应用softmax函数。

下面通过调用Sampling(get_params())('We',500))运行上述代码,使网络生成一段新的摘要。虽然你一定能够看出这段文字绝非出自人手,但网络从样本中学习到的结果还是让人感到吃惊。

我们并未告知RNN什么是空间,但它却捕捉到了数据内部的统计依赖性,在所生成的文本中相应地放置了空格。即使在一些网络自己生成的并不存在的单词之间,空格的安排看上去也非常合理。此外,那些单词中的元音和辅音的搭配都很合理,这是从样例文本中学习到的另一种抽象特征。

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
    我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
    原文