Skip to content

Commit 0bb1960

Browse files
jan-wassenbergcopybara-github
authored andcommitted
Expand BitSet: add AtomicBitSet, All/First0/MaxSize members
Also expand test to cover all variants. PiperOrigin-RevId: 813195963
1 parent 7e6a088 commit 0bb1960

File tree

2 files changed

+340
-39
lines changed

2 files changed

+340
-39
lines changed

hwy/bit_set.h

Lines changed: 258 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,19 @@
1616
#ifndef HIGHWAY_HWY_BIT_SET_H_
1717
#define HIGHWAY_HWY_BIT_SET_H_
1818

19-
// BitSet with fast Foreach for up to 64 and 4096 members.
19+
// Various BitSet for 64, up to 4096, or any number of bits.
2020

2121
#include <stddef.h>
2222

2323
#include "hwy/base.h"
2424

2525
namespace hwy {
2626

27-
// 64-bit specialization of std::bitset, which lacks Foreach.
27+
// 64-bit specialization of `std::bitset`, which lacks `Foreach`.
2828
class BitSet64 {
2929
public:
30+
constexpr size_t MaxSize() const { return 64; }
31+
3032
// No harm if `i` is already set.
3133
void Set(size_t i) {
3234
HWY_DASSERT(i < 64);
@@ -48,15 +50,24 @@ class BitSet64 {
4850
return (bits_ & (1ULL << i)) != 0;
4951
}
5052

51-
// Returns true if any Get(i) would return true for i in [0, 64).
53+
// Returns true if Get(i) would return true for any i in [0, 64).
5254
bool Any() const { return bits_ != 0; }
5355

54-
// Returns lowest i such that Get(i). Caller must ensure Any() beforehand!
56+
// Returns true if Get(i) would return true for all i in [0, 64).
57+
bool All() const { return bits_ == ~uint64_t{0}; }
58+
59+
// Returns lowest i such that `Get(i)`. Caller must first ensure `Any()`!
5560
size_t First() const {
5661
HWY_DASSERT(Any());
5762
return Num0BitsBelowLS1Bit_Nonzero64(bits_);
5863
}
5964

65+
// Returns lowest i such that `!Get(i)`. Caller must first ensure `!All()`!
66+
size_t First0() const {
67+
HWY_DASSERT(!All());
68+
return Num0BitsBelowLS1Bit_Nonzero64(~bits_);
69+
}
70+
6071
// Returns uint64_t(Get(i)) << i for i in [0, 64).
6172
uint64_t Get64() const { return bits_; }
6273

@@ -78,10 +89,226 @@ class BitSet64 {
7889
uint64_t bits_ = 0;
7990
};
8091

81-
// Two-level bitset for up to kMaxSize <= 4096 values.
92+
// Any number of bits, flat array.
93+
template <size_t kMaxSize>
94+
class BitSet {
95+
static_assert(kMaxSize != 0, "BitSet requires non-zero size");
96+
97+
public:
98+
constexpr size_t MaxSize() const { return kMaxSize; }
99+
100+
// No harm if `i` is already set.
101+
void Set(size_t i) {
102+
HWY_DASSERT(i < kMaxSize);
103+
const size_t idx = i / 64;
104+
const size_t mod = i % 64;
105+
bits_[idx].Set(mod);
106+
}
107+
108+
void Clear(size_t i) {
109+
HWY_DASSERT(i < kMaxSize);
110+
const size_t idx = i / 64;
111+
const size_t mod = i % 64;
112+
bits_[idx].Clear(mod);
113+
HWY_DASSERT(!Get(i));
114+
}
115+
116+
bool Get(size_t i) const {
117+
HWY_DASSERT(i < kMaxSize);
118+
const size_t idx = i / 64;
119+
const size_t mod = i % 64;
120+
return bits_[idx].Get(mod);
121+
}
122+
123+
// Returns true if Get(i) would return true for any i in [0, kMaxSize).
124+
bool Any() const {
125+
for (const BitSet64& bits : bits_) {
126+
if (bits.Any()) return true;
127+
}
128+
return false;
129+
}
130+
131+
// Returns true if Get(i) would return true for all i in [0, kMaxSize).
132+
bool All() const {
133+
for (size_t idx = 0; idx < kNum64 - 1; ++idx) {
134+
if (!bits_[idx].All()) return false;
135+
}
136+
137+
constexpr size_t kRemainder = kMaxSize % 64;
138+
if (kRemainder == 0) {
139+
return bits_[kNum64 - 1].All();
140+
}
141+
return bits_[kNum64 - 1].Count() == kRemainder;
142+
}
143+
144+
// Returns lowest i such that `Get(i)`. Caller must first ensure `Any()`!
145+
size_t First() const {
146+
HWY_DASSERT(Any());
147+
for (size_t idx = 0;; ++idx) {
148+
HWY_DASSERT(idx < kNum64);
149+
if (bits_[idx].Any()) return idx * 64 + bits_[idx].First();
150+
}
151+
}
152+
153+
// Returns lowest i such that `!Get(i)`. Caller must first ensure `All()`!
154+
size_t First0() const {
155+
HWY_DASSERT(!All());
156+
for (size_t idx = 0;; ++idx) {
157+
HWY_DASSERT(idx < kNum64);
158+
if (!bits_[idx].All()) {
159+
const size_t first0 = idx * 64 + bits_[idx].First0();
160+
HWY_DASSERT(first0 < kMaxSize);
161+
return first0;
162+
}
163+
}
164+
}
165+
166+
// Calls `func(i)` for each `i` in the set. It is safe for `func` to modify
167+
// the set, but the current Foreach call is only affected if changing one of
168+
// the not yet visited BitSet64.
169+
template <class Func>
170+
void Foreach(const Func& func) const {
171+
for (size_t idx = 0; idx < kNum64; ++idx) {
172+
bits_[idx].Foreach([idx, &func](size_t mod) { func(idx * 64 + mod); });
173+
}
174+
}
175+
176+
size_t Count() const {
177+
size_t total = 0;
178+
for (const BitSet64& bits : bits_) {
179+
total += bits.Count();
180+
}
181+
return total;
182+
}
183+
184+
private:
185+
static constexpr size_t kNum64 = DivCeil(kMaxSize, size_t{64});
186+
BitSet64 bits_[kNum64];
187+
};
188+
189+
// Any number of bits, flat array, atomic updates to the u64.
190+
template <size_t kMaxSize>
191+
class AtomicBitSet {
192+
static_assert(kMaxSize != 0, "AtomicBitSet requires non-zero size");
193+
194+
// Bits may signal something to other threads, hence relaxed is insufficient.
195+
// Acq/Rel ensures a happens-before relationship.
196+
static constexpr auto kAcq = std::memory_order_acquire;
197+
static constexpr auto kRel = std::memory_order_release;
198+
199+
public:
200+
constexpr size_t MaxSize() const { return kMaxSize; }
201+
202+
// No harm if `i` is already set.
203+
void Set(size_t i) {
204+
HWY_DASSERT(i < kMaxSize);
205+
const size_t idx = i / 64;
206+
const size_t mod = i % 64;
207+
bits_[idx].fetch_or(1ULL << mod, kRel);
208+
}
209+
210+
void Clear(size_t i) {
211+
HWY_DASSERT(i < kMaxSize);
212+
const size_t idx = i / 64;
213+
const size_t mod = i % 64;
214+
bits_[idx].fetch_and(~(1ULL << mod), kRel);
215+
HWY_DASSERT(!Get(i));
216+
}
217+
218+
bool Get(size_t i) const {
219+
HWY_DASSERT(i < kMaxSize);
220+
const size_t idx = i / 64;
221+
const size_t mod = i % 64;
222+
return ((bits_[idx].load(kAcq) & (1ULL << mod))) != 0;
223+
}
224+
225+
// Returns true if Get(i) would return true for any i in [0, kMaxSize).
226+
bool Any() const {
227+
for (const std::atomic<uint64_t>& bits : bits_) {
228+
if (bits.load(kAcq)) return true;
229+
}
230+
return false;
231+
}
232+
233+
// Returns true if Get(i) would return true for all i in [0, kMaxSize).
234+
bool All() const {
235+
for (size_t idx = 0; idx < kNum64 - 1; ++idx) {
236+
if (bits_[idx].load(kAcq) != ~uint64_t{0}) return false;
237+
}
238+
239+
constexpr size_t kRemainder = kMaxSize % 64;
240+
const uint64_t last_bits = bits_[kNum64 - 1].load(kAcq);
241+
if (kRemainder == 0) {
242+
return last_bits == ~uint64_t{0};
243+
}
244+
return PopCount(last_bits) == kRemainder;
245+
}
246+
247+
// Returns lowest i such that `Get(i)`. Caller must first ensure `Any()`!
248+
size_t First() const {
249+
HWY_DASSERT(Any());
250+
for (size_t idx = 0;; ++idx) {
251+
HWY_DASSERT(idx < kNum64);
252+
const uint64_t bits = bits_[idx].load(kAcq);
253+
if (bits != 0) {
254+
return idx * 64 + Num0BitsBelowLS1Bit_Nonzero64(bits);
255+
}
256+
}
257+
}
258+
259+
// Returns lowest i such that `!Get(i)`. Caller must first ensure `!All()`!
260+
size_t First0() const {
261+
HWY_DASSERT(!All());
262+
for (size_t idx = 0;; ++idx) {
263+
HWY_DASSERT(idx < kNum64);
264+
const uint64_t inv_bits = ~bits_[idx].load(kAcq);
265+
if (inv_bits != 0) {
266+
const size_t first0 =
267+
idx * 64 + Num0BitsBelowLS1Bit_Nonzero64(inv_bits);
268+
HWY_DASSERT(first0 < kMaxSize);
269+
return first0;
270+
}
271+
}
272+
}
273+
274+
// Calls `func(i)` for each `i` in the set. It is safe for `func` to modify
275+
// the set, but the current Foreach call is only affected if changing one of
276+
// the not yet visited uint64_t.
277+
template <class Func>
278+
void Foreach(const Func& func) const {
279+
for (size_t idx = 0; idx < kNum64; ++idx) {
280+
uint64_t remaining_bits = bits_[idx].load(kAcq);
281+
while (remaining_bits != 0) {
282+
const size_t i = Num0BitsBelowLS1Bit_Nonzero64(remaining_bits);
283+
remaining_bits &= remaining_bits - 1; // clear LSB
284+
func(idx * 64 + i);
285+
}
286+
}
287+
}
288+
289+
size_t Count() const {
290+
size_t total = 0;
291+
for (const std::atomic<uint64_t>& bits : bits_) {
292+
total += PopCount(bits.load(kAcq));
293+
}
294+
return total;
295+
}
296+
297+
private:
298+
static constexpr size_t kNum64 = DivCeil(kMaxSize, size_t{64});
299+
std::atomic<uint64_t> bits_[kNum64] = {};
300+
};
301+
302+
// Two-level bitset for up to `kMaxSize` <= 4096 values. The iterators
303+
// (`Any/First/Foreach/Count`) are more efficient than `BitSet` for sparse sets.
304+
// This comes at the cost of slightly slower mutators (`Set/Clear`).
82305
template <size_t kMaxSize = 4096>
83306
class BitSet4096 {
307+
static_assert(kMaxSize != 0, "BitSet4096 requires non-zero size");
308+
84309
public:
310+
constexpr size_t MaxSize() const { return kMaxSize; }
311+
85312
// No harm if `i` is already set.
86313
void Set(size_t i) {
87314
HWY_DASSERT(i < kMaxSize);
@@ -117,16 +344,38 @@ class BitSet4096 {
117344
return bits_[idx].Get(mod);
118345
}
119346

120-
// Returns true if any Get(i) would return true for i in [0, 64).
347+
// Returns true if `Get(i)` would return true for any i in [0, kMaxSize).
121348
bool Any() const { return nonzero_.Any(); }
122349

123-
// Returns lowest i such that Get(i). Caller must ensure Any() beforehand!
350+
// Returns true if `Get(i)` would return true for all i in [0, kMaxSize).
351+
bool All() const {
352+
// Do not check `nonzero_.All()` - that only works if `kMaxSize` is 4096.
353+
if (nonzero_.Count() != kNum64) return false;
354+
return Count() == kMaxSize;
355+
}
356+
357+
// Returns lowest i such that `Get(i)`. Caller must first ensure `Any()`!
124358
size_t First() const {
125359
HWY_DASSERT(Any());
126360
const size_t idx = nonzero_.First();
127361
return idx * 64 + bits_[idx].First();
128362
}
129363

364+
// Returns lowest i such that `!Get(i)`. Caller must first ensure `!All()`!
365+
size_t First0() const {
366+
HWY_DASSERT(!All());
367+
// It is likely not worthwhile to have a separate `BitSet64` for `not_all_`,
368+
// hence iterate over all u64.
369+
for (size_t idx = 0;; ++idx) {
370+
HWY_DASSERT(idx < kNum64);
371+
if (!bits_[idx].All()) {
372+
const size_t first0 = idx * 64 + bits_[idx].First0();
373+
HWY_DASSERT(first0 < kMaxSize);
374+
return first0;
375+
}
376+
}
377+
}
378+
130379
// Returns uint64_t(Get(i)) << i for i in [0, 64).
131380
uint64_t Get64() const { return bits_[0].Get64(); }
132381

@@ -149,8 +398,9 @@ class BitSet4096 {
149398

150399
private:
151400
static_assert(kMaxSize <= 64 * 64, "One BitSet64 insufficient");
401+
static constexpr size_t kNum64 = DivCeil(kMaxSize, size_t{64});
152402
BitSet64 nonzero_;
153-
BitSet64 bits_[kMaxSize / 64];
403+
BitSet64 bits_[kNum64];
154404
};
155405

156406
} // namespace hwy

0 commit comments

Comments
 (0)