From 4633f687b68b7c03285829a0384b59664b14894a Mon Sep 17 00:00:00 2001 From: Laurent Sorber Date: Sun, 15 Jun 2025 12:41:32 +0200 Subject: [PATCH 1/3] style: apply all ruff rules --- pyproject.toml | 4 +- src/raglite/__init__.py | 2 +- src/raglite/_bench.py | 23 +++++++---- src/raglite/_chainlit.py | 6 ++- src/raglite/_chatml_function_calling.py | 55 +++++++++++++++++-------- src/raglite/_cli.py | 25 +++++++---- src/raglite/_config.py | 11 +++-- src/raglite/_database.py | 36 +++++++++------- src/raglite/_embed.py | 35 +++++++++++----- src/raglite/_eval.py | 30 +++++++++----- src/raglite/_extract.py | 4 +- src/raglite/_insert.py | 32 +++++++++----- src/raglite/_lazy_llama.py | 6 +-- src/raglite/_litellm.py | 10 +++-- src/raglite/_markdown.py | 5 ++- src/raglite/_mcp.py | 4 +- src/raglite/_query_adapter.py | 16 ++++--- src/raglite/_rag.py | 20 +++++---- src/raglite/_search.py | 38 +++++++++++------ src/raglite/_split_chunklets.py | 3 +- src/raglite/_split_chunks.py | 3 +- src/raglite/_split_sentences.py | 6 ++- src/raglite/_typing.py | 35 ++++++++++++---- tests/test_chatml_function_calling.py | 15 ++++--- tests/test_extract.py | 8 +++- tests/test_insert.py | 2 +- tests/test_lazy_llama.py | 2 +- tests/test_rerank.py | 4 +- tests/test_search.py | 3 +- tests/test_split_sentences.py | 8 +++- 30 files changed, 305 insertions(+), 146 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f354425..618e6a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -144,8 +144,8 @@ target-version = "py310" docstring-code-format = true [tool.ruff.lint] -select = ["A", "ASYNC", "B", "BLE", "C4", "C90", "D", "DTZ", "E", "EM", "ERA", "F", "FBT", "FLY", "FURB", "G", "I", "ICN", "INP", "INT", "ISC", "LOG", "N", "NPY", "PERF", "PGH", "PIE", "PL", "PT", "PTH", "PYI", "Q", "RET", "RSE", "RUF", "S", "SIM", "SLF", "SLOT", "T10", "T20", "TCH", "TID", "TRY", "UP", "W", "YTT"] -ignore = ["D203", "D213", "E501", "RET504", "RUF002", "RUF022", "S101", "S307", "TC004"] +select = ["ALL"] +ignore = ["CPY", "FIX", "ARG001", "COM812", "D203", "D213", "E501", "PD008", "PD009", "RET504", "S101", "TD003"] unfixable = ["ERA001", "F401", "F841", "T201", "T203"] [tool.ruff.lint.flake8-tidy-imports] diff --git a/src/raglite/__init__.py b/src/raglite/__init__.py index df2e30d..a0b9995 100644 --- a/src/raglite/__init__.py +++ b/src/raglite/__init__.py @@ -17,7 +17,7 @@ vector_search, ) -__all__ = [ +__all__ = [ # noqa: RUF022 # Config "RAGLiteConfig", # Insert diff --git a/src/raglite/_bench.py b/src/raglite/_bench.py index 3356c3c..9026141 100644 --- a/src/raglite/_bench.py +++ b/src/raglite/_bench.py @@ -94,7 +94,7 @@ def __init__( insert_variant: str | None = None, search_variant: str | None = None, config: RAGLiteConfig | None = None, - ): + ) -> None: super().__init__( dataset, num_results=num_results, @@ -145,7 +145,7 @@ def __init__( num_results: int = 10, insert_variant: str | None = None, search_variant: str | None = None, - ): + ) -> None: super().__init__( dataset, num_results=num_results, @@ -156,7 +156,7 @@ def __init__( self.embedder_dim = 3072 self.persist_path = self.cwd / self.insert_id - def insert_documents(self, max_workers: int | None = None) -> None: + def insert_documents(self, max_workers: int | None = None) -> None: # noqa: ARG002 # Adapted from https://docs.llamaindex.ai/en/stable/examples/vector_stores/FaissIndexDemo/. import faiss from llama_index.core import Document, StorageContext, VectorStoreIndex @@ -178,14 +178,15 @@ def insert_documents(self, max_workers: int | None = None) -> None: index.storage_context.persist(persist_dir=self.persist_path) @cached_property - def index(self) -> Any: + def index(self) -> Any: # noqa: ANN401 from llama_index.core import StorageContext, load_index_from_storage from llama_index.embeddings.openai import OpenAIEmbedding from llama_index.vector_stores.faiss import FaissVectorStore vector_store = FaissVectorStore.from_persist_dir(persist_dir=self.persist_path.as_posix()) storage_context = StorageContext.from_defaults( - vector_store=vector_store, persist_dir=self.persist_path.as_posix() + vector_store=vector_store, + persist_dir=self.persist_path.as_posix(), ) embed_model = OpenAIEmbedding(model=self.embedder, dimensions=self.embedder_dim) index = load_index_from_storage(storage_context, embed_model=embed_model) @@ -215,7 +216,7 @@ def __init__( num_results: int = 10, insert_variant: str | None = None, search_variant: str | None = None, - ): + ) -> None: super().__init__( dataset, num_results=num_results, @@ -227,7 +228,7 @@ def __init__( ) @cached_property - def client(self) -> Any: + def client(self) -> Any: # noqa: ANN401 import openai return openai.OpenAI() @@ -269,7 +270,9 @@ def insert_documents(self, max_workers: int | None = None) -> None: files.append(temp_file.open("rb")) if len(files) == max_files_per_batch or (i == self.dataset.docs_count() - 1): self.client.vector_stores.file_batches.upload_and_poll( - vector_store_id=vector_store.id, files=files, max_concurrency=max_workers + vector_store_id=vector_store.id, + files=files, + max_concurrency=max_workers, ) for f in files: f.close() @@ -283,7 +286,9 @@ def search(self, query_id: str, query: str, *, num_results: int = 10) -> list[Sc if not self.vector_store_id: return [] response = self.client.vector_stores.search( - vector_store_id=self.vector_store_id, query=query, max_num_results=2 * num_results + vector_store_id=self.vector_store_id, + query=query, + max_num_results=2 * num_results, ) scored_docs = [ ScoredDoc( diff --git a/src/raglite/_chainlit.py b/src/raglite/_chainlit.py index 8c94f52..402d6f1 100644 --- a/src/raglite/_chainlit.py +++ b/src/raglite/_chainlit.py @@ -39,7 +39,7 @@ async def start_chat() -> None: TextInput(id="llm", label="LLM", initial=config.llm), TextInput(id="embedder", label="Embedder", initial=config.embedder), Switch(id="vector_search_query_adapter", label="Query adapter", initial=True), - ] + ], ).send() await update_config(settings) @@ -95,7 +95,9 @@ async def handle_message(user_message: cl.Message) -> None: messages: list[dict[str, str]] = cl.chat_context.to_openai()[:-1] # type: ignore[no-untyped-call] messages.append({"role": "user", "content": user_prompt}) async for token in async_rag( - messages, on_retrieval=lambda x: chunk_spans.extend(x), config=config + messages, + on_retrieval=lambda x: chunk_spans.extend(x), + config=config, ): await assistant_message.stream_token(token) # Append RAG sources, if any. diff --git a/src/raglite/_chatml_function_calling.py b/src/raglite/_chatml_function_calling.py index dcbaa4d..bed4bda 100644 --- a/src/raglite/_chatml_function_calling.py +++ b/src/raglite/_chatml_function_calling.py @@ -98,9 +98,9 @@ def _convert_chunks_to_completion( { "text": text, "index": 0, - "logprobs": logprobs, # TODO: Improve accumulation of logprobs + "logprobs": logprobs, # TODO(lsorber): Improve accumulation of logprobs "finish_reason": finish_reason, # type: ignore[typeddict-item] - } + }, ], } # Add usage section if present in the chunks @@ -131,7 +131,8 @@ def _stream_tool_calls( prompt += f"functions.{tool_name}:\n" try: grammar = llama_grammar.LlamaGrammar.from_json_schema( - json.dumps(tool["function"]["parameters"]), verbose=llama.verbose + json.dumps(tool["function"]["parameters"]), + verbose=llama.verbose, ) except Exception as e: warnings.warn( @@ -140,7 +141,8 @@ def _stream_tool_calls( stacklevel=2, ) grammar = llama_grammar.LlamaGrammar.from_string( - llama_grammar.JSON_GBNF, verbose=llama.verbose + llama_grammar.JSON_GBNF, + verbose=llama.verbose, ) completion_or_chunks = llama.create_completion( prompt=prompt, @@ -182,7 +184,8 @@ def _stream_tool_calls( "stop": [*completion_kwargs["stop"], ":", ""], "max_tokens": None, "grammar": llama_grammar.LlamaGrammar.from_string( - follow_up_gbnf_tool_grammar, verbose=llama.verbose + follow_up_gbnf_tool_grammar, + verbose=llama.verbose, ), }, ), @@ -253,7 +256,7 @@ def chatml_function_calling_with_streaming( grammar: Optional[llama.LlamaGrammar] = None, # type: ignore[name-defined] logprobs: Optional[bool] = None, top_logprobs: Optional[int] = None, - **kwargs: Any, + **kwargs: Any, # noqa: ANN401 ) -> Union[ llama_types.CreateChatCompletionResponse, Iterator[llama_types.CreateChatCompletionStreamResponse], @@ -381,7 +384,10 @@ def chatml_function_calling_with_streaming( or len(tools) == 0 ): prompt = template_renderer.render( - messages=messages, tools=[], tool_calls=None, add_generation_prompt=True + messages=messages, + tools=[], + tool_calls=None, + add_generation_prompt=True, ) return llama_chat_format._convert_completion_to_chat( # noqa: SLF001 llama.create_completion( @@ -404,7 +410,10 @@ def chatml_function_calling_with_streaming( assert tools function_names = " | ".join([f'''"functions.{t["function"]["name"]}:"''' for t in tools]) prompt = template_renderer.render( - messages=messages, tools=tools, tool_calls=True, add_generation_prompt=True + messages=messages, + tools=tools, + tool_calls=True, + add_generation_prompt=True, ) initial_gbnf_tool_grammar = ( ( @@ -429,7 +438,8 @@ def chatml_function_calling_with_streaming( "stream": False, "max_tokens": None, "grammar": llama_grammar.LlamaGrammar.from_string( - initial_gbnf_tool_grammar, verbose=llama.verbose + initial_gbnf_tool_grammar, + verbose=llama.verbose, ), }, ), @@ -449,7 +459,10 @@ def chatml_function_calling_with_streaming( # Case 2 step 2A: Respond with a message if tool_name is None: prompt = template_renderer.render( - messages=messages, tools=[], tool_calls=None, add_generation_prompt=True + messages=messages, + tools=[], + tool_calls=None, + add_generation_prompt=True, ) prompt += think return llama_chat_format._convert_completion_to_chat( # noqa: SLF001 @@ -469,7 +482,12 @@ def chatml_function_calling_with_streaming( prompt += "\n" if stream: return _stream_tool_calls( - llama, prompt, tools, tool_name, completion_kwargs, follow_up_gbnf_tool_grammar + llama, + prompt, + tools, + tool_name, + completion_kwargs, + follow_up_gbnf_tool_grammar, ) tool = next((tool for tool in tools if tool["function"]["name"] == tool_name), None) completions: List[llama_types.CreateCompletionResponse] = [] @@ -479,7 +497,8 @@ def chatml_function_calling_with_streaming( prompt += f"functions.{tool_name}:\n" try: grammar = llama_grammar.LlamaGrammar.from_json_schema( - json.dumps(tool["function"]["parameters"]), verbose=llama.verbose + json.dumps(tool["function"]["parameters"]), + verbose=llama.verbose, ) except Exception as e: warnings.warn( @@ -488,7 +507,8 @@ def chatml_function_calling_with_streaming( stacklevel=2, ) grammar = llama_grammar.LlamaGrammar.from_string( - llama_grammar.JSON_GBNF, verbose=llama.verbose + llama_grammar.JSON_GBNF, + verbose=llama.verbose, ) completion_or_chunks = llama.create_completion( prompt=prompt, @@ -515,7 +535,8 @@ def chatml_function_calling_with_streaming( "stop": [*completion_kwargs["stop"], ":", ""], # type: ignore[misc] "max_tokens": None, "grammar": llama_grammar.LlamaGrammar.from_string( - follow_up_gbnf_tool_grammar, verbose=llama.verbose + follow_up_gbnf_tool_grammar, + verbose=llama.verbose, ), }, ), @@ -533,7 +554,7 @@ def chatml_function_calling_with_streaming( "finish_reason": "tool_calls", "index": 0, "logprobs": _convert_text_completion_logprobs_to_chat( - completion["choices"][0]["logprobs"] + completion["choices"][0]["logprobs"], ), "message": { "role": "assistant", @@ -548,11 +569,11 @@ def chatml_function_calling_with_streaming( }, } for i, (tool_name, completion) in enumerate( - zip(completions_tool_name, completions, strict=True) + zip(completions_tool_name, completions, strict=True), ) ], }, - } + }, ], "usage": { "completion_tokens": sum( diff --git a/src/raglite/_cli.py b/src/raglite/_cli.py index 34fcd24..f2664e6 100644 --- a/src/raglite/_cli.py +++ b/src/raglite/_cli.py @@ -14,7 +14,9 @@ class RAGLiteCLIConfig(BaseSettings): """RAGLite CLI config.""" model_config: ClassVar[SettingsConfigDict] = SettingsConfigDict( - env_prefix="RAGLITE_", env_file=".env", extra="allow" + env_prefix="RAGLITE_", + env_file=".env", + extra="allow", ) mcp_server_name: str = "RAGLite" @@ -67,7 +69,7 @@ def install_mcp_server( claude_config_path = get_claude_config_path() if not claude_config_path: typer.echo( - "Please download the Claude desktop app from https://claude.ai/download before installing an MCP server." + "Please download the Claude desktop app from https://claude.ai/download before installing an MCP server.", ) return claude_config_filepath = claude_config_path / "claude_desktop_config.json" @@ -88,7 +90,7 @@ def install_mcp_server( "--python", "3.11", "--with", - "numpy<2.0.0", # TODO: Remove this constraint when uv no longer needs it to solve the environment. + "numpy<2.0.0", # TODO(lsorber): Remove this constraint when uv no longer needs it to solve the environment. "raglite", "mcp", "run", @@ -112,7 +114,9 @@ def run_mcp_server( from raglite._mcp import create_mcp_server config = RAGLiteConfig( - db_url=ctx.obj["db_url"], llm=ctx.obj["llm"], embedder=ctx.obj["embedder"] + db_url=ctx.obj["db_url"], + llm=ctx.obj["llm"], + embedder=ctx.obj["embedder"], ) mcp = create_mcp_server(server_name, config=config) mcp.run() @@ -122,7 +126,10 @@ def run_mcp_server( def bench( ctx: typer.Context, dataset_name: str = typer.Option( - "nano-beir/hotpotqa", "--dataset", "-d", help="Dataset to use from https://ir-datasets.com/" + "nano-beir/hotpotqa", + "--dataset", + "-d", + help="Dataset to use from https://ir-datasets.com/", ), measure: str = typer.Option( "AP@10", @@ -157,7 +164,9 @@ def bench( ) dataset = ir_datasets.load(dataset_name) evaluator = RAGLiteEvaluator( - dataset, insert_variant=f"single-vector-{chunk_max_size // 4}t", config=config + dataset, + insert_variant=f"single-vector-{chunk_max_size // 4}t", + config=config, ) index.append("RAGLite (single-vector)") results.append(ir_measures.calc_aggregate(measures, dataset.qrels_iter(), evaluator.score())) @@ -170,7 +179,9 @@ def bench( ) dataset = ir_datasets.load(dataset_name) evaluator = RAGLiteEvaluator( - dataset, insert_variant=f"multi-vector-{chunk_max_size // 4}t", config=config + dataset, + insert_variant=f"multi-vector-{chunk_max_size // 4}t", + config=config, ) index.append("RAGLite (multi-vector)") results.append(ir_measures.calc_aggregate(measures, dataset.qrels_iter(), evaluator.score())) diff --git a/src/raglite/_config.py b/src/raglite/_config.py index 3988060..37071cd 100644 --- a/src/raglite/_config.py +++ b/src/raglite/_config.py @@ -24,9 +24,12 @@ # Lazily load the default search method to avoid circular imports. -# TODO: Replace with search_and_rerank_chunk_spans after benchmarking. +# TODO(lsorber): Replace with search_and_rerank_chunk_spans after benchmarking. def _vector_search( - query: str, *, num_results: int = 8, config: "RAGLiteConfig | None" = None + query: str, + *, + num_results: int = 8, + config: "RAGLiteConfig | None" = None, ) -> tuple[list[ChunkId], list[float]]: from raglite._search import vector_search @@ -45,7 +48,7 @@ class RAGLiteConfig: "llama-cpp-python/unsloth/Qwen3-8B-GGUF/*Q4_K_M.gguf@8192" if llama_supports_gpu_offload() else "llama-cpp-python/unsloth/Qwen3-4B-GGUF/*Q4_K_M.gguf@8192" - ) + ), ) llm_max_tries: int = 4 # Embedder config used for indexing. @@ -54,7 +57,7 @@ class RAGLiteConfig: "llama-cpp-python/lm-kit/bge-m3-gguf/*F16.gguf@512" if llama_supports_gpu_offload() or (os.cpu_count() or 1) >= 4 # noqa: PLR2004 else "llama-cpp-python/lm-kit/bge-m3-gguf/*Q4_K_M.gguf@512" - ) + ), ) embedder_normalize: bool = True # Chunk config used to partition documents into chunks. diff --git a/src/raglite/_database.py b/src/raglite/_database.py index f4c76a8..0b9ea1a 100644 --- a/src/raglite/_database.py +++ b/src/raglite/_database.py @@ -68,7 +68,7 @@ class Document(SQLModel, table=True): chunks: list["Chunk"] = Relationship(back_populates="document", cascade_delete=True) evals: list["Eval"] = Relationship(back_populates="document", cascade_delete=True) - def __init__(self, **kwargs: Any) -> None: + def __init__(self, **kwargs: Any) -> None: # noqa: ANN401 # Workaround for https://github.com/fastapi/sqlmodel/issues/149. super().__init__(**kwargs) self.content = kwargs.get("content") @@ -87,7 +87,7 @@ def from_path( *, id: DocumentId | None = None, # noqa: A002 url: str | None = None, - **kwargs: Any, + **kwargs: Any, # noqa: ANN401 ) -> "Document": """Create a document from a file path. @@ -133,7 +133,7 @@ def from_text( id: DocumentId | None = None, # noqa: A002 url: str | None = None, filename: str | None = None, - **kwargs: Any, + **kwargs: Any, # noqa: ANN401 ) -> "Document": """Create a document from text content. @@ -197,7 +197,11 @@ class Chunk(SQLModel, table=True): @staticmethod def from_body( - document: Document, index: int, body: str, headings: str = "", **kwargs: Any + document: Document, + index: int, + body: str, + headings: str = "", + **kwargs: Any, # noqa: ANN401 ) -> "Chunk": """Create a chunk from Markdown.""" return Chunk( @@ -324,7 +328,7 @@ def to_xml(self, index: int | None = None) -> str: f"\n{escape(''.join(chunk.body for chunk in self.chunks).strip())}\n", "", "", - ] + ], ) return xml_document @@ -410,10 +414,11 @@ class IndexMetadata(SQLModel, table=True): # Table columns. id: IndexId = Field(..., primary_key=True) version: datetime.datetime = Field( - default_factory=lambda: datetime.datetime.now(datetime.timezone.utc) + default_factory=lambda: datetime.datetime.now(datetime.timezone.utc), ) metadata_: dict[str, Any] = Field( - default_factory=dict, sa_column=Column("metadata", PickledObject) + default_factory=dict, + sa_column=Column("metadata", PickledObject), ) @staticmethod @@ -453,7 +458,10 @@ class Eval(SQLModel, table=True): @staticmethod def from_chunks( - question: str, contexts: list[Chunk], ground_truth: str, **kwargs: Any + question: str, + contexts: list[Chunk], + ground_truth: str, + **kwargs: Any, # noqa: ANN401 ) -> "Eval": """Create a chunk from Markdown.""" document_id = contexts[0].document_id @@ -521,7 +529,7 @@ def create_database_engine(config: RAGLiteConfig | None = None) -> Engine: # no session.execute( text(""" CREATE INDEX IF NOT EXISTS keyword_search_chunk_index ON chunk USING GIN (to_tsvector('simple', body)); - """) + """), ) metrics = {"cosine": "cosine", "dot": "ip", "l1": "l1", "l2": "l2"} create_vector_index_sql = f""" @@ -534,7 +542,7 @@ def create_database_engine(config: RAGLiteConfig | None = None) -> Engine: # no """ # Enable iterative scan for pgvector v0.8.0 and up. pgvector_version = session.execute( - text("SELECT extversion FROM pg_extension WHERE extname = 'vector'") + text("SELECT extversion FROM pg_extension WHERE extname = 'vector'"), ).scalar_one() if pgvector_version and version.parse(pgvector_version) >= version.parse("0.8.0"): create_vector_index_sql += f"\nSET hnsw.iterative_scan = {'relaxed_order' if config.reranker else 'strict_order'};" @@ -548,20 +556,20 @@ def create_database_engine(config: RAGLiteConfig | None = None) -> Engine: # no num_chunks = session.execute(text("SELECT COUNT(*) FROM chunk")).scalar_one() try: num_indexed_chunks = session.execute( - text("SELECT COUNT(*) FROM fts_main_chunk.docs") + text("SELECT COUNT(*) FROM fts_main_chunk.docs"), ).scalar_one() except ProgrammingError: num_indexed_chunks = 0 if num_indexed_chunks == 0 or num_indexed_chunks != num_chunks: session.execute( - text("PRAGMA create_fts_index('chunk', 'id', 'body', overwrite = 1);") + text("PRAGMA create_fts_index('chunk', 'id', 'body', overwrite = 1);"), ) # Create a vector search index with VSS if it doesn't exist. session.execute( text(f""" SET hnsw_ef_search = {ef_search}; SET hnsw_enable_experimental_persistence = true; - """) + """), ) vss_index_exists = session.execute( text(""" @@ -570,7 +578,7 @@ def create_database_engine(config: RAGLiteConfig | None = None) -> Engine: # no WHERE schema_name = current_schema() AND table_name = 'chunk_embedding' AND index_name = 'vector_search_chunk_index' - """) + """), ).scalar_one() if not vss_index_exists: metrics = {"cosine": "cosine", "dot": "ip", "l2": "l2sq"} diff --git a/src/raglite/_embed.py b/src/raglite/_embed.py index 0e523f5..5e49e4a 100644 --- a/src/raglite/_embed.py +++ b/src/raglite/_embed.py @@ -14,16 +14,22 @@ def embed_strings_with_late_chunking( # noqa: C901,PLR0915 - sentences: list[str], *, config: RAGLiteConfig | None = None + sentences: list[str], + *, + config: RAGLiteConfig | None = None, ) -> FloatMatrix: """Embed a document's sentences with late chunking.""" def _count_tokens( - sentences: list[str], embedder: Llama, sentinel_char: str, sentinel_tokens: list[int] + sentences: list[str], + embedder: Llama, + sentinel_char: str, + sentinel_tokens: list[int], ) -> list[int]: # Join the sentences with the sentinel token and tokenise the result. sentences_tokens = np.asarray( - embedder.tokenize(sentinel_char.join(sentences).encode(), add_bos=False), dtype=np.intp + embedder.tokenize(sentinel_char.join(sentences).encode(), add_bos=False), + dtype=np.intp, ) # Map all sentinel token variants to the first one. for sentinel_token in sentinel_tokens[1:]: @@ -62,7 +68,9 @@ def _create_segment( config = config or RAGLiteConfig() assert config.embedder.startswith("llama-cpp-python") embedder = LlamaCppPythonLLM.llm( - config.embedder, embedding=True, pooling_type=LLAMA_POOLING_TYPE_NONE + config.embedder, + embedding=True, + pooling_type=LLAMA_POOLING_TYPE_NONE, ) n_ctx = embedder.n_ctx() n_batch = embedder.n_batch @@ -78,7 +86,7 @@ def _create_segment( # Compute the number of tokens per sentence. We use a method based on a sentinel token to # minimise the number of calls to embedder.tokenize, which incurs a significant overhead # (presumably to load the tokenizer) [1]. - # TODO: Make token counting faster and more robust once [1] is fixed. + # TODO(lsorber): Make token counting faster and more robust once [1] is fixed. # [1] https://github.com/abetlen/llama-cpp-python/issues/1763 num_tokens_list: list[int] = [] sentence_batch, sentence_batch_len = [], 0 @@ -87,14 +95,14 @@ def _create_segment( sentence_batch_len += len(sentence) if i == len(sentences) - 1 or sentence_batch_len > (n_ctx // 2): num_tokens_list.extend( - _count_tokens(sentence_batch, embedder, sentinel_char, sentinel_tokens) + _count_tokens(sentence_batch, embedder, sentinel_char, sentinel_tokens), ) sentence_batch, sentence_batch_len = [], 0 num_tokens = np.asarray(num_tokens_list, dtype=np.intp) # Compute the maximum number of tokens for each segment's preamble and content. # Unfortunately, llama-cpp-python truncates the input to n_batch tokens and crashes if you try # to increase it [1]. Until this is fixed, we have to limit max_tokens to n_batch. - # TODO: Improve the context window size once [1] is fixed. + # TODO(lsorber): Improve the context window size once [1] is fixed. # [1] https://github.com/abetlen/llama-cpp-python/issues/1762 max_tokens = min(n_ctx, n_batch) - 16 max_tokens_preamble = round(0.382 * max_tokens) # Golden ratio. @@ -104,7 +112,10 @@ def _create_segment( content_start_index = 0 while content_start_index < len(sentences): segment_start_index, segment_end_index = _create_segment( - content_start_index, max_tokens_preamble, max_tokens_content, num_tokens + content_start_index, + max_tokens_preamble, + max_tokens_content, + num_tokens, ) segments.append((segment_start_index, content_start_index, segment_end_index)) content_start_index = segment_end_index @@ -149,7 +160,9 @@ def _embed_string_batch(string_batch: list[str], *, config: RAGLiteConfig) -> Fl # embeddings because token embeddings are universally supported, while sequence # embeddings are only supported by some models. embedder = LlamaCppPythonLLM.llm( - config.embedder, embedding=True, pooling_type=LLAMA_POOLING_TYPE_NONE + config.embedder, + embedding=True, + pooling_type=LLAMA_POOLING_TYPE_NONE, ) embeddings = np.asarray([np.mean(row, axis=0) for row in embedder.embed(string_batch)]) else: @@ -166,7 +179,9 @@ def _embed_string_batch(string_batch: list[str], *, config: RAGLiteConfig) -> Fl def embed_strings_without_late_chunking( - strings: list[str], *, config: RAGLiteConfig | None = None + strings: list[str], + *, + config: RAGLiteConfig | None = None, ) -> FloatMatrix: """Embed a list of text strings in batches.""" config = config or RAGLiteConfig() diff --git a/src/raglite/_eval.py b/src/raglite/_eval.py index ab2296c..451b308 100644 --- a/src/raglite/_eval.py +++ b/src/raglite/_eval.py @@ -29,10 +29,11 @@ class QuestionResponse(BaseModel): """A specific question about the content of a set of document contexts.""" model_config = ConfigDict( - extra="forbid" # Forbid extra attributes as required by OpenAI's strict mode. + extra="forbid", # Forbid extra attributes as required by OpenAI's strict mode. ) question: str = Field( - ..., description="A specific question about the content of a set of document contexts." + ..., + description="A specific question about the content of a set of document contexts.", ) system_prompt: ClassVar[str] = """ You are given a set of contexts extracted from a document. @@ -69,7 +70,7 @@ def validate_question(cls, value: str) -> str: select(Chunk) .where(Chunk.document_id == seed_document.id) .order_by(func.random()) - .limit(1) + .limit(1), ).first() assert isinstance(seed_chunk, Chunk) # Expand the seed chunk into a set of related chunks. @@ -84,11 +85,16 @@ def validate_question(cls, value: str) -> str: ] # Extract a question from the seed chunk's related chunks. question = extract_with_llm( - QuestionResponse, related_chunks, strict=True, config=config + QuestionResponse, + related_chunks, + strict=True, + config=config, ).question # Search for candidate chunks to answer the generated question. candidate_chunk_ids, _ = vector_search( - query=question, num_results=2 * max_chunks, config=config + query=question, + num_results=2 * max_chunks, + config=config, ) candidate_chunks = [session.get(Chunk, chunk_id) for chunk_id in candidate_chunk_ids] @@ -97,7 +103,7 @@ class ContextEvalResponse(BaseModel): """Indicate whether the provided context can be used to answer a given question.""" model_config = ConfigDict( - extra="forbid" # Forbid extra attributes as required by OpenAI's strict mode. + extra="forbid", # Forbid extra attributes as required by OpenAI's strict mode. ) hit: bool = Field( ..., @@ -120,7 +126,10 @@ class ContextEvalResponse(BaseModel): ): try: context_eval_response = extract_with_llm( - ContextEvalResponse, str(candidate_chunk), strict=True, config=config + ContextEvalResponse, + str(candidate_chunk), + strict=True, + config=config, ) except ValueError: # noqa: PERF203 pass @@ -136,7 +145,7 @@ class AnswerResponse(BaseModel): """Answer a question using the provided context.""" model_config = ConfigDict( - extra="forbid" # Forbid extra attributes as required by OpenAI's strict mode. + extra="forbid", # Forbid extra attributes as required by OpenAI's strict mode. ) answer: str = Field( ..., @@ -229,7 +238,8 @@ def answer_evals( def evaluate( - answered_evals: "pd.DataFrame | int" = 100, config: RAGLiteConfig | None = None + answered_evals: "pd.DataFrame | int" = 100, + config: RAGLiteConfig | None = None, ) -> "pd.DataFrame": """Evaluate the performance of a set of answered evals with Ragas.""" try: @@ -251,7 +261,7 @@ def evaluate( class RAGLiteRagasEmbeddings(BaseRagasEmbeddings): """A RAGLite embedder for Ragas.""" - def __init__(self, config: RAGLiteConfig | None = None): + def __init__(self, config: RAGLiteConfig | None = None) -> None: self.config = config or RAGLiteConfig() def embed_query(self, text: str) -> list[float]: diff --git a/src/raglite/_extract.py b/src/raglite/_extract.py index f904747..cac1a6b 100644 --- a/src/raglite/_extract.py +++ b/src/raglite/_extract.py @@ -15,7 +15,7 @@ def extract_with_llm( user_prompt: str | list[str], strict: bool = False, # noqa: FBT001, FBT002 config: RAGLiteConfig | None = None, - **kwargs: Any, + **kwargs: Any, # noqa: ANN401 ) -> T: """Extract structured data from unstructured text with an LLM. @@ -45,7 +45,7 @@ class MyNameResponse(BaseModel): # is disabled by default because it only supports a subset of JSON schema features [2]. # [1] https://docs.litellm.ai/docs/completion/json_mode # [2] https://platform.openai.com/docs/guides/structured-outputs#some-type-specific-keywords-are-not-yet-supported - # TODO: Fall back to {"type": "json_object"} if JSON schema is not supported by the LLM. + # TODO(lsorber): Fall back to {"type": "json_object"} if JSON schema isn't supported by the LLM. response_format: dict[str, Any] | None = ( { "type": "json_schema", diff --git a/src/raglite/_insert.py b/src/raglite/_insert.py index bd63c29..d7fe0d7 100644 --- a/src/raglite/_insert.py +++ b/src/raglite/_insert.py @@ -20,7 +20,8 @@ def _create_chunk_records( - document: Document, config: RAGLiteConfig + document: Document, + config: RAGLiteConfig, ) -> tuple[Document, list[Chunk], list[list[ChunkEmbedding]]]: """Process chunks into chunk and chunk embedding records.""" # Partition the document into chunks. @@ -38,7 +39,11 @@ def _create_chunk_records( for i, chunk in enumerate(chunks): # Create and append the chunk record. record = Chunk.from_body( - document=document, index=i, body=chunk, headings=headings, **document.metadata_ + document=document, + index=i, + body=chunk, + headings=headings, + **document.metadata_, ) chunk_records.append(record) # Update the Markdown headings with those of this chunk. @@ -53,19 +58,23 @@ def _create_chunk_records( [ ChunkEmbedding(chunk_id=chunk_record.id, embedding=chunklet_embedding) for chunklet_embedding in chunk_embedding - ] + ], ) else: # Embed the full chunks, including the current Markdown headings. full_chunk_embeddings = embed_strings_without_late_chunking( - [chunk_record.content for chunk_record in chunk_records], config=config + [chunk_record.content for chunk_record in chunk_records], + config=config, ) # Every chunk record is associated with a list of chunk embedding records. The chunk # embedding records each correspond to a linear combination of a chunklet embedding and an # embedding of the full chunk with Markdown headings. α = 0.15 # Benchmark-optimised value. # noqa: PLC2401 for chunk_record, chunk_embedding, full_chunk_embedding in zip( - chunk_records, chunk_embeddings, full_chunk_embeddings, strict=True + chunk_records, + chunk_embeddings, + full_chunk_embeddings, + strict=True, ): if config.vector_search_multivector: chunk_embedding_records_list.append( @@ -75,7 +84,7 @@ def _create_chunk_records( embedding=α * chunklet_embedding + (1 - α) * full_chunk_embedding, ) for chunklet_embedding in chunk_embedding - ] + ], ) else: chunk_embedding_records_list.append( @@ -83,8 +92,8 @@ def _create_chunk_records( ChunkEmbedding( chunk_id=chunk_record.id, embedding=full_chunk_embedding, - ) - ] + ), + ], ) return document, chunk_records, chunk_embedding_records_list @@ -125,7 +134,7 @@ def insert_documents( # noqa: C901 for i in range(0, len(documents), batch_size): doc_id_batch = [doc.id for doc in documents[i : i + batch_size]] existing_doc_ids.update( - session.exec(select(Document.id).where(col(Document.id).in_(doc_id_batch))).all() + session.exec(select(Document.id).where(col(Document.id).in_(doc_id_batch))).all(), ) documents = [doc for doc in documents if doc.id not in existing_doc_ids] if not documents: @@ -146,7 +155,10 @@ def insert_documents( # noqa: C901 Session(engine) as session, ThreadPoolExecutor(max_workers=max_workers) as executor, tqdm( - total=len(documents), desc="Inserting documents", unit="document", dynamic_ncols=True + total=len(documents), + desc="Inserting documents", + unit="document", + dynamic_ncols=True, ) as pbar, ): futures = [ diff --git a/src/raglite/_lazy_llama.py b/src/raglite/_lazy_llama.py index 8c10a51..c3a1040 100644 --- a/src/raglite/_lazy_llama.py +++ b/src/raglite/_lazy_llama.py @@ -36,17 +36,17 @@ def __getattr__(name: str) -> object: class LazyAttributeError: error_message = "To use llama.cpp models, please install `llama-cpp-python`." - def __init__(self, error: ModuleNotFoundError | None = None): + def __init__(self, error: ModuleNotFoundError | None = None) -> None: self.error = error def __getattr__(self, name: str) -> NoReturn: raise ModuleNotFoundError(self.error_message) from self.error - def __call__(self, *args: Any, **kwargs: Any) -> NoReturn: + def __call__(self, *args: Any, **kwargs: Any) -> NoReturn: # noqa: ARG002, ANN401 raise ModuleNotFoundError(self.error_message) from self.error class LazySubmoduleError: - def __init__(self, error: ModuleNotFoundError): + def __init__(self, error: ModuleNotFoundError) -> None: self.error = error def __getattr__(self, name: str) -> LazyAttributeError | type[LazyAttributeError]: diff --git a/src/raglite/_litellm.py b/src/raglite/_litellm.py index 6525920..50bf886 100644 --- a/src/raglite/_litellm.py +++ b/src/raglite/_litellm.py @@ -1,5 +1,7 @@ """Add support for llama-cpp-python models to LiteLLM.""" +# ruff: noqa: ANN401, ARG002 + import asyncio import contextlib import logging @@ -142,7 +144,7 @@ def llm(model: str, **kwargs: Any) -> Llama: "supports_function_calling": True, "supports_parallel_function_calling": True, "supports_vision": False, - } + }, } litellm.register_model(model_info) # type: ignore[attr-defined] return llm @@ -166,7 +168,9 @@ def _translate_openai_params(self, optional_params: dict[str, Any]) -> dict[str, return llama_cpp_python_params def _add_recommended_model_params( - self, model: str, llama_cpp_python_params: dict[str, Any] + self, + model: str, + llama_cpp_python_params: dict[str, Any], ) -> dict[str, Any]: """Add recommended model settings.""" recommended_settings = {} @@ -320,7 +324,7 @@ async def astreaming( # type: ignore[misc,override] # noqa: PLR0913 # Register the LlamaCppPythonLLM provider. if not any(provider["provider"] == "llama-cpp-python" for provider in litellm.custom_provider_map): litellm.custom_provider_map.append( - {"provider": "llama-cpp-python", "custom_handler": LlamaCppPythonLLM()} + {"provider": "llama-cpp-python", "custom_handler": LlamaCppPythonLLM()}, ) custom_llm_setup() # type: ignore[no-untyped-call] diff --git a/src/raglite/_markdown.py b/src/raglite/_markdown.py index b5269fb..ba87de4 100644 --- a/src/raglite/_markdown.py +++ b/src/raglite/_markdown.py @@ -40,7 +40,7 @@ def extract_font_size(span: dict[str, Any]) -> float: for block in page["blocks"] for line in block["lines"] for span in line["spans"] - ] + ], ) font_sizes = np.round(font_sizes * 2) / 2 unique_font_sizes, counts = np.unique(font_sizes, return_counts=True) @@ -113,7 +113,8 @@ def strip_page_numbers(pages: list[dict[str, Any]]) -> list[dict[str, Any]]: line for line in block["lines"] if not re.match( - r"^\s*[#0]*\d+\s*$", "".join(span["text"] for span in line["spans"]) + r"^\s*[#0]*\d+\s*$", + "".join(span["text"] for span in line["spans"]), ) ] return pages diff --git a/src/raglite/_mcp.py b/src/raglite/_mcp.py index ec97364..aea6a37 100644 --- a/src/raglite/_mcp.py +++ b/src/raglite/_mcp.py @@ -14,7 +14,7 @@ description=( "The `query` string MUST be a precise single-faceted question in the user's language.\n" "The `query` string MUST resolve all pronouns to explicit nouns." - ) + ), ), ] @@ -42,7 +42,7 @@ def search_knowledge_base(query: Query) -> str: rag_context = '{{"documents": [{elements}]}}'.format( elements=", ".join( chunk_span.to_json(index=i + 1) for i, chunk_span in enumerate(chunk_spans) - ) + ), ) return rag_context diff --git a/src/raglite/_query_adapter.py b/src/raglite/_query_adapter.py index 47957cc..ad4e832 100644 --- a/src/raglite/_query_adapter.py +++ b/src/raglite/_query_adapter.py @@ -1,6 +1,6 @@ """Compute and update an optimal query adapter.""" -# ruff: noqa: N806 +# ruff: noqa: N806, RUF002 from dataclasses import replace @@ -154,13 +154,19 @@ def update_query_adapter( Q = np.zeros((0, len(chunk_embedding.embedding))) T = np.zeros_like(Q) for eval_ in tqdm( - evals, desc="Optimizing evals", unit="eval", dynamic_ncols=True, leave=False + evals, + desc="Optimizing evals", + unit="eval", + dynamic_ncols=True, + leave=False, ): # Embed the question. q = embed_strings([eval_.question], config=config)[0] # Retrieve chunks that would be used to answer the question. chunk_ids, _ = vector_search( - q, num_results=optimize_top_k, config=config_no_query_adapter + q, + num_results=optimize_top_k, + config=config_no_query_adapter, ) retrieved_chunks = session.exec(select(Chunk).where(col(Chunk.id).in_(chunk_ids))).all() retrieved_chunks = sorted(retrieved_chunks, key=lambda chunk: chunk_ids.index(chunk.id)) @@ -173,13 +179,13 @@ def update_query_adapter( [ chunk.embedding_matrix[[np.argmax(chunk.embedding_matrix @ q)]] for chunk in np.array(retrieved_chunks)[is_relevant] - ] + ], ) N = np.vstack( [ chunk.embedding_matrix[[np.argmax(chunk.embedding_matrix @ q)]] for chunk in np.array(retrieved_chunks)[~is_relevant] - ] + ], ) # Compute the optimal target vector t for this query embedding q. t = _optimize_query_target(q, P, N, α=optimize_gap) diff --git a/src/raglite/_rag.py b/src/raglite/_rag.py index 07f676a..b5c0081 100644 --- a/src/raglite/_rag.py +++ b/src/raglite/_rag.py @@ -37,7 +37,10 @@ def retrieve_context( - query: str, *, num_chunks: int = 10, config: RAGLiteConfig | None = None + query: str, + *, + num_chunks: int = 10, + config: RAGLiteConfig | None = None, ) -> list[ChunkSpan]: """Retrieve context for RAG.""" # Call the search method. @@ -86,7 +89,8 @@ def _clip(messages: list[dict[str, str]], max_tokens: int) -> list[dict[str, str def _get_tools( - messages: list[dict[str, str]], config: RAGLiteConfig + messages: list[dict[str, str]], + config: RAGLiteConfig, ) -> tuple[list[dict[str, Any]] | None, dict[str, Any] | str | None]: """Get tools to search the knowledge base if no RAG context is provided in the messages.""" # Check if messages already contain RAG context or if the LLM supports tool use. @@ -125,7 +129,7 @@ def _get_tools( "additionalProperties": False, }, }, - } + }, ] if not messages_contain_rag_context else None @@ -153,10 +157,10 @@ def _run_tools( elements=", ".join( chunk_span.to_json(index=i + 1) for i, chunk_span in enumerate(chunk_spans) - ) + ), ), "tool_call_id": tool_call.id, - } + }, ) if chunk_spans and callable(on_retrieval): on_retrieval(chunk_spans) @@ -237,12 +241,14 @@ async def async_rag( # Add the tool call requests to the message array. messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr] # Run the tool calls to retrieve the RAG context and append the output to the message array. - # TODO: Make this async. + # TODO(lsorber): Make this async. messages.extend(_run_tools(tool_calls, on_retrieval, config)) # Asynchronously stream the assistant response. chunks = [] async_stream = await acompletion( - model=config.llm, messages=_clip(messages, max_tokens), stream=True + model=config.llm, + messages=_clip(messages, max_tokens), + stream=True, ) async for chunk in async_stream: chunks.append(chunk) diff --git a/src/raglite/_search.py b/src/raglite/_search.py index 652275c..a611761 100644 --- a/src/raglite/_search.py +++ b/src/raglite/_search.py @@ -49,7 +49,8 @@ def vector_search( corrected_oversample = oversample * config.chunk_max_size / RAGLiteConfig.chunk_max_size num_hits = round(corrected_oversample) * max(num_results, 10) dist = ChunkEmbedding.embedding.distance( # type: ignore[attr-defined] - query_embedding, metric=config.vector_search_distance_metric + query_embedding, + metric=config.vector_search_distance_metric, ).label("dist") sim = (1.0 - dist).label("sim") top_vectors = select(ChunkEmbedding.chunk_id, sim).order_by(dist).limit(num_hits).subquery() @@ -67,7 +68,10 @@ def vector_search( def keyword_search( - query: str, *, num_results: int = 3, config: RAGLiteConfig | None = None + query: str, + *, + num_results: int = 3, + config: RAGLiteConfig | None = None, ) -> tuple[list[ChunkId], list[float]]: """Search chunks using BM25 keyword search.""" # Read the config. @@ -100,7 +104,7 @@ def keyword_search( WHERE score IS NOT NULL ORDER BY score DESC LIMIT :limit; - """ + """, ) results = session.execute(statement, params={"query": query, "limit": num_results}) # Unpack the results. @@ -111,7 +115,10 @@ def keyword_search( def reciprocal_rank_fusion( - rankings: list[list[ChunkId]], *, k: int = 60, weights: list[float] | None = None + rankings: list[list[ChunkId]], + *, + k: int = 60, + weights: list[float] | None = None, ) -> tuple[list[ChunkId], list[float]]: """Reciprocal Rank Fusion.""" if weights is None: @@ -129,7 +136,8 @@ def reciprocal_rank_fusion( return [], [] # Rank RRF results according to descending RRF score. rrf_chunk_ids, rrf_score = zip( - *sorted(chunk_id_score.items(), key=lambda x: x[1], reverse=True), strict=True + *sorted(chunk_id_score.items(), key=lambda x: x[1], reverse=True), + strict=True, ) return list(rrf_chunk_ids), list(rrf_score) @@ -149,14 +157,17 @@ def hybrid_search( # noqa: PLR0913 ks_chunk_ids, _ = keyword_search(query, num_results=oversample * num_results, config=config) # Combine the results with Reciprocal Rank Fusion (RRF). chunk_ids, hybrid_score = reciprocal_rank_fusion( - [vs_chunk_ids, ks_chunk_ids], weights=[vector_search_weight, keyword_search_weight] + [vs_chunk_ids, ks_chunk_ids], + weights=[vector_search_weight, keyword_search_weight], ) chunk_ids, hybrid_score = chunk_ids[:num_results], hybrid_score[:num_results] return chunk_ids, hybrid_score def retrieve_chunks( - chunk_ids: list[ChunkId], *, config: RAGLiteConfig | None = None + chunk_ids: list[ChunkId], + *, + config: RAGLiteConfig | None = None, ) -> list[Chunk]: """Retrieve chunks by their ids.""" if not chunk_ids: @@ -167,8 +178,8 @@ def retrieve_chunks( select(Chunk) .where(col(Chunk.id).in_(chunk_ids)) # Eagerly load chunk.document. - .options(joinedload(Chunk.document)) # type: ignore[arg-type] - ).all() + .options(joinedload(Chunk.document)), # type: ignore[arg-type] + ).all(), ) chunks = sorted(chunks, key=lambda chunk: chunk_ids.index(chunk.id)) return chunks @@ -210,8 +221,8 @@ def retrieve_chunk_spans( select(Chunk) .where(or_(*neighbor_conditions)) # Eagerly load chunk.document. - .options(joinedload(Chunk.document)) # type: ignore[arg-type] - ).all() + .options(joinedload(Chunk.document)), # type: ignore[arg-type] + ).all(), ) # Deduplicate and sort the chunks by document_id and index (needed for groupby). unique_chunks = sorted(set(chunks), key=lambda chunk: (chunk.document_id, chunk.index)) @@ -237,7 +248,10 @@ def retrieve_chunk_spans( def rerank_chunks( - query: str, chunk_ids: list[ChunkId] | list[Chunk], *, config: RAGLiteConfig | None = None + query: str, + chunk_ids: list[ChunkId] | list[Chunk], + *, + config: RAGLiteConfig | None = None, ) -> list[Chunk]: """Rerank chunks according to their relevance to a given query.""" # Retrieve the chunks. diff --git a/src/raglite/_split_chunklets.py b/src/raglite/_split_chunklets.py index d4a64aa..7c0e250 100644 --- a/src/raglite/_split_chunklets.py +++ b/src/raglite/_split_chunklets.py @@ -58,7 +58,8 @@ def markdown_chunklet_boundaries(sentences: list[str]) -> FloatVector: def compute_num_statements(sentences: list[str]) -> FloatVector: """Compute the approximate number of statements of each sentence in a list of sentences.""" sentence_word_length = np.asarray( - [len(sentence.split()) for sentence in sentences], dtype=np.float64 + [len(sentence.split()) for sentence in sentences], + dtype=np.float64, ) q25, q75 = np.quantile(sentence_word_length, [0.25, 0.75]) q25 = max(q25, np.sqrt(np.finfo(np.float64).eps)) diff --git a/src/raglite/_split_chunks.py b/src/raglite/_split_chunks.py index d85b3e8..6550bf5 100644 --- a/src/raglite/_split_chunks.py +++ b/src/raglite/_split_chunks.py @@ -68,7 +68,8 @@ def split_chunks( # noqa: C901, PLR0915 partition_similarity = np.sum(X[:-1] * X[1:], axis=1) # Make partition similarity nonnegative before modification and optimisation. partition_similarity = np.maximum( - (partition_similarity + 1) / 2, np.sqrt(np.finfo(X.dtype).eps) + (partition_similarity + 1) / 2, + np.sqrt(np.finfo(X.dtype).eps), ) # Modify the partition similarity to encourage splitting on Markdown headings. prev_chunklet_is_heading = True diff --git a/src/raglite/_split_sentences.py b/src/raglite/_split_sentences.py index 1cbeb18..da84c8f 100644 --- a/src/raglite/_split_sentences.py +++ b/src/raglite/_split_sentences.py @@ -56,7 +56,11 @@ def get_markdown_heading_indexes(doc: str) -> list[tuple[int, int]]: def _split_sentences( - doc: str, probas: FloatVector, *, min_len: int, max_len: int | None = None + doc: str, + probas: FloatVector, + *, + min_len: int, + max_len: int | None = None, ) -> list[str]: # Solve an optimisation problem to find the best sentence boundaries given the predicted # boundary probabilities. The objective is to select boundaries that maximise the sum of the diff --git a/src/raglite/_typing.py b/src/raglite/_typing.py index 5a0cb73..1ef690d 100644 --- a/src/raglite/_typing.py +++ b/src/raglite/_typing.py @@ -1,5 +1,7 @@ """RAGLite typing.""" +# ruff: noqa: ANN401, ARG002 + import io import pickle from collections.abc import Callable @@ -31,13 +33,21 @@ class BasicSearchMethod(Protocol): def __call__( - self, query: str, *, num_results: int, config: "RAGLiteConfig | None" = None + self, + query: str, + *, + num_results: int, + config: "RAGLiteConfig | None" = None, ) -> tuple[list[ChunkId], list[float]]: ... class SearchMethod(Protocol): def __call__( - self, query: str, *, num_results: int, config: "RAGLiteConfig | None" = None + self, + query: str, + *, + num_results: int, + config: "RAGLiteConfig | None" = None, ) -> tuple[list[ChunkId], list[float]] | list["Chunk"] | list["ChunkSpan"]: ... @@ -47,7 +57,9 @@ class NumpyArray(TypeDecorator[np.ndarray[Any, np.dtype[np.floating[Any]]]]): impl = LargeBinary def process_bind_param( - self, value: np.ndarray[Any, np.dtype[np.floating[Any]]] | None, dialect: Dialect + self, + value: np.ndarray[Any, np.dtype[np.floating[Any]]] | None, + dialect: Dialect, ) -> bytes | None: """Convert a NumPy array to bytes.""" if value is None: @@ -57,7 +69,9 @@ def process_bind_param( return buffer.getvalue() def process_result_value( - self, value: bytes | None, dialect: Dialect + self, + value: bytes | None, + dialect: Dialect, ) -> np.ndarray[Any, np.dtype[np.floating[Any]]] | None: """Convert bytes to a NumPy array.""" if value is None: @@ -150,7 +164,9 @@ def process(value: FloatVector | None) -> str | None: return process def result_processor( - self, dialect: Dialect, coltype: Any + self, + dialect: Dialect, + coltype: Any, ) -> Callable[[str | None], FloatVector | None]: """Process PostgreSQL halfvec format to NumPy ndarray.""" @@ -175,7 +191,8 @@ def get_col_spec(self, **kwargs: Any) -> str: return f"FLOAT[{self.dim}]" if self.dim is not None else "FLOAT[]" def bind_processor( - self, dialect: Dialect + self, + dialect: Dialect, ) -> Callable[[FloatVector | None], list[float] | None]: """Process NumPy ndarray to DuckDB single precision vector format for bound parameters.""" @@ -185,7 +202,9 @@ def process(value: FloatVector | None) -> list[float] | None: return process def result_processor( - self, dialect: Dialect, coltype: Any + self, + dialect: Dialect, + coltype: Any, ) -> Callable[[list[float] | None], FloatVector | None]: """Process DuckDB single precision vector format to NumPy ndarray.""" @@ -202,7 +221,7 @@ class Embedding(TypeDecorator[FloatVector]): impl = NumpyArray comparator_factory: type[EmbeddingComparator] = EmbeddingComparator - def __init__(self, dim: int = -1): + def __init__(self, dim: int = -1) -> None: super().__init__() self.dim = dim diff --git a/tests/test_chatml_function_calling.py b/tests/test_chatml_function_calling.py index fd0c91a..d7f02a3 100644 --- a/tests/test_chatml_function_calling.py +++ b/tests/test_chatml_function_calling.py @@ -39,7 +39,8 @@ def is_accelerator_available() -> bool: pytest.param("none", id="tool_choice=none"), pytest.param("auto", id="tool_choice=auto"), pytest.param( - {"type": "function", "function": {"name": "get_weather"}}, id="tool_choice=fixed" + {"type": "function", "function": {"name": "get_weather"}}, + id="tool_choice=fixed", ), ], ) @@ -68,7 +69,8 @@ def is_accelerator_available() -> bool: "unsloth/Qwen3-8B-GGUF", id="qwen3_8B", marks=pytest.mark.skipif( - not is_accelerator_available(), reason="Accelerator not available" + not is_accelerator_available(), + reason="Accelerator not available", ), ), ], @@ -93,7 +95,7 @@ def test_llama_cpp_python_tool_use( chat_handler=chatml_function_calling_with_streaming, ) messages: list[llama_types.ChatCompletionRequestMessage] = [ - {"role": "user", "content": user_prompt} + {"role": "user", "content": user_prompt}, ] tools: list[llama_types.ChatCompletionTool] = [ { @@ -106,10 +108,13 @@ def test_llama_cpp_python_tool_use( "properties": {"location": {"type": "string", "description": "A city name."}}, }, }, - } + }, ] response = llm.create_chat_completion( - messages=messages, tools=tools, tool_choice=tool_choice, stream=stream + messages=messages, + tools=tools, + tool_choice=tool_choice, + stream=stream, ) if stream: response = cast("Iterator[llama_types.CreateChatCompletionStreamResponse]", response) diff --git a/tests/test_extract.py b/tests/test_extract.py index 9cf8dd1..a81df5c 100644 --- a/tests/test_extract.py +++ b/tests/test_extract.py @@ -10,7 +10,8 @@ @pytest.mark.parametrize( - "strict", [pytest.param(False, id="strict=False"), pytest.param(True, id="strict=True")] + "strict", + [pytest.param(False, id="strict=False"), pytest.param(True, id="strict=True")], ) def test_extract(llm: str, strict: bool) -> None: # noqa: FBT001 """Test extracting structured data.""" @@ -29,7 +30,10 @@ class LoginResponse(BaseModel): # Extract structured data. username, password = "cypher", "steak" login_response = extract_with_llm( - LoginResponse, f"username: {username}\npassword: {password}", strict=strict, config=config + LoginResponse, + f"username: {username}\npassword: {password}", + strict=strict, + config=config, ) # Validate the response. assert isinstance(login_response, LoginResponse) diff --git a/tests/test_insert.py b/tests/test_insert.py index acc9246..1ab62b2 100644 --- a/tests/test_insert.py +++ b/tests/test_insert.py @@ -19,7 +19,7 @@ def test_insert(raglite_test_config: RAGLiteConfig) -> None: assert document is not None, "No document found in the database" # Get the existing chunks for this document. chunks = session.exec( - select(Chunk).where(Chunk.document_id == document.id).order_by(Chunk.index) # type: ignore[arg-type] + select(Chunk).where(Chunk.document_id == document.id).order_by(Chunk.index), # type: ignore[arg-type] ).all() assert len(chunks) > 0, "No chunks found for the document" restored_document = "" diff --git a/tests/test_lazy_llama.py b/tests/test_lazy_llama.py index 42363c4..c68067d 100644 --- a/tests/test_lazy_llama.py +++ b/tests/test_lazy_llama.py @@ -19,7 +19,7 @@ def test_raglite_import_without_llama_cpp(monkeypatch: pytest.MonkeyPatch) -> No original_import = builtins.__import__ # Define a fake import function that raises ModuleNotFoundError when trying to import llama_cpp. - def fake_import(name: str, *args: Any) -> Any: + def fake_import(name: str, *args: Any) -> Any: # noqa: ANN401 if name.startswith("llama_cpp"): import_error = f"No module named '{name}'" raise ModuleNotFoundError(import_error) diff --git a/tests/test_rerank.py b/tests/test_rerank.py index 25ddf7b..b13f4a6 100644 --- a/tests/test_rerank.py +++ b/tests/test_rerank.py @@ -48,7 +48,9 @@ def test_reranker( """Test inserting a document, updating the indexes, and searching for a query.""" # Update the config with the reranker. raglite_test_config = RAGLiteConfig( - db_url=raglite_test_config.db_url, embedder=raglite_test_config.embedder, reranker=reranker + db_url=raglite_test_config.db_url, + embedder=raglite_test_config.embedder, + reranker=reranker, ) # Search for a query. query = "What does it mean for two events to be simultaneous?" diff --git a/tests/test_search.py b/tests/test_search.py index 8151018..f972422 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -57,7 +57,8 @@ def test_search(raglite_test_config: RAGLiteConfig, search_method: BasicSearchMe def test_search_no_results( - raglite_test_config: RAGLiteConfig, search_method: BasicSearchMethod + raglite_test_config: RAGLiteConfig, + search_method: BasicSearchMethod, ) -> None: """Test searching for a query with no keyword search results.""" query = "supercalifragilisticexpialidocious" diff --git a/tests/test_split_sentences.py b/tests/test_split_sentences.py index b8de530..8a40008 100644 --- a/tests/test_split_sentences.py +++ b/tests/test_split_sentences.py @@ -39,7 +39,9 @@ def test_split_sentences() -> None: assert all( sentence == expected_sentence for sentence, expected_sentence in zip( - sentences[: len(expected_sentences)], expected_sentences, strict=True + sentences[: len(expected_sentences)], + expected_sentences, + strict=True, ) ) @@ -74,6 +76,8 @@ def test_split_sentences_edge_cases(case: tuple[str, list[str], tuple[int, int | assert all( sentence == expected_sentence for sentence, expected_sentence in zip( - sentences[: len(expected_sentences)], expected_sentences, strict=True + sentences[: len(expected_sentences)], + expected_sentences, + strict=True, ) ) From 1d05914ef9e453e98232aa1c53783180e77f113f Mon Sep 17 00:00:00 2001 From: Laurent Sorber Date: Sun, 15 Jun 2025 12:51:57 +0200 Subject: [PATCH 2/3] style: skip magic trailing comma --- pyproject.toml | 4 ++ src/raglite/_bench.py | 11 +--- src/raglite/_chainlit.py | 6 +- src/raglite/_chatml_function_calling.py | 86 +++++++------------------ src/raglite/_cli.py | 23 ++----- src/raglite/_config.py | 9 +-- src/raglite/_database.py | 19 +++--- src/raglite/_embed.py | 31 +++------ src/raglite/_eval.py | 43 ++++--------- src/raglite/_insert.py | 39 +++-------- src/raglite/_litellm.py | 15 ++--- src/raglite/_markdown.py | 5 +- src/raglite/_mcp.py | 4 +- src/raglite/_query_adapter.py | 14 ++-- src/raglite/_rag.py | 20 ++---- src/raglite/_search.py | 38 ++++------- src/raglite/_split_chunklets.py | 3 +- src/raglite/_split_chunks.py | 13 +--- src/raglite/_split_sentences.py | 6 +- src/raglite/_typing.py | 38 +++-------- tests/test_chatml_function_calling.py | 37 +++-------- tests/test_extract.py | 8 +-- tests/test_insert.py | 2 +- tests/test_rag.py | 6 +- tests/test_rerank.py | 13 ++-- tests/test_search.py | 9 +-- tests/test_split_sentences.py | 11 +--- 27 files changed, 146 insertions(+), 367 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 618e6a8..ba77e24 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -142,6 +142,7 @@ target-version = "py310" [tool.ruff.format] docstring-code-format = true +skip-magic-trailing-comma = true [tool.ruff.lint] select = ["ALL"] @@ -151,6 +152,9 @@ unfixable = ["ERA001", "F401", "F841", "T201", "T203"] [tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "all" +[tool.ruff.lint.isort] +split-on-trailing-comma = false + [tool.ruff.lint.pycodestyle] max-doc-length = 100 diff --git a/src/raglite/_bench.py b/src/raglite/_bench.py index 9026141..6a40637 100644 --- a/src/raglite/_bench.py +++ b/src/raglite/_bench.py @@ -185,8 +185,7 @@ def index(self) -> Any: # noqa: ANN401 vector_store = FaissVectorStore.from_persist_dir(persist_dir=self.persist_path.as_posix()) storage_context = StorageContext.from_defaults( - vector_store=vector_store, - persist_dir=self.persist_path.as_posix(), + vector_store=vector_store, persist_dir=self.persist_path.as_posix() ) embed_model = OpenAIEmbedding(model=self.embedder, dimensions=self.embedder_dim) index = load_index_from_storage(storage_context, embed_model=embed_model) @@ -270,9 +269,7 @@ def insert_documents(self, max_workers: int | None = None) -> None: files.append(temp_file.open("rb")) if len(files) == max_files_per_batch or (i == self.dataset.docs_count() - 1): self.client.vector_stores.file_batches.upload_and_poll( - vector_store_id=vector_store.id, - files=files, - max_concurrency=max_workers, + vector_store_id=vector_store.id, files=files, max_concurrency=max_workers ) for f in files: f.close() @@ -286,9 +283,7 @@ def search(self, query_id: str, query: str, *, num_results: int = 10) -> list[Sc if not self.vector_store_id: return [] response = self.client.vector_stores.search( - vector_store_id=self.vector_store_id, - query=query, - max_num_results=2 * num_results, + vector_store_id=self.vector_store_id, query=query, max_num_results=2 * num_results ) scored_docs = [ ScoredDoc( diff --git a/src/raglite/_chainlit.py b/src/raglite/_chainlit.py index 402d6f1..8c94f52 100644 --- a/src/raglite/_chainlit.py +++ b/src/raglite/_chainlit.py @@ -39,7 +39,7 @@ async def start_chat() -> None: TextInput(id="llm", label="LLM", initial=config.llm), TextInput(id="embedder", label="Embedder", initial=config.embedder), Switch(id="vector_search_query_adapter", label="Query adapter", initial=True), - ], + ] ).send() await update_config(settings) @@ -95,9 +95,7 @@ async def handle_message(user_message: cl.Message) -> None: messages: list[dict[str, str]] = cl.chat_context.to_openai()[:-1] # type: ignore[no-untyped-call] messages.append({"role": "user", "content": user_prompt}) async for token in async_rag( - messages, - on_retrieval=lambda x: chunk_spans.extend(x), - config=config, + messages, on_retrieval=lambda x: chunk_spans.extend(x), config=config ): await assistant_message.stream_token(token) # Append RAG sources, if any. diff --git a/src/raglite/_chatml_function_calling.py b/src/raglite/_chatml_function_calling.py index bed4bda..9042080 100644 --- a/src/raglite/_chatml_function_calling.py +++ b/src/raglite/_chatml_function_calling.py @@ -25,24 +25,12 @@ import json import warnings -from typing import ( # noqa: UP035 - Any, - Iterator, - List, - Optional, - Union, - cast, -) +from typing import Any, Iterator, List, Optional, Union, cast # noqa: UP035 import jinja2 from jinja2.sandbox import ImmutableSandboxedEnvironment -from raglite._lazy_llama import ( - llama, - llama_chat_format, - llama_grammar, - llama_types, -) +from raglite._lazy_llama import llama, llama_chat_format, llama_grammar, llama_types def _accumulate_chunks( @@ -100,7 +88,7 @@ def _convert_chunks_to_completion( "index": 0, "logprobs": logprobs, # TODO(lsorber): Improve accumulation of logprobs "finish_reason": finish_reason, # type: ignore[typeddict-item] - }, + } ], } # Add usage section if present in the chunks @@ -131,8 +119,7 @@ def _stream_tool_calls( prompt += f"functions.{tool_name}:\n" try: grammar = llama_grammar.LlamaGrammar.from_json_schema( - json.dumps(tool["function"]["parameters"]), - verbose=llama.verbose, + json.dumps(tool["function"]["parameters"]), verbose=llama.verbose ) except Exception as e: warnings.warn( @@ -141,16 +128,10 @@ def _stream_tool_calls( stacklevel=2, ) grammar = llama_grammar.LlamaGrammar.from_string( - llama_grammar.JSON_GBNF, - verbose=llama.verbose, + llama_grammar.JSON_GBNF, verbose=llama.verbose ) completion_or_chunks = llama.create_completion( - prompt=prompt, - **{ - **completion_kwargs, - "max_tokens": None, - "grammar": grammar, - }, + prompt=prompt, **{**completion_kwargs, "max_tokens": None, "grammar": grammar} ) chunks: List[llama_types.CreateCompletionResponse] = [] chat_chunks = llama_chat_format._convert_completion_to_chat_function( # noqa: SLF001 @@ -184,8 +165,7 @@ def _stream_tool_calls( "stop": [*completion_kwargs["stop"], ":", ""], "max_tokens": None, "grammar": llama_grammar.LlamaGrammar.from_string( - follow_up_gbnf_tool_grammar, - verbose=llama.verbose, + follow_up_gbnf_tool_grammar, verbose=llama.verbose ), }, ), @@ -209,11 +189,7 @@ def _convert_text_completion_logprobs_to_chat( "bytes": None, "logprob": logprob, # type: ignore[typeddict-item] "top_logprobs": [ - { - "token": top_token, - "logprob": top_logprob, - "bytes": None, - } + {"token": top_token, "logprob": top_logprob, "bytes": None} for top_token, top_logprob in (top_logprobs or {}).items() ], } @@ -321,9 +297,9 @@ def chatml_function_calling_with_streaming( "{% endfor %}" "{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}" ) - template_renderer = ImmutableSandboxedEnvironment( - undefined=jinja2.StrictUndefined, - ).from_string(function_calling_template) + template_renderer = ImmutableSandboxedEnvironment(undefined=jinja2.StrictUndefined).from_string( + function_calling_template + ) # Convert legacy functions to tools if functions is not None: @@ -384,10 +360,7 @@ def chatml_function_calling_with_streaming( or len(tools) == 0 ): prompt = template_renderer.render( - messages=messages, - tools=[], - tool_calls=None, - add_generation_prompt=True, + messages=messages, tools=[], tool_calls=None, add_generation_prompt=True ) return llama_chat_format._convert_completion_to_chat( # noqa: SLF001 llama.create_completion( @@ -410,10 +383,7 @@ def chatml_function_calling_with_streaming( assert tools function_names = " | ".join([f'''"functions.{t["function"]["name"]}:"''' for t in tools]) prompt = template_renderer.render( - messages=messages, - tools=tools, - tool_calls=True, - add_generation_prompt=True, + messages=messages, tools=tools, tool_calls=True, add_generation_prompt=True ) initial_gbnf_tool_grammar = ( ( @@ -438,8 +408,7 @@ def chatml_function_calling_with_streaming( "stream": False, "max_tokens": None, "grammar": llama_grammar.LlamaGrammar.from_string( - initial_gbnf_tool_grammar, - verbose=llama.verbose, + initial_gbnf_tool_grammar, verbose=llama.verbose ), }, ), @@ -459,10 +428,7 @@ def chatml_function_calling_with_streaming( # Case 2 step 2A: Respond with a message if tool_name is None: prompt = template_renderer.render( - messages=messages, - tools=[], - tool_calls=None, - add_generation_prompt=True, + messages=messages, tools=[], tool_calls=None, add_generation_prompt=True ) prompt += think return llama_chat_format._convert_completion_to_chat( # noqa: SLF001 @@ -482,12 +448,7 @@ def chatml_function_calling_with_streaming( prompt += "\n" if stream: return _stream_tool_calls( - llama, - prompt, - tools, - tool_name, - completion_kwargs, - follow_up_gbnf_tool_grammar, + llama, prompt, tools, tool_name, completion_kwargs, follow_up_gbnf_tool_grammar ) tool = next((tool for tool in tools if tool["function"]["name"] == tool_name), None) completions: List[llama_types.CreateCompletionResponse] = [] @@ -497,8 +458,7 @@ def chatml_function_calling_with_streaming( prompt += f"functions.{tool_name}:\n" try: grammar = llama_grammar.LlamaGrammar.from_json_schema( - json.dumps(tool["function"]["parameters"]), - verbose=llama.verbose, + json.dumps(tool["function"]["parameters"]), verbose=llama.verbose ) except Exception as e: warnings.warn( @@ -507,8 +467,7 @@ def chatml_function_calling_with_streaming( stacklevel=2, ) grammar = llama_grammar.LlamaGrammar.from_string( - llama_grammar.JSON_GBNF, - verbose=llama.verbose, + llama_grammar.JSON_GBNF, verbose=llama.verbose ) completion_or_chunks = llama.create_completion( prompt=prompt, @@ -535,8 +494,7 @@ def chatml_function_calling_with_streaming( "stop": [*completion_kwargs["stop"], ":", ""], # type: ignore[misc] "max_tokens": None, "grammar": llama_grammar.LlamaGrammar.from_string( - follow_up_gbnf_tool_grammar, - verbose=llama.verbose, + follow_up_gbnf_tool_grammar, verbose=llama.verbose ), }, ), @@ -554,7 +512,7 @@ def chatml_function_calling_with_streaming( "finish_reason": "tool_calls", "index": 0, "logprobs": _convert_text_completion_logprobs_to_chat( - completion["choices"][0]["logprobs"], + completion["choices"][0]["logprobs"] ), "message": { "role": "assistant", @@ -569,11 +527,11 @@ def chatml_function_calling_with_streaming( }, } for i, (tool_name, completion) in enumerate( - zip(completions_tool_name, completions, strict=True), + zip(completions_tool_name, completions, strict=True) ) ], }, - }, + } ], "usage": { "completion_tokens": sum( diff --git a/src/raglite/_cli.py b/src/raglite/_cli.py index f2664e6..e6342f2 100644 --- a/src/raglite/_cli.py +++ b/src/raglite/_cli.py @@ -14,9 +14,7 @@ class RAGLiteCLIConfig(BaseSettings): """RAGLite CLI config.""" model_config: ClassVar[SettingsConfigDict] = SettingsConfigDict( - env_prefix="RAGLITE_", - env_file=".env", - extra="allow", + env_prefix="RAGLITE_", env_file=".env", extra="allow" ) mcp_server_name: str = "RAGLite" @@ -69,7 +67,7 @@ def install_mcp_server( claude_config_path = get_claude_config_path() if not claude_config_path: typer.echo( - "Please download the Claude desktop app from https://claude.ai/download before installing an MCP server.", + "Please download the Claude desktop app from https://claude.ai/download before installing an MCP server." ) return claude_config_filepath = claude_config_path / "claude_desktop_config.json" @@ -114,9 +112,7 @@ def run_mcp_server( from raglite._mcp import create_mcp_server config = RAGLiteConfig( - db_url=ctx.obj["db_url"], - llm=ctx.obj["llm"], - embedder=ctx.obj["embedder"], + db_url=ctx.obj["db_url"], llm=ctx.obj["llm"], embedder=ctx.obj["embedder"] ) mcp = create_mcp_server(server_name, config=config) mcp.run() @@ -126,10 +122,7 @@ def run_mcp_server( def bench( ctx: typer.Context, dataset_name: str = typer.Option( - "nano-beir/hotpotqa", - "--dataset", - "-d", - help="Dataset to use from https://ir-datasets.com/", + "nano-beir/hotpotqa", "--dataset", "-d", help="Dataset to use from https://ir-datasets.com/" ), measure: str = typer.Option( "AP@10", @@ -164,9 +157,7 @@ def bench( ) dataset = ir_datasets.load(dataset_name) evaluator = RAGLiteEvaluator( - dataset, - insert_variant=f"single-vector-{chunk_max_size // 4}t", - config=config, + dataset, insert_variant=f"single-vector-{chunk_max_size // 4}t", config=config ) index.append("RAGLite (single-vector)") results.append(ir_measures.calc_aggregate(measures, dataset.qrels_iter(), evaluator.score())) @@ -179,9 +170,7 @@ def bench( ) dataset = ir_datasets.load(dataset_name) evaluator = RAGLiteEvaluator( - dataset, - insert_variant=f"multi-vector-{chunk_max_size // 4}t", - config=config, + dataset, insert_variant=f"multi-vector-{chunk_max_size // 4}t", config=config ) index.append("RAGLite (multi-vector)") results.append(ir_measures.calc_aggregate(measures, dataset.qrels_iter(), evaluator.score())) diff --git a/src/raglite/_config.py b/src/raglite/_config.py index 37071cd..af13126 100644 --- a/src/raglite/_config.py +++ b/src/raglite/_config.py @@ -26,10 +26,7 @@ # Lazily load the default search method to avoid circular imports. # TODO(lsorber): Replace with search_and_rerank_chunk_spans after benchmarking. def _vector_search( - query: str, - *, - num_results: int = 8, - config: "RAGLiteConfig | None" = None, + query: str, *, num_results: int = 8, config: "RAGLiteConfig | None" = None ) -> tuple[list[ChunkId], list[float]]: from raglite._search import vector_search @@ -48,7 +45,7 @@ class RAGLiteConfig: "llama-cpp-python/unsloth/Qwen3-8B-GGUF/*Q4_K_M.gguf@8192" if llama_supports_gpu_offload() else "llama-cpp-python/unsloth/Qwen3-4B-GGUF/*Q4_K_M.gguf@8192" - ), + ) ) llm_max_tries: int = 4 # Embedder config used for indexing. @@ -57,7 +54,7 @@ class RAGLiteConfig: "llama-cpp-python/lm-kit/bge-m3-gguf/*F16.gguf@512" if llama_supports_gpu_offload() or (os.cpu_count() or 1) >= 4 # noqa: PLR2004 else "llama-cpp-python/lm-kit/bge-m3-gguf/*Q4_K_M.gguf@512" - ), + ) ) embedder_normalize: bool = True # Chunk config used to partition documents into chunks. diff --git a/src/raglite/_database.py b/src/raglite/_database.py index 0b9ea1a..ceb9629 100644 --- a/src/raglite/_database.py +++ b/src/raglite/_database.py @@ -328,7 +328,7 @@ def to_xml(self, index: int | None = None) -> str: f"\n{escape(''.join(chunk.body for chunk in self.chunks).strip())}\n", "", "", - ], + ] ) return xml_document @@ -414,11 +414,10 @@ class IndexMetadata(SQLModel, table=True): # Table columns. id: IndexId = Field(..., primary_key=True) version: datetime.datetime = Field( - default_factory=lambda: datetime.datetime.now(datetime.timezone.utc), + default_factory=lambda: datetime.datetime.now(datetime.timezone.utc) ) metadata_: dict[str, Any] = Field( - default_factory=dict, - sa_column=Column("metadata", PickledObject), + default_factory=dict, sa_column=Column("metadata", PickledObject) ) @staticmethod @@ -529,7 +528,7 @@ def create_database_engine(config: RAGLiteConfig | None = None) -> Engine: # no session.execute( text(""" CREATE INDEX IF NOT EXISTS keyword_search_chunk_index ON chunk USING GIN (to_tsvector('simple', body)); - """), + """) ) metrics = {"cosine": "cosine", "dot": "ip", "l1": "l1", "l2": "l2"} create_vector_index_sql = f""" @@ -542,7 +541,7 @@ def create_database_engine(config: RAGLiteConfig | None = None) -> Engine: # no """ # Enable iterative scan for pgvector v0.8.0 and up. pgvector_version = session.execute( - text("SELECT extversion FROM pg_extension WHERE extname = 'vector'"), + text("SELECT extversion FROM pg_extension WHERE extname = 'vector'") ).scalar_one() if pgvector_version and version.parse(pgvector_version) >= version.parse("0.8.0"): create_vector_index_sql += f"\nSET hnsw.iterative_scan = {'relaxed_order' if config.reranker else 'strict_order'};" @@ -556,20 +555,20 @@ def create_database_engine(config: RAGLiteConfig | None = None) -> Engine: # no num_chunks = session.execute(text("SELECT COUNT(*) FROM chunk")).scalar_one() try: num_indexed_chunks = session.execute( - text("SELECT COUNT(*) FROM fts_main_chunk.docs"), + text("SELECT COUNT(*) FROM fts_main_chunk.docs") ).scalar_one() except ProgrammingError: num_indexed_chunks = 0 if num_indexed_chunks == 0 or num_indexed_chunks != num_chunks: session.execute( - text("PRAGMA create_fts_index('chunk', 'id', 'body', overwrite = 1);"), + text("PRAGMA create_fts_index('chunk', 'id', 'body', overwrite = 1);") ) # Create a vector search index with VSS if it doesn't exist. session.execute( text(f""" SET hnsw_ef_search = {ef_search}; SET hnsw_enable_experimental_persistence = true; - """), + """) ) vss_index_exists = session.execute( text(""" @@ -578,7 +577,7 @@ def create_database_engine(config: RAGLiteConfig | None = None) -> Engine: # no WHERE schema_name = current_schema() AND table_name = 'chunk_embedding' AND index_name = 'vector_search_chunk_index' - """), + """) ).scalar_one() if not vss_index_exists: metrics = {"cosine": "cosine", "dot": "ip", "l2": "l2sq"} diff --git a/src/raglite/_embed.py b/src/raglite/_embed.py index 5e49e4a..fc50d08 100644 --- a/src/raglite/_embed.py +++ b/src/raglite/_embed.py @@ -14,22 +14,16 @@ def embed_strings_with_late_chunking( # noqa: C901,PLR0915 - sentences: list[str], - *, - config: RAGLiteConfig | None = None, + sentences: list[str], *, config: RAGLiteConfig | None = None ) -> FloatMatrix: """Embed a document's sentences with late chunking.""" def _count_tokens( - sentences: list[str], - embedder: Llama, - sentinel_char: str, - sentinel_tokens: list[int], + sentences: list[str], embedder: Llama, sentinel_char: str, sentinel_tokens: list[int] ) -> list[int]: # Join the sentences with the sentinel token and tokenise the result. sentences_tokens = np.asarray( - embedder.tokenize(sentinel_char.join(sentences).encode(), add_bos=False), - dtype=np.intp, + embedder.tokenize(sentinel_char.join(sentences).encode(), add_bos=False), dtype=np.intp ) # Map all sentinel token variants to the first one. for sentinel_token in sentinel_tokens[1:]: @@ -68,9 +62,7 @@ def _create_segment( config = config or RAGLiteConfig() assert config.embedder.startswith("llama-cpp-python") embedder = LlamaCppPythonLLM.llm( - config.embedder, - embedding=True, - pooling_type=LLAMA_POOLING_TYPE_NONE, + config.embedder, embedding=True, pooling_type=LLAMA_POOLING_TYPE_NONE ) n_ctx = embedder.n_ctx() n_batch = embedder.n_batch @@ -95,7 +87,7 @@ def _create_segment( sentence_batch_len += len(sentence) if i == len(sentences) - 1 or sentence_batch_len > (n_ctx // 2): num_tokens_list.extend( - _count_tokens(sentence_batch, embedder, sentinel_char, sentinel_tokens), + _count_tokens(sentence_batch, embedder, sentinel_char, sentinel_tokens) ) sentence_batch, sentence_batch_len = [], 0 num_tokens = np.asarray(num_tokens_list, dtype=np.intp) @@ -112,10 +104,7 @@ def _create_segment( content_start_index = 0 while content_start_index < len(sentences): segment_start_index, segment_end_index = _create_segment( - content_start_index, - max_tokens_preamble, - max_tokens_content, - num_tokens, + content_start_index, max_tokens_preamble, max_tokens_content, num_tokens ) segments.append((segment_start_index, content_start_index, segment_end_index)) content_start_index = segment_end_index @@ -160,9 +149,7 @@ def _embed_string_batch(string_batch: list[str], *, config: RAGLiteConfig) -> Fl # embeddings because token embeddings are universally supported, while sequence # embeddings are only supported by some models. embedder = LlamaCppPythonLLM.llm( - config.embedder, - embedding=True, - pooling_type=LLAMA_POOLING_TYPE_NONE, + config.embedder, embedding=True, pooling_type=LLAMA_POOLING_TYPE_NONE ) embeddings = np.asarray([np.mean(row, axis=0) for row in embedder.embed(string_batch)]) else: @@ -179,9 +166,7 @@ def _embed_string_batch(string_batch: list[str], *, config: RAGLiteConfig) -> Fl def embed_strings_without_late_chunking( - strings: list[str], - *, - config: RAGLiteConfig | None = None, + strings: list[str], *, config: RAGLiteConfig | None = None ) -> FloatMatrix: """Embed a list of text strings in batches.""" config = config or RAGLiteConfig() diff --git a/src/raglite/_eval.py b/src/raglite/_eval.py index 451b308..cf57377 100644 --- a/src/raglite/_eval.py +++ b/src/raglite/_eval.py @@ -29,11 +29,10 @@ class QuestionResponse(BaseModel): """A specific question about the content of a set of document contexts.""" model_config = ConfigDict( - extra="forbid", # Forbid extra attributes as required by OpenAI's strict mode. + extra="forbid" # Forbid extra attributes as required by OpenAI's strict mode. ) question: str = Field( - ..., - description="A specific question about the content of a set of document contexts.", + ..., description="A specific question about the content of a set of document contexts." ) system_prompt: ClassVar[str] = """ You are given a set of contexts extracted from a document. @@ -70,7 +69,7 @@ def validate_question(cls, value: str) -> str: select(Chunk) .where(Chunk.document_id == seed_document.id) .order_by(func.random()) - .limit(1), + .limit(1) ).first() assert isinstance(seed_chunk, Chunk) # Expand the seed chunk into a set of related chunks. @@ -85,16 +84,11 @@ def validate_question(cls, value: str) -> str: ] # Extract a question from the seed chunk's related chunks. question = extract_with_llm( - QuestionResponse, - related_chunks, - strict=True, - config=config, + QuestionResponse, related_chunks, strict=True, config=config ).question # Search for candidate chunks to answer the generated question. candidate_chunk_ids, _ = vector_search( - query=question, - num_results=2 * max_chunks, - config=config, + query=question, num_results=2 * max_chunks, config=config ) candidate_chunks = [session.get(Chunk, chunk_id) for chunk_id in candidate_chunk_ids] @@ -103,7 +97,7 @@ class ContextEvalResponse(BaseModel): """Indicate whether the provided context can be used to answer a given question.""" model_config = ConfigDict( - extra="forbid", # Forbid extra attributes as required by OpenAI's strict mode. + extra="forbid" # Forbid extra attributes as required by OpenAI's strict mode. ) hit: bool = Field( ..., @@ -118,18 +112,11 @@ class ContextEvalResponse(BaseModel): relevant_chunks = [] for candidate_chunk in tqdm( - candidate_chunks, - desc="Evaluating chunks", - unit="chunk", - dynamic_ncols=True, - leave=False, + candidate_chunks, desc="Evaluating chunks", unit="chunk", dynamic_ncols=True, leave=False ): try: context_eval_response = extract_with_llm( - ContextEvalResponse, - str(candidate_chunk), - strict=True, - config=config, + ContextEvalResponse, str(candidate_chunk), strict=True, config=config ) except ValueError: # noqa: PERF203 pass @@ -145,11 +132,10 @@ class AnswerResponse(BaseModel): """Answer a question using the provided context.""" model_config = ConfigDict( - extra="forbid", # Forbid extra attributes as required by OpenAI's strict mode. + extra="forbid" # Forbid extra attributes as required by OpenAI's strict mode. ) answer: str = Field( - ..., - description="A complete answer to the given question using the provided context.", + ..., description="A complete answer to the given question using the provided context." ) system_prompt: ClassVar[str] = f""" You are given a set of contexts extracted from a document. @@ -200,11 +186,7 @@ def insert_evals( session.execute(text("CHECKPOINT;")) -def answer_evals( - num_evals: int = 100, - *, - config: RAGLiteConfig | None = None, -) -> "pd.DataFrame": +def answer_evals(num_evals: int = 100, *, config: RAGLiteConfig | None = None) -> "pd.DataFrame": """Read evals from the database and answer them with RAG.""" try: import pandas as pd @@ -238,8 +220,7 @@ def answer_evals( def evaluate( - answered_evals: "pd.DataFrame | int" = 100, - config: RAGLiteConfig | None = None, + answered_evals: "pd.DataFrame | int" = 100, config: RAGLiteConfig | None = None ) -> "pd.DataFrame": """Evaluate the performance of a set of answered evals with Ragas.""" try: diff --git a/src/raglite/_insert.py b/src/raglite/_insert.py index d7fe0d7..2deb171 100644 --- a/src/raglite/_insert.py +++ b/src/raglite/_insert.py @@ -20,8 +20,7 @@ def _create_chunk_records( - document: Document, - config: RAGLiteConfig, + document: Document, config: RAGLiteConfig ) -> tuple[Document, list[Chunk], list[list[ChunkEmbedding]]]: """Process chunks into chunk and chunk embedding records.""" # Partition the document into chunks. @@ -30,20 +29,14 @@ def _create_chunk_records( chunklets = split_chunklets(sentences, max_size=config.chunk_max_size) chunklet_embeddings = embed_strings(chunklets, config=config) chunks, chunk_embeddings = split_chunks( - chunklets=chunklets, - chunklet_embeddings=chunklet_embeddings, - max_size=config.chunk_max_size, + chunklets=chunklets, chunklet_embeddings=chunklet_embeddings, max_size=config.chunk_max_size ) # Create the chunk records. chunk_records, headings = [], "" for i, chunk in enumerate(chunks): # Create and append the chunk record. record = Chunk.from_body( - document=document, - index=i, - body=chunk, - headings=headings, - **document.metadata_, + document=document, index=i, body=chunk, headings=headings, **document.metadata_ ) chunk_records.append(record) # Update the Markdown headings with those of this chunk. @@ -58,23 +51,19 @@ def _create_chunk_records( [ ChunkEmbedding(chunk_id=chunk_record.id, embedding=chunklet_embedding) for chunklet_embedding in chunk_embedding - ], + ] ) else: # Embed the full chunks, including the current Markdown headings. full_chunk_embeddings = embed_strings_without_late_chunking( - [chunk_record.content for chunk_record in chunk_records], - config=config, + [chunk_record.content for chunk_record in chunk_records], config=config ) # Every chunk record is associated with a list of chunk embedding records. The chunk # embedding records each correspond to a linear combination of a chunklet embedding and an # embedding of the full chunk with Markdown headings. α = 0.15 # Benchmark-optimised value. # noqa: PLC2401 for chunk_record, chunk_embedding, full_chunk_embedding in zip( - chunk_records, - chunk_embeddings, - full_chunk_embeddings, - strict=True, + chunk_records, chunk_embeddings, full_chunk_embeddings, strict=True ): if config.vector_search_multivector: chunk_embedding_records_list.append( @@ -84,16 +73,11 @@ def _create_chunk_records( embedding=α * chunklet_embedding + (1 - α) * full_chunk_embedding, ) for chunklet_embedding in chunk_embedding - ], + ] ) else: chunk_embedding_records_list.append( - [ - ChunkEmbedding( - chunk_id=chunk_record.id, - embedding=full_chunk_embedding, - ), - ], + [ChunkEmbedding(chunk_id=chunk_record.id, embedding=full_chunk_embedding)] ) return document, chunk_records, chunk_embedding_records_list @@ -134,7 +118,7 @@ def insert_documents( # noqa: C901 for i in range(0, len(documents), batch_size): doc_id_batch = [doc.id for doc in documents[i : i + batch_size]] existing_doc_ids.update( - session.exec(select(Document.id).where(col(Document.id).in_(doc_id_batch))).all(), + session.exec(select(Document.id).where(col(Document.id).in_(doc_id_batch))).all() ) documents = [doc for doc in documents if doc.id not in existing_doc_ids] if not documents: @@ -155,10 +139,7 @@ def insert_documents( # noqa: C901 Session(engine) as session, ThreadPoolExecutor(max_workers=max_workers) as executor, tqdm( - total=len(documents), - desc="Inserting documents", - unit="document", - dynamic_ncols=True, + total=len(documents), desc="Inserting documents", unit="document", dynamic_ncols=True ) as pbar, ): futures = [ diff --git a/src/raglite/_litellm.py b/src/raglite/_litellm.py index 50bf886..77c99ab 100644 --- a/src/raglite/_litellm.py +++ b/src/raglite/_litellm.py @@ -28,12 +28,7 @@ from raglite._chatml_function_calling import chatml_function_calling_with_streaming from raglite._config import RAGLiteConfig -from raglite._lazy_llama import ( - Llama, - LlamaRAMCache, - llama_supports_gpu_offload, - llama_types, -) +from raglite._lazy_llama import Llama, LlamaRAMCache, llama_supports_gpu_offload, llama_types # Reduce the logging level for LiteLLM, flashrank, and httpx. litellm.suppress_debug_info = True @@ -144,7 +139,7 @@ def llm(model: str, **kwargs: Any) -> Llama: "supports_function_calling": True, "supports_parallel_function_calling": True, "supports_vision": False, - }, + } } litellm.register_model(model_info) # type: ignore[attr-defined] return llm @@ -168,9 +163,7 @@ def _translate_openai_params(self, optional_params: dict[str, Any]) -> dict[str, return llama_cpp_python_params def _add_recommended_model_params( - self, - model: str, - llama_cpp_python_params: dict[str, Any], + self, model: str, llama_cpp_python_params: dict[str, Any] ) -> dict[str, Any]: """Add recommended model settings.""" recommended_settings = {} @@ -324,7 +317,7 @@ async def astreaming( # type: ignore[misc,override] # noqa: PLR0913 # Register the LlamaCppPythonLLM provider. if not any(provider["provider"] == "llama-cpp-python" for provider in litellm.custom_provider_map): litellm.custom_provider_map.append( - {"provider": "llama-cpp-python", "custom_handler": LlamaCppPythonLLM()}, + {"provider": "llama-cpp-python", "custom_handler": LlamaCppPythonLLM()} ) custom_llm_setup() # type: ignore[no-untyped-call] diff --git a/src/raglite/_markdown.py b/src/raglite/_markdown.py index ba87de4..b5269fb 100644 --- a/src/raglite/_markdown.py +++ b/src/raglite/_markdown.py @@ -40,7 +40,7 @@ def extract_font_size(span: dict[str, Any]) -> float: for block in page["blocks"] for line in block["lines"] for span in line["spans"] - ], + ] ) font_sizes = np.round(font_sizes * 2) / 2 unique_font_sizes, counts = np.unique(font_sizes, return_counts=True) @@ -113,8 +113,7 @@ def strip_page_numbers(pages: list[dict[str, Any]]) -> list[dict[str, Any]]: line for line in block["lines"] if not re.match( - r"^\s*[#0]*\d+\s*$", - "".join(span["text"] for span in line["spans"]), + r"^\s*[#0]*\d+\s*$", "".join(span["text"] for span in line["spans"]) ) ] return pages diff --git a/src/raglite/_mcp.py b/src/raglite/_mcp.py index aea6a37..ec97364 100644 --- a/src/raglite/_mcp.py +++ b/src/raglite/_mcp.py @@ -14,7 +14,7 @@ description=( "The `query` string MUST be a precise single-faceted question in the user's language.\n" "The `query` string MUST resolve all pronouns to explicit nouns." - ), + ) ), ] @@ -42,7 +42,7 @@ def search_knowledge_base(query: Query) -> str: rag_context = '{{"documents": [{elements}]}}'.format( elements=", ".join( chunk_span.to_json(index=i + 1) for i, chunk_span in enumerate(chunk_spans) - ), + ) ) return rag_context diff --git a/src/raglite/_query_adapter.py b/src/raglite/_query_adapter.py index ad4e832..cbd4e0f 100644 --- a/src/raglite/_query_adapter.py +++ b/src/raglite/_query_adapter.py @@ -154,19 +154,13 @@ def update_query_adapter( Q = np.zeros((0, len(chunk_embedding.embedding))) T = np.zeros_like(Q) for eval_ in tqdm( - evals, - desc="Optimizing evals", - unit="eval", - dynamic_ncols=True, - leave=False, + evals, desc="Optimizing evals", unit="eval", dynamic_ncols=True, leave=False ): # Embed the question. q = embed_strings([eval_.question], config=config)[0] # Retrieve chunks that would be used to answer the question. chunk_ids, _ = vector_search( - q, - num_results=optimize_top_k, - config=config_no_query_adapter, + q, num_results=optimize_top_k, config=config_no_query_adapter ) retrieved_chunks = session.exec(select(Chunk).where(col(Chunk.id).in_(chunk_ids))).all() retrieved_chunks = sorted(retrieved_chunks, key=lambda chunk: chunk_ids.index(chunk.id)) @@ -179,13 +173,13 @@ def update_query_adapter( [ chunk.embedding_matrix[[np.argmax(chunk.embedding_matrix @ q)]] for chunk in np.array(retrieved_chunks)[is_relevant] - ], + ] ) N = np.vstack( [ chunk.embedding_matrix[[np.argmax(chunk.embedding_matrix @ q)]] for chunk in np.array(retrieved_chunks)[~is_relevant] - ], + ] ) # Compute the optimal target vector t for this query embedding q. t = _optimize_query_target(q, P, N, α=optimize_gap) diff --git a/src/raglite/_rag.py b/src/raglite/_rag.py index b5c0081..8393e31 100644 --- a/src/raglite/_rag.py +++ b/src/raglite/_rag.py @@ -37,10 +37,7 @@ def retrieve_context( - query: str, - *, - num_chunks: int = 10, - config: RAGLiteConfig | None = None, + query: str, *, num_chunks: int = 10, config: RAGLiteConfig | None = None ) -> list[ChunkSpan]: """Retrieve context for RAG.""" # Call the search method. @@ -89,8 +86,7 @@ def _clip(messages: list[dict[str, str]], max_tokens: int) -> list[dict[str, str def _get_tools( - messages: list[dict[str, str]], - config: RAGLiteConfig, + messages: list[dict[str, str]], config: RAGLiteConfig ) -> tuple[list[dict[str, Any]] | None, dict[str, Any] | str | None]: """Get tools to search the knowledge base if no RAG context is provided in the messages.""" # Check if messages already contain RAG context or if the LLM supports tool use. @@ -123,13 +119,13 @@ def _get_tools( "The `query` string MUST be a precise single-faceted question in the user's language.\n" "The `query` string MUST resolve all pronouns to explicit nouns." ), - }, + } }, "required": ["query"], "additionalProperties": False, }, }, - }, + } ] if not messages_contain_rag_context else None @@ -157,10 +153,10 @@ def _run_tools( elements=", ".join( chunk_span.to_json(index=i + 1) for i, chunk_span in enumerate(chunk_spans) - ), + ) ), "tool_call_id": tool_call.id, - }, + } ) if chunk_spans and callable(on_retrieval): on_retrieval(chunk_spans) @@ -246,9 +242,7 @@ async def async_rag( # Asynchronously stream the assistant response. chunks = [] async_stream = await acompletion( - model=config.llm, - messages=_clip(messages, max_tokens), - stream=True, + model=config.llm, messages=_clip(messages, max_tokens), stream=True ) async for chunk in async_stream: chunks.append(chunk) diff --git a/src/raglite/_search.py b/src/raglite/_search.py index a611761..652275c 100644 --- a/src/raglite/_search.py +++ b/src/raglite/_search.py @@ -49,8 +49,7 @@ def vector_search( corrected_oversample = oversample * config.chunk_max_size / RAGLiteConfig.chunk_max_size num_hits = round(corrected_oversample) * max(num_results, 10) dist = ChunkEmbedding.embedding.distance( # type: ignore[attr-defined] - query_embedding, - metric=config.vector_search_distance_metric, + query_embedding, metric=config.vector_search_distance_metric ).label("dist") sim = (1.0 - dist).label("sim") top_vectors = select(ChunkEmbedding.chunk_id, sim).order_by(dist).limit(num_hits).subquery() @@ -68,10 +67,7 @@ def vector_search( def keyword_search( - query: str, - *, - num_results: int = 3, - config: RAGLiteConfig | None = None, + query: str, *, num_results: int = 3, config: RAGLiteConfig | None = None ) -> tuple[list[ChunkId], list[float]]: """Search chunks using BM25 keyword search.""" # Read the config. @@ -104,7 +100,7 @@ def keyword_search( WHERE score IS NOT NULL ORDER BY score DESC LIMIT :limit; - """, + """ ) results = session.execute(statement, params={"query": query, "limit": num_results}) # Unpack the results. @@ -115,10 +111,7 @@ def keyword_search( def reciprocal_rank_fusion( - rankings: list[list[ChunkId]], - *, - k: int = 60, - weights: list[float] | None = None, + rankings: list[list[ChunkId]], *, k: int = 60, weights: list[float] | None = None ) -> tuple[list[ChunkId], list[float]]: """Reciprocal Rank Fusion.""" if weights is None: @@ -136,8 +129,7 @@ def reciprocal_rank_fusion( return [], [] # Rank RRF results according to descending RRF score. rrf_chunk_ids, rrf_score = zip( - *sorted(chunk_id_score.items(), key=lambda x: x[1], reverse=True), - strict=True, + *sorted(chunk_id_score.items(), key=lambda x: x[1], reverse=True), strict=True ) return list(rrf_chunk_ids), list(rrf_score) @@ -157,17 +149,14 @@ def hybrid_search( # noqa: PLR0913 ks_chunk_ids, _ = keyword_search(query, num_results=oversample * num_results, config=config) # Combine the results with Reciprocal Rank Fusion (RRF). chunk_ids, hybrid_score = reciprocal_rank_fusion( - [vs_chunk_ids, ks_chunk_ids], - weights=[vector_search_weight, keyword_search_weight], + [vs_chunk_ids, ks_chunk_ids], weights=[vector_search_weight, keyword_search_weight] ) chunk_ids, hybrid_score = chunk_ids[:num_results], hybrid_score[:num_results] return chunk_ids, hybrid_score def retrieve_chunks( - chunk_ids: list[ChunkId], - *, - config: RAGLiteConfig | None = None, + chunk_ids: list[ChunkId], *, config: RAGLiteConfig | None = None ) -> list[Chunk]: """Retrieve chunks by their ids.""" if not chunk_ids: @@ -178,8 +167,8 @@ def retrieve_chunks( select(Chunk) .where(col(Chunk.id).in_(chunk_ids)) # Eagerly load chunk.document. - .options(joinedload(Chunk.document)), # type: ignore[arg-type] - ).all(), + .options(joinedload(Chunk.document)) # type: ignore[arg-type] + ).all() ) chunks = sorted(chunks, key=lambda chunk: chunk_ids.index(chunk.id)) return chunks @@ -221,8 +210,8 @@ def retrieve_chunk_spans( select(Chunk) .where(or_(*neighbor_conditions)) # Eagerly load chunk.document. - .options(joinedload(Chunk.document)), # type: ignore[arg-type] - ).all(), + .options(joinedload(Chunk.document)) # type: ignore[arg-type] + ).all() ) # Deduplicate and sort the chunks by document_id and index (needed for groupby). unique_chunks = sorted(set(chunks), key=lambda chunk: (chunk.document_id, chunk.index)) @@ -248,10 +237,7 @@ def retrieve_chunk_spans( def rerank_chunks( - query: str, - chunk_ids: list[ChunkId] | list[Chunk], - *, - config: RAGLiteConfig | None = None, + query: str, chunk_ids: list[ChunkId] | list[Chunk], *, config: RAGLiteConfig | None = None ) -> list[Chunk]: """Rerank chunks according to their relevance to a given query.""" # Retrieve the chunks. diff --git a/src/raglite/_split_chunklets.py b/src/raglite/_split_chunklets.py index 7c0e250..d4a64aa 100644 --- a/src/raglite/_split_chunklets.py +++ b/src/raglite/_split_chunklets.py @@ -58,8 +58,7 @@ def markdown_chunklet_boundaries(sentences: list[str]) -> FloatVector: def compute_num_statements(sentences: list[str]) -> FloatVector: """Compute the approximate number of statements of each sentence in a list of sentences.""" sentence_word_length = np.asarray( - [len(sentence.split()) for sentence in sentences], - dtype=np.float64, + [len(sentence.split()) for sentence in sentences], dtype=np.float64 ) q25, q75 = np.quantile(sentence_word_length, [0.25, 0.75]) q25 = max(q25, np.sqrt(np.finfo(np.float64).eps)) diff --git a/src/raglite/_split_chunks.py b/src/raglite/_split_chunks.py index 6550bf5..f3355a7 100644 --- a/src/raglite/_split_chunks.py +++ b/src/raglite/_split_chunks.py @@ -10,9 +10,7 @@ def split_chunks( # noqa: C901, PLR0915 - chunklets: list[str], - chunklet_embeddings: FloatMatrix, - max_size: int = 2048, + chunklets: list[str], chunklet_embeddings: FloatMatrix, max_size: int = 2048 ) -> tuple[list[str], list[FloatMatrix]]: """Split chunklets into optimal semantic chunks with corresponding chunklet embeddings. @@ -68,8 +66,7 @@ def split_chunks( # noqa: C901, PLR0915 partition_similarity = np.sum(X[:-1] * X[1:], axis=1) # Make partition similarity nonnegative before modification and optimisation. partition_similarity = np.maximum( - (partition_similarity + 1) / 2, - np.sqrt(np.finfo(X.dtype).eps), + (partition_similarity + 1) / 2, np.sqrt(np.finfo(X.dtype).eps) ) # Modify the partition similarity to encourage splitting on Markdown headings. prev_chunklet_is_heading = True @@ -104,11 +101,7 @@ def split_chunks( # noqa: C901, PLR0915 ) b_ub = np.ones(A.shape[0], dtype=np.float32) res = linprog( - partition_similarity, - A_ub=-A, - b_ub=-b_ub, - bounds=(0, 1), - integrality=[1] * A.shape[1], + partition_similarity, A_ub=-A, b_ub=-b_ub, bounds=(0, 1), integrality=[1] * A.shape[1] ) if not res.success: error_message = "Optimization of chunk partitions failed." diff --git a/src/raglite/_split_sentences.py b/src/raglite/_split_sentences.py index da84c8f..1cbeb18 100644 --- a/src/raglite/_split_sentences.py +++ b/src/raglite/_split_sentences.py @@ -56,11 +56,7 @@ def get_markdown_heading_indexes(doc: str) -> list[tuple[int, int]]: def _split_sentences( - doc: str, - probas: FloatVector, - *, - min_len: int, - max_len: int | None = None, + doc: str, probas: FloatVector, *, min_len: int, max_len: int | None = None ) -> list[str]: # Solve an optimisation problem to find the best sentence boundaries given the predicted # boundary probabilities. The objective is to select boundaries that maximise the sum of the diff --git a/src/raglite/_typing.py b/src/raglite/_typing.py index 1ef690d..94c4661 100644 --- a/src/raglite/_typing.py +++ b/src/raglite/_typing.py @@ -33,21 +33,13 @@ class BasicSearchMethod(Protocol): def __call__( - self, - query: str, - *, - num_results: int, - config: "RAGLiteConfig | None" = None, + self, query: str, *, num_results: int, config: "RAGLiteConfig | None" = None ) -> tuple[list[ChunkId], list[float]]: ... class SearchMethod(Protocol): def __call__( - self, - query: str, - *, - num_results: int, - config: "RAGLiteConfig | None" = None, + self, query: str, *, num_results: int, config: "RAGLiteConfig | None" = None ) -> tuple[list[ChunkId], list[float]] | list["Chunk"] | list["ChunkSpan"]: ... @@ -57,9 +49,7 @@ class NumpyArray(TypeDecorator[np.ndarray[Any, np.dtype[np.floating[Any]]]]): impl = LargeBinary def process_bind_param( - self, - value: np.ndarray[Any, np.dtype[np.floating[Any]]] | None, - dialect: Dialect, + self, value: np.ndarray[Any, np.dtype[np.floating[Any]]] | None, dialect: Dialect ) -> bytes | None: """Convert a NumPy array to bytes.""" if value is None: @@ -69,9 +59,7 @@ def process_bind_param( return buffer.getvalue() def process_result_value( - self, - value: bytes | None, - dialect: Dialect, + self, value: bytes | None, dialect: Dialect ) -> np.ndarray[Any, np.dtype[np.floating[Any]]] | None: """Convert bytes to a NumPy array.""" if value is None: @@ -110,12 +98,7 @@ def __init__(self, left: Any, right: Any, metric: DistanceMetric) -> None: @compiles(EmbeddingDistance, "postgresql") def _embedding_distance_postgresql(element: EmbeddingDistance, compiler: Any, **kwargs: Any) -> str: - op_map: dict[DistanceMetric, str] = { - "cosine": "<=>", - "dot": "<#>", - "l1": "<+>", - "l2": "<->", - } + op_map: dict[DistanceMetric, str] = {"cosine": "<=>", "dot": "<#>", "l1": "<+>", "l2": "<->"} left, right = list(element.clauses) operator = op_map[element.metric] return f"({compiler.process(left)} {operator} {compiler.process(right)})" @@ -164,9 +147,7 @@ def process(value: FloatVector | None) -> str | None: return process def result_processor( - self, - dialect: Dialect, - coltype: Any, + self, dialect: Dialect, coltype: Any ) -> Callable[[str | None], FloatVector | None]: """Process PostgreSQL halfvec format to NumPy ndarray.""" @@ -191,8 +172,7 @@ def get_col_spec(self, **kwargs: Any) -> str: return f"FLOAT[{self.dim}]" if self.dim is not None else "FLOAT[]" def bind_processor( - self, - dialect: Dialect, + self, dialect: Dialect ) -> Callable[[FloatVector | None], list[float] | None]: """Process NumPy ndarray to DuckDB single precision vector format for bound parameters.""" @@ -202,9 +182,7 @@ def process(value: FloatVector | None) -> list[float] | None: return process def result_processor( - self, - dialect: Dialect, - coltype: Any, + self, dialect: Dialect, coltype: Any ) -> Callable[[list[float] | None], FloatVector | None]: """Process DuckDB single precision vector format to NumPy ndarray.""" diff --git a/tests/test_chatml_function_calling.py b/tests/test_chatml_function_calling.py index d7f02a3..a2b612f 100644 --- a/tests/test_chatml_function_calling.py +++ b/tests/test_chatml_function_calling.py @@ -10,11 +10,7 @@ from typeguard import ForwardRefPolicy, check_type from raglite._chatml_function_calling import chatml_function_calling_with_streaming -from raglite._lazy_llama import ( - Llama, - llama_supports_gpu_offload, - llama_types, -) +from raglite._lazy_llama import Llama, llama_supports_gpu_offload, llama_types def is_accelerator_available() -> bool: @@ -27,11 +23,7 @@ def is_accelerator_available() -> bool: @pytest.mark.parametrize( - "stream", - [ - pytest.param(True, id="stream=True"), - pytest.param(False, id="stream=False"), - ], + "stream", [pytest.param(True, id="stream=True"), pytest.param(False, id="stream=False")] ) @pytest.mark.parametrize( "tool_choice", @@ -39,22 +31,15 @@ def is_accelerator_available() -> bool: pytest.param("none", id="tool_choice=none"), pytest.param("auto", id="tool_choice=auto"), pytest.param( - {"type": "function", "function": {"name": "get_weather"}}, - id="tool_choice=fixed", + {"type": "function", "function": {"name": "get_weather"}}, id="tool_choice=fixed" ), ], ) @pytest.mark.parametrize( "user_prompt_expected_tool_calls", [ - pytest.param( - ("Is 7 a prime number?", 0), - id="expected_tool_calls=0", - ), - pytest.param( - ("What's the weather like in Paris today?", 1), - id="expected_tool_calls=1", - ), + pytest.param(("Is 7 a prime number?", 0), id="expected_tool_calls=0"), + pytest.param(("What's the weather like in Paris today?", 1), id="expected_tool_calls=1"), pytest.param( ("What's the weather like in Paris today? What about New York?", 2), id="expected_tool_calls=2", @@ -69,8 +54,7 @@ def is_accelerator_available() -> bool: "unsloth/Qwen3-8B-GGUF", id="qwen3_8B", marks=pytest.mark.skipif( - not is_accelerator_available(), - reason="Accelerator not available", + not is_accelerator_available(), reason="Accelerator not available" ), ), ], @@ -95,7 +79,7 @@ def test_llama_cpp_python_tool_use( chat_handler=chatml_function_calling_with_streaming, ) messages: list[llama_types.ChatCompletionRequestMessage] = [ - {"role": "user", "content": user_prompt}, + {"role": "user", "content": user_prompt} ] tools: list[llama_types.ChatCompletionTool] = [ { @@ -108,13 +92,10 @@ def test_llama_cpp_python_tool_use( "properties": {"location": {"type": "string", "description": "A city name."}}, }, }, - }, + } ] response = llm.create_chat_completion( - messages=messages, - tools=tools, - tool_choice=tool_choice, - stream=stream, + messages=messages, tools=tools, tool_choice=tool_choice, stream=stream ) if stream: response = cast("Iterator[llama_types.CreateChatCompletionStreamResponse]", response) diff --git a/tests/test_extract.py b/tests/test_extract.py index a81df5c..9cf8dd1 100644 --- a/tests/test_extract.py +++ b/tests/test_extract.py @@ -10,8 +10,7 @@ @pytest.mark.parametrize( - "strict", - [pytest.param(False, id="strict=False"), pytest.param(True, id="strict=True")], + "strict", [pytest.param(False, id="strict=False"), pytest.param(True, id="strict=True")] ) def test_extract(llm: str, strict: bool) -> None: # noqa: FBT001 """Test extracting structured data.""" @@ -30,10 +29,7 @@ class LoginResponse(BaseModel): # Extract structured data. username, password = "cypher", "steak" login_response = extract_with_llm( - LoginResponse, - f"username: {username}\npassword: {password}", - strict=strict, - config=config, + LoginResponse, f"username: {username}\npassword: {password}", strict=strict, config=config ) # Validate the response. assert isinstance(login_response, LoginResponse) diff --git a/tests/test_insert.py b/tests/test_insert.py index 1ab62b2..acc9246 100644 --- a/tests/test_insert.py +++ b/tests/test_insert.py @@ -19,7 +19,7 @@ def test_insert(raglite_test_config: RAGLiteConfig) -> None: assert document is not None, "No document found in the database" # Get the existing chunks for this document. chunks = session.exec( - select(Chunk).where(Chunk.document_id == document.id).order_by(Chunk.index), # type: ignore[arg-type] + select(Chunk).where(Chunk.document_id == document.id).order_by(Chunk.index) # type: ignore[arg-type] ).all() assert len(chunks) > 0, "No chunks found for the document" restored_document = "" diff --git a/tests/test_rag.py b/tests/test_rag.py index 151cd96..bb3ca1e 100644 --- a/tests/test_rag.py +++ b/tests/test_rag.py @@ -2,11 +2,7 @@ import json -from raglite import ( - RAGLiteConfig, - add_context, - retrieve_context, -) +from raglite import RAGLiteConfig, add_context, retrieve_context from raglite._database import ChunkSpan from raglite._rag import rag diff --git a/tests/test_rerank.py b/tests/test_rerank.py index b13f4a6..51da375 100644 --- a/tests/test_rerank.py +++ b/tests/test_rerank.py @@ -31,26 +31,21 @@ def kendall_tau(a: list[T], b: list[T]) -> float: }, id="flashrank_multilingual", ), - ], + ] ) -def reranker( - request: pytest.FixtureRequest, -) -> BaseRanker | dict[str, BaseRanker] | None: +def reranker(request: pytest.FixtureRequest) -> BaseRanker | dict[str, BaseRanker] | None: """Get a reranker to test RAGLite with.""" reranker: BaseRanker | dict[str, BaseRanker] | None = request.param return reranker def test_reranker( - raglite_test_config: RAGLiteConfig, - reranker: BaseRanker | dict[str, BaseRanker] | None, + raglite_test_config: RAGLiteConfig, reranker: BaseRanker | dict[str, BaseRanker] | None ) -> None: """Test inserting a document, updating the indexes, and searching for a query.""" # Update the config with the reranker. raglite_test_config = RAGLiteConfig( - db_url=raglite_test_config.db_url, - embedder=raglite_test_config.embedder, - reranker=reranker, + db_url=raglite_test_config.db_url, embedder=raglite_test_config.embedder, reranker=reranker ) # Search for a query. query = "What does it mean for two events to be simultaneous?" diff --git a/tests/test_search.py b/tests/test_search.py index f972422..6c116ee 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -19,11 +19,9 @@ pytest.param(keyword_search, id="keyword_search"), pytest.param(vector_search, id="vector_search"), pytest.param(hybrid_search, id="hybrid_search"), - ], + ] ) -def search_method( - request: pytest.FixtureRequest, -) -> BasicSearchMethod: +def search_method(request: pytest.FixtureRequest) -> BasicSearchMethod: """Get a search method to test RAGLite with.""" search_method: BasicSearchMethod = request.param return search_method @@ -57,8 +55,7 @@ def test_search(raglite_test_config: RAGLiteConfig, search_method: BasicSearchMe def test_search_no_results( - raglite_test_config: RAGLiteConfig, - search_method: BasicSearchMethod, + raglite_test_config: RAGLiteConfig, search_method: BasicSearchMethod ) -> None: """Test searching for a query with no keyword search results.""" query = "supercalifragilisticexpialidocious" diff --git a/tests/test_split_sentences.py b/tests/test_split_sentences.py index 8a40008..5122297 100644 --- a/tests/test_split_sentences.py +++ b/tests/test_split_sentences.py @@ -39,9 +39,7 @@ def test_split_sentences() -> None: assert all( sentence == expected_sentence for sentence, expected_sentence in zip( - sentences[: len(expected_sentences)], - expected_sentences, - strict=True, + sentences[: len(expected_sentences)], expected_sentences, strict=True ) ) @@ -63,8 +61,7 @@ def test_split_sentences() -> None: id="huge-2a", ), pytest.param( - ("X" * 768 + " " + "X" * 768, ["X" * 768 + " ", "X" * 768], (4, 1024)), - id="huge-2b", + ("X" * 768 + " " + "X" * 768, ["X" * 768 + " ", "X" * 768], (4, 1024)), id="huge-2b" ), ], ) @@ -76,8 +73,6 @@ def test_split_sentences_edge_cases(case: tuple[str, list[str], tuple[int, int | assert all( sentence == expected_sentence for sentence, expected_sentence in zip( - sentences[: len(expected_sentences)], - expected_sentences, - strict=True, + sentences[: len(expected_sentences)], expected_sentences, strict=True ) ) From 163455e881f6d7dd7e97d600b0b92bd184e2b6e5 Mon Sep 17 00:00:00 2001 From: Laurent Sorber Date: Sun, 15 Jun 2025 13:19:27 +0200 Subject: [PATCH 3/3] style: allow star arg any --- pyproject.toml | 3 +++ src/raglite/_chatml_function_calling.py | 2 +- src/raglite/_database.py | 17 +++++------------ src/raglite/_extract.py | 2 +- src/raglite/_lazy_llama.py | 2 +- 5 files changed, 11 insertions(+), 15 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ba77e24..7b496d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -149,6 +149,9 @@ select = ["ALL"] ignore = ["CPY", "FIX", "ARG001", "COM812", "D203", "D213", "E501", "PD008", "PD009", "RET504", "S101", "TD003"] unfixable = ["ERA001", "F401", "F841", "T201", "T203"] +[tool.ruff.lint.flake8-annotations] +allow-star-arg-any = true + [tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "all" diff --git a/src/raglite/_chatml_function_calling.py b/src/raglite/_chatml_function_calling.py index 9042080..54c0da1 100644 --- a/src/raglite/_chatml_function_calling.py +++ b/src/raglite/_chatml_function_calling.py @@ -232,7 +232,7 @@ def chatml_function_calling_with_streaming( grammar: Optional[llama.LlamaGrammar] = None, # type: ignore[name-defined] logprobs: Optional[bool] = None, top_logprobs: Optional[int] = None, - **kwargs: Any, # noqa: ANN401 + **kwargs: Any, ) -> Union[ llama_types.CreateChatCompletionResponse, Iterator[llama_types.CreateChatCompletionStreamResponse], diff --git a/src/raglite/_database.py b/src/raglite/_database.py index ceb9629..f4c76a8 100644 --- a/src/raglite/_database.py +++ b/src/raglite/_database.py @@ -68,7 +68,7 @@ class Document(SQLModel, table=True): chunks: list["Chunk"] = Relationship(back_populates="document", cascade_delete=True) evals: list["Eval"] = Relationship(back_populates="document", cascade_delete=True) - def __init__(self, **kwargs: Any) -> None: # noqa: ANN401 + def __init__(self, **kwargs: Any) -> None: # Workaround for https://github.com/fastapi/sqlmodel/issues/149. super().__init__(**kwargs) self.content = kwargs.get("content") @@ -87,7 +87,7 @@ def from_path( *, id: DocumentId | None = None, # noqa: A002 url: str | None = None, - **kwargs: Any, # noqa: ANN401 + **kwargs: Any, ) -> "Document": """Create a document from a file path. @@ -133,7 +133,7 @@ def from_text( id: DocumentId | None = None, # noqa: A002 url: str | None = None, filename: str | None = None, - **kwargs: Any, # noqa: ANN401 + **kwargs: Any, ) -> "Document": """Create a document from text content. @@ -197,11 +197,7 @@ class Chunk(SQLModel, table=True): @staticmethod def from_body( - document: Document, - index: int, - body: str, - headings: str = "", - **kwargs: Any, # noqa: ANN401 + document: Document, index: int, body: str, headings: str = "", **kwargs: Any ) -> "Chunk": """Create a chunk from Markdown.""" return Chunk( @@ -457,10 +453,7 @@ class Eval(SQLModel, table=True): @staticmethod def from_chunks( - question: str, - contexts: list[Chunk], - ground_truth: str, - **kwargs: Any, # noqa: ANN401 + question: str, contexts: list[Chunk], ground_truth: str, **kwargs: Any ) -> "Eval": """Create a chunk from Markdown.""" document_id = contexts[0].document_id diff --git a/src/raglite/_extract.py b/src/raglite/_extract.py index cac1a6b..328f0ac 100644 --- a/src/raglite/_extract.py +++ b/src/raglite/_extract.py @@ -15,7 +15,7 @@ def extract_with_llm( user_prompt: str | list[str], strict: bool = False, # noqa: FBT001, FBT002 config: RAGLiteConfig | None = None, - **kwargs: Any, # noqa: ANN401 + **kwargs: Any, ) -> T: """Extract structured data from unstructured text with an LLM. diff --git a/src/raglite/_lazy_llama.py b/src/raglite/_lazy_llama.py index c3a1040..3e6e238 100644 --- a/src/raglite/_lazy_llama.py +++ b/src/raglite/_lazy_llama.py @@ -42,7 +42,7 @@ def __init__(self, error: ModuleNotFoundError | None = None) -> None: def __getattr__(self, name: str) -> NoReturn: raise ModuleNotFoundError(self.error_message) from self.error - def __call__(self, *args: Any, **kwargs: Any) -> NoReturn: # noqa: ARG002, ANN401 + def __call__(self, *args: Any, **kwargs: Any) -> NoReturn: # noqa: ARG002 raise ModuleNotFoundError(self.error_message) from self.error class LazySubmoduleError: