diff --git a/benchmarks/torchbench_model.py b/benchmarks/torchbench_model.py index b9dea15adf6..55b7f555276 100644 --- a/benchmarks/torchbench_model.py +++ b/benchmarks/torchbench_model.py @@ -373,6 +373,10 @@ def is_accelerator_tpu(self): return self.benchmark_experiment.accelerator == "tpu" def use_amp(self): + # AMP is only supported on cuda and tpu, not on cpu. + if self.benchmark_experiment.accelerator == "cpu": + logger.warning("AMP is not used due to running on CPU.") + return False return self.is_training() or self.model_name in config( ).dtype.force_amp_for_fp16_bf16_models