2222import pytest
2323import torchao
2424from executorch .extension .pybindings .portable_lib import ExecuTorchModule
25+
26+ from optimum .executorch import ExecuTorchModelForMaskedLM
2527from packaging .version import parse
2628from transformers import AutoTokenizer
2729from transformers .testing_utils import slow
2830
29- from optimum .executorch import ExecuTorchModelForMaskedLM
30-
3131
3232@pytest .mark .skipif (
3333 parse (torchao .__version__ ) < parse ("0.11.0.dev0" ),
@@ -70,7 +70,9 @@ def _helper_bert_fill_mask(self, recipe: str):
7070 tokenizer = AutoTokenizer .from_pretrained (model_id )
7171
7272 # Test fetching and lowering the model to ExecuTorch
73- model = ExecuTorchModelForMaskedLM .from_pretrained (model_id = model_id , recipe = recipe )
73+ model = ExecuTorchModelForMaskedLM .from_pretrained (
74+ model_id = model_id , recipe = recipe
75+ )
7476 self .assertIsInstance (model , ExecuTorchModelForMaskedLM )
7577 self .assertIsInstance (model .model , ExecuTorchModule )
7678
@@ -85,9 +87,14 @@ def _helper_bert_fill_mask(self, recipe: str):
8587 # Test inference using ExecuTorch model
8688 exported_outputs = model .forward (inputs ["input_ids" ], inputs ["attention_mask" ])
8789 predicted_masks = tokenizer .decode (exported_outputs [0 , 4 ].topk (5 ).indices )
88- logging .info (f"\n Input text:\n \t { input_text } \n Predicted masks:\n \t { predicted_masks } " )
90+ logging .info (
91+ f"\n Input text:\n \t { input_text } \n Predicted masks:\n \t { predicted_masks } "
92+ )
8993 self .assertTrue (
90- any (word in predicted_masks for word in ["capital" , "center" , "heart" , "birthplace" ]),
94+ any (
95+ word in predicted_masks
96+ for word in ["capital" , "center" , "heart" , "birthplace" ]
97+ ),
9198 f"Exported model predictions { predicted_masks } don't contain any of the most common expected words" ,
9299 )
93100
@@ -101,3 +108,7 @@ def test_bert_fill_mask(self):
101108 @pytest .mark .portable
102109 def test_bert_fill_mask_portable (self ):
103110 self ._helper_bert_fill_mask ("portable" )
111+
112+ @pytest .mark .run_slow
113+ def test_bert_fill_mask_qnn (self ):
114+ self ._helper_bert_fill_mask (recipe = "qnn_fp16_SM8650" )
0 commit comments