Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 34 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ FramePack can be trained with a much larger batch size, similar to the batch siz

# Notes

Note that this GitHub repository is the only official FramePack website. We do not have any web services. All other websites are spam and fake, including but not limited to `framepack.co`, `frame_pack.co`, `framepack.net`, `frame_pack.net`, `framepack.ai`, `frame_pack.ai`, `framepack.pro`, `frame_pack.pro`, `framepack.cc`, `frame_pack.cc`,`framepackai.co`, `frame_pack_ai.co`, `framepackai.net`, `frame_pack_ai.net`, `framepackai.pro`, `frame_pack_ai.pro`, `framepackai.cc`, `frame_pack_ai.cc`, and so on. Again, they are all spam and fake. **Do not pay money or download files from any of those websites.**
Note that this GitHub repository is the only official FramePack website. We do not have any web services. All other websites are spam and fake, including but not limited to `framepack.co`, `frame_pack.co`, `framepack.net`, `frame_pack.net`, `framepack.ai`, `frame_pack.ai`, `framepack.pro`, `frame_pack.pro`, `framepack.cc`, `framepack.cc`,`framepackai.co`, `frame_pack_ai.co`, `framepackai.net`, `frame_pack_ai.net`, `framepackai.pro`, `frame_pack_ai.pro`, `framepackai.cc`, `frame_pack_ai.cc`, and so on. Again, they are all spam and fake. **Do not pay money or download files from any of those websites.**

The team is on leave between April 21 and 29. PR merging will be delayed.

Expand All @@ -44,7 +44,37 @@ In any case, you will directly see the generated frames since it is next-frame(-

# Installation

**Windows**:
## For M3 Mac Users

1. Install Python 3.10 or newer:
```bash
brew install [email protected]
```

2. Create and activate a virtual environment:
```bash
python3.10 -m venv venv
source venv/bin/activate
```

3. Install PyTorch with MPS support:
```bash
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
```

4. Install other dependencies:
```bash
pip install -r requirements_m3.txt
```

5. Run the application:
```bash
python demo_gradio.py
```

Note: The application will automatically use Metal Performance Shaders (MPS) on M3 Macs for optimal performance.

## For Windows Users

[>>> Click Here to Download One-Click Package (CUDA 12.6 + Pytorch 2.6) <<<](https://github.com/lllyasviel/FramePack/releases/download/windows/framepack_cu126_torch26.7z)

Expand All @@ -56,7 +86,7 @@ Note that running `update.bat` is important, otherwise you may be using a previo

Note that the models will be downloaded automatically. You will download more than 30GB from HuggingFace.

**Linux**:
## For Linux Users

We recommend having an independent Python 3.10.

Expand Down Expand Up @@ -290,7 +320,7 @@ Below are some more examples that you may be interested in reproducing.

<img src="https://github.com/user-attachments/assets/853f4f40-2956-472f-aa7a-fa50da03ed92" width="150">

`The girl suddenly took out a sign that said cute using right hand`
`The girl suddenly took out a sign that said "cute" using right hand`

![image](https://github.com/user-attachments/assets/d51180e4-5537-4e25-a6c6-faecae28648a)

Expand Down
108 changes: 74 additions & 34 deletions demo_gradio.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from diffusers_helper.hf_login import login
from diffusers_helper.device_config import (
get_device,
configure_m3_optimizations,
get_model_dtype,
handle_device_specific_errors
)

import os

Expand Down Expand Up @@ -27,7 +33,6 @@
from diffusers_helper.clip_vision import hf_clip_vision_encode
from diffusers_helper.bucket_tools import find_nearest_bucket


parser = argparse.ArgumentParser()
parser.add_argument('--share', action='store_true')
parser.add_argument("--server", type=str, default='0.0.0.0')
Expand All @@ -40,58 +45,93 @@

print(args)

free_mem_gb = get_cuda_free_memory_gb(gpu)
high_vram = free_mem_gb > 60
# Get device and configure optimizations
device = get_device()
high_vram, gpu_memory_preservation = configure_m3_optimizations()
model_dtype = get_model_dtype()

print(f'Free VRAM {free_mem_gb} GB')
print(f'Using device: {device}')
print(f'High-VRAM Mode: {high_vram}')
print(f'Model dtype: {model_dtype}')

# Load models with M3-optimized settings
text_encoder = LlamaModel.from_pretrained(
"hunyuanvideo-community/HunyuanVideo",
subfolder='text_encoder',
torch_dtype=model_dtype
).to(device)

text_encoder_2 = CLIPTextModel.from_pretrained(
"hunyuanvideo-community/HunyuanVideo",
subfolder='text_encoder_2',
torch_dtype=model_dtype
).to(device)

tokenizer = LlamaTokenizerFast.from_pretrained(
"hunyuanvideo-community/HunyuanVideo",
subfolder='tokenizer'
)

tokenizer_2 = CLIPTokenizer.from_pretrained(
"hunyuanvideo-community/HunyuanVideo",
subfolder='tokenizer_2'
)

vae = AutoencoderKLHunyuanVideo.from_pretrained(
"hunyuanvideo-community/HunyuanVideo",
subfolder='vae',
torch_dtype=model_dtype
).to(device)

text_encoder = LlamaModel.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='text_encoder', torch_dtype=torch.float16).cpu()
text_encoder_2 = CLIPTextModel.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='text_encoder_2', torch_dtype=torch.float16).cpu()
tokenizer = LlamaTokenizerFast.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='tokenizer')
tokenizer_2 = CLIPTokenizer.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='tokenizer_2')
vae = AutoencoderKLHunyuanVideo.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='vae', torch_dtype=torch.float16).cpu()
feature_extractor = SiglipImageProcessor.from_pretrained(
"lllyasviel/flux_redux_bfl",
subfolder='feature_extractor'
)

feature_extractor = SiglipImageProcessor.from_pretrained("lllyasviel/flux_redux_bfl", subfolder='feature_extractor')
image_encoder = SiglipVisionModel.from_pretrained("lllyasviel/flux_redux_bfl", subfolder='image_encoder', torch_dtype=torch.float16).cpu()
image_encoder = SiglipVisionModel.from_pretrained(
"lllyasviel/flux_redux_bfl",
subfolder='image_encoder',
torch_dtype=model_dtype
).to(device)

transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained('lllyasviel/FramePackI2V_HY', torch_dtype=torch.bfloat16).cpu()
transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained(
'lllyasviel/FramePackI2V_HY',
torch_dtype=model_dtype
).to(device)

# Set models to eval mode
vae.eval()
text_encoder.eval()
text_encoder_2.eval()
image_encoder.eval()
transformer.eval()

# Configure memory optimizations
if not high_vram:
vae.enable_slicing()
vae.enable_tiling()

# Set transformer quality settings
transformer.high_quality_fp32_output_for_inference = True
print('transformer.high_quality_fp32_output_for_inference = True')

transformer.to(dtype=torch.bfloat16)
vae.to(dtype=torch.float16)
image_encoder.to(dtype=torch.float16)
text_encoder.to(dtype=torch.float16)
text_encoder_2.to(dtype=torch.float16)

# Disable gradients
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
text_encoder_2.requires_grad_(False)
image_encoder.requires_grad_(False)
transformer.requires_grad_(False)

# Configure model loading based on memory availability
if not high_vram:
# DynamicSwapInstaller is same as huggingface's enable_sequential_offload but 3x faster
DynamicSwapInstaller.install_model(transformer, device=gpu)
DynamicSwapInstaller.install_model(text_encoder, device=gpu)
DynamicSwapInstaller.install_model(transformer, device=device)
DynamicSwapInstaller.install_model(text_encoder, device=device)
else:
text_encoder.to(gpu)
text_encoder_2.to(gpu)
image_encoder.to(gpu)
vae.to(gpu)
transformer.to(gpu)
text_encoder.to(device)
text_encoder_2.to(device)
image_encoder.to(device)
vae.to(device)
transformer.to(device)

stream = AsyncStream()

Expand Down Expand Up @@ -120,8 +160,8 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind
stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Text encoding ...'))))

if not high_vram:
fake_diffusers_current_device(text_encoder, gpu) # since we only encode one text - that is one model move and one encode, offload is same time consumption since it is also one load and one encode.
load_model_as_complete(text_encoder_2, target_device=gpu)
fake_diffusers_current_device(text_encoder, device) # since we only encode one text - that is one model move and one encode, offload is same time consumption since it is also one load and one encode.
load_model_as_complete(text_encoder_2, target_device=device)

llama_vec, clip_l_pooler = encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)

Expand Down Expand Up @@ -151,7 +191,7 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind
stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'VAE encoding ...'))))

if not high_vram:
load_model_as_complete(vae, target_device=gpu)
load_model_as_complete(vae, target_device=device)

start_latent = vae_encode(input_image_pt, vae)

Expand All @@ -160,7 +200,7 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind
stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'CLIP Vision encoding ...'))))

if not high_vram:
load_model_as_complete(image_encoder, target_device=gpu)
load_model_as_complete(image_encoder, target_device=device)

image_encoder_output = hf_clip_vision_encode(input_image_np, feature_extractor, image_encoder)
image_encoder_last_hidden_state = image_encoder_output.last_hidden_state
Expand Down Expand Up @@ -213,7 +253,7 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind

if not high_vram:
unload_complete_models()
move_model_to_device_with_memory_preservation(transformer, target_device=gpu, preserved_memory_gb=gpu_memory_preservation)
move_model_to_device_with_memory_preservation(transformer, target_device=device, preserved_memory_gb=gpu_memory_preservation)

if use_teacache:
transformer.initialize_teacache(enable_teacache=True, num_steps=steps)
Expand Down Expand Up @@ -256,7 +296,7 @@ def callback(d):
negative_prompt_embeds=llama_vec_n,
negative_prompt_embeds_mask=llama_attention_mask_n,
negative_prompt_poolers=clip_l_pooler_n,
device=gpu,
device=device,
dtype=torch.bfloat16,
image_embeddings=image_encoder_last_hidden_state,
latent_indices=latent_indices,
Expand All @@ -276,8 +316,8 @@ def callback(d):
history_latents = torch.cat([generated_latents.to(history_latents), history_latents], dim=2)

if not high_vram:
offload_model_from_device_for_memory_preservation(transformer, target_device=gpu, preserved_memory_gb=8)
load_model_as_complete(vae, target_device=gpu)
offload_model_from_device_for_memory_preservation(transformer, target_device=device, preserved_memory_gb=8)
load_model_as_complete(vae, target_device=device)

real_history_latents = history_latents[:, :, :total_generated_latent_frames, :, :]

Expand Down
53 changes: 53 additions & 0 deletions diffusers_helper/device_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import torch
import os

def get_device():
"""Get the optimal device for the current system."""
if torch.backends.mps.is_available():
return torch.device("mps")
elif torch.cuda.is_available():
return torch.device("cuda")
return torch.device("cpu")

def configure_m3_optimizations():
"""Configure M3-specific optimizations."""
if torch.backends.mps.is_available():
# Enable Metal Performance Shaders
torch.backends.mps.enable_fallback_kernels = True
# Set conservative memory settings
high_vram = False
gpu_memory_preservation = 8
return high_vram, gpu_memory_preservation
return None, None

def get_available_memory():
"""Get available memory for the current device."""
if torch.backends.mps.is_available():
# MPS doesn't provide direct memory info
# Return conservative estimate for M3
return 32 # GB
elif torch.cuda.is_available():
from diffusers_helper.memory import get_cuda_free_memory_gb
return get_cuda_free_memory_gb(torch.device("cuda"))
return 8 # Conservative CPU estimate

def handle_device_specific_errors(e, device):
"""Handle device-specific errors."""
if device.type == "mps":
if "out of memory" in str(e):
# Handle MPS OOM
torch.mps.empty_cache()
return True
return False

def get_optimal_batch_size(device):
"""Get optimal batch size for the current device."""
if device.type == "mps":
return 4 # Conservative batch size for M3
return 8 # Default batch size

def get_model_dtype():
"""Get optimal model dtype for the current device."""
if torch.backends.mps.is_available():
return torch.float16 # M3 prefers float16
return torch.bfloat16 # Default to bfloat16
Loading