This document explains how to add support for new internal models in TeleFuser, including computing model hashes, adding configurations, and implementing necessary converters.
TeleFuser uses a Hash-based automatic recognition mechanism to determine model types. To integrate a new model into the system, you need to:
- Implement the model class (inherit from
BaseModel) - Implement the
state_dict_converterconverter - Use
weight_viewer.pyto compute model hash - Add configuration and test validation
Create a model class inheriting from BaseModel (or appropriate base class based on model type):
# telefuser/models/my_custom_dit.py
import torch
import torch.nn as nn
from telefuser.core.base_model import BaseModel
class MyCustomDiT(BaseModel):
def __init__(
self,
in_channels=16,
out_channels=16,
hidden_size=2048,
num_layers=32,
# ... other parameters
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
# ... model definition
def forward(self, x, t, context, **kwargs):
# Forward logic
pass
@classmethod
def state_dict_converter(cls):
"""Return state dict converter class"""
return MyCustomDiTStateDictConverterModels can optionally implement a from_pretrained class method for convenient model loading. This method provides a unified interface for loading models in pipeline examples:
# telefuser/models/hunyuan_video_text_encoder.py
class TextEncoder(nn.Module):
"""Text encoder using LLM for HunyuanVideo."""
def __init__(
self,
text_encoder_type: str,
max_length: int,
text_encoder_precision: str,
text_encoder_path: str,
# ... other parameters with internal defaults
):
super().__init__()
# ... initialization logic
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
torch_dtype: torch.dtype = torch.bfloat16,
**kwargs,
) -> "TextEncoder":
"""Load TextEncoder from pretrained checkpoint.
Args:
pretrained_model_name_or_path: Path to text encoder model
torch_dtype: Model precision (default: bfloat16)
**kwargs: Ignored for compatibility
Returns:
Loaded TextEncoder instance
"""
# Determine precision from torch_dtype
precision = "bf16" if torch_dtype == torch.bfloat16 else "fp16"
# All internal parameters are set with sensible defaults
return cls(
text_encoder_type="llm",
max_length=1000,
text_encoder_precision=precision,
text_encoder_path=pretrained_model_name_or_path,
tokenizer_type="llm",
# ... other internal defaults
)Key principles for from_pretrained:
- Only expose essential parameters like
pretrained_model_name_or_pathandtorch_dtype - Set all other parameters internally with sensible defaults
- Accept
**kwargsfor compatibility but ignore unknown parameters - Return a fully initialized model instance
Note: If from_pretrained is not implemented, you can still use ModuleManager.load_model() with hash-based auto-recognition, or manually instantiate the model and add it via add_module().
# telefuser/models/hunyuan_video_vae.py
class HunyuanVideoVAE(nn.Module):
"""HunyuanVideo VAE for video encoding/decoding."""
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
torch_dtype: torch.dtype = torch.bfloat16,
**kwargs,
) -> "HunyuanVideoVAE":
"""Load HunyuanVideoVAE from pretrained checkpoint.
Args:
pretrained_model_name_or_path: Path to VAE checkpoint directory
torch_dtype: Model precision (default: bfloat16)
**kwargs: Ignored for compatibility
Returns:
Loaded HunyuanVideoVAE instance
"""
# Load config from JSON
config_path = os.path.join(pretrained_model_name_or_path, "config.json")
with open(config_path, "r") as f:
config = json.load(f)
# Create model with config
model = cls(
in_channels=config.get("in_channels", 3),
out_channels=config.get("out_channels", 3),
# ... other config parameters
)
# Load state dict
state_dict = load_state_dict(os.path.join(pretrained_model_name_or_path, "diffusion_pytorch_model.safetensors"))
model.load_state_dict(state_dict, strict=False)
return model.to(dtype=torch_dtype)Note: Tiling/slicing settings should be handled at runtime by the VAE stage, not during model initialization.
The converter is responsible for transforming weights from different source formats to internal format:
# telefuser/models/my_custom_dit.py
class MyCustomDiTStateDictConverter:
"""
Convert state_dict from different sources to internal format
"""
@staticmethod
def from_official(state_dict):
"""
Convert from Civitai/Direct format
Args:
state_dict: Original state dictionary
Returns:
Converted state_dict, or (state_dict, extra_kwargs) tuple
"""
# Create key mapping
rename_dict = {
"input_blocks.0.0.weight": "conv_in.weight",
"input_blocks.0.0.bias": "conv_in.bias",
# ... more mappings
}
converted_state_dict = {}
for old_key, new_key in rename_dict.items():
if old_key in state_dict:
converted_state_dict[new_key] = state_dict[old_key]
# Return extra_kwargs if model parameters need to be inferred from weights
extra_kwargs = {
"hidden_size": 2048, # Infer from weights or hard-code
"num_layers": 32,
}
return converted_state_dict, extra_kwargs
@staticmethod
def from_diffusers(state_dict):
"""Convert from Diffusers format"""
# Similar implementation
passUse the built-in weight_viewer.py tool to analyze the model:
# Quick hash retrieval
python tools/viewer/weight_viewer.py /path/to/your/model.safetensors --quietOutput example:
Total parameters: 14.02B
Files: 1
hash with shape: 4c3523c69fb7b24cf2db147a715b277f
Record the hash with shape value, which will be added to the configuration.
For more detailed analysis (view model structure to help implement StateDictConverter):
# View complete structure and export
python tools/viewer/weight_viewer.py /path/to/your/model.safetensors \
--max-depth 10 \
--export model_structure.jsonAdvantages of using weight_viewer:
- Automatically handles sharded models (using wildcards
model-*.safetensors) - Displays parameter statistics and data type distribution
- Automatically merges structurally identical modules (like transformer blocks)
- Exports JSON for further analysis
If the model is split into multiple files:
# Automatically merge all shards and compute hash
python tools/viewer/weight_viewer.py "/path/to/model-*.safetensors" --quietNote: When adding to model_config.py, ensure the hash is based on merged complete weights.
Edit telefuser/core/model_config.py to add model configuration.
First, get information from weight_viewer output:
$ python tools/viewer/weight_viewer.py /path/to/my_model.safetensors --quiet
Total parameters: 6.91B
Files: 1
hash with shape: a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6Then add configuration:
from ..models.my_custom_dit import MyCustomDiT
model_loader_configs = [
# ... existing configurations ...
# MyCustomDiT - Standard version (from weight_viewer: hash=a1b2c3d4...)
# Parameters: 6.91B
(
None, # hash without shape (optional, for non-strict matching)
"a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6", # hash with shape (from weight_viewer)
["my_custom_dit"], # model_name (for fetch_module)
[MyCustomDiT], # model_class
"official", # model_resource: "official" or "diffusers"
),
]If the same model has multiple variants (e.g., FP8 version):
# Analyze FP8 version
$ python tools/viewer/weight_viewer.py /path/to/my_model_fp8.safetensors --quiet
Total parameters: 6.91B
Files: 1
hash with shape: b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6q7 # Different hash!Add to configuration:
# MyCustomDiT - Standard version (hash: a1b2c3d4...)
(
None,
"a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6",
["my_custom_dit"],
[MyCustomDiT],
"official",
),
# MyCustomDiT - FP8 version (hash: b2c3d4e5...)
# Note: FP8 quantized weights
(
None,
"b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6q7",
["my_custom_dit"],
[MyCustomDiT],
"official",
),Tip: If variants have different tensor shapes (like pruned models), consider using non-strict matching (only using keys_hash).
Configuration field description:
| Field | Type | Description |
|---|---|---|
keys_hash |
str | None |
Hash based only on key names (without shape). For variants where shape may change |
keys_hash_with_shape |
str |
Hash including key names and shape. Strict matching, recommended priority |
model_names |
list[str] |
Model identifier name list, used for fetch_module() |
model_classes |
list[type] |
Corresponding model class list |
model_resource |
str |
Weight source format: "official" or "diffusers" |
Create a test script to verify model loading:
# tests/test_my_custom_model_loading.py
import torch
import pytest
from telefuser.core.module_manager import ModuleManager
def test_my_custom_dit_loading():
"""Test MyCustomDiT model loading"""
module_manager = ModuleManager(device="cpu")
# Test auto-recognition
module_manager.load_model(
"/path/to/your/model.safetensors",
torch_dtype=torch.bfloat16
)
# Verify model can be fetched
model = module_manager.fetch_module("my_custom_dit")
assert model is not None
# Verify model type
from telefuser.models.my_custom_dit import MyCustomDiT
assert isinstance(model, MyCustomDiT)
print("✓ MyCustomDiT loading test passed!")
if __name__ == "__main__":
test_my_custom_dit_loading()Run tests:
pytest tests/test_my_custom_model_loading.py -vWhen creating pipeline examples, use the from_pretrained interface and add_module pattern:
import os
import torch
from telefuser.utils.logging import logger
from telefuser.core.module_manager import ModuleManager
from telefuser.models.hunyuan_video_vae import HunyuanVideoVAE
from telefuser.models.hunyuan_video_text_encoder import HunyuanVideoTextEncoder
def get_pipeline(model_root: str = "/path/to/models"):
"""Create and initialize pipeline with all models."""
module_manager = ModuleManager(device="cpu")
# 1. Load VAE using from_pretrained
vae_path = os.path.join(model_root, "vae")
logger.info(f"Loading VAE from {vae_path}")
vae = HunyuanVideoVAE.from_pretrained(vae_path, torch_dtype=torch.bfloat16)
module_manager.add_module(vae, name="vae")
# 2. Load TextEncoder using from_pretrained
text_encoder_path = os.path.join(model_root, "text_encoder", "llm")
logger.info(f"Loading TextEncoder from {text_encoder_path}")
text_encoder = HunyuanVideoTextEncoder.from_pretrained(text_encoder_path, torch_dtype=torch.bfloat16)
module_manager.add_module(text_encoder, name="text_encoder")
# 3. Load other models similarly...
# transformer = HunyuanVideoDiT.from_pretrained(transformer_path, torch_dtype=torch.bfloat16)
# module_manager.add_module(transformer, name="hunyuan_video_dit")
# 4. Create and initialize pipeline
# pipe = HunyuanVideo15Pipeline(device="cuda", torch_dtype=torch.bfloat16)
# pipe.init(module_manager, pipe_config)
return pipe- Use
from_pretrainedfor all model loading - This provides a consistent interface - Only expose model path externally - All other parameters should be internal defaults
- Use
add_modulewith meaningful names - Names like"vae","text_encoder","hunyuan_video_dit"are used by pipeline stages to fetch modules - Let stages handle runtime settings - Tiling, slicing, and other runtime configurations should be handled by pipeline stages, not during model initialization
| Module Type | Recommended Name | Used By |
|---|---|---|
| VAE | "vae" |
HunyuanVideoVAEStage |
| Text Encoder | "text_encoder" |
HunyuanVideoTextEncodingStage |
| DiT/Transformer | "hunyuan_video_dit" |
HunyuanVideoDenoisingStage |
| Vision Encoder (I2V) | "vision_encoder" |
HunyuanVideoImageEncodingStage |
| Upsampler (SR) | "upsampler" |
HunyuanVideoUpsamplerStage |
| Scheduler | "scheduler" |
Pipeline init |
Some model variants (like FP8 quantized, pruned versions) may have different tensor shapes:
# Main version (strict matching)
(
None, # No non-strict hash needed
"a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6",
["my_model"],
[MyModel],
"official",
),
# FP8 version (different shape, use non-strict matching)
(
"q7r8s9t0u1v2w3x4y5z6a7b8c9d0e1f2", # Only key hash
None, # Don't use shape hash (because shape differs)
["my_model_fp8"],
[MyModelFP8], # May need different class
"official",
),Some model files contain multiple components (like VAE encoder + decoder):
# Separate components in state_dict_converter
@staticmethod
def from_official(state_dict):
encoder_dict = {}
decoder_dict = {}
for key, value in state_dict.items():
if key.startswith("encoder."):
encoder_dict[key[8:]] = value # Remove "encoder." prefix
elif key.startswith("decoder."):
decoder_dict[key[8:]] = value
# Return merged dict, handle in model class
combined_dict = {
"encoder": encoder_dict,
"decoder": decoder_dict,
}
return combined_dict, {"has_separate_components": True}If models may come from different sources (Civitai, HuggingFace Diffusers):
class MyModelStateDictConverter:
@staticmethod
def from_official(state_dict):
# Civitai format conversion
return convert_official_format(state_dict)
@staticmethod
def from_diffusers(state_dict):
# Diffusers format conversion
return convert_diffusers_format(state_dict)Then specify the correct model_resource in configuration.
# View all keys and shapes
python tools/viewer/weight_viewer.py /path/to/model.safetensors --show-all
# Export as JSON for programmatic processing
python tools/viewer/weight_viewer.py /path/to/model.safetensors --export model_info.jsonfrom telefuser.utils.model_weight import load_state_dict, hash_state_dict_keys
from telefuser.core.model_config import model_loader_configs
sd = load_state_dict("/path/to/model.safetensors")
hash_with_shape = hash_state_dict_keys(sd, with_shape=True)
hash_without_shape = hash_state_dict_keys(sd, with_shape=False)
print(f"Model hash (with shape): {hash_with_shape}")
print(f"Model hash (without shape): {hash_without_shape}")
# Check if in configuration
found = False
for config in model_loader_configs:
keys_hash, keys_hash_with_shape, model_names, model_classes, resource = config
if keys_hash_with_shape == hash_with_shape:
print(f"✓ Found match (strict): {model_names}")
found = True
elif keys_hash == hash_without_shape:
print(f"✓ Found match (non-strict): {model_names}")
found = True
if not found:
print("✗ No matching configuration found!")
print(f"Add this to model_config.py:")
print(f' (None, "{hash_with_shape}", ["your_model_name"], [YourModelClass], "official"),')from telefuser.models.my_custom_dit import MyCustomDiT
from telefuser.utils.model_weight import load_state_dict
sd = load_state_dict("/path/to/model.safetensors")
converter = MyCustomDiT.state_dict_converter()
converted, extra_kwargs = converter.from_official(sd)
print(f"Extra kwargs: {extra_kwargs}")
print(f"Converted keys: {list(converted.keys())[:10]}")
# Try initialization
model = MyCustomDiT(**extra_kwargs)
model.load_state_dict(converted, strict=False) # Test with non-strict mode first
print("✓ Model initialized successfully!")# After modifying configuration, quickly verify hash matches
python -c "
from telefuser.core.module_manager import ModuleManager
mm = ModuleManager(device='cpu')
mm.load_model('/path/to/your/model.safetensors')
print('✓ Configuration is correct!')
print(f'Loaded models: {mm.module_name}')
"-
Keep configurations organized
- Group by model type
- Keep different variants of same model together
- Add comments explaining version differences
-
Use strict matching when possible
- Provide
keys_hash_with_shapewhenever possible - Only use non-strict matching when shape may vary
- Provide
-
Document variants in detail
# Wan2.1 T2V 14B - FP8 per-channel quantized # Note: This version has scaled weights for FP8 inference ( None, "4cf556355bc7e9b6545b38f4930f60b1", ["wan_video_dit"], [WanModel], "official", ),
-
Test all variants
- Original version
- FP8 quantized version
- Pruned version
- Different source formats (Civitai vs Diffusers)
-
Naming conventions
- Use lowercase with underscores for
model_names - Prefix indicates model family:
wan_video_,qwen_image_,flashvsr_
- Use lowercase with underscores for
-
Make full use of weight_viewer
# Analyze model before adding configuration python tools/viewer/weight_viewer.py /path/to/model.safetensors --export model_info.json # Compare differences between versions python tools/viewer/weight_viewer.py /path/to/model_v1.safetensors --export v1.json python tools/viewer/weight_viewer.py /path/to/model_v2.safetensors --export v2.json diff v1.json v2.json
Refer to the following files for complete implementation:
- Model implementation:
telefuser/models/wan_video_dit.py - Configuration definition:
telefuser/core/model_config.py(WanModel related configurations) - Usage example:
examples/wan_video/wan21_14b_image_to_video_h100.py
After completing model integration, you can optimize inference performance and memory usage through the following methods.
TeleFuser's ops module provides high-performance neural network operator implementations. Reusing these operators in new models achieves optimal performance:
| Operator | Usage | Performance Optimization |
|---|---|---|
RMSNorm / LayerNorm |
Normalization layers | tf_kernel > Triton > PyTorch |
FeedForward |
Feed-forward networks | Supports GEGLU/SwiGLU |
attention |
Attention computation | Flash Attention 2/3/4, SageAttention |
LinearFP8 |
Quantized linear layers | FP8 inference |
from telefuser.ops.normalization import RMSNorm
from telefuser.ops.ffn import FeedForward
from telefuser.ops.attention import attention
from telefuser.core.config import AttentionConfig, AttnImplType
class MyTransformerBlock(nn.Module):
def __init__(self, dim, num_heads):
super().__init__()
self.norm1 = RMSNorm(dim)
self.norm2 = RMSNorm(dim)
self.ffn = FeedForward(dim=dim, mult=4, activation_fn="geglu")
self.attention_config = AttentionConfig.dense_attention(AttnImplType.FLASH_ATTN_2)For detailed documentation, see Ops Module Documentation.
For large models or long sequence generation, various parallel strategies are available:
from telefuser.core.config import ParallelConfig
# Ulysses sequence parallelism (2 GPU)
config = ParallelConfig(
device_ids=[0, 1],
sp_ulysses_degree=2,
)
pipe_config.dit_config.parallel_config = config
pipe_config.enable_denoising_parallel = True
# CFG + Ulysses (4 GPU)
config = ParallelConfig(
device_ids=[0, 1, 2, 3],
cfg_degree=2,
sp_ulysses_degree=2,
)| Strategy | Use Case | Description |
|---|---|---|
| Ulysses | Medium-length sequences | All-to-All communication |
| Ring | Extra-long sequences | P2P communication, supports arbitrary length |
| USP | Large-scale parallelism | Ulysses + Ring combination |
| CFG Parallel | CFG acceleration | Positive/negative prompt parallel computation |
| Pipeline Parallel | Large model inference | Layer distribution across GPUs |
For detailed configuration, see Parallel Inference Guide.
Use tools/convert/converter.py to quantize models and significantly reduce memory usage:
FP8 Quantization (recommended):
python tools/convert/converter.py \
--source /path/to/model/ \
--output /path/to/output \
--linear_dtype fp8 \
--non_linear_dtype torch.bfloat16 \
--model_type wan_dit \
--quantized \
--single_fileINT8 Quantization:
python tools/convert/converter.py \
--source /path/to/model/ \
--output /path/to/output \
--linear_dtype torch.int8 \
--model_type wan_dit \
--quantized \
--single_fileSupported quantization types: int8, fp8, nvfp4, mxfp4, mxfp6, mxfp8.
For detailed usage, see tools/convert/README.md.
When GPU memory is insufficient, use CPU offloading to temporarily move model weights to CPU:
from telefuser.core.config import OffloadConfig, WeightOffloadType
# Async CPU offload (recommended)
pipe_config.dit_config.offload_config = OffloadConfig(
offload_type=WeightOffloadType.ASYNC_CPU_OFFLOAD,
pin_cpu_memory=True,
prefetch_size=1,
)| Strategy | Memory Savings | Speed Impact | Use Case |
|---|---|---|---|
NO_CPU_OFFLOAD |
None | Fastest | Sufficient VRAM |
MODEL_CPU_OFFLOAD |
~50% | Medium | Moderately constrained |
ASYNC_CPU_OFFLOAD |
~60-70% | Low | 8-16GB VRAM |
SEQUENTIAL_CPU_OFFLOAD |
Maximum | Slowest | <8GB VRAM |
For detailed configuration, see CPU Offloading Guide.
Here's a complete optimization configuration example:
from telefuser.core.config import (
ParallelConfig,
AttentionConfig,
AttnImplType,
OffloadConfig,
WeightOffloadType,
)
# Multi-GPU + Attention optimization + Offloading
pipe_config.dit_config.parallel_config = ParallelConfig(
device_ids=[0, 1],
sp_ulysses_degree=2,
)
pipe_config.dit_config.attention_config = AttentionConfig.dense_attention(
AttnImplType.FLASH_ATTN_2
)
pipe_config.dit_config.offload_config = OffloadConfig(
offload_type=WeightOffloadType.ASYNC_CPU_OFFLOAD,
)
pipe_config.enable_denoising_parallel = True- Model Loading User Guide
- Hash Configuration Management Guide
- Ops Module Documentation - Neural network operator implementations (activations, normalization layers, attention, etc.)
- Parallel Inference Guide - Multi-GPU inference configuration
- CPU Offloading Guide - Memory optimization strategies