|
1 | 1 | import os |
2 | 2 | import unittest |
3 | 3 | import torch |
4 | | -from onnx_diagnostic.ext_test_case import ( |
5 | | - ExtTestCase, |
6 | | - hide_stdout, |
7 | | - has_transformers, |
8 | | - requires_transformers, |
9 | | -) |
| 4 | +from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, has_transformers |
10 | 5 | from onnx_diagnostic.helpers.torch_helper import to_any, torch_deepcopy |
11 | 6 | from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs |
12 | 7 | from onnx_diagnostic.torch_export_patches import torch_export_patches |
@@ -216,50 +211,6 @@ def test_fill_mask(self): |
216 | 211 | model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False |
217 | 212 | ) |
218 | 213 |
|
219 | | - @hide_stdout() |
220 | | - @requires_transformers("4.53.99") |
221 | | - def test_feature_extraction_bart_base(self): |
222 | | - mid = "facebook/bart-base" |
223 | | - data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True) |
224 | | - self.assertEqual(data["task"], "feature-extraction") |
225 | | - self.assertIn((data["size"], data["n_weights"]), [(557681664, 139420416)]) |
226 | | - model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] |
227 | | - model(**torch_deepcopy(inputs)) |
228 | | - model(**data["inputs2"]) |
229 | | - with torch_export_patches(patch_transformers=True, verbose=10): |
230 | | - torch.export.export( |
231 | | - model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False |
232 | | - ) |
233 | | - |
234 | | - @hide_stdout() |
235 | | - def test_feature_extraction_tiny_bart(self): |
236 | | - mid = "hf-tiny-model-private/tiny-random-PLBartForConditionalGeneration" |
237 | | - data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True) |
238 | | - self.assertEqual(data["task"], "text2text-generation") |
239 | | - self.assertIn((data["size"], data["n_weights"]), [(3243392, 810848)]) |
240 | | - model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] |
241 | | - model(**inputs) |
242 | | - model(**data["inputs2"]) |
243 | | - with torch_export_patches(patch_transformers=True, verbose=10): |
244 | | - torch.export.export( |
245 | | - model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False |
246 | | - ) |
247 | | - |
248 | | - @requires_transformers("4.51.999") |
249 | | - @hide_stdout() |
250 | | - def test_summarization(self): |
251 | | - mid = "facebook/bart-large-cnn" |
252 | | - data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True) |
253 | | - self.assertEqual(data["task"], "summarization") |
254 | | - self.assertIn((data["size"], data["n_weights"]), [(1625161728, 406290432)]) |
255 | | - model, inputs, _ds = data["model"], data["inputs"], data["dynamic_shapes"] |
256 | | - model(**inputs) |
257 | | - model(**data["inputs2"]) |
258 | | - # with torch_export_patches(patch_transformers=True, verbose=10): |
259 | | - # torch.export.export( |
260 | | - # model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False |
261 | | - # ) |
262 | | - |
263 | 214 | @hide_stdout() |
264 | 215 | def test_text_classification(self): |
265 | 216 | mid = "Intel/bert-base-uncased-mrpc" |
|
0 commit comments