DQN(Deep Q-learning)入门教程(四)之Q-learning Play Flappy Bird

在上一篇博客中,咱们详细的对Q-learning的算法流程进行了介绍。同时咱们使用了\(\epsilon-贪婪法\)防止陷入局部最优。html

那么咱们能够想一下,最后咱们获得的结果是什么样的呢?由于咱们考虑到了全部的(\(\epsilon-贪婪法\)致使的)状况,所以最终咱们将会获得一张以下的Q-Table表。python

Q-Table \(a_1\) \(a_2\)
\(s_1\) \(q(s_1,a_1)\) \(q(s_1,a_2)\)
\(s_2\) \(q(s_2,a_1)\) \(q(s_2,a_2)\)
\(s_3\) \(q(s_3,a_1)\) \(q(s_3,a_2)\)

当agent运行到某一个场景\(s\)时,会去查询已经训练好的Q-Table,而后从中选择一个最大的\(q\)对应的action。git

训练内容

这一次,咱们将对Flappy-bird游戏进行训练。这个游戏的介绍我就很少说了,能够看一下维基百科的介绍。es6

游戏就是控制一只🐦穿越管道,而后能够得到分数,对于小鸟来讲,他只有两个动做,跳or不跳,而咱们的目标就是使小鸟穿越管道得到更多的分数。github

前置准备

由于咱们的目标是来学习“强化学习”的,因此咱们不可能说本身去弄一个Flappy-bird(固然本身弄也能够),这里咱们直接使用一个已经写好的Flappy-bird。算法

PyGame-Learning-Environment,是一个Python的强化学习环境,简称PLE,下面时他Github上面的介绍:app

PyGame Learning Environment (PLE) is a learning environment, mimicking the Arcade Learning Environment interface, allowing a quick start to Reinforcement Learning in Python. The goal of PLE is allow practitioners to focus design of models and experiments instead of environment design.less

PLE hopes to eventually build an expansive library of games.dom

而后关于FlappyBird的文档介绍在这里,文档的介绍仍是蛮清楚的。安装步骤以下所示,推荐在Pipenv的环境下安装,不过你也能够直接clone个人代码而后而后根据reademe的步骤进行使用。函数

git clone https://github.com/ntasfi/PyGame-Learning-Environment.git
cd PyGame-Learning-Environment/
pip install -e .

须要的库以下:

  • pygame
  • numpy
  • pillow

函数说明

官方文档有几个的函数在这里说下,由于等下咱们须要用到。

  • getGameState():得到游戏当前的状态,返回值为一个字典:

    1. player y position.
    2. players velocity.
    3. next pipe distance to player
    4. next pipe top y position
    5. next pipe bottom y position
    6. next next pipe distance to player
    7. next next pipe top y position
    8. next next pipe bottom y position

    部分数据表示以下:

  • reset_game():从新开始游戏

  • act(action):在游戏中执行一个动做,参数为动做,返回执行后的分数。

  • game_over():假如游戏结束,则返回True,否者返回False。

  • getActionSet():得到游戏的动做集合。

咱们的窗体大小默认是288*512,其中鸟的速度在-20到10之间(最小速度我并不知道,可是通过观察,并无小于-20的状况,而最大的速度在源代码里面已经说明好了为10)

Coding Time

在前面咱们说,经过getGameState()函数,咱们能够得到几个关于环境的数据,在这里咱们选择以下的数据:

  • next_pipe_dist_to_player:
  • player_y与next_pipe_top_y的差值
  • 🐦的速度

可是咱们能够想想,next_pipe_dist_to_player一共会有多少种的取值:由于窗体大小为288*512,则取值的范围大约是0~288,也就是说它大约有288个取值,而关于player_y与next_pipe_top_y的差值,则大概有1024个取值。这样很难让模型收敛,所以咱们将数值进行简化。其中简化的思路来自:GitHub

首先咱们建立一个Agent类,而后逐渐向里面添加功能。

class Agent():

    def __init__(self, action_space):
        # 得到游戏支持的动做集合
        self.action_set = action_space

        # 建立q-table
        self.q_table = np.zeros((6, 6, 6, 2))

        # 学习率
        self.alpha = 0.7
        # 励衰减因子
        self.gamma = 0.8
        # 贪婪率
        self.greedy = 0.8

至于为何q-table的大小是(6,6,6,2),其中的3个6分别表明next_pipe_dist_to_playerplayer_y与next_pipe_top_y的差值🐦的速度,其中的2表明动做的个数。也就是说,表格中的state一共有$6 \times6 \times 6 $种,表格的大小为\(6 \times6 \times 6 \times 2\)

缩小状态值的范围

咱们定义一个函数get_state(s),这个函数专门提取游戏中的状态,而后返回进行简化的状态数据:

def get_state(self, state):
        """
        提取游戏state中咱们须要的数据
        :param state: 游戏state
        :return: 返回提取好的数据
        """
        return_state = np.zeros((3,), dtype=int)
        dist_to_pipe_horz = state["next_pipe_dist_to_player"]
        dist_to_pipe_bottom = state["player_y"] - state["next_pipe_top_y"]
        velocity = state['player_vel']
        if velocity < -15:
            velocity_category = 0
        elif velocity < -10:
            velocity_category = 1
        elif velocity < -5:
            velocity_category = 2
        elif velocity < 0:
            velocity_category = 3
        elif velocity < 5:
            velocity_category = 4
        else:
            velocity_category = 5

        if dist_to_pipe_bottom < 8:  # very close or less than 0
            height_category = 0
        elif dist_to_pipe_bottom < 20:  # close
            height_category = 1
        elif dist_to_pipe_bottom < 50:  # not close
            height_category = 2
        elif dist_to_pipe_bottom < 125:  # mid
            height_category = 3
        elif dist_to_pipe_bottom < 250:  # far
            height_category = 4
        else:
            height_category = 5

        # make a distance category
        if dist_to_pipe_horz < 8:  # very close
            dist_category = 0
        elif dist_to_pipe_horz < 20:  # close
            dist_category = 1
        elif dist_to_pipe_horz < 50:  # not close
            dist_category = 2
        elif dist_to_pipe_horz < 125:  # mid
            dist_category = 3
        elif dist_to_pipe_horz < 250:  # far
            dist_category = 4
        else:
            dist_category = 5

        return_state[0] = height_category
        return_state[1] = dist_category
        return_state[2] = velocity_category
        return return_state

更新Q-table

更新的数学公式以下:

\[{\displaystyle Q^{new}(s_{t},a_{t})\leftarrow \underbrace {Q(s_{t},a_{t})} _{\text{旧的值}}+\underbrace {\alpha } _{\text{学习率}}\cdot \overbrace {{\bigg (}\underbrace {\underbrace {r_{t}} _{\text{奖励}}+\underbrace {\gamma } _{\text{奖励衰减因子}}\cdot \underbrace {\max _{a}Q(s_{t+1},a)} _{\text{estimate of optimal future value}}} _{\text{new value (temporal difference target)}}-\underbrace {Q(s_{t},a_{t})} _{\text{旧的值}}{\bigg )}} ^{\text{temporal difference}}} \]

下面是更新Q-table的函数代码:

def update_q_table(self, old_state, current_action, next_state, r):
    """

    :param old_state: 执行动做前的状态
    :param current_action: 执行的动做
    :param next_state: 执行动做后的状态
    :param r: 奖励
    :return:
    """
    next_max_value = np.max(self.q_table[next_state[0], next_state[1], next_state[2]])

    self.q_table[old_state[0], old_state[1], old_state[2], current_action] = (1 - self.alpha) * self.q_table[
        old_state[0], old_state[1], old_state[2], current_action] + self.alpha * (r + next_max_value)

选择最佳的动做

而后咱们就是根据q-table对应的Q值选择最大的那一个,其中第一个表明(也就是0)跳跃,第2个表明不执行任何操做。

选择的示意图以下:

代码以下所示:

def get_best_action(self, state, greedy=False):
    """
    得到最佳的动做
    :param state: 状态
    :是否使用ϵ-贪婪法
    :return: 最佳动做
    """
	
    # 得到q值
    jump = self.q_table[state[0], state[1], state[2], 0]
    no_jump = self.q_table[state[0], state[1], state[2], 1]
    # 是否执行策略
    if greedy:
        if np.random.rand(1) < self.greedy:
            return np.random.choice([0, 1])
        else:
            if jump > no_jump:
                return 0
            else:
                return 1
    else:
        if jump > no_jump:
            return 0
        else:
            return 1

更新\(\epsilon\)

这个比较简单,从前面的博客中,咱们知道\(\epsilon\)是随着训练次数的增长而减小的,有不少种策略能够选择,这里乘以\(0.95\)吧。

def update_greedy(self):
    self.greedy *= 0.95

执行动做

在官方文档中,若是小鸟没有死亡奖励为0,越过一个管道,奖励为1,死亡奖励为-1,咱们稍微的对其进行改变:

def act(self, p, action):
    """
    执行动做
    :param p: 经过p来向游戏发出动做命令
    :param action: 动做
    :return: 奖励
    """
    # action_set表示游戏动做集(119,None),其中119表明跳跃
    r = p.act(self.action_set[action])
    if r == 0:
        r = 1
    if r == 1:
        r = 10
    else:
        r = -1000
    return r

main函数

最后咱们就能够执行main函数了。

if __name__ == "__main__":
    # 训练次数
    episodes = 2000_000000
    # 实例化游戏对象
    game = FlappyBird()
    # 相似游戏的一个接口,能够为咱们提供一些功能
    p = PLE(game, fps=30, display_screen=False)
    # 初始化
    p.init()
    # 实例化Agent,将动做集传进去
    agent = Agent(p.getActionSet())
    max_score = 0
	
    for episode in range(episodes):
        # 重置游戏
        p.reset_game()
        # 得到状态
        state = agent.get_state(game.getGameState())
        agent.update_greedy()
        while True:
            # 得到最佳动做
            action = agent.get_best_action(state)
            # 而后执行动做得到奖励
            reward = agent.act(p, action)
            # 得到执行动做以后的状态
            next_state = agent.get_state(game.getGameState())
            # 更新q-table
            agent.update_q_table(state, action, next_state, reward)
            # 得到当前分数
            current_score = p.score()
            state = next_state
            if p.game_over():
                max_score = max(current_score, max_score)
                print('Episodes: %s, Current score: %s, Max score: %s' % (episode, current_score, max_score))
                # 保存q-table
                if current_score > 300:
                    np.save("{}_{}.npy".format(current_score, episode), agent.q_table)
                break

部分的训练的结果以下:

总结

emm,说实话,我也不知道结果会怎么样,由于训练的时间比较长,我不想放在个人电脑上面跑,而后我就放在树莓派上面跑,可是树莓派性能比较低,致使训练的速度比较慢。可是,我仍是以为个人方法有点问题,get_state()函数中简化的方法,我感受不是特别的合理,若是各位有好的见解,能够在评论区留言哦,而后共同窗习。

项目地址:https://github.com/xiaohuiduan/flappy-bird-q-learning

参考

相关文章
相关标签/搜索