-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #48 from intelligentnode/39-revamp-the-flow
39 revamp the flow
- Loading branch information
Showing
21 changed files
with
307 additions
and
52 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# print the project tree | ||
tree -I '__pycache__|test|Instructions|assets' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,42 +1,45 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
from intelli.controller.remote_image_model import RemoteImageModel | ||
from intelli.flow.types import AgentTypes | ||
from intelli.function.chatbot import Chatbot | ||
from intelli.model.input.chatbot_input import ChatModelInput | ||
from intelli.controller.remote_image_model import RemoteImageModel | ||
from intelli.model.input.image_input import ImageModelInput | ||
from abc import ABC, abstractmethod | ||
from intelli.flow.types import AgentTypes | ||
from intelli.flow.input.agent_input import AgentInput, TextAgentInput, ImageAgentInput | ||
|
||
|
||
class BasicAgent(ABC): | ||
|
||
@abstractmethod | ||
def execute(self, agent_input): | ||
pass | ||
|
||
|
||
class Agent(BasicAgent): | ||
def __init__(self, agent_type, provider, mission, model_params, options=None): | ||
|
||
if agent_type not in AgentTypes._value2member_map_: | ||
raise ValueError("Incorrect agent type. Accepted types in AgentTypes.") | ||
|
||
self.type = agent_type | ||
self.provider = provider | ||
self.mission = mission | ||
self.model_params = model_params | ||
self.options = options | ||
|
||
|
||
def execute(self, agent_input): | ||
def execute(self, agent_input: AgentInput): | ||
|
||
# Check the agent type and call the appropriate function | ||
if self.type == AgentTypes.TEXT.value: | ||
chatbot = Chatbot(self.model_params['key'], self.provider, self.options) | ||
chat_input = ChatModelInput(self.mission, model=self.model_params.get('model')) | ||
chat_input.add_user_message(agent_input) | ||
chat_input.add_user_message(agent_input.desc) | ||
result = chatbot.chat(chat_input)[0] | ||
elif self.type == AgentTypes.IMAGE.value: | ||
image_model = RemoteImageModel(self.model_params['key'], self.provider) | ||
image_input = ImageModelInput(prompt=agent_input, model=self.model_params.get('model')) | ||
image_input = ImageModelInput(prompt=agent_input.desc, model=self.model_params.get('model')) | ||
result = image_model.generate_images(image_input) | ||
else: | ||
raise ValueError(f"Unsupported agent type: {self.type}.") | ||
|
||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import asyncio | ||
import networkx as nx | ||
from intelli.utils.logging import Logger | ||
from functools import partial | ||
|
||
|
||
class Flow: | ||
def __init__(self, tasks, map_paths, log=False): | ||
self.tasks = tasks | ||
self.map_paths = map_paths | ||
self.graph = nx.DiGraph() | ||
self.output = {} | ||
self.logger = Logger(log) | ||
self._prepare_graph() | ||
|
||
def _prepare_graph(self): | ||
# Initialize the graph with tasks as nodes | ||
for task_name in self.tasks: | ||
self.graph.add_node(task_name) | ||
|
||
# Add edges based on map_paths to define dependencies | ||
for parent_task, dependencies in self.map_paths.items(): | ||
for child_task in dependencies: | ||
self.graph.add_edge(parent_task, child_task) | ||
|
||
# Check for cycles in the graph | ||
if not nx.is_directed_acyclic_graph(self.graph): | ||
raise ValueError("The dependency graph has cycles, please revise map_paths.") | ||
|
||
async def _execute_task(self, task_name): | ||
self.logger.log(f'---- execute task {task_name} ---- ') | ||
task = self.tasks[task_name] | ||
predecessor_outputs = [] | ||
predecessor_types = set() | ||
|
||
# Gather inputs and types from previous tasks based on the graph | ||
for pred in self.graph.predecessors(task_name): | ||
if pred in self.output: | ||
predecessor_outputs.append(self.output[pred]['output']) | ||
predecessor_types.add(self.output[pred]['type']) | ||
else: | ||
print(f"Warning: Output for predecessor task '{pred}' not found. Skipping...") | ||
|
||
self.logger.log(f'The number of combined inputs for task {task_name} is {len(predecessor_outputs)}') | ||
merged_input = " ".join(predecessor_outputs) | ||
merged_type = next(iter(predecessor_types)) if len(predecessor_types) == 1 else None | ||
|
||
# Execute task with merged input | ||
loop = asyncio.get_event_loop() | ||
execute_task = partial(task.execute, merged_input, input_type=merged_type) | ||
|
||
# Run the synchronous function | ||
await loop.run_in_executor(None, execute_task) | ||
|
||
# Collect outputs and types | ||
self.output[task_name] = {'output': task.output, 'type': task.output_type} | ||
|
||
async def start(self, max_workers=10): | ||
ordered_tasks = list(nx.topological_sort(self.graph)) | ||
task_coroutines = {task_name: self._execute_task(task_name) for task_name in ordered_tasks} | ||
async with asyncio.Semaphore(max_workers): | ||
for task_name in ordered_tasks: | ||
await task_coroutines[task_name] | ||
|
||
# Filter the outputs (and types) of excluded tasks | ||
filtered_output = { | ||
task_name: { 'output': self.output[task_name]['output'], 'type': self.output[task_name]['type'] } | ||
for task_name in ordered_tasks if not self.tasks[task_name].exclude | ||
} | ||
|
||
return filtered_output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
class AgentInput: | ||
def __init__(self, desc=None, img=None, audio=None): | ||
self.desc = desc | ||
self.img = img | ||
self.audio = audio | ||
|
||
|
||
class TextAgentInput(AgentInput): | ||
def __init__(self, desc): | ||
super().__init__(desc=desc) | ||
|
||
|
||
class ImageAgentInput(AgentInput): | ||
def __init__(self, desc, img): | ||
super().__init__(desc=desc, img=img) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,11 @@ | ||
from enum import Enum | ||
|
||
|
||
class AgentTypes(Enum): | ||
TEXT = 'text' | ||
IMAGE = 'image' | ||
|
||
|
||
class InputTypes(Enum): | ||
TEXT = 'text' | ||
IMAGE = 'image' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Oops, something went wrong.