ものづくりのブログ

うちのネコを題材にしたものづくりができたらいいなと思っていろいろ奮闘してます。

強化学習でマリオに挑戦 #1

何かAIっぽいことがやってみたくて強化学習の勉強をはじめてみました。
初めから本とか読むと自分の場合。。。難しくて眠くなっちゃうので、ゲームっぽい面白そうなところからやってみようと思います。
ちょうど、PyTorchチュートリアルにスーパーマリオのプレイ動画があったのでやってみました。
pytorch.org
参考にしたサイト
qiita.com
f:id:a1026302:20210111032405j:plain

インストール

スーパーマリオの OpenAI Gym用

スーパーマリオの OpenAI Gym環境をインストール
pip install -q gym-super-mario-bros==7.3.0

matplotlibでアニメーション出力

gif/mp4でアニメーションファイルを出力用
$ pip install matplotlib
$ brew install imagemagick # gif用
$ brew install ffmpeg # mp4用

プログラム

必要なライブラリをインポート

必要なライブラリをインポートします。

import torch
from torch import nn
from torchvision import transforms as T
from PIL import Image
import numpy as np
from pathlib import Path
from collections import deque
import random, datetime, os, copy

# Gym is an OpenAI toolkit for RL
import gym
from gym.spaces import Box
from gym.wrappers import FrameStack

# NES Emulator for OpenAI Gym
from nes_py.wrappers import JoypadSpace

# Super Mario environment for OpenAI Gym
import gym_super_mario_bros
環境を初期化
# Initialize Super Mario environment
env = gym_super_mario_bros.make("SuperMarioBros-1-1-v0")

# Limit the action-space to
#   0. walk right
#   1. jump right
env = JoypadSpace(env, [["right"], ["right", "A"]])

env.reset()
next_state, reward, done, info = env.step(action=0)
print(f"{next_state.shape},\n {reward},\n {done},\n {info}")
設定値

設定値を指定します。

# 試行回数
EPISODE_NUMB = 10

# 最大試行時間
MAX_TIME = 600

マリオを動かしてみる

import copy

frames = []
for i in range(EPISODE_NUMB):
    observation = env.reset()  # reset for each new trial
    done = False
    total_reward = 0
    total_time = 0
    time = 0
    while not done and total_time < MAX_TIME: 
        frames.append(copy.deepcopy(env.render(mode = 'rgb_array')))
        action = env.action_space.sample()  # select a random action
        next_state, reward, done, info = env.step(action)
        total_reward += reward
        total_time += 1
    print('test episode:', i, 'reward:', total_reward, 'time:', total_time)

プレイ結果

import matplotlib.pyplot as plt
import matplotlib.animation
import numpy as np

matplotlib.rcParams['animation.embed_limit'] = 20**128
plt.figure(figsize=(frames[0].shape[1] / 72.0, frames[0].shape[0] / 72.0), dpi = 72)
patch = plt.imshow(frames[0])
plt.axis('off')
animate = lambda i: patch.set_data(frames[i])
ani = matplotlib.animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval = 50)
[gif]でファイル出力する場合
ani.save("anim.gif", writer = 'imagemagick')
[mp4]でファイル出力する場合
ani.save('anim.mp4', writer="ffmpeg")

マリオのプレイ動画

なんとなく動いたけど、高い土管がこえられない。。。( ノД`)シクシク…
youtu.be