|
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | 15 |
|
16 | | -from nemoguardrails.actions.llm.utils import _infer_provider_from_module |
| 16 | +from nemoguardrails.actions.llm.utils import ( |
| 17 | + _extract_and_remove_think_tags, |
| 18 | + _infer_provider_from_module, |
| 19 | + _store_reasoning_traces, |
| 20 | +) |
| 21 | +from nemoguardrails.context import reasoning_trace_var |
17 | 22 |
|
18 | 23 |
|
19 | 24 | class MockOpenAILLM: |
@@ -123,3 +128,179 @@ class Wrapper3(Wrapper2): |
123 | 128 | llm = Wrapper3() |
124 | 129 | provider = _infer_provider_from_module(llm) |
125 | 130 | assert provider == "anthropic" |
| 131 | + |
| 132 | + |
| 133 | +class MockResponse: |
| 134 | + def __init__(self, content="", additional_kwargs=None): |
| 135 | + self.content = content |
| 136 | + self.additional_kwargs = additional_kwargs or {} |
| 137 | + |
| 138 | + |
| 139 | +def test_store_reasoning_traces_from_additional_kwargs(): |
| 140 | + reasoning_trace_var.set(None) |
| 141 | + |
| 142 | + response = MockResponse( |
| 143 | + content="The answer is 42", |
| 144 | + additional_kwargs={"reasoning_content": "Let me think about this..."}, |
| 145 | + ) |
| 146 | + |
| 147 | + _store_reasoning_traces(response) |
| 148 | + |
| 149 | + assert reasoning_trace_var.get() == "Let me think about this..." |
| 150 | + |
| 151 | + |
| 152 | +def test_store_reasoning_traces_from_think_tags(): |
| 153 | + reasoning_trace_var.set(None) |
| 154 | + |
| 155 | + response = MockResponse( |
| 156 | + content="<think>Let me think about this...</think>The answer is 42" |
| 157 | + ) |
| 158 | + |
| 159 | + _store_reasoning_traces(response) |
| 160 | + |
| 161 | + assert reasoning_trace_var.get() == "Let me think about this..." |
| 162 | + assert response.content == "The answer is 42" |
| 163 | + |
| 164 | + |
| 165 | +def test_store_reasoning_traces_multiline_think_tags(): |
| 166 | + reasoning_trace_var.set(None) |
| 167 | + |
| 168 | + response = MockResponse( |
| 169 | + content="<think>Step 1: Analyze the problem\nStep 2: Consider options\nStep 3: Choose solution</think>The answer is 42" |
| 170 | + ) |
| 171 | + |
| 172 | + _store_reasoning_traces(response) |
| 173 | + |
| 174 | + assert ( |
| 175 | + reasoning_trace_var.get() |
| 176 | + == "Step 1: Analyze the problem\nStep 2: Consider options\nStep 3: Choose solution" |
| 177 | + ) |
| 178 | + assert response.content == "The answer is 42" |
| 179 | + |
| 180 | + |
| 181 | +def test_store_reasoning_traces_prefers_additional_kwargs(): |
| 182 | + reasoning_trace_var.set(None) |
| 183 | + |
| 184 | + response = MockResponse( |
| 185 | + content="<think>This should not be used</think>The answer is 42", |
| 186 | + additional_kwargs={"reasoning_content": "This should be used"}, |
| 187 | + ) |
| 188 | + |
| 189 | + _store_reasoning_traces(response) |
| 190 | + |
| 191 | + assert reasoning_trace_var.get() == "This should be used" |
| 192 | + |
| 193 | + |
| 194 | +def test_store_reasoning_traces_no_reasoning_content(): |
| 195 | + reasoning_trace_var.set(None) |
| 196 | + |
| 197 | + response = MockResponse(content="The answer is 42") |
| 198 | + |
| 199 | + _store_reasoning_traces(response) |
| 200 | + |
| 201 | + assert reasoning_trace_var.get() is None |
| 202 | + |
| 203 | + |
| 204 | +def test_store_reasoning_traces_empty_reasoning_content(): |
| 205 | + reasoning_trace_var.set(None) |
| 206 | + |
| 207 | + response = MockResponse( |
| 208 | + content="The answer is 42", additional_kwargs={"reasoning_content": ""} |
| 209 | + ) |
| 210 | + |
| 211 | + _store_reasoning_traces(response) |
| 212 | + |
| 213 | + assert reasoning_trace_var.get() is None |
| 214 | + |
| 215 | + |
| 216 | +def test_store_reasoning_traces_incomplete_think_tags(): |
| 217 | + reasoning_trace_var.set(None) |
| 218 | + |
| 219 | + response = MockResponse(content="<think>This is incomplete") |
| 220 | + |
| 221 | + _store_reasoning_traces(response) |
| 222 | + |
| 223 | + assert reasoning_trace_var.get() is None |
| 224 | + |
| 225 | + |
| 226 | +def test_store_reasoning_traces_no_content_attribute(): |
| 227 | + reasoning_trace_var.set(None) |
| 228 | + |
| 229 | + class ResponseWithoutContent: |
| 230 | + def __init__(self): |
| 231 | + self.additional_kwargs = {} |
| 232 | + |
| 233 | + response = ResponseWithoutContent() |
| 234 | + |
| 235 | + _store_reasoning_traces(response) |
| 236 | + |
| 237 | + assert reasoning_trace_var.get() is None |
| 238 | + |
| 239 | + |
| 240 | +def test_store_reasoning_traces_removes_think_tags_with_whitespace(): |
| 241 | + reasoning_trace_var.set(None) |
| 242 | + |
| 243 | + response = MockResponse( |
| 244 | + content=" <think>reasoning here</think> \n\n Final answer " |
| 245 | + ) |
| 246 | + |
| 247 | + _store_reasoning_traces(response) |
| 248 | + |
| 249 | + assert reasoning_trace_var.get() == "reasoning here" |
| 250 | + assert response.content == "Final answer" |
| 251 | + |
| 252 | + |
| 253 | +def test_extract_and_remove_think_tags_basic(): |
| 254 | + response = MockResponse(content="<think>reasoning</think>answer") |
| 255 | + |
| 256 | + result = _extract_and_remove_think_tags(response) |
| 257 | + |
| 258 | + assert result == "reasoning" |
| 259 | + assert response.content == "answer" |
| 260 | + |
| 261 | + |
| 262 | +def test_extract_and_remove_think_tags_multiline(): |
| 263 | + response = MockResponse(content="<think>line1\nline2\nline3</think>final answer") |
| 264 | + |
| 265 | + result = _extract_and_remove_think_tags(response) |
| 266 | + |
| 267 | + assert result == "line1\nline2\nline3" |
| 268 | + assert response.content == "final answer" |
| 269 | + |
| 270 | + |
| 271 | +def test_extract_and_remove_think_tags_no_tags(): |
| 272 | + response = MockResponse(content="just a normal response") |
| 273 | + |
| 274 | + result = _extract_and_remove_think_tags(response) |
| 275 | + |
| 276 | + assert result is None |
| 277 | + assert response.content == "just a normal response" |
| 278 | + |
| 279 | + |
| 280 | +def test_extract_and_remove_think_tags_incomplete(): |
| 281 | + response = MockResponse(content="<think>incomplete") |
| 282 | + |
| 283 | + result = _extract_and_remove_think_tags(response) |
| 284 | + |
| 285 | + assert result is None |
| 286 | + assert response.content == "<think>incomplete" |
| 287 | + |
| 288 | + |
| 289 | +def test_extract_and_remove_think_tags_no_content_attribute(): |
| 290 | + class ResponseWithoutContent: |
| 291 | + pass |
| 292 | + |
| 293 | + response = ResponseWithoutContent() |
| 294 | + |
| 295 | + result = _extract_and_remove_think_tags(response) |
| 296 | + |
| 297 | + assert result is None |
| 298 | + |
| 299 | + |
| 300 | +def test_extract_and_remove_think_tags_wrong_order(): |
| 301 | + response = MockResponse(content="</think> text here <think>") |
| 302 | + |
| 303 | + result = _extract_and_remove_think_tags(response) |
| 304 | + |
| 305 | + assert result is None |
| 306 | + assert response.content == "</think> text here <think>" |
0 commit comments