Skip to content

Commit 59be002

Browse files
authored
Merge pull request #65 from PriorLabs/georg/eng-192-count-extension-entry-points-instead-of-model-calls
feat: `extension_entry` event that only fires once per instantiation
2 parents 9580f44 + fb600a6 commit 59be002

File tree

6 files changed

+186
-18
lines changed

6 files changed

+186
-18
lines changed

src/tabpfn_common_utils/telemetry/core/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from .events import (
66
BaseTelemetryEvent,
7+
ExtensionEntryEvent,
78
ModelLoadEvent,
89
PingEvent,
910
DatasetEvent,
@@ -24,6 +25,7 @@
2425
# Public exports
2526
__all__ = [
2627
"BaseTelemetryEvent",
28+
"ExtensionEntryEvent",
2729
"PingEvent",
2830
"DatasetEvent",
2931
"ModelLoadEvent",

src/tabpfn_common_utils/telemetry/core/decorators.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from functools import wraps
1616
from typing import Any, Callable, Dict, Literal, Optional, Tuple, Union
1717

18-
from .events import FitEvent, PredictEvent
18+
from .events import ExtensionEntryEvent, FitEvent, PredictEvent
1919
from .service import capture_event
2020
from tabpfn_common_utils.utils import shape_of
2121

@@ -214,6 +214,7 @@ def wrapped(*args, **kwargs):
214214
if _get_context_var("tabpfn_current_extension").get() is not None:
215215
return fn(*args, **kwargs)
216216
with _extension_context(extension_name):
217+
capture_event(ExtensionEntryEvent(extension_name=extension_name))
217218
return fn(*args, **kwargs)
218219

219220
setattr(wrapped, _MARKER_ATTR, extension_name)

src/tabpfn_common_utils/telemetry/core/events.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,3 +499,21 @@ class PredictEvent(ModelCallEvent):
499499
@property
500500
def name(self) -> str:
501501
return "predict_called"
502+
503+
504+
@dataclass
505+
class ExtensionEntryEvent(BaseTelemetryEvent):
506+
"""
507+
Event emitted once per user-facing extension entry point call.
508+
509+
Unlike FitEvent/PredictEvent which fire per downstream model call,
510+
this fires exactly once when the outermost @set_extension decorator
511+
is entered, giving an unbiased count of extension usage.
512+
"""
513+
514+
# Name of the extension that was entered
515+
extension_name: str = ""
516+
517+
@property
518+
def name(self) -> str:
519+
return "extension_entry"

tests/telemetry/core/test_decorators.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from __future__ import annotations
44

5+
from unittest.mock import patch
6+
57
import pytest
68

79
from tabpfn_common_utils.telemetry.core.decorators import (
@@ -10,6 +12,7 @@
1012
get_current_extension,
1113
_extension_context,
1214
)
15+
from tabpfn_common_utils.telemetry.core.events import ExtensionEntryEvent
1316

1417

1518
class TestSetExtensionDecorator:
@@ -268,3 +271,115 @@ def test_round_dims_special_cases(self) -> None:
268271
# Test some intermediate values
269272
assert _round_dims((1234, 67)) == (1200, 75) # 1234 -> 1200, 67 -> 75
270273
assert _round_dims((876, 89)) == (1000, 100) # 876 -> 1000, 89 -> 100
274+
275+
276+
class TestExtensionEntryEventEmission:
277+
"""Test that set_extension emits ExtensionEntryEvent correctly."""
278+
279+
@patch("tabpfn_common_utils.telemetry.core.decorators.capture_event")
280+
def test_single_call_emits_one_event(self, mock_capture):
281+
"""A single decorated function call emits exactly one ExtensionEntryEvent."""
282+
283+
@set_extension("test_ext")
284+
def my_func():
285+
return 42
286+
287+
result = my_func()
288+
289+
assert result == 42
290+
assert mock_capture.call_count == 1
291+
event = mock_capture.call_args[0][0]
292+
assert isinstance(event, ExtensionEntryEvent)
293+
assert event.extension_name == "test_ext"
294+
295+
@patch("tabpfn_common_utils.telemetry.core.decorators.capture_event")
296+
def test_nested_calls_emit_one_event(self, mock_capture):
297+
"""Nested decorated calls emit only one event for the outermost extension."""
298+
299+
@set_extension("outer")
300+
def outer():
301+
return inner()
302+
303+
@set_extension("inner")
304+
def inner():
305+
return get_current_extension()
306+
307+
result = outer()
308+
309+
# Inner should see the outer context
310+
assert result == "outer"
311+
# Only one event emitted (for the outer entry)
312+
assert mock_capture.call_count == 1
313+
event = mock_capture.call_args[0][0]
314+
assert isinstance(event, ExtensionEntryEvent)
315+
assert event.extension_name == "outer"
316+
317+
@patch("tabpfn_common_utils.telemetry.core.decorators.capture_event")
318+
def test_sequential_calls_emit_separate_events(self, mock_capture):
319+
"""Two sequential (non-nested) calls emit two separate events."""
320+
321+
@set_extension("ext_a")
322+
def func_a():
323+
return "a"
324+
325+
@set_extension("ext_b")
326+
def func_b():
327+
return "b"
328+
329+
func_a()
330+
func_b()
331+
332+
assert mock_capture.call_count == 2
333+
assert mock_capture.call_args_list[0][0][0].extension_name == "ext_a"
334+
assert mock_capture.call_args_list[1][0][0].extension_name == "ext_b"
335+
336+
@patch("tabpfn_common_utils.telemetry.core.decorators.capture_event")
337+
def test_class_decorator_emits_one_event_per_public_method_call(self, mock_capture):
338+
"""A class decorated with set_extension emits one event per public method call.
339+
__init__ is private (starts with _) so it's not wrapped by default."""
340+
341+
@set_extension("cls_ext")
342+
class MyClass:
343+
def do_work(self):
344+
return "done"
345+
346+
obj = MyClass()
347+
obj.do_work()
348+
349+
# Only do_work is public; __init__ starts with _ so not wrapped
350+
assert mock_capture.call_count == 1
351+
event = mock_capture.call_args[0][0]
352+
assert isinstance(event, ExtensionEntryEvent)
353+
assert event.extension_name == "cls_ext"
354+
355+
@patch("tabpfn_common_utils.telemetry.core.decorators.capture_event")
356+
def test_class_nested_method_calls_emit_one_event(self, mock_capture):
357+
"""When a class method calls another decorated function, only the outer emits."""
358+
359+
@set_extension("inner_ext")
360+
def helper():
361+
return "helped"
362+
363+
@set_extension("cls_ext")
364+
class MyClass:
365+
def do_work(self):
366+
return helper()
367+
368+
obj = MyClass()
369+
result = obj.do_work()
370+
371+
assert result == "helped"
372+
# do_work emits one event, helper() is nested so no event
373+
assert mock_capture.call_count == 1
374+
assert mock_capture.call_args[0][0].extension_name == "cls_ext"
375+
376+
@patch("posthog.Posthog.capture", side_effect=RuntimeError("PostHog down"))
377+
def test_capture_event_resilient_to_posthog_failure(self, _mock_posthog):
378+
"""PostHog client failure doesn't prevent the wrapped function from running."""
379+
380+
@set_extension("fail_ext")
381+
def my_func():
382+
return "still works"
383+
384+
result = my_func()
385+
assert result == "still works"

tests/telemetry/core/test_events.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from tabpfn_common_utils.telemetry.core.events import (
66
BaseTelemetryEvent,
77
DatasetEvent,
8+
ExtensionEntryEvent,
89
FitEvent,
910
ModelLoadEvent,
1011
PingEvent,
@@ -422,6 +423,51 @@ def test_model_load_event_properties_method(self):
422423
assert "install_id" in props
423424

424425

426+
class TestExtensionEntryEvent:
427+
"""Test ExtensionEntryEvent class"""
428+
429+
def test_extension_entry_event_initialization(self):
430+
"""Test ExtensionEntryEvent initialization with extension name"""
431+
event = ExtensionEntryEvent(extension_name="post_hoc_ensembles")
432+
433+
assert event.extension_name == "post_hoc_ensembles"
434+
assert event.name == "extension_entry"
435+
436+
def test_extension_entry_event_default_extension_name(self):
437+
"""Test ExtensionEntryEvent default extension_name is empty string"""
438+
event = ExtensionEntryEvent()
439+
440+
assert event.extension_name == ""
441+
assert event.name == "extension_entry"
442+
443+
def test_extension_entry_event_inherits_base_properties(self):
444+
"""Test that ExtensionEntryEvent inherits base telemetry properties"""
445+
event = ExtensionEntryEvent(extension_name="rf_pfn")
446+
447+
assert isinstance(event.python_version, str)
448+
assert isinstance(event.tabpfn_version, str)
449+
assert isinstance(event.timestamp, datetime)
450+
assert event.source == "sdk"
451+
452+
def test_extension_entry_event_properties_method(self):
453+
"""Test ExtensionEntryEvent properties method"""
454+
event = ExtensionEntryEvent(extension_name="interpretability")
455+
456+
props = event.properties
457+
458+
assert "name" not in props
459+
assert props["extension_name"] == "interpretability"
460+
assert "python_version" in props
461+
assert "tabpfn_version" in props
462+
463+
def test_extension_entry_event_with_colon_separated_name(self):
464+
"""Test ExtensionEntryEvent with sub-extension names like unsupervised:impute"""
465+
event = ExtensionEntryEvent(extension_name="unsupervised:impute")
466+
467+
assert event.extension_name == "unsupervised:impute"
468+
assert event.name == "extension_entry"
469+
470+
425471
class TestEventIntegration:
426472
"""Integration tests for all event types"""
427473

0 commit comments

Comments
 (0)