Skip to content

Commit 1fb2d5d

Browse files
committed
add qnn tests
1 parent 471c1ac commit 1fb2d5d

File tree

3 files changed

+42
-13
lines changed

3 files changed

+42
-13
lines changed

tests/models/test_modeling_bert.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@
2222
import pytest
2323
import torchao
2424
from executorch.extension.pybindings.portable_lib import ExecuTorchModule
25+
26+
from optimum.executorch import ExecuTorchModelForMaskedLM
2527
from packaging.version import parse
2628
from transformers import AutoTokenizer
2729
from transformers.testing_utils import slow
2830

29-
from optimum.executorch import ExecuTorchModelForMaskedLM
30-
3131

3232
@pytest.mark.skipif(
3333
parse(torchao.__version__) < parse("0.11.0.dev0"),
@@ -70,7 +70,9 @@ def _helper_bert_fill_mask(self, recipe: str):
7070
tokenizer = AutoTokenizer.from_pretrained(model_id)
7171

7272
# Test fetching and lowering the model to ExecuTorch
73-
model = ExecuTorchModelForMaskedLM.from_pretrained(model_id=model_id, recipe=recipe)
73+
model = ExecuTorchModelForMaskedLM.from_pretrained(
74+
model_id=model_id, recipe=recipe
75+
)
7476
self.assertIsInstance(model, ExecuTorchModelForMaskedLM)
7577
self.assertIsInstance(model.model, ExecuTorchModule)
7678

@@ -85,9 +87,14 @@ def _helper_bert_fill_mask(self, recipe: str):
8587
# Test inference using ExecuTorch model
8688
exported_outputs = model.forward(inputs["input_ids"], inputs["attention_mask"])
8789
predicted_masks = tokenizer.decode(exported_outputs[0, 4].topk(5).indices)
88-
logging.info(f"\nInput text:\n\t{input_text}\nPredicted masks:\n\t{predicted_masks}")
90+
logging.info(
91+
f"\nInput text:\n\t{input_text}\nPredicted masks:\n\t{predicted_masks}"
92+
)
8993
self.assertTrue(
90-
any(word in predicted_masks for word in ["capital", "center", "heart", "birthplace"]),
94+
any(
95+
word in predicted_masks
96+
for word in ["capital", "center", "heart", "birthplace"]
97+
),
9198
f"Exported model predictions {predicted_masks} don't contain any of the most common expected words",
9299
)
93100

@@ -101,3 +108,7 @@ def test_bert_fill_mask(self):
101108
@pytest.mark.portable
102109
def test_bert_fill_mask_portable(self):
103110
self._helper_bert_fill_mask("portable")
111+
112+
@pytest.mark.run_slow
113+
def test_bert_fill_mask_qnn(self):
114+
self._helper_bert_fill_mask(recipe="qnn_fp16_SM8650")

tests/models/test_modeling_cvt.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@
2121
import pytest
2222
import torch
2323
from executorch.extension.pybindings.portable_lib import ExecuTorchModule
24-
from transformers import AutoConfig, AutoModelForImageClassification
25-
from transformers.testing_utils import slow
2624

2725
from optimum.executorch import ExecuTorchModelForImageClassification
26+
from transformers import AutoConfig, AutoModelForImageClassification
27+
from transformers.testing_utils import slow
2828

2929
from ..utils import check_close_recursively
3030

@@ -58,11 +58,15 @@ def _helper_cvt_image_classification(self, recipe: str):
5858
pixel_values = torch.rand(batch_size, num_channels, height, width)
5959

6060
# Test fetching and lowering the model to ExecuTorch
61-
et_model = ExecuTorchModelForImageClassification.from_pretrained(model_id=model_id, recipe=recipe)
61+
et_model = ExecuTorchModelForImageClassification.from_pretrained(
62+
model_id=model_id, recipe=recipe
63+
)
6264
self.assertIsInstance(et_model, ExecuTorchModelForImageClassification)
6365
self.assertIsInstance(et_model.model, ExecuTorchModule)
6466

65-
eager_model = AutoModelForImageClassification.from_pretrained(model_id).eval().to("cpu")
67+
eager_model = (
68+
AutoModelForImageClassification.from_pretrained(model_id).eval().to("cpu")
69+
)
6670
with torch.no_grad():
6771
eager_output = eager_model(pixel_values)
6872
et_output = et_model.forward(pixel_values)
@@ -80,3 +84,8 @@ def test_cvt_image_classification(self):
8084
@pytest.mark.portable
8185
def test_cvt_image_classification_portable(self):
8286
self._helper_cvt_image_classification(recipe="portable")
87+
88+
@slow
89+
@pytest.mark.run_slow
90+
def test_cvt_image_classification_qnn(self):
91+
self._helper_cvt_image_classification(recipe="qnn_fp16_SM8650")

tests/models/test_modeling_deit.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@
2121
import pytest
2222
import torch
2323
from executorch.extension.pybindings.portable_lib import ExecuTorchModule
24-
from transformers import AutoConfig, AutoModelForImageClassification
25-
from transformers.testing_utils import slow
2624

2725
from optimum.executorch import ExecuTorchModelForImageClassification
26+
from transformers import AutoConfig, AutoModelForImageClassification
27+
from transformers.testing_utils import slow
2828

2929
from ..utils import check_close_recursively
3030

@@ -58,11 +58,15 @@ def _helper_deit_image_classification(self, recipe: str):
5858
pixel_values = torch.rand(batch_size, num_channels, height, width)
5959

6060
# Test fetching and lowering the model to ExecuTorch
61-
et_model = ExecuTorchModelForImageClassification.from_pretrained(model_id=model_id, recipe=recipe)
61+
et_model = ExecuTorchModelForImageClassification.from_pretrained(
62+
model_id=model_id, recipe=recipe
63+
)
6264
self.assertIsInstance(et_model, ExecuTorchModelForImageClassification)
6365
self.assertIsInstance(et_model.model, ExecuTorchModule)
6466

65-
eager_model = AutoModelForImageClassification.from_pretrained(model_id).eval().to("cpu")
67+
eager_model = (
68+
AutoModelForImageClassification.from_pretrained(model_id).eval().to("cpu")
69+
)
6670
with torch.no_grad():
6771
eager_output = eager_model(pixel_values)
6872
et_output = et_model.forward(pixel_values)
@@ -80,3 +84,8 @@ def test_deit_image_classification(self):
8084
@pytest.mark.portable
8185
def test_deit_image_classification_portable(self):
8286
self._helper_deit_image_classification(recipe="portable")
87+
88+
@slow
89+
@pytest.mark.run_slow
90+
def test_deit_image_classification_qnn(self):
91+
self._helper_deit_image_classification(recipe="qnn_fp16_SM8650")

0 commit comments

Comments
 (0)