Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pydantic_ai_slim/pydantic_ai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
WebSearchTool,
WebSearchUserLocation,
)
from .embeddings import (
Embedder,
)
from .exceptions import (
AgentRunError,
ApprovalRequired,
Expand Down Expand Up @@ -119,6 +122,8 @@
'UserPromptNode',
'capture_run_messages',
'InstrumentationSettings',
# embeddings
'Embedder',
# exceptions
'AgentRunError',
'CallDeferred',
Expand Down
137 changes: 137 additions & 0 deletions pydantic_ai_slim/pydantic_ai/embeddings/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
from collections.abc import Iterator, Sequence
from contextlib import contextmanager
from contextvars import ContextVar
from dataclasses import dataclass
from typing import Literal, overload

from typing_extensions import TypeAliasType

from pydantic_ai import _utils
from pydantic_ai.embeddings.embedding_model import EmbeddingModel
from pydantic_ai.embeddings.settings import EmbeddingSettings, merge_embedding_settings
from pydantic_ai.exceptions import UserError
from pydantic_ai.models.instrumented import InstrumentationSettings
from pydantic_ai.providers import infer_provider

KnownEmbeddingModelName = TypeAliasType(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a test like this one to verify this is up to date:

def test_known_model_names(): # pragma: lax no cover

'KnownEmbeddingModelName',
Literal[
'openai:text-embedding-ada-002',
'openai:text-embedding-3-small',
'openai:text-embedding-3-largecohere:embed-v4.0',
],
)
"""Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent].

`KnownModelName` is provided as a concise way to specify a model.
"""


def infer_model(model: EmbeddingModel | KnownEmbeddingModelName | str) -> EmbeddingModel:
"""Infer the model from the name."""
if isinstance(model, EmbeddingModel):
return model

try:
provider_name, model_name = model.split(':', maxsplit=1)
except ValueError as e:
raise ValueError('You must provide a provider prefix when specifying an embedding model name') from e

provider = infer_provider(provider_name)

model_kind = provider_name
if model_kind.startswith('gateway/'):
model_kind = provider_name.removeprefix('gateway/')

# TODO: extend the following list for other providers as appropriate
if model_kind in ('openai',):
model_kind = 'openai'

if model_kind == 'openai':
from .openai import OpenAIEmbeddingModel

return OpenAIEmbeddingModel(model_name, provider=provider)
elif model_kind == 'cohere':
from .cohere import CohereEmbeddingModel

return CohereEmbeddingModel(model_name, provider=provider)
else:
raise UserError(f'Unknown embeddings model: {model}') # pragma: no cover
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/ggozad/haiku.rag/tree/main/src/haiku/rag/embeddings has Ollama, vLLM and VoyageAI, which would be worth adding as well



@dataclass
class Embedder:
instrument: InstrumentationSettings | bool | None
"""Options to automatically instrument with OpenTelemetry."""

def __init__(
self,
model: EmbeddingModel | KnownEmbeddingModelName | str,
*,
settings: EmbeddingSettings | None = None,
defer_model_check: bool = True,
# TODO: Figure out instrumentation later..
instrument: InstrumentationSettings | bool | None = None,
) -> None:
self._model = model if defer_model_check else infer_model(model)
self._settings = settings
self._instrument = instrument

self._override_model: ContextVar[EmbeddingModel | None] = ContextVar('_override_model', default=None)

@property
def model(self) -> EmbeddingModel | KnownEmbeddingModelName | str:
return self._model

@contextmanager
def override(
self,
*,
model: EmbeddingModel | KnownEmbeddingModelName | str | _utils.Unset = _utils.UNSET,
) -> Iterator[None]:
if _utils.is_set(model):
model_token = self._override_model.set(infer_model(model))
else:
model_token = None

try:
yield
finally:
if model_token is not None:
self._override_model.reset(model_token)

@overload
async def embed(self, documents: str, *, settings: EmbeddingSettings | None = None) -> list[float]:
pass

@overload
async def embed(self, documents: Sequence[str], *, settings: EmbeddingSettings | None = None) -> list[list[float]]:
pass

async def embed(
self, documents: str | Sequence[str], *, settings: EmbeddingSettings | None = None
) -> list[float] | list[list[float]]:
model = self._get_model()
settings = merge_embedding_settings(self._settings, settings)
return await model.embed(documents, settings=settings)

def _get_model(self) -> EmbeddingModel:
"""Create a model configured for this agent.

Returns:
The embedding model to use
"""
model_: EmbeddingModel
if some_model := self._override_model.get():
model_ = some_model
else:
model_ = self._model = infer_model(self.model)

# TODO: Port the instrumentation logic from Model once we settle on an embeddings API
# instrument = self.instrument
# if instrument is None:
# instrument = Agent._instrument_default
#
# return instrument_model(model_, instrument)

return model_
104 changes: 104 additions & 0 deletions pydantic_ai_slim/pydantic_ai/embeddings/cohere.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import Literal, cast, overload

from pydantic_ai.embeddings.embedding_model import EmbeddingModel
from pydantic_ai.embeddings.settings import EmbeddingSettings
from pydantic_ai.providers import Provider, infer_provider

from .settings import merge_embedding_settings

try:
from cohere import AsyncClientV2
except ImportError as _import_error:
raise ImportError(
'Please install `cohere` to use the Cohere embeddings model, '
'you can use the `cohere` optional group — `pip install "pydantic-ai-slim[cohere]"`'
) from _import_error

LatestCohereEmbeddingModelNames = Literal[
'cohere:embed-v4.0',
# TODO: Add the others
]
"""Latest Cohere embeddings models."""

CohereEmbeddingModelName = str | LatestCohereEmbeddingModelNames
"""Possible Cohere embeddings model names."""


@dataclass(init=False)
class CohereEmbeddingModel(EmbeddingModel):
_model_name: CohereEmbeddingModelName = field(repr=False)
_provider: Provider[AsyncClientV2] = field(repr=False)

def __init__(
self,
model_name: CohereEmbeddingModelName,
*,
provider: Literal['cohere'] | Provider[AsyncClientV2] = 'cohere',
settings: EmbeddingSettings | None = None,
):
"""Initialize an Cohere model.
Args:
model_name: The name of the Cohere model to use. List of model names
available [here](https://docs.cohere.com/docs/models#command).
provider: The provider to use for authentication and API access. Can be either the string
'cohere' or an instance of `Provider[AsyncClientV2]`. If not provided, a new provider will be
created using the other parameters.
profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
settings: Model-specific settings that will be used as defaults for this model.
"""
self._model_name = model_name

if isinstance(provider, str):
provider = infer_provider(provider)
self._provider = provider
self._client = provider.client

super().__init__(settings=settings)

@property
def base_url(self) -> str:
"""The base URL for the provider API, if available."""
return self._provider.base_url

@property
def model_name(self) -> CohereEmbeddingModelName:
"""The embedding model name."""
return self._model_name

@property
def system(self) -> str:
"""The embedding model provider."""
return self._provider.name

@overload
async def embed(self, documents: str, *, settings: EmbeddingSettings | None = None) -> list[float]:
pass

@overload
async def embed(self, documents: Sequence[str], *, settings: EmbeddingSettings | None = None) -> list[list[float]]:
pass

async def embed(
self, documents: Sequence[str], *, settings: EmbeddingSettings | None = None
) -> list[float] | list[list[float]]:
input_is_string = isinstance(documents, str)
if input_is_string:
documents = [documents]

settings = merge_embedding_settings(self._settings, settings) or {}
response = await self._client.embed(
model=self.model_name,
input_type=settings.get('input_type', 'search_document'),
texts=cast(Sequence[str], documents),
output_dimension=settings.get('output_dimension'),
)
embeddings = response.embeddings.float_
assert embeddings is not None, 'This is a bug in cohere?'

if input_is_string:
return embeddings[0]

return embeddings
55 changes: 55 additions & 0 deletions pydantic_ai_slim/pydantic_ai/embeddings/embedding_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import overload

from pydantic_ai.embeddings.settings import EmbeddingSettings


class EmbeddingModel(ABC):
"""Abstract class for a model."""

_settings: EmbeddingSettings | None = None

def __init__(
self,
*,
settings: EmbeddingSettings | None = None,
) -> None:
"""Initialize the model with optional settings and profile.

Args:
settings: Model-specific settings that will be used as defaults for this model.
profile: The model profile to use.
"""
self._settings = settings

@property
def settings(self) -> EmbeddingSettings | None:
"""Get the model settings."""
return self._settings

@property
@abstractmethod
def model_name(self) -> str:
"""The model name."""
raise NotImplementedError()

# TODO: Add system?

@property
def base_url(self) -> str | None:
"""The base URL for the provider API, if available."""
return None

@overload
async def embed(self, documents: str, *, settings: EmbeddingSettings | None = None) -> list[float]:
pass

@overload
async def embed(self, documents: Sequence[str], *, settings: EmbeddingSettings | None = None) -> list[list[float]]:
pass

async def embed(
self, documents: str | Sequence[str], *, settings: EmbeddingSettings | None = None
) -> list[float] | list[list[float]]:
raise NotImplementedError
Loading
Loading