Skip to content

Commit 98926f4

Browse files
committed
change model dropdown
1 parent 293804e commit 98926f4

File tree

1 file changed

+13
-11
lines changed

1 file changed

+13
-11
lines changed

gui/main.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,24 @@
1818
import customtkinter as ctk
1919

2020
from crab import Experiment
21-
from crab.agents.backend_models import OpenAIModel, ClaudeModel, GeminiModel
21+
from crab.agents.backend_models import ClaudeModel, GeminiModel, OpenAIModel
2222
from crab.agents.policies import SingleAgentPolicy
2323
from gui.utils import get_benchmark
2424

2525
warnings.filterwarnings("ignore")
2626

2727
AVAILABLE_MODELS = {
28-
"gpt-4o": ("OpenAIModel", "gpt-4o"),
29-
"gpt-4turbo": ("OpenAIModel", "gpt-4-turbo"),
30-
"gemini": ("GeminiModel", "gemini-1.5-pro-latest"),
31-
"claude": ("ClaudeModel", "claude-3-opus-20240229"),
28+
"GPT-4o": ("OpenAIModel", "gpt-4o"),
29+
"GPT-4 Turbo": ("OpenAIModel", "gpt-4-turbo"),
30+
"Gemini": ("GeminiModel", "gemini-1.5-pro-latest"),
31+
"Claude": ("ClaudeModel", "claude-3-opus-20240229"),
3232
}
3333

34+
3435
def get_model_instance(model_key: str):
3536
if model_key not in AVAILABLE_MODELS:
3637
raise ValueError(f"Model {model_key} not supported")
37-
38+
3839
model_config = AVAILABLE_MODELS[model_key]
3940
model_class_name = model_config[0]
4041
model_name = model_config[1]
@@ -46,12 +47,13 @@ def get_model_instance(model_key: str):
4647
elif model_class_name == "ClaudeModel":
4748
return ClaudeModel(model=model_name, history_messages_len=2)
4849

50+
4951
def assign_task():
5052
task_description = input_entry.get()
5153
input_entry.delete(0, "end")
5254
display_message(task_description)
5355

54-
model = get_model_instance(selected_model.get())
56+
model = get_model_instance(model_dropdown.get())
5557
agent_policy = SingleAgentPolicy(model_backend=model)
5658

5759
task_id = str(uuid4())
@@ -80,7 +82,8 @@ def display_message(message, sender="user"):
8082

8183

8284
if __name__ == "__main__":
83-
# TODO: Handle JSON decode error from environment action endpoint and display model response in GUI
85+
# TODO: Handle JSON decode error from environment action endpoint and
86+
# display model response in GUI
8487
log_dir = (Path(__file__).parent / "logs").resolve()
8588

8689
ctk.set_appearance_mode("System")
@@ -93,15 +96,14 @@ def display_message(message, sender="user"):
9396
model_frame = ctk.CTkFrame(app)
9497
model_frame.pack(pady=10, padx=10, fill="x")
9598

96-
model_label = ctk.CTkLabel(model_frame, text="Select Model:")
99+
model_label = ctk.CTkLabel(model_frame, text="Model")
97100
model_label.pack(side="left", padx=(0, 10))
98101

99-
selected_model = ctk.StringVar(value="gpt-4o")
100102
model_dropdown = ctk.CTkOptionMenu(
101103
model_frame,
102104
values=list(AVAILABLE_MODELS.keys()),
103-
variable=selected_model,
104105
)
106+
model_dropdown.set(next(iter(AVAILABLE_MODELS)))
105107
model_dropdown.pack(side="left", fill="x", expand=True)
106108

107109
chat_display_frame = ctk.CTkFrame(app, width=380, height=380)

0 commit comments

Comments
 (0)