diff --git a/src/torchattack/fia.py b/src/torchattack/fia.py index f645e3a..bc84c95 100644 --- a/src/torchattack/fia.py +++ b/src/torchattack/fia.py @@ -79,8 +79,8 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: if self.alpha is None: self.alpha = self.eps / self.steps - h = self.feature_layer.register_forward_hook(self.__forward_hook) - h2 = self.feature_layer.register_full_backward_hook(self.__backward_hook) + h = self.feature_layer.register_forward_hook(self.__forward_hook) # type: ignore + h2 = self.feature_layer.register_full_backward_hook(self.__backward_hook) # type: ignore # Gradient aggregation on ensembles agg_grad = 0 @@ -93,7 +93,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: loss += output_random[batch_i][y[batch_i]] # type: ignore self.model.zero_grad() loss.backward() # type: ignore - agg_grad += self.mid_grad[0].detach() + agg_grad += self.mid_grad[0].detach() # type: ignore for batch_i in range(x.shape[0]): agg_grad[batch_i] /= agg_grad[batch_i].norm(p=2) # type: ignore h2.remove() @@ -104,10 +104,10 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: _ = self.model(self.normalize(x + delta)) # Hooks are updated during forward pass - loss = (self.mid_output * agg_grad).sum() + outs = (self.mid_output * agg_grad).sum() self.model.zero_grad() - grad = torch.autograd.grad(loss, delta, retain_graph=False)[0] + grad = torch.autograd.grad(outs, delta, retain_graph=False)[0] # Update delta delta.data = delta.data - self.alpha * grad.sign() @@ -143,17 +143,17 @@ def find_layer(self, feature_layer_name) -> nn.Module: The layer to compute feature importance. """ - parser = feature_layer_name.split(' ') - m = self.model - - for layer in parser: - if layer in m._modules: - m = m._modules[layer] - break - else: - raise ValueError(f'Layer {layer} not found in the model.') - - return m + # for layer in feature_layer_name.split(' '): + # if layer not in self.model._modules: + # raise ValueError(f'Layer {layer} not found in the model.') + # return self.model._modules[layer] + + if feature_layer_name not in self.model._modules: + raise ValueError(f'Layer {feature_layer_name} not found in the model.') + feature_layer = self.model._modules[feature_layer_name] + if not isinstance(feature_layer, nn.Module): + raise ValueError(f'Layer {feature_layer_name} invalid.') + return feature_layer def __forward_hook(self, m: nn.Module, i: torch.Tensor, o: torch.Tensor): self.mid_output = o