You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
import torch
from vit_pytorch import ViT
import torch_pruning as tp
import vit_pytorch.vit
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
def forward(self, x):
B, N, C = x.shape
x = self.norm(x)
qkv = self.to_qkv(x).reshape(B, N, 3, self.heads, self.dim_head).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = out.transpose(1, 2).reshape(B, N, -1)
return self.to_out(out)
model = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 16,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
img = torch.randn(1, 3, 256, 256)
imp = tp.importance.GroupNormImportance(p=1)
example_inputs = img
base_macs, base_params = tp.utils.count_ops_and_params(model, example_inputs)
num_heads = {}
ignored_layers = [model.mlp_head]
times = 0
for m in model.modules():
if isinstance(m, vit_pytorch.vit.Attention):
m.forward = forward.__get__(m, vit_pytorch.vit.Attention)
num_heads[m.to_qkv] = m.heads
if isinstance(m, vit_pytorch.vit.FeedForward):
ignored_layers.append(m.net[4]) # only prune the internal layers of FFN & Attention
pruner = tp.pruner.MetaPruner(
model,
example_inputs,
global_pruning=False, # If False, a uniform pruning ratio will be assigned to different layers.
importance=imp, # importance criterion for parameter selection
pruning_ratio=0.5, # target pruning ratio
ignored_layers=ignored_layers,
num_heads=num_heads, # number of heads in self attention
prune_num_heads=True, # reduce num_heads by pruning entire heads (default: False)
prune_head_dims= False, # reduce head_dim by pruning featrues dims of each head (default: True)
head_pruning_ratio=0.5, #args.head_pruning_ratio, # remove 50% heads, only works when prune_num_heads=True (default: 0.0)
round_to=1
)
for i, g in enumerate(pruner.step(interactive=True)):
g.prune()
head_id = 0
for m in model.modules():
if isinstance(m, vit_pytorch.vit.Attention):
print("Head #%d"%head_id)
print("[Before Pruning] Num Heads: %d, Head Dim: %d =>"%(m.heads, m.dim_head))
m.num_heads = pruner.num_heads[m.to_qkv]
m.head_dim = m.to_qkv.out_features // (3 * m.heads)
print("[After Pruning] Num Heads: %d, Head Dim: %d"%(m.heads, m.dim_head))
print()
head_id+=1
print("----------------------------------------")
print("Summary:")
pruned_macs, pruned_params = tp.utils.count_ops_and_params(model, example_inputs)
print("Base MACs: %.2f G, Pruned MACs: %.2f G"%(base_macs/1e9, pruned_macs/1e9))
print("Base Params: %.2f M, Pruned Params: %.2f M"%(base_params/1e6, pruned_params/1e6))
I'm trying to prune a ViT model implemented in vit_pytorch but got following error:
Summary:
Traceback (most recent call last):
File "C:\Users\Desktop\ViT model pruning\vit-pytorch\ViT_pruning.py", line 93, in <module>
pruned_macs, pruned_params = tp.utils.count_ops_and_params(model, example_inputs)
File "D:\Anaconda\envs\cvlface-env\lib\site-packages\torch\autograd\grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "D:\Anaconda\envs\cvlface-env\lib\site-packages\torch_pruning\utils\op_counter.py", line 35, in count_ops_and_params
_ = flops_model(example_inputs)
File "D:\Anaconda\envs\cvlface-env\lib\site-packages\torch\nn\modules\module.py", line 1212, in _call_impl
result = forward_call(*input, **kwargs)
File "C:\Users\Desktop\ViT model pruning\vit-pytorch\vit_pytorch\vit.py", line 123, in forward
x = self.transformer(x)
File "D:\Anaconda\envs\cvlface-env\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\Desktop\ViT model pruning\vit-pytorch\vit_pytorch\vit.py", line 79, in forward
x = attn(x) + x
File "D:\Anaconda\envs\cvlface-env\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\Desktop\ViT model pruning\vit-pytorch\ViT_pruning.py", line 13, in forward
qkv = self.to_qkv(x).reshape(B, N, 3, self.heads, self.dim_head).permute(2, 0, 3, 1, 4)
RuntimeError: shape '[1, 65, 3, 16, 64]' is invalid for input of size 99840
It is likely to have problems after pruning. How to solve this one ? @VainF or anyone else could you help me tp solve this problem. Thank you
The text was updated successfully, but these errors were encountered:
I'm trying to prune a ViT model implemented in vit_pytorch but got following error:
It is likely to have problems after pruning. How to solve this one ? @VainF or anyone else could you help me tp solve this problem. Thank you
The text was updated successfully, but these errors were encountered: