From 39fad835c61af417ec482bf90c0beb4c9bf39180 Mon Sep 17 00:00:00 2001 From: Liam Childs Date: Mon, 26 Sep 2016 10:06:37 +0200 Subject: [PATCH] converted library to asyncio --- mimo/connection/connection.py | 18 ----- mimo/connection/input.py | 78 +++++++++++++++++--- mimo/connection/output.py | 23 +++--- mimo/stream/stream.py | 5 +- mimo/test_helper.py | 24 +++--- mimo/workflow/node.py | 1 + mimo/workflow/workflow.py | 47 ++---------- setup.py | 2 +- tests/test_connection/test_connection.py | 16 ---- tests/test_connection/test_connection_set.py | 20 ----- tests/test_connection/test_input.py | 44 +++++++---- tests/test_connection/test_output.py | 19 ----- tests/test_stream/test_stream.py | 31 +++----- tests/test_test_helper.py | 18 ++++- tests/test_workflow/test_workflow.py | 52 +++---------- 15 files changed, 171 insertions(+), 227 deletions(-) delete mode 100644 mimo/connection/connection.py delete mode 100644 tests/test_connection/test_connection.py delete mode 100644 tests/test_connection/test_connection_set.py delete mode 100644 tests/test_connection/test_output.py diff --git a/mimo/connection/connection.py b/mimo/connection/connection.py deleted file mode 100644 index 50148c7..0000000 --- a/mimo/connection/connection.py +++ /dev/null @@ -1,18 +0,0 @@ -from collections import deque - - -class Connection: - def __init__(self, name, threshold=10): - self.entities = deque() - - self.name = name - self.threshold = threshold - - def __len__(self): - return len(self.entities) - - def clear(self): - self.entities.clear() - - def is_full(self): - return len(self.entities) >= self.threshold diff --git a/mimo/connection/input.py b/mimo/connection/input.py index 02d1a64..f68130f 100644 --- a/mimo/connection/input.py +++ b/mimo/connection/input.py @@ -1,13 +1,73 @@ -from .connection import Connection +import asyncio -class Input(Connection): - def peek(self): - return self.entities[0] +class ConnectionClosed(Exception): + pass - def pop(self): - return self.entities.popleft() - def extend(self, entities): - self.entities.extend(entities) - return len(self.entities) < self.threshold +class Input(asyncio.Queue): + def __init__(self, name, maxsize=0, loop=None): + super().__init__(maxsize, loop=loop) + self.name = name + self._closed = False + + async def __aiter__(self): + return self + + async def __anext__(self): + try: + return await self.pop() + except ConnectionClosed: + raise StopAsyncIteration + + async def push(self, item): + if self._closed: + raise ConnectionClosed + return await self.put(item) + + async def peek(self): + while self.empty(): + if self._closed and len(self._putters) == 0: + raise ConnectionClosed + + getter = self._loop.create_future() + self._getters.append(getter) + try: + await getter + except: + getter.cancel() + if not self.empty() and not getter.cancelled(): + self._wakeup_next(self._getters) + raise + return self._queue[0] + + async def pop(self): + while self.empty(): + if self._closed and len(self._putters) == 0: + raise ConnectionClosed + + getter = self._loop.create_future() + self._getters.append(getter) + try: + await getter + except: + getter.cancel() + if not self.empty() and not getter.cancelled(): + self._wakeup_next(self._getters) + raise + return self.get_nowait() + + def close(self): + self._closed = True + self._maxsize = 0 + while self._putters: + putter = self._putters.popleft() + if not putter.done(): + putter.set_result(None) + while self._getters: + getter = self._getters.popleft() + if not getter.done(): + if self.empty(): + getter.set_exception(ConnectionClosed) + else: + getter.set_result(None) diff --git a/mimo/connection/output.py b/mimo/connection/output.py index 40a0cf4..b14ac48 100644 --- a/mimo/connection/output.py +++ b/mimo/connection/output.py @@ -1,12 +1,15 @@ -from .connection import Connection +class Output: + def __init__(self, name): + self.name = name + self._connections = [] + async def push(self, entity): + for connection in self._connections: + await connection.push(entity) -class Output(Connection): - def push(self, entity): - """ - Add an entity to the end of the connection. Return if connection can still be pushed to. - :param entity: - :return: - """ - self.entities.append(entity) - return len(self.entities) < self.threshold + def pipe(self, connection): + self._connections.append(connection) + + def close(self): + for connection in self._connections: + connection.close() diff --git a/mimo/stream/stream.py b/mimo/stream/stream.py index 5a1bead..ea200de 100644 --- a/mimo/stream/stream.py +++ b/mimo/stream/stream.py @@ -8,7 +8,7 @@ class Stream: IN = [] OUT = [] - def __init__(self, ins=None, outs=None, fn=None, name=None): + def __init__(self, ins=None, outs=None, fn=None, name=None, state=None): """ Initialise a stream. Streams can be sub-classed to alter the behaviour or customised directly. If sub-classing a stream, the class members `IN` and `OUT` define the names of the input and output entities. @@ -24,12 +24,11 @@ def __init__(self, ins=None, outs=None, fn=None, name=None): :param name: name of the stream :param fn: run function """ - self.state = {} - self.ins = self.IN if ins is None else ins self.outs = self.OUT if outs is None else outs self.fn = fn self.name = type(self).__name__ if name is None else name + self.state = {} if state is None else state def run(self, ins, outs): """ diff --git a/mimo/test_helper.py b/mimo/test_helper.py index c351958..939d0dc 100644 --- a/mimo/test_helper.py +++ b/mimo/test_helper.py @@ -1,23 +1,27 @@ +import asyncio + from .connection.input import Input from .connection.output import Output from .connection.connection_set import ConnectionSet class TestHelper: - def __init__(self, stream): + def __init__(self, stream, timeout=5): ins = [Input(in_) for in_ in stream.ins] outs = [Output(out) for out in stream.outs] + self.sinks = [Input(out) for out in stream.outs] + for out, sink in zip(outs, self.sinks): + out.pipe(sink) self.ins = ConnectionSet(ins) self.outs = ConnectionSet(outs) self.stream = stream + self._timeout = timeout + self._loop = asyncio.get_event_loop() - def run(self, ins): + def run(self, ins={}): for key, value in ins.items(): - self.ins[key].clear() - self.ins[key].extend(value) - outs = {out.name: [] for out in self.outs} - while self.stream.run(self.ins, self.outs): - for out in self.outs: - outs[out.name].extend(out.entities) - out.clear() - return outs + self.ins[key]._queue.extend(value) + self.ins[key].close() + task = self._loop.create_task(self.stream.run(self.ins, self.outs)) + self._loop.run_until_complete(asyncio.wait_for(task, self._timeout, loop=self._loop)) + return {sink.name: list(sink._queue) for sink in self.sinks} diff --git a/mimo/workflow/node.py b/mimo/workflow/node.py index 2c97193..5d0baca 100644 --- a/mimo/workflow/node.py +++ b/mimo/workflow/node.py @@ -40,4 +40,5 @@ def pipe(self, step, output=None, input=None): input_id = step.input_ids[input] self.workflow.graph.add_edge(output_id, input_id) + self.workflow.outputs[output_id].pipe(self.workflow.inputs[input_id]) return step diff --git a/mimo/workflow/workflow.py b/mimo/workflow/workflow.py index 6dd1aad..85a0254 100644 --- a/mimo/workflow/workflow.py +++ b/mimo/workflow/workflow.py @@ -1,4 +1,5 @@ -from collections import deque +import asyncio + from uuid import uuid4 from lhc.graph import NPartiteGraph from mimo.connection.input import Input @@ -32,31 +33,11 @@ def __str__(self): return '\n'.join(res) def run(self): - stacked = set(self._get_head_streams()) | set(self._get_streams_with_input()) - stack = deque(stacked) - while len(stack) > 0: - stream_id = stack.popleft() - stacked.remove(stream_id) - stream = self.streams[stream_id] - paused = stream.run(self.input_sets[stream_id], self.output_sets[stream_id]) - - output_ids = self.graph.get_children(stream_id) - for output_id in output_ids: - output = self.outputs[output_id] - input_ids = self.graph.get_children(output_id) - inputs = [self.inputs[input_id] for input_id in input_ids] - if any(input.is_full() for input in inputs): - continue - for input in inputs: - input.extend(output.entities) - output.clear() - for input_id in input_ids: - stream_ids = self.graph.get_children(input_id) - stack.extend(stream_id for stream_id in stream_ids if stream_id not in stacked) - stacked.update(stream_ids) - if paused: - stack.append(stream_id) - stacked.add(stream_id) + loop = asyncio.get_event_loop() + tasks = [stream.run(self.input_sets[stream_id], self.output_sets[stream_id]) + for stream_id, stream in self.streams.items()] + loop.run_until_complete(asyncio.gather(*tasks)) + loop.close() def add_stream(self, stream): stream_id, input_ids, output_ids = self._get_identifiers(stream) @@ -64,18 +45,6 @@ def add_stream(self, stream): self._add_edges(stream_id, input_ids, output_ids) return Node(self, stream_id, input_ids, output_ids) - def _get_head_streams(self): - for stream_id in self.graph.partitions[0]: - if len(self.graph.get_parents(stream_id)) == 0: - yield stream_id - - def _get_streams_with_input(self): - inputs = self.inputs - for stream_id in self.graph.partitions[0]: - input_ids = self.graph.get_parents(stream_id) - if any(len(inputs[input_id]) > 0 for input_id in input_ids): - yield stream_id - def _get_identifiers(self, stream): return str(uuid4())[:8],\ {name: str(uuid4())[:8] for name in stream.ins},\ @@ -84,7 +53,7 @@ def _get_identifiers(self, stream): def _add_vertices(self, stream, stream_id, input_ids, output_ids): self.streams[stream_id] = stream self.input_sets[stream_id] = ConnectionSet(Input(name, self.threshold) for name in stream.ins) - self.output_sets[stream_id] = ConnectionSet(Output(name, self.threshold) for name in stream.outs) + self.output_sets[stream_id] = ConnectionSet(Output(name) for name in stream.outs) self.graph.add_vertex(stream_id, 0) for input, input_id in input_ids.items(): diff --git a/setup.py b/setup.py index 3451cbf..e7575d3 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ setup( name='mimo', - version='1.0.6', + version='1.0.7', author='Liam H. Childs', author_email='liam.h.childs@gmail.com', packages=find_packages(exclude=['tests']), diff --git a/tests/test_connection/test_connection.py b/tests/test_connection/test_connection.py deleted file mode 100644 index 2868d49..0000000 --- a/tests/test_connection/test_connection.py +++ /dev/null @@ -1,16 +0,0 @@ -import unittest - -from mimo.connection.connection import Connection - - -class TestConnection(unittest.TestCase): - def test_is_full(self): - connection = Connection('a', 5) - - self.assertFalse(connection.is_full()) - connection.entities.extend([1, 2, 3, 4, 5]) - self.assertTrue(connection.is_full()) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/test_connection/test_connection_set.py b/tests/test_connection/test_connection_set.py deleted file mode 100644 index 094b38c..0000000 --- a/tests/test_connection/test_connection_set.py +++ /dev/null @@ -1,20 +0,0 @@ -import unittest - -from mimo.connection.output import Output -from mimo.connection.connection_set import ConnectionSet - - -class TestConnectionSet(unittest.TestCase): - def test_items_and_attributes(self): - a = Output('a') - b = Output('b') - connection_set = ConnectionSet([a, b]) - - self.assertEqual(a, connection_set.a) - self.assertEqual(b, connection_set.b) - self.assertEqual(a, connection_set['a']) - self.assertEqual(b, connection_set['b']) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/test_connection/test_input.py b/tests/test_connection/test_input.py index 309d5bd..8934afd 100644 --- a/tests/test_connection/test_input.py +++ b/tests/test_connection/test_input.py @@ -1,29 +1,47 @@ +import asyncio import unittest -from mimo.connection.input import Input +from mimo.connection.input import Input, ConnectionClosed class TestInput(unittest.TestCase): + def setUp(self): + self.loop = asyncio.get_event_loop() + + def test_push(self): + connection = Input('a') + + self.loop.run_until_complete(connection.push(0)) + + self.assertIn(0, connection._queue) + + def test_push_closed(self): + connection = Input('a') + connection.close() + + task = self.loop.create_task(connection.push(0)) + + self.assertRaises(ConnectionClosed, self.loop.run_until_complete, task) + def test_peek(self): connection = Input('a') + connection._queue.extend((0, 1, 2)) - self.assertRaises(IndexError, connection.peek) - connection.entities.append(1) - self.assertEqual(1, connection.peek()) + task = self.loop.create_task(connection.peek()) + self.loop.run_until_complete(task) + + self.assertEqual(0, task.result()) + self.assertEqual([0, 1, 2], list(connection._queue)) def test_pop(self): connection = Input('a') + connection._queue.extend((0, 1, 2)) - self.assertRaises(IndexError, connection.pop) - connection.entities.extend((1, 2, 3)) - self.assertEqual(1, connection.pop()) - self.assertEqual([2, 3], list(connection.entities)) - - def test_extend(self): - connection = Input('a', 3) + task = self.loop.create_task(connection.pop()) + self.loop.run_until_complete(task) - self.assertFalse(connection.extend((1, 2, 3, 4))) - self.assertEqual([1, 2, 3, 4], list(connection.entities)) + self.assertEqual(0, task.result()) + self.assertEqual([1, 2], list(connection._queue)) if __name__ == '__main__': diff --git a/tests/test_connection/test_output.py b/tests/test_connection/test_output.py deleted file mode 100644 index 9ca6302..0000000 --- a/tests/test_connection/test_output.py +++ /dev/null @@ -1,19 +0,0 @@ -import unittest - -from mimo.connection.output import Output - - -class TestOutput(unittest.TestCase): - def test_push(self): - connection = Output('a', 3) - - self.assertTrue(connection.push(1)) - self.assertEqual([1], list(connection.entities)) - self.assertTrue(connection.push(2)) - self.assertEqual([1, 2], list(connection.entities)) - self.assertFalse(connection.push(3)) - self.assertEqual([1, 2, 3], list(connection.entities)) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/test_stream/test_stream.py b/tests/test_stream/test_stream.py index 3e3aceb..b8fa797 100644 --- a/tests/test_stream/test_stream.py +++ b/tests/test_stream/test_stream.py @@ -1,33 +1,20 @@ import unittest from mimo import Stream -from mimo.connection.connection_set import ConnectionSet -from mimo.connection.input import Input -from mimo.connection.output import Output +from mimo.test_helper import TestHelper class TestStream(unittest.TestCase): def test_run(self): stream = Stream(['a'], ['b'], fn=fn) - input = Input('a', 5) - ins = ConnectionSet([input]) - output = Output('b', 5) - outs = ConnectionSet([output]) - input.entities.extend([1, 2, 3, 4, 5, 6, 7, 8, 9, 0]) - - self.assertTrue(stream.run(ins, outs)) - self.assertEqual([2, 4, 6, 8, 10], list(output.entities)) - self.assertTrue(stream.run(ins, outs)) - self.assertEqual([2, 4, 6, 8, 10, 12], list(output.entities)) - output.entities.clear() - self.assertFalse(stream.run(ins, outs)) - self.assertEqual([14, 16, 18, 0], list(output.entities)) - - -def fn(ins, outs, state): - while len(ins.a) > 0: - if not outs.b.push(2 * ins.a.pop()): - return True + helper = TestHelper(stream) + + self.assertEqual({'b': [2, 4, 6, 8, 10, 12, 14, 16, 18, 0]}, helper.run({'a': [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]})) + + +async def fn(ins, outs, state): + async for item in ins.a: + await outs.b.push(2 * item) if __name__ == '__main__': diff --git a/tests/test_test_helper.py b/tests/test_test_helper.py index 338b384..020d745 100644 --- a/tests/test_test_helper.py +++ b/tests/test_test_helper.py @@ -1,3 +1,4 @@ +import asyncio import unittest from mimo import Stream @@ -12,11 +13,20 @@ def test_run(self): self.assertEqual({'b': [2, 4, 6, 8, 10, 12, 14, 16, 18, 0]}, helper.run({'a': [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]})) + def test_run_timeout(self): + stream = Stream(ins=['a'], outs=['b'], fn=will_timeout) + helper = TestHelper(stream, timeout=1) -def fn(ins, outs, state): - while len(ins.a) > 0: - if not outs.b.push(2 * ins.a.pop()): - return True + self.assertRaises(asyncio.TimeoutError, helper.run) + + +async def fn(ins, outs, state): + async for item in ins.a: + await outs.b.push(2 * item) + + +async def will_timeout(ins, outs, state): + await asyncio.sleep(10) if __name__ == '__main__': diff --git a/tests/test_workflow/test_workflow.py b/tests/test_workflow/test_workflow.py index 14216ab..978df95 100644 --- a/tests/test_workflow/test_workflow.py +++ b/tests/test_workflow/test_workflow.py @@ -7,15 +7,14 @@ class TestWorkflow(unittest.TestCase): def test_run(self): workflow = Workflow() stream1 = Stream(outs=['entity'], fn=iterator_stream) - stream2 = Stream(ins=['entity'], fn=collect_stream) + stream2 = Stream(ins=['entity'], fn=collect_stream, state=[]) step1 = workflow.add_stream(stream1) step2 = workflow.add_stream(stream2) step1.pipe(step2) workflow.run() - self.assertIn('collection', stream2.state) - self.assertEqual(list(range(100)), stream2.state['collection']) + self.assertEqual(list(range(100)), stream2.state) def test_add_stream(self): stream = Stream(['a'], ['b']) @@ -30,49 +29,16 @@ def test_add_stream(self): self.assertEqual(set(workflow.inputs), set(node.input_ids.values())) self.assertEqual(set(workflow.outputs), set(node.output_ids.values())) - def test_get_head_streams(self): - workflow = Workflow() - stream1 = Stream(outs=['a']) - stream2 = Stream(outs=['b']) - stream3 = Stream(['c', 'd'], ['e']) - step1 = workflow.add_stream(stream1) - step2 = workflow.add_stream(stream2) - step3 = workflow.add_stream(stream3) - step1.pipe(step3, input='c') - step2.pipe(step3, input='d') - - heads = set(workflow._get_head_streams()) - - self.assertEqual({step1.stream_id, step2.stream_id}, heads) - - def test_get_streams_with_input(self): - workflow = Workflow() - stream1 = Stream(['a']) - stream2 = Stream(['b']) - step1 = workflow.add_stream(stream1) - workflow.add_stream(stream2) - - workflow.input_sets[step1.stream_id].a.entities.extend([1, 2, 3]) - nodes_with_input = set(workflow._get_streams_with_input()) - - self.assertEqual({step1.stream_id}, nodes_with_input) - -def iterator_stream(ins, outs, state): - if 'iterator' not in state: - state['iterator'] = iter(range(100)) - iterator = state['iterator'] - for item in iterator: - if not outs.entity.push(item): - return True +async def iterator_stream(ins, outs, state): + for item in iter(range(100)): + await outs.entity.push(item) + outs.entity.close() -def collect_stream(ins, outs, state): - if 'collection' not in state: - state['collection'] = [] - collection = state['collection'] - while len(ins.entity) > 0: - collection.append(ins.entity.pop()) +async def collect_stream(ins, outs, state): + async for item in ins.entity: + state.append(item) if __name__ == '__main__':