Skip to content

Commit 104aa69

Browse files
committed
fix variant loading issues.
1 parent 8f0786a commit 104aa69

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

tests/pipelines/test_pipelines_common.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2065,7 +2065,16 @@ def is_nan(tensor):
20652065
for component_name in model_components_pipe:
20662066
pipe_component = model_components_pipe[component_name]
20672067
pipe_loaded_component = model_components_pipe_loaded[component_name]
2068-
for p1, p2 in zip(pipe_component.parameters(), pipe_loaded_component.parameters()):
2068+
2069+
model_loaded_params = dict(pipe_loaded_component.named_parameters())
2070+
model_original_params = dict(pipe_component.named_parameters())
2071+
2072+
for name, p1 in model_original_params.items():
2073+
# Skip tied weights that aren't saved with variants (transformers v5 behavior)
2074+
if name not in model_loaded_params:
2075+
continue
2076+
2077+
p2 = model_loaded_params[name]
20692078
# nan check for luminanext (mps).
20702079
if not (is_nan(p1) and is_nan(p2)):
20712080
self.assertTrue(torch.equal(p1, p2))

0 commit comments

Comments
 (0)