Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions src/ai-bundle/config/services.php
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

namespace Symfony\Component\DependencyInjection\Loader\Configurator;

use Symfony\AI\Agent\AgentInterface;
use Symfony\AI\Agent\StructuredOutput\AgentProcessor as StructureOutputProcessor;
use Symfony\AI\Agent\StructuredOutput\ResponseFormatFactory;
use Symfony\AI\Agent\StructuredOutput\ResponseFormatFactoryInterface;
Expand All @@ -23,6 +24,8 @@
use Symfony\AI\AiBundle\Command\AgentCallCommand;
use Symfony\AI\AiBundle\Command\PlatformInvokeCommand;
use Symfony\AI\AiBundle\Profiler\DataCollector;
use Symfony\AI\AiBundle\Profiler\TraceableAgent;
use Symfony\AI\AiBundle\Profiler\TraceableToolbox;
use Symfony\AI\AiBundle\Security\EventListener\IsGrantedToolAttributeListener;
use Symfony\AI\Chat\Command\DropStoreCommand as DropMessageStoreCommand;
use Symfony\AI\Chat\Command\SetupStoreCommand as SetupMessageStoreCommand;
Expand Down Expand Up @@ -165,10 +168,18 @@
->tag('kernel.event_listener')

// profiler
->set('ai.traceable_agent', TraceableAgent::class)
->decorate(AgentInterface::class, priority: 5)
->args([
service('.inner'),
service('ai.data_collector'),
service('request_stack'),
])
->set('ai.data_collector', DataCollector::class)
->args([
tagged_iterator('ai.traceable_platform'),
tagged_iterator('ai.traceable_toolbox'),
tagged_iterator('ai.platform'),
service('ai.toolbox.default'),
tagged_iterator('ai.toolbox'),
])
->tag('data_collector')

Expand Down
67 changes: 48 additions & 19 deletions src/ai-bundle/src/Profiler/DataCollector.php
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,20 @@

namespace Symfony\AI\AiBundle\Profiler;

use Symfony\AI\Agent\Toolbox\ToolResult;
use Symfony\AI\Platform\Metadata\Metadata;
use Symfony\AI\Agent\Toolbox\ToolboxInterface;
use Symfony\AI\Platform\Model;
use Symfony\AI\Platform\Tool\Tool;
use Symfony\Bundle\FrameworkBundle\DataCollector\AbstractDataCollector;
use Symfony\Component\HttpFoundation\Request;
use Symfony\Component\HttpFoundation\Response;
use Symfony\Component\HttpKernel\DataCollector\LateDataCollectorInterface;
use Symfony\Component\VarDumper\Cloner\Data;

/**
* @author Christopher Hertel <[email protected]>
*
* @phpstan-import-type PlatformCallData from TraceablePlatform
* @phpstan-import-type ToolCallData from TraceableToolbox
*/
final class DataCollector extends AbstractDataCollector implements LateDataCollectorInterface
{
Expand All @@ -37,11 +39,17 @@ final class DataCollector extends AbstractDataCollector implements LateDataColle
private readonly array $toolboxes;

/**
* @param TraceablePlatform[] $platforms
* @param TraceableToolbox[] $toolboxes
* @var list<array{method: string, duration: float, input: mixed, result: mixed, error: ?\Throwable}>
*/
private array $collectedChatCalls = [];

/**
* @param iterable<TraceablePlatform> $platforms
* @param iterable<TraceableToolbox> $toolboxes
*/
public function __construct(
iterable $platforms,
private readonly ToolboxInterface $defaultToolBox,
iterable $toolboxes,
) {
$this->platforms = $platforms instanceof \Traversable ? iterator_to_array($platforms) : $platforms;
Expand All @@ -50,15 +58,26 @@ public function __construct(

public function collect(Request $request, Response $response, ?\Throwable $exception = null): void
{
$this->lateCollect();
}

public function lateCollect(): void
{
$this->data = [
'tools' => $this->getAllTools(),
'tools' => $this->defaultToolBox->getTools(),
'platform_calls' => array_merge(...array_map($this->awaitCallResults(...), $this->platforms)),
'tool_calls' => array_merge(...array_map(fn (TraceableToolbox $toolbox) => $toolbox->calls, $this->toolboxes)),
'chat_calls' => $this->cloneVar($this->collectedChatCalls),
];
}

public function collectChatCall(string $method, float $duration, mixed $input, mixed $result, ?\Throwable $error): void
{
$this->collectedChatCalls[] = [
'method' => $method,
'duration' => $duration,
'input' => $input,
'result' => $result,
'error' => $error,
];
}

Expand All @@ -84,44 +103,54 @@ public function getTools(): array
}

/**
* @return ToolResult[]
* @return ToolCallData[]
*/
public function getToolCalls(): array
{
return $this->data['tool_calls'] ?? [];
}

/**
* @return Tool[]
* @return list<array{method: string, duration: float, input: mixed, result: mixed, error: ?\Throwable}>
*/
private function getAllTools(): array
public function getChatCalls(): array
{
return array_merge(...array_map(fn (TraceableToolbox $toolbox) => $toolbox->getTools(), $this->toolboxes));
if (!isset($this->data['chat_calls'])) {
return [];
}

$chatCalls = $this->data['chat_calls']->getValue(true);

/** @var list<array{method: string, duration: float, input: mixed, result: mixed, error: ?\Throwable}> $chatCalls */
return $chatCalls;
}

public function reset(): void
{
$this->data = [];
$this->collectedChatCalls = [];
}

/**
* @return array{
* model: string,
* input: array<mixed>|string|object,
* options: array<string, mixed>,
* result: string|iterable<mixed>|object|null,
* metadata: Metadata,
* model: Model,
* input: array<mixed>|string|object,
* options: array<string, mixed>,
* result: string|iterable<mixed>|object|null
* }[]
*/
private function awaitCallResults(TraceablePlatform $platform): array
{
$calls = $platform->calls;
foreach ($calls as $key => $call) {
$result = $call['result']->getResult();
$result = $call['result'];

if (isset($platform->resultCache[$result])) {
$call['result'] = $platform->resultCache[$result];
} else {
$call['result'] = $result->getContent();
$call['result'] = $result->asText();
}

$call['metadata'] = $result->getMetadata();

$calls[$key] = $call;
}

Expand Down
55 changes: 55 additions & 0 deletions src/ai-bundle/src/Profiler/TraceableAgent.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
<?php

namespace Symfony\AI\AiBundle\Profiler;

use Symfony\AI\Agent\AgentInterface;
use Symfony\AI\Platform\Message\MessageBag;
use Symfony\AI\Platform\Result\ResultInterface;
use Symfony\Component\HttpFoundation\RequestStack;
use Symfony\Contracts\Service\ResetInterface;

final class TraceableAgent implements AgentInterface, ResetInterface
{
public function __construct(
private readonly AgentInterface $decorated,
private readonly DataCollector $collector,
private readonly RequestStack $requestStack,
) {
}

public function call(MessageBag $messages, array $options = []): ResultInterface
{
$startTime = microtime(true);
$error = null;
$response = null;

try {
return $response = $this->decorated->call($messages, $options);
} catch (\Throwable $e) {
$error = $e;
throw $e;
} finally {
if ($this->requestStack->getMainRequest() === $this->requestStack->getCurrentRequest()) {
$this->collector->collectChatCall(
'call',
microtime(true) - $startTime,
$messages,
$response,
$error
);
}
}
}

public function getName(): string
{
return $this->decorated->getName();
}

public function reset(): void
{
if ($this->decorated instanceof ResetInterface) {
$this->decorated->reset();
}
}
}
Loading
Loading