Skip to content

Commit a682d15

Browse files
authored
Fixes dummy inputs for summarization, feature_extraction tasks (#303)
* issue with latest transformers * fix * fix tasks
1 parent 54c0b00 commit a682d15

File tree

8 files changed

+198
-208
lines changed

8 files changed

+198
-208
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.8.2
55
+++++
66

7+
* :pr:`303`: fix inputs for summarization, feature extraction tasks
78
* :pr:`302`: adds helpers to analyse onnxruntime profiling
89
* :pr:`297`: experiment around a higher ops ``loop_for``
910
* :pr:`292`, :pr:`293`, :pr:`294`, :pr:`295`: new patches for Qwen models

_unittests/ut_tasks/test_tasks.py

Lines changed: 1 addition & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,7 @@
11
import os
22
import unittest
33
import torch
4-
from onnx_diagnostic.ext_test_case import (
5-
ExtTestCase,
6-
hide_stdout,
7-
has_transformers,
8-
requires_transformers,
9-
)
4+
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, has_transformers
105
from onnx_diagnostic.helpers.torch_helper import to_any, torch_deepcopy
116
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
127
from onnx_diagnostic.torch_export_patches import torch_export_patches
@@ -216,50 +211,6 @@ def test_fill_mask(self):
216211
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
217212
)
218213

219-
@hide_stdout()
220-
@requires_transformers("4.53.99")
221-
def test_feature_extraction_bart_base(self):
222-
mid = "facebook/bart-base"
223-
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
224-
self.assertEqual(data["task"], "feature-extraction")
225-
self.assertIn((data["size"], data["n_weights"]), [(557681664, 139420416)])
226-
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
227-
model(**torch_deepcopy(inputs))
228-
model(**data["inputs2"])
229-
with torch_export_patches(patch_transformers=True, verbose=10):
230-
torch.export.export(
231-
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
232-
)
233-
234-
@hide_stdout()
235-
def test_feature_extraction_tiny_bart(self):
236-
mid = "hf-tiny-model-private/tiny-random-PLBartForConditionalGeneration"
237-
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
238-
self.assertEqual(data["task"], "text2text-generation")
239-
self.assertIn((data["size"], data["n_weights"]), [(3243392, 810848)])
240-
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
241-
model(**inputs)
242-
model(**data["inputs2"])
243-
with torch_export_patches(patch_transformers=True, verbose=10):
244-
torch.export.export(
245-
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
246-
)
247-
248-
@requires_transformers("4.51.999")
249-
@hide_stdout()
250-
def test_summarization(self):
251-
mid = "facebook/bart-large-cnn"
252-
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
253-
self.assertEqual(data["task"], "summarization")
254-
self.assertIn((data["size"], data["n_weights"]), [(1625161728, 406290432)])
255-
model, inputs, _ds = data["model"], data["inputs"], data["dynamic_shapes"]
256-
model(**inputs)
257-
model(**data["inputs2"])
258-
# with torch_export_patches(patch_transformers=True, verbose=10):
259-
# torch.export.export(
260-
# model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
261-
# )
262-
263214
@hide_stdout()
264215
def test_text_classification(self):
265216
mid = "Intel/bert-base-uncased-mrpc"
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import unittest
2+
import torch
3+
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, requires_transformers
4+
from onnx_diagnostic.helpers.torch_helper import torch_deepcopy
5+
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
6+
from onnx_diagnostic.torch_export_patches import torch_export_patches
7+
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
8+
9+
10+
class TestTasksFeatureExtration(ExtTestCase):
11+
@hide_stdout()
12+
@requires_transformers("4.53.99")
13+
def test_feature_extraction_bart_base(self):
14+
"""
15+
ata=dict(
16+
input_ids:T7s2x12,
17+
attention_mask:T7s2x12,
18+
past_key_values:EncoderDecoderCache(
19+
self_attention_cache=DynamicCache(
20+
key_cache=#6[T1s2x12x30x64,...
21+
value_cache=#6[T1s2x12x30x64,...
22+
cross_attention_cache=DynamicCache(
23+
key_cache=#6[T1s2x12x4x64
24+
value_cache=#6[T1s2x12x4x64
25+
"""
26+
mid = "facebook/bart-base"
27+
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
28+
self.assertEqual(data["task"], "feature-extraction")
29+
self.assertIn((data["size"], data["n_weights"]), [(409583616, 102395904)])
30+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
31+
print(f"-- {self.string_type(inputs, with_shape=True)}")
32+
model(**torch_deepcopy(inputs))
33+
model(**data["inputs2"])
34+
with torch_export_patches(patch_transformers=True, verbose=10):
35+
torch.export.export(
36+
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
37+
)
38+
39+
@hide_stdout()
40+
def test_feature_extraction_tiny_bart(self):
41+
mid = "hf-tiny-model-private/tiny-random-PLBartForConditionalGeneration"
42+
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
43+
self.assertEqual(data["task"], "text2text-generation")
44+
self.assertIn((data["size"], data["n_weights"]), [(3243392, 810848)])
45+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
46+
model(**inputs)
47+
model(**data["inputs2"])
48+
with torch_export_patches(patch_transformers=True, verbose=10):
49+
torch.export.export(
50+
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
51+
)
52+
53+
54+
if __name__ == "__main__":
55+
unittest.main(verbosity=2)
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import unittest
2+
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, requires_transformers
3+
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
4+
5+
6+
class TestTasksSummarization(ExtTestCase):
7+
@requires_transformers("4.51.999")
8+
@hide_stdout()
9+
def test_summarization(self):
10+
mid = "facebook/bart-large-cnn"
11+
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
12+
self.assertEqual(data["task"], "summarization")
13+
self.assertIn((data["size"], data["n_weights"]), [(1427701760, 356925440)])
14+
model, inputs, _ds = data["model"], data["inputs"], data["dynamic_shapes"]
15+
print(f"-- {mid}: {self.string_type(inputs, with_shape=True)}")
16+
model(**inputs)
17+
model(**data["inputs2"])
18+
# with torch_export_patches(patch_transformers=True, verbose=10):
19+
# torch.export.export(
20+
# model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
21+
# )
22+
23+
24+
if __name__ == "__main__":
25+
unittest.main(verbosity=2)

_unittests/ut_tasks/try_tasks.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,22 @@ def test_fill_mask(self):
530530
print("-- outputs", string_type(output, with_shape=True, with_min_max=True))
531531

532532
@never_test()
533-
def test_feature_extraction(self):
533+
def test_feature_extraction_generate(self):
534+
# clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k feature_ex
535+
# https://huggingface.co/google-bert/bert-base-multilingual-cased
536+
537+
from transformers import BartTokenizer, BartModel
538+
539+
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
540+
model = BartModel.from_pretrained("facebook/bart-base")
541+
text = "Replace me by any text you'd like."
542+
encoded_input = tokenizer(text, return_tensors="pt")
543+
print(f"-- {string_type(encoded_input, with_shape=True)}")
544+
outputs = model(**encoded_input)
545+
print(f"-- {string_type(outputs, with_shape=True)}")
546+
547+
@never_test()
548+
def test_feature_extraction_check(self):
534549
# clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k feature_ex
535550
# https://huggingface.co/google-bert/bert-base-multilingual-cased
536551

@@ -541,10 +556,14 @@ def test_feature_extraction(self):
541556
text = "Replace me by any text you'd like."
542557
encoded_input = tokenizer(text, return_tensors="pt")
543558
sequence_length, sequence_length2 = 30, 4
544-
sequence_length = 3
545-
batch_size, encoder_attention_heads, encoder_ffn_dim = 1, 12, 64
546-
batch_size, decoder_attention_heads, decoder_ffn_dim = 1, 12, 64
559+
# sequence_length = 3
560+
batch_size, encoder_attention_heads, encoder_ffn_dim = 2, 12, 64
561+
__________, decoder_attention_heads, decoder_ffn_dim = 2, 12, 64
547562
num_hidden_layers = 6
563+
encoded_input["input_ids"] = encoded_input["input_ids"].expand((batch_size, -1))
564+
encoded_input["attention_mask"] = encoded_input["attention_mask"].expand(
565+
(batch_size, -1)
566+
)
548567
encoded_input["past_key_values"] = make_encoder_decoder_cache(
549568
make_dynamic_cache(
550569
[
@@ -586,9 +605,9 @@ def test_feature_extraction(self):
586605
),
587606
)
588607
print()
589-
print("-- inputs", string_type(encoded_input, with_shape=True, with_min_max=True))
608+
print("-- inputs", string_type(encoded_input, with_shape=True))
590609
output = model(**encoded_input)
591-
print("-- outputs", string_type(output, with_shape=True, with_min_max=True))
610+
print("-- outputs", string_type(output, with_shape=True))
592611

593612
@never_test()
594613
def test_text_classification(self):

_unittests/ut_torch_export_patches/test_patch_torch.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ def forward(self, x, ind1, ind2):
341341
self.assertIn("export 0/1 specialized due to hint of 1 for dimension", str(e))
342342

343343
dynamic_shapes = use_dyn_not_str(dynamic_string, torch.export.Dim.AUTO)
344-
if has_torch("2.9"):
344+
if has_torch("2.9") and not has_torch("2.9.99"):
345345
with self.subTest(
346346
name="expected shape should be broadcastable to (>= 2.9)",
347347
dynamic_shapes=dynamic_shapes,
@@ -352,6 +352,9 @@ def forward(self, x, ind1, ind2):
352352
raise AssertionError("torch fixed that case")
353353
except RuntimeError as e:
354354
self.assertIn("expected shape should be broadcastable to", str(e))
355+
elif has_torch("2.9.99"):
356+
with torch.fx.experimental._config.patch(backed_size_oblivious=True):
357+
torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes)
355358

356359
if not has_torch("2.9"):
357360
with self.subTest(

onnx_diagnostic/tasks/feature_extraction.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
from typing import Any, Callable, Dict, Optional, Tuple
22
import torch
3-
from ..helpers.config_helper import (
4-
update_config,
5-
check_hasattr,
6-
default_num_hidden_layers as nhl,
7-
)
3+
from ..helpers.config_helper import update_config, check_hasattr
84
from ..helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache
95

106

@@ -13,8 +9,9 @@
139

1410
def reduce_model_config(config: Any) -> Dict[str, Any]:
1511
"""Reduces a model size."""
16-
check_hasattr(config, "num_hidden_layers")
17-
kwargs = dict(num_hidden_layers=min(config.num_hidden_layers, nhl()))
12+
check_hasattr(config, "vocab_size")
13+
# Bart architecture does not like too much that the number of layers is changed.
14+
kwargs = dict(vocab_size=2056)
1815
update_config(config, kwargs)
1916
return kwargs
2017

@@ -25,7 +22,8 @@ def get_inputs(
2522
batch_size: int,
2623
sequence_length: int,
2724
dummy_max_token_id: int,
28-
sequence_length2: int = 3,
25+
past_length: int = 30,
26+
past_length2: int = 4,
2927
decoder_attention_heads: Optional[int] = None,
3028
encoder_attention_heads: Optional[int] = None,
3129
encoder_ffn_dim: Optional[int] = None,
@@ -73,13 +71,13 @@ def get_inputs(
7371
torch.randn(
7472
batch_size,
7573
encoder_attention_heads,
76-
sequence_length,
74+
past_length,
7775
encoder_ffn_dim,
7876
),
7977
torch.randn(
8078
batch_size,
8179
encoder_attention_heads,
82-
sequence_length,
80+
past_length,
8381
encoder_ffn_dim,
8482
),
8583
)
@@ -92,13 +90,13 @@ def get_inputs(
9290
torch.randn(
9391
batch_size,
9492
decoder_attention_heads,
95-
sequence_length2,
93+
past_length2,
9694
decoder_ffn_dim,
9795
),
9896
torch.randn(
9997
batch_size,
10098
decoder_attention_heads,
101-
sequence_length2,
99+
past_length2,
102100
decoder_ffn_dim,
103101
),
104102
)
@@ -124,7 +122,8 @@ def get_inputs(
124122
batch_size=batch_size + 1,
125123
sequence_length=sequence_length + add_second_input,
126124
dummy_max_token_id=dummy_max_token_id,
127-
sequence_length2=sequence_length2,
125+
past_length=past_length,
126+
past_length2=past_length2,
128127
decoder_attention_heads=decoder_attention_heads,
129128
encoder_attention_heads=encoder_attention_heads,
130129
encoder_ffn_dim=encoder_ffn_dim,
@@ -146,7 +145,9 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
146145
check_hasattr(config, "vocab_size")
147146
kwargs = dict(
148147
batch_size=2,
149-
sequence_length=30,
148+
sequence_length=12,
149+
past_length=30,
150+
past_length2=4,
150151
dummy_max_token_id=31999 if config is None else (config.vocab_size - 1),
151152
)
152153
for att in [

0 commit comments

Comments
 (0)