|
2 | 2 |
|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
| 5 | +from unittest.mock import patch |
| 6 | + |
5 | 7 | import pytest |
6 | 8 |
|
7 | 9 | from tabpfn_common_utils.telemetry.core.decorators import ( |
|
10 | 12 | get_current_extension, |
11 | 13 | _extension_context, |
12 | 14 | ) |
| 15 | +from tabpfn_common_utils.telemetry.core.events import ExtensionEntryEvent |
13 | 16 |
|
14 | 17 |
|
15 | 18 | class TestSetExtensionDecorator: |
@@ -268,3 +271,115 @@ def test_round_dims_special_cases(self) -> None: |
268 | 271 | # Test some intermediate values |
269 | 272 | assert _round_dims((1234, 67)) == (1200, 75) # 1234 -> 1200, 67 -> 75 |
270 | 273 | assert _round_dims((876, 89)) == (1000, 100) # 876 -> 1000, 89 -> 100 |
| 274 | + |
| 275 | + |
| 276 | +class TestExtensionEntryEventEmission: |
| 277 | + """Test that set_extension emits ExtensionEntryEvent correctly.""" |
| 278 | + |
| 279 | + @patch("tabpfn_common_utils.telemetry.core.decorators.capture_event") |
| 280 | + def test_single_call_emits_one_event(self, mock_capture): |
| 281 | + """A single decorated function call emits exactly one ExtensionEntryEvent.""" |
| 282 | + |
| 283 | + @set_extension("test_ext") |
| 284 | + def my_func(): |
| 285 | + return 42 |
| 286 | + |
| 287 | + result = my_func() |
| 288 | + |
| 289 | + assert result == 42 |
| 290 | + assert mock_capture.call_count == 1 |
| 291 | + event = mock_capture.call_args[0][0] |
| 292 | + assert isinstance(event, ExtensionEntryEvent) |
| 293 | + assert event.extension_name == "test_ext" |
| 294 | + |
| 295 | + @patch("tabpfn_common_utils.telemetry.core.decorators.capture_event") |
| 296 | + def test_nested_calls_emit_one_event(self, mock_capture): |
| 297 | + """Nested decorated calls emit only one event for the outermost extension.""" |
| 298 | + |
| 299 | + @set_extension("outer") |
| 300 | + def outer(): |
| 301 | + return inner() |
| 302 | + |
| 303 | + @set_extension("inner") |
| 304 | + def inner(): |
| 305 | + return get_current_extension() |
| 306 | + |
| 307 | + result = outer() |
| 308 | + |
| 309 | + # Inner should see the outer context |
| 310 | + assert result == "outer" |
| 311 | + # Only one event emitted (for the outer entry) |
| 312 | + assert mock_capture.call_count == 1 |
| 313 | + event = mock_capture.call_args[0][0] |
| 314 | + assert isinstance(event, ExtensionEntryEvent) |
| 315 | + assert event.extension_name == "outer" |
| 316 | + |
| 317 | + @patch("tabpfn_common_utils.telemetry.core.decorators.capture_event") |
| 318 | + def test_sequential_calls_emit_separate_events(self, mock_capture): |
| 319 | + """Two sequential (non-nested) calls emit two separate events.""" |
| 320 | + |
| 321 | + @set_extension("ext_a") |
| 322 | + def func_a(): |
| 323 | + return "a" |
| 324 | + |
| 325 | + @set_extension("ext_b") |
| 326 | + def func_b(): |
| 327 | + return "b" |
| 328 | + |
| 329 | + func_a() |
| 330 | + func_b() |
| 331 | + |
| 332 | + assert mock_capture.call_count == 2 |
| 333 | + assert mock_capture.call_args_list[0][0][0].extension_name == "ext_a" |
| 334 | + assert mock_capture.call_args_list[1][0][0].extension_name == "ext_b" |
| 335 | + |
| 336 | + @patch("tabpfn_common_utils.telemetry.core.decorators.capture_event") |
| 337 | + def test_class_decorator_emits_one_event_per_public_method_call(self, mock_capture): |
| 338 | + """A class decorated with set_extension emits one event per public method call. |
| 339 | + __init__ is private (starts with _) so it's not wrapped by default.""" |
| 340 | + |
| 341 | + @set_extension("cls_ext") |
| 342 | + class MyClass: |
| 343 | + def do_work(self): |
| 344 | + return "done" |
| 345 | + |
| 346 | + obj = MyClass() |
| 347 | + obj.do_work() |
| 348 | + |
| 349 | + # Only do_work is public; __init__ starts with _ so not wrapped |
| 350 | + assert mock_capture.call_count == 1 |
| 351 | + event = mock_capture.call_args[0][0] |
| 352 | + assert isinstance(event, ExtensionEntryEvent) |
| 353 | + assert event.extension_name == "cls_ext" |
| 354 | + |
| 355 | + @patch("tabpfn_common_utils.telemetry.core.decorators.capture_event") |
| 356 | + def test_class_nested_method_calls_emit_one_event(self, mock_capture): |
| 357 | + """When a class method calls another decorated function, only the outer emits.""" |
| 358 | + |
| 359 | + @set_extension("inner_ext") |
| 360 | + def helper(): |
| 361 | + return "helped" |
| 362 | + |
| 363 | + @set_extension("cls_ext") |
| 364 | + class MyClass: |
| 365 | + def do_work(self): |
| 366 | + return helper() |
| 367 | + |
| 368 | + obj = MyClass() |
| 369 | + result = obj.do_work() |
| 370 | + |
| 371 | + assert result == "helped" |
| 372 | + # do_work emits one event, helper() is nested so no event |
| 373 | + assert mock_capture.call_count == 1 |
| 374 | + assert mock_capture.call_args[0][0].extension_name == "cls_ext" |
| 375 | + |
| 376 | + @patch("posthog.Posthog.capture", side_effect=RuntimeError("PostHog down")) |
| 377 | + def test_capture_event_resilient_to_posthog_failure(self, _mock_posthog): |
| 378 | + """PostHog client failure doesn't prevent the wrapped function from running.""" |
| 379 | + |
| 380 | + @set_extension("fail_ext") |
| 381 | + def my_func(): |
| 382 | + return "still works" |
| 383 | + |
| 384 | + result = my_func() |
| 385 | + assert result == "still works" |
0 commit comments