Skip to content

Commit d241323

Browse files
Pouyanpitgasser-nv
authored andcommitted
fix(llm): add fallback extraction for reasoning traces from <think> tags (#1474)
Adds a compatibility layer for LLM providers that don't properly populate reasoning_content in additional_kwargs. When reasoning_content is missing, the system now falls back to extracting reasoning traces from <think>...</think> tags in the response content and removes the tags from the final output. This fixes compatibility with certain NVIDIA models (e.g., nvidia/llama-3.3-nemotron-super-49b-v1.5) in langchain-nvidia-ai-endpoints that include reasoning traces in <think> tags but fail to populate the reasoning_content field. All reasoning models using ChatNVIDIA should expose reasoning content consistently through the same interface
1 parent 04e6cfb commit d241323

File tree

4 files changed

+352
-4
lines changed

4 files changed

+352
-4
lines changed

nemoguardrails/actions/llm/utils.py

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,12 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import logging
1617
import re
1718
from typing import Any, Dict, List, Optional, Sequence, Union
1819

20+
logger = logging.getLogger(__name__)
21+
1922
from langchain.base_language import BaseLanguageModel
2023
from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackManager
2124
from langchain_core.runnables import RunnableConfig
@@ -238,15 +241,78 @@ def _convert_messages_to_langchain_format(prompt: List[dict]) -> List:
238241

239242

240243
def _store_reasoning_traces(response) -> None:
244+
"""Store reasoning traces from response in context variable.
245+
246+
Extracts reasoning content from response.additional_kwargs["reasoning_content"]
247+
if available. Otherwise, falls back to extracting from <think> tags in the
248+
response content (and removes the tags from content).
249+
250+
Args:
251+
response: The LLM response object
252+
"""
253+
254+
reasoning_content = _extract_reasoning_content(response)
255+
256+
if not reasoning_content:
257+
# Some LLM providers (e.g., certain NVIDIA models) embed reasoning in <think> tags
258+
# instead of properly populating reasoning_content in additional_kwargs, so we need
259+
# both extraction methods to support different provider implementations.
260+
reasoning_content = _extract_and_remove_think_tags(response)
261+
262+
if reasoning_content:
263+
reasoning_trace_var.set(reasoning_content)
264+
265+
266+
def _extract_reasoning_content(response):
241267
if hasattr(response, "additional_kwargs"):
242268
additional_kwargs = response.additional_kwargs
243269
if (
244270
isinstance(additional_kwargs, dict)
245271
and "reasoning_content" in additional_kwargs
246272
):
247-
reasoning_content = additional_kwargs["reasoning_content"]
248-
if reasoning_content:
249-
reasoning_trace_var.set(reasoning_content)
273+
return additional_kwargs["reasoning_content"]
274+
return None
275+
276+
277+
def _extract_and_remove_think_tags(response) -> Optional[str]:
278+
"""Extract reasoning from <think> tags and remove them from `response.content`.
279+
280+
This function looks for <think>...</think> tags in the response content,
281+
and if found, extracts the reasoning content inside the tags. It has a side-effect:
282+
it removes the full reasoning trace and tags from response.content.
283+
284+
Args:
285+
response: The LLM response object
286+
287+
Returns:
288+
The extracted reasoning content, or None if no <think> tags found
289+
"""
290+
if not hasattr(response, "content"):
291+
return None
292+
293+
content = response.content
294+
has_opening_tag = "<think>" in content
295+
has_closing_tag = "</think>" in content
296+
297+
if not has_opening_tag and not has_closing_tag:
298+
return None
299+
300+
if has_opening_tag != has_closing_tag:
301+
logger.warning(
302+
"Malformed <think> tags detected: missing %s tag. "
303+
"Skipping reasoning extraction to prevent corrupted content.",
304+
"closing" if has_opening_tag else "opening",
305+
)
306+
return None
307+
308+
match = re.search(r"<think>(.*?)</think>", content, re.DOTALL)
309+
if match:
310+
reasoning_content = match.group(1).strip()
311+
response.content = re.sub(
312+
r"<think>.*?</think>", "", content, flags=re.DOTALL
313+
).strip()
314+
return reasoning_content
315+
return None
250316

251317

252318
def _store_tool_calls(response) -> None:

tests/conftest.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,15 @@
2222
)
2323

2424

25+
@pytest.fixture(autouse=True)
26+
def reset_reasoning_trace_var():
27+
"""Reset reasoning_trace_var before each test to prevent state leakage."""
28+
from nemoguardrails.context import reasoning_trace_var
29+
30+
reasoning_trace_var.set(None)
31+
yield
32+
reasoning_trace_var.set(None)
33+
34+
2535
def pytest_configure(config):
2636
patch("prompt_toolkit.PromptSession", autospec=True).start()

tests/test_actions_llm_utils.py

Lines changed: 182 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
from nemoguardrails.actions.llm.utils import _infer_provider_from_module
16+
from nemoguardrails.actions.llm.utils import (
17+
_extract_and_remove_think_tags,
18+
_infer_provider_from_module,
19+
_store_reasoning_traces,
20+
)
21+
from nemoguardrails.context import reasoning_trace_var
1722

1823

1924
class MockOpenAILLM:
@@ -123,3 +128,179 @@ class Wrapper3(Wrapper2):
123128
llm = Wrapper3()
124129
provider = _infer_provider_from_module(llm)
125130
assert provider == "anthropic"
131+
132+
133+
class MockResponse:
134+
def __init__(self, content="", additional_kwargs=None):
135+
self.content = content
136+
self.additional_kwargs = additional_kwargs or {}
137+
138+
139+
def test_store_reasoning_traces_from_additional_kwargs():
140+
reasoning_trace_var.set(None)
141+
142+
response = MockResponse(
143+
content="The answer is 42",
144+
additional_kwargs={"reasoning_content": "Let me think about this..."},
145+
)
146+
147+
_store_reasoning_traces(response)
148+
149+
assert reasoning_trace_var.get() == "Let me think about this..."
150+
151+
152+
def test_store_reasoning_traces_from_think_tags():
153+
reasoning_trace_var.set(None)
154+
155+
response = MockResponse(
156+
content="<think>Let me think about this...</think>The answer is 42"
157+
)
158+
159+
_store_reasoning_traces(response)
160+
161+
assert reasoning_trace_var.get() == "Let me think about this..."
162+
assert response.content == "The answer is 42"
163+
164+
165+
def test_store_reasoning_traces_multiline_think_tags():
166+
reasoning_trace_var.set(None)
167+
168+
response = MockResponse(
169+
content="<think>Step 1: Analyze the problem\nStep 2: Consider options\nStep 3: Choose solution</think>The answer is 42"
170+
)
171+
172+
_store_reasoning_traces(response)
173+
174+
assert (
175+
reasoning_trace_var.get()
176+
== "Step 1: Analyze the problem\nStep 2: Consider options\nStep 3: Choose solution"
177+
)
178+
assert response.content == "The answer is 42"
179+
180+
181+
def test_store_reasoning_traces_prefers_additional_kwargs():
182+
reasoning_trace_var.set(None)
183+
184+
response = MockResponse(
185+
content="<think>This should not be used</think>The answer is 42",
186+
additional_kwargs={"reasoning_content": "This should be used"},
187+
)
188+
189+
_store_reasoning_traces(response)
190+
191+
assert reasoning_trace_var.get() == "This should be used"
192+
193+
194+
def test_store_reasoning_traces_no_reasoning_content():
195+
reasoning_trace_var.set(None)
196+
197+
response = MockResponse(content="The answer is 42")
198+
199+
_store_reasoning_traces(response)
200+
201+
assert reasoning_trace_var.get() is None
202+
203+
204+
def test_store_reasoning_traces_empty_reasoning_content():
205+
reasoning_trace_var.set(None)
206+
207+
response = MockResponse(
208+
content="The answer is 42", additional_kwargs={"reasoning_content": ""}
209+
)
210+
211+
_store_reasoning_traces(response)
212+
213+
assert reasoning_trace_var.get() is None
214+
215+
216+
def test_store_reasoning_traces_incomplete_think_tags():
217+
reasoning_trace_var.set(None)
218+
219+
response = MockResponse(content="<think>This is incomplete")
220+
221+
_store_reasoning_traces(response)
222+
223+
assert reasoning_trace_var.get() is None
224+
225+
226+
def test_store_reasoning_traces_no_content_attribute():
227+
reasoning_trace_var.set(None)
228+
229+
class ResponseWithoutContent:
230+
def __init__(self):
231+
self.additional_kwargs = {}
232+
233+
response = ResponseWithoutContent()
234+
235+
_store_reasoning_traces(response)
236+
237+
assert reasoning_trace_var.get() is None
238+
239+
240+
def test_store_reasoning_traces_removes_think_tags_with_whitespace():
241+
reasoning_trace_var.set(None)
242+
243+
response = MockResponse(
244+
content=" <think>reasoning here</think> \n\n Final answer "
245+
)
246+
247+
_store_reasoning_traces(response)
248+
249+
assert reasoning_trace_var.get() == "reasoning here"
250+
assert response.content == "Final answer"
251+
252+
253+
def test_extract_and_remove_think_tags_basic():
254+
response = MockResponse(content="<think>reasoning</think>answer")
255+
256+
result = _extract_and_remove_think_tags(response)
257+
258+
assert result == "reasoning"
259+
assert response.content == "answer"
260+
261+
262+
def test_extract_and_remove_think_tags_multiline():
263+
response = MockResponse(content="<think>line1\nline2\nline3</think>final answer")
264+
265+
result = _extract_and_remove_think_tags(response)
266+
267+
assert result == "line1\nline2\nline3"
268+
assert response.content == "final answer"
269+
270+
271+
def test_extract_and_remove_think_tags_no_tags():
272+
response = MockResponse(content="just a normal response")
273+
274+
result = _extract_and_remove_think_tags(response)
275+
276+
assert result is None
277+
assert response.content == "just a normal response"
278+
279+
280+
def test_extract_and_remove_think_tags_incomplete():
281+
response = MockResponse(content="<think>incomplete")
282+
283+
result = _extract_and_remove_think_tags(response)
284+
285+
assert result is None
286+
assert response.content == "<think>incomplete"
287+
288+
289+
def test_extract_and_remove_think_tags_no_content_attribute():
290+
class ResponseWithoutContent:
291+
pass
292+
293+
response = ResponseWithoutContent()
294+
295+
result = _extract_and_remove_think_tags(response)
296+
297+
assert result is None
298+
299+
300+
def test_extract_and_remove_think_tags_wrong_order():
301+
response = MockResponse(content="</think> text here <think>")
302+
303+
result = _extract_and_remove_think_tags(response)
304+
305+
assert result is None
306+
assert response.content == "</think> text here <think>"

tests/test_reasoning_trace_extraction.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,3 +304,94 @@ async def test_reasoning_content_with_other_additional_kwargs(self):
304304
assert stored_trace == test_reasoning
305305

306306
reasoning_trace_var.set(None)
307+
308+
@pytest.mark.asyncio
309+
async def test_llm_call_extracts_reasoning_from_think_tags(self):
310+
test_reasoning = "Let me analyze this step by step"
311+
312+
mock_llm = AsyncMock()
313+
mock_response = AIMessage(
314+
content=f"<think>{test_reasoning}</think>The answer is 42",
315+
additional_kwargs={},
316+
)
317+
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
318+
319+
from nemoguardrails.actions.llm.utils import llm_call
320+
321+
reasoning_trace_var.set(None)
322+
result = await llm_call(mock_llm, "What is the answer?")
323+
324+
assert result == "The answer is 42"
325+
assert "<think>" not in result
326+
stored_trace = reasoning_trace_var.get()
327+
assert stored_trace == test_reasoning
328+
329+
reasoning_trace_var.set(None)
330+
331+
@pytest.mark.asyncio
332+
async def test_llm_call_prefers_additional_kwargs_over_think_tags(self):
333+
reasoning_from_kwargs = "This should be used"
334+
reasoning_from_tags = "This should be ignored"
335+
336+
mock_llm = AsyncMock()
337+
mock_response = AIMessage(
338+
content=f"<think>{reasoning_from_tags}</think>Response",
339+
additional_kwargs={"reasoning_content": reasoning_from_kwargs},
340+
)
341+
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
342+
343+
from nemoguardrails.actions.llm.utils import llm_call
344+
345+
reasoning_trace_var.set(None)
346+
result = await llm_call(mock_llm, "Query")
347+
348+
assert result == f"<think>{reasoning_from_tags}</think>Response"
349+
stored_trace = reasoning_trace_var.get()
350+
assert stored_trace == reasoning_from_kwargs
351+
352+
reasoning_trace_var.set(None)
353+
354+
@pytest.mark.asyncio
355+
async def test_llm_call_extracts_multiline_reasoning_from_think_tags(self):
356+
multiline_reasoning = """Step 1: Understand the question
357+
Step 2: Break down the problem
358+
Step 3: Formulate the answer"""
359+
360+
mock_llm = AsyncMock()
361+
mock_response = AIMessage(
362+
content=f"<think>{multiline_reasoning}</think>Final answer",
363+
additional_kwargs={},
364+
)
365+
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
366+
367+
from nemoguardrails.actions.llm.utils import llm_call
368+
369+
reasoning_trace_var.set(None)
370+
result = await llm_call(mock_llm, "Question")
371+
372+
assert result == "Final answer"
373+
assert "<think>" not in result
374+
stored_trace = reasoning_trace_var.get()
375+
assert stored_trace == multiline_reasoning
376+
377+
reasoning_trace_var.set(None)
378+
379+
@pytest.mark.asyncio
380+
async def test_llm_call_handles_incomplete_think_tags(self):
381+
mock_llm = AsyncMock()
382+
mock_response = AIMessage(
383+
content="<think>This is incomplete",
384+
additional_kwargs={},
385+
)
386+
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
387+
388+
from nemoguardrails.actions.llm.utils import llm_call
389+
390+
reasoning_trace_var.set(None)
391+
result = await llm_call(mock_llm, "Query")
392+
393+
assert result == "<think>This is incomplete"
394+
stored_trace = reasoning_trace_var.get()
395+
assert stored_trace is None
396+
397+
reasoning_trace_var.set(None)

0 commit comments

Comments
 (0)