Skip to content

Commit

Permalink
Add SliceOp & Support Phi-3
Browse files Browse the repository at this point in the history
  • Loading branch information
VainF committed Nov 17, 2024
1 parent 94140cc commit 66f8cbb
Show file tree
Hide file tree
Showing 10 changed files with 256 additions and 21 deletions.
41 changes: 35 additions & 6 deletions examples/LLMs/prune_llama.py → examples/LLMs/prune_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,11 +291,22 @@ def main():
inputs = torch.tensor(tokenizer.encode(text)).unsqueeze(0).to(model.device)
import torch_pruning as tp
num_heads = {}
out_channel_groups = {}
seperate_qkv = False
for name, m in model.named_modules():
if name.endswith("self_attn"):
num_heads[m.q_proj] = model.config.num_attention_heads
num_heads[m.k_proj] = model.config.num_key_value_heads
num_heads[m.v_proj] = model.config.num_key_value_heads
if hasattr(m, "q_proj"):
seperate_qkv = True
num_heads[m.q_proj] = model.config.num_attention_heads
num_heads[m.k_proj] = model.config.num_key_value_heads
num_heads[m.v_proj] = model.config.num_key_value_heads
elif hasattr(m, "qkv_proj"):
seperate_qkv = False
num_heads[m.qkv_proj] = model.config.num_attention_heads
if name.endswith('mlp'):
if hasattr(m, "gate_up_proj"):
out_channel_groups[m.gate_up_proj] = 2

_is_gqa = model.config.num_attention_heads != model.config.num_key_value_heads
head_pruning_ratio = args.pruning_ratio
hidden_size_pruning_ratio = args.pruning_ratio
Expand All @@ -311,27 +322,45 @@ def main():
prune_num_heads=True,
prune_head_dims=False, # we do not prune head dims so that we don't need to prune the ROPE
head_pruning_ratio=head_pruning_ratio,
out_channel_groups=out_channel_groups
)


#with torch.no_grad():
# with importance.compute_importance(model):
# calibration_data = "We recommend at least a 1TB hard drive for 4 channels, more if you plan on using 8MP \/ 4K cameras.\nDahua's Lite Series network video recorders offer excellent performance and high recording quality for IP video surveillance applications. For applications where details are critical for identification, this professional NVR provides a powerful processor with up to 4K resolution. Additionally, the NVR features a mouse shortcut operation menu, remote management and control, center storage, edge storage, and back up storage."
# calibration_data = torch.tensor(tokenizer.encode(text)).unsqueeze(0).to(model.device)
# _ = model(calibration_data)
pruner.step()

#group = pruner.DG.get_pruning_group(model.model.layers[31].mlp.gate_up_proj, tp.prune_linear_out_channels, idxs=list(range(16384)))
#print(group)

for g in pruner.step(interactive=True):
print(g)
g.prune()

# Update model attributes
model.config.hidden_size = model.lm_head.in_features
for name, m in model.named_modules():
if name.endswith("self_attn"):
m.hidden_size = m.q_proj.out_features
if seperate_qkv:
m.hidden_size = m.q_proj.out_features
else:
m.hidden_size = m.qkv_proj.out_features // 3
m.num_heads = m.hidden_size // m.head_dim
model.config.num_attention_heads = m.num_heads
#m.head_dim = m.q_proj.out_features // m.num_heads
if not _is_gqa:
m.num_key_value_heads = m.num_heads
m.num_key_value_groups = m.num_heads // m.num_key_value_heads
elif name.endswith("mlp"):
model.config.intermediate_size = m.gate_proj.out_features
if hasattr(m, "gate_proj"):
m.hidden_size = m.gate_proj.out_features
elif hasattr(m, "gate_up_proj"):
m.hidden_size = m.gate_up_proj.in_features
else:
raise ValueError("Unknown mlp layer")

if not _is_gqa:
model.config.num_key_value_heads = model.config.num_attention_heads
print("----------------- After Pruning -----------------")
Expand Down
116 changes: 114 additions & 2 deletions examples/LLMs/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ pip install transformers datasets
### Llama-3 8B

```bash
python prune_llama.py --model meta-llama/Meta-Llama-3-8B --pruning_ratio 0.5
python prune_llm.py --model meta-llama/Meta-Llama-3-8B --pruning_ratio 0.5
```

<details>
Expand Down Expand Up @@ -120,7 +120,7 @@ wikitext perplexity 552648.25
### Llama-2 7B

```bash
python prune_llama.py --model meta-llama/Llama-2-7b-hf --pruning_ratio 0.5
python prune_llm.py --model meta-llama/Llama-2-7b-hf --pruning_ratio 0.5
```


Expand Down Expand Up @@ -224,3 +224,115 @@ wikitext perplexity 8479.0673828125
</details>


### microsoft/Phi-3-mini-4k-instruct

```bash
python prune_llm.py --model microsoft/Phi-3-mini-4k-instruct --pruning_ratio 0.5
```


<details>
<summary>Output:</summary>

```
----------------- Before Pruning -----------------
Phi3ForCausalLM(
(model): Phi3Model(
(embed_tokens): Embedding(32064, 3072, padding_idx=32000)
(embed_dropout): Dropout(p=0.0, inplace=False)
(layers): ModuleList(
(0-31): 32 x Phi3DecoderLayer(
(self_attn): Phi3Attention(
(o_proj): Linear(in_features=3072, out_features=3072, bias=False)
(qkv_proj): Linear(in_features=3072, out_features=9216, bias=False)
(rotary_emb): Phi3RotaryEmbedding()
)
(mlp): Phi3MLP(
(gate_up_proj): Linear(in_features=3072, out_features=16384, bias=False)
(down_proj): Linear(in_features=8192, out_features=3072, bias=False)
(activation_fn): SiLU()
)
(input_layernorm): Phi3RMSNorm()
(resid_attn_dropout): Dropout(p=0.0, inplace=False)
(resid_mlp_dropout): Dropout(p=0.0, inplace=False)
(post_attention_layernorm): Phi3RMSNorm()
)
)
(norm): Phi3RMSNorm()
)
(lm_head): Linear(in_features=3072, out_features=32064, bias=False)
)
----------------- After Pruning -----------------
Token indices sequence length is longer than the specified maximum sequence length for this model (2824490 > 4096). Running this sequence through the model will result in indexing errors
Phi3ForCausalLM(
(model): Phi3Model(
(embed_tokens): Embedding(32064, 1536, padding_idx=32000)
(embed_dropout): Dropout(p=0.0, inplace=False)
(layers): ModuleList(
(0-31): 32 x Phi3DecoderLayer(
(self_attn): Phi3Attention(
(o_proj): Linear(in_features=1536, out_features=1536, bias=False)
(qkv_proj): Linear(in_features=1536, out_features=4608, bias=False)
(rotary_emb): Phi3RotaryEmbedding()
)
(mlp): Phi3MLP(
(gate_up_proj): Linear(in_features=1536, out_features=8192, bias=False)
(down_proj): Linear(in_features=4096, out_features=1536, bias=False)
(activation_fn): SiLU()
)
(input_layernorm): Phi3RMSNorm()
(resid_attn_dropout): Dropout(p=0.0, inplace=False)
(resid_mlp_dropout): Dropout(p=0.0, inplace=False)
(post_attention_layernorm): Phi3RMSNorm()
)
)
(norm): Phi3RMSNorm()
)
(lm_head): Linear(in_features=1536, out_features=32064, bias=False)
)
Phi3Config {
"_name_or_path": "microsoft/Phi-3-mini-4k-instruct",
"architectures": [
"Phi3ForCausalLM"
],
"attention_bias": false,
"attention_dropout": 0.0,
"auto_map": {
"AutoConfig": "microsoft/Phi-3-mini-4k-instruct--configuration_phi3.Phi3Config",
"AutoModelForCausalLM": "microsoft/Phi-3-mini-4k-instruct--modeling_phi3.Phi3ForCausalLM"
},
"bos_token_id": 1,
"embd_pdrop": 0.0,
"eos_token_id": 32000,
"hidden_act": "silu",
"hidden_size": 1536,
"initializer_range": 0.02,
"intermediate_size": 8192,
"max_position_embeddings": 4096,
"model_type": "phi3",
"num_attention_heads": 16,
"num_hidden_layers": 32,
"num_key_value_heads": 16,
"original_max_position_embeddings": 4096,
"pad_token_id": 32000,
"resid_pdrop": 0.0,
"rms_norm_eps": 1e-05,
"rope_scaling": null,
"rope_theta": 10000.0,
"sliding_window": 2047,
"tie_word_embeddings": false,
"torch_dtype": "float16",
"transformers_version": "4.36.2",
"use_cache": true,
"vocab_size": 32064
}
num_params 1004570112
evaluating on wikitext2
nsamples 83
sample 0
sample 50
wikitext perplexity 92795.3984375
```

</details>
2 changes: 0 additions & 2 deletions examples/torchvision_models/torchvision_global_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
)
from torchvision.models.detection.fcos import fcos_resnet50_fpn
from torchvision.models.detection.keypoint_rcnn import keypointrcnn_resnet50_fpn
from torchvision.models.detection.mask_rcnn import maskrcnn_resnet50_fpn_v2
from torchvision.models.detection.retinanet import retinanet_resnet50_fpn_v2
from torchvision.models.alexnet import alexnet

from torchvision.models.vision_transformer import (
Expand Down
2 changes: 0 additions & 2 deletions examples/torchvision_models/torchvision_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
)
from torchvision.models.detection.fcos import fcos_resnet50_fpn
from torchvision.models.detection.keypoint_rcnn import keypointrcnn_resnet50_fpn
from torchvision.models.detection.mask_rcnn import maskrcnn_resnet50_fpn_v2
from torchvision.models.detection.retinanet import retinanet_resnet50_fpn_v2
from torchvision.models.alexnet import alexnet

from torchvision.models.vision_transformer import (
Expand Down
5 changes: 4 additions & 1 deletion examples/transformers/prune_timm_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def main():
prune_num_heads=args.prune_num_heads, # reduce num_heads by pruning entire heads (default: False)
prune_head_dims=not args.prune_num_heads, # reduce head_dim by pruning featrues dims of each head (default: True)
head_pruning_ratio=0.5, #args.head_pruning_ratio, # remove 50% heads, only works when prune_num_heads=True (default: 0.0)
round_to=2
round_to=1
)

if isinstance(imp, (tp.importance.GroupTaylorImportance, tp.importance.GroupHessianImportance)):
Expand Down Expand Up @@ -206,6 +206,9 @@ def main():
print("Base Loss: %.4f, Pruned Loss: %.4f"%(loss_ori, loss_pruned))
print("Base Accuracy: %.4f, Pruned Accuracy: %.4f"%(acc_ori, acc_pruned))

latency_mean, latency_std = tp.utils.benchmark.measure_latency(model, example_inputs=torch.randn(16,3,224,224).to(device), repeat=300)
print("Latency: %.4f ms, Std: %.4f ms"%(latency_mean, latency_std))

if args.save_as is not None:
print("Saving the pruned model to %s..."%args.save_as)
os.makedirs(os.path.dirname(args.save_as), exist_ok=True)
Expand Down
16 changes: 16 additions & 0 deletions torch_pruning/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,22 @@ def __call__(self, idxs: _HybridIndex):

return new_idxs

class _SliceIndexMapping(object):
def __init__(self, dim, start, step, end, reverse=False):
self.start = start
self.step = step
self.end = end
self.reverse = reverse
self.dim = dim

def __call__(self, idxs: _HybridIndex):

if self.reverse == True:
new_idxs = [ _HybridIndex(idx=i.idx * self.step + self.start, root_idx=i.root_idx) for i in idxs]
else:
new_idxs = [ _HybridIndex(idx=(i.idx - self.start) // self.step, root_idx=i.root_idx) for i in idxs if (i.idx >= self.start and i.idx < self.end and (i.idx-self.start)%self.step==0) ]
return new_idxs

class _SplitIndexMapping(object):
def __init__(self, offset, reverse=False):
self.offset = offset
Expand Down
52 changes: 47 additions & 5 deletions torch_pruning/dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ def __init__(self):
ops.OPTYPE.UNBIND: ops.UnbindPruner(),
ops.OPTYPE.EXPAND: ops.ExpandPruner(),
ops.OPTYPE.CUSTOMIZED: ops.CustomizedPruner(), # just a placeholder
ops.OPTYPE.SLICE: ops.SlicePruner(),
}
self.REGISTERED_PRUNERS = function.PrunerBox.copy() # shallow copy
self.REGISTERED_PRUNERS.update(_dummy_pruners) # merge dummy pruners
Expand Down Expand Up @@ -511,7 +512,7 @@ def _fix_dependency_graph_non_recursive(dep, idxs, *args):
)

_fix_dependency_graph_non_recursive(*group[0])

# merge pruning ops
merged_group = Group() # craft a new group for merging
for dep, idxs in group.items:
Expand Down Expand Up @@ -827,6 +828,7 @@ def create_node_if_not_exists(grad_fn):

# 1. link grad_fns and modules
if module is None: # a new module

if not hasattr(grad_fn, "name"):
# we treat all unknwon modules as element-wise operations by default,
# which does not modify the #dimension/#channel of features.
Expand All @@ -853,6 +855,12 @@ def create_node_if_not_exists(grad_fn):
elif "view" in grad_fn.name().lower() or 'reshape' in grad_fn.name().lower():
module = ops._ReshapeOp(self._op_id)
self._op_id+=1
elif "slice" in grad_fn.name().lower() and "copyslices" not in grad_fn.name().lower():
if hasattr(grad_fn, '_saved_start') and hasattr(grad_fn, '_saved_end') and hasattr(grad_fn, '_saved_step') and hasattr(grad_fn, '_saved_dim'):
module = ops._SliceOp(self._op_id, grad_fn)
else: # for old version of pytorch, we can not handle the slice operation
module = ops._ElementWiseOp(self._op_id, grad_fn.name())
self._op_id+=1
else:
# treate other ops as element-wise ones, like Add, Sub, Div, Mul.
module = ops._ElementWiseOp(self._op_id, grad_fn.name())
Expand Down Expand Up @@ -924,6 +932,32 @@ def update_index_mapping(self):
self._update_unbind_index_mapping(node)
if node.type == ops.OPTYPE.EXPAND and torch.__version__ >= "1.8":
self._update_expand_index_mapping(node)
if node.type == ops.OPTYPE.SLICE:
self._update_slice_index_mapping(node)


def _update_slice_index_mapping(self, slice_node: Node):
if slice_node.type != ops.OPTYPE.SLICE:
return
grad_fn = slice_node.grad_fn
if hasattr(grad_fn, '_saved_self_sym_sizes'):
if len(grad_fn._saved_self_sym_sizes)==4 and grad_fn._saved_dim != 1:
return
elif len(grad_fn._saved_self_sym_sizes)==3 and grad_fn._saved_dim != 2:
return

start, step, end, dim = slice_node.module.start, slice_node.module.step, slice_node.module.end, slice_node.module.dim
for node in slice_node.inputs:
for dep in slice_node.dependencies:
if dep.target == node:
dep.index_mapping[0] = _helpers._SliceIndexMapping(
dim=dim, start=start, end=end, step=step, reverse=True
)
for dep in node.dependencies:
if dep.target == slice_node:
dep.index_mapping[0] = _helpers._SliceIndexMapping(
dim=dim, start=start, end=end, step=step, reverse=False
)

def _init_shape_information(self):
for module, node in self.module2node.items():
Expand Down Expand Up @@ -1111,10 +1145,18 @@ def _update_concat_index_mapping(self, cat_node: Node):
def _update_split_index_mapping(self, split_node: Node):
if split_node.type != ops.OPTYPE.SPLIT:
return

if hasattr(split_node.grad_fn, '_saved_dim') and split_node.grad_fn._saved_dim != 1: # this only works for Pytorch>=1.12
return


if hasattr(split_node.grad_fn, '_saved_dim'): # this only works for Pytorch>=1.12

# There a issue in some pytorch version, where the _saved_dim is an uninitialized value like 118745347895359
# So we need to check if the _saved_dim is a valid value (<len(_saved_self_sym_sizes) or a nominal value like 20)
if hasattr(split_node.grad_fn, '_saved_self_sym_sizes'):
if split_node.grad_fn._saved_dim<len(split_node.grad_fn._saved_self_sym_sizes) and split_node.grad_fn._saved_dim != 1:
return
else:
THRESHOLD = 20
if split_node.grad_fn._saved_dim<THRESHOLD and split_node.grad_fn._saved_dim>=0 and split_node.grad_fn._saved_dim != 1:
return
offsets = split_node.module.offsets

if offsets is None:
Expand Down
Loading

0 comments on commit 66f8cbb

Please sign in to comment.