@@ -1825,24 +1825,22 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
1825
1825
is_i2v_lora = any ("k_img" in k for k in original_state_dict ) and any ("v_img" in k for k in original_state_dict )
1826
1826
lora_down_key = "lora_A" if any ("lora_A" in k for k in original_state_dict ) else "lora_down"
1827
1827
lora_up_key = "lora_B" if any ("lora_B" in k for k in original_state_dict ) else "lora_up"
1828
+ has_time_projection_weight = any (
1829
+ k .startswith ("time_projection" ) and k .endswith (".weight" ) for k in original_state_dict
1830
+ )
1828
1831
1829
- diff_keys = [k for k in original_state_dict if k .endswith ((".diff_b" , ".diff" ))]
1830
- if diff_keys :
1831
- for diff_k in diff_keys :
1832
- param = original_state_dict [diff_k ]
1833
- # The magnitudes of the .diff-ending weights are very low (most are below 1e-4, some are upto 1e-3,
1834
- # and 2 of them are about 1.6e-2 [the case with AccVideo lora]). The low magnitudes mostly correspond
1835
- # to norm layers. Ignoring them is the best option at the moment until a better solution is found. It
1836
- # is okay to ignore because they do not affect the model output in a significant manner.
1837
- threshold = 1.6e-2
1838
- absdiff = param .abs ().max () - param .abs ().min ()
1839
- all_zero = torch .all (param == 0 ).item ()
1840
- all_absdiff_lower_than_threshold = absdiff < threshold
1841
- if all_zero or all_absdiff_lower_than_threshold :
1842
- logger .debug (
1843
- f"Removed { diff_k } key from the state dict as it's all zeros, or values lower than hardcoded threshold."
1844
- )
1845
- original_state_dict .pop (diff_k )
1832
+ for key in list (original_state_dict .keys ()):
1833
+ if key .endswith ((".diff" , ".diff_b" )) and "norm" in key :
1834
+ # NOTE: we don't support this because norm layer diff keys are just zeroed values. We can support it
1835
+ # in future if needed and they are not zeroed.
1836
+ original_state_dict .pop (key )
1837
+ logger .debug (f"Removing { key } key from the state dict as it is a norm diff key. This is unsupported." )
1838
+
1839
+ if "time_projection" in key and not has_time_projection_weight :
1840
+ # AccVideo lora has diff bias keys but not the weight keys. This causes a weird problem where
1841
+ # our lora config adds the time proj lora layers, but we don't have the weights for them.
1842
+ # CausVid lora has the weight keys and the bias keys.
1843
+ original_state_dict .pop (key )
1846
1844
1847
1845
# For the `diff_b` keys, we treat them as lora_bias.
1848
1846
# https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.lora_bias
0 commit comments