1616#ifndef HIGHWAY_HWY_BIT_SET_H_
1717#define HIGHWAY_HWY_BIT_SET_H_
1818
19- // BitSet with fast Foreach for up to 64 and 4096 members .
19+ // Various BitSet for 64, up to 4096, or any number of bits .
2020
2121#include < stddef.h>
2222
2323#include " hwy/base.h"
2424
2525namespace hwy {
2626
27- // 64-bit specialization of std::bitset, which lacks Foreach.
27+ // 64-bit specialization of ` std::bitset` , which lacks ` Foreach` .
2828class BitSet64 {
2929 public:
30+ constexpr size_t MaxSize () const { return 64 ; }
31+
3032 // No harm if `i` is already set.
3133 void Set (size_t i) {
3234 HWY_DASSERT (i < 64 );
@@ -48,15 +50,24 @@ class BitSet64 {
4850 return (bits_ & (1ULL << i)) != 0 ;
4951 }
5052
51- // Returns true if any Get(i) would return true for i in [0, 64).
53+ // Returns true if Get(i) would return true for any i in [0, 64).
5254 bool Any () const { return bits_ != 0 ; }
5355
54- // Returns lowest i such that Get(i). Caller must ensure Any() beforehand!
56+ // Returns true if Get(i) would return true for all i in [0, 64).
57+ bool All () const { return bits_ == ~uint64_t {0 }; }
58+
59+ // Returns lowest i such that `Get(i)`. Caller must first ensure `Any()`!
5560 size_t First () const {
5661 HWY_DASSERT (Any ());
5762 return Num0BitsBelowLS1Bit_Nonzero64 (bits_);
5863 }
5964
65+ // Returns lowest i such that `!Get(i)`. Caller must first ensure `!All()`!
66+ size_t First0 () const {
67+ HWY_DASSERT (!All ());
68+ return Num0BitsBelowLS1Bit_Nonzero64 (~bits_);
69+ }
70+
6071 // Returns uint64_t(Get(i)) << i for i in [0, 64).
6172 uint64_t Get64 () const { return bits_; }
6273
@@ -78,10 +89,226 @@ class BitSet64 {
7889 uint64_t bits_ = 0 ;
7990};
8091
81- // Two-level bitset for up to kMaxSize <= 4096 values.
92+ // Any number of bits, flat array.
93+ template <size_t kMaxSize >
94+ class BitSet {
95+ static_assert (kMaxSize != 0 , " BitSet requires non-zero size" );
96+
97+ public:
98+ constexpr size_t MaxSize () const { return kMaxSize ; }
99+
100+ // No harm if `i` is already set.
101+ void Set (size_t i) {
102+ HWY_DASSERT (i < kMaxSize );
103+ const size_t idx = i / 64 ;
104+ const size_t mod = i % 64 ;
105+ bits_[idx].Set (mod);
106+ }
107+
108+ void Clear (size_t i) {
109+ HWY_DASSERT (i < kMaxSize );
110+ const size_t idx = i / 64 ;
111+ const size_t mod = i % 64 ;
112+ bits_[idx].Clear (mod);
113+ HWY_DASSERT (!Get (i));
114+ }
115+
116+ bool Get (size_t i) const {
117+ HWY_DASSERT (i < kMaxSize );
118+ const size_t idx = i / 64 ;
119+ const size_t mod = i % 64 ;
120+ return bits_[idx].Get (mod);
121+ }
122+
123+ // Returns true if Get(i) would return true for any i in [0, kMaxSize).
124+ bool Any () const {
125+ for (const BitSet64& bits : bits_) {
126+ if (bits.Any ()) return true ;
127+ }
128+ return false ;
129+ }
130+
131+ // Returns true if Get(i) would return true for all i in [0, kMaxSize).
132+ bool All () const {
133+ for (size_t idx = 0 ; idx < kNum64 - 1 ; ++idx) {
134+ if (!bits_[idx].All ()) return false ;
135+ }
136+
137+ constexpr size_t kRemainder = kMaxSize % 64 ;
138+ if (kRemainder == 0 ) {
139+ return bits_[kNum64 - 1 ].All ();
140+ }
141+ return bits_[kNum64 - 1 ].Count () == kRemainder ;
142+ }
143+
144+ // Returns lowest i such that `Get(i)`. Caller must first ensure `Any()`!
145+ size_t First () const {
146+ HWY_DASSERT (Any ());
147+ for (size_t idx = 0 ;; ++idx) {
148+ HWY_DASSERT (idx < kNum64 );
149+ if (bits_[idx].Any ()) return idx * 64 + bits_[idx].First ();
150+ }
151+ }
152+
153+ // Returns lowest i such that `!Get(i)`. Caller must first ensure `All()`!
154+ size_t First0 () const {
155+ HWY_DASSERT (!All ());
156+ for (size_t idx = 0 ;; ++idx) {
157+ HWY_DASSERT (idx < kNum64 );
158+ if (!bits_[idx].All ()) {
159+ const size_t first0 = idx * 64 + bits_[idx].First0 ();
160+ HWY_DASSERT (first0 < kMaxSize );
161+ return first0;
162+ }
163+ }
164+ }
165+
166+ // Calls `func(i)` for each `i` in the set. It is safe for `func` to modify
167+ // the set, but the current Foreach call is only affected if changing one of
168+ // the not yet visited BitSet64.
169+ template <class Func >
170+ void Foreach (const Func& func) const {
171+ for (size_t idx = 0 ; idx < kNum64 ; ++idx) {
172+ bits_[idx].Foreach ([idx, &func](size_t mod) { func (idx * 64 + mod); });
173+ }
174+ }
175+
176+ size_t Count () const {
177+ size_t total = 0 ;
178+ for (const BitSet64& bits : bits_) {
179+ total += bits.Count ();
180+ }
181+ return total;
182+ }
183+
184+ private:
185+ static constexpr size_t kNum64 = DivCeil(kMaxSize , size_t {64 });
186+ BitSet64 bits_[kNum64 ];
187+ };
188+
189+ // Any number of bits, flat array, atomic updates to the u64.
190+ template <size_t kMaxSize >
191+ class AtomicBitSet {
192+ static_assert (kMaxSize != 0 , " AtomicBitSet requires non-zero size" );
193+
194+ // Bits may signal something to other threads, hence relaxed is insufficient.
195+ // Acq/Rel ensures a happens-before relationship.
196+ static constexpr auto kAcq = std::memory_order_acquire;
197+ static constexpr auto kRel = std::memory_order_release;
198+
199+ public:
200+ constexpr size_t MaxSize () const { return kMaxSize ; }
201+
202+ // No harm if `i` is already set.
203+ void Set (size_t i) {
204+ HWY_DASSERT (i < kMaxSize );
205+ const size_t idx = i / 64 ;
206+ const size_t mod = i % 64 ;
207+ bits_[idx].fetch_or (1ULL << mod, kRel );
208+ }
209+
210+ void Clear (size_t i) {
211+ HWY_DASSERT (i < kMaxSize );
212+ const size_t idx = i / 64 ;
213+ const size_t mod = i % 64 ;
214+ bits_[idx].fetch_and (~(1ULL << mod), kRel );
215+ HWY_DASSERT (!Get (i));
216+ }
217+
218+ bool Get (size_t i) const {
219+ HWY_DASSERT (i < kMaxSize );
220+ const size_t idx = i / 64 ;
221+ const size_t mod = i % 64 ;
222+ return ((bits_[idx].load (kAcq ) & (1ULL << mod))) != 0 ;
223+ }
224+
225+ // Returns true if Get(i) would return true for any i in [0, kMaxSize).
226+ bool Any () const {
227+ for (const std::atomic<uint64_t >& bits : bits_) {
228+ if (bits.load (kAcq )) return true ;
229+ }
230+ return false ;
231+ }
232+
233+ // Returns true if Get(i) would return true for all i in [0, kMaxSize).
234+ bool All () const {
235+ for (size_t idx = 0 ; idx < kNum64 - 1 ; ++idx) {
236+ if (bits_[idx].load (kAcq ) != ~uint64_t {0 }) return false ;
237+ }
238+
239+ constexpr size_t kRemainder = kMaxSize % 64 ;
240+ const uint64_t last_bits = bits_[kNum64 - 1 ].load (kAcq );
241+ if (kRemainder == 0 ) {
242+ return last_bits == ~uint64_t {0 };
243+ }
244+ return PopCount (last_bits) == kRemainder ;
245+ }
246+
247+ // Returns lowest i such that `Get(i)`. Caller must first ensure `Any()`!
248+ size_t First () const {
249+ HWY_DASSERT (Any ());
250+ for (size_t idx = 0 ;; ++idx) {
251+ HWY_DASSERT (idx < kNum64 );
252+ const uint64_t bits = bits_[idx].load (kAcq );
253+ if (bits != 0 ) {
254+ return idx * 64 + Num0BitsBelowLS1Bit_Nonzero64 (bits);
255+ }
256+ }
257+ }
258+
259+ // Returns lowest i such that `!Get(i)`. Caller must first ensure `!All()`!
260+ size_t First0 () const {
261+ HWY_DASSERT (!All ());
262+ for (size_t idx = 0 ;; ++idx) {
263+ HWY_DASSERT (idx < kNum64 );
264+ const uint64_t inv_bits = ~bits_[idx].load (kAcq );
265+ if (inv_bits != 0 ) {
266+ const size_t first0 =
267+ idx * 64 + Num0BitsBelowLS1Bit_Nonzero64 (inv_bits);
268+ HWY_DASSERT (first0 < kMaxSize );
269+ return first0;
270+ }
271+ }
272+ }
273+
274+ // Calls `func(i)` for each `i` in the set. It is safe for `func` to modify
275+ // the set, but the current Foreach call is only affected if changing one of
276+ // the not yet visited uint64_t.
277+ template <class Func >
278+ void Foreach (const Func& func) const {
279+ for (size_t idx = 0 ; idx < kNum64 ; ++idx) {
280+ uint64_t remaining_bits = bits_[idx].load (kAcq );
281+ while (remaining_bits != 0 ) {
282+ const size_t i = Num0BitsBelowLS1Bit_Nonzero64 (remaining_bits);
283+ remaining_bits &= remaining_bits - 1 ; // clear LSB
284+ func (idx * 64 + i);
285+ }
286+ }
287+ }
288+
289+ size_t Count () const {
290+ size_t total = 0 ;
291+ for (const std::atomic<uint64_t >& bits : bits_) {
292+ total += PopCount (bits.load (kAcq ));
293+ }
294+ return total;
295+ }
296+
297+ private:
298+ static constexpr size_t kNum64 = DivCeil(kMaxSize , size_t {64 });
299+ std::atomic<uint64_t > bits_[kNum64 ] = {};
300+ };
301+
302+ // Two-level bitset for up to `kMaxSize` <= 4096 values. The iterators
303+ // (`Any/First/Foreach/Count`) are more efficient than `BitSet` for sparse sets.
304+ // This comes at the cost of slightly slower mutators (`Set/Clear`).
82305template <size_t kMaxSize = 4096 >
83306class BitSet4096 {
307+ static_assert (kMaxSize != 0 , " BitSet4096 requires non-zero size" );
308+
84309 public:
310+ constexpr size_t MaxSize () const { return kMaxSize ; }
311+
85312 // No harm if `i` is already set.
86313 void Set (size_t i) {
87314 HWY_DASSERT (i < kMaxSize );
@@ -117,16 +344,38 @@ class BitSet4096 {
117344 return bits_[idx].Get (mod);
118345 }
119346
120- // Returns true if any Get(i) would return true for i in [0, 64 ).
347+ // Returns true if ` Get(i)` would return true for any i in [0, kMaxSize ).
121348 bool Any () const { return nonzero_.Any (); }
122349
123- // Returns lowest i such that Get(i). Caller must ensure Any() beforehand!
350+ // Returns true if `Get(i)` would return true for all i in [0, kMaxSize).
351+ bool All () const {
352+ // Do not check `nonzero_.All()` - that only works if `kMaxSize` is 4096.
353+ if (nonzero_.Count () != kNum64 ) return false ;
354+ return Count () == kMaxSize ;
355+ }
356+
357+ // Returns lowest i such that `Get(i)`. Caller must first ensure `Any()`!
124358 size_t First () const {
125359 HWY_DASSERT (Any ());
126360 const size_t idx = nonzero_.First ();
127361 return idx * 64 + bits_[idx].First ();
128362 }
129363
364+ // Returns lowest i such that `!Get(i)`. Caller must first ensure `!All()`!
365+ size_t First0 () const {
366+ HWY_DASSERT (!All ());
367+ // It is likely not worthwhile to have a separate `BitSet64` for `not_all_`,
368+ // hence iterate over all u64.
369+ for (size_t idx = 0 ;; ++idx) {
370+ HWY_DASSERT (idx < kNum64 );
371+ if (!bits_[idx].All ()) {
372+ const size_t first0 = idx * 64 + bits_[idx].First0 ();
373+ HWY_DASSERT (first0 < kMaxSize );
374+ return first0;
375+ }
376+ }
377+ }
378+
130379 // Returns uint64_t(Get(i)) << i for i in [0, 64).
131380 uint64_t Get64 () const { return bits_[0 ].Get64 (); }
132381
@@ -149,8 +398,9 @@ class BitSet4096 {
149398
150399 private:
151400 static_assert (kMaxSize <= 64 * 64 , " One BitSet64 insufficient" );
401+ static constexpr size_t kNum64 = DivCeil(kMaxSize , size_t {64 });
152402 BitSet64 nonzero_;
153- BitSet64 bits_[kMaxSize / 64 ];
403+ BitSet64 bits_[kNum64 ];
154404};
155405
156406} // namespace hwy
0 commit comments