代码地址以下:
http://www.demodashi.com/demo/14072.htmlhtml
先上效果图:前端
配置的两款简单小游戏以及训练效果:jquery
*原图像太大被迫修改大小git
→在上面的主界面中点击倒三角形状的键,屏幕上会弹出一个黑色的设置窗。在该窗口界面上,用户能够经过拖动滑块条、在框内输入具体数值两种方法设置模型参数。滑块条和编辑框互联。github
→点击最小化按钮,将会复制浏览器地址到剪切板上,能够将其粘贴到浏览器中实时监测训练状况。窗口中的折线图每隔五秒从temp.db数据库中获取更新的数据并加入到折线图中,实施实时数据可视化。web
→当点击关闭按钮时,若训练次数超过1000帧,将会弹出窗口询问是否保存记录。不然会因为训练次数过少,对训练没有意义而直接退出不保存结果,以提升效率。算法
→点击确认sql
→成功保存数据库
→选择训练游戏json
→开始训练(点击播放按钮)
→鼠标放在进度条上能看到具体数值
→点击切换按钮
→此时再点击播放按钮,会弹出窗口用于选择加载模型
→点击开始按钮开始训练,同时设置窗口按钮、模式转换按钮都会失效,以确保训练顺利进行。
|————MyLibrary.py 用于设置游戏中人物等类
|————run_window.py 启动主程序,包括启动界面
|————mainwindow.py 主界面程序
|————setting.py 参数调节窗口程序
|————message_box.py 消息框窗口程序
|————DQL.py 人工智能主程序,负责选择和启动游戏、启动深度强化学习内核
|————DQLBrain.py 深度强化学习内核
|————game_setting.py 存储已有游戏决策状态数、库名等信息,新游戏加入必须将相关信息也加入在其中
|————flask_tk.py 服务器文件
|————jumpMan.py 跳跳人游戏文件
|————greedySnake.py 贪吃蛇游戏文件
|————resource 窗口图片资源文件夹
|————save_networks 已得出的模型文件
|————templates
|————index.html 网页前端模板文件
|————static
|————exporting.js
|————highcharts-zh_CN.js
|————highstock.js
|————jquery.js
|————temp.db 临时数据库,用于服务器和AI端数据交互使用
|————greedy_snake.data-00000-of-00001
|————greedy_snake.index
|————greedy_snake.meta 以上三个为一个训练好的模型
|————greedy_snake.db.bak
|————greedy_snake.db.dat
|————greedy_snake.db.dir 以上三个为一个模型文件
|————setting_resource.py 设定窗口的资源文件
|————resource_message_box.py 消息框窗口的资源文件
|————resource.py 主窗口的资源文件
|————document.py 根据数据库文件自动化生成报告
整个demo主要分为四大部分:主窗口、算法和游戏内核、服务器以及管理版本数据库文件部分。
import sys from mainWindow import MAINWINDOW from PyQt5.QtWidgets import QApplication,QSplashScreen from PyQt5 import QtCore,QtGui,QtWidgets if __name__ == '__main__': app = QApplication(sys.argv) #初始化启动界面 splash=QtWidgets.QSplashScreen(QtGui.QPixmap("启动界面.png")) #展现启动界面 splash.show() #设置计时器 timer = QtCore.QElapsedTimer() #计时器开始 timer.start() #保证启动界面出现3s while timer.elapsed() < 3000: app.processEvents() #初始化主界面 MainWindow = MAINWINDOW() #展现主界面 MainWindow.show() #主界面彻底加载后,启动界面消失 splash.finish(MainWindow) sys.exit(app.exec_())
import gameSetting import resource from PyQt5 import QtWidgets,QtCore,QtGui from collections import deque from threading import Thread from multiprocessing import Process import shelve import sqlite3 import socket import pyperclip from DQL import AI import setting import messageBox import webServers import glob import shutil game_start=False class myThread(Thread): def __init__(self,game,model,replay_memory,timestep,setting): Thread.__init__(self) self.game=game self.model=model self.setting=setting self.replay_memory=replay_memory self.timestep=timestep def run(self): self.AI = AI(self.game,self.model,self.replay_memory,self.timestep,int(self.setting["Explore"]),float(self.setting["Initial"]),float(self.setting["Final"]),float(self.setting["Gamma"]),int(self.setting["Replay"]),int(self.setting["Batch"]),) self.AI.playGame() def stop(self): self.AI.closeGame() class MAINWINDOW(QtWidgets.QWidget): def __init__(self, parent=None): #父类初始化 super().__init__() #主窗体对象初始化 self.setObjectName("Form") self.setEnabled(True) self.resize(681, 397) self.setStyleSheet("background-color: rgb(255, 255, 255);") self.setWindowFlags(QtCore.Qt.FramelessWindowHint) #进度条初始化 self.progressBar = QtWidgets.QProgressBar(self) self.progressBar.setEnabled(True) self.progressBar.setGeometry(QtCore.QRect(140, 348, 291, 23)) self.progressBar.setProperty("value", 0) self.progressBar.setTextVisible(False) self.progressBar.setObjectName("progressxzBar") #启动按钮初始化 self.control = QtWidgets.QPushButton(self) self.control.setGeometry(QtCore.QRect(10, 325, 71, 71)) self.control.setStyleSheet("border-image: url(:/bottom/resource/开始按钮.png);") self.control.setText("") self.control.setObjectName("control") self.control_state=False #下拉框初始化 self.game_selection = QtWidgets.QComboBox(self) self.game_selection.setEnabled(True) self.game_selection.setGeometry(QtCore.QRect(530, 343, 141, 31)) self.game_selection.setAutoFillBackground(False) self.game_selection.setStyleSheet("QComboBox{border-image: url(:/list/resource/下拉框.png)} \n""QComboBox::drop-down {image: url(:/bottom/resource/下拉框按钮.png) }") self.game_selection.setEditable(False) self.game_selection.setInsertPolicy(QtWidgets.QComboBox.NoInsert) self.game_selection.setIconSize(QtCore.QSize(0, 0)) self.game_selection.setFrame(False) self.game_selection.setObjectName("game_selection") #模式选择按钮加载 self.mode = QtWidgets.QPushButton(self) self.mode.setGeometry(QtCore.QRect(440, 340, 71, 41)) self.mode.setStyleSheet("border-image: url(:/bottom/resource/空白模式.png);\n""") self.mode.setText("") self.mode.setObjectName("mode") self.mode_state = False #背景图初始化 self.label = QtWidgets.QLabel(self) self.label.setGeometry(QtCore.QRect(0, 0, 681, 331)) self.label.setStyleSheet("border-image: url(:/image/resource/Background.png);") self.label.setText("") self.label.setObjectName("label") #设置按钮初始化 self.setting = QtWidgets.QPushButton(self) self.setting.setGeometry(QtCore.QRect(570, 10, 31, 21)) self.setting.setStyleSheet("border-image: url(:/bottom/resource/菜单.png);") self.setting.setText("") self.setting.setObjectName("setting") #获取ip地址按钮初始化 self.pushButton_3 = QtWidgets.QPushButton(self) self.pushButton_3.setGeometry(QtCore.QRect(610, 10, 31, 23)) self.pushButton_3.setStyleSheet("border-image: url(:/bottom/resource/最小化.png);") self.pushButton_3.setText("") self.pushButton_3.setObjectName("pushButton_3") #关闭按钮初始化 self.bottom_close = QtWidgets.QPushButton(self) self.bottom_close.setGeometry(QtCore.QRect(650, 10, 21, 23)) self.bottom_close.setStyleSheet("border-image: url(:/bottom/resource/关闭.png);") self.bottom_close.setText("") self.bottom_close.setObjectName("bottom_close") #重设界面 self.init_window(self) #按键消息槽设置 self.connectBottom() QtCore.QMetaObject.connectSlotsByName(self) #初始化窗口 def init_window(self, Form): _translate = QtCore.QCoreApplication.translate Form.setWindowTitle(_translate("Form", "深度强化学习工具箱")) #子窗口对象获取 self.setting_form = setting. SETTING() self.message_box=messageBox.MESSAGE_BOX() #游戏列表加载 game_setting_dict = gameSetting.getSetting() for i,game in enumerate(game_setting_dict.keys()): self.game_selection.addItem("") self.game_selection.setItemText(i, _translate("Form", game)) self.game_selection.setCurrentText(_translate("Form", list(game_setting_dict.keys())[0])) self.game_selection.setCurrentIndex(0) #启动服务器 flask_process = Process(target=webServers.start) flask_process.daemon = True flask_process.start() #统一实现按键与消息函数链接 def connectBottom(self): self.control.clicked.connect(self.loadGame) self.bottom_close.clicked.connect(self.closeWindow) self.mode.clicked.connect(self.setMode) self.setting.clicked.connect(self.openSetting) self.pushButton_3.clicked.connect(self.getIp) #界面可拖动设置 def mousePressEvent(self, event): if event.button() == QtCore.Qt.LeftButton: self.m_drag = True self.m_DragPosition = event.globalPos() - self.pos() event.accept() self.setCursor(QtGui.QCursor(QtCore.Qt.OpenHandCursor)) def mouseMoveEvent(self, QMouseEvent): if QtCore.Qt.LeftButton and self.m_drag: self.move(QMouseEvent.globalPos() - self.m_DragPosition) QMouseEvent.accept() def mouseReleaseEvent(self, QMouseEvent): self.m_drag = False self.setCursor(QtGui.QCursor(QtCore.Qt.ArrowCursor)) #加载按键操做 def loadGame(self): self.mode.setEnabled(False) self.setting.setEnabled(False) #开启游戏标志 global game_start game_start=True #control_state为按键标志,false为还没开始游戏,true为已经开始游戏。按键外形随状态改变 if self.control_state: self.closeWindow() else: #改变按键状态 self.control.setStyleSheet("border-image: url(:/bottom/resource/终止按钮.png);") self.control_state =True #初始化AI须要的变量 self.program_name = "" game=self.game_selection.currentText() model = "" replay_memory = deque() self.actual_timestep=0 setting=self.setting_form.getSetting() #若是导入已有项目文件,那么更新上述变量 if self.mode_state: program_path = QtWidgets.QFileDialog.getOpenFileName(self, "请选择你想要加载的项目", "../", "Model File (*.dat)") try: #获取项目名字(无后缀,包含地址) self.program_name=program_path[0][:-7] #打开项目文件 with shelve.open(self.program_name+'.db') as f: #加载项目信息 game=f["game"] model = self.program_name replay_memory = f["replay"] setting=f["setting"] self.actual_timestep = int(f["timestep"]) self.setting_form.updateSetting(setting) self.update_dataset(f["result"]) except: pass #启动游戏线程 self.game_thread = myThread(game,model,replay_memory,self.actual_timestep,setting) self.game_thread.start() #启动状态更新计时器 self.state_Timer = QtCore.QTimer() self.state_Timer.timeout.connect(self.updateState) self.state_Timer.start(5000) #关闭窗口 def closeWindow(self): timestep=0 #若是游戏根本没启动或者启动时间太短,那么按退出键则直接退出 #这里用try是由于有时候游戏启动太慢,超过五秒 try: timestep=self.state["TIMESTEP"] except: pass if timestep>1000: #启动对话框 reply = self.message_box.exec_() if reply: # 关闭游戏窗口 try: self.game_thread.AI.closeGame() except: pass #新建模式 if not self.program_name: save_program_path = QtWidgets.QFileDialog.getSaveFileName(self, "请选择你保存项目的位置", "../", "Program File(*.db)") #确保完成了完整保存操做后再进行操做 if save_program_path: #获取保存的程序地址和名称(无后缀) program_name = save_program_path[0].split(".")[0] #打开程序地址 self.saveProgram(save_program_path,0) #保存模型 self.saveModel(program_name) #加载模式 else: program_name=self.program_name try: self.saveProgram(program_name+'.db',1) except: pass #保存模型 self.saveModel(program_name) #清空临时数据库 with sqlite3.connect('temp.db', check_same_thread=False) as f: c = f.cursor() c.execute('delete from scores') f.commit() #关闭主界面窗口并终止计时器、服务器线程 self.close() #统一处理保存项目文件 def saveProgram(self,save_program_path,state): with shelve.open(save_program_path[0]) as f: # AI运行的设定 f["setting"] = self.setting_form.getSetting() # AI运行的状态 state = self.game_thread.AI.getState() f["game"] = self.game_selection.currentText() f["epsilon"] = state["EPSILON"] f["result"] = [[i[0] * 1000, i[1]] for i in sqlite3.connect('temp.db', check_same_thread=False).cursor().execute( 'select * from scores').fetchall()] f["replay"] = self.game_thread.AI.getReplay() if state: f["timestep"]=int(state["TIMESTEP"]) + int(f["timestep"]) else: f["timestep"] = state["TIMESTEP"] #定时更新主窗口状态 def updateState(self): #尝试获取游戏状态,若是启动时间过慢仍未启动则跳过这次获取 try: self.state = self.game_thread.AI.getState() except: pass else: actual_timestep=self.state["TIMESTEP"] self.progressBar.setToolTip("Timestep:"+str(actual_timestep)+" STATE:"+self.state["STATE"]+" EPSILON:"+str(self.state["EPSILON"])) self.progressBar.setProperty("value",min(float(actual_timestep)/float(self.setting_form.getSetting()["Explore"])*100,100)) #每隔5秒才向数据库读取一次,优化速度 try: self.game_thread.AI.data_base.commit() except: pass # 经过按键更改AI模式 def setMode(self): if not self.mode_state: self.mode_state = True self.mode.setStyleSheet("border-image: url(:/bottom/resource/加载模式.png);\n""") else: self.mode_state = False self.mode.setStyleSheet("border-image: url(:/bottom/resource/空白模式.png);\n""") # 获取本机ip地址 def getIp(self): try: sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock.connect(('8.8.8.8', 80)) ip = sock.getsockname()[0] finally: sock.close() pyperclip.copy(ip + ':9090') #定时更新数据库 def updateDataset(self,results): with shelve.open('temp.db',writeback=True) as f: c=f.cursor() for result in results: c.execute("insert into scores values (%s,%s)" % (result[0], result[1])) f.commit() # 保存模型 def saveModel(self, program_name): for file in glob.glob("./saved_networks/network-dqn-*"): postfix = file.split('.')[-1] try: shutil.copy(file, program_name + '.' + postfix) except: pass # 设置按键操做 def openSetting(self): self.setting_form.show()
from PyQt5 import QtCore, QtGui, QtWidgets import setting_resource class SETTING(QtWidgets.QWidget): def __init__(self): #父类初始化 super().__init__() #主窗口初始化 self.setObjectName("Dialog") self.resize(547, 402) self.setStyleSheet("") #初始化肯定按钮 self.pushButton = QtWidgets.QPushButton(self) self.pushButton.setGeometry(QtCore.QRect(160, 320, 75, 23)) self.pushButton.setStyleSheet("color: rgb(255, 255, 255);\n""border-image: url(:/image/resource/设定肯定按钮.png);") self.pushButton.setText("") self.pushButton.setObjectName("pushButton") #初始化取消按钮 self.pushButton_2 = QtWidgets.QPushButton(self) self.pushButton_2.setGeometry(QtCore.QRect(320, 320, 75, 23)) self.pushButton_2.setStyleSheet("color: rgb(255, 255, 255);\n""border-image: url(:/image/resource/设定取消按钮.png);") self.pushButton_2.setText("") self.pushButton_2.setObjectName("pushButton_2") #初始化各个编辑框 self.line_explore = QtWidgets.QLineEdit(self) self.line_explore.setGeometry(QtCore.QRect(450, 60, 61, 20)) self.line_explore.setStyleSheet("color: rgb(0, 0, 0);") self.line_explore.setObjectName("line_explore") self.line_initial = QtWidgets.QLineEdit(self) self.line_initial.setGeometry(QtCore.QRect(450, 100, 61, 20)) self.line_initial.setStyleSheet("color: rgb(0, 0, 0);") self.line_initial.setObjectName("line_Initial") self.line_final = QtWidgets.QLineEdit(self) self.line_final.setGeometry(QtCore.QRect(450, 140, 61, 20)) self.line_final.setStyleSheet("color: rgb(0, 0, 0);") self.line_final.setObjectName("line_final") self.line_gamma = QtWidgets.QLineEdit(self) self.line_gamma.setGeometry(QtCore.QRect(450, 180, 61, 20)) self.line_gamma.setStyleSheet("color: rgb(0, 0, 0);") self.line_gamma.setObjectName("line_gamma") self.line_replay = QtWidgets.QLineEdit(self) self.line_replay.setGeometry(QtCore.QRect(450, 220, 61, 20)) self.line_replay.setStyleSheet("color: rgb(0, 0, 0);") self.line_replay.setObjectName("line_replay") self.line_batch = QtWidgets.QLineEdit(self) self.line_batch.setGeometry(QtCore.QRect(450, 260, 61, 20)) self.line_batch.setStyleSheet("color: rgb(0, 0, 0);") self.line_batch.setObjectName("line_batch") self.exploreSlider = QtWidgets.QSlider(self) self.exploreSlider.setGeometry(QtCore.QRect(120, 60, 300, 19)) self.exploreSlider.setStyleSheet("QSlider::handle:horizontal { \n"" image: url(:/image/resource/Handle.png);\n""}\n""QSlider::groove:horizontal { \n"" image: url(:/image/resource/Base.png);\n""}\n""") self.exploreSlider.setMinimum(200000) self.exploreSlider.setMaximum(10000000) self.exploreSlider.setProperty("value", 200000) self.exploreSlider.setOrientation(QtCore.Qt.Horizontal) self.exploreSlider.setObjectName("exploreSlider") self.label = QtWidgets.QLabel(self) self.label.setGeometry(QtCore.QRect(50, 60, 48, 19)) self.label.setStyleSheet("color: rgb(255, 255, 255);") self.label.setObjectName("label") self.label_2 = QtWidgets.QLabel(self) self.label_2.setGeometry(QtCore.QRect(50, 100, 48, 19)) self.label_2.setStyleSheet("color: rgb(255, 255, 255);") self.label_2.setObjectName("label_2") self.initialSlider = QtWidgets.QSlider(self) self.initialSlider.setGeometry(QtCore.QRect(120, 100, 300, 19)) self.initialSlider.setStyleSheet("QSlider::handle:horizontal { \n"" image: url(:/image/resource/Handle.png);\n""}\n""QSlider::groove:horizontal { \n"" image: url(:/image/resource/Base.png);\n""}\n""") self.initialSlider.setMaximum(1000) self.initialSlider.setProperty("value", 0) self.initialSlider.setOrientation(QtCore.Qt.Horizontal) self.initialSlider.setObjectName("initialSlider") self.label_3 = QtWidgets.QLabel(self) self.label_3.setGeometry(QtCore.QRect(50, 140, 42, 19)) self.label_3.setStyleSheet("color: rgb(255, 255, 255);") self.label_3.setObjectName("label_3") self.finalSlider = QtWidgets.QSlider(self) self.finalSlider.setGeometry(QtCore.QRect(120, 140, 300, 19)) self.finalSlider.setStyleSheet("QSlider::handle:horizontal { \n"" image: url(:/image/resource/Handle.png);\n""}\n""QSlider::groove:horizontal { \n"" image: url(:/image/resource/Base.png);\n""}\n""") self.finalSlider.setMaximum(1000) self.finalSlider.setProperty("value", 0) self.finalSlider.setOrientation(QtCore.Qt.Horizontal) self.finalSlider.setObjectName("finalSlider") self.label_4 = QtWidgets.QLabel(self) self.label_4.setGeometry(QtCore.QRect(50, 180, 42, 19)) self.label_4.setStyleSheet("color: rgb(255, 255, 255);") self.label_4.setObjectName("label_4") self.gammaSlider = QtWidgets.QSlider(self) self.gammaSlider.setGeometry(QtCore.QRect(120, 180, 300, 19)) self.gammaSlider.setStyleSheet("QSlider::handle:horizontal { \n"" image: url(:/image/resource/Handle.png);\n""}\n""QSlider::groove:horizontal { \n"" image: url(:/image/resource/Base.png);\n""}\n""") self.gammaSlider.setMaximum(100) self.gammaSlider.setProperty("value", 99) self.gammaSlider.setOrientation(QtCore.Qt.Horizontal) self.gammaSlider.setObjectName("gammaSlider") self.label_6 = QtWidgets.QLabel(self) self.label_6.setGeometry(QtCore.QRect(50, 220, 42, 19)) self.label_6.setStyleSheet("color: rgb(255, 255, 255);") self.label_6.setObjectName("label_6") self.replaySlider = QtWidgets.QSlider(self) self.replaySlider.setGeometry(QtCore.QRect(120, 220, 300, 19)) self.replaySlider.setStyleSheet("QSlider::handle:horizontal { \n"" image: url(:/image/resource/Handle.png);\n""}\n""QSlider::groove:horizontal { \n"" image: url(:/image/resource/Base.png);\n""}\n""") self.replaySlider.setMaximum(100000) self.replaySlider.setProperty("value", 50000) self.replaySlider.setOrientation(QtCore.Qt.Horizontal) self.replaySlider.setObjectName("replaySlider") self.label_7 = QtWidgets.QLabel(self) self.label_7.setGeometry(QtCore.QRect(50, 260, 36, 19)) self.label_7.setStyleSheet("color: rgb(255, 255, 255);") self.label_7.setObjectName("label_7") self.batchSlider = QtWidgets.QSlider(self) self.batchSlider.setGeometry(QtCore.QRect(120, 260, 300, 19)) self.batchSlider.setStyleSheet("QSlider::handle:horizontal { \n"" image: url(:/image/resource/Handle.png);\n""}\n""QSlider::groove:horizontal { \n"" image: url(:/image/resource/Base.png);\n""}\n""") self.batchSlider.setMaximum(100) self.batchSlider.setProperty("value", 32) self.batchSlider.setOrientation(QtCore.Qt.Horizontal) self.batchSlider.setObjectName("batchSlider") self.label_5 = QtWidgets.QLabel(self) self.label_5.setGeometry(QtCore.QRect(0, 0, 551, 411)) self.label_5.setStyleSheet("background-image: url(:/background/resource/设定背景.png);") self.label_5.setText("") self.label_5.setObjectName("label_5") #组件挂起待用 self.label_5.raise_() self.pushButton.raise_() self.pushButton_2.raise_() self.line_explore.raise_() self.line_initial.raise_() self.line_final.raise_() self.line_gamma.raise_() self.line_replay.raise_() self.line_batch.raise_() self.exploreSlider.raise_() self.label.raise_() self.label_2.raise_() self.initialSlider.raise_() self.label_3.raise_() self.finalSlider.raise_() self.label_4.raise_() self.gammaSlider.raise_() self.label_6.raise_() self.replaySlider.raise_() self.label_7.raise_() self.batchSlider.raise_() #重设界面 self.retranslateUi(self) #编辑框和滑条互联 self.connect() #按钮消息槽激活 self.pushButton.clicked.connect(self.saveSetting) self.pushButton_2.clicked.connect(self.cancel) QtCore.QMetaObject.connectSlotsByName(self) def retranslateUi(self, Dialog): _translate = QtCore.QCoreApplication.translate Dialog.setWindowTitle(_translate("Dialog", "设置")) #初始化各编辑框 self.line_explore.setText(_translate("Dialog", "200000")) self.line_initial.setText(_translate("Dialog", "0")) self.line_final.setText(_translate("Dialog", "0")) self.line_gamma.setText(_translate("Dialog", "0.99")) self.line_replay.setText(_translate("Dialog", "50000")) self.line_batch.setText(_translate("Dialog", "32")) self.label.setText(_translate("Dialog", "Explore:")) self.label_2.setText(_translate("Dialog", "Initial:")) self.label_3.setText(_translate("Dialog", "Final:")) self.label_4.setText(_translate("Dialog", "Gamma:")) self.label_6.setText(_translate("Dialog", "Replay:")) self.label_7.setText(_translate("Dialog", "Batch:")) #初始化设定 self.setting={"Explore":200000,"Initial":0,"Final":0,"Gamma":0.99,"Replay":50000,"Batch":32} #编辑框和滑动条互联 def connect(self): self.exploreSlider.valueChanged.connect(self.changeLineExplore) self.line_explore.textChanged.connect(self.changeSliderExplore) self.initialSlider.valueChanged.connect(self.changeLineInitial) self.line_initial.textChanged.connect(self.changeSliderInitial) self.finalSlider.valueChanged.connect(self.changeLineFinal) self.line_final.textChanged.connect(self.changeSliderFinal) self.gammaSlider.valueChanged.connect(self.changeLineGamma) self.line_gamma.textChanged.connect(self.changeSliderGamma) self.replaySlider.valueChanged.connect(self.changeLineReplay) self.line_replay.textChanged.connect(self.changeSliderReplay) self.batchSlider.valueChanged.connect(self.changeLineBatch) self.line_batch.textChanged.connect(self.changeSliderBatch) def changeLineExplore(self): try: self.line_explore.setText(str(self.exploreSlider.value())) except: pass def changeSliderExplore(self): try: self.exploreSlider.setValue(int(self.line_explore.text())) except: pass def changeLineInitial(self): try: self.line_initial.setText(str(self.initialSlider.value()/1000)) except: pass def changeSliderInitial(self): try: self.initialSlider.setValue(int(float(self.line_initial.text())*1000)) except: pass def changeLineFinal(self): try: self.line_final.setText(str(self.finalSlider.value()/1000)) except: pass def changeSliderFinal(self): try: self.finalSlider.setValue(int(float(self.line_final.text()*1000))) except: pass def changeLineGamma(self): try: self.line_gamma.setText(str(self.gammaSlider.value()/100)) except: pass def changeSliderGamma(self): try: self.gammaSlider.setValue(int(100*float(self.line_gamma.text()))) except: pass def changeLineReplay(self): try: self.line_replay.setText(str(self.replaySlider.value())) except: pass def changeSliderReplay(self): try: self.replaySlider.setValue(int(self.line_replay.text())) except: pass def changeLineBatch(self): try: self.line_batch.setText(str(self.batchSlider.value())) except: pass def changeSliderBatch(self): try: self.batchSlider.setValue(int(self.line_batch.text())) except: pass #外部获取AI设置 def getSetting(self): return self.setting #保存设定 def saveSetting(self): self.setting={"Explore":self.line_explore.text(),"Initial":self.line_initial.text(),"Final":self.line_final.text(),"Gamma":self.line_gamma.text(),"Replay":self.line_replay.text(),"Batch":self.line_batch.text()}#还要作一个数字判断 self.hide() #取消设定 def cancel(self): self.hide() return 0 #经过导入文档更新设定 def updateSetting(self,setting): self.setting={"Explore":setting["Explore"],"Initial":setting["Initial"],"Final":setting["Final"],"Gamma":setting["Gamma"],"Replay":setting["Replay"],"Batch":setting["Batch"]}#还要作一个数字判断 self.line_explore.setText(str(setting["Explore"])) self.line_final.setText(str(setting["Final"])) self.line_Initial.setText(str(setting["Initial"])) self.line_gamma.setText(str(setting["Gamma"])) self.line_replay.setText(str(setting["Replay"])) self.line_batch.setText(str(setting["Batch"]))
DQL.py
import cv2
from DQLBrain import Brain
import numpy as np
from collections import deque
import sqlite3
import pygame
import time
import gameSetting
import importlib
#全部游戏的统一设置 SCREEN_X = 288 SCREEN_Y = 512 FPS = 60 class AI: def __init__(self, title,model_path,replay_memory,current_timestep,explore,initial_epsilon,final_epsilon,gamma,replay_size,batch_size): #初始化常量 self.scores = deque() self.games_info = gameSetting.getSetting() #链接临时数据库(并确保已经存在对应的表) self.data_base = sqlite3.connect('temp.db', check_same_thread=False) self.c = self.data_base.cursor() try: self.c.execute('create table scores (time integer, score integer) ') except: pass #建立Deep-Reinforcement Learning对象 self.brain = Brain(self.games_info[title]["action"],model_path,replay_memory,current_timestep,explore,initial_epsilon,final_epsilon,gamma,replay_size,batch_size) #建立游戏窗口 self.startGame(title,SCREEN_X,SCREEN_Y) #加载对应的游戏 game=importlib.import_module(self.games_info[title]['class']) self.game=game.Game(self.screen) def startGame(self,title,SCREEN_X, SCREEN_Y): #窗口的初始化 pygame.init() screen_size = (SCREEN_X, SCREEN_Y) pygame.display.set_caption(title) #屏幕的建立 self.screen = pygame.display.set_mode(screen_size) #游戏计时器的建立 self.clock = pygame.time.Clock() #为下降画面复杂度,将画面进行预处理 def preProcess(self, observation): #将512*288的画面裁剪为80*80并将RGB(三通道)画面转换成灰度图(一通道) observation = cv2.cvtColor(cv2.resize(observation, (80, 80)), cv2.COLOR_BGR2GRAY) #将非黑色的像素都变成白色 threshold,observation = cv2.threshold(observation, 1, 255, cv2.THRESH_BINARY) #返回(80,80,1),最后一维是保证图像是一个tensor(张量),用于输入tensorflow return np.reshape(observation, (80, 80, 1)) #开始游戏 def playGame(self): #先随便给一个决策输入,启动游戏 observation0, reward0, terminal,score =self.game.frameStep(np.array([1, 0, 0])) observation0 = self.preProcess(observation0) self.brain.setInitState(observation0[:,:,0]) #开始正式游戏 i = 1 while True: i = i + 1 action = self.brain.getAction() next_bservation, reward, terminal,score = self.game.frameStep(action) #处理游戏界面销毁消息 if (terminal == -1): self.closeGame() return else: #继续游戏 next_bservation = self.preProcess(next_bservation) self.brain.setPerception(next_bservation, action, reward, terminal) #提取每一局的成绩 if terminal: t = int(time.time()) self.c.execute("insert into scores values (%s,%s)" % (t, score)) #关闭游戏 def closeGame(self): pygame.quit() self.brain.close() time.sleep(0.5)#确保brain中写入数据库的操做已经完成 self.data_base.close() #得到当前游戏状态 def getState(self): return self.brain.getState() #得到当前replay数据,以加入项目文件 def getReplay(self): return self.brain.replay_memory
DQLBrain.py
observe=100
class Brain: def __init__(self, actions,model_path,replay_memory=deque(),current_timestep=0,explore=200000.,initial_epsilon=0.0,final_epsilon=0.0,gamma=0.99,replay_size=50000,batch_size=32): # 设置超参数: # 学习率 self.gamma = gamma # 训练以前观察的次数 self.observe = observe # 容错率降低的次数 self.explore = explore # 一开始的容错率 self.initial_epsilon = initial_epsilon #最终的容错率 self.final_epsilon = final_epsilon # replay buffer的大小 self.replay_size = replay_size # minibatch的大小 self.batch_size = batch_size self.update_time = 100 self.whole_state = dict() #初始化replay buffer self.replay_memory = replay_memory # 初始化其余参数 self.timestep = 0 self.initial_timestep=current_timestep self.accual_timestep=self.initial_timestep+self.timestep #当主界面采用加载模式时,算法核心必须从新加载项目文件中的已经记录的容错率 self.epsilon = self.initial_epsilon - (self.initial_epsilon - self.final_epsilon) / self.explore * self.accual_timestep if self.epsilon<self.final_epsilon: self.epsilon=self.final_epsilon self.actions = actions # 初始化 Q_t+1 网络 self.state_input, self.QValue, self.conv1_w, self.conv1_b, self.conv2_w, self.conv2_b, self.conv3_w, self.conv3_b, self.fc1_w, self.fc1_b, self.fc2_w, self.fc2_b = self.createQNetwork() # 初始化 Q_t 网络 self.state_inputT, self.QValueT, self.conv1_wT, self.conv1_bT, self.conv2_wT, self.conv2_bT, self.conv3_wT, self.conv3_bT, self.fc1_wT, self.fc1_bT, self.fc2_wT, self.fc2_bT = self.createQNetwork() self.copyTargetQNetwork = [self.conv1_wT.assign(self.conv1_w), self.conv1_bT.assign(self.conv1_b), self.conv2_wT.assign(self.conv2_w), self.conv2_bT.assign(self.conv2_b), self.conv3_wT.assign(self.conv3_w), self.conv3_bT.assign(self.conv3_b), self.fc1_wT.assign(self.fc1_w), self.fc1_bT.assign(self.fc1_b), self.fc2_wT.assign(self.fc2_w), self.fc2_bT.assign(self.fc2_b)] #损失函数的设置 self.action_input = tf.placeholder("float", [None, self.actions]) self.y_input = tf.placeholder("float", [None]) Q_Action = tf.reduce_sum(tf.multiply(self.QValue, self.action_input), reduction_indices=1) self.cost = tf.reduce_mean(tf.square(self.y_input - Q_Action)) self.optimizer = tf.train.AdamOptimizer(1e-6).minimize(self.cost) # 保存和从新加载模型 self.saver = tf.train.Saver(max_to_keep=1) self.session = tf.InteractiveSession() self.session.run(tf.initialize_all_variables()) def createQNetwork(self): # 初始化结构 # 第一层卷积层 8*8*4*32 W_conv1 = self.weightVariable([8, 8, 4, 32]) b_conv1 = self.biasVariable([32]) # 第二层卷积层 4*4*32*64: W_conv2 = self.weightVariable([4, 4, 32, 64]) b_conv2 = self.biasVariable([64]) #第三层卷积层 3*3*64*64 W_conv3 = self.weightVariable([3, 3, 64, 64]) b_conv3 = self.biasVariable([64]) #全链接层1600*512 W_fc1 = self.weightVariable([1600, 512]) b_fc1 = self.biasVariable([512]) #输出层 512*actions W_fc2 = self.weightVariable([512, self.actions]) b_fc2 = self.biasVariable([self.actions]) # input layer stateInput = tf.placeholder("float", [None, 80, 80, 4]) # 开始创建网络 # 隐藏层 h_conv1 = tf.nn.relu(self.conv2d(stateInput, W_conv1, 4) + b_conv1) #20*20*32 to 10*10*32 h_pool1 = self.maxPool_2x2(h_conv1) h_conv2 = tf.nn.relu(self.conv2d(h_pool1, W_conv2, 2) + b_conv2) #stride=1,5*5*64 to 5*5*64 h_conv3 = tf.nn.relu(self.conv2d(h_conv2, W_conv3, 1) + b_conv3) #5*5*64 to 1*1600 h_conv3_flat = tf.reshape(h_conv3, [-1, 1600]) h_fc1 = tf.nn.relu(tf.matmul(h_conv3_flat, W_fc1) + b_fc1) #输出层 QValue = tf.matmul(h_fc1, W_fc2) + b_fc2 return stateInput, QValue, W_conv1, b_conv1, W_conv2, b_conv2, W_conv3, b_conv3, W_fc1, b_fc1, W_fc2, b_fc2 def trainQNetwork(self): #从replay buffer中抽样 minibatch = random.sample(self.replay_memory, self.batch_size) state_batch = [data[0] for data in minibatch] action_batch = [data[1] for data in minibatch] reward_batch = [data[2] for data in minibatch] nextState_batch = [data[3] for data in minibatch] #计算损失函数 y_batch = [] QValue_batch = self.QValueT.eval(feed_dict={self.state_inputT: nextState_batch}) for i in range(0, self.batch_size): terminal = minibatch[i][4] if terminal: y_batch.append(reward_batch[i]) else: y_batch.append(reward_batch[i] + self.gamma * np.max(QValue_batch[i])) self.optimizer.run(feed_dict={self.y_input: y_batch, self.action_input: action_batch, self.state_input: state_batch}) # 每运行100epoch保存一次网络 if self.timestep % 1000 == 0: self.saver.save(self.session, './saved_networks/network' + '-dqn', global_step=self.timestep+self.initial_timestep) #更新Q网络 if self.timestep % self.update_time == 0: self.session.run(self.copyTargetQNetwork) def setPerception(self, nextObservation, action, reward, terminal): new_state = np.append(self.current_state[:, :, 1:], nextObservation, axis=2) self.replay_memory.append((self.current_state, action, reward, new_state, terminal)) #控制replay buffer的大小 if len(self.replay_memory) > self.replay_size: self.replay_memory.popleft() if self.timestep > self.observe: self.trainQNetwork() # 将训练信息输出到主界面中 if self.timestep <= self.observe: state = "observe" elif self.timestep > self.observe and self.timestep <= self.observe + self.explore: state = "explore" else: state = "train" self.whole_state={"TIMESTEP":self.timestep +self.initial_timestep,"STATE":state, "EPSILON":self.epsilon,"ACTUAL":int(self.timestep+self.initial_timestep)} self.current_state = new_state self.timestep += 1 def getAction(self): QValue = self.QValue.eval(feed_dict={self.state_input: [self.current_state]})[0] action = np.zeros(self.actions) #epsilon策略 if random.random() <= self.epsilon: action_index = random.randrange(self.actions) action[action_index] = 1 else: action_index = np.argmax(QValue) action[action_index] = 1 # 改变episilon if self.epsilon > self.final_epsilon and self.accual_timestep > self.observe: self.epsilon = self.initial_epsilon - (self.initial_epsilon - self.final_epsilon) / self.explore * self.accual_timestep return action def setInitState(self, observation): self.current_state = np.stack((observation, observation, observation, observation), axis=2) def weightVariable(self, shape): initial = tf.truncated_normal(shape, stddev=0.01) return tf.Variable(initial) def biasVariable(self, shape): initial = tf.constant(0.01, shape=shape) return tf.Variable(initial) def conv2d(self, x, W, stride): return tf.nn.conv2d(x, W, strides=[1, stride, stride, 1], padding="SAME") def maxPool_2x2(self, x): return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="SAME") def close(self): self.session.close() def getState(self): return self.whole_state
主要采用highchart的API。在static文件夹中放好上述的四项文件后,在template文件夹中写好服务器界面的代码index.html(为了方便你们学习,界面写得至关简陋hh):
<head> <script src='/static/jquery.js'></script> <script src='/static/highstock.js'></script> <script src='/static/exporting.js'></script> </head> <body> <div id="container" style="min-width:310px;height:400px"></div> <script> $(function () { // 使用当前时区,不然东八区会差八个小时 Highcharts.setOptions({ global: { useUTC: false } }); $.getJSON('/data', function (data) { // Create the chart $('#container').highcharts('StockChart', { chart:{ events:{ load:function(){ var series = this.series[0] setInterval(function(){ $.getJSON('/data',function(res){ $.each(res,function(i,v){ series.addPoint(v) }) }) },3000) } } }, rangeSelector : { selected : 1 }, title : { text : '每局分数' }, series : [{ name : '训练表现', data : data, tooltip: { valueDecimals: 2 } }] }); }); }); </script> </body> </html>
同时还须要编写一个实时调用该模板的py文件:Webservice.py:
from flask import Flask,render_template,request import sqlite3 import json app=Flask(__name__) #链接临时数据库 data_base = sqlite3.connect('temp.db', check_same_thread=False) c = data_base.cursor() #设置前端模板 @app.route('/') def index(): return render_template("index.html") #设置数据来源 @app.route('/data') def data(): global tmp_time,c sql='select * from scores' c.execute(sql) arr=[] for i in c.fetchall(): arr.append([i[0]*1000,i[1]]) return json.dumps(arr) #启动服务器并设定端口,设置0.0.0.0表示对内网全部主机都进行监听 def start(): app.run(host='0.0.0.0',port=9090)
不过貌似PyQt5和tensorflow会有冲突,所以实际运行的时候会偶尔出现崩溃。另外服务器没法由外网的机器链接。若是你们知道怎么解决这些问题请在下方留言告诉我,谢谢!最后再来一次:github地址为https://github.com/qq303067814/DQLearning-Toolbox, 若是讲解中有部分还想继续了解的话能够直接查看源代码,或者在留言中提出。训练简单小游戏的强化学习工具箱
代码地址以下:
http://www.demodashi.com/demo/14072.html
注:本文著做权归做者,由demo大师代发,拒绝转载,转载须要做者受权