Skip to content

Commit

Permalink
various small changes
Browse files Browse the repository at this point in the history
  • Loading branch information
pierric committed Feb 19, 2025
1 parent d065b68 commit f3c9678
Show file tree
Hide file tree
Showing 12 changed files with 425 additions and 3,013 deletions.
280 changes: 215 additions & 65 deletions notebooks/verify_dataset.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion py/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def _get_outcome(res):


def _prepare(boards, meta, dist, outcome):
inp = np.concatenate((boards, meta), axis=-1).astype(np.float32)
inp = np.concatenate((boards.astype(np.float32), meta.astype(np.float32)), axis=-1)
inp = inp.transpose((2, 0, 1))

# turn = meta[0, 0, 0]
Expand Down
37 changes: 37 additions & 0 deletions py/debug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import numpy as np
import chess


def decode_pieces(arr, color):
types = [
chess.PAWN,
chess.KNIGHT,
chess.BISHOP,
chess.ROOK,
chess.QUEEN,
chess.KING,
]

all_pieces = []

for idx, typ in enumerate(types):
rs, fs = np.where(arr[idx] == 1)
squares = [chess.square(file, rank) for file, rank in zip(fs, rs)]
piece = chess.Piece(typ, color)
all_pieces.extend([(sq, piece, False) for sq in squares])

return all_pieces


def decode_board(arr):
if np.array_equal(arr[:12], 0):
return None

pieces_white = decode_pieces(arr[0:6, :, :], chess.WHITE)
pieces_black = decode_pieces(arr[6:12, :, :], chess.BLACK)

board = chess.Board(fen=None)
for square, piece, promoted in pieces_white + pieces_black:
board.set_piece_at(square, piece, promoted)

return board
27 changes: 13 additions & 14 deletions py/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,23 +45,22 @@ def __init__(self):
)

self.value_head = torch.nn.Sequential(
# torch.nn.Conv2d(256, 1, kernel_size=1, bias=False),
# torch.nn.BatchNorm2d(1),
torch.nn.Conv2d(256, 16, kernel_size=1, bias=False),
torch.nn.BatchNorm2d(16),
torch.nn.Flatten(),
torch.nn.Linear(1024, 256),
torch.nn.ReLU(inplace=False),
torch.nn.Linear(256, 1),
torch.nn.Tanh(),
# torch.nn.Conv2d(256, 64, kernel_size=1, bias=False),
# torch.nn.BatchNorm2d(64),
# # # torch.nn.AvgPool2d(kernel_size=8),
# torch.nn.Dropout2d(p=0.5),
# torch.nn.Flatten(),
# torch.nn.Linear(64, 64),
# torch.nn.Linear(64 * 8 * 8, 512),
# torch.nn.ReLU(inplace=False),
# torch.nn.Linear(64, 1),
# torch.nn.Linear(512, 1),
# torch.nn.Tanh(),
torch.nn.Conv2d(256, 64, kernel_size=1, bias=False),
torch.nn.BatchNorm2d(64),
# torch.nn.Dropout2d(p=0.5),
torch.nn.AvgPool2d(kernel_size=8),
torch.nn.Flatten(),
# torch.nn.Dropout(p=0.5),
torch.nn.Linear(64, 1),
# torch.nn.ReLU(inplace=True),
# torch.nn.Linear(128, 1),
torch.nn.Tanh(),
)

self.policy_head = torch.nn.Sequential(
Expand Down
4 changes: 2 additions & 2 deletions scripts/leader-board
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ export -f task
#set -ex
## export BLACK=''
## parallel -j3 task {#} 3 version_0 99 -d cuda --rollout=300 --temperature 0 --cpuct 1 ::: $(seq 1 100)
export BLACK="runs/017/tb_logs/chess/version_13/step:469-3.548-0.326.pt2"
export WHITEBASE="tb_logs/chess/version_3"
export BLACK="${BLACK:-runs/017/tb_logs/chess/version_13/step:469-3.548-0.326.pt2}"
export WHITEBASE="${WHITEBASE:-tb_logs/chess/version_3}"
export WHITEVER="${WHITEVER:-step:1008-3.512-0.768.pt2}"

ROLLOUT=${ROLLOUT:-20}
Expand Down
2 changes: 1 addition & 1 deletion scripts/run_batch
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
set -eux
source .env
# source .env

ARGS=$@
N=${N:-100}
Expand Down
69 changes: 65 additions & 4 deletions scripts/sample.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,30 @@
import os
import json
import argparse
import pandas as pd
from tqdm import tqdm


def main():
parser = argparse.ArgumentParser()
parser.add_argument("-l", "--list", required=True)
parser.add_argument("-r", "--ratio", default=1.0, type=float)
subparser = parser.add_subparsers(dest="command", required=True)
pmix = subparser.add_parser("mix")
pmix.add_argument("-r", "--ratio", default=1.0, type=float)
psplit = subparser.add_parser("split")
psplit.add_argument("-r", "--ratio", default=0.1, type=float)
psplit.add_argument("-v", "--val", required=True)
psplit.add_argument("-t", "--train", required=True)
args = parser.parse_args()

if args.command == "mix":
mix(args)

else:
split(args)


def mix(args):
assert args.ratio > 0

with open(args.list, "r") as fp:
Expand All @@ -18,11 +34,16 @@ def main():

for f in files:
comps = f.split(os.sep)
runs_idx = comps.index("runs")
assert runs_idx >= 0
try:
runs_idx = comps.index("runs")
except ValueError:
run_id = int(comps[0])
else:
run_id = int(comps[runs_idx + 1])

collection.append(
{
"run_id": int(comps[runs_idx + 1]),
"run_id": run_id,
"file": f,
}
)
Expand All @@ -45,5 +66,45 @@ def main():
print(f)


def split(args):
with open(args.list, "r") as fp:
files = list(filter(None, map(lambda l: l.strip(), fp.readlines())))

def _get_result(f):
o = json.load(open(f))
return (o.get("outcome") or {}).get("winner", None) or "draw"

df = pd.DataFrame([{"filename": f, "result": _get_result(f)} for f in tqdm(files)])
groups = df.groupby(by="result")
nmax = groups.count().loc[["White", "Black"]].max().item()

d = groups.get_group("draw").sample(n=nmax)
w = groups.get_group("White")
b = groups.get_group("Black")

w = w.sample(frac=1).reset_index(drop=True)
b = b.sample(frac=1).reset_index(drop=True)

def _split(d):
nval = int(len(d) * args.ratio)
assert nval > 0
return d.iloc[:nval], d.iloc[nval:]

val_d, train_d = _split(d)
val_w, train_w = _split(w)
val_b, train_b = _split(b)

val = pd.concat([val_d, val_w, val_b])
train = pd.concat([train_d, train_w, train_b])

print(f"val split: {len(val)} samples")
print(f"train split: {len(train)} samples")

assert set(val.filename).isdisjoint(set(train.filename))

open(args.train, "w").writelines("\n".join(train.filename))
open(args.val, "w").writelines("\n".join(val.filename))


if __name__ == "__main__":
main()
67 changes: 53 additions & 14 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Optional
from multiprocessing import Pool
from functools import partial
from tqdm import tqdm

import numpy as np
import chess
Expand All @@ -31,12 +32,12 @@ def __init__(self, config):
n_res_blocks=model_conf.teacher,
checkpoint=config["last_ckpt"],
inference=True,
compile=config["compile_model"],
compile=False,
)
self.model = load_model(
n_res_blocks=model_conf.student,
inference=False,
compile=config["compile_model"],
compile=False,
)

else:
Expand All @@ -45,8 +46,23 @@ def __init__(self, config):
n_res_blocks=model_conf,
checkpoint=config["last_ckpt"],
inference=False,
compile=config["compile_model"],
compile=False,
)
self.teacher = None

if config["freeze_backbone"]:
trainable_weights = []
prefixes = ["model.conv_block.", "model.res_blocks."]
for name, param in self.named_parameters():
if not any(name.startswith(p) for p in prefixes):
trainable_weights.append(name)
continue
param.requires_grad = False
print("Trainnable weights: ", trainable_weights)

if config["compile_model"]:
self.model.compile(mode="reduce-overhead", fullgraph=True)
# skip compiling the teacher, as we don't use do transfer-learning yet.

def compute_loss1(self, log_dist_pred, dist_gt):
batch_size = dist_gt.shape[0]
Expand Down Expand Up @@ -123,7 +139,9 @@ def configure_optimizers(self):
optimizer = torch.optim.SGD(
self.parameters(), lr=self.config["lr"], momentum=0.9, weight_decay=3e-5
)
return optimizer

if self.config["lr_scheduler"] == "constant":
return optimizer

from torch.optim.lr_scheduler import OneCycleLR

Expand All @@ -150,14 +168,18 @@ def __init__(self, start, end, interval, model_compiled=True):
self.end = end
self.interval = interval
self.model_compiled = model_compiled
self.counting = 0

def on_validation_end(self, trainer, pl_module):
step = trainer.global_step
callback_metrics = trainer.callback_metrics
self.counting += 1

if step < self.start or step >= self.end or (step + 1) % self.interval != 0:
if step < self.start or step >= self.end or self.counting < self.interval:
return

self.counting = 0

val_loss1 = callback_metrics["val_syn_loss1/dataloader_idx_0"].item()
val_loss2 = callback_metrics["val_syn_loss2/dataloader_idx_0"].item()

Expand All @@ -168,8 +190,9 @@ def on_validation_end(self, trainer, pl_module):
print("saving checkpoint: ", path)

m = pl_module.model
if self.model_compiled:
m = m._orig_mod
## necessary only if calling torch.compile instead of model.compile
# if self.model_compiled:
# m = m._orig_mod

torch.save(m.state_dict(), path)

Expand Down Expand Up @@ -206,6 +229,9 @@ def main():
parser.add_argument("--model-conf", type=str, default=None)
parser.add_argument("--val-data", type=str, default="py/validation/sample.csv")
parser.add_argument("--train-batch-size", type=int, default=1024)
parser.add_argument(
"--lr-scheduler", type=str, choices=["constant", "onecycle"], default="constant"
)
parser.add_argument("--freeze-backbone", action="store_true")
parser.add_argument("--no-compile", action="store_true")
args = parser.parse_args()
Expand All @@ -228,7 +254,12 @@ def main():
]

with Pool(12) as p:
dss = p.map(ChessDataset, args.trace_file_for_train)
dss = list(
tqdm(
p.imap(partial(ChessDataset), args.trace_file_for_train),
total=len(args.trace_file_for_train),
)
)
train_split = ConcatDataset(dss)

# train_split, val_syn_split = random_split(dss, [0.8, 0.2])
Expand All @@ -239,11 +270,16 @@ def main():
batch_size=args.train_batch_size,
shuffle=True,
drop_last=True,
persistent_workers=True,
persistent_workers=True,
)

with Pool(12) as p:
dss = p.map(partial(ChessDataset), args.trace_file_for_val)
dss = list(
tqdm(
p.imap(partial(ChessDataset), args.trace_file_for_val),
total=len(args.trace_file_for_val),
)
)
val_syn_split = ConcatDataset(dss)

val_syn = DataLoader(
Expand All @@ -252,7 +288,7 @@ def main():
batch_size=128,
shuffle=False,
drop_last=True,
persistent_workers=True,
persistent_workers=True,
)

val_real = DataLoader(
Expand All @@ -261,7 +297,7 @@ def main():
batch_size=128,
shuffle=False,
drop_last=True,
persistent_workers=True,
persistent_workers=True,
)

config = dict(
Expand All @@ -278,6 +314,8 @@ def main():
loss_weight=args.loss_weight,
compile_model=compile_model,
model_conf=TransferConf.parse(args.model_conf) or int(args.model_conf),
freeze_backbone=args.freeze_backbone,
lr_scheduler=args.lr_scheduler,
)

module = ChessLightningModule(config)
Expand All @@ -287,7 +325,7 @@ def main():
callbacks=lightning_checkpoints,
max_epochs=config["epochs"],
log_every_n_steps=10,
precision="bf16-mixed",
# precision="bf16-mixed",
val_check_interval=20,
)
trainer.fit(
Expand All @@ -296,7 +334,8 @@ def main():
val_dataloaders=[val_syn, val_real],
)

m = module.model._orig_mod if compile_model else module.model
# m = module.model._orig_mod if compile_model else module.model
m = module.model
torch.save(m.state_dict(), os.path.join(trainer.log_dir, "last.ckpt"))


Expand Down
Loading

0 comments on commit f3c9678

Please sign in to comment.