File tree 2 files changed +8
-7
lines changed
models/turbine_models/custom_models 2 files changed +8
-7
lines changed Original file line number Diff line number Diff line change @@ -244,7 +244,7 @@ def __init__(
244
244
"exit_on_vmfb" : False ,
245
245
"pipeline_dir" : pipeline_dir ,
246
246
"input_mlir" : None ,
247
- "attn_spec" : None ,
247
+ "attn_spec" : attn_spec ,
248
248
"external_weights" : None ,
249
249
"external_weight_path" : None ,
250
250
}
Original file line number Diff line number Diff line change @@ -182,12 +182,13 @@ def export_unet_model(
182
182
submodel_name = "punet"
183
183
else :
184
184
submodel_name = "unet"
185
- if (not decomp_attn ) and use_punet :
186
- attn_spec = "punet"
187
- elif (not decomp_attn ) and "gfx9" in target :
188
- attn_spec = "mfma"
189
- elif (not decomp_attn ) and "gfx11" in target :
190
- attn_spec = "wmma"
185
+ if not attn_spec :
186
+ if (not decomp_attn ) and use_punet :
187
+ attn_spec = "punet"
188
+ elif (not decomp_attn ) and "gfx9" in target :
189
+ attn_spec = "mfma"
190
+ elif (not decomp_attn ) and "gfx11" in target :
191
+ attn_spec = "wmma"
191
192
safe_name = utils .create_safe_name (
192
193
hf_model_name ,
193
194
f"_bs{ batch_size } _{ max_length } _{ height } x{ width } _{ precision } _{ submodel_name } " ,
You can’t perform that action at this time.
0 commit comments