Skip to content

Commit 22bbfd9

Browse files
feat(Profiler): Trace and display AgentInterface calls
1 parent bc7036a commit 22bbfd9

File tree

5 files changed

+286
-152
lines changed

5 files changed

+286
-152
lines changed

src/ai-bundle/config/services.php

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
namespace Symfony\Component\DependencyInjection\Loader\Configurator;
1313

14+
use Symfony\AI\Agent\AgentInterface;
1415
use Symfony\AI\Agent\StructuredOutput\AgentProcessor as StructureOutputProcessor;
1516
use Symfony\AI\Agent\StructuredOutput\ResponseFormatFactory;
1617
use Symfony\AI\Agent\StructuredOutput\ResponseFormatFactoryInterface;
@@ -23,6 +24,8 @@
2324
use Symfony\AI\AiBundle\Command\AgentCallCommand;
2425
use Symfony\AI\AiBundle\Command\PlatformInvokeCommand;
2526
use Symfony\AI\AiBundle\Profiler\DataCollector;
27+
use Symfony\AI\AiBundle\Profiler\TraceableAgent;
28+
use Symfony\AI\AiBundle\Profiler\TraceableToolbox;
2629
use Symfony\AI\AiBundle\Security\EventListener\IsGrantedToolAttributeListener;
2730
use Symfony\AI\Chat\Command\DropStoreCommand as DropMessageStoreCommand;
2831
use Symfony\AI\Chat\Command\SetupStoreCommand as SetupMessageStoreCommand;
@@ -165,10 +168,18 @@
165168
->tag('kernel.event_listener')
166169

167170
// profiler
171+
->set('ai.traceable_agent', TraceableAgent::class)
172+
->decorate(AgentInterface::class, priority: 5)
173+
->args([
174+
service('.inner'),
175+
service('ai.data_collector'),
176+
service('request_stack'),
177+
])
168178
->set('ai.data_collector', DataCollector::class)
169179
->args([
170-
tagged_iterator('ai.traceable_platform'),
171-
tagged_iterator('ai.traceable_toolbox'),
180+
tagged_iterator('ai.platform'),
181+
service('ai.toolbox.default'),
182+
tagged_iterator('ai.toolbox'),
172183
])
173184
->tag('data_collector')
174185

src/ai-bundle/src/Profiler/DataCollector.php

Lines changed: 48 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,20 @@
1111

1212
namespace Symfony\AI\AiBundle\Profiler;
1313

14-
use Symfony\AI\Agent\Toolbox\ToolResult;
15-
use Symfony\AI\Platform\Metadata\Metadata;
14+
use Symfony\AI\Agent\Toolbox\ToolboxInterface;
15+
use Symfony\AI\Platform\Model;
1616
use Symfony\AI\Platform\Tool\Tool;
1717
use Symfony\Bundle\FrameworkBundle\DataCollector\AbstractDataCollector;
1818
use Symfony\Component\HttpFoundation\Request;
1919
use Symfony\Component\HttpFoundation\Response;
2020
use Symfony\Component\HttpKernel\DataCollector\LateDataCollectorInterface;
21+
use Symfony\Component\VarDumper\Cloner\Data;
2122

2223
/**
2324
* @author Christopher Hertel <[email protected]>
2425
*
2526
* @phpstan-import-type PlatformCallData from TraceablePlatform
27+
* @phpstan-import-type ToolCallData from TraceableToolbox
2628
*/
2729
final class DataCollector extends AbstractDataCollector implements LateDataCollectorInterface
2830
{
@@ -37,11 +39,17 @@ final class DataCollector extends AbstractDataCollector implements LateDataColle
3739
private readonly array $toolboxes;
3840

3941
/**
40-
* @param TraceablePlatform[] $platforms
41-
* @param TraceableToolbox[] $toolboxes
42+
* @var list<array{method: string, duration: float, input: mixed, result: mixed, error: ?\Throwable}>
43+
*/
44+
private array $collectedChatCalls = [];
45+
46+
/**
47+
* @param iterable<TraceablePlatform> $platforms
48+
* @param iterable<TraceableToolbox> $toolboxes
4249
*/
4350
public function __construct(
4451
iterable $platforms,
52+
private readonly ToolboxInterface $defaultToolBox,
4553
iterable $toolboxes,
4654
) {
4755
$this->platforms = $platforms instanceof \Traversable ? iterator_to_array($platforms) : $platforms;
@@ -50,15 +58,26 @@ public function __construct(
5058

5159
public function collect(Request $request, Response $response, ?\Throwable $exception = null): void
5260
{
53-
$this->lateCollect();
5461
}
5562

5663
public function lateCollect(): void
5764
{
5865
$this->data = [
59-
'tools' => $this->getAllTools(),
66+
'tools' => $this->defaultToolBox->getTools(),
6067
'platform_calls' => array_merge(...array_map($this->awaitCallResults(...), $this->platforms)),
6168
'tool_calls' => array_merge(...array_map(fn (TraceableToolbox $toolbox) => $toolbox->calls, $this->toolboxes)),
69+
'chat_calls' => $this->cloneVar($this->collectedChatCalls),
70+
];
71+
}
72+
73+
public function collectChatCall(string $method, float $duration, mixed $input, mixed $result, ?\Throwable $error): void
74+
{
75+
$this->collectedChatCalls[] = [
76+
'method' => $method,
77+
'duration' => $duration,
78+
'input' => $input,
79+
'result' => $result,
80+
'error' => $error,
6281
];
6382
}
6483

@@ -84,44 +103,54 @@ public function getTools(): array
84103
}
85104

86105
/**
87-
* @return ToolResult[]
106+
* @return ToolCallData[]
88107
*/
89108
public function getToolCalls(): array
90109
{
91110
return $this->data['tool_calls'] ?? [];
92111
}
93112

94113
/**
95-
* @return Tool[]
114+
* @return list<array{method: string, duration: float, input: mixed, result: mixed, error: ?\Throwable}>
96115
*/
97-
private function getAllTools(): array
116+
public function getChatCalls(): array
98117
{
99-
return array_merge(...array_map(fn (TraceableToolbox $toolbox) => $toolbox->getTools(), $this->toolboxes));
118+
if (!isset($this->data['chat_calls'])) {
119+
return [];
120+
}
121+
122+
$chatCalls = $this->data['chat_calls']->getValue(true);
123+
124+
/** @var list<array{method: string, duration: float, input: mixed, result: mixed, error: ?\Throwable}> $chatCalls */
125+
return $chatCalls;
126+
}
127+
128+
public function reset(): void
129+
{
130+
$this->data = [];
131+
$this->collectedChatCalls = [];
100132
}
101133

102134
/**
103135
* @return array{
104-
* model: string,
105-
* input: array<mixed>|string|object,
106-
* options: array<string, mixed>,
107-
* result: string|iterable<mixed>|object|null,
108-
* metadata: Metadata,
136+
* model: Model,
137+
* input: array<mixed>|string|object,
138+
* options: array<string, mixed>,
139+
* result: string|iterable<mixed>|object|null
109140
* }[]
110141
*/
111142
private function awaitCallResults(TraceablePlatform $platform): array
112143
{
113144
$calls = $platform->calls;
114145
foreach ($calls as $key => $call) {
115-
$result = $call['result']->getResult();
146+
$result = $call['result'];
116147

117148
if (isset($platform->resultCache[$result])) {
118149
$call['result'] = $platform->resultCache[$result];
119150
} else {
120-
$call['result'] = $result->getContent();
151+
$call['result'] = $result->asText();
121152
}
122153

123-
$call['metadata'] = $result->getMetadata();
124-
125154
$calls[$key] = $call;
126155
}
127156

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
<?php
2+
3+
namespace Symfony\AI\AiBundle\Profiler;
4+
5+
use Symfony\AI\Agent\AgentInterface;
6+
use Symfony\AI\Platform\Message\MessageBag;
7+
use Symfony\AI\Platform\Result\ResultInterface;
8+
use Symfony\Component\HttpFoundation\RequestStack;
9+
use Symfony\Contracts\Service\ResetInterface;
10+
11+
final class TraceableAgent implements AgentInterface, ResetInterface
12+
{
13+
public function __construct(
14+
private readonly AgentInterface $decorated,
15+
private readonly DataCollector $collector,
16+
private readonly RequestStack $requestStack,
17+
) {
18+
}
19+
20+
public function call(MessageBag $messages, array $options = []): ResultInterface
21+
{
22+
$startTime = microtime(true);
23+
$error = null;
24+
$response = null;
25+
26+
try {
27+
return $response = $this->decorated->call($messages, $options);
28+
} catch (\Throwable $e) {
29+
$error = $e;
30+
throw $e;
31+
} finally {
32+
if ($this->requestStack->getMainRequest() === $this->requestStack->getCurrentRequest()) {
33+
$this->collector->collectChatCall(
34+
'call',
35+
microtime(true) - $startTime,
36+
$messages,
37+
$response,
38+
$error
39+
);
40+
}
41+
}
42+
}
43+
44+
public function getName(): string
45+
{
46+
return $this->decorated->getName();
47+
}
48+
49+
public function reset(): void
50+
{
51+
if ($this->decorated instanceof ResetInterface) {
52+
$this->decorated->reset();
53+
}
54+
}
55+
}

0 commit comments

Comments
 (0)