diff --git a/README.md b/README.md index 2c9ce5f..5e2072c 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ There are two core components in MiMo; the `Stream` and the `Workflow`. Streams ### Streams -Implementing a stream can be done through inheriting a sub-class from the `Stream` class or creating a `Stream`class with a custom function as the `fn` parameter. The following code shows the same implementation of a stream that will produce the numbers from 0 to 99. +Implementing a stream can be done through inheriting a sub-class from the `Stream` class or creating a `Stream` class with a custom function as the `fn` parameter. The following code shows two implementations of a stream that will produce the numbers from 0 to 99. ```python @@ -28,36 +28,26 @@ class MyStream(Stream): IN = [] OUT = ['entity'] - - def __init__(self): - super().__init__() - self.iterator = None - def run(self, ins, outs): - if self.iterator is None: - self.iterator = iter(range(100)) - for item in self.iterator: - if not outs.entity.pueh(item): - return True + async def run(self, ins, outs): + for item in iter(range(100)): + await outs.entity.push(item) # Method 2 (constructor) my_stream = Stream(outs=['entity], fn=my_stream_fn) -def my_stream_fn(ins, outs, state): - if 'iterator' not in state: - state['iterator'] = iter(range(100)) - for item in state['iterator']: - if not outs.entity.push(item): - return True +async def my_stream_fn(ins, outs, state): + for item in iter(range(100)): + await outs.entity.push(item) ``` There are a few things to note about the `run` function. -1. It takes two parameters, `ins` and `outs`, that contain the input streams and the output streams. The names of the input and output streams are defined by the `IN` and `OUT` member variables and accessing the input and output streams can be done through the attributes. From the example above, accessing the `entity` output stream can be done with `outs.entity`. -2. Input streams can be popped and peeked. Input streams haven't been used in the above example, but the entities in the stream can be accessed one at a time with the functions `pop` and `peek`. Popping an entity will remove it from the input stream, and peeking will look at the top-most entity without removing it from the stream. -2. Output streams can be pushed. Pushing an entity to an output stream will make it available to any connected downstream streams. The `push` function return a boolean to indicate whether the stream is full or not (`True` if still pushable). A full stream ca still be pushed to, but users can make their custom streams back-pressure aware by testing this value. -3. The return value is a boolean. If a stream did not fully complete it's task (possibly due to back-pressure), then it should return `True` to indicate that it can be run again after downstream streams have completed. Otherwise a `False` (or equivalent like `None`) will prevent further execution of the stream until new input is available. +1. It must be asynchronous, ie. it must be defined wth the `async def` keywords. +2. It takes two parameters, `ins` and `outs`, that contain the input streams and the output streams. The names of the input and output streams are defined by the `IN` and `OUT` member variables or overridden using the `ins` and `outs` of the initialisation function. Accessing the input and output streams can be done through the attributes. From the example above, accessing the `entity` output stream can be done with `outs.entity`. +3. Input streams can be popped and peeked and this must be done using the `await` keyword. Input streams haven't been used in the above example, but the entities in the stream can be accessed one at a time with the functions `pop` and `peek`. Popping an entity will remove it from the input stream, and peeking will look at the top-most entity without removing it from the stream. Input streams can also be iterated using the `async for` looping construct. +4. Output streams can be pushed and must also use the `await` keyword. Pushing an entity to an output stream will make it available to any connected downstream streams. ### Workflows diff --git a/examples/join.py b/examples/join.py index 2e005b0..ac96e32 100644 --- a/examples/join.py +++ b/examples/join.py @@ -1,4 +1,4 @@ -from mimo import Workflow, Stream +from mimo import Workflow, Stream, azip def main(): @@ -14,36 +14,30 @@ def main(): workflow.run() -def stream1(ins, outs, state): +async def stream1(ins, outs, state): """ Generates integers from 0 to 99. """ - if 'iterator' not in state: - state['iterator'] = iter(range(100)) - iterator = state['iterator'] - for item in iterator: - if not outs.a.push(item): - return True + for item in iter(range(100)): + await outs.a.push(item) + outs.a.close() -def stream2(ins, outs, state): +async def stream2(ins, outs, state): """ Generates integers from 99 to 0. """ - if 'iterator' not in state: - state['iterator'] = iter(range(99, -1, -1)) - iterator = state['iterator'] - for item in iterator: - if not outs.b.push(item): - return True + for item in iter(range(99, -1, -1)): + await outs.b.push(item) + outs.b.close() -def stream3(ins, outs, state): +async def stream3(ins, outs, state): """ Divide incoming entities by 10 and print to stdout """ - while len(ins.c) > 0 and len(ins.d) > 0: - sys.stdout.write('{}\n'.format(ins.c.pop() + ins.d.pop())) + async for c, d in azip(ins.c, ins.d): + sys.stdout.write('{}\n'.format(c + d)) if __name__ == '__main__': import sys diff --git a/examples/linear.py b/examples/linear.py index a8733ac..d43e7a0 100644 --- a/examples/linear.py +++ b/examples/linear.py @@ -13,35 +13,30 @@ def main(): workflow.run() -def stream1(ins, outs, state): +async def stream1(ins, outs, state): """ Generates integers from 0 to 99. """ - if 'iterator' not in state: - state['iterator'] = iter(range(100)) - iterator = state['iterator'] - for item in iterator: - if not outs.a.push(item): - return True + for item in iter(range(100)): + await outs.a.push(item) + outs.a.close() -def stream2(ins, outs, state): +async def stream2(ins, outs, state): """ Multiplies the integers by 2. """ - while len(ins.b) > 0: - item = ins.b.pop() - if not outs.c.push(item * 2): - break - return len(ins.b) > 0 + async for item in ins.b: + await outs.c.push(item * 2) + outs.c.close() -def stream3(ins, outs, state): +async def stream3(ins, outs, state): """ Print incoming entities to stdout """ - while len(ins.d) > 0: - print(ins.d.pop()) + async for item in ins.d: + print(item) if __name__ == '__main__': import sys diff --git a/examples/multi_output.py b/examples/multi_output.py index bdc2edc..4cdb3ed 100644 --- a/examples/multi_output.py +++ b/examples/multi_output.py @@ -3,43 +3,42 @@ def main(): workflow = Workflow(10) - step1 = workflow.add_stream(Stream(outs=['a'], fn=stream1)) - step2 = workflow.add_stream(Stream(['b'], fn=stream2)) - step3 = workflow.add_stream(Stream(['c'], fn=stream3)) + step1 = workflow.add_stream(Stream(outs=['a', 'b'], fn=stream1)) + step2 = workflow.add_stream(Stream(['c'], fn=stream2)) + step3 = workflow.add_stream(Stream(['d'], fn=stream3)) - step1.pipe(step2) - step1.pipe(step3) + step1.pipe(step2, 'a') + step1.pipe(step3, 'b') print(str(workflow)) workflow.run() -def stream1(ins, outs, state): +async def stream1(ins, outs, state): """ - Generates integers from 0 to 99. + Generates one stream of integers from 0 to 99 and another from 100 to 1 """ - if 'iterator' not in state: - state['iterator'] = iter(range(100)) - iterator = state['iterator'] - for item in iterator: - if not outs.a.push(item): - return True + for item in iter(range(100)): + await outs.a.push(item) + await outs.b.push(100 - item) + outs.a.close() + outs.b.close() -def stream2(ins, outs, state): +async def stream2(ins, outs, state): """ Multiply incoming entities by 2 and print to stdout """ - while len(ins.b) > 0: - sys.stdout.write('{}\n'.format(2 * ins.b.pop())) + async for item in ins.c: + sys.stdout.write('{}\n'.format(2 * item)) -def stream3(ins, outs, state): +async def stream3(ins, outs, state): """ Divide incoming entities by 10 and print to stdout """ - while len(ins.c) > 0: - sys.stdout.write('{}\n'.format(ins.c.pop() / 10)) + async for item in ins.d: + sys.stdout.write('{}\n'.format(item / 10)) if __name__ == '__main__': import sys diff --git a/examples/split_output.py b/examples/split_output.py index bdc2edc..ea0638a 100644 --- a/examples/split_output.py +++ b/examples/split_output.py @@ -14,32 +14,29 @@ def main(): workflow.run() -def stream1(ins, outs, state): +async def stream1(ins, outs, state): """ Generates integers from 0 to 99. """ - if 'iterator' not in state: - state['iterator'] = iter(range(100)) - iterator = state['iterator'] - for item in iterator: - if not outs.a.push(item): - return True + for item in iter(range(100)): + await outs.a.push(item) + outs.a.close() -def stream2(ins, outs, state): +async def stream2(ins, outs, state): """ Multiply incoming entities by 2 and print to stdout """ - while len(ins.b) > 0: - sys.stdout.write('{}\n'.format(2 * ins.b.pop())) + async for item in ins.b: + sys.stdout.write('{}\n'.format(2 * item)) -def stream3(ins, outs, state): +async def stream3(ins, outs, state): """ Divide incoming entities by 10 and print to stdout """ - while len(ins.c) > 0: - sys.stdout.write('{}\n'.format(ins.c.pop() / 10)) + async for item in ins.c: + sys.stdout.write('{}\n'.format(item / 10)) if __name__ == '__main__': import sys diff --git a/mimo/__init__.py b/mimo/__init__.py index 283777b..b99b8ef 100644 --- a/mimo/__init__.py +++ b/mimo/__init__.py @@ -1,2 +1,4 @@ from .stream import Stream from .workflow import Workflow + +from .asynctools import * diff --git a/mimo/asynctools/__init__.py b/mimo/asynctools/__init__.py new file mode 100644 index 0000000..36dc84f --- /dev/null +++ b/mimo/asynctools/__init__.py @@ -0,0 +1,10 @@ +from .range import AsynchronousRange +from .zip import AsynchronousZip + + +def azip(*iterables): + return AsynchronousZip(*iterables) + + +def arange(fr, to=None, step=1): + return AsynchronousRange(fr, to, step) diff --git a/mimo/asynctools/range.py b/mimo/asynctools/range.py new file mode 100644 index 0000000..816b774 --- /dev/null +++ b/mimo/asynctools/range.py @@ -0,0 +1,15 @@ +class AsynchronousRange: + def __init__(self, fr, to=None, step=1): + if to is None: + self._iterator = iter(range(fr)) + else: + self._iterator = iter(range(fr, to, step)) + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return next(self._iterator) + except StopIteration: + raise StopAsyncIteration diff --git a/mimo/asynctools/zip.py b/mimo/asynctools/zip.py new file mode 100644 index 0000000..9708fef --- /dev/null +++ b/mimo/asynctools/zip.py @@ -0,0 +1,12 @@ +class AsynchronousZip: + def __init__(self, *iterables): + self._iterables = iterables + + def __aiter__(self): + return self + + async def __anext__(self): + res = [] + for iterator in self._iterables: + res.append(await iterator.__anext__()) + return tuple(res) diff --git a/mimo/connection/connection.py b/mimo/connection/connection.py deleted file mode 100644 index 50148c7..0000000 --- a/mimo/connection/connection.py +++ /dev/null @@ -1,18 +0,0 @@ -from collections import deque - - -class Connection: - def __init__(self, name, threshold=10): - self.entities = deque() - - self.name = name - self.threshold = threshold - - def __len__(self): - return len(self.entities) - - def clear(self): - self.entities.clear() - - def is_full(self): - return len(self.entities) >= self.threshold diff --git a/mimo/connection/input.py b/mimo/connection/input.py index 02d1a64..f68130f 100644 --- a/mimo/connection/input.py +++ b/mimo/connection/input.py @@ -1,13 +1,73 @@ -from .connection import Connection +import asyncio -class Input(Connection): - def peek(self): - return self.entities[0] +class ConnectionClosed(Exception): + pass - def pop(self): - return self.entities.popleft() - def extend(self, entities): - self.entities.extend(entities) - return len(self.entities) < self.threshold +class Input(asyncio.Queue): + def __init__(self, name, maxsize=0, loop=None): + super().__init__(maxsize, loop=loop) + self.name = name + self._closed = False + + async def __aiter__(self): + return self + + async def __anext__(self): + try: + return await self.pop() + except ConnectionClosed: + raise StopAsyncIteration + + async def push(self, item): + if self._closed: + raise ConnectionClosed + return await self.put(item) + + async def peek(self): + while self.empty(): + if self._closed and len(self._putters) == 0: + raise ConnectionClosed + + getter = self._loop.create_future() + self._getters.append(getter) + try: + await getter + except: + getter.cancel() + if not self.empty() and not getter.cancelled(): + self._wakeup_next(self._getters) + raise + return self._queue[0] + + async def pop(self): + while self.empty(): + if self._closed and len(self._putters) == 0: + raise ConnectionClosed + + getter = self._loop.create_future() + self._getters.append(getter) + try: + await getter + except: + getter.cancel() + if not self.empty() and not getter.cancelled(): + self._wakeup_next(self._getters) + raise + return self.get_nowait() + + def close(self): + self._closed = True + self._maxsize = 0 + while self._putters: + putter = self._putters.popleft() + if not putter.done(): + putter.set_result(None) + while self._getters: + getter = self._getters.popleft() + if not getter.done(): + if self.empty(): + getter.set_exception(ConnectionClosed) + else: + getter.set_result(None) diff --git a/mimo/connection/output.py b/mimo/connection/output.py index 40a0cf4..b14ac48 100644 --- a/mimo/connection/output.py +++ b/mimo/connection/output.py @@ -1,12 +1,15 @@ -from .connection import Connection +class Output: + def __init__(self, name): + self.name = name + self._connections = [] + async def push(self, entity): + for connection in self._connections: + await connection.push(entity) -class Output(Connection): - def push(self, entity): - """ - Add an entity to the end of the connection. Return if connection can still be pushed to. - :param entity: - :return: - """ - self.entities.append(entity) - return len(self.entities) < self.threshold + def pipe(self, connection): + self._connections.append(connection) + + def close(self): + for connection in self._connections: + connection.close() diff --git a/mimo/stream/stream.py b/mimo/stream/stream.py index 5a1bead..ea200de 100644 --- a/mimo/stream/stream.py +++ b/mimo/stream/stream.py @@ -8,7 +8,7 @@ class Stream: IN = [] OUT = [] - def __init__(self, ins=None, outs=None, fn=None, name=None): + def __init__(self, ins=None, outs=None, fn=None, name=None, state=None): """ Initialise a stream. Streams can be sub-classed to alter the behaviour or customised directly. If sub-classing a stream, the class members `IN` and `OUT` define the names of the input and output entities. @@ -24,12 +24,11 @@ def __init__(self, ins=None, outs=None, fn=None, name=None): :param name: name of the stream :param fn: run function """ - self.state = {} - self.ins = self.IN if ins is None else ins self.outs = self.OUT if outs is None else outs self.fn = fn self.name = type(self).__name__ if name is None else name + self.state = {} if state is None else state def run(self, ins, outs): """ diff --git a/mimo/test_helper.py b/mimo/test_helper.py index c351958..939d0dc 100644 --- a/mimo/test_helper.py +++ b/mimo/test_helper.py @@ -1,23 +1,27 @@ +import asyncio + from .connection.input import Input from .connection.output import Output from .connection.connection_set import ConnectionSet class TestHelper: - def __init__(self, stream): + def __init__(self, stream, timeout=5): ins = [Input(in_) for in_ in stream.ins] outs = [Output(out) for out in stream.outs] + self.sinks = [Input(out) for out in stream.outs] + for out, sink in zip(outs, self.sinks): + out.pipe(sink) self.ins = ConnectionSet(ins) self.outs = ConnectionSet(outs) self.stream = stream + self._timeout = timeout + self._loop = asyncio.get_event_loop() - def run(self, ins): + def run(self, ins={}): for key, value in ins.items(): - self.ins[key].clear() - self.ins[key].extend(value) - outs = {out.name: [] for out in self.outs} - while self.stream.run(self.ins, self.outs): - for out in self.outs: - outs[out.name].extend(out.entities) - out.clear() - return outs + self.ins[key]._queue.extend(value) + self.ins[key].close() + task = self._loop.create_task(self.stream.run(self.ins, self.outs)) + self._loop.run_until_complete(asyncio.wait_for(task, self._timeout, loop=self._loop)) + return {sink.name: list(sink._queue) for sink in self.sinks} diff --git a/mimo/workflow/node.py b/mimo/workflow/node.py index 2c97193..5d0baca 100644 --- a/mimo/workflow/node.py +++ b/mimo/workflow/node.py @@ -40,4 +40,5 @@ def pipe(self, step, output=None, input=None): input_id = step.input_ids[input] self.workflow.graph.add_edge(output_id, input_id) + self.workflow.outputs[output_id].pipe(self.workflow.inputs[input_id]) return step diff --git a/mimo/workflow/workflow.py b/mimo/workflow/workflow.py index 6dd1aad..85a0254 100644 --- a/mimo/workflow/workflow.py +++ b/mimo/workflow/workflow.py @@ -1,4 +1,5 @@ -from collections import deque +import asyncio + from uuid import uuid4 from lhc.graph import NPartiteGraph from mimo.connection.input import Input @@ -32,31 +33,11 @@ def __str__(self): return '\n'.join(res) def run(self): - stacked = set(self._get_head_streams()) | set(self._get_streams_with_input()) - stack = deque(stacked) - while len(stack) > 0: - stream_id = stack.popleft() - stacked.remove(stream_id) - stream = self.streams[stream_id] - paused = stream.run(self.input_sets[stream_id], self.output_sets[stream_id]) - - output_ids = self.graph.get_children(stream_id) - for output_id in output_ids: - output = self.outputs[output_id] - input_ids = self.graph.get_children(output_id) - inputs = [self.inputs[input_id] for input_id in input_ids] - if any(input.is_full() for input in inputs): - continue - for input in inputs: - input.extend(output.entities) - output.clear() - for input_id in input_ids: - stream_ids = self.graph.get_children(input_id) - stack.extend(stream_id for stream_id in stream_ids if stream_id not in stacked) - stacked.update(stream_ids) - if paused: - stack.append(stream_id) - stacked.add(stream_id) + loop = asyncio.get_event_loop() + tasks = [stream.run(self.input_sets[stream_id], self.output_sets[stream_id]) + for stream_id, stream in self.streams.items()] + loop.run_until_complete(asyncio.gather(*tasks)) + loop.close() def add_stream(self, stream): stream_id, input_ids, output_ids = self._get_identifiers(stream) @@ -64,18 +45,6 @@ def add_stream(self, stream): self._add_edges(stream_id, input_ids, output_ids) return Node(self, stream_id, input_ids, output_ids) - def _get_head_streams(self): - for stream_id in self.graph.partitions[0]: - if len(self.graph.get_parents(stream_id)) == 0: - yield stream_id - - def _get_streams_with_input(self): - inputs = self.inputs - for stream_id in self.graph.partitions[0]: - input_ids = self.graph.get_parents(stream_id) - if any(len(inputs[input_id]) > 0 for input_id in input_ids): - yield stream_id - def _get_identifiers(self, stream): return str(uuid4())[:8],\ {name: str(uuid4())[:8] for name in stream.ins},\ @@ -84,7 +53,7 @@ def _get_identifiers(self, stream): def _add_vertices(self, stream, stream_id, input_ids, output_ids): self.streams[stream_id] = stream self.input_sets[stream_id] = ConnectionSet(Input(name, self.threshold) for name in stream.ins) - self.output_sets[stream_id] = ConnectionSet(Output(name, self.threshold) for name in stream.outs) + self.output_sets[stream_id] = ConnectionSet(Output(name) for name in stream.outs) self.graph.add_vertex(stream_id, 0) for input, input_id in input_ids.items(): diff --git a/setup.py b/setup.py index 3451cbf..e7575d3 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ setup( name='mimo', - version='1.0.6', + version='1.0.7', author='Liam H. Childs', author_email='liam.h.childs@gmail.com', packages=find_packages(exclude=['tests']), diff --git a/tests/test_asynctools.py b/tests/test_asynctools.py new file mode 100644 index 0000000..6bd3a23 --- /dev/null +++ b/tests/test_asynctools.py @@ -0,0 +1,85 @@ +import asyncio +import unittest + +from mimo import arange, azip + + +class TestRange(unittest.TestCase): + def setUp(self): + self.loop = asyncio.get_event_loop() + + def test_range(self): + iterator = arange(10) + + future = self.loop.create_task(iterator.__anext__()) + self.loop.run_until_complete(future) + self.assertEqual(0, future.result()) + + future = self.loop.create_task(iterator.__anext__()) + self.loop.run_until_complete(future) + self.assertEqual(1, future.result()) + + future = self.loop.create_task(iterator.__anext__()) + self.loop.run_until_complete(future) + self.assertEqual(2, future.result()) + + future = self.loop.create_task(iterator.__anext__()) + self.loop.run_until_complete(future) + self.assertEqual(3, future.result()) + + future = self.loop.create_task(iterator.__anext__()) + self.loop.run_until_complete(future) + self.assertEqual(4, future.result()) + + future = self.loop.create_task(iterator.__anext__()) + self.loop.run_until_complete(future) + self.assertEqual(5, future.result()) + + future = self.loop.create_task(iterator.__anext__()) + self.loop.run_until_complete(future) + self.assertEqual(6, future.result()) + + future = self.loop.create_task(iterator.__anext__()) + self.loop.run_until_complete(future) + self.assertEqual(7, future.result()) + + future = self.loop.create_task(iterator.__anext__()) + self.loop.run_until_complete(future) + self.assertEqual(8, future.result()) + + future = self.loop.create_task(iterator.__anext__()) + self.loop.run_until_complete(future) + self.assertEqual(9, future.result()) + + future = self.loop.create_task(iterator.__anext__()) + self.assertRaises(StopAsyncIteration, self.loop.run_until_complete, future) + + def test_zip(self): + iterator = azip(arange(5), arange(5, 10), arange(10, 15)) + + future = self.loop.create_task(iterator.__anext__()) + self.loop.run_until_complete(future) + self.assertEqual((0, 5, 10), future.result()) + + future = self.loop.create_task(iterator.__anext__()) + self.loop.run_until_complete(future) + self.assertEqual((1, 6, 11), future.result()) + + future = self.loop.create_task(iterator.__anext__()) + self.loop.run_until_complete(future) + self.assertEqual((2, 7, 12), future.result()) + + future = self.loop.create_task(iterator.__anext__()) + self.loop.run_until_complete(future) + self.assertEqual((3, 8, 13), future.result()) + + future = self.loop.create_task(iterator.__anext__()) + self.loop.run_until_complete(future) + self.assertEqual((4, 9, 14), future.result()) + + future = self.loop.create_task(iterator.__anext__()) + self.assertRaises(StopAsyncIteration, self.loop.run_until_complete, future) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_connection/test_connection.py b/tests/test_connection/test_connection.py deleted file mode 100644 index 2868d49..0000000 --- a/tests/test_connection/test_connection.py +++ /dev/null @@ -1,16 +0,0 @@ -import unittest - -from mimo.connection.connection import Connection - - -class TestConnection(unittest.TestCase): - def test_is_full(self): - connection = Connection('a', 5) - - self.assertFalse(connection.is_full()) - connection.entities.extend([1, 2, 3, 4, 5]) - self.assertTrue(connection.is_full()) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/test_connection/test_connection_set.py b/tests/test_connection/test_connection_set.py deleted file mode 100644 index 094b38c..0000000 --- a/tests/test_connection/test_connection_set.py +++ /dev/null @@ -1,20 +0,0 @@ -import unittest - -from mimo.connection.output import Output -from mimo.connection.connection_set import ConnectionSet - - -class TestConnectionSet(unittest.TestCase): - def test_items_and_attributes(self): - a = Output('a') - b = Output('b') - connection_set = ConnectionSet([a, b]) - - self.assertEqual(a, connection_set.a) - self.assertEqual(b, connection_set.b) - self.assertEqual(a, connection_set['a']) - self.assertEqual(b, connection_set['b']) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/test_connection/test_input.py b/tests/test_connection/test_input.py index 309d5bd..8934afd 100644 --- a/tests/test_connection/test_input.py +++ b/tests/test_connection/test_input.py @@ -1,29 +1,47 @@ +import asyncio import unittest -from mimo.connection.input import Input +from mimo.connection.input import Input, ConnectionClosed class TestInput(unittest.TestCase): + def setUp(self): + self.loop = asyncio.get_event_loop() + + def test_push(self): + connection = Input('a') + + self.loop.run_until_complete(connection.push(0)) + + self.assertIn(0, connection._queue) + + def test_push_closed(self): + connection = Input('a') + connection.close() + + task = self.loop.create_task(connection.push(0)) + + self.assertRaises(ConnectionClosed, self.loop.run_until_complete, task) + def test_peek(self): connection = Input('a') + connection._queue.extend((0, 1, 2)) - self.assertRaises(IndexError, connection.peek) - connection.entities.append(1) - self.assertEqual(1, connection.peek()) + task = self.loop.create_task(connection.peek()) + self.loop.run_until_complete(task) + + self.assertEqual(0, task.result()) + self.assertEqual([0, 1, 2], list(connection._queue)) def test_pop(self): connection = Input('a') + connection._queue.extend((0, 1, 2)) - self.assertRaises(IndexError, connection.pop) - connection.entities.extend((1, 2, 3)) - self.assertEqual(1, connection.pop()) - self.assertEqual([2, 3], list(connection.entities)) - - def test_extend(self): - connection = Input('a', 3) + task = self.loop.create_task(connection.pop()) + self.loop.run_until_complete(task) - self.assertFalse(connection.extend((1, 2, 3, 4))) - self.assertEqual([1, 2, 3, 4], list(connection.entities)) + self.assertEqual(0, task.result()) + self.assertEqual([1, 2], list(connection._queue)) if __name__ == '__main__': diff --git a/tests/test_connection/test_output.py b/tests/test_connection/test_output.py deleted file mode 100644 index 9ca6302..0000000 --- a/tests/test_connection/test_output.py +++ /dev/null @@ -1,19 +0,0 @@ -import unittest - -from mimo.connection.output import Output - - -class TestOutput(unittest.TestCase): - def test_push(self): - connection = Output('a', 3) - - self.assertTrue(connection.push(1)) - self.assertEqual([1], list(connection.entities)) - self.assertTrue(connection.push(2)) - self.assertEqual([1, 2], list(connection.entities)) - self.assertFalse(connection.push(3)) - self.assertEqual([1, 2, 3], list(connection.entities)) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/test_stream/test_stream.py b/tests/test_stream/test_stream.py index 3e3aceb..b8fa797 100644 --- a/tests/test_stream/test_stream.py +++ b/tests/test_stream/test_stream.py @@ -1,33 +1,20 @@ import unittest from mimo import Stream -from mimo.connection.connection_set import ConnectionSet -from mimo.connection.input import Input -from mimo.connection.output import Output +from mimo.test_helper import TestHelper class TestStream(unittest.TestCase): def test_run(self): stream = Stream(['a'], ['b'], fn=fn) - input = Input('a', 5) - ins = ConnectionSet([input]) - output = Output('b', 5) - outs = ConnectionSet([output]) - input.entities.extend([1, 2, 3, 4, 5, 6, 7, 8, 9, 0]) - - self.assertTrue(stream.run(ins, outs)) - self.assertEqual([2, 4, 6, 8, 10], list(output.entities)) - self.assertTrue(stream.run(ins, outs)) - self.assertEqual([2, 4, 6, 8, 10, 12], list(output.entities)) - output.entities.clear() - self.assertFalse(stream.run(ins, outs)) - self.assertEqual([14, 16, 18, 0], list(output.entities)) - - -def fn(ins, outs, state): - while len(ins.a) > 0: - if not outs.b.push(2 * ins.a.pop()): - return True + helper = TestHelper(stream) + + self.assertEqual({'b': [2, 4, 6, 8, 10, 12, 14, 16, 18, 0]}, helper.run({'a': [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]})) + + +async def fn(ins, outs, state): + async for item in ins.a: + await outs.b.push(2 * item) if __name__ == '__main__': diff --git a/tests/test_test_helper.py b/tests/test_test_helper.py index 338b384..020d745 100644 --- a/tests/test_test_helper.py +++ b/tests/test_test_helper.py @@ -1,3 +1,4 @@ +import asyncio import unittest from mimo import Stream @@ -12,11 +13,20 @@ def test_run(self): self.assertEqual({'b': [2, 4, 6, 8, 10, 12, 14, 16, 18, 0]}, helper.run({'a': [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]})) + def test_run_timeout(self): + stream = Stream(ins=['a'], outs=['b'], fn=will_timeout) + helper = TestHelper(stream, timeout=1) -def fn(ins, outs, state): - while len(ins.a) > 0: - if not outs.b.push(2 * ins.a.pop()): - return True + self.assertRaises(asyncio.TimeoutError, helper.run) + + +async def fn(ins, outs, state): + async for item in ins.a: + await outs.b.push(2 * item) + + +async def will_timeout(ins, outs, state): + await asyncio.sleep(10) if __name__ == '__main__': diff --git a/tests/test_workflow/test_workflow.py b/tests/test_workflow/test_workflow.py index 14216ab..978df95 100644 --- a/tests/test_workflow/test_workflow.py +++ b/tests/test_workflow/test_workflow.py @@ -7,15 +7,14 @@ class TestWorkflow(unittest.TestCase): def test_run(self): workflow = Workflow() stream1 = Stream(outs=['entity'], fn=iterator_stream) - stream2 = Stream(ins=['entity'], fn=collect_stream) + stream2 = Stream(ins=['entity'], fn=collect_stream, state=[]) step1 = workflow.add_stream(stream1) step2 = workflow.add_stream(stream2) step1.pipe(step2) workflow.run() - self.assertIn('collection', stream2.state) - self.assertEqual(list(range(100)), stream2.state['collection']) + self.assertEqual(list(range(100)), stream2.state) def test_add_stream(self): stream = Stream(['a'], ['b']) @@ -30,49 +29,16 @@ def test_add_stream(self): self.assertEqual(set(workflow.inputs), set(node.input_ids.values())) self.assertEqual(set(workflow.outputs), set(node.output_ids.values())) - def test_get_head_streams(self): - workflow = Workflow() - stream1 = Stream(outs=['a']) - stream2 = Stream(outs=['b']) - stream3 = Stream(['c', 'd'], ['e']) - step1 = workflow.add_stream(stream1) - step2 = workflow.add_stream(stream2) - step3 = workflow.add_stream(stream3) - step1.pipe(step3, input='c') - step2.pipe(step3, input='d') - - heads = set(workflow._get_head_streams()) - - self.assertEqual({step1.stream_id, step2.stream_id}, heads) - - def test_get_streams_with_input(self): - workflow = Workflow() - stream1 = Stream(['a']) - stream2 = Stream(['b']) - step1 = workflow.add_stream(stream1) - workflow.add_stream(stream2) - - workflow.input_sets[step1.stream_id].a.entities.extend([1, 2, 3]) - nodes_with_input = set(workflow._get_streams_with_input()) - - self.assertEqual({step1.stream_id}, nodes_with_input) - -def iterator_stream(ins, outs, state): - if 'iterator' not in state: - state['iterator'] = iter(range(100)) - iterator = state['iterator'] - for item in iterator: - if not outs.entity.push(item): - return True +async def iterator_stream(ins, outs, state): + for item in iter(range(100)): + await outs.entity.push(item) + outs.entity.close() -def collect_stream(ins, outs, state): - if 'collection' not in state: - state['collection'] = [] - collection = state['collection'] - while len(ins.entity) > 0: - collection.append(ins.entity.pop()) +async def collect_stream(ins, outs, state): + async for item in ins.entity: + state.append(item) if __name__ == '__main__':