diff --git a/acme/agents/jax/ppo/learning.py b/acme/agents/jax/ppo/learning.py index 101f249b9b..914c597b0b 100644 --- a/acme/agents/jax/ppo/learning.py +++ b/acme/agents/jax/ppo/learning.py @@ -14,8 +14,9 @@ """Learner for the PPO agent.""" -from typing import Dict, Iterator, List, NamedTuple, Optional, Tuple +from typing import Dict, Iterator, List, NamedTuple, Optional, Sequence, Tuple +from absl import logging import acme from acme import types from acme.agents.jax.ppo import networks @@ -103,10 +104,18 @@ def __init__( metrics_logging_period: int = 100, pmap_axis_name: str = 'devices', obs_normalization_fns: Optional[normalization.NormalizationFns] = None, + devices: Optional[Sequence[jax.Device]] = None, ): - self.local_learner_devices = jax.local_devices() - self.num_local_learner_devices = jax.local_device_count() - self.learner_devices = jax.devices() + local_devices = jax.local_devices() + process_id = jax.process_index() + logging.info('Learner process id: %s. Devices passed: %s', process_id, + devices) + logging.info('Learner process id: %s. Local devices from JAX API: %s', + process_id, local_devices) + self.learner_devices = devices or jax.devices() + self.local_learner_devices = [d for d in self.learner_devices if d in local_devices] + self.num_local_learner_devices = len(self.local_learner_devices) + self.num_epochs = num_epochs self.num_minibatches = num_minibatches self.metrics_logging_period = metrics_logging_period