diff --git a/cycler.py b/cycler.py index 8400444..489ab5c 100644 --- a/cycler.py +++ b/cycler.py @@ -47,6 +47,7 @@ from itertools import product, cycle from six.moves import zip, reduce from operator import mul, add +from collections import defaultdict import copy __version__ = '0.10.0' @@ -558,3 +559,29 @@ def _cycler(label, itr): itr = (v[lab] for v in itr) return Cycler._from_iter(label, itr) + + +class OutOfStyles(StopIteration): + pass + + +def persistent_style(cyl, repeat=False): + '''Create a defaultdict mapping keys -> styles + + Parameters + ---------- + cyl : Cycler + The c + ''' + def next_style(): + try: + next(cy_iter) + except StopIteration: + raise OutOfStyles() + + if repeat: + cy_iter = cyl() + return defaultdict(lambda: next(cy_iter)) + else: + cy_iter = iter(cyl) + return defaultdict(next_style) diff --git a/test_cycler.py b/test_cycler.py index 52f65ec..039fdb8 100644 --- a/test_cycler.py +++ b/test_cycler.py @@ -2,7 +2,7 @@ import six from six.moves import zip, range -from cycler import cycler, Cycler, concat +from cycler import cycler, Cycler, concat, persistent_style, OutOfStyles import pytest from itertools import product, cycle, chain from operator import add, iadd, mul, imul @@ -341,3 +341,24 @@ def test_contains(): assert 'a' in ab assert 'b' in ab + + +@pytest.mark.parametrize('repeat', [True, False]) +def test_persistent(repeat): + a = cycler('a', range(3)) + cycler('b', range(3)) + dd = persistent_style(a, repeat=repeat) + one = dd['one'] + two = dd['two'] + three = dd['three'] + + assert one == dd['one'] + assert two == dd['two'] + assert three == dd['three'] + if not repeat: + with pytest.raises(OutOfStyles): + dd['four'] + else: + assert one == dd['four'] + assert one == dd['four'] + assert two == dd['five'] + assert three == dd['six']