diff --git a/src/pythae/trainers/base_trainer/base_trainer.py b/src/pythae/trainers/base_trainer/base_trainer.py index 073deede..606d5019 100644 --- a/src/pythae/trainers/base_trainer/base_trainer.py +++ b/src/pythae/trainers/base_trainer/base_trainer.py @@ -87,11 +87,12 @@ def __init__( device = self._setup_devices() else: - device = ( - "cuda" - if torch.cuda.is_available() and not self.training_config.no_cuda - else "cpu" - ) + if torch.cuda.is_available() and not self.training_config.no_cuda: + device = "cuda" + elif torch.backends.mps.is_available() and not self.training_config.no_cuda: + device = "mps" + else: + device = "cpu" self.amp_context = ( torch.autocast("cuda") @@ -174,8 +175,9 @@ def _setup_devices(self): device = "cpu" else: - torch.cuda.set_device(self.local_rank) - device = torch.device("cuda", self.local_rank) + if not device == "mps": + torch.cuda.set_device(self.local_rank) + device = torch.device("cuda", self.local_rank) if not dist.is_initialized(): dist.init_process_group(