diff --git a/examples/join.py b/examples/join.py new file mode 100644 index 0000000..2e005b0 --- /dev/null +++ b/examples/join.py @@ -0,0 +1,50 @@ +from mimo import Workflow, Stream + + +def main(): + workflow = Workflow(10) + step1 = workflow.add_stream(Stream(outs=['a'], fn=stream1)) + step2 = workflow.add_stream(Stream(outs=['b'], fn=stream2)) + step3 = workflow.add_stream(Stream(['c', 'd'], fn=stream3)) + + step1.pipe(step3, input='c') + step2.pipe(step3, input='d') + + print(str(workflow)) + workflow.run() + + +def stream1(ins, outs, state): + """ + Generates integers from 0 to 99. + """ + if 'iterator' not in state: + state['iterator'] = iter(range(100)) + iterator = state['iterator'] + for item in iterator: + if not outs.a.push(item): + return True + + +def stream2(ins, outs, state): + """ + Generates integers from 99 to 0. + """ + if 'iterator' not in state: + state['iterator'] = iter(range(99, -1, -1)) + iterator = state['iterator'] + for item in iterator: + if not outs.b.push(item): + return True + + +def stream3(ins, outs, state): + """ + Divide incoming entities by 10 and print to stdout + """ + while len(ins.c) > 0 and len(ins.d) > 0: + sys.stdout.write('{}\n'.format(ins.c.pop() + ins.d.pop())) + +if __name__ == '__main__': + import sys + sys.exit(main()) diff --git a/examples/linear.py b/examples/linear.py new file mode 100644 index 0000000..a8733ac --- /dev/null +++ b/examples/linear.py @@ -0,0 +1,48 @@ +from mimo import Workflow, Stream + + +def main(): + workflow = Workflow(10) + step1 = workflow.add_stream(Stream(outs=['a'], fn=stream1)) + step2 = workflow.add_stream(Stream(['b'], ['c'], fn=stream2)) + step3 = workflow.add_stream(Stream(['d'], fn=stream3)) + + step1.pipe(step2).pipe(step3) + + print(str(workflow)) + workflow.run() + + +def stream1(ins, outs, state): + """ + Generates integers from 0 to 99. + """ + if 'iterator' not in state: + state['iterator'] = iter(range(100)) + iterator = state['iterator'] + for item in iterator: + if not outs.a.push(item): + return True + + +def stream2(ins, outs, state): + """ + Multiplies the integers by 2. + """ + while len(ins.b) > 0: + item = ins.b.pop() + if not outs.c.push(item * 2): + break + return len(ins.b) > 0 + + +def stream3(ins, outs, state): + """ + Print incoming entities to stdout + """ + while len(ins.d) > 0: + print(ins.d.pop()) + +if __name__ == '__main__': + import sys + sys.exit(main()) diff --git a/examples/multi_output.py b/examples/multi_output.py new file mode 100644 index 0000000..bdc2edc --- /dev/null +++ b/examples/multi_output.py @@ -0,0 +1,46 @@ +from mimo import Workflow, Stream + + +def main(): + workflow = Workflow(10) + step1 = workflow.add_stream(Stream(outs=['a'], fn=stream1)) + step2 = workflow.add_stream(Stream(['b'], fn=stream2)) + step3 = workflow.add_stream(Stream(['c'], fn=stream3)) + + step1.pipe(step2) + step1.pipe(step3) + + print(str(workflow)) + workflow.run() + + +def stream1(ins, outs, state): + """ + Generates integers from 0 to 99. + """ + if 'iterator' not in state: + state['iterator'] = iter(range(100)) + iterator = state['iterator'] + for item in iterator: + if not outs.a.push(item): + return True + + +def stream2(ins, outs, state): + """ + Multiply incoming entities by 2 and print to stdout + """ + while len(ins.b) > 0: + sys.stdout.write('{}\n'.format(2 * ins.b.pop())) + + +def stream3(ins, outs, state): + """ + Divide incoming entities by 10 and print to stdout + """ + while len(ins.c) > 0: + sys.stdout.write('{}\n'.format(ins.c.pop() / 10)) + +if __name__ == '__main__': + import sys + sys.exit(main()) diff --git a/examples/split_output.py b/examples/split_output.py new file mode 100644 index 0000000..bdc2edc --- /dev/null +++ b/examples/split_output.py @@ -0,0 +1,46 @@ +from mimo import Workflow, Stream + + +def main(): + workflow = Workflow(10) + step1 = workflow.add_stream(Stream(outs=['a'], fn=stream1)) + step2 = workflow.add_stream(Stream(['b'], fn=stream2)) + step3 = workflow.add_stream(Stream(['c'], fn=stream3)) + + step1.pipe(step2) + step1.pipe(step3) + + print(str(workflow)) + workflow.run() + + +def stream1(ins, outs, state): + """ + Generates integers from 0 to 99. + """ + if 'iterator' not in state: + state['iterator'] = iter(range(100)) + iterator = state['iterator'] + for item in iterator: + if not outs.a.push(item): + return True + + +def stream2(ins, outs, state): + """ + Multiply incoming entities by 2 and print to stdout + """ + while len(ins.b) > 0: + sys.stdout.write('{}\n'.format(2 * ins.b.pop())) + + +def stream3(ins, outs, state): + """ + Divide incoming entities by 10 and print to stdout + """ + while len(ins.c) > 0: + sys.stdout.write('{}\n'.format(ins.c.pop() / 10)) + +if __name__ == '__main__': + import sys + sys.exit(main()) diff --git a/mimo/__init__.py b/mimo/__init__.py index 1fde943..283777b 100644 --- a/mimo/__init__.py +++ b/mimo/__init__.py @@ -1 +1,2 @@ from .stream import Stream +from .workflow import Workflow diff --git a/mimo/connection/__init__.py b/mimo/connection/__init__.py index acaec5a..e69de29 100644 --- a/mimo/connection/__init__.py +++ b/mimo/connection/__init__.py @@ -1,2 +0,0 @@ -from .connection import Connection -from .connection_set import ConnectionSet diff --git a/mimo/connection/connection.py b/mimo/connection/connection.py index 3496c0f..5cb6479 100644 --- a/mimo/connection/connection.py +++ b/mimo/connection/connection.py @@ -4,41 +4,12 @@ class Connection: def __init__(self, name, threshold=10): self.entities = deque() - self.connections = set() self.name = name self.threshold = threshold - def peek(self): - return self.entities[0] - - def pop(self): - return self.entities.popleft() - - def push(self, entity): - """ - Add an entity to the end of the connection. Return if connection can still be pushed to. - :param entity: - :return: - """ - self.entities.append(entity) - return len(self.entities) < self.threshold - - def extend(self, entities): - self.entities.extend(entities) - return len(self.entities) < self.threshold + def __len__(self): + return len(self.entities) def is_full(self): return len(self.entities) >= self.threshold - - def join(self, connection): - self.connections.add(connection) - - def drain(self): - entities = self.entities - if len(entities) == 0 or any(connection.is_full() for connection in self.connections): - return False - for connection in self.connections: - connection.extend(entities) - entities.clear() - return True diff --git a/mimo/connection/connection_set.py b/mimo/connection/connection_set.py index 85ec5cc..78b3c42 100644 --- a/mimo/connection/connection_set.py +++ b/mimo/connection/connection_set.py @@ -16,14 +16,3 @@ def __getattr__(self, key): def __getitem__(self, key): return self.connections[key] - - def drain(self): - """ - Drain all connections and return the streams that received updates. - :return: Set of updated streams - """ - drained = set() - for connection in self.connections.values(): - if connection.drain(): - drained.add(connection.name) - return drained diff --git a/mimo/connection/input.py b/mimo/connection/input.py new file mode 100644 index 0000000..02d1a64 --- /dev/null +++ b/mimo/connection/input.py @@ -0,0 +1,13 @@ +from .connection import Connection + + +class Input(Connection): + def peek(self): + return self.entities[0] + + def pop(self): + return self.entities.popleft() + + def extend(self, entities): + self.entities.extend(entities) + return len(self.entities) < self.threshold diff --git a/mimo/connection/output.py b/mimo/connection/output.py new file mode 100644 index 0000000..40a0cf4 --- /dev/null +++ b/mimo/connection/output.py @@ -0,0 +1,12 @@ +from .connection import Connection + + +class Output(Connection): + def push(self, entity): + """ + Add an entity to the end of the connection. Return if connection can still be pushed to. + :param entity: + :return: + """ + self.entities.append(entity) + return len(self.entities) < self.threshold diff --git a/mimo/stream/stream.py b/mimo/stream/stream.py index 2770a29..5a1bead 100644 --- a/mimo/stream/stream.py +++ b/mimo/stream/stream.py @@ -1,84 +1,46 @@ -from mimo.connection import Connection, ConnectionSet +from mimo.connection.connection_set import ConnectionSet class Stream: + __slots__ = ('state', 'ins', 'outs', 'name', 'fn') + IN = [] OUT = [] - def __init__(self, ins=None, outs=None, name=None): - self.paused = False + def __init__(self, ins=None, outs=None, fn=None, name=None): + """ + Initialise a stream. Streams can be sub-classed to alter the behaviour or customised directly. + If sub-classing a stream, the class members `IN` and `OUT` define the names of the input and output entities. + Overriding the `run` function will determine what the stream does and the name of the class determines the name + of the stream. + If creating a stream directly, the parameters `ins` and `outs` define the names of the input and output + entities. The `fn` parameter is a function that will determine what the stream does. This function takes a set + of inputs, a set of outputs and the state of the stream as a dictionary. The `name` parameter determines the + name of the stream. + + :param ins: names of input entities + :param outs: names of output entities + :param name: name of the stream + :param fn: run function + """ + self.state = {} - ins = self.IN if ins is None else ins - outs = self.OUT if outs is None else outs - self.ins = ConnectionSet(Connection(name) for name in ins) - self.outs = ConnectionSet(Connection(name) for name in outs) - self.children = {out: set() for out in outs} + self.ins = self.IN if ins is None else ins + self.outs = self.OUT if outs is None else outs + self.fn = fn self.name = type(self).__name__ if name is None else name def run(self, ins, outs): """ - The main method to over-ride when implementing custom streams. + The main method to over-ride when implementing custom streams. This can also be over-ridden by providing the + 'fn' parameter when creating a new stream. + :param ins: connection set of input connections :type ins: ConnectionSet :param outs: connection set of output connections :type outs: ConnectionSet - :return: True if stream is paused - """ - raise NotImplementedError - - def activate(self): - """ - Run a step and propogate the output entities to any connected child streams. - :return: + :return: True if stream is did not finish running (eg. was suspended because output was full) + :rtype: bool """ - run = self.run - ins = self.ins - outs = self.outs - children = self.children - - paused = True - while paused: - paused = run(ins, outs) - updated_streams = set() - for streams in children.values(): - updated_streams.update(stream for stream in streams if stream.paused) - for out in outs.drain(): - updated_streams.update(children[out]) - for updated_stream in updated_streams: - updated_stream.activate() - self.paused = paused - - def pipe(self, stream, output=None, input=None): - """ - Pipe the output of one stream to the input of another. If there are more than one outputs or inputs, the - specific output/input must be specified. - :param stream: stream to connect to - :param output: name of the output connection (default: None) - :param input: name of the input connection (default: None) - :return: stream connected to - """ - if len(self.outs) == 0: - raise ValueError('{} has no output to pipe from'.format(self.name)) - elif len(stream.ins) == 0: - raise ValueError('{} has no input to pipe to'.format(stream.name)) - - if output is None: - if len(self.outs) == 1: - from_connection = next(iter(self.outs)) - else: - raise ValueError('{} has multiple output and none chosen to pipe from'.format(self.name)) - else: - from_connection = self.outs[output] - - if input is None: - if len(stream.ins) == 1: - to_connection = next(iter(stream.ins)) - else: - raise ValueError('{} has multiple output and none chosen to pipe from'.format(stream.name)) - else: - to_connection = stream.ins[input] - - from_connection.join(to_connection) - self.children[from_connection.name].add(stream) - return stream + return self.fn(ins, outs, self.state) diff --git a/mimo/workflow/__init__.py b/mimo/workflow/__init__.py new file mode 100644 index 0000000..74af3ea --- /dev/null +++ b/mimo/workflow/__init__.py @@ -0,0 +1 @@ +from .workflow import Workflow diff --git a/mimo/workflow/node.py b/mimo/workflow/node.py new file mode 100644 index 0000000..2c97193 --- /dev/null +++ b/mimo/workflow/node.py @@ -0,0 +1,43 @@ +class Node: + __slots__ = ('workflow', 'stream_id', 'input_ids', 'output_ids') + + def __init__(self, workflow, stream_id, input_ids, output_ids): + self.workflow = workflow + self.stream_id = stream_id + self.input_ids = input_ids + self.output_ids = output_ids + + def pipe(self, step, output=None, input=None): + """ + Pipe the output of one stream to the input of another. If there are more than one outputs or inputs, the + specific output/input must be specified. + :param step: stream to connect to + :param output: name of the output connection (default: None) + :param input: name of the input connection (default: None) + :return: stream connected to + """ + if len(self.output_ids) == 0: + raise ValueError('{} has no output to pipe from'.format(self.workflow.streams[self.stream_id])) + elif len(step.input_ids) == 0: + raise ValueError('{} has no input to pipe to'.format(self.workflow.streams[step.stream_id])) + + if output is None: + if len(self.output_ids) == 1: + output_id = next(iter(self.output_ids.values())) + else: + msg = '{} has multiple output and none chosen to pipe from' + raise ValueError(msg.format(self.workflow.streams[self.stream_id])) + else: + output_id = self.output_ids[output] + + if input is None: + if len(step.input_ids) == 1: + input_id = next(iter(step.input_ids.values())) + else: + msg = '{} has multiple input and none chosen to pipe to' + raise ValueError(msg.format(self.workflow.streams[step.stream_id])) + else: + input_id = step.input_ids[input] + + self.workflow.graph.add_edge(output_id, input_id) + return step diff --git a/mimo/workflow/workflow.py b/mimo/workflow/workflow.py new file mode 100644 index 0000000..8a6deb9 --- /dev/null +++ b/mimo/workflow/workflow.py @@ -0,0 +1,101 @@ +from collections import deque +from uuid import uuid4 +from lhc.graph import NPartiteGraph +from mimo.connection.input import Input +from mimo.connection.output import Output +from mimo.connection.connection_set import ConnectionSet +from .node import Node + + +class Workflow: + def __init__(self, threshold=100): + self.graph = NPartiteGraph(n=3) + + self.streams = {} + self.input_sets = {} + self.output_sets = {} + self.inputs = {} + self.outputs = {} + + self.threshold = threshold + + def __str__(self): + graph = self.graph + res = ['digraph {} {{'.format(graph.name)] + partitions = [self.streams, self.inputs, self.outputs] + for i, partition, shape in zip(range(3), graph.partitions, graph.shapes): + for vertex in partition: + res.append(' "{}" [shape={},label="{}"];'.format(vertex, shape, partitions[i][vertex].name)) + for fr, to in sorted(graph.graph.es): + res.append(' "{}" -> "{}";'.format(fr, to)) + res.append('}') + return '\n'.join(res) + + def run(self): + stacked = set(self._get_head_streams()) | set(self._get_streams_with_input()) + stack = deque(stacked) + while len(stack) > 0: + stream_id = stack.popleft() + stacked.remove(stream_id) + stream = self.streams[stream_id] + paused = stream.run(self.input_sets[stream_id], self.output_sets[stream_id]) + + output_ids = self.graph.get_children(stream_id) + for output_id in output_ids: + output = self.outputs[output_id] + input_ids = self.graph.get_children(output_id) + inputs = [self.inputs[input_id] for input_id in input_ids] + if any(input.is_full() for input in inputs): + continue + for input in inputs: + input.extend(output.entities) + output.entities.clear() + for input_id in input_ids: + stream_ids = self.graph.get_children(input_id) + stack.extend(stream_id for stream_id in stream_ids if stream_id not in stacked) + stacked.update(stream_ids) + if paused: + stack.append(stream_id) + stacked.add(stream_id) + + def add_stream(self, stream): + stream_id, input_ids, output_ids = self._get_identifiers(stream) + self._add_vertices(stream, stream_id, input_ids, output_ids) + self._add_edges(stream_id, input_ids, output_ids) + return Node(self, stream_id, input_ids, output_ids) + + def _get_head_streams(self): + for stream_id in self.graph.partitions[0]: + if len(self.graph.get_parents(stream_id)) == 0: + yield stream_id + + def _get_streams_with_input(self): + inputs = self.inputs + for stream_id in self.graph.partitions[0]: + input_ids = self.graph.get_parents(stream_id) + if any(len(inputs[input_id]) > 0 for input_id in input_ids): + yield stream_id + + def _get_identifiers(self, stream): + return str(uuid4())[:8],\ + {name: str(uuid4())[:8] for name in stream.ins},\ + {name: str(uuid4())[:8] for name in stream.outs} + + def _add_vertices(self, stream, stream_id, input_ids, output_ids): + self.streams[stream_id] = stream + self.input_sets[stream_id] = ConnectionSet(Input(name, self.threshold) for name in stream.ins) + self.output_sets[stream_id] = ConnectionSet(Output(name, self.threshold) for name in stream.outs) + + self.graph.add_vertex(stream_id, 0) + for input, input_id in input_ids.items(): + self.inputs[input_id] = self.input_sets[stream_id][input] + self.graph.add_vertex(input_id, 1) + for output, output_id in output_ids.items(): + self.outputs[output_id] = self.output_sets[stream_id][output] + self.graph.add_vertex(output_id, 2) + + def _add_edges(self, stream_id, input_ids, output_ids): + for input, in_id in input_ids.items(): + self.graph.add_edge(in_id, stream_id) + for output, out_id in output_ids.items(): + self.graph.add_edge(stream_id, out_id) diff --git a/tests/test_connection/test_connection.py b/tests/test_connection/test_connection.py index 441d266..2868d49 100644 --- a/tests/test_connection/test_connection.py +++ b/tests/test_connection/test_connection.py @@ -1,59 +1,15 @@ import unittest -from mimo.connection import Connection -from mimo.stream import Stream +from mimo.connection.connection import Connection class TestConnection(unittest.TestCase): - def test_peek(self): - connection = Connection('a') + def test_is_full(self): + connection = Connection('a', 5) - self.assertRaises(IndexError, connection.peek) - connection.entities.append(1) - self.assertEqual(1, connection.peek()) - - def test_pop(self): - connection = Connection('a') - - self.assertRaises(IndexError, connection.pop) - connection.entities.extend((1, 2, 3)) - self.assertEqual(1, connection.pop()) - self.assertEqual([2, 3], list(connection.entities)) - - def test_push(self): - connection = Connection('a', 3) - - self.assertTrue(connection.push(1)) - self.assertEqual([1], list(connection.entities)) - self.assertTrue(connection.push(2)) - self.assertEqual([1, 2], list(connection.entities)) - self.assertFalse(connection.push(3)) - self.assertEqual([1, 2, 3], list(connection.entities)) - - def test_extend(self): - connection = Connection('a', 3) - - self.assertFalse(connection.extend((1, 2, 3, 4))) - self.assertEqual([1, 2, 3, 4], list(connection.entities)) - - def test_connect_to_input(self): - connection = Connection('a') - stream = Stream(['input'], ['output']) - - connection.join(stream, 'input') - - self.assertEqual(stream, next(iter(connection.streams))) - self.assertEqual(stream.ins['input'], next(iter(connection.connections))) - - def test_drain(self): - connection = Connection('a') - stream = Stream(['input', 'output']) - - connection.join(stream, 'input') + self.assertFalse(connection.is_full()) connection.entities.extend([1, 2, 3, 4, 5]) - connection.drain() - - self.assertEqual([1, 2, 3, 4, 5], list(stream.ins.input.entities)) + self.assertTrue(connection.is_full()) if __name__ == '__main__': diff --git a/tests/test_connection/test_connection_set.py b/tests/test_connection/test_connection_set.py index e4dfcd2..094b38c 100644 --- a/tests/test_connection/test_connection_set.py +++ b/tests/test_connection/test_connection_set.py @@ -1,29 +1,19 @@ import unittest -from mimo.connection import Connection, ConnectionSet +from mimo.connection.output import Output +from mimo.connection.connection_set import ConnectionSet class TestConnectionSet(unittest.TestCase): def test_items_and_attributes(self): - a = Connection('a') - b = Connection('b') - connections = ConnectionSet([a, b]) - - self.assertEqual(a, connections.a) - self.assertEqual(b, connections.b) - self.assertEqual(a, connections['a']) - self.assertEqual(b, connections['b']) - - def test_drain(self): - connection = Connection('a') - connection.entities.extend((1, 2, 3, 4)) - - self.assertTrue(connection.drain()) - - def test_drain_empty(self): - connection = Connection('a') - - self.assertFalse(connection.drain()) + a = Output('a') + b = Output('b') + connection_set = ConnectionSet([a, b]) + + self.assertEqual(a, connection_set.a) + self.assertEqual(b, connection_set.b) + self.assertEqual(a, connection_set['a']) + self.assertEqual(b, connection_set['b']) if __name__ == '__main__': diff --git a/tests/test_connection/test_input.py b/tests/test_connection/test_input.py new file mode 100644 index 0000000..309d5bd --- /dev/null +++ b/tests/test_connection/test_input.py @@ -0,0 +1,30 @@ +import unittest + +from mimo.connection.input import Input + + +class TestInput(unittest.TestCase): + def test_peek(self): + connection = Input('a') + + self.assertRaises(IndexError, connection.peek) + connection.entities.append(1) + self.assertEqual(1, connection.peek()) + + def test_pop(self): + connection = Input('a') + + self.assertRaises(IndexError, connection.pop) + connection.entities.extend((1, 2, 3)) + self.assertEqual(1, connection.pop()) + self.assertEqual([2, 3], list(connection.entities)) + + def test_extend(self): + connection = Input('a', 3) + + self.assertFalse(connection.extend((1, 2, 3, 4))) + self.assertEqual([1, 2, 3, 4], list(connection.entities)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_connection/test_output.py b/tests/test_connection/test_output.py new file mode 100644 index 0000000..902e52c --- /dev/null +++ b/tests/test_connection/test_output.py @@ -0,0 +1,20 @@ +import unittest + +from mimo import Stream +from mimo.connection.output import Output + + +class TestOutput(unittest.TestCase): + def test_push(self): + connection = Output('a', 3) + + self.assertTrue(connection.push(1)) + self.assertEqual([1], list(connection.entities)) + self.assertTrue(connection.push(2)) + self.assertEqual([1, 2], list(connection.entities)) + self.assertFalse(connection.push(3)) + self.assertEqual([1, 2, 3], list(connection.entities)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_stream/test_stream.py b/tests/test_stream/test_stream.py index c0737f1..263782a 100644 --- a/tests/test_stream/test_stream.py +++ b/tests/test_stream/test_stream.py @@ -1,17 +1,33 @@ import unittest from mimo import Stream +from mimo.connection.connection_set import ConnectionSet +from mimo.connection.input import Input +from mimo.connection.output import Output class TestStream(unittest.TestCase): - def test_pipe(self): - stream1 = Stream(['a'], ['b']) - stream2 = Stream(['c'], ['d']) + def test_run(self): + stream = Stream(['a'], ['b'], fn=fn) + input = Input('a', 5) + ins = ConnectionSet([input]) + output = Output('b', 5) + outs = ConnectionSet([output]) + input.entities.extend([1, 2, 3, 4, 5, 6, 7, 8, 9, 0]) - stream1.pipe(stream2) + stream.run(ins, outs) + self.assertEqual([2, 4, 6, 8, 10], list(output.entities)) + stream.run(ins, outs) + self.assertEqual([2, 4, 6, 8, 10, 12], list(output.entities)) + output.entities.clear() + stream.run(ins, outs) + self.assertEqual([14, 16, 18, 0], list(output.entities)) - self.assertEqual({stream2}, stream1.children['b']) - self.assertEqual({stream2.ins['c']}, stream1.outs['b'].connections) + +def fn(ins, outs, state): + while len(ins.a) > 0: + if not outs.b.push(2 * ins.a.pop()): + return True if __name__ == '__main__': diff --git a/tests/test_workflow/__init__.py b/tests/test_workflow/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_workflow/test_node.py b/tests/test_workflow/test_node.py new file mode 100644 index 0000000..1d9ed38 --- /dev/null +++ b/tests/test_workflow/test_node.py @@ -0,0 +1,19 @@ +import unittest + +from mimo import Workflow, Stream + + +class TestStep(unittest.TestCase): + def test_pipe(self): + workflow = Workflow() + step1 = workflow.add_stream(Stream(['a'], ['b'])) + step2 = workflow.add_stream(Stream(['c'], ['d'])) + + step1.pipe(step2) + + self.assertEqual(2, len(workflow.streams)) + self.assertIn(step2.input_ids['c'], workflow.graph.graph.adjacency[step1.output_ids['b']].children) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_workflow/test_workflow.py b/tests/test_workflow/test_workflow.py new file mode 100644 index 0000000..14216ab --- /dev/null +++ b/tests/test_workflow/test_workflow.py @@ -0,0 +1,79 @@ +import unittest + +from mimo import Workflow, Stream + + +class TestWorkflow(unittest.TestCase): + def test_run(self): + workflow = Workflow() + stream1 = Stream(outs=['entity'], fn=iterator_stream) + stream2 = Stream(ins=['entity'], fn=collect_stream) + step1 = workflow.add_stream(stream1) + step2 = workflow.add_stream(stream2) + step1.pipe(step2) + + workflow.run() + + self.assertIn('collection', stream2.state) + self.assertEqual(list(range(100)), stream2.state['collection']) + + def test_add_stream(self): + stream = Stream(['a'], ['b']) + workflow = Workflow() + node = workflow.add_stream(stream) + + self.assertEqual({node.stream_id: stream}, workflow.streams) + self.assertEqual(stream.ins, list(conn.name for conn in workflow.inputs.values())) + self.assertEqual(stream.outs, list(conn.name for conn in workflow.outputs.values())) + self.assertEqual(workflow, node.workflow) + self.assertIn(node.stream_id, workflow.streams) + self.assertEqual(set(workflow.inputs), set(node.input_ids.values())) + self.assertEqual(set(workflow.outputs), set(node.output_ids.values())) + + def test_get_head_streams(self): + workflow = Workflow() + stream1 = Stream(outs=['a']) + stream2 = Stream(outs=['b']) + stream3 = Stream(['c', 'd'], ['e']) + step1 = workflow.add_stream(stream1) + step2 = workflow.add_stream(stream2) + step3 = workflow.add_stream(stream3) + step1.pipe(step3, input='c') + step2.pipe(step3, input='d') + + heads = set(workflow._get_head_streams()) + + self.assertEqual({step1.stream_id, step2.stream_id}, heads) + + def test_get_streams_with_input(self): + workflow = Workflow() + stream1 = Stream(['a']) + stream2 = Stream(['b']) + step1 = workflow.add_stream(stream1) + workflow.add_stream(stream2) + + workflow.input_sets[step1.stream_id].a.entities.extend([1, 2, 3]) + nodes_with_input = set(workflow._get_streams_with_input()) + + self.assertEqual({step1.stream_id}, nodes_with_input) + + +def iterator_stream(ins, outs, state): + if 'iterator' not in state: + state['iterator'] = iter(range(100)) + iterator = state['iterator'] + for item in iterator: + if not outs.entity.push(item): + return True + + +def collect_stream(ins, outs, state): + if 'collection' not in state: + state['collection'] = [] + collection = state['collection'] + while len(ins.entity) > 0: + collection.append(ins.entity.pop()) + + +if __name__ == '__main__': + unittest.main()