A3C——一种异步强化学习方法

目录
python

一、简介二、算法细节三、代码3.1 主结构3.2 Actor Critic 网络3.3 Worker3.4 Worker并行工做四、参考git

一、简介

A3C是Google DeepMind 提出的一种解决Actor-Critic不收敛问题的算法。咱们知道DQN中很重要的一点是他具备经验池,能够下降数据之间的相关性,而A3C则提出下降数据之间的相关性的另外一种方法:异步github

简单来讲:A3C会建立多个并行的环境, 让多个拥有副结构的 agent 同时在这些并行环境上更新主结构中的参数. 并行中的 agent 们互不干扰, 而主结构的参数更新受到副结构提交更新的不连续性干扰, 因此更新的相关性被下降, 收敛性提升.
web

二、算法细节

A3C的算法实际上就是将Actor-Critic放在了多个线程中进行同步训练. 能够想象成几我的同时在玩同样的游戏, 而他们玩游戏的经验都会同步上传到一个中央大脑. 而后他们又从中央大脑中获取最新的玩游戏方法。算法

这样, 对于这几我的, 他们的好处是: 中央大脑聚集了全部人的经验, 是最会玩游戏的一个, 他们能时不时获取到中央大脑的必杀招, 用在本身的场景中.缓存

对于中央大脑的好处是: 中央大脑最怕一我的的连续性更新, 不仅基于一我的推送更新这种方式能打消这种连续性. 使中央大脑没必要像DQN,DDPG那样的记忆库也能很好的更新。
微信


为了达到这个目的,咱们要有两套体系, 能够看做中央大脑拥有 global net和他的参数, 每位玩家有一个 global net的副本 local net, 能够定时向 global net推送更新, 而后定时从 global net那获取综合版的更新.

若是在 tensorboard 中查看咱们今天要创建的体系, 这就是你会看到的。

W_0就是第0个 worker, 每一个 worker均可以分享 global_net


若是咱们调用 sync中的 pull, 这个 worker就会从 global_net中获取到最新的参数.


若是咱们调用sync中的push, 这个worker就会将本身的我的更新推送去global_net.
网络

三、代码

此次咱们也是使用连续动做环境Pendulum作例子。
app

3.1 主结构


咱们使用了 Normal distribution 来选择动做, 因此在搭建神经网络的时候, actor这边要输出动做的均值和方差. 而后放入 Normal distribution 去选择动做. 计算 actor loss的时候咱们还须要使用到 critic提供的 TD error做为 gradient ascent 的导向.

critic只须要获得他对于 state的价值就行了. 用于计算 TD error.

3.2 Actor Critic 网络

这里由于代码有点多,有些部分会使用伪代码,完整代码最后会附上连接。异步

咱们将ActorCritic合并成一整套系统, 这样方便运行.

 1# 这个 class 能够被调用生成一个 global net.
2# 也能被调用生成一个 worker 的 net, 由于他们的结构是同样的,
3# 因此这个 class 能够被重复利用.
4class ACNet(object):
5    def __init__(self, globalAC=None):
6        # 当建立 worker 网络的时候, 咱们传入以前建立的 globalAC 给这个 worker
7        if 这是 global:   # 判断当下创建的网络是 local 仍是 global
8            with tf.variable_scope('Global_Net'):
9                self._build_net()
10        else:
11            with tf.variable_scope('worker'):
12                self._build_net()
13
14            # 接着计算 critic loss 和 actor loss
15            # 用这两个 loss 计算要推送的 gradients
16
17            with tf.name_scope('sync'):  # 同步
18                with tf.name_scope('pull'):
19                    # 更新去 global
20                with tf.name_scope('push'):
21                    # 获取 global 参数
22
23    def _build_net(self):
24        # 在这里搭建 Actor 和 Critic 的网络
25        return 均值, 方差, state_value
26
27    def update_global(self, feed_dict):
28        # 进行 push 操做
29
30    def pull_global(self):
31        # 进行 pull 操做
32
33    def choose_action(self, s):
34        # 根据 s 选动做

这些只是在建立网络而已,worker还有属于本身的class, 用来执行在每一个线程里的工做.

3.3 Worker

每一个worker有本身的class, class 里面有他的工做内容work

 1class Worker(object):
2    def __init__(self, name, globalAC):
3        self.env = gym.make(GAME).unwrapped # 建立本身的环境
4        self.name = name    # 本身的名字
5        self.AC = ACNet(name, globalAC) # 本身的 local net, 并绑定上 globalAC
6
7    def work(self):
8        # s, a, r 的缓存, 用于 n_steps 更新
9        buffer_s, buffer_a, buffer_r = [], [], []
10        while not COORD.should_stop() and GLOBAL_EP < MAX_GLOBAL_EP:
11            s = self.env.reset()
12
13            for ep_t in range(MAX_EP_STEP):
14                a = self.AC.choose_action(s)
15                s_, r, done, info = self.env.step(a)
16
17                buffer_s.append(s)  # 添加各类缓存
18                buffer_a.append(a)
19                buffer_r.append(r)
20
21                # 每 UPDATE_GLOBAL_ITER 步 或者回合完了, 进行 sync 操做
22                if total_step % UPDATE_GLOBAL_ITER == 0 or done:
23                    # 得到用于计算 TD error 的 下一 state 的 value
24                    if done:
25                        v_s_ = 0   # terminal
26                    else:
27                        v_s_ = SESS.run(self.AC.v, {self.AC.s: s_[np.newaxis, :]})[00]
28
29                    buffer_v_target = []    # 下 state value 的缓存, 用于算 TD
30                    for r in buffer_r[::-1]:    # 进行 n_steps forward view
31                        v_s_ = r + GAMMA * v_s_
32                        buffer_v_target.append(v_s_)
33                    buffer_v_target.reverse()
34
35                    buffer_s, buffer_a, buffer_v_target = np.vstack(buffer_s), np.vstack(buffer_a), np.vstack(buffer_v_target)
36
37                    feed_dict = {
38                        self.AC.s: buffer_s,
39                        self.AC.a_his: buffer_a,
40                        self.AC.v_target: buffer_v_target,
41                    }
42
43                    self.AC.update_global(feed_dict)    # 推送更新去 globalAC
44                    buffer_s, buffer_a, buffer_r = [], [], []   # 清空缓存
45                    self.AC.pull_global()   # 获取 globalAC 的最新参数
46
47                s = s_
48                if done:
49                    GLOBAL_EP += 1  # 加一回合
50                    break   # 结束这回合

3.4 Worker并行工做

这里是重点,也就是Worker并行工做的计算

 1    GLOBAL_AC = ACNet(GLOBAL_NET_SCOPE)  # 创建 Global AC
2    workers = []
3    for i in range(N_WORKERS):  # 建立 worker, 以后在并行
4        workers.append(Worker(GLOBAL_AC))   # 每一个 worker 都有共享这个 global AC
5
6COORD = tf.train.Coordinator()  # Tensorflow 用于并行的工具
7
8worker_threads = []
9for worker in workers:
10    job = lambda: worker.work()
11    t = threading.Thread(target=job)    # 添加一个工做线程
12    t.start()
13    worker_threads.append(t)
14COORD.join(worker_threads)  # tf 的线程调度

电脑里CPU有几个核就能够创建多少个worker, 也就能够把它们放在CPU核数个线程中并行探索更新. 最后的学习结果能够用这个获取 moving average 的 reward 的图来归纳.

完整代码连接:

https://github.com/cristianoc20/RL_learning/tree/master/A3C

四、参考

  1. https://medium.com/emergent-future/simple-reinforcement-learning-with-tensorflow-part-8-asynchronous-actor-critic-agents-a3c-c88f72a5e9f2

  2. https://morvanzhou.github.io/tutorials/machine-learning/reinforcement-learning/6-3-A3C/


本文分享自微信公众号 - 计算机视觉漫谈()。
若有侵权,请联系 support@oschina.cn 删除。
本文参与“OSC源创计划”,欢迎正在阅读的你也加入,一块儿分享。

相关文章
相关标签/搜索