Skip to content
Closed
1 change: 1 addition & 0 deletions ding/policy/base_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,7 @@ def sync_gradients(self, model: torch.nn.Module) -> None:
else:
synchronize()


# don't need to implement default_model method by force
def default_model(self) -> Tuple[str, List[str]]:
"""
Expand Down
19 changes: 19 additions & 0 deletions ding/utils/pytorch_ddp_dist_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,25 @@ def allreduce(x: torch.Tensor) -> None:
dist.all_reduce(x)
x.div_(get_world_size())

def allreduce_with_indicator(grad: torch.Tensor, indicator: torch.Tensor) -> None:
"""
Overview:
Custom allreduce: Sum both the gradient and indicator tensors across all processes.
Then, if at least one process contributed (i.e., the summation of indicator > 0),
divide the gradient by the summed indicator. This ensures that if only a subset of
GPUs contributed a gradient, the averaging is performed based on the actual number
of contributors rather than the total number of GPUs.
Arguments:
- grad (torch.Tensor): Local gradient tensor to be reduced.
- indicator (torch.Tensor): A tensor flag (1 if the gradient is computed, 0 otherwise).
"""
# Allreduce (sum) the gradient and indicator
dist.all_reduce(grad)
dist.all_reduce(indicator)

# Avoid division by zero. If indicator is close to 0 (extreme case), grad remains zeros.
if not torch.isclose(indicator, torch.tensor(0.0)):
grad.div_(indicator.item())

def allreduce_with_indicator(grad: torch.Tensor, indicator: torch.Tensor) -> None:
"""
Expand Down
3 changes: 2 additions & 1 deletion ding/worker/learner/base_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ def __init__(
self._logger, _ = build_logger(
'./{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name, need_tb=False
)
self._tb_logger = None
self._tb_logger = tb_logger


self._log_buffer = {
'scalar': build_log_buffer(),
Expand Down
Loading