diff --git a/src/DependencyInjection/Configuration.php b/src/DependencyInjection/Configuration.php index c5e2454..c4a0dfd 100644 --- a/src/DependencyInjection/Configuration.php +++ b/src/DependencyInjection/Configuration.php @@ -197,7 +197,7 @@ public function getConfigTreeBuilder(): TreeBuilder ->end() ->arrayNode('model') ->children() - ->scalarNode('name')->isRequired()->end() + ->scalarNode('className')->isRequired()->end() ->scalarNode('version')->defaultNull()->end() ->arrayNode('options') ->scalarPrototype()->end() diff --git a/src/DependencyInjection/LlmChainExtension.php b/src/DependencyInjection/LlmChainExtension.php index 03a8678..b465653 100644 --- a/src/DependencyInjection/LlmChainExtension.php +++ b/src/DependencyInjection/LlmChainExtension.php @@ -25,11 +25,9 @@ use PhpLlm\LlmChain\Platform\Bridge\Meta\Llama; use PhpLlm\LlmChain\Platform\Bridge\Mistral\Mistral; use PhpLlm\LlmChain\Platform\Bridge\Mistral\PlatformFactory as MistralPlatformFactory; -use PhpLlm\LlmChain\Platform\Bridge\OpenAI\Embeddings; use PhpLlm\LlmChain\Platform\Bridge\OpenAI\GPT; use PhpLlm\LlmChain\Platform\Bridge\OpenAI\PlatformFactory as OpenAIPlatformFactory; use PhpLlm\LlmChain\Platform\Bridge\OpenRouter\PlatformFactory as OpenRouterPlatformFactory; -use PhpLlm\LlmChain\Platform\Bridge\Voyage\Voyage; use PhpLlm\LlmChain\Platform\Model; use PhpLlm\LlmChain\Platform\ModelClientInterface; use PhpLlm\LlmChain\Platform\Platform; @@ -457,14 +455,13 @@ private function processStoreConfig(string $type, array $stores, ContainerBuilde */ private function processEmbedderConfig(int|string $name, array $config, ContainerBuilder $container): void { - ['name' => $modelName, 'version' => $version, 'options' => $options] = $config['model']; + ['className' => $modelClassName, 'version' => $version, 'options' => $options] = $config['model']; - $modelClass = match (strtolower((string) $modelName)) { - 'embeddings' => Embeddings::class, - 'voyage' => Voyage::class, - default => throw new \InvalidArgumentException(sprintf('Model "%s" is not supported.', $modelName)), - }; - $modelDefinition = (new Definition($modelClass)); + if (!is_a($modelClassName, Model::class, true)) { + throw new \InvalidArgumentException(sprintf('"%s" class is not extending PhpLlm\LlmChain\Platform\Model.', $modelClassName)); + } + + $modelDefinition = (new Definition((string) $modelClassName)); if (null !== $version) { $modelDefinition->setArgument('$name', $version); }