您當前的位置:首頁 > 舞蹈

讓神經網路自學畫圈打叉(Tic Tac Toe)遊戲

作者:由 MatthewYan 發表于 舞蹈時間:2022-09-06

之前我用監督式學習(Supervised Learning)訓練出來的神經網路五子棋程式效果很糟糕(見此文章

MatthewYan:訓練卷積神經網路下五子棋

),於是嘗試強化學習(Reinforcement learning)。因為五子棋棋盤為15x15,也就意味著最多每一步有225種可能的走法,很可能訓練很久也不會得到理想效果,所以這次我沒有直接從五子棋入手,而是用一個極簡版的“三子棋”——畫圈打叉遊戲來測試我的想法。

畫圈打叉遊戲大家都不陌生,就是兩名玩家在一個3x3的格子板上畫圈(O)或者打叉(X)。先將圈或者叉連城一條線(3連子)的玩家勝利,如果到了最後不能走了,就判定為平局(draw)。

這次試驗用了Q-learning(一種強化學習演算法)的方法,利用神經網路來充當Q-table,讓計算機和自己對戰,然後將收集到的資料簡單處理,然後訓練神經網路。再將訓練過的神經網路和自己對戰,繼續收集處理資料訓練神經網路,反覆以上過程多次(2500次)得到一個具有一定“棋力”神經網路。

聽起來很簡單,但是如何獲取和處理用來訓練神經網路的資料呢?

在Q-learning中, 我們需要一系列的狀態(States),行為(Actions),以及每個行為發生後產生的獎勵(Rewards。注意,雖然說是“獎勵”,這個值如果為負我們就看作是懲罰)。

舉個“慄?子”,假設你是神經網路,你從來沒有玩過一個叫“乒乓”的電腦遊戲。那麼所謂某一時刻的狀態(state)就是你眼睛觀察到的該時刻的一個遊戲場景(相當於一張圖片)。

讓神經網路自學畫圈打叉(Tic Tac Toe)遊戲

一個狀態

行為(action)則是你在觀察到這個狀態後所做出的反應,“乒乓”遊戲裡,你可以選擇向上(UP)或者向下(DOWN)移動球板,還可以選擇不動(STAY)。那麼每一時刻你能做出的行為就有3種可能。

讓神經網路自學畫圈打叉(Tic Tac Toe)遊戲

用神經網路玩“乒乓”遊戲

例子到此結束!

這樣,在我們的實驗中,輸入神經網路的input data就是畫圈打叉遊戲的棋盤(可看作3x3的灰度畫素圖),神經網路的輸出則是(9個可能actions的機率,實際遊戲中則取最大機率的action來執行)。

接下來我簡單說一下(我就不在本文詳細講了,因為本文只關注Q-learning在TicTacToe裡的應用)最難理解且在強化學習裡面最關鍵的部分,獎勵值!

由於我們不太容易在每次action之後都得到一個獎勵值(我們叫它instant reward:即時獎勵),我們在每次遊戲結束後給勝利方一個正獎勵值,給失敗方一個負獎勵值(懲罰),再透過簡單的一個discount函式來給剛剛那一局遊戲所有的(State,Action)附一個Reward值,具體見連結(英文原文: Deep Q-Learning with Keras and Gym)。

讓神經網路自學畫圈打叉(Tic Tac Toe)遊戲

這是一回合遊戲結束後產生的相應獎勵

動手吧(訓練一個只有一層200個神經元的小神經網路):

讓神經網路自學畫圈打叉(Tic Tac Toe)遊戲

用時6分30秒,2400多次訓練(每次訓練大約幾十局對局),loss十分接近0

實驗結果是神經網路具有一定的“棋力”,但是我們還是可以戰勝它。我分析了一下原因,其實很簡單,我們一直在讓神經網路跟自己對局,但卻沒有考慮到當神經網路訓練到一定程度以後,可能會陷入一個“區域性最優”的情況,我們可以理解為“神經網路自我感覺良好”,結果當人類玩家除了一招它沒有“想到”的走法,神經網路則無法做出很好的反應。解決方法也很簡單,只需要在self-paly的時候加入點隨機因素來避免每次action都是由神經網路產生的就好了。這相當於我們平衡了“開發(exploitation)”和“探索(exploration)”(想法來自於蒙特卡洛樹搜尋Monte Carlo Tree Search)。

最後,我把原始碼釋出在了Github上,有興趣的朋友可以拿我程式碼改進。

(更新: 新程式碼增加了隨機因素,訓練出來的神經網路比前代要強(不過還是有點缺陷),但訓練時間也增加了好幾倍!!)

https://

github。com/18369766918/

Q-Learning-Tic-Tac-Toe

18369766918/Q-Learning-Tic-Tac-Toe

附上增加了隨機探索後的Loss變化圖:

讓神經網路自學畫圈打叉(Tic Tac Toe)遊戲

Loss

讓神經網路自學畫圈打叉(Tic Tac Toe)遊戲

隨機率變化曲線(每100次訓練減少10%)