Skip to content

Commit

Permalink
fix: mypy types
Browse files Browse the repository at this point in the history
  • Loading branch information
spencerwooo committed Feb 26, 2024
1 parent d57a3ae commit cf1dbe2
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ version = { attr = "torchattack.__version__" }

[tool.ruff]
line-length = 88

[tool.ruff.lint]
select = ["E", "F", "I", "N", "B", "SIM"]
ignore = ["E501"]

Expand All @@ -32,9 +34,9 @@ ignore = ["E501"]
quote-style = "single"

[tool.mypy]
disallow_any_unimported = true
no_implicit_optional = true
check_untyped_defs = true
ignore_missing_imports = true # Used as torchvision does not ship type hints
ignore_missing_imports = true # Used as torchvision does not ship type hints
# disallow_any_unimported = true
# disallow_untyped_defs = true
# warn_return_any = true
10 changes: 7 additions & 3 deletions src/torchattack/fia.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ def __init__(
self.targeted = targeted
# self.lossfn = nn.CrossEntropyLoss()

# TODO: Targeted attack is not supported yet.
if self.targeted:
print('Targeted attack is not supported, using non-targeted variant.')

def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Perform FIA on a batch of images.
Expand Down Expand Up @@ -86,12 +90,12 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
output_random = torch.softmax(output_random, dim=1)
loss = 0
for batch_i in range(x.shape[0]):
loss += output_random[batch_i][y[batch_i]]
loss += output_random[batch_i][y[batch_i]] # type: ignore
self.model.zero_grad()
loss.backward()
loss.backward() # type: ignore
agg_grad += self.mid_grad[0].detach()
for batch_i in range(x.shape[0]):
agg_grad[batch_i] /= agg_grad[batch_i].norm(p=2)
agg_grad[batch_i] /= agg_grad[batch_i].norm(p=2) # type: ignore
h2.remove()

# Perform FIA
Expand Down

0 comments on commit cf1dbe2

Please sign in to comment.