diff --git a/README.md b/README.md index 44c4820..732708b 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ [![lint](https://github.com/daisylab-bit/torchattack/actions/workflows/lint.yml/badge.svg)](https://github.com/daisylab-bit/torchattack/actions/workflows/lint.yml) [![GitHub release (latest by date)](https://img.shields.io/github/v/release/daisylab-bit/torchattack)](https://github.com/daisylab-bit/torchattack/releases/latest) -A set of adversarial attacks implemented in PyTorch. _For internal use._ +A set of adversarial attacks implemented in PyTorch. For internal use, no support guaranteed. ```shell # Install from github source @@ -18,22 +18,25 @@ python -m pip install git+https://gitee.com/daisylab-bit/torchattack ## Usage ```python +import torch +from torchattack import FGSM, MIFGSM from torchvision.models import resnet50 from torchvision.transforms import transforms -from torchattack import FGSM, MIFGSM +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Load a model model = resnet50(weights='DEFAULT') +model = model.to(device) -# Define transforms (you are responsible for normalizing the data if needed) -transform = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) +# Define normalization (you are responsible for normalizing the data if needed) +normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Initialize an attack -attack = FGSM(model, transform, eps=0.03) +attack = FGSM(model, normalize, device) # Initialize an attack with extra params -attack = MIFGSM(model, transform, eps=0.03, steps=10, decay=1.0) +attack = MIFGSM(model, normalize, device, eps=0.03, steps=10, decay=1.0) ``` Check out [`torchattack.utils.run_attack`](src/torchattack/utils.py) for a simple example.