Skip to content

Commit 02705a9

Browse files
committed
Fix for passing a path as attn_spec.
1 parent b20be32 commit 02705a9

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

models/turbine_models/custom_models/sd_inference/sd_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def __init__(
244244
"exit_on_vmfb": False,
245245
"pipeline_dir": pipeline_dir,
246246
"input_mlir": None,
247-
"attn_spec": None,
247+
"attn_spec": attn_spec,
248248
"external_weights": None,
249249
"external_weight_path": None,
250250
}

models/turbine_models/custom_models/sdxl_inference/unet.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -182,12 +182,13 @@ def export_unet_model(
182182
submodel_name = "punet"
183183
else:
184184
submodel_name = "unet"
185-
if (not decomp_attn) and use_punet:
186-
attn_spec = "punet"
187-
elif (not decomp_attn) and "gfx9" in target:
188-
attn_spec = "mfma"
189-
elif (not decomp_attn) and "gfx11" in target:
190-
attn_spec = "wmma"
185+
if not attn_spec:
186+
if (not decomp_attn) and use_punet:
187+
attn_spec = "punet"
188+
elif (not decomp_attn) and "gfx9" in target:
189+
attn_spec = "mfma"
190+
elif (not decomp_attn) and "gfx11" in target:
191+
attn_spec = "wmma"
191192
safe_name = utils.create_safe_name(
192193
hf_model_name,
193194
f"_bs{batch_size}_{max_length}_{height}x{width}_{precision}_{submodel_name}",

0 commit comments

Comments
 (0)