diff --git a/closed_form_factorization.py b/closed_form_factorization.py index 8f6ab62a4..b4220355c 100644 --- a/closed_form_factorization.py +++ b/closed_form_factorization.py @@ -21,7 +21,7 @@ modulate = { k[0]: k[1] for k in G.named_parameters() - if "affine" in k[0] and "torgb" not in k[0] and "weight" in k[0] + if "affine" in k[0] and "torgb" not in k[0] and "weight" in k[0] or ("torgb" in k[0] and "b4" in k[0] and "weight" in k[0] and "affine" in k[0]) } weight_mat = []