Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions docs/_core_features/tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,27 @@ puts response.content
# => "Current weather at 52.52, 13.4: Temperature: 12.5°C, Wind Speed: 8.3 km/h, Conditions: Mainly clear, partly cloudy, and overcast."
```

### Tool Choice Control

Control when and how tools are called using `choice` and `parallel` options:

```ruby
chat = RubyLLM.chat(model: 'gpt-4o')

# Choice options
chat.with_tool(Weather, choice: :auto) # Model decides whether to call any provided tools or not (default)
chat.with_tool(Weather, choice: :any) # Model must use one of the provided tools
chat.with_tool(Weather, choice: :none) # No tools
chat.with_tool(Weather, choice: :weather) # Force specific tool

# Parallel tool calls
chat.with_tools(Weather, Calculator, parallel: true) # Model can output multiple tool calls at once (default)
chat.with_tools(Weather, Calculator, parallel: false) # At most one tool call
```
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both examples of with_tool and with_tools should contain both choice and parallel parameters otherwise people get the false sense that one parameter is for one call only.


> With `:any` or specific tool choices, tool results are not automatically sent back to the AI model (see The Tool Execution Flow section below) to prevent infinite loops.
{: .note }

### Model Compatibility
{: .d-inline-block }

Expand Down
42 changes: 36 additions & 6 deletions lib/ruby_llm/chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ module RubyLLM
class Chat
include Enumerable

attr_reader :model, :messages, :tools, :params, :headers, :schema
attr_reader :model, :messages, :tools, :tool_choice, :parallel_tool_calls, :params, :headers, :schema
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to compress these two into one tool_prefs = {choice: ..., parallel: ...}

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should choice and parallel default to nil (like temperature) to use provider defaults, or should we set them explicitly to :auto and true?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely set them to nil


def initialize(model: nil, provider: nil, assume_model_exists: false, context: nil)
if assume_model_exists && !provider
Expand All @@ -17,6 +17,8 @@ def initialize(model: nil, provider: nil, assume_model_exists: false, context: n
model_id = model || @config.default_model
with_model(model_id, provider: provider, assume_exists: assume_model_exists)
@temperature = nil
@tool_choice = nil
@parallel_tool_calls = nil
@messages = []
@tools = {}
@params = {}
Expand Down Expand Up @@ -44,15 +46,19 @@ def with_instructions(instructions, replace: false)
self
end

def with_tool(tool)
tool_instance = tool.is_a?(Class) ? tool.new : tool
@tools[tool_instance.name.to_sym] = tool_instance
def with_tool(tool, choice: nil, parallel: nil)
unless tool.nil?
tool_instance = tool.is_a?(Class) ? tool.new : tool
@tools[tool_instance.name.to_sym] = tool_instance
end
update_tool_options(choice:, parallel:)
self
end

def with_tools(*tools, replace: false)
def with_tools(*tools, replace: false, choice: nil, parallel: nil)
@tools.clear if replace
tools.compact.each { |tool| with_tool tool }
update_tool_options(choice:, parallel:)
self
end

Expand Down Expand Up @@ -130,6 +136,8 @@ def complete(&) # rubocop:disable Metrics/PerceivedComplexity
params: @params,
headers: @headers,
schema: @schema,
tool_choice: @tool_choice,
parallel_tool_calls: @parallel_tool_calls,
&wrap_streaming_block(&)
)

Expand Down Expand Up @@ -196,7 +204,9 @@ def handle_tool_calls(response, &) # rubocop:disable Metrics/PerceivedComplexity
halt_result = result if result.is_a?(Tool::Halt)
end

halt_result || complete(&)
return halt_result if halt_result

should_continue_after_tools? ? complete(&) : response
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The solution to this is not to halt the conversation, but to reset tool_choice to auto after tools have been called.

Halting the conversation is not normal behavior and it's for very specific use cases. More info here: https://community.openai.com/t/infinite-loop-with-tool-choice-required-or-type-function/755129

end

def execute_tool(tool_call)
Expand All @@ -205,6 +215,26 @@ def execute_tool(tool_call)
tool.call(args)
end

def update_tool_options(choice:, parallel:)
unless choice.nil?
valid_tool_choices = %i[auto none any] + tools.keys
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as mentioned in the issue, I want the API to look like auto, none or required as required captures better what's going to happen.

unless valid_tool_choices.include?(choice.to_sym)
raise InvalidToolChoiceError,
"Invalid tool choice: #{choice}. Valid choices are: #{valid_tool_choices.join(', ')}"
end

@tool_choice = choice.to_sym
end

@parallel_tool_calls = !!parallel unless parallel.nil?
end

def should_continue_after_tools?
# Continue conversation only with :auto tool choice to avoid infinite loops.
# With :any or specific tool choices, the model would keep calling tools repeatedly.
tool_choice.nil? || tool_choice == :auto
end

def instance_variables
super - %i[@connection @config]
end
Expand Down
1 change: 1 addition & 0 deletions lib/ruby_llm/error.rb
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def initialize(response = nil, message = nil)
# Error classes for non-HTTP errors
class ConfigurationError < StandardError; end
class InvalidRoleError < StandardError; end
class InvalidToolChoiceError < StandardError; end
class ModelNotFoundError < StandardError; end
class UnsupportedAttachmentError < StandardError; end

Expand Down
7 changes: 6 additions & 1 deletion lib/ruby_llm/provider.rb
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,18 @@ def configuration_requirements
self.class.configuration_requirements
end

def complete(messages, tools:, temperature:, model:, params: {}, headers: {}, schema: nil, &) # rubocop:disable Metrics/ParameterLists
# rubocop:disable Metrics/ParameterLists
def complete(messages, tools:, temperature:, model:, params: {}, headers: {}, schema: nil,
tool_choice: nil, parallel_tool_calls: nil, &)
normalized_temperature = maybe_normalize_temperature(temperature, model)

payload = Utils.deep_merge(
params,
render_payload(
messages,
tools: tools,
tool_choice: tool_choice,
parallel_tool_calls: parallel_tool_calls,
temperature: normalized_temperature,
model: model,
stream: block_given?,
Expand All @@ -58,6 +62,7 @@ def complete(messages, tools:, temperature:, model:, params: {}, headers: {}, sc
sync_response @connection, payload, headers
end
end
# rubocop:enable Metrics/ParameterLists

def list_models
response = @connection.get models_url
Expand Down
15 changes: 11 additions & 4 deletions lib/ruby_llm/providers/anthropic/chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,17 @@ def completion_url
'/v1/messages'
end

def render_payload(messages, tools:, temperature:, model:, stream: false, schema: nil) # rubocop:disable Metrics/ParameterLists,Lint/UnusedMethodArgument
# rubocop:disable Metrics/ParameterLists,Lint/UnusedMethodArgument
def render_payload(messages, tools:, tool_choice:, parallel_tool_calls:,
temperature:, model:, stream: false, schema: nil)
system_messages, chat_messages = separate_messages(messages)
system_content = build_system_content(system_messages)

build_base_payload(chat_messages, model, stream).tap do |payload|
add_optional_fields(payload, system_content:, tools:, temperature:)
add_optional_fields(payload, system_content:, tools:, tool_choice:, parallel_tool_calls:, temperature:)
end
end
# rubocop:enable Metrics/ParameterLists,Lint/UnusedMethodArgument

def separate_messages(messages)
messages.partition { |msg| msg.role == :system }
Expand All @@ -44,8 +47,12 @@ def build_base_payload(chat_messages, model, stream)
}
end

def add_optional_fields(payload, system_content:, tools:, temperature:)
payload[:tools] = tools.values.map { |t| Tools.function_for(t) } if tools.any?
def add_optional_fields(payload, system_content:, tools:, tool_choice:, parallel_tool_calls:, temperature:) # rubocop:disable Metrics/ParameterLists
if tools.any?
payload[:tools] = tools.values.map { |t| Tools.function_for(t) }
payload[:tool_choice] = build_tool_choice(tool_choice, parallel_tool_calls) unless tool_choice.nil?
end

payload[:system] = system_content unless system_content.empty?
payload[:temperature] = temperature unless temperature.nil?
end
Expand Down
9 changes: 9 additions & 0 deletions lib/ruby_llm/providers/anthropic/tools.rb
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,15 @@ def clean_parameters(parameters)
def required_parameters(parameters)
parameters.select { |_, param| param.required }.keys
end

def build_tool_choice(tool_choice, parallel_tool_calls)
{
type: %i[auto any none].include?(tool_choice) ? tool_choice : :tool
}.tap do |tc|
tc[:name] = tool_choice if tc[:type] == :tool
tc[:disable_parallel_tool_use] = !parallel_tool_calls unless tc[:type] == :none || parallel_tool_calls.nil?
end
end
end
end
end
Expand Down
8 changes: 6 additions & 2 deletions lib/ruby_llm/providers/bedrock/chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,20 @@ def completion_url
"model/#{@model_id}/invoke"
end

def render_payload(messages, tools:, temperature:, model:, stream: false, schema: nil) # rubocop:disable Lint/UnusedMethodArgument,Metrics/ParameterLists
# rubocop:disable Metrics/ParameterLists,Lint/UnusedMethodArgument
def render_payload(messages, tools:, tool_choice:, parallel_tool_calls:,
temperature:, model:, stream: false, schema: nil)
@model_id = model

system_messages, chat_messages = Anthropic::Chat.separate_messages(messages)
system_content = Anthropic::Chat.build_system_content(system_messages)

build_base_payload(chat_messages, model).tap do |payload|
Anthropic::Chat.add_optional_fields(payload, system_content:, tools:, temperature:)
Anthropic::Chat.add_optional_fields(payload, system_content:, tools:, tool_choice:,
parallel_tool_calls:, temperature:)
end
end
# rubocop:enable Metrics/ParameterLists,Lint/UnusedMethodArgument

def build_base_payload(chat_messages, model)
{
Expand Down
12 changes: 10 additions & 2 deletions lib/ruby_llm/providers/gemini/chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ def completion_url
"models/#{@model}:generateContent"
end

def render_payload(messages, tools:, temperature:, model:, stream: false, schema: nil) # rubocop:disable Metrics/ParameterLists,Lint/UnusedMethodArgument
# rubocop:disable Metrics/ParameterLists,Lint/UnusedMethodArgument
def render_payload(messages, tools:, tool_choice:, parallel_tool_calls:,
temperature:, model:, stream: false, schema: nil)
@model = model
payload = {
contents: format_messages(messages),
Expand All @@ -25,9 +27,15 @@ def render_payload(messages, tools:, temperature:, model:, stream: false, schema
payload[:generationConfig][:responseSchema] = convert_schema_to_gemini(schema)
end

payload[:tools] = format_tools(tools) if tools.any?
if tools.any?
payload[:tools] = format_tools(tools)
# Gemini doesn't support controlling parallel tool calls
payload[:toolConfig] = build_tool_config(tool_choice) unless tool_choice.nil?
end

payload
end
# rubocop:enable Metrics/ParameterLists,Lint/UnusedMethodArgument

private

Expand Down
15 changes: 15 additions & 0 deletions lib/ruby_llm/providers/gemini/tools.rb
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,21 @@ def param_type_for_gemini(type)
else 'STRING'
end
end

def build_tool_config(tool_choice)
{
functionCallingConfig: {
mode: specific_tool_choice?(tool_choice) ? 'any' : tool_choice
}.tap do |config|
# Use allowedFunctionNames to simulate specific tool choice
config[:allowedFunctionNames] = [tool_choice] if specific_tool_choice?(tool_choice)
end
}
end

def specific_tool_choice?(tool_choice)
!%i[auto none any].include?(tool_choice)
end
end
end
end
Expand Down
3 changes: 2 additions & 1 deletion lib/ruby_llm/providers/mistral/chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ def format_role(role)
end

# rubocop:disable Metrics/ParameterLists
def render_payload(messages, tools:, temperature:, model:, stream: false, schema: nil)
def render_payload(messages, tools:, tool_choice:, parallel_tool_calls:, temperature:, model:, stream: false,
schema: nil)
payload = super
payload.delete(:stream_options)
payload
Expand Down
12 changes: 10 additions & 2 deletions lib/ruby_llm/providers/openai/chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,22 @@ def completion_url

module_function

def render_payload(messages, tools:, temperature:, model:, stream: false, schema: nil) # rubocop:disable Metrics/ParameterLists
# rubocop:disable Metrics/ParameterLists
def render_payload(messages, tools:, tool_choice:, parallel_tool_calls:,
temperature:, model:, stream: false, schema: nil)
payload = {
model: model,
messages: format_messages(messages),
stream: stream
}

payload[:temperature] = temperature unless temperature.nil?
payload[:tools] = tools.map { |_, tool| tool_for(tool) } if tools.any?

if tools.any?
payload[:tools] = tools.map { |_, tool| tool_for(tool) }
payload[:tool_choice] = build_tool_choice(tool_choice) unless tool_choice.nil?
payload[:parallel_tool_calls] = parallel_tool_calls unless parallel_tool_calls.nil?
end

if schema
strict = schema[:strict] != false
Expand All @@ -37,6 +44,7 @@ def render_payload(messages, tools:, temperature:, model:, stream: false, schema
payload[:stream_options] = { include_usage: true } if stream
payload
end
# rubocop:enable Metrics/ParameterLists

def parse_completion_response(response)
data = response.body
Expand Down
16 changes: 16 additions & 0 deletions lib/ruby_llm/providers/openai/tools.rb
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,22 @@ def parse_tool_calls(tool_calls, parse_arguments: true)
]
end
end

def build_tool_choice(tool_choice)
case tool_choice
when :auto, :none
tool_choice
when :any
:required
else
{
type: 'function',
function: {
name: tool_choice
}
}
end
end
end
end
end
Expand Down