博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
8. Actor-Critic、DDPG、A3C
阅读量:5119 次
发布时间:2019-06-13

本文共 5475 字,大约阅读时间需要 18 分钟。

我们知道,学术中很多时候一般是先有了牛逼算法A,再有了牛逼算法B。但A,B算法一般都有缺点,于是有一天有人将两者整合,结合了两者优点,避免了两者缺点,皆大欢喜,喜大普奔。但对于AC算法来说其架构可以追溯到三、四十年前。 最早由Witten在1977年提出了类似AC算法的方法,然后Barto, Sutton和Anderson等大牛在1983年左右引入了actor-critic架构。但由于AC算法的研究难度和一些历史偶然因素,之后学界开始将研究重点转向value-based方法。之后的一段时间里value-based方法和policy-based方法都有了蓬勃的发展。前者比较典型的有TD系的方法。经典的Sarsa, Q-learning等都属于此列;后者比如经典的REINFORCE算法。之后AC算法结合了两者的发展红利,其理论和实践再次有了长足的发展。直到深度学习时代,AC方法结合了DNN作为函数拟合工具,产生了化学反应,出现了DDPG,A3C这样一批先进算法,以及其它基于它们的一些改进和变体。可以看到,这是一个先分后合的圆满故事。

---------------------
作者:ariesjzj
来源:CSDN
原文:https://blog.csdn.net/jinzhuojun/article/details/72851548
版权声明:本文为博主原创文章,转载请附上博文链接!

Actor-Critic

上一章讲到的蒙特卡罗策略梯度REINFORCE算法更新需要完整的状态序列,而且是每个序列单独对策略$\theta$进行更新,不太容易收敛。因此本章讨论策略(Policy Based)和价值(Value Based)相结合的方法Actor-Critic算法,解耦成生成动作并与环境交互的Actor,和评估Actor的表现并指导Actor下一阶段动作的Critic,用期望值Q代替蒙特卡洛采样得到的G实现单步更新,增大了学习效率。

对于Actor-Critic算法,我们需要

Actor,策略函数的近似,参数$\theta$,

\[{\pi _\theta }\left( {s,a} \right) = P\left( {a|s,\theta } \right)\]

Critic,价值函数的近似,参数$w$,

\[\begin{array}{l}

\hat v(s,w) \approx {v_\pi }(s)\\
\hat q(s,a,w) \approx {q_\pi }(s,a)
\end{array}\]

我们用两套神经网络来代替,流程是Critic通过Q网络计算状态的最优价值$v_t$,而Actor利用$v_t$这个最优价值迭代更新策略函数的参数,进而选择动作,并得到反馈到新的状态,Critic使用反馈和新的状态更新Q网络参数$\theta$, 在后面Critic会使用新的网络参数$w$来帮Actor计算状态的最优价值。

关于Critic的评估指标,可以基于状态价值V函数,动作价值函数Q函数,TD误差,优势函数等。

下面按照基于TD误差更新的Actor-Critic算法流程


1. 初始化Actor网络的策略函数参数$\theta$,Critic网络的价值函数参数$w$,

2. 对每一个episode:

3.  从$s_0$开始,对episode中的每一步:

A.    在Actor网络中输入状态$s$,输出动作$a$,

B.    采取动作$a$,得到新的状态$s'$和即时奖励$r$

C.    在Critic网络中分别输入状态$s$和$s'$,得到值函数输出$v(s)$和$v(s')$

D.    计算TD误差$\delta  = r + \gamma v(s') - v(s)$

E.    更新Critic网络参数$w$,通过均方差损失函数梯度更新

\[w \leftarrow w + \beta \delta \phi (s,a)\]

F.    更新Actor网络参数$\theta$,

\[\theta  \leftarrow \theta  + \alpha {\nabla _\theta }\log {\pi _\theta }\left( {

{s_t},{a_t}} \right)\delta \]

4. 重复以上步骤,从许多个episode中的每一步中不断学习。


DDPG

Actor-Critic 涉及到了两个神经网络,而且每次都是在连续状态中更新参数,每次参数更新前后都存在相关性,导致神经网络只能片面的看待问题,甚至导致神经网络学不到东西。为了解决这个问题,和之前我们讲到的DQN类似,Google DeepMind引入经验回放和双网络的方法来改进Actor-Critic难收敛的问题,提出了Deep Deterministic Policy Gradient[1]。

 DDPG算法


1. 初始化Actor当前网络$Q^a$的参数$\theta$,Actor目标网络$Q'^a$的参数$\theta'$,Critic当前网络$Q^c$的参数$w$,Critic目标网络$Q'^c$的参数$w'$,空的经验回放的集合D

2. 对每一个episode:

3.  从$s_0$开始,对episode中的每一步:

A.    在Actor当前网络$Q^a$中输入状态$s$,得到动作$a = {\pi _\theta }\left( {\phi (s)} \right) + N$,

B.    执行动作$a$,得到新的状态$s'$和即时奖励$r$,是否终止状态$is\_end$

C.    将$\left\{ {\phi (s),a,r,\phi (s'),is\_end} \right\}$五元组存入经验回放集合D

D.    从经验回放集合D中采样m个样本$\left\{ {\phi ({s_j}),{a_j},{r_j},\phi (s{'_j}),is\_en{d_j}} \right\}$,$j = 1,2,...,m$来更新两个网络参数:

a).      根据Actor目标网络$Q'^a$,依据采样样本中下一状态$s'$的最优下一动作$a'$,

\[a' = {\pi _{\theta '}}\left( {\phi \left( {s'} \right)} \right)\]

b).      根据Critic目标网络$Q'^c$,依据$s'$,$a'$计算当前目标Q值

\[{y_t} = \left\{ \begin{array}{l}

{r_t}{\rm{ is\_en}}{
{\rm{d}}_t}{\rm{ is true}}\\
{r_t} + \gamma Q'\left( {\phi \left( {s{'_t}} \right),a',w'} \right){\rm{ is\_en}}{
{\rm{d}}_t}{\rm{ is false}}
\end{array} \right.\]

c).      通过均方差损失函数梯度,更新Critic当前网络$Q^c$参数$w$

\[Loss = \frac{1}{m}\sum\nolimits_i {

{
{\left( {
{y_i} - Q\left( {\phi \left( {
{s_i}} \right),{a_i}|w} \right)} \right)}^2}} \]

d).      通过抽样策略梯度,更新Actor当前网络$Q^a$参数$\theta$,

\[J(\theta ) =  - \frac{1}{m}\sum\limits_{i = 1}^m {Q({s_i},{a_i},w)} \]

E.    如果t达到设定的目标网络参数更新频率,则更新Actor目标网络$Q^a$和Critic目标网络$Q^c$参数

\[\begin{array}{l}

w' \leftarrow \tau w + (1 - \tau )w'\\
\theta ' \leftarrow \tau \theta + (1 - \tau )\theta '
\end{array}\]


 

A3C

Deepmind克服了一些经验回放相关性过强的问题,提出了A3C算法,在多个环境样本中异构平行执行多个线程。该论文另辟蹊径,发明了“平民版”的DRL算法,证明了我们这些架不起集群,买不起GPU的穷人也照样能玩转前沿高端科技,而且效果还不比前者差。

传统经验认为,online的RL算法在和DNN简单结合后会不稳定。主要原因是观察数据往往波动很大且前后sample相互关联。像Neural fitted Q iteration和TRPO方法通过将经验数据batch,或者像DQN中通过experience replay memory对之随机采样,这些方法有效解决了前面所说的两个问题,但是也将算法限定在了off-policy方法中。本文提出了另一种思路,即通过创建多个agent,在多个环境实例中并行且异步的执行和学习。于是,通过这种方式,在DNN下,解锁了一大批online/offline的RL算法(如Sarsa, AC, Q-learning)。它还有个潜在的好处是不那么依赖于GPU或大型分布式系统。A3C可以跑在一个多核CPU上。总得来说,这篇论文更多是工程上的设计和优化。

其中异步训练框架是最大的优化。首先有一个Global Network,它是一个公共的神经网络模型,包含了Actor和Critic两套网络。下面有n个worker线程,每个线程里有和公共的神经网络一样的网络结构,每个线程会独立的和环境进行交互得到经验数据,这些线程之间互不干扰,独立运行。每个线程和环境交互到一定量的数据后,就计算在自己线程里的神经网络损失函数的梯度,但是这些梯度却并不更新自己线程里的神经网络,而是去更新公共的神经网络。也就是n个线程会独立的使用累积的梯度分别更新公共部分的神经网络模型参数。每隔一段时间,线程会将自己的神经网络的参数更新为公共神经网络的参数,进而指导后面的环境交互。可见,公共部分的网络模型就是我们要学习的模型,而线程里的网络模型主要是用于和环境交互使用的,这些线程里的模型可以帮助线程更好的和环境交互,拿到高质量的数据帮助公共部分的网络模型更快收敛。

 

 A3C的每一个线程的AC算法流程


1. 假设公共网络的参数为$\theta$,

2. 从$s_0$开始,重置本线程对应的参数$\theta'=\theta$,梯度更新量,$d\theta  \leftarrow 0$,

A.   根据当前Q网络$Q(s,a;\theta)$,利用对应的$\varepsilon$−Greedy策略选择动作a

B.   执行动作a,得到下一状态s'和奖励r,是否是终止状态

C.   根据目标网络$Q(\theta)$,依据$s'$,$a'$计算当前目标Q值

\[y = \left\{ {\begin{array}{*{20}{c}}

r&{
{\rm{for terminal }}s'}\\
{r + \gamma {
{\max }_{a'}}Q\left( {s',a';{\theta ^ - }} \right)}&{
{\rm{for non - terminal }}s'}
\end{array}} \right.\]

D.  累积参数$\theta$的梯度$d\theta$

\[d\theta  \leftarrow d\theta  + \frac{

{\partial {
{\left( {y - Q\left( {s,a;\theta } \right)} \right)}^2}}}{
{\partial \theta }}\]

E.   $T \leftarrow T+1$,$t \leftarrow t+1$

F.   if 全局共享的迭代轮数T mod 目标网络更新次数 ==0,then

    更新目标网络参数

G.   if 线程内单次迭代次数t mod 异步更新次数 ==0 or s是终止状态 then

    利用累积的$d\theta$更新参数$\theta$

    重置累积的$d\theta$

3. until 全局共享的迭代轮数T>全局最大迭代次数


 

 

[1] Lillicrap T P, Hunt J J, Pritzel A, et al. Continuous control with deep reinforcement learning[J]. international conference on learning representations, 2016.
 
[2] Mnih V, Badia A P, Mirza M, et al. Asynchronous methods for deep reinforcement learning[J]. international conference on machine learning, 2016: 1928-1937.

转载于:https://www.cnblogs.com/yijuncheng/p/10509691.html

你可能感兴趣的文章
[BZOJ5248] 2018九省联考 D1T1 一双木棋 | 博弈论 状压DP
查看>>
super 小记
查看>>
C语言实现<读取>和<写入> *.ini文件(转)
查看>>
【架构】Linux的架构(architecture)
查看>>
从解决Cocos2dx-2.x arm64 Crash 来看C的奇技淫巧
查看>>
ASM 图解
查看>>
Java快捷键
查看>>
Wasserstein距离 和 Lipschitz连续
查看>>
Python Day14 JavaScript
查看>>
关于java Date和时区的问题 (转)
查看>>
通过表单展示不一样的页面(input对象)
查看>>
Windows 7 SP1 加速了系统 还是 SQL 2008拖慢了系统 ? SQL Server 2012 初体验
查看>>
centos 开机自启设定:
查看>>
组件基础(插槽slot)—Vue学习笔记
查看>>
Gensim进阶教程:训练word2vec与doc2vec模型
查看>>
插头DP小结
查看>>
Springboot通过cors解决跨域问题(解决spring security oath2的/oauth/token跨域问题)
查看>>
博客开通了
查看>>
D. Jzzhu and Cities
查看>>
UVA1279,Asteroid Rangers,星际游击队,好烦的最小生成树
查看>>