From 3cfbc1445650ebc58db81c629fa2e87103ad82d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Thu, 13 Nov 2025 14:33:07 +0100 Subject: [PATCH 1/3] issue with latest transformers --- _unittests/ut_tasks/test_tasks.py | 51 +------------------ .../ut_tasks/test_tasks_feature_extraction.py | 42 +++++++++++++++ .../ut_tasks/test_tasks_summarization.py | 25 +++++++++ 3 files changed, 68 insertions(+), 50 deletions(-) create mode 100644 _unittests/ut_tasks/test_tasks_feature_extraction.py create mode 100644 _unittests/ut_tasks/test_tasks_summarization.py diff --git a/_unittests/ut_tasks/test_tasks.py b/_unittests/ut_tasks/test_tasks.py index 5f9f9e8c..9599359a 100644 --- a/_unittests/ut_tasks/test_tasks.py +++ b/_unittests/ut_tasks/test_tasks.py @@ -1,12 +1,7 @@ import os import unittest import torch -from onnx_diagnostic.ext_test_case import ( - ExtTestCase, - hide_stdout, - has_transformers, - requires_transformers, -) +from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, has_transformers from onnx_diagnostic.helpers.torch_helper import to_any, torch_deepcopy from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs from onnx_diagnostic.torch_export_patches import torch_export_patches @@ -216,50 +211,6 @@ def test_fill_mask(self): model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False ) - @hide_stdout() - @requires_transformers("4.53.99") - def test_feature_extraction_bart_base(self): - mid = "facebook/bart-base" - data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True) - self.assertEqual(data["task"], "feature-extraction") - self.assertIn((data["size"], data["n_weights"]), [(557681664, 139420416)]) - model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] - model(**torch_deepcopy(inputs)) - model(**data["inputs2"]) - with torch_export_patches(patch_transformers=True, verbose=10): - torch.export.export( - model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False - ) - - @hide_stdout() - def test_feature_extraction_tiny_bart(self): - mid = "hf-tiny-model-private/tiny-random-PLBartForConditionalGeneration" - data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True) - self.assertEqual(data["task"], "text2text-generation") - self.assertIn((data["size"], data["n_weights"]), [(3243392, 810848)]) - model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] - model(**inputs) - model(**data["inputs2"]) - with torch_export_patches(patch_transformers=True, verbose=10): - torch.export.export( - model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False - ) - - @requires_transformers("4.51.999") - @hide_stdout() - def test_summarization(self): - mid = "facebook/bart-large-cnn" - data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True) - self.assertEqual(data["task"], "summarization") - self.assertIn((data["size"], data["n_weights"]), [(1625161728, 406290432)]) - model, inputs, _ds = data["model"], data["inputs"], data["dynamic_shapes"] - model(**inputs) - model(**data["inputs2"]) - # with torch_export_patches(patch_transformers=True, verbose=10): - # torch.export.export( - # model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False - # ) - @hide_stdout() def test_text_classification(self): mid = "Intel/bert-base-uncased-mrpc" diff --git a/_unittests/ut_tasks/test_tasks_feature_extraction.py b/_unittests/ut_tasks/test_tasks_feature_extraction.py new file mode 100644 index 00000000..6d2f310f --- /dev/null +++ b/_unittests/ut_tasks/test_tasks_feature_extraction.py @@ -0,0 +1,42 @@ +import unittest +import torch +from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, requires_transformers +from onnx_diagnostic.helpers.torch_helper import torch_deepcopy +from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs +from onnx_diagnostic.torch_export_patches import torch_export_patches +from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str + + +class TestTasksFeatureExtration(ExtTestCase): + @hide_stdout() + @requires_transformers("4.53.99") + def test_feature_extraction_bart_base(self): + mid = "facebook/bart-base" + data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True) + self.assertEqual(data["task"], "feature-extraction") + self.assertIn((data["size"], data["n_weights"]), [(557681664, 139420416)]) + model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] + model(**torch_deepcopy(inputs)) + model(**data["inputs2"]) + with torch_export_patches(patch_transformers=True, verbose=10): + torch.export.export( + model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False + ) + + @hide_stdout() + def test_feature_extraction_tiny_bart(self): + mid = "hf-tiny-model-private/tiny-random-PLBartForConditionalGeneration" + data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True) + self.assertEqual(data["task"], "text2text-generation") + self.assertIn((data["size"], data["n_weights"]), [(3243392, 810848)]) + model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] + model(**inputs) + model(**data["inputs2"]) + with torch_export_patches(patch_transformers=True, verbose=10): + torch.export.export( + model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False + ) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/_unittests/ut_tasks/test_tasks_summarization.py b/_unittests/ut_tasks/test_tasks_summarization.py new file mode 100644 index 00000000..17cd3ed2 --- /dev/null +++ b/_unittests/ut_tasks/test_tasks_summarization.py @@ -0,0 +1,25 @@ +import unittest +from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, requires_transformers +from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs + + +class TestTasksSummarization(ExtTestCase): + @requires_transformers("4.51.999") + @hide_stdout() + def test_summarization(self): + mid = "facebook/bart-large-cnn" + data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True) + self.assertEqual(data["task"], "summarization") + self.assertIn((data["size"], data["n_weights"]), [(1625161728, 406290432)]) + model, inputs, _ds = data["model"], data["inputs"], data["dynamic_shapes"] + print(f"-- {mid}: {self.string_type(inputs, with_shape=True)}") + model(**inputs) + model(**data["inputs2"]) + # with torch_export_patches(patch_transformers=True, verbose=10): + # torch.export.export( + # model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False + # ) + + +if __name__ == "__main__": + unittest.main(verbosity=2) From 5a01286879067e4a705a7dda5e2871ae1f3333c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Thu, 13 Nov 2025 17:00:23 +0100 Subject: [PATCH 2/3] fix --- _unittests/ut_torch_export_patches/test_patch_torch.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/_unittests/ut_torch_export_patches/test_patch_torch.py b/_unittests/ut_torch_export_patches/test_patch_torch.py index 5e260d61..74afca7d 100644 --- a/_unittests/ut_torch_export_patches/test_patch_torch.py +++ b/_unittests/ut_torch_export_patches/test_patch_torch.py @@ -341,7 +341,7 @@ def forward(self, x, ind1, ind2): self.assertIn("export 0/1 specialized due to hint of 1 for dimension", str(e)) dynamic_shapes = use_dyn_not_str(dynamic_string, torch.export.Dim.AUTO) - if has_torch("2.9"): + if has_torch("2.9") and not has_torch("2.9.99"): with self.subTest( name="expected shape should be broadcastable to (>= 2.9)", dynamic_shapes=dynamic_shapes, @@ -352,6 +352,9 @@ def forward(self, x, ind1, ind2): raise AssertionError("torch fixed that case") except RuntimeError as e: self.assertIn("expected shape should be broadcastable to", str(e)) + elif has_torch("2.9.99"): + with torch.fx.experimental._config.patch(backed_size_oblivious=True): + torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes) if not has_torch("2.9"): with self.subTest( From 823e59e658cf2372db6830090381b8e317dd3be7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 14 Nov 2025 01:31:49 +0100 Subject: [PATCH 3/3] fix tasks --- CHANGELOGS.rst | 1 + .../ut_tasks/test_tasks_feature_extraction.py | 15 +- .../ut_tasks/test_tasks_summarization.py | 2 +- _unittests/ut_tasks/try_tasks.py | 31 ++- onnx_diagnostic/tasks/feature_extraction.py | 29 +-- onnx_diagnostic/tasks/summarization.py | 209 ++++++------------ 6 files changed, 128 insertions(+), 159 deletions(-) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 07eaeb2b..73d274b6 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.8.2 +++++ +* :pr:`303`: fix inputs for summarization, feature extraction tasks * :pr:`302`: adds helpers to analyse onnxruntime profiling * :pr:`297`: experiment around a higher ops ``loop_for`` * :pr:`292`, :pr:`293`, :pr:`294`, :pr:`295`: new patches for Qwen models diff --git a/_unittests/ut_tasks/test_tasks_feature_extraction.py b/_unittests/ut_tasks/test_tasks_feature_extraction.py index 6d2f310f..36ec4717 100644 --- a/_unittests/ut_tasks/test_tasks_feature_extraction.py +++ b/_unittests/ut_tasks/test_tasks_feature_extraction.py @@ -11,11 +11,24 @@ class TestTasksFeatureExtration(ExtTestCase): @hide_stdout() @requires_transformers("4.53.99") def test_feature_extraction_bart_base(self): + """ + ata=dict( + input_ids:T7s2x12, + attention_mask:T7s2x12, + past_key_values:EncoderDecoderCache( + self_attention_cache=DynamicCache( + key_cache=#6[T1s2x12x30x64,... + value_cache=#6[T1s2x12x30x64,... + cross_attention_cache=DynamicCache( + key_cache=#6[T1s2x12x4x64 + value_cache=#6[T1s2x12x4x64 + """ mid = "facebook/bart-base" data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True) self.assertEqual(data["task"], "feature-extraction") - self.assertIn((data["size"], data["n_weights"]), [(557681664, 139420416)]) + self.assertIn((data["size"], data["n_weights"]), [(409583616, 102395904)]) model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] + print(f"-- {self.string_type(inputs, with_shape=True)}") model(**torch_deepcopy(inputs)) model(**data["inputs2"]) with torch_export_patches(patch_transformers=True, verbose=10): diff --git a/_unittests/ut_tasks/test_tasks_summarization.py b/_unittests/ut_tasks/test_tasks_summarization.py index 17cd3ed2..87b8a8dc 100644 --- a/_unittests/ut_tasks/test_tasks_summarization.py +++ b/_unittests/ut_tasks/test_tasks_summarization.py @@ -10,7 +10,7 @@ def test_summarization(self): mid = "facebook/bart-large-cnn" data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True) self.assertEqual(data["task"], "summarization") - self.assertIn((data["size"], data["n_weights"]), [(1625161728, 406290432)]) + self.assertIn((data["size"], data["n_weights"]), [(1427701760, 356925440)]) model, inputs, _ds = data["model"], data["inputs"], data["dynamic_shapes"] print(f"-- {mid}: {self.string_type(inputs, with_shape=True)}") model(**inputs) diff --git a/_unittests/ut_tasks/try_tasks.py b/_unittests/ut_tasks/try_tasks.py index 4fb73b23..69ecf2c9 100644 --- a/_unittests/ut_tasks/try_tasks.py +++ b/_unittests/ut_tasks/try_tasks.py @@ -530,7 +530,22 @@ def test_fill_mask(self): print("-- outputs", string_type(output, with_shape=True, with_min_max=True)) @never_test() - def test_feature_extraction(self): + def test_feature_extraction_generate(self): + # clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k feature_ex + # https://huggingface.co/google-bert/bert-base-multilingual-cased + + from transformers import BartTokenizer, BartModel + + tokenizer = BartTokenizer.from_pretrained("facebook/bart-base") + model = BartModel.from_pretrained("facebook/bart-base") + text = "Replace me by any text you'd like." + encoded_input = tokenizer(text, return_tensors="pt") + print(f"-- {string_type(encoded_input, with_shape=True)}") + outputs = model(**encoded_input) + print(f"-- {string_type(outputs, with_shape=True)}") + + @never_test() + def test_feature_extraction_check(self): # clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k feature_ex # https://huggingface.co/google-bert/bert-base-multilingual-cased @@ -541,10 +556,14 @@ def test_feature_extraction(self): text = "Replace me by any text you'd like." encoded_input = tokenizer(text, return_tensors="pt") sequence_length, sequence_length2 = 30, 4 - sequence_length = 3 - batch_size, encoder_attention_heads, encoder_ffn_dim = 1, 12, 64 - batch_size, decoder_attention_heads, decoder_ffn_dim = 1, 12, 64 + # sequence_length = 3 + batch_size, encoder_attention_heads, encoder_ffn_dim = 2, 12, 64 + __________, decoder_attention_heads, decoder_ffn_dim = 2, 12, 64 num_hidden_layers = 6 + encoded_input["input_ids"] = encoded_input["input_ids"].expand((batch_size, -1)) + encoded_input["attention_mask"] = encoded_input["attention_mask"].expand( + (batch_size, -1) + ) encoded_input["past_key_values"] = make_encoder_decoder_cache( make_dynamic_cache( [ @@ -586,9 +605,9 @@ def test_feature_extraction(self): ), ) print() - print("-- inputs", string_type(encoded_input, with_shape=True, with_min_max=True)) + print("-- inputs", string_type(encoded_input, with_shape=True)) output = model(**encoded_input) - print("-- outputs", string_type(output, with_shape=True, with_min_max=True)) + print("-- outputs", string_type(output, with_shape=True)) @never_test() def test_text_classification(self): diff --git a/onnx_diagnostic/tasks/feature_extraction.py b/onnx_diagnostic/tasks/feature_extraction.py index 58b1e3c5..5bb0af99 100644 --- a/onnx_diagnostic/tasks/feature_extraction.py +++ b/onnx_diagnostic/tasks/feature_extraction.py @@ -1,10 +1,6 @@ from typing import Any, Callable, Dict, Optional, Tuple import torch -from ..helpers.config_helper import ( - update_config, - check_hasattr, - default_num_hidden_layers as nhl, -) +from ..helpers.config_helper import update_config, check_hasattr from ..helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache @@ -13,8 +9,9 @@ def reduce_model_config(config: Any) -> Dict[str, Any]: """Reduces a model size.""" - check_hasattr(config, "num_hidden_layers") - kwargs = dict(num_hidden_layers=min(config.num_hidden_layers, nhl())) + check_hasattr(config, "vocab_size") + # Bart architecture does not like too much that the number of layers is changed. + kwargs = dict(vocab_size=2056) update_config(config, kwargs) return kwargs @@ -25,7 +22,8 @@ def get_inputs( batch_size: int, sequence_length: int, dummy_max_token_id: int, - sequence_length2: int = 3, + past_length: int = 30, + past_length2: int = 4, decoder_attention_heads: Optional[int] = None, encoder_attention_heads: Optional[int] = None, encoder_ffn_dim: Optional[int] = None, @@ -73,13 +71,13 @@ def get_inputs( torch.randn( batch_size, encoder_attention_heads, - sequence_length, + past_length, encoder_ffn_dim, ), torch.randn( batch_size, encoder_attention_heads, - sequence_length, + past_length, encoder_ffn_dim, ), ) @@ -92,13 +90,13 @@ def get_inputs( torch.randn( batch_size, decoder_attention_heads, - sequence_length2, + past_length2, decoder_ffn_dim, ), torch.randn( batch_size, decoder_attention_heads, - sequence_length2, + past_length2, decoder_ffn_dim, ), ) @@ -124,7 +122,8 @@ def get_inputs( batch_size=batch_size + 1, sequence_length=sequence_length + add_second_input, dummy_max_token_id=dummy_max_token_id, - sequence_length2=sequence_length2, + past_length=past_length, + past_length2=past_length2, decoder_attention_heads=decoder_attention_heads, encoder_attention_heads=encoder_attention_heads, encoder_ffn_dim=encoder_ffn_dim, @@ -146,7 +145,9 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]: check_hasattr(config, "vocab_size") kwargs = dict( batch_size=2, - sequence_length=30, + sequence_length=12, + past_length=30, + past_length2=4, dummy_max_token_id=31999 if config is None else (config.vocab_size - 1), ) for att in [ diff --git a/onnx_diagnostic/tasks/summarization.py b/onnx_diagnostic/tasks/summarization.py index fe9c8138..e3a0e611 100644 --- a/onnx_diagnostic/tasks/summarization.py +++ b/onnx_diagnostic/tasks/summarization.py @@ -1,23 +1,16 @@ from typing import Any, Callable, Dict, Optional, Tuple import torch from ..helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache -from ..helpers.config_helper import ( - update_config, - check_hasattr, - _pick, - default_num_hidden_layers as nhl, -) +from ..helpers.config_helper import update_config, check_hasattr __TASK__ = "summarization" def reduce_model_config(config: Any) -> Dict[str, Any]: """Reduces a model size.""" - kwargs: Dict[str, Any] = {} - if hasattr(config, "num_decoder_layers"): - config.num_decoder_layers = min(config.num_decoder_layers, 2) - if hasattr(config, "num_hidden_layers"): - config.num_hidden_layers = min(config.num_hidden_layers, nhl()) + check_hasattr(config, "vocab_size") + # Bart architecture does not like too much that the number of layers is changed. + kwargs = dict(vocab_size=2056) update_config(config, kwargs) return kwargs @@ -25,96 +18,66 @@ def reduce_model_config(config: Any) -> Dict[str, Any]: def get_inputs( model: torch.nn.Module, config: Optional[Any], + batch_size: int, + sequence_length: int, dummy_max_token_id: int, - num_key_value_heads_encoder: int, - num_key_value_heads_decoder: int, - num_hidden_layers: int, - head_dim_encoder: int, - head_dim_decoder: int, - batch_size: int = 2, - sequence_length: int = 30, - sequence_length2: int = 3, + past_length: int = 30, + past_length2: int = 4, + decoder_attention_heads: Optional[int] = None, + encoder_attention_heads: Optional[int] = None, + encoder_ffn_dim: Optional[int] = None, + decoder_ffn_dim: Optional[int] = None, + num_hidden_layers: Optional[int] = None, add_second_input: int = 1, **kwargs, # unused ): """ - Generates input for task ``summarization``. - - :param model: model to get the missing information - :param config: configuration used to generate the model - :param head_dim_encoder: last dimension of the cache for the encoder - :param head_dim_decoder: last dimension of the cache for the decoder - :param num_key_value_heads_encoder: number of heads for the encoder - :param num_key_value_heads_decoder: number of heads for the decoder - :param dummy_max_token_id: dummy max token id - :param batch_size: batch size - :param sequence_length: sequence length - :param sequence_length2: new sequence length - :return: dictionary - - Stolen inputs for one model. + Generates inputs for task ``feature-extraction``. + Example: :: - cache_position:T7s1 - past_key_values:EncoderDecoderCache( - self_attention_cache=DynamicCache( - key_cache=#6[T1s1x8x1x64,...], - value_cache=#6[T1s1x8x1x64,...]), - cross_attention_cache=DynamicCache( - key_cache=#6[T1s1x8x16x64,...], - value_cache=#6[T1s1x8x16x64,...])), - decoder_input_ids:T7s1x1, - encoder_outputs:dict(last_hidden_state:T1s1x16x512) + input_ids:T7s1x13[101,72654:A16789.23076923077], + token_type_ids:T7s1x13[0,0:A0.0], + attention_mask:T7s1x13[1,1:A1.0]) """ assert ( "cls_cache" not in kwargs ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}." batch = "batch" - seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096) - cache_length = "cache_length_key" # torch.export.Dim("cache_length", min=1, max=4096) - cache_length2 = "cache_length_val" # torch.export.Dim("cache_length2", min=1, max=4096) - + seq_length = "sequence_length" shapes = { "input_ids": {0: batch, 1: seq_length}, - "decoder_input_ids": {0: batch, 1: "seq_ids"}, - "attention_mask": {0: batch, 1: "seq_mask"}, - # "cache_position": {0: batch, 1: torch.export.Dim.DYNAMIC}, - "past_key_values": [ - [{0: batch, 2: cache_length} for _ in range(num_hidden_layers * 2)], - [{0: batch, 2: cache_length2} for _ in range(num_hidden_layers * 2)], - ], - # one these is selected based on the forward method signature - # "encoder_last_hidden_state": {0: batch, 1: torch.export.Dim.DYNAMIC}, - # "encoder_outputs": {0: batch, 1: torch.export.Dim.DYNAMIC}, + "attention_mask": {0: batch, 1: seq_length}, } - inputs = dict( input_ids=torch.randint(0, dummy_max_token_id, (batch_size, sequence_length)).to( torch.int64 ), - decoder_input_ids=torch.randint( - 0, dummy_max_token_id, (batch_size, sequence_length2) - ).to(torch.int64), attention_mask=torch.ones((batch_size, sequence_length)).to(torch.int64), - # cache_position=torch.arange(sequence_length, sequence_length + sequence_length2) - # .to(torch.int64) - # .expand((batch_size, -1)), - past_key_values=make_encoder_decoder_cache( + ) + if ( + encoder_attention_heads + and decoder_attention_heads + and encoder_ffn_dim + and decoder_ffn_dim + and num_hidden_layers + ): + inputs["past_key_values"] = make_encoder_decoder_cache( make_dynamic_cache( [ ( torch.randn( batch_size, - num_key_value_heads_encoder, - sequence_length, - head_dim_encoder, + encoder_attention_heads, + past_length, + encoder_ffn_dim, ), torch.randn( batch_size, - num_key_value_heads_encoder, - sequence_length, - head_dim_encoder, + encoder_attention_heads, + past_length, + encoder_ffn_dim, ), ) for i in range(num_hidden_layers) @@ -125,22 +88,28 @@ def get_inputs( ( torch.randn( batch_size, - num_key_value_heads_decoder, - sequence_length2, - head_dim_decoder, + decoder_attention_heads, + past_length2, + decoder_ffn_dim, ), torch.randn( batch_size, - num_key_value_heads_decoder, - sequence_length2, - head_dim_decoder, + decoder_attention_heads, + past_length2, + decoder_ffn_dim, ), ) for i in range(num_hidden_layers) ] ), - ), - ) + ) + cache_length = "cache_length_key" + cache_length2 = "cache_length_val" + shapes["past_key_values"] = [ # type: ignore[assignment] + [{0: batch, 2: cache_length} for _ in range(num_hidden_layers * 2)], + [{0: batch, 2: cache_length2} for _ in range(num_hidden_layers * 2)], + ] + res = dict(inputs=inputs, dynamic_shapes=shapes) if add_second_input: assert ( @@ -149,15 +118,16 @@ def get_inputs( res["inputs2"] = get_inputs( model=model, config=config, - dummy_max_token_id=dummy_max_token_id, - num_key_value_heads_encoder=num_key_value_heads_encoder, - num_key_value_heads_decoder=num_key_value_heads_decoder, - num_hidden_layers=num_hidden_layers, - head_dim_encoder=head_dim_encoder, - head_dim_decoder=head_dim_decoder, batch_size=batch_size + 1, sequence_length=sequence_length + add_second_input, - sequence_length2=sequence_length2 + 1, + dummy_max_token_id=dummy_max_token_id, + past_length=past_length, + past_length2=past_length2, + decoder_attention_heads=decoder_attention_heads, + encoder_attention_heads=encoder_attention_heads, + encoder_ffn_dim=encoder_ffn_dim, + decoder_ffn_dim=decoder_ffn_dim, + num_hidden_layers=num_hidden_layers, add_second_input=0, **kwargs, )["inputs"] @@ -171,57 +141,22 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]: If the configuration is None, the function selects typical dimensions. """ if config is not None: - check_hasattr( - config, - "vocab_size", - "hidden_size", - "num_attention_heads", - ("num_hidden_layers", "num_layers"), - ("n_positions", "d_model"), - ( - "num_key_value_heads", - "num_heads", - ("decoder_attention_heads", "encoder_attention_heads"), - ), - ) - # exceptions = { - # "PLBartForConditionalGeneration": ( - # lambda c: c.encoder_attention_heads + c.decoder_attention_heads - # ) - # } + check_hasattr(config, "vocab_size") kwargs = dict( batch_size=2, - sequence_length=30, - sequence_length2=3, - head_dim_encoder=( - 16 if config is None else int(_pick(config, "encoder_ffn_dim") ** 0.5) - ), - head_dim_decoder=( - 16 if config is None else int(_pick(config, "decoder_ffn_dim") ** 0.5) - ), - dummy_max_token_id=31999 if config is None else config.vocab_size - 1, - num_hidden_layers=( - 8 if config is None else _pick(config, "num_hidden_layers", "num_layers") - ), - num_key_value_heads_encoder=( - 16 - if config is None - else _pick( - config, - "encoder_attention_heads", - "num_key_value_heads", - "num_heads", - ) - ), - num_key_value_heads_decoder=( - 16 - if config is None - else _pick( - config, - "decoder_attention_heads", - "num_key_value_heads", - "num_heads", - ) - ), + sequence_length=12, + past_length=30, + past_length2=4, + dummy_max_token_id=31999 if config is None else (config.vocab_size - 1), ) + for att in [ + "decoder_attention_heads", + "encoder_attention_heads", + "encoder_ffn_dim", + "decoder_ffn_dim", + "num_hidden_layers", + ]: + if hasattr(config, att): + kwargs[att] = getattr(config, att) + kwargs["decoder_ffn_dim"] = kwargs["encoder_ffn_dim"] = 64 return kwargs, get_inputs