We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hi! I tried to use Torch-Pruning to prune a transformer-based video generation model CogVideoX, but I got stuck when initializing the pruner:
pruner = tp.pruner.MetaPruner( cogvideox, example_inputs, importance=imp, global_pruning=False, num_heads=num_heads, pruning_ratio=args.pruning_ratio, ignored_layers=ignored_layers, prune_num_heads=True, prune_head_dims=False, head_pruning_ratio=args.pruning_ratio, round_to=4, )
Specifically, there is an infinite loop in the _fix_dependency_graph_non_recursive function.
_fix_dependency_graph_non_recursive
This is the script I used:
import torch_pruning as tp import torch import torchvision from torchvision import transforms import torchvision from tqdm import tqdm import os from glob import glob from PIL import Image import accelerate import utils from transformers import T5EncoderModel, AutoTokenizer from models.CogVideoX import CogVideoXTransformer3DModel pretrained_text_encoder_path = 'xxx' pretrained_tokenizer_path = 'xxx' import argparse parser = argparse.ArgumentParser() #parser.add_argument("--dataset", type=str, required=True) parser.add_argument("--model_path", type=str, required=True) parser.add_argument("--save_path", type=str, required=True) parser.add_argument("--pruning_ratio", type=float, default=0.3) parser.add_argument("--batch_size", type=int, default=128) parser.add_argument("--pruning_type", type=str, default='random', choices=['random', 'magnitude', 'reinit', 'taylor']) args = parser.parse_args() batch_size = args.batch_size if __name__=='__main__': torch_device = 'cuda' if torch.cuda.is_available() else "cpu" text_encoder = T5EncoderModel.from_pretrained(pretrained_text_encoder_path, torch_dtype=torch.bfloat16) tokenizer = AutoTokenizer.from_pretrained(pretrained_tokenizer_path, use_fast=False) # load model print(f"############ model initializing ##########") # config = CogVideoXTransformer3DModel.load_config(f'{args.model_path}/config.json') # print(config) # cogvideox = CogVideoXTransformer3DModel.from_config(config) cogvideox = CogVideoXTransformer3DModel.from_pretrained(args.model_path) cogvideox = cogvideox.to(dtype=torch.bfloat16) # m, u = cogvideox.load_state_dict(torch.load(cfg.pretrained_base_path, map_location="cpu", weights_only=True), strict=False) # print('Missing keys:', m) # print('Unexpected keys:', u) # set to cuda text_encoder.to(torch_device) cogvideox.to(torch_device) # prepare dummy inputs height = 960 // 2 width = 720 // 2 num_frames = 49 timestep = torch.ones((1,)).long().to(torch_device) with torch.no_grad(): image_rotary_emb = cogvideox._prepare_rotary_positional_embeddings(height // 8 // cogvideox.config.patch_size, width // 8 // cogvideox.config.patch_size, num_frames // 4 + 1, torch_device) text_tokens_and_mask = tokenizer( '', max_length=226, padding="max_length", truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors="pt", ) input_ids = text_tokens_and_mask["input_ids"].to(torch_device) attention_mask = text_tokens_and_mask["attention_mask"].to(torch_device) prompt_embeds = text_encoder( input_ids=input_ids, attention_mask=attention_mask, )["last_hidden_state"].detach() prompt_embeds = prompt_embeds[:, None].squeeze(1) example_inputs = { 'hidden_states': torch.randn(1, num_frames // 4 + 1, 33, height // 8, width // 8).to(torch_device, dtype=torch.bfloat16), 'timestep': timestep, 'encoder_hidden_states': prompt_embeds, 'image_rotary_emb': image_rotary_emb, 'guidance_scale': torch.full([1], 1, device=torch_device, dtype=torch.float32) # It's a CFG distillation version of CogVideoX } # Select importance criterion if args.pruning_type == 'random': imp = tp.importance.RandomImportance() elif args.pruning_type == 'l1': imp = tp.importance.GroupNormImportance(p=1) elif args.pruning_type == 'l2': imp = tp.importance.GroupNormImportance(p=2) elif args.pruning_type == 'taylor': imp = tp.importance.GroupTaylorImportance() elif args.pruning_type == 'hessian': imp = tp.importance.GroupHessianImportance() else: raise NotImplementedError(f"Pruning type {args.pruning_type} not implemented") # Get baseline metrics base_macs, base_params = tp.utils.count_ops_and_params(cogvideox, example_inputs) print("\n============ Before Pruning ============") print(f"Parameters: {base_params/1e6:.2f}M") print(f"MACs: {base_macs/1e9:.2f}G") cogvideox.zero_grad() cogvideox.eval() ignored_layers = [cogvideox.proj_out] from diffusers.models.attention import Attention num_heads = {} for m in cogvideox.modules(): if isinstance(m, Attention): num_heads[m.to_q] = m.heads num_heads[m.to_k] = m.heads num_heads[m.to_v] = m.heads # Initialize pruner, here I got stuck pruner = tp.pruner.MetaPruner( cogvideox, example_inputs, importance=imp, global_pruning=False, num_heads=num_heads, pruning_ratio=args.pruning_ratio, ignored_layers=ignored_layers, prune_num_heads=True, prune_head_dims=False, head_pruning_ratio=args.pruning_ratio, round_to=4, ) for g in pruner.step(interactive=True): g.prune()
Does torch_pruning support pruning CogVideoX now? Or Is there something wrong with my code?
The text was updated successfully, but these errors were encountered:
您在优化模型时遇到显存问题,我们的A100/H100集群可提供高显存即时租赁,为您提供灵活的GPU算力调用,需不需要给您做个免费压力测试?微信或电话联系18059827182
Sorry, something went wrong.
No branches or pull requests
Hi! I tried to use Torch-Pruning to prune a transformer-based video generation model CogVideoX, but I got stuck when initializing the pruner:
Specifically, there is an infinite loop in the
_fix_dependency_graph_non_recursive
function.This is the script I used:
Does torch_pruning support pruning CogVideoX now? Or Is there something wrong with my code?
The text was updated successfully, but these errors were encountered: