Skip to content

Commit f1e4d23

Browse files
author
Joe Jevnik
committed
Add dichotomize function for splitting a sequence by a predicate
1 parent c3a6294 commit f1e4d23

File tree

3 files changed

+76
-2
lines changed

3 files changed

+76
-2
lines changed

doc/source/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ Itertoolz
1515
concatv
1616
cons
1717
count
18+
dichotomize
1819
diff
1920
drop
2021
first

toolz/itertoolz.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
'first', 'second', 'nth', 'last', 'get', 'concat', 'concatv',
1515
'mapcat', 'cons', 'interpose', 'frequencies', 'reduceby', 'iterate',
1616
'sliding_window', 'partition', 'partition_all', 'count', 'pluck',
17-
'join', 'tail', 'diff', 'topk', 'peek', 'random_sample')
17+
'join', 'tail', 'diff', 'topk', 'peek', 'random_sample',
18+
'dichotomize')
1819

1920

2021
def remove(predicate, seq):
@@ -980,3 +981,48 @@ def random_sample(prob, seq, random_state=None):
980981
if not hasattr(random_state, 'random'):
981982
random_state = Random(random_state)
982983
return filter(lambda _: random_state.random() < prob, seq)
984+
985+
986+
def _complement_iterator(it, predicate, our_queue, other_queue):
987+
for element in our_queue:
988+
yield element
989+
our_queue.clear()
990+
991+
for element in it:
992+
if predicate(element):
993+
yield element
994+
else:
995+
other_queue.append(element)
996+
997+
for element in our_queue:
998+
yield element
999+
our_queue.clear()
1000+
1001+
1002+
def dichotomize(predicate, iterable):
1003+
"""Take a predicate and an iterable and return the pair of iterables of
1004+
elements which do and do not satisfy the predicate. The resulting iterators
1005+
are lazy.
1006+
1007+
>>> def even(n):
1008+
... return n & 1 == 0
1009+
...
1010+
>>> evens, odds = dichotomize(even, range(10))
1011+
>>> list(evens)
1012+
[0, 2, 4, 6, 8]
1013+
>>> list(odds)
1014+
[1, 3, 5, 7, 9]
1015+
"""
1016+
true_queue = collections.deque()
1017+
false_queue = collections.deque()
1018+
it = iter(iterable)
1019+
1020+
return (
1021+
_complement_iterator(it, predicate, true_queue, false_queue),
1022+
_complement_iterator(
1023+
it,
1024+
lambda element: not predicate(element),
1025+
false_queue,
1026+
true_queue,
1027+
),
1028+
)

toolz/tests/test_itertoolz.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
reduceby, iterate, accumulate,
1414
sliding_window, count, partition,
1515
partition_all, take_nth, pluck, join,
16-
diff, topk, peek, random_sample)
16+
diff, topk, peek, random_sample, dichotomize)
1717
from toolz.compatibility import range, filter
1818
from operator import add, mul
1919

@@ -524,3 +524,30 @@ def test_random_sample():
524524
assert mk_rsample(b"a") == mk_rsample(u"a")
525525

526526
assert raises(TypeError, lambda: mk_rsample([]))
527+
528+
529+
def test_dichotimize():
530+
evens, odds = dichotomize(iseven, range(10))
531+
assert list(evens) == [0, 2, 4, 6, 8]
532+
assert list(odds) == [1, 3, 5, 7, 9]
533+
534+
535+
def test_dichotimize_interleaved_next_calls():
536+
evens, odds = dichotomize(iseven, range(10))
537+
538+
assert next(evens) == 0
539+
assert next(evens) == 2
540+
541+
assert next(odds) == 1
542+
assert next(odds) == 3
543+
assert next(odds) == 5
544+
545+
assert next(evens) == 4
546+
assert next(evens) == 6
547+
assert next(evens) == 8
548+
549+
assert next(odds) == 7
550+
assert next(odds) == 9
551+
552+
assert list(evens) == []
553+
assert list(odds) == []

0 commit comments

Comments
 (0)