Skip to content

Commit f4c295a

Browse files
committed
Add integration tests for topic extraction
1 parent 403a5aa commit f4c295a

File tree

1 file changed

+104
-0
lines changed

1 file changed

+104
-0
lines changed

tests/test_extraction.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from agent_memory_server.extraction import (
88
extract_entities,
99
extract_topics_bertopic,
10+
extract_topics_llm,
1011
handle_extraction,
1112
)
1213

@@ -135,3 +136,106 @@ async def test_handle_extraction_disabled_features(
135136
# Restore settings
136137
settings.enable_topic_extraction = original_topic_setting
137138
settings.enable_ner = original_ner_setting
139+
140+
141+
@pytest.mark.requires_api_keys
142+
class TestTopicExtractionIntegration:
143+
@pytest.mark.asyncio
144+
async def test_bertopic_integration(self):
145+
"""Integration test for BERTopic topic extraction (skipped if not available)"""
146+
147+
# Save and set topic_model_source
148+
original_source = settings.topic_model_source
149+
settings.topic_model_source = "BERTopic"
150+
sample_text = (
151+
"OpenAI and Google are leading companies in artificial intelligence."
152+
)
153+
try:
154+
try:
155+
# Try to import BERTopic and check model loading
156+
topics = extract_topics_bertopic(sample_text)
157+
# print(f"[DEBUG] BERTopic returned topics: {topics}")
158+
except Exception as e:
159+
pytest.skip(f"BERTopic integration test skipped: {e}")
160+
assert isinstance(topics, list)
161+
expected_keywords = {
162+
"generative",
163+
"transformer",
164+
"neural",
165+
"learning",
166+
"trained",
167+
"multimodal",
168+
"generates",
169+
"models",
170+
"encoding",
171+
"text",
172+
}
173+
assert any(t.lower() in expected_keywords for t in topics)
174+
finally:
175+
settings.topic_model_source = original_source
176+
177+
@pytest.mark.asyncio
178+
async def test_llm_integration(self):
179+
"""Integration test for LLM-based topic extraction (skipped if no API key)"""
180+
181+
# Save and set topic_model_source
182+
original_source = settings.topic_model_source
183+
settings.topic_model_source = "LLM"
184+
sample_text = (
185+
"OpenAI and Google are leading companies in artificial intelligence."
186+
)
187+
try:
188+
# Check for API key
189+
if not (settings.openai_api_key or settings.anthropic_api_key):
190+
pytest.skip("No LLM API key available for integration test.")
191+
topics = await extract_topics_llm(sample_text)
192+
assert isinstance(topics, list)
193+
assert any(
194+
t.lower() in ["technology", "business", "artificial intelligence"]
195+
for t in topics
196+
)
197+
finally:
198+
settings.topic_model_source = original_source
199+
200+
201+
class TestHandleExtractionPathSelection:
202+
@pytest.mark.asyncio
203+
@patch("agent_memory_server.extraction.extract_topics_bertopic")
204+
@patch("agent_memory_server.extraction.extract_topics_llm")
205+
async def test_handle_extraction_path_selection(
206+
self, mock_extract_topics_llm, mock_extract_topics_bertopic
207+
):
208+
"""Test that handle_extraction uses the correct extraction path based on settings.topic_model_source"""
209+
210+
sample_text = (
211+
"OpenAI and Google are leading companies in artificial intelligence."
212+
)
213+
original_source = settings.topic_model_source
214+
original_enable_topic_extraction = settings.enable_topic_extraction
215+
original_enable_ner = settings.enable_ner
216+
try:
217+
# Enable topic extraction and disable NER for clarity
218+
settings.enable_topic_extraction = True
219+
settings.enable_ner = False
220+
221+
# Test BERTopic path
222+
settings.topic_model_source = "BERTopic"
223+
mock_extract_topics_bertopic.return_value = ["technology"]
224+
mock_extract_topics_llm.return_value = ["should not be called"]
225+
topics, _ = await handle_extraction(sample_text)
226+
mock_extract_topics_bertopic.assert_called_once()
227+
mock_extract_topics_llm.assert_not_called()
228+
assert topics == ["technology"]
229+
mock_extract_topics_bertopic.reset_mock()
230+
231+
# Test LLM path
232+
settings.topic_model_source = "LLM"
233+
mock_extract_topics_llm.return_value = ["ai"]
234+
topics, _ = await handle_extraction(sample_text)
235+
mock_extract_topics_llm.assert_called_once()
236+
mock_extract_topics_bertopic.assert_not_called()
237+
assert topics == ["ai"]
238+
finally:
239+
settings.topic_model_source = original_source
240+
settings.enable_topic_extraction = original_enable_topic_extraction
241+
settings.enable_ner = original_enable_ner

0 commit comments

Comments
 (0)