Skip to content

Commit fee6195

Browse files
committed
Implement CenteredClip in averager
1 parent b84f62b commit fee6195

File tree

5 files changed

+137
-11
lines changed

5 files changed

+137
-11
lines changed

hivemind/averaging/accumulators.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import dataclasses
2+
from abc import ABC
3+
from typing import Callable, Optional
4+
5+
import torch
6+
7+
8+
class AccumulatorBase(ABC):
9+
def accumulate_part(self, tensor: torch.Tensor, weight: float) -> None:
10+
...
11+
12+
def reduce(self) -> torch.Tensor:
13+
...
14+
15+
16+
AccumulatorFactory = Callable[[torch.Size, int], AccumulatorBase]
17+
18+
19+
class MeanAccumulator(AccumulatorBase):
20+
def __init__(self, part_shape: torch.Size, _n_peers: int):
21+
self._accumulator = torch.zeros(part_shape)
22+
self._denominator = 0.0
23+
24+
def accumulate_part(self, tensor_part: torch.Tensor, weight: float) -> None:
25+
self._accumulator.add_(tensor_part, alpha=weight)
26+
self._denominator += weight
27+
28+
def reduce(self) -> torch.Tensor:
29+
return self._accumulator.div_(self._denominator)
30+
31+
32+
class CenteredClipAccumulator(AccumulatorBase):
33+
def __init__(self, part_shape: torch.Size, n_peers: int, **kwargs):
34+
self._kwargs = kwargs
35+
36+
self._tensors = torch.empty([n_peers] + part_shape)
37+
self._weights = torch.empty(n_peers)
38+
self._index = 0
39+
40+
def accumulate_part(self, tensor_part: torch.Tensor, weight: float) -> None:
41+
self._tensors[self._index] = tensor_part
42+
self._weights[self._index] = weight
43+
self._index += 1
44+
45+
def reduce(self) -> torch.Tensor:
46+
clipped = centered_clip(self._tensors, self._weights, **self._kwargs)
47+
return clipped.result
48+
49+
50+
@dataclasses.dataclass(frozen=True)
51+
class CenteredClipResult:
52+
result: torch.Tensor
53+
n_clipped: torch.Tensor
54+
last_step_delta: torch.Tensor
55+
56+
57+
def centered_clip(
58+
input_tensors: torch.Tensor,
59+
weights: torch.Tensor,
60+
tau: float = 1.0,
61+
n_iters: int = 20,
62+
stop_delta: Optional[float] = None,
63+
) -> CenteredClipResult:
64+
"""
65+
Optimized implementation of CenteredClip from [Karimireddy, 2021].
66+
Intended to be used in a decentralized fashion as in [Gorbunov, 2021].
67+
68+
:stop_delta: Stop iterations early if the ``L_inf`` norm of the last step is less than ``stop_delta``.
69+
Note: if this option is used, the step norm calculations may increase the time per iteration by ~25%.
70+
71+
References:
72+
73+
[Karimireddy, 2021] Karimireddy, Sai Praneeth, Lie He, and Martin Jaggi. "Learning from history for byzantine
74+
robust optimization." International Conference on Machine Learning. PMLR, 2021.
75+
76+
[Gorbunov, 2021] Gorbunov, Eduard, Alexander Borzunov, Michael Diskin, and Max Ryabinin.
77+
"Secure Distributed Training at Scale." arXiv preprint arXiv:2106.11257 (2021).
78+
"""
79+
80+
with torch.no_grad():
81+
n_peers = input_tensors.shape[0]
82+
result_shape = input_tensors.shape[1:]
83+
84+
input_tensors = input_tensors.flatten(start_dim=1)
85+
weights /= weights.sum()
86+
87+
# This finds medians faster than torch.median() and torch.quantile(q=0.5),
88+
# see https://github.com/pytorch/pytorch/issues/51450
89+
sorted_tensors = input_tensors.sort(dim=0).values
90+
result = sorted_tensors[n_peers // 2].clone()
91+
delta = None
92+
93+
diff = torch.sub(input_tensors, result, out=sorted_tensors) # Reuse memory from `sorted_tensors`
94+
for _ in range(n_iters):
95+
norms = diff.norm(dim=1)
96+
coeffs = weights * torch.minimum(torch.tensor(1.0), tau / norms)
97+
98+
if stop_delta is not None:
99+
prev_diff = result[...] = diff[0] # Reuse memory from `result`
100+
101+
# We only need to update `diff` (not `result`) between iterations
102+
diff.addmm_(-coeffs.repeat(n_peers, 1), diff)
103+
104+
if stop_delta is not None:
105+
delta = prev_diff.sub_(diff[0]).max()
106+
if delta < stop_delta:
107+
break
108+
torch.sub(input_tensors[0], diff[0], out=result)
109+
110+
return CenteredClipResult(
111+
result=result.reshape(result_shape), n_clipped=(tau < norms).sum(), last_step_delta=delta
112+
)

hivemind/averaging/allreduce.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import torch
66

7+
from hivemind.averaging.accumulators import AccumulatorFactory
78
from hivemind.averaging.partition import AllreduceException, TensorPartContainer, TensorPartReducer
89
from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
910
from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase, StubBase
@@ -58,6 +59,7 @@ def __init__(
5859
tensors: Sequence[torch.Tensor],
5960
ordered_peer_ids: Sequence[PeerID],
6061
peer_fractions: Tuple[float, ...],
62+
accumulator_factory: AccumulatorFactory,
6163
weights: Optional[Sequence[float]] = None,
6264
modes: Optional[Sequence[AveragingMode]] = None,
6365
gathered: Optional[Dict[PeerID, Any]] = None,
@@ -97,7 +99,8 @@ def __init__(
9799
self.tensor_part_reducer = TensorPartReducer(
98100
tuple(part.shape for part in self.parts_for_local_averaging),
99101
len(self.sender_peer_ids),
100-
self.sender_weights,
102+
weights=self.sender_weights,
103+
accumulator_factory=accumulator_factory,
101104
)
102105

103106
def __repr__(self):

hivemind/averaging/averager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import numpy as np
1616
import torch
1717

18+
from hivemind.averaging.accumulators import AccumulatorFactory, MeanAccumulator
1819
from hivemind.averaging.allreduce import AllreduceException, AllReduceRunner, AveragingMode, GroupID
1920
from hivemind.averaging.group_info import GroupInfo
2021
from hivemind.averaging.load_balancing import load_balance_peers
@@ -112,6 +113,7 @@ def __init__(
112113
compression: CompressionBase = NoCompression(),
113114
state_compression: CompressionBase = NoCompression(),
114115
tensor_infos: Optional[Sequence[CompressionInfo]] = None,
116+
accumulator_factory: AccumulatorFactory = MeanAccumulator,
115117
bandwidth: Optional[float] = None,
116118
min_vector_size: int = 0,
117119
auxiliary: bool = False,
@@ -170,6 +172,7 @@ def __init__(
170172
compression=compression,
171173
part_size_bytes=part_size_bytes,
172174
min_vector_size=min_vector_size,
175+
accumulator_factory=accumulator_factory,
173176
)
174177
self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
175178
self._running_groups: Dict[GroupID, AllReduceRunner] = {} # one or more assembled groups that run all-reduce

hivemind/averaging/partition.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import numpy as np
99
import torch
1010

11+
from hivemind.averaging.accumulators import AccumulatorFactory
1112
from hivemind.compression import CompressionBase, CompressionInfo, NoCompression
1213
from hivemind.proto import runtime_pb2
1314
from hivemind.utils.asyncio import amap_in_executor
@@ -171,16 +172,23 @@ class TensorPartReducer:
171172
:note: even if local peer is not sending data, local parts will be used for shape information
172173
"""
173174

174-
def __init__(self, part_shapes: Sequence[torch.Size], num_senders: int, weights: Optional[Sequence[float]] = None):
175+
def __init__(
176+
self,
177+
part_shapes: Sequence[torch.Size],
178+
num_senders: int,
179+
*,
180+
weights: Optional[Sequence[float]],
181+
accumulator_factory: AccumulatorFactory,
182+
):
175183
self.part_shapes, self.num_senders, self.num_parts = part_shapes, num_senders, len(part_shapes)
176184
self.weights = tuple(weights or (1 for _ in range(num_senders)))
177185
assert len(self.weights) == self.num_senders, "The number of weights is inconsistent with num_senders"
178186
assert all(isinstance(weight, (int, float)) for weight in self.weights)
179187
self.current_part_index = -1 # index in local_parts of the part that should be loaded next
180188
self.current_part_accumulated_from = 0 # number of peers from which the current part was accumulated
181-
self.accumulator = None # this will contain the sum of current tensor part from group peers
182-
self.denominator = 0.0 # total weight accumulated from all peers for current part
183189
self.current_part_future = asyncio.Future()
190+
self.accumulator_factory = accumulator_factory
191+
self.accumulator = None
184192
self.finished = asyncio.Event()
185193
self.reset_accumulators()
186194

@@ -194,8 +202,7 @@ def reset_accumulators(self):
194202
self.current_part_index += 1
195203
self.current_part_accumulated_from = 0
196204
self.current_part_future = asyncio.Future()
197-
self.accumulator = torch.zeros(self.part_shapes[self.current_part_index])
198-
self.denominator = 0.0
205+
self.accumulator = self.accumulator_factory(self.part_shapes[self.current_part_index], self.num_senders)
199206

200207
async def accumulate_part(self, sender_index: int, part_index: int, tensor_part: torch.Tensor) -> torch.Tensor:
201208
"""Add vector part to accumulator, wait for all other vectors to be added, then return the average part"""
@@ -211,21 +218,20 @@ async def accumulate_part(self, sender_index: int, part_index: int, tensor_part:
211218

212219
current_part_future = self.current_part_future
213220

214-
self.accumulator.add_(tensor_part, alpha=self.weights[sender_index])
215-
self.denominator += self.weights[sender_index]
221+
self.accumulator.accumulate_part(tensor_part, self.weights[sender_index])
216222
self.current_part_accumulated_from += 1
217223

218224
assert self.current_part_accumulated_from <= self.num_senders
219225
if self.current_part_accumulated_from == self.num_senders:
220-
current_part_future.set_result(self.accumulator.div_(self.denominator))
226+
current_part_future.set_result(self.accumulator.reduce())
221227
self.reset_accumulators()
222228
return await current_part_future
223229

224230
def finalize(self):
225231
if not self.finished.is_set():
226232
if hasattr(self, "current_part_future"):
227233
self.current_part_future.cancel()
228-
del self.accumulator
234+
self.accumulator = None
229235
self.finished.set()
230236

231237
def __del__(self):

tests/test_allreduce.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch
88

99
from hivemind import Quantile8BitQuantization, aenumerate
10+
from hivemind.averaging.accumulators import MeanAccumulator
1011
from hivemind.averaging.allreduce import AllReduceRunner, AveragingMode
1112
from hivemind.averaging.partition import TensorPartContainer, TensorPartReducer
1213
from hivemind.compression import deserialize_torch_tensor
@@ -119,7 +120,7 @@ async def wait_synchronously():
119120
@pytest.mark.asyncio
120121
async def test_reducer(num_senders: int, num_parts: int, synchronize_prob: float):
121122
tensor_part_shapes = [torch.Size([i]) for i in range(num_parts)]
122-
reducer = TensorPartReducer(tensor_part_shapes, num_senders)
123+
reducer = TensorPartReducer(tensor_part_shapes, num_senders, weights=None, accumulator_factory=MeanAccumulator)
123124

124125
local_tensors_by_sender = [[torch.randn(i) for i in range(num_parts)] for j in range(num_senders)]
125126

@@ -196,6 +197,7 @@ async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions,
196197
tensors=[x.clone() for x in tensors_by_peer[p2p.peer_id]],
197198
ordered_peer_ids=peers,
198199
peer_fractions=peer_fractions,
200+
accumulator_factory=MeanAccumulator,
199201
modes=peer_modes,
200202
weights=averaging_weights,
201203
part_size_bytes=part_size_bytes,

0 commit comments

Comments
 (0)