Skip to content

Commit

Permalink
interface to python for step-by-step play
Browse files Browse the repository at this point in the history
  • Loading branch information
pierric committed Jan 10, 2025
1 parent 6c18d24 commit 63fa005
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 24 deletions.
2 changes: 1 addition & 1 deletion profiling/flamegraph-selfplay
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash

CKPT="${CKPT:-runs/001/tb_logs/chess/version_7/epoch:4-val_loss:4.325.pt}"
CKPT="${CKPT:-runs/000/last.pt}"

cargo flamegraph --bin smartchess -- -t trace.json -c ${CKPT} \
--temperature 1.0 --cpuct 2.5 --rollout-num 200 --num-steps 50
6 changes: 3 additions & 3 deletions src/debug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ fn debug_step(chess: game::Chess, filename: &str, target_step: usize) {
let steps = trace["steps"].as_array().unwrap();

let mut state = chess::BoardState::new();
let mut cursor = mcts::Cursor::new(mcts::Node {
let (mut cursor, _root) = mcts::Cursor::new(mcts::Node {
step: (None, chess::Color::White),
depth: 0,
q_value: 0.,
Expand Down Expand Up @@ -124,7 +124,7 @@ fn debug_trace(chess: game::Chess, filename: &str, target_step: usize, args: &Ar
let steps = trace["steps"].as_array().unwrap();

let mut state = chess::BoardState::new();
let mut cursor = mcts::Cursor::new(mcts::Node {
let (mut cursor, _root) = mcts::Cursor::new(mcts::Node {
step: (None, chess::Color::White),
depth: 0,
q_value: 0.,
Expand Down Expand Up @@ -270,7 +270,7 @@ fn bench_to_board() {
let steps = trace["steps"].as_array().unwrap();

let mut turn = chess::Color::White;
let mut cursor = mcts::Cursor::new(mcts::Node {
let (mut cursor, _root) = mcts::Cursor::new(mcts::Node {
step: (None, turn),
depth: 0,
q_value: 0.,
Expand Down
77 changes: 70 additions & 7 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,29 +113,89 @@ fn encode_board(view: chess::Color, board: chess::Board) -> PyResult<(PyObject,
}

#[allow(dead_code)]
struct ChessEngineState(game::ChessTS, chess::BoardState, mcts::Cursor<<chess::BoardState as game::State>::Step>);
struct ChessEngineState {
chess: game::ChessTS,
board: chess::BoardState,
root: mcts::ArcRefNode<<chess::BoardState as game::State>::Step>,
cursor: mcts::Cursor<<chess::BoardState as game::State>::Step>,
}


// NOTE: it is not possible to be Send safely without adding locks, as the cursor part
// carries mutable data. But for the use case, it is just right to fake a Send impl,
// so that we can wrap it into a PyCapsule.
unsafe impl Send for ChessEngineState {}

#[pyfunction]
fn new_chess_engine_state(checkpoint: &str) -> PyResult<Py<PyCapsule>> {
fn play_new(checkpoint: &str) -> PyResult<PyObject> {
let device = tch::Device::Cuda(0);
let chess = game::ChessTS {
model: tch::CModule::load_on_device(checkpoint, device).unwrap(),
device: device,
};

let board = chess::BoardState::new();
let cursor = mcts::Cursor::new(mcts::Node {
let (cursor, root) = mcts::Cursor::new(mcts::Node {
step: (None::<chess::Move>, chess::Color::White),
depth: 0,
q_value: 0.,
num_act: 0,
parent: None,
children: Vec::new(),
});
let state = ChessEngineState {
chess: chess,
board: board,
root: root,
cursor: cursor,
};

Python::with_gil(|py| {
//PyCapsule::new(py, ChessEngineState(chess, board, cursor), None).map(Bound::unbind)
});
panic!("")
let capsule = PyCapsule::new(py, state, None).unwrap();
Ok(capsule.unbind().into_any())
})
}

#[pyfunction]
fn play_mcts(state: Py<PyCapsule>, rollout: i32, cpuct: f32) {
Python::with_gil(|py| {
let state = unsafe { state.bind(py).reference::<ChessEngineState>() };
mcts::mcts(
&state.chess, state.cursor.arc(), &state.board, rollout, Some(cpuct)
);
})
}

#[pyfunction]
fn play_inspect(state: Py<PyCapsule>) -> PyResult<(PyObject, PyObject)>{
Python::with_gil(|py| {
let state = unsafe { state.bind(py).reference::<ChessEngineState>() };
let node = state.cursor.current();
let q_value = node.q_value;
let num_act_children: Vec<(chess::Move, i32, f32)> = node
.children
.iter()
.map(|n| {
let n = n.borrow();
(n.step.0.unwrap(), n.num_act, n.q_value)
})
.collect();

let q = q_value.into_pyobject(py).unwrap().unbind().into_any();
let a = num_act_children.into_pyobject(py).unwrap().unbind().into_any();
Ok((q, a))
})
}

#[pyfunction]
fn play_dump_search_tree(state: Py<PyCapsule>) -> PyResult<PyObject> {
Python::with_gil(|py| {
let state = unsafe { state.bind(py).reference::<ChessEngineState>() };
let json = serde_json::to_string(&*state.root.borrow()).unwrap();
let json_module = py.import("json")?;
let ret = json_module.getattr("loads")?.call1((json,))?.unbind();
Ok(ret)
})
}

/// A Python module implemented in Rust. The name of this function must match
Expand All @@ -146,6 +206,9 @@ fn libsmartchess(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(encode_board, m)?)?;
m.add_function(wrap_pyfunction!(encode_move, m)?)?;
m.add_function(wrap_pyfunction!(encode_steps, m)?)?;
m.add_function(wrap_pyfunction!(new_chess_engine_state, m)?)?;
m.add_function(wrap_pyfunction!(play_new, m)?)?;
m.add_function(wrap_pyfunction!(play_mcts, m)?)?;
m.add_function(wrap_pyfunction!(play_inspect, m)?)?;
m.add_function(wrap_pyfunction!(play_dump_search_tree, m)?)?;
Ok(())
}
13 changes: 2 additions & 11 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
#![feature(get_mut_unchecked)]

use clap::Parser;
use std::sync::Arc;
use std::cell::RefCell;

mod chess;
mod game;
Expand Down Expand Up @@ -69,19 +65,14 @@ fn main() {

let mut state = chess::BoardState::new();

// don't move this root! necessary to ensure the root alive.
// as the children has only a weak back reference, the parent might
// be recycled.
let root = Arc::new(RefCell::new(mcts::Node {
let (mut cursor, _root) = mcts::Cursor::new(mcts::Node {
step: (None, chess::Color::White),
depth: 0,
q_value: 0.,
num_act: 0,
parent: None,
children: Vec::new(),
}));

let mut cursor = mcts::Cursor::from_arc(root.clone());
});
let mut outcome = None;

for i in 0..args.num_steps {
Expand Down
42 changes: 42 additions & 0 deletions src/mcts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use crate::game::{Game, State};
use rand::distributions::WeightedIndex;
use rand::{thread_rng, Rng};
use rand_distr::{Dirichlet, Distribution};
use serde::{Serialize, Serializer};
use serde::ser::{SerializeStruct, SerializeSeq};
use std::iter::Sum;
use std::sync::{Arc, Weak};
use std::cell::{RefCell, Ref, RefMut};
Expand All @@ -18,6 +20,38 @@ pub struct Node<T> {
pub children: Vec<ArcRefNode<T>>,
}

unsafe impl<T: Send> Send for Node<T> {}

pub struct ChildrenList<'a, T>(&'a Vec<ArcRefNode<T>>);

impl<'a, T: Serialize> Serialize for ChildrenList<'a, T> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut seq = serializer.serialize_seq(Some(self.0.len()))?;
for elem in self.0 {
seq.serialize_element(&*elem.borrow())?;
}
seq.end()
}
}

impl<T: Serialize> Serialize for Node<T> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut node = serializer.serialize_struct("Node", 5)?;
node.serialize_field("step", &self.step)?;
node.serialize_field("depth", &self.depth)?;
node.serialize_field("q", &self.q_value)?;
node.serialize_field("num_act", &self.num_act)?;
node.serialize_field("children", &ChildrenList(&self.children))?;
node.end()
}
}

pub struct Cursor<T>(ArcRefNode<T>);

fn uct(
Expand Down Expand Up @@ -237,6 +271,14 @@ where
}

impl<T> Cursor<T> {
pub fn new(data: Node<T>) -> (Self, ArcRefNode<T>) {
// don't move this root! necessary to ensure the root alive.
// as the children has only a weak back reference, the parent might
// be recycled.
let arc = Arc::new(RefCell::new(data));
(Cursor(arc.clone()), arc)
}

pub fn from_arc(arc: ArcRefNode<T>) -> Self {
Cursor(arc)
}
Expand Down
4 changes: 2 additions & 2 deletions src/play.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ fn step(choice: usize, cursor: &mut MctsCursor, state: &mut chess::BoardState, t
.children
.iter()
.map(|n: &mcts::ArcRefNode<_>| {
let n: Ref<'_, mcts::Node<_>> = n.borrow();
let n: Ref<'_, _> = n.borrow();
(n.step.0.unwrap(), n.num_act, n.q_value)
})
.collect();
Expand Down Expand Up @@ -329,7 +329,7 @@ fn main() {

let mut trace = trace::Trace::new();
let mut state = chess::BoardState::new();
let mut cursor = mcts::Cursor::new(mcts::Node {
let (mut cursor, _root) = mcts::Cursor::new(mcts::Node {
step: (None, chess::Color::White),
depth: 0,
q_value: 0.,
Expand Down

0 comments on commit 63fa005

Please sign in to comment.