@@ -97,11 +97,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
9797
9898
9999class HunyuanVideoHistoryPatchEmbed (nn .Module ):
100- def __init__ (self , inner_dim : int ):
100+ def __init__ (self , in_channels : int , inner_dim : int ):
101101 super ().__init__ ()
102- self .proj = nn .Conv3d (16 , inner_dim , kernel_size = (1 , 2 , 2 ), stride = (1 , 2 , 2 ))
103- self .proj_2x = nn .Conv3d (16 , inner_dim , kernel_size = (2 , 4 , 4 ), stride = (2 , 4 , 4 ))
104- self .proj_4x = nn .Conv3d (16 , inner_dim , kernel_size = (4 , 8 , 8 ), stride = (4 , 8 , 8 ))
102+ self .proj = nn .Conv3d (in_channels , inner_dim , kernel_size = (1 , 2 , 2 ), stride = (1 , 2 , 2 ))
103+ self .proj_2x = nn .Conv3d (in_channels , inner_dim , kernel_size = (2 , 4 , 4 ), stride = (2 , 4 , 4 ))
104+ self .proj_4x = nn .Conv3d (in_channels , inner_dim , kernel_size = (4 , 8 , 8 ), stride = (4 , 8 , 8 ))
105105
106106 def forward (
107107 self ,
@@ -131,7 +131,7 @@ class HunyuanVideoFramepackTransformer3DModel(
131131 _no_split_modules = [
132132 "HunyuanVideoTransformerBlock" ,
133133 "HunyuanVideoSingleTransformerBlock" ,
134- "HunyuanVideoPatchEmbedForCleanLatents" , # TODO
134+ "HunyuanVideoHistoryPatchEmbed" ,
135135 "HunyuanVideoTokenRefiner" ,
136136 ]
137137
@@ -205,7 +205,7 @@ def __init__(
205205
206206 self .clean_x_embedder = None
207207 if has_clean_x_embedder :
208- self .clean_x_embedder = HunyuanVideoHistoryPatchEmbed (inner_dim )
208+ self .clean_x_embedder = HunyuanVideoHistoryPatchEmbed (in_channels , inner_dim )
209209
210210 self .use_gradient_checkpointing = False
211211
0 commit comments