Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tp.utils.count_ops_and_params Error #450

Open
DRVu16 opened this issue Dec 30, 2024 · 0 comments
Open

tp.utils.count_ops_and_params Error #450

DRVu16 opened this issue Dec 30, 2024 · 0 comments

Comments

@DRVu16
Copy link

DRVu16 commented Dec 30, 2024

import torch
from vit_pytorch import ViT
import torch_pruning as tp 
import vit_pytorch.vit

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

def forward(self, x):
    B, N, C = x.shape
    x = self.norm(x)

    qkv = self.to_qkv(x).reshape(B, N, 3, self.heads, self.dim_head).permute(2, 0, 3, 1, 4)
    q, k, v = qkv.unbind(0)

    dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

    attn = self.attend(dots)
    attn = self.dropout(attn)

    out = torch.matmul(attn, v)
    out = out.transpose(1, 2).reshape(B, N, -1)
    return self.to_out(out)

    
model = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 16,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(1, 3, 256, 256)

imp = tp.importance.GroupNormImportance(p=1)


example_inputs = img
base_macs, base_params = tp.utils.count_ops_and_params(model, example_inputs)

num_heads = {}
ignored_layers = [model.mlp_head]
times = 0

for m in model.modules():
    if isinstance(m, vit_pytorch.vit.Attention):
        m.forward = forward.__get__(m, vit_pytorch.vit.Attention)
        num_heads[m.to_qkv] = m.heads
    if isinstance(m, vit_pytorch.vit.FeedForward):
        ignored_layers.append(m.net[4]) # only prune the internal layers of FFN & Attention

    pruner = tp.pruner.MetaPruner(
        model, 
        example_inputs, 
        global_pruning=False, # If False, a uniform pruning ratio will be assigned to different layers.
        importance=imp, # importance criterion for parameter selection
        pruning_ratio=0.5, # target pruning ratio
        ignored_layers=ignored_layers,
        num_heads=num_heads, # number of heads in self attention
        prune_num_heads=True, # reduce num_heads by pruning entire heads (default: False)
        prune_head_dims= False, # reduce head_dim by pruning featrues dims of each head (default: True)
        head_pruning_ratio=0.5, #args.head_pruning_ratio, # remove 50% heads, only works when prune_num_heads=True (default: 0.0)
        round_to=1
    )


for i, g in enumerate(pruner.step(interactive=True)):
    g.prune()

head_id = 0
for m in model.modules():
    if isinstance(m, vit_pytorch.vit.Attention):
        print("Head #%d"%head_id)
        print("[Before Pruning] Num Heads: %d, Head Dim: %d =>"%(m.heads, m.dim_head))
        m.num_heads = pruner.num_heads[m.to_qkv]
        m.head_dim = m.to_qkv.out_features // (3 * m.heads)
        print("[After Pruning] Num Heads: %d, Head Dim: %d"%(m.heads, m.dim_head))
        print()
        head_id+=1

print("----------------------------------------")
print("Summary:")
pruned_macs, pruned_params = tp.utils.count_ops_and_params(model, example_inputs)
print("Base MACs: %.2f G, Pruned MACs: %.2f G"%(base_macs/1e9, pruned_macs/1e9))
print("Base Params: %.2f M, Pruned Params: %.2f M"%(base_params/1e6, pruned_params/1e6))

I'm trying to prune a ViT model implemented in vit_pytorch but got following error:

Summary:
Traceback (most recent call last):
  File "C:\Users\Desktop\ViT model pruning\vit-pytorch\ViT_pruning.py", line 93, in <module>
    pruned_macs, pruned_params = tp.utils.count_ops_and_params(model, example_inputs)
  File "D:\Anaconda\envs\cvlface-env\lib\site-packages\torch\autograd\grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "D:\Anaconda\envs\cvlface-env\lib\site-packages\torch_pruning\utils\op_counter.py", line 35, in count_ops_and_params
    _ = flops_model(example_inputs)
  File "D:\Anaconda\envs\cvlface-env\lib\site-packages\torch\nn\modules\module.py", line 1212, in _call_impl
    result = forward_call(*input, **kwargs)
  File "C:\Users\Desktop\ViT model pruning\vit-pytorch\vit_pytorch\vit.py", line 123, in forward
    x = self.transformer(x)
  File "D:\Anaconda\envs\cvlface-env\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\Desktop\ViT model pruning\vit-pytorch\vit_pytorch\vit.py", line 79, in forward
    x = attn(x) + x
  File "D:\Anaconda\envs\cvlface-env\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\Desktop\ViT model pruning\vit-pytorch\ViT_pruning.py", line 13, in forward
    qkv = self.to_qkv(x).reshape(B, N, 3, self.heads, self.dim_head).permute(2, 0, 3, 1, 4)
RuntimeError: shape '[1, 65, 3, 16, 64]' is invalid for input of size 99840

It is likely to have problems after pruning. How to solve this one ? @VainF or anyone else could you help me tp solve this problem. Thank you

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant