Skip to content

Commit

Permalink
increase num filters for value head
Browse files Browse the repository at this point in the history
  • Loading branch information
pierric committed Dec 30, 2024
1 parent a9d42f0 commit 6cf0cb6
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions py/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,26 @@ def __init__(self):
)

self.value_head = torch.nn.Sequential(
torch.nn.Conv2d(256, 1, kernel_size=1, bias=False),
torch.nn.BatchNorm2d(1),
# torch.nn.Conv2d(256, 1, kernel_size=1, bias=False),
# torch.nn.BatchNorm2d(1),
# torch.nn.Flatten(),
# torch.nn.Linear(64, 64),
# torch.nn.ReLU(inplace=False),
# torch.nn.Linear(64, 1),
# torch.nn.Tanh(),
torch.nn.Conv2d(256, 32, kernel_size=1, bias=False),
torch.nn.BatchNorm2d(32),
torch.nn.Flatten(),
torch.nn.Linear(64, 64),
torch.nn.ReLU(inplace=False),
torch.nn.Linear(64, 1),
torch.nn.Linear(32 * 8 * 8, 128),
torch.nn.ReLU(inplace=True),
torch.nn.Linear(128, 1),
torch.nn.Tanh(),
)

self.policy_head = torch.nn.Sequential(
torch.nn.Conv2d(256, 128, kernel_size=1, bias=False),
torch.nn.BatchNorm2d(128),
torch.nn.ReLU(inplace=False),
torch.nn.ReLU(inplace=True),
torch.nn.Flatten(),
torch.nn.Linear(8 * 8 * 128, 8 * 8 * 73),
)
Expand Down

0 comments on commit 6cf0cb6

Please sign in to comment.