- PaperWeekly 2016.08.05 第一期
- PaperWeekly 第二期
- PaperWeekly 第三期
- PaperWeekly 第四期 - 基于强化学习的文本生成技术
- PaperWeekly 第五期 - 从 Word2Vec 到 FastText
- PaperWeekly 第六期 - 机器阅读理解
- PaperWeekly 第七期 -- 基于 Char-level 的 NMT OOV 解决方案
- PaperWeekly 第八期 - Sigdial2016 文章精选(对话系统最新研究成果)
- PaperWeekly 第九期 -- 浅谈 GAN
- PaperWeekly 第十期
- PaperWeekly 第十一期
- PaperWeekly 第十二期 - 文本摘要
- PaperWeekly 第十三期--最新文章解读
- PaperWeekly 第十四期 - TTIC 在 QA 任务上的研究进展
- PaperWeekly 第十六期 - ICLR 2017 精选
- PaperWeekly 第十七期 - 无监督/半监督 NER
- PaperWeekly 第十八期 - 提高 seq2seq 方法所生成对话的流畅度和多样性
- PaperWeekly 第十九期 - 新文解读(情感分析、机器阅读理解、知识图谱、文本分类)
- PaperWeekly 第二十期 - GAN(Generative Adversarial Nets)研究进展
- PaperWeekly 第二十一期 - 多模态机器翻译
- PaperWeekly 第二十二期 - Image Caption 任务综述
- PaperWeekly 第二十三期 - 机器写诗
- PaperWeekly 第二十四期 - GAN for NLP
- PaperWeekly 第二十五期 - 增强学习在 image caption 任务上的应用
- PaperWeekly 第二十六期 - 2016 年最值得读的 NLP paper 解读(3 篇)+在线 Chat 实录
- PaperWeekly 第二十七期 | VAE for NLP
- PaperWeekly 第 28 期 | 图像语义分割之特征整合和结构预测
- PaperWeekly 第 29 期 | 你的 Emoji 不一定是我的 Emoji
- PaperWeekly 第 30 期 | 解读 2016 年最值得读的三篇 NLP 论文 + 在线 Chat 实录
- PaperWeekly 第 31 期 | 远程监督在关系抽取中的应用
- PaperWeekly 第 32 期 | 基于知识图谱的问答系统关键技术研究 #01
- PaperWeekly 第 33 期 | 基于知识图谱的问答系统关键技术研究 #03
- PaperWeekly 第 34 期 | VAE 在 chatbot 中的应用
- PaperWeekly 第 35 期 | 如何让聊天机器人懂情感 PaperWeekly 第 35 期 | 如何让聊天机器人懂情感
- PaperWeekly 第 36 期 | Seq2Seq 有哪些不为人知的有趣应用?
- PaperWeekly 第 37 期 | 论文盘点:检索式问答系统的语义匹配模型(神经网络篇)
- PaperWeekly 第 38 期 | SQuAD 综述
- PaperWeekly 第 39 期 | 从 PM 到 GAN - LSTM 之父 Schmidhuber 横跨 22 年的怨念
- PaperWeekly 第 40 期 | 对话系统任务综述与基于 POMDP 的对话系统
- PaperWeekly 第 41 期 | 互怼的艺术:从零直达 WGAN-GP
- PaperWeekly 第 42 期 | 基于知识图谱的问答系统关键技术研究 #04
- PaperWeekly 第 43 期 | 教机器学习编程
- PaperWeekly 第 44 期 | Kaggle 求生
- PaperWeekly 第 45 期 | 词义的动态变迁
- PaperWeekly 第 46 期 | 关于远程监督,我们来推荐几篇值得读的论文
- PaperWeekly 第 47 期 | 开学啦!咱们来做完形填空:“讯飞杯”参赛历程
- 深度强化学习实战:Tensorflow 实现 DDPG - PaperWeekly 第 48 期
- 评测任务实战:中文文本分类技术实践与分享 - PaperWeekly 第 49 期
- 从 2017 年顶会论文看 Attention Model - PaperWeekly 第 50 期
- 深入浅出看懂 AlphaGo Zero - PaperWeekly 第 51 期
- PaperWeekly 第 52 期 | 更别致的词向量模型:Simpler GloVe - Part 1
- PaperWeekly 第 53 期 | 更别致的词向量模型:Simpler GloVe - Part 2
- 基于神经网络的实体识别和关系抽取联合学习 | PaperWeekly #54
深度强化学习实战:Tensorflow 实现 DDPG - PaperWeekly 第 48 期
1. 前言
本文主要讲解 DeepMind 发布在 ICLR 2016 的文章 Continuous control with deep reinforcement learning,时间稍微有点久远,但因为算法经典,还是值得去实现。
2. 环境
这次实验环境是 Openai Gym 的 Pendulum-v0,state 是 3 维连续的表示杆的位置方向信息,action 是 1 维的连续动作,大小是 -2.0 到 2.0,表示对杆施加的力和方向。目标是让杆保持直立,所以 reward 在杆保持直立不动的时候最大。笔者所用的环境为:
- Tensorflow (1.2.1)
- gym (0.9.2)
请先安装 Tensorflow 和 gym,Tensorflow 和 gym 的安装就不赘述了,下面是网络收敛后的结果。
3. 代码详解
先贴一张 DeepMind 文章中的伪代码,分析一下实现它,我们需要实现哪些东西:
4. 网络结构(model)
首先,我们需要实现一个 critic network 和一个 actor network,然后再实现一个 target critic network 和 target actor network,并且对应初始化为相同的 weights。下面来看看这部分代码怎么实现:
critic network & target critic network
上面是 critic network 的实现,critic network
是一个用神经网络去近似的一个函数,输入是 s-state,a-action,输出是 Q 函数,网络参数是,在这里我的实现和原文类似,state 经过一个全连接层得到隐藏层特征 h1,action 经过另外一个全连接层得到隐藏层特征 h2,然后特征串联在一起得到 h_concat,之后 h_concat 再经过一层全连接层得到 h3,最后 h3 经过一个没有激活函数的全连接层得到 q_output。这就简单得实现了一个 critic network。
上面是 target critic network 的实现,target critic network
网络结构和 critic network 一样,也参数初始化为一样的权重,思路是先把 critic network 的权重取出来初始化,再调用一遍 self.__create_critic_network() 创建 target network,最后把 critic network 初始化的权重赋值给 target critic network。
这样我们就得到了 critic network 和 critic target network。
actor network & actor target network
actor network 和 actor target network
的实现与 critic 几乎一样,区别在于网络结构和激活函数。
这里用了 3 层全连接层,最后激活函数是 tanh,把输出限定在 -1 到 1 之间。这样大体的网络结构就实现完了。
5. Replay Buffer & Random Process(Mechanism)
接下来,伪代码提到 replay buffer 和 random process,这部分代码比较简单也很短,主要参考了 openai 的 rllab 的实现,大家可以直接看看源码。
6. 网络更新和损失函数(Model)
用梯度下降更新网络,先需要定义我们的 loss 函数。
critic nework 更新
这里 critic 只是很简单的是一个 L2 loss。不过由于 transition 是 s, a, r, s'。要得到 y 需要一步处理,下面是预处理 transition 的代码。
训练模型是,从 Replay buffer 里取出一个 mini-batch,在经过预处理就可以更新我们的网络了,是不是很简单。y 经过下面这行代码处理得到。
actor nework 更新
actor network 的更新也很简单,我们需要求的梯度如上图,首先我们需要 critic network
对动作 a 的导数,其中 a 是由 actor network 根据状态 s 估计出来的。代码如下:
先根据 actor network 估计出 action,再用 critic network 的输出 q 对估计出来的 action 求导。
然后我们把得到的这部分梯度,和 actor network 的输出对 actor network 的权重求导的梯度,相乘就能得到最后的梯度,代码如下:
也就是说我们需要求的 policy gradient 主要由下面这一行代码求得,由于我们需要梯度下降去更新网络,所以需要加个负号:
之后就是更新我们的 target network,target network 采用 soft update 的方式去稳定网络的变化,算法如下:
就这样我们的整体网络更新需要的东西都实现了,下面是整体网络更新的代码:
总体的细节都介绍完了,希望大家有所收获。
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。
绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论