Skip to content

Commit abc8bed

Browse files
authored
feat: capture ExceptionGroup sub-exceptions in traceback (#20)
* feat: capture ExceptionGroup sub-exceptions in traceback * fix pylint
1 parent a705350 commit abc8bed

2 files changed

Lines changed: 82 additions & 5 deletions

File tree

py/src/braintrust/logger.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4395,11 +4395,7 @@ def log_exc_info_to_span(
43954395

43964396

43974397
def stringify_exception(exc_type: type[BaseException], exc_value: BaseException, tb: TracebackType | None) -> str:
4398-
return "".join(
4399-
traceback.format_exception_only(exc_type, exc_value)
4400-
+ ["\nTraceback (most recent call last):\n"]
4401-
+ traceback.format_tb(tb)
4402-
)
4398+
return "".join(traceback.format_exception(exc_type, exc_value, tb))
44034399

44044400

44054401
def _strip_nones(d: T, deep: bool) -> T:

py/src/braintrust/test_logger.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from unittest.mock import MagicMock, patch
1111

1212
import braintrust
13+
import exceptiongroup
1314
import pytest
1415
from braintrust import (
1516
Attachment,
@@ -27,6 +28,7 @@
2728
parent_context,
2829
render_message,
2930
render_mustache,
31+
stringify_exception,
3032
)
3133
from braintrust.prompt import PromptChatBlock, PromptData, PromptMessage, PromptSchema
3234
from braintrust.test_helpers import (
@@ -3318,3 +3320,82 @@ def my_traced_function():
33183320
my_traced_function()
33193321

33203322
assert captured_name == "my_traced_function"
3323+
3324+
3325+
def _raise_test_exception_group():
3326+
"""Raise and return a standard ExceptionGroup with a traceback for testing."""
3327+
try:
3328+
raise exceptiongroup.ExceptionGroup(
3329+
"Multiple failures",
3330+
[
3331+
ConnectionRefusedError("[Errno 61] Connection refused"),
3332+
ValueError("Invalid configuration"),
3333+
],
3334+
)
3335+
except exceptiongroup.ExceptionGroup as eg:
3336+
return eg
3337+
3338+
3339+
def _assert_test_exception_group_contents(error_str):
3340+
"""Assert that error_str contains the expected sub-exception details."""
3341+
assert "ExceptionGroup: Multiple failures" in error_str
3342+
assert "ConnectionRefusedError" in error_str
3343+
assert "[Errno 61] Connection refused" in error_str
3344+
assert "ValueError" in error_str
3345+
assert "Invalid configuration" in error_str
3346+
3347+
3348+
def test_stringify_exception_with_exception_group():
3349+
eg = _raise_test_exception_group()
3350+
result = stringify_exception(type(eg), eg, eg.__traceback__)
3351+
_assert_test_exception_group_contents(result)
3352+
assert "(2 sub-exceptions)" in result
3353+
3354+
3355+
def test_stringify_exception_with_nested_exception_group():
3356+
result = ""
3357+
try:
3358+
inner = exceptiongroup.ExceptionGroup("inner", [TypeError("bad type")])
3359+
raise exceptiongroup.ExceptionGroup(
3360+
"outer",
3361+
[inner, RuntimeError("top-level error")],
3362+
)
3363+
except exceptiongroup.ExceptionGroup as eg:
3364+
result = stringify_exception(type(eg), eg, eg.__traceback__)
3365+
3366+
assert result, "ExceptionGroup was not raised"
3367+
assert "outer" in result
3368+
assert "inner" in result
3369+
assert "TypeError" in result
3370+
assert "bad type" in result
3371+
assert "RuntimeError" in result
3372+
assert "top-level error" in result
3373+
3374+
3375+
def test_span_exit_logs_exception_group_sub_exceptions(with_memory_logger):
3376+
"""Verify sub-exceptions are captured when an ExceptionGroup propagates through span.__exit__."""
3377+
init_test_logger(__name__)
3378+
3379+
with pytest.raises(exceptiongroup.ExceptionGroup):
3380+
with braintrust.current_logger().start_span(name="eg-span"):
3381+
raise _raise_test_exception_group()
3382+
3383+
logs = with_memory_logger.pop()
3384+
assert len(logs) == 1
3385+
_assert_test_exception_group_contents(logs[0].get("error", ""))
3386+
3387+
3388+
def test_traced_logs_exception_group_sub_exceptions(with_memory_logger):
3389+
"""Verify sub-exceptions are captured when an ExceptionGroup propagates through @traced."""
3390+
init_test_logger(__name__)
3391+
3392+
@logger.traced
3393+
def failing_function():
3394+
raise _raise_test_exception_group()
3395+
3396+
with pytest.raises(exceptiongroup.ExceptionGroup):
3397+
failing_function()
3398+
3399+
logs = with_memory_logger.pop()
3400+
assert len(logs) == 1
3401+
_assert_test_exception_group_contents(logs[0].get("error", ""))

0 commit comments

Comments
 (0)