介绍强化学习(Reinforcement Learning,RL)的概念,并用DQN训练一个会玩FlappyBird的模型python
这个游戏不少人都玩过,很虐,如下是一个用pygame重现的FlappyBird,https://github.com/sourabhv/FlapPyBirdgit
若是没有pygame则安装github
pip install pygame
运行flappy.py
便可开始游戏,若是出现按键没法控制的状况,用pythonw
运行代码便可算法
pythonw flappy.py
无监督学习没有标签,例如聚类;有监督学习有标签,例如分类;而强化学习介于二者之间,标签是经过不断尝试积累的网络
RL包括几个组成部分:app
这样一来,游戏的进行过程,无非是从一个初始S开始,执行A、获得R、进入下一个S,如此往复,直到进入一个终止Sdom
$$ s_0,a_0,r_1,s_1,a_1,r_2,s_2,...,s_{n-1},a_{n-1},r_n,s_n $$ide
定义一个函数,用来计算游戏过程当中回报的总和函数
$$ R=r_1+r_2+r_3+...+r_n $$学习
以及从某个时刻开始以后的回报总和
$$ R_t=r_t+r_{t+1}+r_{t+2}+...+r_n $$
但咱们对将来每一步能获取的回报并非彻底确定的,因此不妨乘上一个0到1之间的衰减系数
$$ R_t=r_t+\gamma r_{t+1}+\gamma^2 r_{t+2}+...+\gamma^{n-t} r_n $$
这样一来,能够获得相邻两步总回报之间的递推关系
$$ R_t=r_t+\gamma R_{t+1} $$
DQN是强化学习中的一种经常使用算法,主要是引入了Q函数(Quality,价值函数),用于计算在某个S下执行某个A能够获得的最大总回报
$$ Q(s_t,a_t)=\max R_{t+1} $$
有了Q函数以后,对于当前状态S,只须要计算每个A对应的Q值,而后选择Q值最大的一个A,即是最优的行动策略(策略函数)
$$ \pi(s)=argmax_a Q(s,a) $$
当Q函数收敛后,还能够获得Q函数的递推公式
$$ Q(s_t,a_t)=r_t+\gamma \max Q(s_{t+1},a_{t+1}) $$
可使用神经网络实现Q函数并训练:
关于强化学习和DQN的原理介绍,能够参考如下文章,https://ai.intel.com/demystifying-deep-reinforcement-learning/
基于如下项目进行修改,https://github.com/yenchenlin/DeepLearningFlappyBird
game
中的代码对以前的flappy.py
进行了简化和修改,去掉了背景图并固定角色和水管颜色,游戏会自动开始,挂掉以后也会自动继续,主要是便于模型自动进行和采集数据
加载库
# -*- coding: utf-8 -*- import tensorflow as tf import numpy as np import random import cv2 import sys sys.path.append('game/') import wrapped_flappy_bird as fb from collections import deque
定义一些参数
ACTIONS = 2 GAMMA = 0.99 OBSERVE = 10000 EXPLORE = 3000000 INITIAL_EPSILON = 0.1 FINAL_EPSILON = 0.0001 REPLAY_MEMORY = 50000 BATCH = 32 IMAGE_SIZE = 80
定义一些网络输入和辅助函数,每个S由连续的四帧游戏截图组成
S = tf.placeholder(dtype=tf.float32, shape=[None, IMAGE_SIZE, IMAGE_SIZE, 4], name='S') A = tf.placeholder(dtype=tf.float32, shape=[None, ACTIONS], name='A') Y = tf.placeholder(dtype=tf.float32, shape=[None], name='Y') k_initializer = tf.truncated_normal_initializer(0, 0.01) b_initializer = tf.constant_initializer(0.01) def conv2d(inputs, kernel_size, filters, strides): return tf.layers.conv2d(inputs, kernel_size=kernel_size, filters=filters, strides=strides, padding='same', kernel_initializer=k_initializer, bias_initializer=b_initializer) def max_pool(inputs): return tf.layers.max_pooling2d(inputs, pool_size=2, strides=2, padding='same') def relu(inputs): return tf.nn.relu(inputs)
定义网络结构,典型的卷积、池化、全链接层结构
h0 = max_pool(relu(conv2d(S, 8, 32, 4))) h0 = relu(conv2d(h0, 4, 64, 2)) h0 = relu(conv2d(h0, 3, 64, 1)) h0 = tf.contrib.layers.flatten(h0) h0 = tf.layers.dense(h0, units=512, activation=tf.nn.relu, bias_initializer=b_initializer) Q = tf.layers.dense(h0, units=ACTIONS, bias_initializer=b_initializer, name='Q') Q_ = tf.reduce_sum(tf.multiply(Q, A), axis=1) loss = tf.losses.mean_squared_error(Y, Q_) optimizer = tf.train.AdamOptimizer(1e-6).minimize(loss)
用一个队列实现记忆模块,开始游戏,对于初始状态选择什么都不作
game_state = fb.GameState() D = deque() do_nothing = np.zeros(ACTIONS) do_nothing[0] = 1 img, reward, terminal = game_state.frame_step(do_nothing) img = cv2.cvtColor(cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE)), cv2.COLOR_BGR2GRAY) _, img = cv2.threshold(img, 1, 255, cv2.THRESH_BINARY) S0 = np.stack((img, img, img, img), axis=2)
继续进行游戏并训练模型
sess = tf.Session() sess.run(tf.global_variables_initializer()) t = 0 success = 0 saver = tf.train.Saver() epsilon = INITIAL_EPSILON while True: if epsilon > FINAL_EPSILON and t > OBSERVE: epsilon = INITIAL_EPSILON - (INITIAL_EPSILON - FINAL_EPSILON) / EXPLORE * (t - OBSERVE) Qv = sess.run(Q, feed_dict={S: [S0]})[0] Av = np.zeros(ACTIONS) if np.random.random() <= epsilon: action_index = np.random.randint(ACTIONS) else: action_index = np.argmax(Qv) Av[action_index] = 1 img, reward, terminal = game_state.frame_step(Av) if reward == 1: success += 1 img = cv2.cvtColor(cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE)), cv2.COLOR_BGR2GRAY) _, img = cv2.threshold(img, 1, 255, cv2.THRESH_BINARY) img = np.reshape(img, (IMAGE_SIZE, IMAGE_SIZE, 1)) S1 = np.append(S0[:, :, 1:], img, axis=2) D.append((S0, Av, reward, S1, terminal)) if len(D) > REPLAY_MEMORY: D.popleft() if t > OBSERVE: minibatch = random.sample(D, BATCH) S_batch = [d[0] for d in minibatch] A_batch = [d[1] for d in minibatch] R_batch = [d[2] for d in minibatch] S_batch_next = [d[3] for d in minibatch] T_batch = [d[4] for d in minibatch] Y_batch = [] Q_batch_next = sess.run(Q, feed_dict={S: S_batch_next}) for i in range(BATCH): if T_batch[i]: Y_batch.append(R_batch[i]) else: Y_batch.append(R_batch[i] + GAMMA * np.max(Q_batch_next[i])) sess.run(optimizer, feed_dict={S: S_batch, A: A_batch, Y: Y_batch}) S0 = S1 t += 1 if t > OBSERVE and t % 10000 == 0: saver.save(sess, './flappy_bird_dqn', global_step=t) if t <= OBSERVE: state = 'observe' elif t <= OBSERVE + EXPLORE: state = 'explore' else: state = 'train' print('Current Step %d Success %d State %s Epsilon %.6f Action %d Reward %f Q_MAX %f' % (t, success, state, epsilon, action_index, reward, np.max(Qv)))
运行dqn_flappy.py
便可从零开始训练模型,一开始角色各类乱跳,一根水管都跳不过去,但随着训练的进行,角色会经过学习得到愈来愈稳定的表现
也能够直接使用如下代码运行训练好的模型
# -*- coding: utf-8 -*- import tensorflow as tf import numpy as np import cv2 import sys sys.path.append('game/') import wrapped_flappy_bird as fb ACTIONS = 2 IMAGE_SIZE = 80 sess = tf.Session() sess.run(tf.global_variables_initializer()) saver = tf.train.import_meta_graph('./flappy_bird_dqn-8500000.meta') saver.restore(sess, tf.train.latest_checkpoint('./')) graph = tf.get_default_graph() S = graph.get_tensor_by_name('S:0') Q = graph.get_tensor_by_name('Q/BiasAdd:0') game_state = fb.GameState() do_nothing = np.zeros(ACTIONS) do_nothing[0] = 1 img, reward, terminal = game_state.frame_step(do_nothing) img = cv2.cvtColor(cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE)), cv2.COLOR_BGR2GRAY) _, img = cv2.threshold(img, 1, 255, cv2.THRESH_BINARY) S0 = np.stack((img, img, img, img), axis=2) while True: Qv = sess.run(Q, feed_dict={S: [S0]})[0] Av = np.zeros(ACTIONS) Av[np.argmax(Qv)] = 1 img, reward, terminal = game_state.frame_step(Av) img = cv2.cvtColor(cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE)), cv2.COLOR_BGR2GRAY) _, img = cv2.threshold(img, 1, 255, cv2.THRESH_BINARY) img = np.reshape(img, (IMAGE_SIZE, IMAGE_SIZE, 1)) S0 = np.append(S0[:, :, 1:], img, axis=2)