Skip to content

Commit

Permalink
converted library to asyncio
Browse files Browse the repository at this point in the history
  • Loading branch information
Liam Childs committed Sep 26, 2016
1 parent dd96cdf commit 39fad83
Show file tree
Hide file tree
Showing 15 changed files with 171 additions and 227 deletions.
18 changes: 0 additions & 18 deletions mimo/connection/connection.py

This file was deleted.

78 changes: 69 additions & 9 deletions mimo/connection/input.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,73 @@
from .connection import Connection
import asyncio


class Input(Connection):
def peek(self):
return self.entities[0]
class ConnectionClosed(Exception):
pass

def pop(self):
return self.entities.popleft()

def extend(self, entities):
self.entities.extend(entities)
return len(self.entities) < self.threshold
class Input(asyncio.Queue):
def __init__(self, name, maxsize=0, loop=None):
super().__init__(maxsize, loop=loop)
self.name = name
self._closed = False

async def __aiter__(self):
return self

async def __anext__(self):
try:
return await self.pop()
except ConnectionClosed:
raise StopAsyncIteration

async def push(self, item):
if self._closed:
raise ConnectionClosed
return await self.put(item)

async def peek(self):
while self.empty():
if self._closed and len(self._putters) == 0:
raise ConnectionClosed

getter = self._loop.create_future()
self._getters.append(getter)
try:
await getter
except:
getter.cancel()
if not self.empty() and not getter.cancelled():
self._wakeup_next(self._getters)
raise
return self._queue[0]

async def pop(self):
while self.empty():
if self._closed and len(self._putters) == 0:
raise ConnectionClosed

getter = self._loop.create_future()
self._getters.append(getter)
try:
await getter
except:
getter.cancel()
if not self.empty() and not getter.cancelled():
self._wakeup_next(self._getters)
raise
return self.get_nowait()

def close(self):
self._closed = True
self._maxsize = 0
while self._putters:
putter = self._putters.popleft()
if not putter.done():
putter.set_result(None)
while self._getters:
getter = self._getters.popleft()
if not getter.done():
if self.empty():
getter.set_exception(ConnectionClosed)
else:
getter.set_result(None)
23 changes: 13 additions & 10 deletions mimo/connection/output.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from .connection import Connection
class Output:
def __init__(self, name):
self.name = name
self._connections = []

async def push(self, entity):
for connection in self._connections:
await connection.push(entity)

class Output(Connection):
def push(self, entity):
"""
Add an entity to the end of the connection. Return if connection can still be pushed to.
:param entity:
:return:
"""
self.entities.append(entity)
return len(self.entities) < self.threshold
def pipe(self, connection):
self._connections.append(connection)

def close(self):
for connection in self._connections:
connection.close()
5 changes: 2 additions & 3 deletions mimo/stream/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class Stream:
IN = []
OUT = []

def __init__(self, ins=None, outs=None, fn=None, name=None):
def __init__(self, ins=None, outs=None, fn=None, name=None, state=None):
"""
Initialise a stream. Streams can be sub-classed to alter the behaviour or customised directly.
If sub-classing a stream, the class members `IN` and `OUT` define the names of the input and output entities.
Expand All @@ -24,12 +24,11 @@ def __init__(self, ins=None, outs=None, fn=None, name=None):
:param name: name of the stream
:param fn: run function
"""
self.state = {}

self.ins = self.IN if ins is None else ins
self.outs = self.OUT if outs is None else outs
self.fn = fn
self.name = type(self).__name__ if name is None else name
self.state = {} if state is None else state

def run(self, ins, outs):
"""
Expand Down
24 changes: 14 additions & 10 deletions mimo/test_helper.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,27 @@
import asyncio

from .connection.input import Input
from .connection.output import Output
from .connection.connection_set import ConnectionSet


class TestHelper:
def __init__(self, stream):
def __init__(self, stream, timeout=5):
ins = [Input(in_) for in_ in stream.ins]
outs = [Output(out) for out in stream.outs]
self.sinks = [Input(out) for out in stream.outs]
for out, sink in zip(outs, self.sinks):
out.pipe(sink)
self.ins = ConnectionSet(ins)
self.outs = ConnectionSet(outs)
self.stream = stream
self._timeout = timeout
self._loop = asyncio.get_event_loop()

def run(self, ins):
def run(self, ins={}):
for key, value in ins.items():
self.ins[key].clear()
self.ins[key].extend(value)
outs = {out.name: [] for out in self.outs}
while self.stream.run(self.ins, self.outs):
for out in self.outs:
outs[out.name].extend(out.entities)
out.clear()
return outs
self.ins[key]._queue.extend(value)
self.ins[key].close()
task = self._loop.create_task(self.stream.run(self.ins, self.outs))
self._loop.run_until_complete(asyncio.wait_for(task, self._timeout, loop=self._loop))
return {sink.name: list(sink._queue) for sink in self.sinks}
1 change: 1 addition & 0 deletions mimo/workflow/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,5 @@ def pipe(self, step, output=None, input=None):
input_id = step.input_ids[input]

self.workflow.graph.add_edge(output_id, input_id)
self.workflow.outputs[output_id].pipe(self.workflow.inputs[input_id])
return step
47 changes: 8 additions & 39 deletions mimo/workflow/workflow.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections import deque
import asyncio

from uuid import uuid4
from lhc.graph import NPartiteGraph
from mimo.connection.input import Input
Expand Down Expand Up @@ -32,50 +33,18 @@ def __str__(self):
return '\n'.join(res)

def run(self):
stacked = set(self._get_head_streams()) | set(self._get_streams_with_input())
stack = deque(stacked)
while len(stack) > 0:
stream_id = stack.popleft()
stacked.remove(stream_id)
stream = self.streams[stream_id]
paused = stream.run(self.input_sets[stream_id], self.output_sets[stream_id])

output_ids = self.graph.get_children(stream_id)
for output_id in output_ids:
output = self.outputs[output_id]
input_ids = self.graph.get_children(output_id)
inputs = [self.inputs[input_id] for input_id in input_ids]
if any(input.is_full() for input in inputs):
continue
for input in inputs:
input.extend(output.entities)
output.clear()
for input_id in input_ids:
stream_ids = self.graph.get_children(input_id)
stack.extend(stream_id for stream_id in stream_ids if stream_id not in stacked)
stacked.update(stream_ids)
if paused:
stack.append(stream_id)
stacked.add(stream_id)
loop = asyncio.get_event_loop()
tasks = [stream.run(self.input_sets[stream_id], self.output_sets[stream_id])
for stream_id, stream in self.streams.items()]
loop.run_until_complete(asyncio.gather(*tasks))
loop.close()

def add_stream(self, stream):
stream_id, input_ids, output_ids = self._get_identifiers(stream)
self._add_vertices(stream, stream_id, input_ids, output_ids)
self._add_edges(stream_id, input_ids, output_ids)
return Node(self, stream_id, input_ids, output_ids)

def _get_head_streams(self):
for stream_id in self.graph.partitions[0]:
if len(self.graph.get_parents(stream_id)) == 0:
yield stream_id

def _get_streams_with_input(self):
inputs = self.inputs
for stream_id in self.graph.partitions[0]:
input_ids = self.graph.get_parents(stream_id)
if any(len(inputs[input_id]) > 0 for input_id in input_ids):
yield stream_id

def _get_identifiers(self, stream):
return str(uuid4())[:8],\
{name: str(uuid4())[:8] for name in stream.ins},\
Expand All @@ -84,7 +53,7 @@ def _get_identifiers(self, stream):
def _add_vertices(self, stream, stream_id, input_ids, output_ids):
self.streams[stream_id] = stream
self.input_sets[stream_id] = ConnectionSet(Input(name, self.threshold) for name in stream.ins)
self.output_sets[stream_id] = ConnectionSet(Output(name, self.threshold) for name in stream.outs)
self.output_sets[stream_id] = ConnectionSet(Output(name) for name in stream.outs)

self.graph.add_vertex(stream_id, 0)
for input, input_id in input_ids.items():
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

setup(
name='mimo',
version='1.0.6',
version='1.0.7',
author='Liam H. Childs',
author_email='[email protected]',
packages=find_packages(exclude=['tests']),
Expand Down
16 changes: 0 additions & 16 deletions tests/test_connection/test_connection.py

This file was deleted.

20 changes: 0 additions & 20 deletions tests/test_connection/test_connection_set.py

This file was deleted.

44 changes: 31 additions & 13 deletions tests/test_connection/test_input.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,47 @@
import asyncio
import unittest

from mimo.connection.input import Input
from mimo.connection.input import Input, ConnectionClosed


class TestInput(unittest.TestCase):
def setUp(self):
self.loop = asyncio.get_event_loop()

def test_push(self):
connection = Input('a')

self.loop.run_until_complete(connection.push(0))

self.assertIn(0, connection._queue)

def test_push_closed(self):
connection = Input('a')
connection.close()

task = self.loop.create_task(connection.push(0))

self.assertRaises(ConnectionClosed, self.loop.run_until_complete, task)

def test_peek(self):
connection = Input('a')
connection._queue.extend((0, 1, 2))

self.assertRaises(IndexError, connection.peek)
connection.entities.append(1)
self.assertEqual(1, connection.peek())
task = self.loop.create_task(connection.peek())
self.loop.run_until_complete(task)

self.assertEqual(0, task.result())
self.assertEqual([0, 1, 2], list(connection._queue))

def test_pop(self):
connection = Input('a')
connection._queue.extend((0, 1, 2))

self.assertRaises(IndexError, connection.pop)
connection.entities.extend((1, 2, 3))
self.assertEqual(1, connection.pop())
self.assertEqual([2, 3], list(connection.entities))

def test_extend(self):
connection = Input('a', 3)
task = self.loop.create_task(connection.pop())
self.loop.run_until_complete(task)

self.assertFalse(connection.extend((1, 2, 3, 4)))
self.assertEqual([1, 2, 3, 4], list(connection.entities))
self.assertEqual(0, task.result())
self.assertEqual([1, 2], list(connection._queue))


if __name__ == '__main__':
Expand Down
19 changes: 0 additions & 19 deletions tests/test_connection/test_output.py

This file was deleted.

Loading

0 comments on commit 39fad83

Please sign in to comment.