Skip to content

Commit

Permalink
add executor and refactor graph traverse
Browse files Browse the repository at this point in the history
  • Loading branch information
SiriusNEO committed Sep 21, 2024
1 parent bd17a66 commit 8448efc
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 98 deletions.
69 changes: 4 additions & 65 deletions tests/pfunc/test_native_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,78 +7,17 @@

def test_parse_native_function():
@P.native_function()
def add(a: P.Input, b: P.Input) -> P.Output:
return str(int(a) + int(b))
def add(a: P.Input, b: P.Input, c: P.Output):
c = str(int(a) + int(b))

def add_pyfunc(a: str, b: str) -> str:
return str(int(a) + int(b))
c = str(int(a) + int(b))

print(add.display_signature())
print(add.inputs)
print(add.outputs)
print(inspect.signature(add_pyfunc))


# def test_parse_native_function_two_rets():
# @P.native_function()
# def add(a: P.Input, b: P.Input) -> (P.Output, P.Output):
# return str(int(a) + int(b)), str(int(a) - int(b))

# def add_pyfunc(a: str, b: str) -> (str, str):
# return str(int(a) + int(b)), str(int(a) - int(b))

# print(add.display_signature())
# print(add.inputs)
# print(add.outputs)
# print(inspect.signature(add_pyfunc))


def test_call_function():
@P.native_function()
def add(a: P.Input, b: P.Input) -> P.Output:
return str(int(a) + int(b))

call = add("1", b="2")
print(call)

pyfunc = add.get_pyfunc()
result = pyfunc("1", b="2")
print(result)


# def test_serialize_call():
# @P.native_function()
# def add(a: P.Input, b: P.Input) -> P.Output:
# return str(int(a) + int(b))

# call = add("1", b="2")
# print(call)
# call_pickled = call.pickle()
# # print(call_pickled)
# call_unpickled = NativeCall.unpickle(call_pickled)
# print(call_unpickled)

# assert call.func.name == call_unpickled.func.name
# assert len(call.func.params) == len(call_unpickled.func.params)
# for p1, p2 in zip(call.func.params, call_unpickled.func.params):
# assert p1.name == p2.name
# assert p1.typ == p2.typ

# assert len(call.bindings) == len(call_unpickled.bindings)
# for k, v in call.bindings.items():
# assert type(call_unpickled.bindings[k]) == type(v)

# pyfunc = call_unpickled.func.get_pyfunc()
# ret = pyfunc("1", b="2")
# print(ret)

# pyfunc = call.func.get_pyfunc()
# ret = pyfunc("1", b="2")
# print(ret)


if __name__ == "__main__":
# test_parse_native_function()
# test_parse_native_function_two_rets()
test_call_function()
# test_serialize_call()
test_parse_native_function()
17 changes: 9 additions & 8 deletions tests/serve/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
PlaceholderFill,
PlaceholderGen,
PerformanceCriteria,
activate_producer,
activate_sv,
)
from parrot.serve.graph.call_request import (
SemanticCallMetadata,
Expand Down Expand Up @@ -197,19 +197,20 @@ def test_graph_traverse():

var_mgr.create_vars_for_semantic_request_chain(session_id, request3)
graph.insert_and_update_request_chain(request3)
out_var2 = request3.comp_chains[0].gen_node.sv

# view_graph(graph)
activate_producer(request1.comp_chains[0].gen_node, PerformanceCriteria.LATENCY)
activate_producer(request2.comp_chains[0].gen_node, PerformanceCriteria.LATENCY)
activate_producer(request3.comp_chains[0].gen_node, PerformanceCriteria.LATENCY)
# activate_producer(request3.comp_chains[0].gen_node, PerformanceCriteria.LATENCY)
activate_sv(out_var0, PerformanceCriteria.LATENCY)
activate_sv(out_var1, PerformanceCriteria.LATENCY)
activate_sv(out_var2, PerformanceCriteria.LATENCY)

# Expected results: A: depth 2, B: depth 1, C: depth 0
requests = [request1, request2, request3]
for req in requests:
assert req.comp_chains[0].is_activated
assert req.comp_chains[0].criteria == PerformanceCriteria.LATENCY
print(req.comp_chains[0].depth)
sv = req.comp_chains[0].gen_node.sv
assert sv.is_activated
assert sv.criteria == PerformanceCriteria.LATENCY
print(sv.depth)


if __name__ == "__main__":
Expand Down
18 changes: 11 additions & 7 deletions tests/serve/test_graph_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
PlaceholderFill,
PlaceholderGen,
PerformanceCriteria,
activate_producer,
activate_sv,
NativeFuncNode,
)
from parrot.serve.graph.call_request import (
Expand Down Expand Up @@ -91,11 +91,12 @@ def test_view_graph_complex():

var_mgr.create_vars_for_semantic_request_chain(session_id, request3)
graph.insert_and_update_request_chain(request3)
out_var2 = request3.comp_chains[0].gen_node.sv

view_graph(graph)
activate_producer(request1.comp_chains[0].gen_node, PerformanceCriteria.LATENCY)
activate_producer(request2.comp_chains[0].gen_node, PerformanceCriteria.LATENCY)
activate_producer(request3.comp_chains[0].gen_node, PerformanceCriteria.LATENCY)
activate_sv(out_var0, PerformanceCriteria.LATENCY)
activate_sv(out_var1, PerformanceCriteria.LATENCY)
activate_sv(out_var2, PerformanceCriteria.LATENCY)


def test_view_graph_with_native():
Expand Down Expand Up @@ -158,9 +159,12 @@ def test_view_graph_with_native():
graph.insert_and_update_request_chain(request3)

view_graph(graph)
activate_producer(request1.comp_chains[0].gen_node, PerformanceCriteria.LATENCY)
activate_producer(request2.comp_chains[0].gen_node, PerformanceCriteria.LATENCY)
activate_producer
# activate_sv(out_var0, PerformanceCriteria.LATENCY)
# activate_sv(out_var1, PerformanceCriteria.LATENCY)
activate_sv(out_var2, PerformanceCriteria.LATENCY)

# for var in [out_var0, out_var1, out_var2]:
# print(var.is_activated)


if __name__ == "__main__":
Expand Down
18 changes: 9 additions & 9 deletions tests/serve/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
PlaceholderGen,
ComputeGraph,
PerformanceCriteria,
activate_producer,
activate_sv,
SemanticVariable,
)
from parrot.serve.graph.call_request import (
Expand Down Expand Up @@ -75,7 +75,7 @@ def test_default_policy_throughput():
var_mgr.create_vars_for_semantic_request_chain(session_id, request_chain)
graph.insert_and_update_request_chain(request_chain)
comp_chain = request_chain.comp_chains[0]
activate_producer(comp_chain.gen_node, PerformanceCriteria.THROUGHPUT)
activate_sv(comp_chain.gen_node.sv, PerformanceCriteria.THROUGHPUT)
task = task_creator.create_task(comp_chain)
task.tokenize_chain(tokenizers_wrapper)
scheduler.submit_task(task)
Expand Down Expand Up @@ -132,7 +132,7 @@ def test_default_policy_latency():
var_mgr.create_vars_for_semantic_request_chain(session_id, request_chain)
graph.insert_and_update_request_chain(request_chain)
comp_chain = request_chain.comp_chains[0]
activate_producer(comp_chain.gen_node, PerformanceCriteria.LATENCY)
activate_sv(comp_chain.gen_node.sv, PerformanceCriteria.LATENCY)
task = task_creator.create_task(comp_chain)
task.tokenize_chain(tokenizers_wrapper)
scheduler.submit_task(task)
Expand Down Expand Up @@ -215,7 +215,7 @@ def test_app_fifo():
var_mgr.create_vars_for_semantic_request_chain(session_id, request_chain2)
graph.insert_and_update_request_chain(request_chain2)
comp_chain2 = request_chain2.comp_chains[0]
activate_producer(comp_chain2.gen_node, PerformanceCriteria.LATENCY)
activate_sv(comp_chain2.gen_node.sv, PerformanceCriteria.LATENCY)

task1 = task_creator.create_task(comp_chain1)
task1.tokenize_chain(tokenizers_wrapper)
Expand Down Expand Up @@ -324,7 +324,7 @@ def test_graph_group():
var_mgr.create_vars_for_semantic_request_chain(session_id, request_chain)
graph.insert_and_update_request_chain(request_chain)
comp_chain = request_chain.comp_chains[0]
activate_producer(comp_chain.gen_node, PerformanceCriteria.LATENCY)
activate_sv(comp_chain.gen_node.sv, PerformanceCriteria.LATENCY)

# view_graph(graph)

Expand Down Expand Up @@ -386,7 +386,7 @@ def test_ctx_group():
var_mgr.create_vars_for_semantic_request_chain(session_id, request_chain)
graph.insert_and_update_request_chain(request_chain)
comp_chain = request_chain.comp_chains[0]
activate_producer(comp_chain.gen_node, PerformanceCriteria.LATENCY)
activate_sv(comp_chain.gen_node.sv, PerformanceCriteria.LATENCY)
task = task_creator.create_task(comp_chain)
task.tokenize_chain(tokenizers_wrapper)
scheduler.submit_task(task)
Expand Down Expand Up @@ -446,7 +446,7 @@ def test_ctx_aware():
graph.insert_and_update_request_chain(request_chain)
comp_chain = request_chain.comp_chains[0]
first_vars.append(comp_chain.first_node.sv)
activate_producer(comp_chain.gen_node, PerformanceCriteria.THROUGHPUT)
activate_sv(comp_chain.gen_node.sv, PerformanceCriteria.THROUGHPUT)
task = task_creator.create_task(comp_chain)
task.tokenize_chain(tokenizers_wrapper)
scheduler.submit_task(task)
Expand All @@ -465,7 +465,7 @@ def test_ctx_aware():
if __name__ == "__main__":
# test_default_policy_throughput()
# test_default_policy_latency()
test_app_fifo()
# test_graph_group()
# test_app_fifo()
test_graph_group()
# test_ctx_group()
# test_ctx_aware()
71 changes: 67 additions & 4 deletions tests/serve/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from parrot.exceptions import ParrotCoreUserError

from parrot.frontend.pfunc import native_function, Output

from parrot.serve.session_manager import SessionManager
from parrot.serve.scheduler import TaskCreator, GlobalScheduler, GlobalSchedulerConfig
from parrot.serve.prefix_matcher import PrefixMatcher
Expand All @@ -12,19 +14,21 @@
from parrot.serve.context_manager import ServeCoreContextManager
from parrot.serve.engine_manager import EngineManager
from parrot.serve.session.graph_executor import GraphExecutor
from parrot.serve.session.native_executor import PyNativeExecutor
from parrot.serve.backend_repr import ExecutionEngine

from parrot.testing.localhost_server_daemon import fake_engine_server
from parrot.testing.fake_engine_server import engine_config

from parrot.serve.graph import (
RequestChain,
ComputeGraph,
PyNativeCallRequest,
NativeFuncNode,
ConstantFill,
PlaceholderFill,
PlaceholderGen,
PerformanceCriteria,
activate_producer,
activate_sv,
)
from parrot.serve.graph.call_request import SemanticFunctionParameter

Expand Down Expand Up @@ -115,7 +119,7 @@ def test_graph_executor():

async def main():
executor.add_request(request)
activate_producer(request.comp_chains[0].gen_node, PerformanceCriteria.LATENCY)
activate_sv(request.comp_chains[0].gen_node.sv, PerformanceCriteria.LATENCY)
await asyncio.sleep(1)
in_var.set("This is a test value.")
await asyncio.sleep(0.1)
Expand All @@ -126,6 +130,65 @@ async def main():
asyncio.run(main())


def test_native_executor():
session_id = 0

task_creator = TaskCreator()
scheduler_config = GlobalSchedulerConfig()
var_mgr = SemanticVariableManager(666)
tokenizers_wrapper = TokenizersWrapper()
context_mgr = ServeCoreContextManager()
engine_mgr = EngineManager(
tokenizers_wrapper=tokenizers_wrapper,
context_mgr=context_mgr,
engine_heartbeat_timeout=666,
)
task_creator = TaskCreator()
scheduler = GlobalScheduler(scheduler_config, engine_mgr, context_mgr)
executor = GraphExecutor(
session_id=session_id,
task_creator=task_creator,
scheduler=scheduler,
engine_mgr=engine_mgr,
context_mgr=context_mgr,
tokenizers_wrapper=tokenizers_wrapper,
)
native_executor = PyNativeExecutor(
session_id=session_id,
graph=executor.graph,
)

var_mgr.register_local_var_space(session_id)

@native_function()
def test_native_func(a: str, b: str, c: Output):
c.set(a + b)

payload = test_native_func("Hello", "World").to_request_payload()

print(payload)

native_request = PyNativeCallRequest.parse_from_payload(
request_id=0, session_id=session_id, payload=payload
)
func_node = NativeFuncNode(native_request)

var_mgr.create_vars_for_pynative_func(session_id, func_node)

print(func_node.input_values, func_node.output_vars)

out_var = func_node.output_vars["c"]

async def main():
native_executor.add_native_func(func_node)
activate_sv(out_var, PerformanceCriteria.LATENCY)
await asyncio.sleep(1)
print(out_var.get())

asyncio.run(main())


if __name__ == "__main__":
# test_session_manager()
test_session_manager()
test_graph_executor()
test_native_executor()
8 changes: 6 additions & 2 deletions tests/serve/test_sv.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ def test_request_chain_hash():
nodes=[
ConstantFill("Test1"),
PlaceholderFill(
parameter=SemanticFunctionParameter(name="a", is_output=False)
parameter=SemanticFunctionParameter(
name="a", is_output=False, value="test"
)
),
ConstantFill("Test2"),
PlaceholderGen(
Expand All @@ -43,7 +45,9 @@ def test_request_chain_hash():
nodes=[
ConstantFill("Test1"),
PlaceholderFill(
parameter=SemanticFunctionParameter(name="a", is_output=False)
parameter=SemanticFunctionParameter(
name="a", is_output=False, value="test"
)
),
ConstantFill("Test2"),
PlaceholderGen(
Expand Down
9 changes: 6 additions & 3 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from parrot.utils import RecyclePool, bytes_to_encoded_b64str, encoded_b64str_to_bytes
from parrot.utils import (
RecyclePool,
bytes_to_encoded_b64str,
encoded_b64str_to_bytes,
)


def test_recycle_pool():
Expand Down Expand Up @@ -37,8 +41,7 @@ def test_serialize_tools():
assert data == decoded



if __name__ == "__main__":
test_recycle_pool()
test_recycle_pool_error()
test_serialize_tools()
test_serialize_tools()

0 comments on commit 8448efc

Please sign in to comment.