Skip to content

Commit

Permalink
quite some small changes
Browse files Browse the repository at this point in the history
  • Loading branch information
pierric committed Jan 30, 2025
1 parent 5c5331b commit 12f857a
Show file tree
Hide file tree
Showing 12 changed files with 933 additions and 560 deletions.
1,246 changes: 740 additions & 506 deletions notebooks/visualize_mcts.ipynb

Large diffs are not rendered by default.

58 changes: 55 additions & 3 deletions py/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
from torch.utils.data import Dataset, ConcatDataset


import libsmartchess

Expand All @@ -30,9 +31,11 @@ def _get_outcome(res):
def _prepare(boards, meta, dist, outcome):
inp = np.concatenate((boards, meta), axis=-1).astype(np.float32)
inp = inp.transpose((2, 0, 1))

# turn = meta[0, 0, 0]
# if turn == 0:
# outcome = -outcome

return (
torch.from_numpy(inp),
torch.from_numpy(dist),
Expand All @@ -42,8 +45,16 @@ def _prepare(boards, meta, dist, outcome):

class ChessDataset(Dataset):
def __init__(self, trace_file, apply_mirror=False):
with open(trace_file, "r") as f:
trace = json.load(f)
if isinstance(trace_file, dict):
trace = trace_file

elif isinstance(trace_file, io.TextIOBase):
trace = json.load(trace_file)

else:
assert isinstance(trace_file, str)
with open(trace_file, "r") as f:
trace = json.load(f)

self.outcome = _get_outcome(trace["outcome"])

Expand All @@ -70,6 +81,47 @@ def __getitem__(self, idx):


class ValidationDataset(Dataset):
def __init__(self, csv_file):
df = pd.read_csv(csv_file).iloc[:10]
traces = [self._to_trace(m, w) for m, w in zip(df.moves, df.winner)]
self.dataset = ConcatDataset([ChessDataset(t) for t in traces])

def __len__(self):
return len(self.dataset)

def __getitem__(self, idx):
return self.dataset[idx]

def _to_trace(self, pgn, winner):
node = chess.pgn.read_game(io.StringIO(pgn))
steps = []

while node:
num_acts = {m: 0 for m in node.board().legal_moves}
node = node.next()

if node is None:
break

move = node.move
num_acts[node.move] = 1
steps.append(
(move.uci(), 0.0, [(m.uci(), c, 0.0) for m, c in num_acts.items()])
)

outcome = {
"white": "White",
"black": "Black",
"draw": None,
}

return {
"outcome": {"winner": outcome[winner]},
"steps": steps,
}


class ValidationDataset2(Dataset):
def __init__(self, csv_file):
df = pd.read_csv(csv_file).iloc[:10]
self.plays = [self._encode(p) for p in df.moves]
Expand Down
18 changes: 11 additions & 7 deletions py/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self):
# 8 boards (14 channels each) + meta (7 channels)
self.conv_block = torch.nn.Sequential(
torch.nn.Conv2d(
14 * 8 + 7, 256, kernel_size=3, stride=1, padding=1, bias=False
14 * 8 + 7, 256, kernel_size=5, stride=1, padding=2, bias=False
),
torch.nn.BatchNorm2d(256),
torch.nn.ReLU(inplace=False),
Expand All @@ -52,20 +52,24 @@ def __init__(self):
# torch.nn.ReLU(inplace=False),
# torch.nn.Linear(64, 1),
# torch.nn.Tanh(),
torch.nn.Conv2d(256, 32, kernel_size=1, bias=False),
torch.nn.BatchNorm2d(32),
torch.nn.Conv2d(256, 512, kernel_size=1, bias=False),
torch.nn.BatchNorm2d(512),
torch.nn.Dropout2d(p=0.5),
torch.nn.AdaptiveAvgPool2d((1, 1)),
torch.nn.Flatten(),
torch.nn.Linear(32 * 8 * 8, 128),
torch.nn.ReLU(inplace=True),
torch.nn.Linear(128, 1),
# torch.nn.Dropout(p=0.5),
torch.nn.Linear(512, 1),
# torch.nn.ReLU(inplace=True),
# torch.nn.Linear(128, 1),
torch.nn.Tanh(),
)

self.policy_head = torch.nn.Sequential(
torch.nn.Conv2d(256, 128, kernel_size=1, bias=False),
torch.nn.BatchNorm2d(128),
torch.nn.ReLU(inplace=True),
torch.nn.Flatten(),
# torch.nn.Dropout(p=0.5),
torch.nn.ReLU(inplace=True),
torch.nn.Linear(8 * 8 * 128, 8 * 8 * 73),
)

Expand Down
2 changes: 2 additions & 0 deletions rust-toolchain.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[toolchain]
channel = "nightly"
49 changes: 49 additions & 0 deletions scripts/sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import os
import argparse
import pandas as pd


def main():
parser = argparse.ArgumentParser()
parser.add_argument("-l", "--list", required=True)
parser.add_argument("-r", "--ratio", default=1.0, type=float)
args = parser.parse_args()

assert args.ratio > 0

with open(args.list, "r") as fp:
files = list(filter(None, map(lambda l: l.strip(), fp.readlines())))

collection = []

for f in files:
comps = f.split(os.sep)
runs_idx = comps.index("runs")
assert runs_idx >= 0
collection.append(
{
"run_id": int(comps[runs_idx + 1]),
"file": f,
}
)

df = pd.DataFrame(collection)

latest_run_id = df.run_id.max()
older_runs = df[df.run_id < latest_run_id]
latest_runs = df[df.run_id == latest_run_id]

assert len(older_runs) >= len(
latest_runs
), "to sample the runs, there need more older runs."

selection = pd.concat(
(older_runs.sample(n=int(args.ratio * len(latest_runs))), latest_runs)
)

for f in selection.file:
print(f)


if __name__ == "__main__":
main()
50 changes: 25 additions & 25 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,24 +112,24 @@ def configure_optimizers(self):
optimizer = torch.optim.SGD(
self.parameters(), lr=self.config["lr"], momentum=0.9, weight_decay=3e-5
)
return optimizer
# return optimizer

# from torch.optim.lr_scheduler import OneCycleLR
from torch.optim.lr_scheduler import OneCycleLR

# scheduler = OneCycleLR(
# optimizer=optimizer,
# max_lr=self.config["lr"],
# total_steps=self.config["steps_per_epoch"] * self.config["epochs"],
# )
scheduler = OneCycleLR(
optimizer=optimizer,
max_lr=self.config["lr"],
total_steps=self.config["steps_per_epoch"] * self.config["epochs"],
)

# return {
# "optimizer": optimizer,
# "lr_scheduler": {
# "scheduler": scheduler,
# "interval": "step",
# "frequency": 1,
# },
# }
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"interval": "step",
"frequency": 1,
},
}


class ModelCheckpointAtEpochEnd(Callback):
Expand All @@ -140,19 +140,19 @@ def __init__(self, start, end, interval, model_compiled=True):
self.interval = interval
self.model_compiled = model_compiled

def on_train_epoch_end(self, trainer, pl_module):
epoch = trainer.current_epoch
def on_validation_end(self, trainer, pl_module):
step = trainer.global_step
callback_metrics = trainer.callback_metrics

if epoch < self.start or epoch >= self.end or (epoch + 1) % self.interval != 0:
if step < self.start or step >= self.end or (step + 1) % self.interval != 0:
return

val_loss1 = callback_metrics["val_syn_loss1/dataloader_idx_0"].item()
val_loss2 = callback_metrics["val_syn_loss2/dataloader_idx_0"].item()

path = os.path.join(
trainer.log_dir,
f"epoch:{epoch}-{val_loss1:0.3f}-{val_loss2:0.3f}.ckpt",
trainer.log_dir or ".",
f"step:{step}-{val_loss1:0.3f}-{val_loss2:0.3f}.ckpt",
)
print("saving checkpoint: ", path)

Expand Down Expand Up @@ -189,7 +189,7 @@ def main():
parser.add_argument("-c", "--last-ckpt", type=str)
parser.add_argument("-l", "--lr", type=float, default=1e-4)
parser.add_argument("-w", "--loss-weight", type=float, default=0.002)
parser.add_argument("--save-every-k", type=int, default=10)
parser.add_argument("--save-every-k", type=int, default=1)
parser.add_argument("--save-start", type=int, default=10)
parser.add_argument("--save-end", type=int, default=100)
parser.add_argument("--model-conf", type=str, default=None)
Expand Down Expand Up @@ -270,18 +270,18 @@ def main():
logger=logger,
callbacks=lightning_checkpoints,
max_epochs=config["epochs"],
log_every_n_steps=5,
log_every_n_steps=2,
# precision="16-mixed",
# val_check_interval=20,
val_check_interval=5,
)
trainer.fit(
model=module,
train_dataloaders=train_loader,
val_dataloaders=[val_syn, val_real],
)

# m = module.model._orig_mod if compile_model else module.model
# torch.save(m.state_dict(), os.path.join(trainer.log_dir, "last.ckpt"))
m = module.model._orig_mod if compile_model else module.model
torch.save(m.state_dict(), os.path.join(trainer.log_dir, "last.ckpt"))


if __name__ == "__main__":
Expand Down
12 changes: 12 additions & 0 deletions src/chess.rs
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,18 @@ impl BoardState {
})
}

#[allow(dead_code)]
pub fn turn(&self) -> Color {
Python::with_gil(|py| {
self.python_object
.getattr(py, intern!(py, "turn"))
.unwrap()
.extract::<i32>(py)
.unwrap()
.into()
})
}

pub fn to_board(&self) -> Board {
Python::with_gil(|py| self.python_object.extract(py).unwrap())
}
Expand Down
21 changes: 12 additions & 9 deletions src/game.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ fn call_ts_model(
meta: Array3<i32>,
turn: Color,
steps: &Vec<Move>,
return_full_distr: bool,
) -> (Vec<f32>, f32) {
use tch::Tensor;

Expand Down Expand Up @@ -292,6 +293,11 @@ fn call_ts_model(

let full_distr = full_distr.to_dtype(tch::Kind::Float, true, false);

if return_full_distr {
let moves_distr = Vec::<f32>::try_from(full_distr.squeeze().exp()).unwrap();
return (moves_distr, score)
}

// rotate if the next move is black
let encoded_moves: Vec<i64> = steps
.iter()
Expand All @@ -309,18 +315,12 @@ fn call_ts_model(
(moves_distr, score)
}

#[cached(
type = "SizedCache<(bool, Vec<Move>, Color), (Vec<(Option<Move>, Color)>, Vec<f32>, f32)>",
create = "{ SizedCache::with_size(10000) }",
convert = r#"{
(argmax, state.move_stack(), node.borrow().step.1)
}"#
)]
fn _chess_ts_predict(
pub fn _chess_ts_predict(
chess: &ChessTS,
node: &ArcRefNode<(Option<Move>, Color)>,
state: &BoardState,
argmax: bool,
return_full_distr: bool,
) -> (Vec<(Option<Move>, Color)>, Vec<f32>, f32) {
let legal_moves = state.legal_moves();

Expand All @@ -336,13 +336,16 @@ fn _chess_ts_predict(
let (encoded_boards, encoded_meta) = _encode(&node, state);
let turn = node.borrow().step.1;

assert!(turn == state.turn());

let (moves_distr, score) = call_ts_model(
&chess.model,
chess.device,
encoded_boards,
encoded_meta,
turn,
&legal_moves,
return_full_distr,
);

let moves_distr = _post_process_distr(moves_distr, argmax);
Expand All @@ -361,7 +364,7 @@ impl Game<BoardState> for ChessTS {
state: &BoardState,
argmax: bool,
) -> (Vec<<BoardState as State>::Step>, Vec<f32>, f32) {
_chess_ts_predict(self, node, state, argmax)
_chess_ts_predict(self, node, state, argmax, false)
}

fn reverse_q(&self, node: &ArcRefNode<<BoardState as State>::Step>) -> bool {
Expand Down
14 changes: 14 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,19 @@ fn play_dump_search_tree(state: Py<PyCapsule>) -> PyResult<PyObject> {
})
}

#[pyfunction]
fn play_inference(state: Py<PyCapsule>, full_distr: bool) -> PyResult<(PyObject, PyObject, PyObject)> {
Python::with_gil(|py| {
let state = unsafe { state.bind(py).reference::<RefCell<ChessEngineState>>().borrow() };

let (steps, prior, outcome) = game::_chess_ts_predict(&state.chess, &state.cursor.arc(), &state.board, false, full_distr);
let steps = steps.into_pyobject(py)?.unbind().into_any();
let prior = prior.into_pyobject(py)?.unbind().into_any();
let outcome = outcome.into_pyobject(py)?.unbind().into_any();
Ok((steps, prior, outcome))
})
}

/// A Python module implemented in Rust. The name of this function must match
/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to
/// import the module.
Expand All @@ -238,5 +251,6 @@ fn libsmartchess(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(play_step, m)?)?;
m.add_function(wrap_pyfunction!(play_inspect, m)?)?;
m.add_function(wrap_pyfunction!(play_dump_search_tree, m)?)?;
m.add_function(wrap_pyfunction!(play_inference, m)?)?;
Ok(())
}
Loading

0 comments on commit 12f857a

Please sign in to comment.