RustでAI その2
概要
前回 の続き,だいぶ間があいちゃったのでとりあえず雑だけど投稿してみる
数手先読み
- 数手先読みする
- SEARCH_DEPTHを定義
- 1.8.0ではRefCellにOrdが定義されていないのでRefNode(RefCell
)として定義 - RefNodeはタプルなので,borrow, borrow_mut,into_innerを呼び出されたときに中身をそのまま呼ぶ関数を定義
#[derive(Eq, PartialEq, Debug)]
struct RefNode(RefCell<Node>)
impl PartialOrd for RefNode {
fn partial_cmp(&self, other: &RefNode) -> Option<Ordering> {
self.0.borrow().partial_cmp(&*other.0.borrow())
}
fn lt(&self, other: &RefNode) -> bool {
*self.0.borrow() < *other.0.borrow()
}
#[inline]
fn le(&self, other: &RefNode) -> bool {
*self.0.borrow() <= *other.0.borrow()
}
#[inline]
fn gt(&self, other: &RefNode) -> bool {
*self.0.borrow() > *other.0.borrow()
}
#[inline]
fn ge(&self, other: &RefNode) -> bool {
*self.0.borrow() >= *other.0.borrow()
}
}
impl Ord for RefNode {
#[inline]
fn cmp(&self, other: &RefNode) -> Ordering {
self.0.borrow().cmp(&*other.0.borrow())
}
}
impl RefNode {
fn borrow(&self) -> Ref<Node> {
self.0.borrow()
}
fn borrow_mut(&self) -> RefMut<Node> {
self.0.borrow_mut()
}
fn into_inner(self) -> Node {
self.0.into_inner()
}
}
定義できたら,元記事を参考に幅優先探索を実装
探索後,一番スコアが高いノードの親ノードまでトラバース この時点の親ノードは,Rc(RefNode)となっているので,Rc::try_unwrapでRcをはがしてからRefNodeの中身をinto_innerで取り出す
search関数はこんな感じ
fn search(&mut self) -> Node {
let mut heap = BinaryHeap::new();
heap.push(self.head.clone().unwrap());
for _ in 0..SEARCH_DEPTH {
let mut tmp = heap.clone();
heap.clear();
while let Some(current_node) = tmp.pop() {
for i in 0..DX.len() {
let next_node = Node::new(Some(current_node.clone()));
next_node.borrow_mut().players[self.p].x1 += DX[i];
next_node.borrow_mut().players[self.p].y1 += DY[i];
next_node.borrow_mut().output = OUTPUT[i].to_string();
let score = next_node.borrow_mut().eval_player(self.p);
next_node.borrow_mut().score += score;
heap.push(next_node.clone());
}
}
}
// the top of heap is the best result after search
let mut node = heap.pop().unwrap();
// traverse until the next of head
loop {
let next_node;
match node.borrow_mut().parent.take() {
Some(parent) => {
next_node = parent;
}
None => { panic!(); }
}
if next_node == self.head.clone().unwrap() {
break;
}
node = next_node;
}
heap.clear();
Rc::try_unwrap(node).ok().unwrap().into_inner()
}
ビームサーチ版
ヒープをビーム幅に縮小する以外は同じ
fn search(&mut self) -> Node {
let mut heap = BinaryHeap::new();
heap.push(self.head.clone().unwrap());
for _ in 0..SEARCH_DEPTH {
let mut tmp = BinaryHeap::new();
for _ in 0..BEAM_WIDTH {
if let Some(data) = heap.pop() {
tmp.push(data);
} else {
break;
}
}
heap.clear();
while let Some(current_node) = tmp.pop() {
for i in 0..DX.len() {
let next_node = Node::new(Some(current_node.clone()));
next_node.borrow_mut().players[self.p].x1 += DX[i];
next_node.borrow_mut().players[self.p].y1 += DY[i];
next_node.borrow_mut().output = OUTPUT[i].to_string();
let score = next_node.borrow_mut().eval_player(self.p);
next_node.borrow_mut().score += score;
heap.push(next_node.clone());
}
}
}
// the top of heap is the best result after search
let mut node = heap.pop().unwrap();
// traverse until the next of head
loop {
let next_node;
match node.borrow_mut().parent.take() {
Some(parent) => {
next_node = parent;
}
None => { panic!(); }
}
if next_node == self.head.clone().unwrap() {
break;
}
node = next_node;
}
heap.clear();
Rc::try_unwrap(node).ok().unwrap().into_inner()
}
最終版
use std::io;
use std::cmp::Ordering;
use std::collections::BinaryHeap;
use std::rc::Rc;
use std::cell::{Ref, RefMut, RefCell};
macro_rules! print_err {
($($arg:tt)*) => (
{
use std::io::Write;
writeln!(&mut ::std::io::stderr(), $($arg)*).ok();
}
)
}
macro_rules! parse_input {
($x:expr, $t:ident) => ($x.trim().parse::<$t>().unwrap());
}
const DX: [i32; 4] = [1, 0, -1, 0];
const DY: [i32; 4] = [0, 1, 0, -1];
const OUTPUT: [&'static str; 4] = ["RIGHT", "DOWN", "LEFT", "UP"];
// const MAX_PLAYER_NUM: i32 = 4;
const COL: i32 = 30;
const ROW: i32 = 20;
const SEARCH_DEPTH: usize = 100;
const BEAM_WIDTH: usize = 20;
#[derive(Eq, PartialEq, Debug)]
struct RefNode(RefCell<Node>);
type Link = Option<Rc<RefNode>>;
struct Game {
head: Link,
n: usize,
p: usize,
}
#[derive(Eq, PartialEq, Debug, Clone)]
struct Node {
score: i32,
output: String,
parent: Link,
players: Vec<Player>,
}
#[derive(Eq, PartialEq, Debug, Clone)]
struct Player {
x0: i32,
y0: i32,
x1: i32,
y1: i32,
locked_field: [[bool; ROW as usize]; COL as usize],
}
impl Ord for Node {
fn cmp(&self, other: &Self) -> Ordering {
self.score.cmp(&other.score)
}
}
impl PartialOrd for Node {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.score.cmp(&other.score))
}
}
impl PartialOrd for RefNode {
fn partial_cmp(&self, other: &RefNode) -> Option<Ordering> {
self.0.borrow().partial_cmp(&*other.0.borrow())
}
fn lt(&self, other: &RefNode) -> bool {
*self.0.borrow() < *other.0.borrow()
}
#[inline]
fn le(&self, other: &RefNode) -> bool {
*self.0.borrow() <= *other.0.borrow()
}
#[inline]
fn gt(&self, other: &RefNode) -> bool {
*self.0.borrow() > *other.0.borrow()
}
#[inline]
fn ge(&self, other: &RefNode) -> bool {
*self.0.borrow() >= *other.0.borrow()
}
}
impl Ord for RefNode {
#[inline]
fn cmp(&self, other: &RefNode) -> Ordering {
self.0.borrow().cmp(&*other.0.borrow())
}
}
impl RefNode {
fn borrow(&self) -> Ref<Node> {
self.0.borrow()
}
fn borrow_mut(&self) -> RefMut<Node> {
self.0.borrow_mut()
}
fn into_inner(self) -> Node {
self.0.into_inner()
}
}
impl Game {
fn new() -> Self {
Game {
head: None,
n: 0,
p: 0,
}
}
fn input(&mut self) {
let new_node = Node::new(self.head.clone());
// input game data
let mut input_line = String::new();
io::stdin().read_line(&mut input_line).unwrap();
if self.head.is_none() {
// initialize game data & create new players
let inputs = input_line.split(" ").collect::<Vec<_>>();
self.n = parse_input!(inputs[0], usize);
self.p = parse_input!(inputs[1], usize);
new_node.borrow_mut().players = vec![Player::new(); self.n as usize];
}
// input player data
for i in 0..self.n as usize {
let mut input_line = String::new();
io::stdin().read_line(&mut input_line).unwrap();
let inputs = input_line.split(" ").collect::<Vec<_>>();
let ref mut player = new_node.borrow_mut().players[i];
player.x0 = parse_input!(inputs[0], i32);
player.y0 = parse_input!(inputs[1], i32);
player.x1 = parse_input!(inputs[2], i32);
player.y1 = parse_input!(inputs[3], i32);
}
for i in 0..self.n as usize {
let x0 = new_node.borrow().players[i].x0;
let y0 = new_node.borrow().players[i].y0;
let x1 = new_node.borrow().players[i].x1;
let y1 = new_node.borrow().players[i].y1;
for j in 0..self.n as usize {
new_node.borrow_mut().players[j].locked_field[x0 as usize][y0 as usize] = true;
new_node.borrow_mut().players[j].locked_field[x1 as usize][y1 as usize] = true;
}
}
self.head = Some(new_node);
}
fn search(&mut self) -> Node {
let mut heap = BinaryHeap::new();
heap.push(self.head.clone().unwrap());
for _ in 0..SEARCH_DEPTH {
let mut tmp = BinaryHeap::new();
for _ in 0..BEAM_WIDTH {
if let Some(data) = heap.pop() {
tmp.push(data);
} else {
break;
}
}
heap.clear();
while let Some(current_node) = tmp.pop() {
for i in 0..DX.len() {
let next_node = Node::new(Some(current_node.clone()));
next_node.borrow_mut().players[self.p].x1 += DX[i];
next_node.borrow_mut().players[self.p].y1 += DY[i];
next_node.borrow_mut().output = OUTPUT[i].to_string();
let score = next_node.borrow_mut().eval_player(self.p);
next_node.borrow_mut().score += score;
heap.push(next_node.clone());
}
}
}
// the top of heap is the best result after search
let mut node = heap.pop().unwrap();
// traverse until the next of head
loop {
let next_node;
match node.borrow_mut().parent.take() {
Some(parent) => {
next_node = parent;
}
None => { panic!(); }
}
if next_node == self.head.clone().unwrap() {
break;
}
node = next_node;
}
heap.clear();
Rc::try_unwrap(node).ok().unwrap().into_inner()
}
}
impl Node {
fn new(parent: Option<Rc<RefNode>>) -> Rc<RefNode> {
let new_node = if let Some(parent) = parent {
// create child
Node {
parent: Some(parent.clone()),
score: parent.borrow().score,
players: parent.borrow().players.clone(),
output: String::new(),
}
} else {
// create the first
Node {
parent: None,
score: 0,
players: vec![],
output: String::new(),
}
};
Rc::new(RefNode(RefCell::new(new_node)))
}
fn eval_player(&mut self, p: usize) -> i32 {
let x = self.players[p].x1;
let y = self.players[p].y1;
if self.players[p].can_move(x, y) && self.score >= 0 {
self.players[p].locked_field[x as usize][y as usize] = true;
0
} else {
-1
}
}
}
impl Player {
fn new() -> Self {
Player {
x0: -1, y0: -1, x1: -1, y1: -1,
locked_field: [[false; ROW as usize]; COL as usize]
}
}
fn can_move(&self, x: i32, y: i32) -> bool {
if x >= 0 && x < COL && y >= 0 && y < ROW && !self.locked_field[x as usize][y as usize] {
true
} else {
false
}
}
}
fn main() {
let mut game = Game::new();
loop {
game.input();
let ans = game.search();
println!("{}", ans.output);
}
}
感想
コンパイラがきちんと指摘してくれるの,それはそうなんだけど,アルゴリズムの部分とかが間違ってると結局正常には動かなくて逆にうるさく指摘される分そちらに注力してしまってアルゴリズムに集中できないような気がする.ただコンパイラのチェックを通るとその部分に対しては変なバグが発生することがなくて安心.
まぁ慣れればそこで間違えることはなくなるんだろう(実際今回も書いてくうちにどうすればいいかなんとなくわかってミスは減った(気がする)). けどやっぱ学習曲線が険しいなぁと.
rubyのpryとか,Haskellのghciみたいに,型とかライフタイムとか借用周りとか今どうなってるのか見れるREPLが欲しい.
たぶんもっと慣れてる人は,そのへんをほとんど無意識に適用できるようになってる気がするので,そういう人達の知見をグラフィカルな形で初心者にもわかりやすく説明してくれるツールがほしいんじゃ
ライフタイムとか,所有権がどこに行くかとかどこにあるかとかアニメーションで表示したりできないのかなと.
まとめると,Rust書いてて面白いし安全に極振りしてるの重要視されないかもだけど実は大事なのでもっとメジャーになるといいなと思いました.