-
Notifications
You must be signed in to change notification settings - Fork 341
/
Copy pathprune_llm.py
390 lines (325 loc) · 14.4 KB
/
prune_llm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
# Code adapted from
# https://github.com/IST-DASLab/sparsegpt/blob/master/datautils.py
# https://github.com/locuslab/wanda
import os, sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))))
import argparse
import os
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from importlib.metadata import version
import time
import torch
import torch.nn as nn
from collections import defaultdict
import fnmatch
import numpy as np
import random
from datasets import load_dataset
# Set seed for reproducibility
def set_seed(seed):
np.random.seed(seed)
torch.random.manual_seed(seed)
# Wrapper for tokenized input IDs
class TokenizerWrapper:
def __init__(self, input_ids):
self.input_ids = input_ids
# Load and process wikitext2 dataset
def get_wikitext2(nsamples, seed, seqlen, tokenizer):
# Load train and test datasets
traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
# Encode datasets
trainenc = tokenizer(" ".join(traindata['text']), return_tensors='pt')
testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt')
# Generate samples from training set
random.seed(seed)
trainloader = []
for _ in range(nsamples):
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
j = i + seqlen
inp = trainenc.input_ids[:, i:j]
tar = inp.clone()
tar[:, :-1] = -100
trainloader.append((inp, tar))
return trainloader, testenc
# Load and process c4 dataset
def get_c4(nsamples, seed, seqlen, tokenizer):
# Load train and validation datasets
traindata = load_dataset('allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train')
valdata = load_dataset('allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation')
# Generate samples from training set
random.seed(seed)
trainloader = []
for _ in range(nsamples):
while True:
i = random.randint(0, len(traindata) - 1)
trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
if trainenc.input_ids.shape[1] > seqlen:
break
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
j = i + seqlen
inp = trainenc.input_ids[:, i:j]
tar = inp.clone()
tar[:, :-1] = -100
trainloader.append((inp, tar))
# Prepare validation dataset
valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt')
valenc = valenc.input_ids[:, :(256 * seqlen)]
valenc = TokenizerWrapper(valenc)
return trainloader, valenc
# Function to select the appropriate loader based on dataset name
def get_loaders(name, nsamples=128, seed=0, seqlen=2048, tokenizer=None):
if 'wikitext2' in name:
return get_wikitext2(nsamples, seed, seqlen, tokenizer)
if "c4" in name:
return get_c4(nsamples, seed, seqlen, tokenizer)
# Function to evaluate perplexity (ppl) on a specified model and tokenizer
def eval_ppl(args, model, tokenizer, device=torch.device("cuda:0")):
# Set dataset
dataset = "wikitext2"
# Print status
print(f"evaluating on {dataset}")
# Get the test loader
_, testloader = get_loaders(
dataset, seed=0, seqlen=model.seqlen, tokenizer=tokenizer,
)
# Evaluate ppl in no grad context to avoid updating the model
with torch.no_grad():
ppl_test = eval_ppl_wikitext(model, testloader, 1, device)
return ppl_test
# Function to evaluate perplexity (ppl) specifically on the wikitext dataset
def eval_ppl_wikitext_train(model, trainloader, bs=1, device=None):
# Get input IDs
# testenc = testenc.input_ids
# Calculate number of samples
# nsamples = testenc.numel() // model.seqlen
nsamples = len(trainloader)
# List to store negative log likelihoods
nlls = []
print(f"nsamples {nsamples}")
# Loop through each batch
for i in range(0,nsamples,bs):
if i % 50 == 0:
print(f"sample {i}")
# Calculate end index
j = min(i+bs, nsamples)
# Prepare inputs and move to device
# inputs = testenc[:,(i * model.seqlen):(j * model.seqlen)].to(device)
inputs = trainloader[i][0].to(device)
inputs = inputs.reshape(j-i, model.seqlen)
# Forward pass through the model
lm_logits = model(inputs).logits
# Shift logits and labels for next token prediction
shift_logits = lm_logits[:, :-1, :].contiguous()
shift_labels = inputs[:, 1:]
# Compute loss
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1))
# Calculate negative log likelihood
neg_log_likelihood = loss.float() * model.seqlen * (j-i)
# Append to list of negative log likelihoods
nlls.append(neg_log_likelihood)
# Compute perplexity
ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))
# Empty CUDA cache to save memory
torch.cuda.empty_cache()
return ppl.item()
# Function to evaluate perplexity (ppl) specifically on the wikitext dataset
def eval_ppl_wikitext(model, testenc, bs=1, device=None):
# Get input IDs
testenc = testenc.input_ids
# Calculate number of samples
nsamples = testenc.numel() // model.seqlen
# List to store negative log likelihoods
nlls = []
print(f"nsamples {nsamples}")
# Loop through each batch
for i in range(0,nsamples,bs):
if i % 50 == 0:
print(f"sample {i}")
# Calculate end index
j = min(i+bs, nsamples)
# Prepare inputs and move to device
inputs = testenc[:,(i * model.seqlen):(j * model.seqlen)].to(device)
inputs = inputs.reshape(j-i, model.seqlen)
# Forward pass through the model
lm_logits = model(inputs).logits
# Shift logits and labels for next token prediction
shift_logits = lm_logits[:, :-1, :].contiguous()
shift_labels = inputs[:, 1:]
# Compute loss
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1))
# Calculate negative log likelihood
neg_log_likelihood = loss.float() * model.seqlen * (j-i)
# Append to list of negative log likelihoods
nlls.append(neg_log_likelihood)
# Compute perplexity
ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))
# Empty CUDA cache to save memory
torch.cuda.empty_cache()
return ppl.item()
def eval_zero_shot(model_name, model, tokenizer, task_list=["boolq","rte","hellaswag","winogrande","arc_challenge","arc_easy","openbookqa"],
num_fewshot=0, use_accelerate=False, add_special_tokens=False):
from lm_eval import tasks, evaluator
def pattern_match(patterns, source_list):
task_names = set()
for pattern in patterns:
for matching in fnmatch.filter(source_list, pattern):
task_names.add(matching)
return list(task_names)
task_names = pattern_match(task_list, tasks.ALL_TASKS)
model_args = f"pretrained={model_name}, cache_dir=./cache"
limit = None
if "70b" in model_name or "65b" in model_name:
limit = 2000
if use_accelerate:
model_args = f"pretrained={model_name}, cache_dir=./cache, use_accelerate=True"
results = evaluator.simple_evaluate(
model="hf-causal-experimental",
model_args=model_args,
tasks=task_names,
num_fewshot=num_fewshot,
batch_size=None,
device=None,
no_cache=True,
limit=limit,
description_dict={},
decontamination_ngrams_path=None,
check_integrity=False,
pretrained_model=model,
tokenizer=tokenizer,
add_special_tokens=add_special_tokens
)
return results
print('torch', version('torch'))
print('transformers', version('transformers'))
print('accelerate', version('accelerate'))
print('# of gpus: ', torch.cuda.device_count())
def get_llm(model_name, max_seq_len=None):
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map="auto"
)
model.seqlen = min(max_seq_len, model.config.max_position_embeddings) if max_seq_len is not None else model.config.max_position_embeddings
# avoid OOM, feel free to change this
return model
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, help='LLaMA model')
parser.add_argument('--seed', type=int, default=0, help='Seed for sampling the calibration data.')
parser.add_argument('--nsamples', type=int, default=128, help='Number of calibration samples.')
parser.add_argument('--pruning_ratio', type=float, default=0, help='Sparsity level')
parser.add_argument('--save', type=str, default=None, help='Path to save results.')
parser.add_argument('--save_model', type=str, default=None, help='Path to save the pruned model.')
parser.add_argument("--eval_zero_shot", action="store_true")
parser.add_argument("--max_seq_len", type=int, default=None)
args = parser.parse_args()
# Setting seeds for reproducibility
np.random.seed(args.seed)
torch.random.manual_seed(args.seed)
model_name = args.model.split("/")[-1]
print(f"loading llm model {args.model}")
model = get_llm(args.model, max_seq_len=args.max_seq_len)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)
device = torch.device("cuda:0")
if "30b" in args.model or "65b" in args.model: # for 30b and 65b we use device_map to load onto multiple A6000 GPUs, thus the processing here.
device = model.hf_device_map["lm_head"]
print("use device ", device)
##############
# Pruning
##############
print("----------------- Before Pruning -----------------")
print(model)
text = "Hello world."
inputs = torch.tensor(tokenizer.encode(text)).unsqueeze(0).to(model.device)
import torch_pruning as tp
num_heads = {}
out_channel_groups = {}
seperate_qkv = False
for name, m in model.named_modules():
if name.endswith("self_attn"):
if hasattr(m, "q_proj"):
seperate_qkv = True
num_heads[m.q_proj] = model.config.num_attention_heads
num_heads[m.k_proj] = model.config.num_key_value_heads
num_heads[m.v_proj] = model.config.num_key_value_heads
elif hasattr(m, "qkv_proj"):
seperate_qkv = False
num_heads[m.qkv_proj] = model.config.num_attention_heads
if name.endswith('mlp'):
if hasattr(m, "gate_up_proj"):
out_channel_groups[m.gate_up_proj] = 2
_is_gqa = model.config.num_attention_heads != model.config.num_key_value_heads
head_pruning_ratio = args.pruning_ratio
hidden_size_pruning_ratio = args.pruning_ratio
importance = tp.importance.GroupNormImportance(p=2, group_reduction='mean') #tp.importance.ActivationImportance(p=2, target_types=[torch.nn.Linear])
pruner = tp.pruner.MetaPruner(
model,
example_inputs=inputs,
importance=importance,
global_pruning=False,
pruning_ratio=hidden_size_pruning_ratio,
ignored_layers=[model.lm_head],
num_heads=num_heads,
prune_num_heads=True,
prune_head_dims=False, # we do not prune head dims so that we don't need to prune the ROPE
head_pruning_ratio=head_pruning_ratio,
out_channel_groups=out_channel_groups,
round_to=4,
)
#with torch.no_grad():
# with importance.compute_importance(model):
# calibration_data = "We recommend at least a 1TB hard drive for 4 channels, more if you plan on using 8MP \/ 4K cameras.\nDahua's Lite Series network video recorders offer excellent performance and high recording quality for IP video surveillance applications. For applications where details are critical for identification, this professional NVR provides a powerful processor with up to 4K resolution. Additionally, the NVR features a mouse shortcut operation menu, remote management and control, center storage, edge storage, and back up storage."
# calibration_data = torch.tensor(tokenizer.encode(text)).unsqueeze(0).to(model.device)
# _ = model(calibration_data)
#group = pruner.DG.get_pruning_group(model.model.layers[31].mlp.gate_up_proj, tp.prune_linear_out_channels, idxs=list(range(16384)))
#print(group)
for g in pruner.step(interactive=True):
#print(g)
g.prune()
# Update model attributes
model.config.hidden_size = model.lm_head.in_features
for name, m in model.named_modules():
if name.endswith("self_attn"):
if seperate_qkv:
m.hidden_size = m.q_proj.out_features
else:
m.hidden_size = m.qkv_proj.out_features // 3
m.num_heads = m.hidden_size // m.head_dim
model.config.num_attention_heads = m.num_heads
#m.head_dim = m.q_proj.out_features // m.num_heads
if not _is_gqa:
m.num_key_value_heads = m.num_heads
m.num_key_value_groups = m.num_heads // m.num_key_value_heads
elif name.endswith("mlp"):
if hasattr(m, "gate_proj"):
m.hidden_size = m.gate_proj.in_features
model.config.intermediate_size = m.gate_proj.out_features
elif hasattr(m, "gate_up_proj"):
m.hidden_size = m.gate_up_proj.in_features
model.config.intermediate_size = m.gate_up_proj.out_features // 2
else:
raise ValueError("Unknown mlp layer")
if not _is_gqa:
model.config.num_key_value_heads = model.config.num_attention_heads
print("----------------- After Pruning -----------------")
print(model)
print(model.config)
del pruner
torch.cuda.empty_cache()
model.eval()
num_params = sum(p.numel() for p in model.parameters())
print(f"num_params {num_params}")
ppl_test = eval_ppl(args, model, tokenizer, device)
print(f"wikitext perplexity {ppl_test}")
if args.save_model:
model.save_pretrained(args.save_model)
tokenizer.save_pretrained(args.save_model)
if __name__ == '__main__':
main()