diff --git a/setup.py b/setup.py index f227a2c..9149b1b 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ name = 'x-clip', packages = find_packages(exclude=[]), include_package_data = True, - version = '0.3.0', + version = '0.4.0', license='MIT', description = 'X-CLIP', author = 'Phil Wang', diff --git a/x_clip/visual_ssl.py b/x_clip/visual_ssl.py index 3a21705..2743a19 100644 --- a/x_clip/visual_ssl.py +++ b/x_clip/visual_ssl.py @@ -91,8 +91,8 @@ def nt_xent_loss(queries, keys, temperature = 0.1): logits = logits[~mask].reshape(n, n - 1) logits /= temperature - labels = torch.cat(((torch.arange(b, device=device) + b - 1), torch.arange(b, device=device)), dim=0) - loss = F.cross_entropy(logits, labels, reduction='sum') + labels = torch.cat(((torch.arange(b, device = device) + b - 1), torch.arange(b, device=device)), dim=0) + loss = F.cross_entropy(logits, labels, reduction = 'sum') loss /= n return loss @@ -105,20 +105,27 @@ def loss_fn(x, y): # MLP class for projector and predictor -class MLP(nn.Module): - def __init__(self, dim, projection_size, hidden_size = None): - super().__init__() - hidden_size = default(hidden_size, dim) +def MLP(dim, projection_size, hidden_size = None): + hidden_size = default(hidden_size, dim) - self.net = nn.Sequential( - nn.Linear(dim, hidden_size), - nn.BatchNorm1d(hidden_size), - nn.ReLU(inplace=True), - nn.Linear(hidden_size, projection_size) - ) + return nn.Sequential( + nn.Linear(dim, hidden_size), + nn.BatchNorm1d(hidden_size), + nn.ReLU(inplace = True), + nn.Linear(hidden_size, projection_size) + ) - def forward(self, x): - return self.net(x) +def SimSiamMLP(dim, projection_size, hidden_size = 4096): + return nn.Sequential( + nn.Linear(dim, hidden_size, bias = False), + nn.BatchNorm1d(hidden_size), + nn.ReLU(inplace = True), + nn.Linear(hidden_size, hidden_size, bias=False), + nn.BatchNorm1d(hidden_size), + nn.ReLU(inplace = True), + nn.Linear(hidden_size, projection_size, bias = False), + nn.BatchNorm1d(projection_size, affine = False) + ) # a wrapper class for the base neural network # will manage the interception of the hidden layer output @@ -159,7 +166,7 @@ def _register_hook(self): @singleton('projector') def _get_projector(self, hidden): _, dim = hidden.shape - projector = MLP(dim, self.projection_size, self.projection_hidden_size) + projector = SimSiamMLP(dim, self.projection_size, self.projection_hidden_size) return projector.to(hidden) def get_representation(self, x):