Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions verbatim_core/extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ def __init__(
extraction_mode: str = "auto",
max_display_spans: int = 5,
batch_size: int = 5,
verify_spans: bool = True,
):
"""
Initialize the LLM span extractor.
Expand All @@ -391,11 +392,14 @@ def __init__(
:param extraction_mode: "batch", "individual", or "auto"
:param max_display_spans: Maximum spans to prioritize for display
:param batch_size: Maximum documents to process in batch mode
:param verify_spans: Whether to validate that extracted spans exist
in the source document. If False, all extracted spans are returned.
"""
self.llm_client = llm_client or LLMClient(model)
self.extraction_mode = extraction_mode
self.max_display_spans = max_display_spans
self.batch_size = batch_size
self.verify_spans = verify_spans

def extract_spans(
self, question: str, search_results: List[Any]
Expand Down Expand Up @@ -470,8 +474,10 @@ def _extract_spans_batch(
doc_key = f"doc_{i}"
result_text = getattr(result, "text", "")
if doc_key in extracted_data:
verified = self._verify_spans(extracted_data[doc_key], result_text)
verified_spans[result_text] = verified
spans = extracted_data[doc_key]
if self.verify_spans:
spans = self._verify_spans(spans, result_text)
verified_spans[result_text] = spans
else:
verified_spans[result_text] = []

Expand Down Expand Up @@ -510,8 +516,10 @@ async def _extract_spans_batch_async(
doc_key = f"doc_{i}"
result_text = getattr(result, "text", "")
if doc_key in extracted_data:
verified = self._verify_spans(extracted_data[doc_key], result_text)
verified_spans[result_text] = verified
spans = extracted_data[doc_key]
if self.verify_spans:
spans = self._verify_spans(spans, result_text)
verified_spans[result_text] = spans
else:
verified_spans[result_text] = []

Expand Down Expand Up @@ -543,8 +551,9 @@ def _extract_spans_individual(
extracted_spans = self.llm_client.extract_relevant_spans(
question, result_text
)
verified = self._verify_spans(extracted_spans, result_text)
all_spans[result_text] = verified
if self.verify_spans:
extracted_spans = self._verify_spans(extracted_spans, result_text)
all_spans[result_text] = extracted_spans
except Exception as e:
print(f"Individual extraction failed for document: {e}")
all_spans[result_text] = []
Expand All @@ -566,8 +575,9 @@ async def _extract_spans_individual_async(
extracted_spans = await self.llm_client.extract_relevant_spans_async(
question, result_text
)
verified = self._verify_spans(extracted_spans, result_text)
all_spans[result_text] = verified
if self.verify_spans:
extracted_spans = self._verify_spans(extracted_spans, result_text)
all_spans[result_text] = extracted_spans
except Exception as e:
print(f"Async individual extraction failed for document: {e}")
all_spans[result_text] = []
Expand Down
26 changes: 25 additions & 1 deletion verbatim_core/llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,22 @@ def __init__(
model: str = "gpt-4o-mini",
temperature: float = 0.7,
api_base: str = "https://api.openai.com/v1",
response_log_file: str | None = None,
):
"""
Initialize the LLM client.

:param model: The OpenAI model to use
:param temperature: Default temperature for completions
:param api_base: The base URL for the OpenAI API (can be used with custom models and with VLLM)
:param response_log_file: Path to file for logging raw LLM responses (JSON format).
"""
self.model = model
self.temperature = temperature
self.api_key = os.getenv("OPENAI_API_KEY") or "EMPTY"
self.client = openai.OpenAI(base_url=api_base, api_key=self.api_key)

self.async_client = openai.AsyncOpenAI(base_url=api_base, api_key=self.api_key)
self.response_log_file = response_log_file

def complete(
self, prompt: str, json_mode: bool = False, temperature: Optional[float] = None
Expand Down Expand Up @@ -104,6 +106,8 @@ def extract_spans(
prompt = self._build_extraction_prompt(question, documents)
try:
response = self.complete(prompt, json_mode=True)
if self.response_log_file:
self._log_response(question, documents, response)
return json.loads(response)
except (json.JSONDecodeError, KeyError) as e:
print(f"Span extraction failed: {e}")
Expand All @@ -123,6 +127,8 @@ async def extract_spans_async(
prompt = self._build_extraction_prompt(question, documents)
try:
response = await self.complete_async(prompt, json_mode=True)
if self.response_log_file:
self._log_response(question, documents, response)
return json.loads(response)
except (json.JSONDecodeError, KeyError) as e:
print(f"Async span extraction failed: {e}")
Expand Down Expand Up @@ -309,6 +315,24 @@ async def generate_template_async(
print(f"Async template generation failed: {e}")
return self._fallback_template(citation_count > 0)

def _log_response(
self, question: str, documents: Dict[str, str], response: str
) -> None:
"""
Log raw LLM response to file in JSON format.

:param question: The question that was asked
:param documents: The documents that were sent to the LLM
:param response: The raw response string from the LLM
"""
log_entry = {
"question": question,
"documents": documents,
"response": response,
}
with open(self.response_log_file, "a") as f:
f.write(json.dumps(log_entry) + "\n")

def _build_extraction_prompt(self, question: str, documents: Dict[str, str]) -> str:
"""Build the prompt for batch span extraction."""
return f"""Extract EXACT verbatim text spans from multiple documents that answer the question.
Expand Down