diff --git a/mimo/connection/connection.py b/mimo/connection/connection.py index 5cb6479..50148c7 100644 --- a/mimo/connection/connection.py +++ b/mimo/connection/connection.py @@ -11,5 +11,8 @@ def __init__(self, name, threshold=10): def __len__(self): return len(self.entities) + def clear(self): + self.entities.clear() + def is_full(self): return len(self.entities) >= self.threshold diff --git a/mimo/test_helper.py b/mimo/test_helper.py new file mode 100644 index 0000000..c351958 --- /dev/null +++ b/mimo/test_helper.py @@ -0,0 +1,23 @@ +from .connection.input import Input +from .connection.output import Output +from .connection.connection_set import ConnectionSet + + +class TestHelper: + def __init__(self, stream): + ins = [Input(in_) for in_ in stream.ins] + outs = [Output(out) for out in stream.outs] + self.ins = ConnectionSet(ins) + self.outs = ConnectionSet(outs) + self.stream = stream + + def run(self, ins): + for key, value in ins.items(): + self.ins[key].clear() + self.ins[key].extend(value) + outs = {out.name: [] for out in self.outs} + while self.stream.run(self.ins, self.outs): + for out in self.outs: + outs[out.name].extend(out.entities) + out.clear() + return outs diff --git a/mimo/workflow/workflow.py b/mimo/workflow/workflow.py index 8a6deb9..6dd1aad 100644 --- a/mimo/workflow/workflow.py +++ b/mimo/workflow/workflow.py @@ -49,7 +49,7 @@ def run(self): continue for input in inputs: input.extend(output.entities) - output.entities.clear() + output.clear() for input_id in input_ids: stream_ids = self.graph.get_children(input_id) stack.extend(stream_id for stream_id in stream_ids if stream_id not in stacked) diff --git a/tests/test_stream/test_stream.py b/tests/test_stream/test_stream.py index 263782a..3e3aceb 100644 --- a/tests/test_stream/test_stream.py +++ b/tests/test_stream/test_stream.py @@ -15,12 +15,12 @@ def test_run(self): outs = ConnectionSet([output]) input.entities.extend([1, 2, 3, 4, 5, 6, 7, 8, 9, 0]) - stream.run(ins, outs) + self.assertTrue(stream.run(ins, outs)) self.assertEqual([2, 4, 6, 8, 10], list(output.entities)) - stream.run(ins, outs) + self.assertTrue(stream.run(ins, outs)) self.assertEqual([2, 4, 6, 8, 10, 12], list(output.entities)) output.entities.clear() - stream.run(ins, outs) + self.assertFalse(stream.run(ins, outs)) self.assertEqual([14, 16, 18, 0], list(output.entities)) diff --git a/tests/test_test_helper.py b/tests/test_test_helper.py new file mode 100644 index 0000000..338b384 --- /dev/null +++ b/tests/test_test_helper.py @@ -0,0 +1,23 @@ +import unittest + +from mimo import Stream +from mimo.test_helper import TestHelper + + +class TestTestHelper(unittest.TestCase): + def test_run(self): + stream = Stream(ins=['a'], outs=['b'], fn=fn) + helper = TestHelper(stream) + + self.assertEqual({'b': [2, 4, 6, 8, 10, 12, 14, 16, 18, 0]}, + helper.run({'a': [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]})) + + +def fn(ins, outs, state): + while len(ins.a) > 0: + if not outs.b.push(2 * ins.a.pop()): + return True + + +if __name__ == '__main__': + unittest.main()