Skip to content

[WIP] Support Huggingface Models #210

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions maab/agents/mlzero_default/mlzero_default.sh
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ fi
mlzero \
-i "$TRAINING_PATH" \
-o "$OUTPUT_DIR" \
-n 10 \
-n 5 \
-v 1 \
-u "complete the task in 10 minutes"
-u "Use models in Huggingface."

# Check if the process was successful
if [ $? -ne 0 ]; then
Expand Down
5 changes: 1 addition & 4 deletions src/autogluon/assistant/agents/reranker_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,6 @@ def __call__(self):
"""Select and rerank relevant tutorials from retrieved candidates."""
self.manager.log_agent_start("RerankerAgent: reranking and selecting top tutorials from retrieved candidates.")

# Get retrieved tutorials from manager
retrieved_tutorials = self.manager.tutorial_retrieval

# Build prompt for tutorial reranking
prompt = self.reranker_prompt.build()

Expand All @@ -56,7 +53,7 @@ def __call__(self):
# Fallback: if parsing fails or returns empty, use top tutorials by score
if not selected_tutorials:
logger.warning("Tutorial reranking failed, falling back to top tutorials by retrieval score.")
selected_tutorials = self._select_top_by_score(retrieved_tutorials)
selected_tutorials = self._select_top_by_score(self.reranker_prompt.tutorials)

# Generate tutorial prompt using selected tutorials
tutorial_prompt = self._generate_tutorial_prompt(selected_tutorials)
Expand Down
2 changes: 2 additions & 0 deletions src/autogluon/assistant/agents/tool_selector_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def __call__(self) -> Tuple[str, str]:

selected_tool = self.tool_selector_prompt.parse(response)

selected_tool = "huggingface"

self.manager.log_agent_end("ToolSelectorAgent: selected tool and recorded justification.")

return selected_tool
8 changes: 4 additions & 4 deletions src/autogluon/assistant/configs/default.yaml
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
# Tutorial Prompt Generator Configuration

per_execution_timeout: 86400
per_execution_timeout: 7200

# Data Perception
max_file_group_size_to_show: 5
num_example_files_to_show: 1

max_chars_per_file: 1024
num_tutorial_retrievals: 20
max_num_tutorials: 5
num_tutorial_retrievals: 50
max_num_tutorials: 1
max_user_input_length: 2048
max_error_message_length: 2048
max_tutorial_length: 8192
create_venv: false
create_venv: True
condense_tutorials: True
use_tutorial_summary: True

Expand Down
4 changes: 2 additions & 2 deletions src/autogluon/assistant/prompts/reranker_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ def build(self) -> str:
condense_tutorials = self.manager.config.condense_tutorials
use_tutorial_summary = self.manager.config.use_tutorial_summary

# Get all available tutorials
self.tutorials = get_all_tutorials(selected_tool, condensed=condense_tutorials)
# Get retrieved tutorials from manager
self.tutorials = self.manager.tutorial_retrieval

if not self.tutorials:
logger.warning(f"No tutorials found for {selected_tool}")
Expand Down
5 changes: 5 additions & 0 deletions src/autogluon/assistant/tools_registry/_common/catalog.json
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@
"path": "machine learning",
"version": "0.1.0",
"description": "You should select this as a general reference of machine learning or deep learning algorithms in case other tools are not helpful."
},
"huggingface": {
"path": "huggingface",
"version": "1.0.0",
"description": "Here we collect top liked/downloaded models from huggingface for each task."
}
}
}
Loading
Loading