Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions docs/_core_features/tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,40 @@ puts response.content
# => "Current weather at 52.52, 13.4: Temperature: 12.5°C, Wind Speed: 8.3 km/h, Conditions: Mainly clear, partly cloudy, and overcast."
```

### Tool Choice Control

Control when and how tools are called using `choice` and `parallel` options.

**Parameter Values:**
- **`choice`**: Controls tool choice behavior
- `:auto` Model decides whether to use any tools
- `:required` - Model must use one of the provided tools
- `:none` - Disable all tools
- `"tool_name"` - Force a specific tool (e.g., `:weather` for `Weather` tool)
- **`parallel`**: Controls parallel tool calls
- `true` Allow multiple tool calls simultaneously
- `false` - One at a time

If not provided, RubyLLM will use the provider's default behavior for tool choice and parallel tool calls.

**Examples:**

```ruby
chat = RubyLLM.chat(model: 'gpt-4o')

# Basic usage with defaults
chat.with_tools(Weather, Calculator) # uses provider defaults

# Force tool usage, one at a time
chat.with_tools(Weather, Calculator, choice: :required, parallel: false)

# Force specific tool
chat.with_tool(Weather, choice: :weather, parallel: true)
Comment on lines 193 to 205
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer that to be code first, then parameter values. Also when specifying a tool in choice, it should accept the tool class too.

```

> With `:required` or specific tool choices, the tool_choice is automatically reset to `nil` after tool execution to prevent infinite loops.
{: .note }

### Model Compatibility
{: .d-inline-block }

Expand Down
39 changes: 34 additions & 5 deletions lib/ruby_llm/chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ module RubyLLM
class Chat
include Enumerable

attr_reader :model, :messages, :tools, :params, :headers, :schema
attr_reader :model, :messages, :tools, :tool_prefs, :params, :headers, :schema

def initialize(model: nil, provider: nil, assume_model_exists: false, context: nil)
if assume_model_exists && !provider
Expand All @@ -17,6 +17,7 @@ def initialize(model: nil, provider: nil, assume_model_exists: false, context: n
model_id = model || @config.default_model
with_model(model_id, provider: provider, assume_exists: assume_model_exists)
@temperature = nil
@tool_prefs = { choice: nil, parallel: nil }
@messages = []
@tools = {}
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

small thing, but make it after @tools

@params = {}
Expand Down Expand Up @@ -44,15 +45,19 @@ def with_instructions(instructions, replace: false)
self
end

def with_tool(tool)
tool_instance = tool.is_a?(Class) ? tool.new : tool
@tools[tool_instance.name.to_sym] = tool_instance
def with_tool(tool, choice: nil, parallel: nil)
unless tool.nil?
tool_instance = tool.is_a?(Class) ? tool.new : tool
@tools[tool_instance.name.to_sym] = tool_instance
end
update_tool_options(choice:, parallel:)
self
end

def with_tools(*tools, replace: false)
def with_tools(*tools, replace: false, choice: nil, parallel: nil)
@tools.clear if replace
tools.compact.each { |tool| with_tool tool }
update_tool_options(choice:, parallel:)
self
end

Expand Down Expand Up @@ -130,6 +135,7 @@ def complete(&) # rubocop:disable Metrics/PerceivedComplexity
params: @params,
headers: @headers,
schema: @schema,
tool_prefs: @tool_prefs,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pass it after tools

&wrap_streaming_block(&)
)

Expand Down Expand Up @@ -196,6 +202,7 @@ def handle_tool_calls(response, &) # rubocop:disable Metrics/PerceivedComplexity
halt_result = result if result.is_a?(Tool::Halt)
end

reset_tool_choice if forced_tool_choice?
halt_result || complete(&)
end

Expand All @@ -205,6 +212,28 @@ def execute_tool(tool_call)
tool.call(args)
end

def update_tool_options(choice:, parallel:)
unless choice.nil?
valid_tool_choices = %i[auto none required] + tools.keys
unless valid_tool_choices.include?(choice.to_sym)
raise InvalidToolChoiceError,
"Invalid tool choice: #{choice}. Valid choices are: #{valid_tool_choices.join(', ')}"
end

@tool_prefs[:choice] = choice.to_sym
end

@tool_prefs[:parallel] = !!parallel unless parallel.nil?
end

def forced_tool_choice?
@tool_prefs[:choice] && !%i[auto none].include?(@tool_prefs[:choice])
end

def reset_tool_choice
@tool_prefs[:choice] = nil
end

def instance_variables
super - %i[@connection @config]
end
Expand Down
1 change: 1 addition & 0 deletions lib/ruby_llm/error.rb
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def initialize(response = nil, message = nil)
# Error classes for non-HTTP errors
class ConfigurationError < StandardError; end
class InvalidRoleError < StandardError; end
class InvalidToolChoiceError < StandardError; end
class ModelNotFoundError < StandardError; end
class UnsupportedAttachmentError < StandardError; end

Expand Down
6 changes: 5 additions & 1 deletion lib/ruby_llm/provider.rb
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,17 @@ def configuration_requirements
self.class.configuration_requirements
end

def complete(messages, tools:, temperature:, model:, params: {}, headers: {}, schema: nil, &) # rubocop:disable Metrics/ParameterLists
# rubocop:disable Metrics/ParameterLists
def complete(messages, tools:, temperature:, model:, params: {}, headers: {}, schema: nil,
tool_prefs: nil, &)
normalized_temperature = maybe_normalize_temperature(temperature, model)

payload = Utils.deep_merge(
params,
render_payload(
messages,
tools: tools,
tool_prefs: tool_prefs,
temperature: normalized_temperature,
model: model,
stream: block_given?,
Expand All @@ -58,6 +61,7 @@ def complete(messages, tools:, temperature:, model:, params: {}, headers: {}, sc
sync_response @connection, payload, headers
end
end
# rubocop:enable Metrics/ParameterLists

def list_models
response = @connection.get models_url
Expand Down
15 changes: 11 additions & 4 deletions lib/ruby_llm/providers/anthropic/chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,17 @@ def completion_url
'/v1/messages'
end

def render_payload(messages, tools:, temperature:, model:, stream: false, schema: nil) # rubocop:disable Metrics/ParameterLists,Lint/UnusedMethodArgument
# rubocop:disable Metrics/ParameterLists,Lint/UnusedMethodArgument
def render_payload(messages, tools:, tool_prefs:,
temperature:, model:, stream: false, schema: nil)
system_messages, chat_messages = separate_messages(messages)
system_content = build_system_content(system_messages)

build_base_payload(chat_messages, model, stream).tap do |payload|
add_optional_fields(payload, system_content:, tools:, temperature:)
add_optional_fields(payload, system_content:, tools:, tool_prefs:, temperature:)
end
end
# rubocop:enable Metrics/ParameterLists,Lint/UnusedMethodArgument

def separate_messages(messages)
messages.partition { |msg| msg.role == :system }
Expand All @@ -44,8 +47,12 @@ def build_base_payload(chat_messages, model, stream)
}
end

def add_optional_fields(payload, system_content:, tools:, temperature:)
payload[:tools] = tools.values.map { |t| Tools.function_for(t) } if tools.any?
def add_optional_fields(payload, system_content:, tools:, tool_prefs:, temperature:)
if tools.any?
payload[:tools] = tools.values.map { |t| Tools.function_for(t) }
payload[:tool_choice] = build_tool_choice(tool_prefs) unless tool_prefs[:choice].nil?
end

payload[:system] = system_content unless system_content.empty?
payload[:temperature] = temperature unless temperature.nil?
end
Expand Down
19 changes: 19 additions & 0 deletions lib/ruby_llm/providers/anthropic/tools.rb
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,25 @@ def clean_parameters(parameters)
def required_parameters(parameters)
parameters.select { |_, param| param.required }.keys
end

def build_tool_choice(tool_prefs)
tool_choice = tool_prefs[:choice]
parallel_tool_calls = tool_prefs[:parallel]

{
type: case tool_choice
when :auto, :none
tool_choice
when :required
:any
else
:tool
end
}.tap do |tc|
tc[:name] = tool_choice if tc[:type] == :tool
tc[:disable_parallel_tool_use] = !parallel_tool_calls unless tc[:type] == :none || parallel_tool_calls.nil?
end
end
end
end
end
Expand Down
8 changes: 6 additions & 2 deletions lib/ruby_llm/providers/bedrock/chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,20 @@ def completion_url
"model/#{@model_id}/invoke"
end

def render_payload(messages, tools:, temperature:, model:, stream: false, schema: nil) # rubocop:disable Lint/UnusedMethodArgument,Metrics/ParameterLists
# rubocop:disable Metrics/ParameterLists,Lint/UnusedMethodArgument
def render_payload(messages, tools:, tool_prefs:, temperature:, model:, stream: false,
schema: nil)
@model_id = model.id

system_messages, chat_messages = Anthropic::Chat.separate_messages(messages)
system_content = Anthropic::Chat.build_system_content(system_messages)

build_base_payload(chat_messages, model).tap do |payload|
Anthropic::Chat.add_optional_fields(payload, system_content:, tools:, temperature:)
Anthropic::Chat.add_optional_fields(payload, system_content:, tools:, tool_prefs:,
temperature:)
end
end
# rubocop:enable Metrics/ParameterLists,Lint/UnusedMethodArgument

def build_base_payload(chat_messages, model)
{
Expand Down
11 changes: 9 additions & 2 deletions lib/ruby_llm/providers/gemini/chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ def completion_url
"models/#{@model}:generateContent"
end

def render_payload(messages, tools:, temperature:, model:, stream: false, schema: nil) # rubocop:disable Metrics/ParameterLists,Lint/UnusedMethodArgument
# rubocop:disable Metrics/ParameterLists,Lint/UnusedMethodArgument
def render_payload(messages, tools:, tool_prefs:, temperature:, model:, stream: false, schema: nil)
@model = model.id
payload = {
contents: format_messages(messages),
Expand All @@ -25,9 +26,15 @@ def render_payload(messages, tools:, temperature:, model:, stream: false, schema
payload[:generationConfig][:responseSchema] = convert_schema_to_gemini(schema)
end

payload[:tools] = format_tools(tools) if tools.any?
if tools.any?
payload[:tools] = format_tools(tools)
# Gemini doesn't support controlling parallel tool calls
payload[:toolConfig] = build_tool_config(tool_prefs[:choice]) unless tool_prefs[:choice].nil?
end

payload
end
# rubocop:enable Metrics/ParameterLists,Lint/UnusedMethodArgument

private

Expand Down
19 changes: 19 additions & 0 deletions lib/ruby_llm/providers/gemini/tools.rb
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,25 @@ def param_type_for_gemini(type)
else 'STRING'
end
end

def build_tool_config(tool_choice)
{
functionCallingConfig: {
mode: forced_tool_choice?(tool_choice) ? 'any' : tool_choice
}.tap do |config|
# Use allowedFunctionNames to simulate specific tool choice
config[:allowedFunctionNames] = [tool_choice] if specific_tool_choice?(tool_choice)
end
}
end

def forced_tool_choice?(tool_choice)
tool_choice == :required || specific_tool_choice?(tool_choice)
end

def specific_tool_choice?(tool_choice)
!%i[auto none required].include?(tool_choice)
end
end
end
end
Expand Down
3 changes: 2 additions & 1 deletion lib/ruby_llm/providers/mistral/chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ def format_role(role)
end

# rubocop:disable Metrics/ParameterLists
def render_payload(messages, tools:, temperature:, model:, stream: false, schema: nil)
def render_payload(messages, tools:, tool_prefs:, temperature:, model:, stream: false,
schema: nil)
payload = super
payload.delete(:stream_options)
payload
Expand Down
11 changes: 9 additions & 2 deletions lib/ruby_llm/providers/openai/chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,21 @@ def completion_url

module_function

def render_payload(messages, tools:, temperature:, model:, stream: false, schema: nil) # rubocop:disable Metrics/ParameterLists
# rubocop:disable Metrics/ParameterLists
def render_payload(messages, tools:, tool_prefs:, temperature:, model:, stream: false, schema: nil)
payload = {
model: model.id,
messages: format_messages(messages),
stream: stream
}

payload[:temperature] = temperature unless temperature.nil?
payload[:tools] = tools.map { |_, tool| tool_for(tool) } if tools.any?

if tools.any?
payload[:tools] = tools.map { |_, tool| tool_for(tool) }
payload[:tool_choice] = build_tool_choice(tool_prefs[:choice]) unless tool_prefs[:choice].nil?
payload[:parallel_tool_calls] = tool_prefs[:parallel] unless tool_prefs[:parallel].nil?
end

if schema
strict = schema[:strict] != false
Expand All @@ -37,6 +43,7 @@ def render_payload(messages, tools:, temperature:, model:, stream: false, schema
payload[:stream_options] = { include_usage: true } if stream
payload
end
# rubocop:enable Metrics/ParameterLists

def parse_completion_response(response)
data = response.body
Expand Down
14 changes: 14 additions & 0 deletions lib/ruby_llm/providers/openai/tools.rb
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,20 @@ def parse_tool_calls(tool_calls, parse_arguments: true)
]
end
end

def build_tool_choice(tool_choice)
case tool_choice
when :auto, :none, :required
tool_choice
else
{
type: 'function',
function: {
name: tool_choice
}
}
end
end
end
end
end
Expand Down