From 9a48eb35996355e88e2c021497d2dae3e9105423 Mon Sep 17 00:00:00 2001 From: Haifeng Jin Date: Tue, 20 May 2025 23:23:46 +0000 Subject: [PATCH 1/3] disable amp by default on cpu --- benchmarks/torchbench_model.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/benchmarks/torchbench_model.py b/benchmarks/torchbench_model.py index b9dea15adf64..d782f6b7f1b4 100644 --- a/benchmarks/torchbench_model.py +++ b/benchmarks/torchbench_model.py @@ -373,6 +373,9 @@ def is_accelerator_tpu(self): return self.benchmark_experiment.accelerator == "tpu" def use_amp(self): + # AMP is only supported on cuda and tpu, not on cpu. + if self.benchmark_experiment.accelerator == "cpu": + return False return self.is_training() or self.model_name in config( ).dtype.force_amp_for_fp16_bf16_models From e1b3a080a60b5eecf93af3164452bb25719f2efd Mon Sep 17 00:00:00 2001 From: Haifeng Jin Date: Thu, 22 May 2025 17:34:35 +0000 Subject: [PATCH 2/3] throw a warning when AMP not used --- benchmarks/torchbench_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/benchmarks/torchbench_model.py b/benchmarks/torchbench_model.py index d782f6b7f1b4..55b7f5552762 100644 --- a/benchmarks/torchbench_model.py +++ b/benchmarks/torchbench_model.py @@ -375,6 +375,7 @@ def is_accelerator_tpu(self): def use_amp(self): # AMP is only supported on cuda and tpu, not on cpu. if self.benchmark_experiment.accelerator == "cpu": + logger.warning("AMP is not used due to running on CPU.") return False return self.is_training() or self.model_name in config( ).dtype.force_amp_for_fp16_bf16_models From 86d36393567079becf6079aaeef9fd8efd09ec53 Mon Sep 17 00:00:00 2001 From: Haifeng Jin Date: Mon, 2 Jun 2025 18:33:16 +0000 Subject: [PATCH 3/3] add unit tests --- test/benchmarks/test_torchbench_model.py | 33 ++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 test/benchmarks/test_torchbench_model.py diff --git a/test/benchmarks/test_torchbench_model.py b/test/benchmarks/test_torchbench_model.py new file mode 100644 index 000000000000..6bef0682eb9f --- /dev/null +++ b/test/benchmarks/test_torchbench_model.py @@ -0,0 +1,33 @@ +import unittest + +from benchmarks.torchbench_model import TorchBenchModel + + +class MockExperiment: + + def __init__(self, accelerator, test): + self.accelerator = accelerator + self.test = "train" + + +class TorchBenchModelTest(unittest.TestCase): + + def test_do_not_use_amp_on_cpu_and_warns(self): + experiment = MockExperiment("cpu", "train") + model = TorchBenchModel("torchbench or other", "super_deep_model", + experiment) + with self.assertLogs('benchmarks.torchbench_model', level='WARNING') as cm: + use_amp = model.use_amp() + self.assertEqual(len(cm.output), 1) + self.assertIn("AMP is not used", cm.output[0]) + self.assertFalse(use_amp) + + def test_use_amp_on_cuda(self): + experiment = MockExperiment("cuda", "train") + model = TorchBenchModel("torchbench or other", "super_deep_model", + experiment) + self.assertTrue(model.use_amp()) + + +if __name__ == '__main__': + unittest.main()