Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Infinite loop when pruning CogVideoX #456

Open
ybWang0820 opened this issue Feb 8, 2025 · 1 comment
Open

Infinite loop when pruning CogVideoX #456

ybWang0820 opened this issue Feb 8, 2025 · 1 comment

Comments

@ybWang0820
Copy link

ybWang0820 commented Feb 8, 2025

Hi! I tried to use Torch-Pruning to prune a transformer-based video generation model CogVideoX, but I got stuck when initializing the pruner:

pruner = tp.pruner.MetaPruner(
        cogvideox,
        example_inputs,
        importance=imp,
        global_pruning=False,
        num_heads=num_heads,
        pruning_ratio=args.pruning_ratio,
        ignored_layers=ignored_layers,
        prune_num_heads=True,
        prune_head_dims=False,
        head_pruning_ratio=args.pruning_ratio,
        round_to=4,
    )

Specifically, there is an infinite loop in the _fix_dependency_graph_non_recursive function.

This is the script I used:

import torch_pruning as tp
import torch
import torchvision
from torchvision import transforms
import torchvision
from tqdm import tqdm
import os
from glob import glob
from PIL import Image
import accelerate
import utils
from transformers import T5EncoderModel, AutoTokenizer
from models.CogVideoX import CogVideoXTransformer3DModel

pretrained_text_encoder_path = 'xxx'
pretrained_tokenizer_path = 'xxx'

import argparse
parser = argparse.ArgumentParser()
#parser.add_argument("--dataset", type=str, required=True)
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--save_path", type=str, required=True)
parser.add_argument("--pruning_ratio", type=float, default=0.3)
parser.add_argument("--batch_size", type=int, default=128)
parser.add_argument("--pruning_type", type=str, default='random', choices=['random', 'magnitude', 'reinit', 'taylor'])

args = parser.parse_args()

batch_size = args.batch_size

if __name__=='__main__':
    torch_device = 'cuda' if torch.cuda.is_available() else "cpu"

    text_encoder = T5EncoderModel.from_pretrained(pretrained_text_encoder_path, torch_dtype=torch.bfloat16)
    tokenizer = AutoTokenizer.from_pretrained(pretrained_tokenizer_path, use_fast=False)

    # load model
    print(f"############ model initializing ##########")
    # config = CogVideoXTransformer3DModel.load_config(f'{args.model_path}/config.json')
    # print(config)
    # cogvideox = CogVideoXTransformer3DModel.from_config(config)
    cogvideox = CogVideoXTransformer3DModel.from_pretrained(args.model_path)
    cogvideox = cogvideox.to(dtype=torch.bfloat16)
    # m, u = cogvideox.load_state_dict(torch.load(cfg.pretrained_base_path, map_location="cpu", weights_only=True), strict=False)
    # print('Missing keys:', m)
    # print('Unexpected keys:', u)

    # set to cuda
    text_encoder.to(torch_device)
    cogvideox.to(torch_device)

    # prepare dummy inputs
    height = 960 // 2
    width = 720 // 2
    num_frames = 49
    timestep = torch.ones((1,)).long().to(torch_device)
    with torch.no_grad():
        image_rotary_emb = cogvideox._prepare_rotary_positional_embeddings(height // 8 // cogvideox.config.patch_size, width // 8 // cogvideox.config.patch_size, num_frames // 4 + 1, torch_device)

        text_tokens_and_mask = tokenizer(
            '',
            max_length=226,
            padding="max_length",
            truncation=True,
            return_attention_mask=True,
            add_special_tokens=True,
            return_tensors="pt",
        )
        input_ids = text_tokens_and_mask["input_ids"].to(torch_device)
        attention_mask = text_tokens_and_mask["attention_mask"].to(torch_device)

        prompt_embeds = text_encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )["last_hidden_state"].detach()
        prompt_embeds = prompt_embeds[:, None].squeeze(1)

    example_inputs = {
        'hidden_states': torch.randn(1, num_frames // 4 + 1, 33, height // 8, width // 8).to(torch_device, dtype=torch.bfloat16),
        'timestep': timestep,
        'encoder_hidden_states': prompt_embeds,
        'image_rotary_emb': image_rotary_emb,
        'guidance_scale': torch.full([1], 1, device=torch_device, dtype=torch.float32)   # It's a CFG distillation version of CogVideoX
    }

    # Select importance criterion
    if args.pruning_type == 'random':
        imp = tp.importance.RandomImportance()
    elif args.pruning_type == 'l1':
        imp = tp.importance.GroupNormImportance(p=1)
    elif args.pruning_type == 'l2':
        imp = tp.importance.GroupNormImportance(p=2)
    elif args.pruning_type == 'taylor':
        imp = tp.importance.GroupTaylorImportance()
    elif args.pruning_type == 'hessian':
        imp = tp.importance.GroupHessianImportance()
    else:
        raise NotImplementedError(f"Pruning type {args.pruning_type} not implemented")

    # Get baseline metrics
    base_macs, base_params = tp.utils.count_ops_and_params(cogvideox, example_inputs)
    
    print("\n============ Before Pruning ============")
    print(f"Parameters: {base_params/1e6:.2f}M")
    print(f"MACs: {base_macs/1e9:.2f}G")

    cogvideox.zero_grad()
    cogvideox.eval()

    ignored_layers = [cogvideox.proj_out]

    from diffusers.models.attention import Attention
    num_heads = {}
    for m in cogvideox.modules():
        if isinstance(m, Attention):
            num_heads[m.to_q] = m.heads
            num_heads[m.to_k] = m.heads
            num_heads[m.to_v] = m.heads

    # Initialize pruner, here I got stuck
    pruner = tp.pruner.MetaPruner(
        cogvideox,
        example_inputs,
        importance=imp,
        global_pruning=False,
        num_heads=num_heads,
        pruning_ratio=args.pruning_ratio,
        ignored_layers=ignored_layers,
        prune_num_heads=True,
        prune_head_dims=False,
        head_pruning_ratio=args.pruning_ratio,
        round_to=4,
    )
    
    for g in pruner.step(interactive=True):
        g.prune()

Does torch_pruning support pruning CogVideoX now? Or Is there something wrong with my code?

@mendeleevprc
Copy link

您在优化模型时遇到显存问题,我们的A100/H100集群可提供高显存即时租赁,为您提供灵活的GPU算力调用,需不需要给您做个免费压力测试?微信或电话联系18059827182

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants