Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v1.4.1 #403

Merged
merged 13 commits into from
Jul 21, 2024
21 changes: 18 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

Torch-Pruning (TP) is designed for structural pruning, facilating the following features:

* **General-purpose Pruning Toolkit:** TP enables structural pruning for a wide range of deep neural networks, including [Large Language Models (LLMs)](https://github.com/VainF/Torch-Pruning/tree/master/examples/LLMs), [Segment Anything Model (SAM)](https://github.com/czg1225/SlimSAM), [Diffusion Models](https://github.com/VainF/Diff-Pruning), [Yolov7](examples/yolov7/), [yolov8](examples/yolov8/), [Vision Transformers](examples/transformers/), [Swin Transformers](examples/transformers#swin-transformers-from-hf-transformers), [BERT](examples/transformers#bert-from-hf-transformers), FasterRCNN, SSD, ResNe(X)t, ConvNext, DenseNet, ConvNext, RegNet, DeepLab, etc. Different from [torch.nn.utils.prune](https://pytorch.org/tutorials/intermediate/pruning_tutorial.html) that zeroizes parameters through masking, Torch-Pruning deploys an algorithm called **[DepGraph](https://openaccess.thecvf.com/content/CVPR2023/html/Fang_DepGraph_Towards_Any_Structural_Pruning_CVPR_2023_paper.html)** to remove parameters physically.
* **General-purpose Pruning Toolkit:** TP enables structural pruning for a wide range of deep neural networks, including [Large Language Models (LLMs)](https://github.com/VainF/Torch-Pruning/tree/master/examples/LLMs), [Segment Anything Model (SAM)](https://github.com/czg1225/SlimSAM), [Diffusion Models](https://github.com/VainF/Diff-Pruning), [Vision Transformers](https://github.com/VainF/Isomorphic-Pruning), [ConvNext](https://github.com/VainF/Isomorphic-Pruning), [Yolov7](examples/yolov7/), [yolov8](examples/yolov8/), [Swin Transformers](examples/transformers#swin-transformers-from-hf-transformers), [BERT](examples/transformers#bert-from-hf-transformers), FasterRCNN, SSD, ResNe(X)t, DenseNet, RegNet, DeepLab, etc. Different from [torch.nn.utils.prune](https://pytorch.org/tutorials/intermediate/pruning_tutorial.html) that zeroizes parameters through masking, Torch-Pruning deploys an algorithm called **[DepGraph](https://openaccess.thecvf.com/content/CVPR2023/html/Fang_DepGraph_Towards_Any_Structural_Pruning_CVPR_2023_paper.html)** to remove parameters physically.
* [Examples](examples): Pruning off-the-shelf models from Timm, Huggingface Transformers, Torchvision, Yolo, etc.
* [Code for reproducing paper results](reproduce): Reproduce the our results in the DepGraph paper.

Expand Down Expand Up @@ -206,15 +206,30 @@ print(f"MACs: {base_macs/1e9} G -> {macs/1e9} G, #Params: {base_nparams/1e6} M -
MACs: 1.822177768 G -> 0.487202536 G, #Params: 11.689512 M -> 3.05588 M
```
#### Global Pruning
Global pruning perform importance ranking across all layers, which has the potential to find a better structures. This can be easily achieved by setting ``global_pruning=True`` in the pruner. While this strategy can possibly offer performance advantages, it also carries the potential of overly pruning specific layers, resulting in a substantial decline in overall performance. We provide an alternative algorithm called [Isomorphic Pruning](https://arxiv.org/abs/2407.04616) to alleviate this issue, which can be eanbled with ``isomorphic=True``. For more details, please see our offical [codebase for ViTs and CNNs](https://github.com/VainF/Isomorphic-Pruning).
Global pruning perform importance ranking across all layers, which has the potential to find a better structures. This can be easily achieved by setting ``global_pruning=True`` in the pruner. While this strategy can possibly offer performance advantages, it also carries the potential of overly pruning specific layers, resulting in a substantial decline in overall performance. We provide an alternative algorithm called [Isomorphic Pruning](https://arxiv.org/abs/2407.04616) to alleviate this issue, which can be eanbled with ``isomorphic=True``.
```python
pruner = tp.pruner.MetaPruner(
...
ismorphic=True, # enable isomorphic pruning to improve global ranking
isomorphic=True, # enable isomorphic pruning to improve global ranking
global_pruning=True, # global pruning
)
```

<div align="center">
<img src="assets/isomorphic_pruning.png" width="96%">
</div>

#### Pruning Ratios

The default pruning ratio can be set by ``pruning_ratio``. If you want to customize the pruning ratio for some layers or blocks, you can use ``pruning_ratio_dict``. The key of the dict can be an ``nn.Module`` or a tuple of ``nn.Module``. In the second case, all modules in the tuple will form a ``scope`` and share the pruning ratio. Global ranking will be perfomed in this scope. This is also the core idea of [Isomorphic Pruning](https://arxiv.org/abs/2407.04616).
```python
pruner = tp.pruner.MetaPruner(
...
pruning_ratio=0.5, # default pruning ratio
pruning_ratio_dict = {(model.layer1, model.layer2): 0.4, model.layer3: 0.2},
# Global pruning will be performed on layer1 and layer2
)
```

#### Sparse Training (Optional)
Some pruners like [BNScalePruner](https://github.com/VainF/Torch-Pruning/blob/dd59921365d72acb2857d3d74f75c03e477060fb/torch_pruning/pruner/algorithms/batchnorm_scale_pruner.py#L45) and [GroupNormPruner](https://github.com/VainF/Torch-Pruning/blob/dd59921365d72acb2857d3d74f75c03e477060fb/torch_pruning/pruner/algorithms/group_norm_pruner.py#L53) support sparse training. This can be easily achieved by inserting ``pruner.update_regularizer()`` and ``pruner.regularize(model)`` in your standard training loops. The pruner will accumulate the regularization gradients to ``.grad``. Sparse training is optional and may be expensive for pruning.
Expand Down
Binary file added assets/isomorphic_pruning.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
7 changes: 4 additions & 3 deletions examples/timm_models/prune_timm_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,10 @@ def main():
num_heads = {}

for m in model.modules():
if hasattr(m, 'head'): #isinstance(m, nn.Linear) and m.out_features == model.num_classes:
ignored_layers.append(model.head)
print("Ignore classifier layer: ", m.head)
#if hasattr(m, 'head'): #isinstance(m, nn.Linear) and m.out_features == model.num_classes:
if isinstance(m, nn.Linear) and m.out_features == model.num_classes:
ignored_layers.append(m)
print("Ignore classifier layer: ", m)

# Attention layers
if hasattr(m, 'num_heads'):
Expand Down
3 changes: 3 additions & 0 deletions examples/transformers/readme.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Transformers

This example demonstrate the minimal code to prune Transformers, including Vision Transformers (ViT), Swin Transformers, and BERT. If you need a more comprehensive example for pruning and finetuning, please refer to the [codebase for Isomorphic Pruning](https://github.com/VainF/Isomorphic-Pruning), where detailed instructions and pre-pruned models are available.


## Pruning ViT-ImageNet-21K-ft-1K from [Timm](https://github.com/huggingface/pytorch-image-models)

### Data
Expand Down
8 changes: 5 additions & 3 deletions tests/test_taylor_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,18 @@ def test_taylor():
if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
ignored_layers.append(m) # DO NOT prune the final classifier!

iterative_steps = 5 # progressive pruning
iterative_steps = 1 # progressive pruning
pruner = tp.pruner.MagnitudePruner(
model,
example_inputs,
importance=imp,
iterative_steps=iterative_steps,
pruning_ratio=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
global_pruning=True,
pruning_ratio=0.1, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
pruning_ratio_dict={model.layer1: 0.5, (model.layer2, model.layer3): 0.5},
ignored_layers=ignored_layers,
)

print(model)
base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
for i in range(iterative_steps):
if isinstance(imp, tp.importance.TaylorImportance):
Expand Down
Loading
Loading