Skip to content

Commit 5f99209

Browse files
committed
feat(llm): add llm_params option to llm_call
Extend llm_call to accept an optional llm_params dictionary for passing configuration parameters (e.g., temperature, max_tokens) to the language model. This enables more flexible control over LLM behavior during calls. refactor(llm): replace llm_params context manager with argument Update all usages of the llm_params context manager to pass llm_params as an argument to llm_call instead. This simplifies parameter handling and improves code clarity for LLM calls. docs: clarify prompt customization and llm_params usage update LLMChain config usage add unit and e2e tests fix failing tests
1 parent 4c34032 commit 5f99209

File tree

21 files changed

+777
-163
lines changed

21 files changed

+777
-163
lines changed

docs/user-guides/advanced/prompt-customization.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ To override the prompt for any other custom purpose, you can specify the `mode`
5555
As an example of this, let's consider the case of compacting. Some applications might need concise prompts, for instance to avoid handling long contexts, and lower latency at the risk of slightly degraded performance due to the smaller context. For this, you might want to have multiple versions of a prompt for the same task and same model. This can be achieved as follows:
5656

5757
Task configuration:
58+
5859
```yaml
5960
models:
6061
- type: main
@@ -65,6 +66,7 @@ prompting_mode: "compact" # Default value is "standard"
6566
```
6667
6768
Prompts configuration:
69+
6870
```yaml
6971
prompts:
7072
- task: generate_user_intent
@@ -117,6 +119,7 @@ prompts:
117119
content: ...
118120
# ...
119121
```
122+
120123
For each task, you can also specify the maximum length of the prompt to be used for the LLM call in terms of the number of characters. This is useful if you want to limit the number of tokens used by the LLM or when you want to make sure that the prompt length does not exceed the maximum context length. When the maximum length is exceeded, the prompt is truncated by removing older turns from the conversation history until length of the prompt is less than or equal to the maximum length. The default maximum length is 16000 characters.
121124

122125
For example, for the `generate_user_intent` task, you can specify the following:
@@ -129,7 +132,6 @@ prompts:
129132
max_length: 3000
130133
```
131134

132-
133135
### Content Template
134136

135137
The content for a completion prompt or the body for a message in a chat prompt is a string that can also include variables and potentially other types of constructs. NeMo Guardrails uses [Jinja2](https://jinja.palletsprojects.com/) as the templating engine. Check out the [Jinja Synopsis](https://jinja.palletsprojects.com/en/3.1.x/templates/#synopsis) for a quick introduction.
@@ -200,7 +202,6 @@ Optionally, the output from the LLM can be parsed using an *output parser*. The
200202
- `bot_message`: parse the bot message, i.e., removes the "Bot message:" prefix if present;
201203
- `verbose_v1`: parse the output of the `verbose_v1` filter.
202204

203-
204205
## Predefined Prompts
205206

206207
Currently, the NeMo Guardrails toolkit includes prompts for `openai/gpt-3.5-turbo-instruct`, `openai/gpt-3.5-turbo`, `openai/gpt-4`, `databricks/dolly-v2-3b`, `cohere/command`, `cohere/command-light`, `cohere/command-light-nightly`.
@@ -232,8 +233,7 @@ prompt = llm_task_manager.render_task_prompt(
232233
},
233234
)
234235
235-
with llm_params(llm, temperature=0.0):
236-
check = await llm_call(llm, prompt)
236+
check = await llm_call(llm, prompt, llm_params={"temperature": 0.0})
237237
...
238238
```
239239

nemoguardrails/actions/llm/generation.py

Lines changed: 70 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -436,8 +436,9 @@ async def generate_user_intent(
436436
llm_call_info_var.set(LLMCallInfo(task=Task.GENERATE_USER_INTENT.value))
437437

438438
# We make this call with temperature 0 to have it as deterministic as possible.
439-
with llm_params(llm, temperature=self.config.lowest_temperature):
440-
result = await llm_call(llm, prompt)
439+
result = await llm_call(
440+
llm, prompt, llm_params={"temperature": self.config.lowest_temperature}
441+
)
441442

442443
# Parse the output using the associated parser
443444
result = self.llm_task_manager.parse_task_output(
@@ -518,17 +519,15 @@ async def generate_user_intent(
518519
llm_call_info_var.set(LLMCallInfo(task=Task.GENERAL.value))
519520

520521
generation_options: GenerationOptions = generation_options_var.get()
521-
with llm_params(
522+
llm_params = (
523+
generation_options and generation_options.llm_params
524+
) or {}
525+
text = await llm_call(
522526
llm,
523-
**(
524-
(generation_options and generation_options.llm_params) or {}
525-
),
526-
):
527-
text = await llm_call(
528-
llm,
529-
prompt,
530-
custom_callback_handlers=[streaming_handler_var.get()],
531-
)
527+
prompt,
528+
custom_callback_handlers=[streaming_handler_var.get()],
529+
llm_params=llm_params,
530+
)
532531
text = self.llm_task_manager.parse_task_output(
533532
Task.GENERAL, output=text
534533
)
@@ -558,16 +557,16 @@ async def generate_user_intent(
558557
)
559558

560559
generation_options: GenerationOptions = generation_options_var.get()
561-
with llm_params(
560+
llm_params = (
561+
generation_options and generation_options.llm_params
562+
) or {}
563+
result = await llm_call(
562564
llm,
563-
**((generation_options and generation_options.llm_params) or {}),
564-
):
565-
result = await llm_call(
566-
llm,
567-
prompt,
568-
custom_callback_handlers=[streaming_handler_var.get()],
569-
stop=["User:"],
570-
)
565+
prompt,
566+
custom_callback_handlers=[streaming_handler_var.get()],
567+
stop=["User:"],
568+
llm_params=llm_params,
569+
)
571570

572571
text = self.llm_task_manager.parse_task_output(
573572
Task.GENERAL, output=result
@@ -662,8 +661,9 @@ async def generate_next_step(
662661
llm_call_info_var.set(LLMCallInfo(task=Task.GENERATE_NEXT_STEPS.value))
663662

664663
# We use temperature 0 for next step prediction as well
665-
with llm_params(llm, temperature=self.config.lowest_temperature):
666-
result = await llm_call(llm, prompt)
664+
result = await llm_call(
665+
llm, prompt, llm_params={"temperature": self.config.lowest_temperature}
666+
)
667667

668668
# Parse the output using the associated parser
669669
result = self.llm_task_manager.parse_task_output(
@@ -924,23 +924,23 @@ async def generate_bot_message(
924924
prompt = context.get("user_message")
925925

926926
generation_options: GenerationOptions = generation_options_var.get()
927-
with llm_params(
927+
llm_params = (
928+
generation_options and generation_options.llm_params
929+
) or {}
930+
result = await llm_call(
928931
llm,
929-
**(
930-
(generation_options and generation_options.llm_params) or {}
931-
),
932-
):
933-
result = await llm_call(
934-
llm, prompt, custom_callback_handlers=[streaming_handler]
935-
)
932+
prompt,
933+
custom_callback_handlers=[streaming_handler],
934+
llm_params=llm_params,
935+
)
936936

937-
result = self.llm_task_manager.parse_task_output(
938-
Task.GENERAL, output=result
939-
)
937+
result = self.llm_task_manager.parse_task_output(
938+
Task.GENERAL, output=result
939+
)
940940

941-
result = _process_parsed_output(
942-
result, self._include_reasoning_traces()
943-
)
941+
result = _process_parsed_output(
942+
result, self._include_reasoning_traces()
943+
)
944944

945945
log.info(
946946
"--- :: LLM Bot Message Generation passthrough call took %.2f seconds",
@@ -987,13 +987,15 @@ async def generate_bot_message(
987987
llm_call_info_var.set(LLMCallInfo(task=Task.GENERATE_BOT_MESSAGE.value))
988988

989989
generation_options: GenerationOptions = generation_options_var.get()
990-
with llm_params(
990+
llm_params = (
991+
generation_options and generation_options.llm_params
992+
) or {}
993+
result = await llm_call(
991994
llm,
992-
**((generation_options and generation_options.llm_params) or {}),
993-
):
994-
result = await llm_call(
995-
llm, prompt, custom_callback_handlers=[streaming_handler]
996-
)
995+
prompt,
996+
custom_callback_handlers=[streaming_handler],
997+
llm_params=llm_params,
998+
)
997999

9981000
log.info(
9991001
"--- :: LLM Bot Message Generation call took %.2f seconds",
@@ -1094,8 +1096,9 @@ async def generate_value(
10941096
# Initialize the LLMCallInfo object
10951097
llm_call_info_var.set(LLMCallInfo(task=Task.GENERATE_VALUE.value))
10961098

1097-
with llm_params(llm, temperature=self.config.lowest_temperature):
1098-
result = await llm_call(llm, prompt)
1099+
result = await llm_call(
1100+
llm, prompt, llm_params={"temperature": self.config.lowest_temperature}
1101+
)
10991102

11001103
# Parse the output using the associated parser
11011104
result = self.llm_task_manager.parse_task_output(
@@ -1269,32 +1272,28 @@ async def generate_intent_steps_message(
12691272
# We buffer the content, so we can get a chance to look at the
12701273
# first k lines.
12711274
await _streaming_handler.enable_buffering()
1272-
with llm_params(llm, temperature=self.config.lowest_temperature):
1273-
asyncio.create_task(
1274-
llm_call(
1275-
llm,
1276-
prompt,
1277-
custom_callback_handlers=[_streaming_handler],
1278-
stop=["\nuser ", "\nUser "],
1279-
)
1275+
asyncio.create_task(
1276+
llm_call(
1277+
llm,
1278+
prompt,
1279+
custom_callback_handlers=[_streaming_handler],
1280+
stop=["\nuser ", "\nUser "],
1281+
llm_params={"temperature": self.config.lowest_temperature},
12801282
)
1281-
result = await _streaming_handler.wait_top_k_nonempty_lines(k=2)
1283+
)
1284+
result = await _streaming_handler.wait_top_k_nonempty_lines(k=2)
12821285

1283-
# We also mark that the message is still being generated
1284-
# by a streaming handler.
1285-
result += (
1286-
f'\nBot message: "<<STREAMING[{_streaming_handler.uid}]>>"'
1287-
)
1286+
# We also mark that the message is still being generated
1287+
# by a streaming handler.
1288+
result += f'\nBot message: "<<STREAMING[{_streaming_handler.uid}]>>"'
12881289

1289-
# Moving forward we need to set the expected pattern to correctly
1290-
# parse the message.
1291-
# TODO: Figure out a more generic way to deal with this.
1292-
if prompt_config.output_parser == "verbose_v1":
1293-
_streaming_handler.set_pattern(
1294-
prefix='Bot message: "', suffix='"'
1295-
)
1296-
else:
1297-
_streaming_handler.set_pattern(prefix=' "', suffix='"')
1290+
# Moving forward we need to set the expected pattern to correctly
1291+
# parse the message.
1292+
# TODO: Figure out a more generic way to deal with this.
1293+
if prompt_config.output_parser == "verbose_v1":
1294+
_streaming_handler.set_pattern(prefix='Bot message: "', suffix='"')
1295+
else:
1296+
_streaming_handler.set_pattern(prefix=' "', suffix='"')
12981297
else:
12991298
# Initialize the LLMCallInfo object
13001299
llm_call_info_var.set(
@@ -1306,8 +1305,7 @@ async def generate_intent_steps_message(
13061305
**((generation_options and generation_options.llm_params) or {}),
13071306
"temperature": self.config.lowest_temperature,
13081307
}
1309-
with llm_params(llm, **additional_params):
1310-
result = await llm_call(llm, prompt)
1308+
result = await llm_call(llm, prompt, llm_params=additional_params)
13111309

13121310
# Parse the output using the associated parser
13131311
result = self.llm_task_manager.parse_task_output(
@@ -1388,10 +1386,8 @@ async def generate_intent_steps_message(
13881386

13891387
# We make this call with temperature 0 to have it as deterministic as possible.
13901388
generation_options: GenerationOptions = generation_options_var.get()
1391-
with llm_params(
1392-
llm, **((generation_options and generation_options.llm_params) or {})
1393-
):
1394-
result = await llm_call(llm, prompt)
1389+
llm_params = (generation_options and generation_options.llm_params) or {}
1390+
result = await llm_call(llm, prompt, llm_params=llm_params)
13951391

13961392
result = self.llm_task_manager.parse_task_output(
13971393
Task.GENERAL, output=result

nemoguardrails/actions/llm/utils.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,28 @@ async def llm_call(
7575
model_provider: Optional[str] = None,
7676
stop: Optional[List[str]] = None,
7777
custom_callback_handlers: Optional[List[AsyncCallbackHandler]] = None,
78+
llm_params: Optional[dict] = None,
7879
) -> str:
79-
"""Calls the LLM with a prompt and returns the generated text."""
80+
"""Calls the LLM with a prompt and returns the generated text.
81+
82+
Args:
83+
llm: The language model instance to use
84+
prompt: The prompt string or list of messages
85+
model_name: Optional model name for tracking
86+
model_provider: Optional model provider for tracking
87+
stop: Optional list of stop tokens
88+
custom_callback_handlers: Optional list of callback handlers
89+
llm_params: Optional configuration dictionary to pass to the LLM (e.g., temperature, max_tokens)
90+
91+
Returns:
92+
The generated text response
93+
"""
8094
_setup_llm_call_info(llm, model_name, model_provider)
8195
all_callbacks = _prepare_callbacks(custom_callback_handlers)
8296

97+
if llm_params and llm is not None:
98+
llm = llm.bind(**llm_params)
99+
83100
if isinstance(prompt, str):
84101
response = await _invoke_with_string_prompt(llm, prompt, all_callbacks, stop)
85102
else:

0 commit comments

Comments
 (0)