diff --git a/convokit/genai/llmprompttransformer.py b/convokit/genai/llmprompttransformer.py index 1b48ea7a..241db40c 100644 --- a/convokit/genai/llmprompttransformer.py +++ b/convokit/genai/llmprompttransformer.py @@ -1,5 +1,7 @@ -from typing import Optional, Union, Callable, Dict, Any +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Optional, Union, Callable, Dict, Any, Iterable from convokit import Transformer, Corpus, Conversation, Speaker, Utterance +from tqdm.auto import tqdm from .factory import get_llm_client from .genai_config import GenAIConfigManager @@ -20,6 +22,7 @@ class LLMPromptTransformer(Transformer): :param selector: Optional function to filter which objects to process. Defaults to processing all objects :param config_manager: GenAIConfigManager instance for LLM API key management :param llm_kwargs: Additional keyword arguments to pass to the LLM client + :param num_workers: Number of worker threads to use for LLM calls. Defaults to 1 (sequential) """ def __init__( @@ -35,6 +38,7 @@ def __init__( ] = None, config_manager: Optional[GenAIConfigManager] = None, llm_kwargs: Optional[Dict[str, Any]] = None, + num_workers: int = 1, ): self.provider = provider self.model = model @@ -45,10 +49,14 @@ def __init__( self.selector = selector or (lambda obj: True) self.config_manager = config_manager or GenAIConfigManager() self.llm_kwargs = llm_kwargs or {} + self.num_workers = num_workers if model is not None: self.llm_kwargs["model"] = model + if num_workers < 1: + raise ValueError("num_workers must be at least 1") + if object_level not in ["conversation", "speaker", "utterance", "corpus"]: raise ValueError( f"Invalid object_level: {object_level}. Must be one of: conversation, speaker, utterance, corpus" @@ -88,6 +96,39 @@ def _process_object(self, obj: Union[Corpus, Conversation, Speaker, Utterance]) print(f"Error processing {self.object_level} {obj.id}: {e}") obj.add_meta(self.metadata_name, None) + def _transform_object(self, obj: Union[Corpus, Conversation, Speaker, Utterance]) -> None: + """ + Apply selector logic and process a single object. + + :param obj: Object to transform + """ + if self.selector(obj): + self._process_object(obj) + else: + obj.add_meta(self.metadata_name, None) + + def _transform_objects( + self, + objects: Iterable[Union[Conversation, Speaker, Utterance]], + desc: str, + ) -> None: + """ + Transform a collection of objects, optionally in parallel. + + :param objects: Objects to transform + :param desc: Progress bar description + """ + if self.num_workers == 1: + for obj in tqdm(objects, desc=desc): + self._transform_object(obj) + return + + objects = list(objects) + with ThreadPoolExecutor(max_workers=self.num_workers) as executor: + futures = [executor.submit(self._transform_object, obj) for obj in objects] + for future in tqdm(as_completed(futures), total=len(futures), desc=desc): + future.result() + def transform(self, corpus: Corpus) -> Corpus: """ Apply the GenAI transformer to the corpus. @@ -96,25 +137,15 @@ def transform(self, corpus: Corpus) -> Corpus: :return: The transformed corpus with LLM responses added as metadata """ if self.object_level == "utterance": - for utterance in corpus.iter_utterances(): - if self.selector(utterance): - self._process_object(utterance) - else: - utterance.add_meta(self.metadata_name, None) + self._transform_objects(corpus.iter_utterances(), "Applying LLM prompt to utterances") elif self.object_level == "conversation": - for conversation in corpus.iter_conversations(): - if self.selector(conversation): - self._process_object(conversation) - else: - conversation.add_meta(self.metadata_name, None) + self._transform_objects( + corpus.iter_conversations(), "Applying LLM prompt to conversations" + ) elif self.object_level == "speaker": - for speaker in corpus.iter_speakers(): - if self.selector(speaker): - self._process_object(speaker) - else: - speaker.add_meta(self.metadata_name, None) + self._transform_objects(corpus.iter_speakers(), "Applying LLM prompt to speakers") elif self.object_level == "corpus": if self.selector(corpus):