Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 47 additions & 16 deletions convokit/genai/llmprompttransformer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Optional, Union, Callable, Dict, Any
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Optional, Union, Callable, Dict, Any, Iterable
from convokit import Transformer, Corpus, Conversation, Speaker, Utterance
from tqdm.auto import tqdm
from .factory import get_llm_client
from .genai_config import GenAIConfigManager

Expand All @@ -20,6 +22,7 @@ class LLMPromptTransformer(Transformer):
:param selector: Optional function to filter which objects to process. Defaults to processing all objects
:param config_manager: GenAIConfigManager instance for LLM API key management
:param llm_kwargs: Additional keyword arguments to pass to the LLM client
:param num_workers: Number of worker threads to use for LLM calls. Defaults to 1 (sequential)
"""

def __init__(
Expand All @@ -35,6 +38,7 @@ def __init__(
] = None,
config_manager: Optional[GenAIConfigManager] = None,
llm_kwargs: Optional[Dict[str, Any]] = None,
num_workers: int = 1,
):
self.provider = provider
self.model = model
Expand All @@ -45,10 +49,14 @@ def __init__(
self.selector = selector or (lambda obj: True)
self.config_manager = config_manager or GenAIConfigManager()
self.llm_kwargs = llm_kwargs or {}
self.num_workers = num_workers

if model is not None:
self.llm_kwargs["model"] = model

if num_workers < 1:
raise ValueError("num_workers must be at least 1")

if object_level not in ["conversation", "speaker", "utterance", "corpus"]:
raise ValueError(
f"Invalid object_level: {object_level}. Must be one of: conversation, speaker, utterance, corpus"
Expand Down Expand Up @@ -88,6 +96,39 @@ def _process_object(self, obj: Union[Corpus, Conversation, Speaker, Utterance])
print(f"Error processing {self.object_level} {obj.id}: {e}")
obj.add_meta(self.metadata_name, None)

def _transform_object(self, obj: Union[Corpus, Conversation, Speaker, Utterance]) -> None:
"""
Apply selector logic and process a single object.

:param obj: Object to transform
"""
if self.selector(obj):
self._process_object(obj)
else:
obj.add_meta(self.metadata_name, None)

def _transform_objects(
self,
objects: Iterable[Union[Conversation, Speaker, Utterance]],
desc: str,
) -> None:
"""
Transform a collection of objects, optionally in parallel.

:param objects: Objects to transform
:param desc: Progress bar description
"""
if self.num_workers == 1:
for obj in tqdm(objects, desc=desc):
self._transform_object(obj)
return

objects = list(objects)
with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
futures = [executor.submit(self._transform_object, obj) for obj in objects]
for future in tqdm(as_completed(futures), total=len(futures), desc=desc):
future.result()

def transform(self, corpus: Corpus) -> Corpus:
"""
Apply the GenAI transformer to the corpus.
Expand All @@ -96,25 +137,15 @@ def transform(self, corpus: Corpus) -> Corpus:
:return: The transformed corpus with LLM responses added as metadata
"""
if self.object_level == "utterance":
for utterance in corpus.iter_utterances():
if self.selector(utterance):
self._process_object(utterance)
else:
utterance.add_meta(self.metadata_name, None)
self._transform_objects(corpus.iter_utterances(), "Applying LLM prompt to utterances")

elif self.object_level == "conversation":
for conversation in corpus.iter_conversations():
if self.selector(conversation):
self._process_object(conversation)
else:
conversation.add_meta(self.metadata_name, None)
self._transform_objects(
corpus.iter_conversations(), "Applying LLM prompt to conversations"
)

elif self.object_level == "speaker":
for speaker in corpus.iter_speakers():
if self.selector(speaker):
self._process_object(speaker)
else:
speaker.add_meta(self.metadata_name, None)
self._transform_objects(corpus.iter_speakers(), "Applying LLM prompt to speakers")

elif self.object_level == "corpus":
if self.selector(corpus):
Expand Down
Loading