何かAIっぽいことがやってみたくて強化学習の勉強をはじめてみました。
初めから本とか読むと自分の場合。。。難しくて眠くなっちゃうので、ゲームっぽい面白そうなところからやってみようと思います。
ちょうど、PyTorchチュートリアルにスーパーマリオのプレイ動画があったのでやってみました。
pytorch.org
参考にしたサイト
qiita.com
インストール
スーパーマリオの OpenAI Gym用
スーパーマリオの OpenAI Gym環境をインストール
pip install -q gym-super-mario-bros==7.3.0
matplotlibでアニメーション出力
gif/mp4でアニメーションファイルを出力用
$ pip install matplotlib $ brew install imagemagick # gif用 $ brew install ffmpeg # mp4用
プログラム
必要なライブラリをインポート
必要なライブラリをインポートします。
import torch from torch import nn from torchvision import transforms as T from PIL import Image import numpy as np from pathlib import Path from collections import deque import random, datetime, os, copy # Gym is an OpenAI toolkit for RL import gym from gym.spaces import Box from gym.wrappers import FrameStack # NES Emulator for OpenAI Gym from nes_py.wrappers import JoypadSpace # Super Mario environment for OpenAI Gym import gym_super_mario_bros
環境を初期化
# Initialize Super Mario environment env = gym_super_mario_bros.make("SuperMarioBros-1-1-v0") # Limit the action-space to # 0. walk right # 1. jump right env = JoypadSpace(env, [["right"], ["right", "A"]]) env.reset() next_state, reward, done, info = env.step(action=0) print(f"{next_state.shape},\n {reward},\n {done},\n {info}")
設定値
設定値を指定します。
# 試行回数 EPISODE_NUMB = 10 # 最大試行時間 MAX_TIME = 600
マリオを動かしてみる
import copy frames = [] for i in range(EPISODE_NUMB): observation = env.reset() # reset for each new trial done = False total_reward = 0 total_time = 0 time = 0 while not done and total_time < MAX_TIME: frames.append(copy.deepcopy(env.render(mode = 'rgb_array'))) action = env.action_space.sample() # select a random action next_state, reward, done, info = env.step(action) total_reward += reward total_time += 1 print('test episode:', i, 'reward:', total_reward, 'time:', total_time)
プレイ結果
import matplotlib.pyplot as plt import matplotlib.animation import numpy as np matplotlib.rcParams['animation.embed_limit'] = 20**128 plt.figure(figsize=(frames[0].shape[1] / 72.0, frames[0].shape[0] / 72.0), dpi = 72) patch = plt.imshow(frames[0]) plt.axis('off') animate = lambda i: patch.set_data(frames[i]) ani = matplotlib.animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval = 50)
[gif]でファイル出力する場合
ani.save("anim.gif", writer = 'imagemagick')
[mp4]でファイル出力する場合
ani.save('anim.mp4', writer="ffmpeg")
マリオのプレイ動画
なんとなく動いたけど、高い土管がこえられない。。。( ノД`)シクシク…
youtu.be