|
7 | 7 | from agent_memory_server.extraction import (
|
8 | 8 | extract_entities,
|
9 | 9 | extract_topics_bertopic,
|
| 10 | + extract_topics_llm, |
10 | 11 | handle_extraction,
|
11 | 12 | )
|
12 | 13 |
|
@@ -135,3 +136,106 @@ async def test_handle_extraction_disabled_features(
|
135 | 136 | # Restore settings
|
136 | 137 | settings.enable_topic_extraction = original_topic_setting
|
137 | 138 | settings.enable_ner = original_ner_setting
|
| 139 | + |
| 140 | + |
| 141 | +@pytest.mark.requires_api_keys |
| 142 | +class TestTopicExtractionIntegration: |
| 143 | + @pytest.mark.asyncio |
| 144 | + async def test_bertopic_integration(self): |
| 145 | + """Integration test for BERTopic topic extraction (skipped if not available)""" |
| 146 | + |
| 147 | + # Save and set topic_model_source |
| 148 | + original_source = settings.topic_model_source |
| 149 | + settings.topic_model_source = "BERTopic" |
| 150 | + sample_text = ( |
| 151 | + "OpenAI and Google are leading companies in artificial intelligence." |
| 152 | + ) |
| 153 | + try: |
| 154 | + try: |
| 155 | + # Try to import BERTopic and check model loading |
| 156 | + topics = extract_topics_bertopic(sample_text) |
| 157 | + # print(f"[DEBUG] BERTopic returned topics: {topics}") |
| 158 | + except Exception as e: |
| 159 | + pytest.skip(f"BERTopic integration test skipped: {e}") |
| 160 | + assert isinstance(topics, list) |
| 161 | + expected_keywords = { |
| 162 | + "generative", |
| 163 | + "transformer", |
| 164 | + "neural", |
| 165 | + "learning", |
| 166 | + "trained", |
| 167 | + "multimodal", |
| 168 | + "generates", |
| 169 | + "models", |
| 170 | + "encoding", |
| 171 | + "text", |
| 172 | + } |
| 173 | + assert any(t.lower() in expected_keywords for t in topics) |
| 174 | + finally: |
| 175 | + settings.topic_model_source = original_source |
| 176 | + |
| 177 | + @pytest.mark.asyncio |
| 178 | + async def test_llm_integration(self): |
| 179 | + """Integration test for LLM-based topic extraction (skipped if no API key)""" |
| 180 | + |
| 181 | + # Save and set topic_model_source |
| 182 | + original_source = settings.topic_model_source |
| 183 | + settings.topic_model_source = "LLM" |
| 184 | + sample_text = ( |
| 185 | + "OpenAI and Google are leading companies in artificial intelligence." |
| 186 | + ) |
| 187 | + try: |
| 188 | + # Check for API key |
| 189 | + if not (settings.openai_api_key or settings.anthropic_api_key): |
| 190 | + pytest.skip("No LLM API key available for integration test.") |
| 191 | + topics = await extract_topics_llm(sample_text) |
| 192 | + assert isinstance(topics, list) |
| 193 | + assert any( |
| 194 | + t.lower() in ["technology", "business", "artificial intelligence"] |
| 195 | + for t in topics |
| 196 | + ) |
| 197 | + finally: |
| 198 | + settings.topic_model_source = original_source |
| 199 | + |
| 200 | + |
| 201 | +class TestHandleExtractionPathSelection: |
| 202 | + @pytest.mark.asyncio |
| 203 | + @patch("agent_memory_server.extraction.extract_topics_bertopic") |
| 204 | + @patch("agent_memory_server.extraction.extract_topics_llm") |
| 205 | + async def test_handle_extraction_path_selection( |
| 206 | + self, mock_extract_topics_llm, mock_extract_topics_bertopic |
| 207 | + ): |
| 208 | + """Test that handle_extraction uses the correct extraction path based on settings.topic_model_source""" |
| 209 | + |
| 210 | + sample_text = ( |
| 211 | + "OpenAI and Google are leading companies in artificial intelligence." |
| 212 | + ) |
| 213 | + original_source = settings.topic_model_source |
| 214 | + original_enable_topic_extraction = settings.enable_topic_extraction |
| 215 | + original_enable_ner = settings.enable_ner |
| 216 | + try: |
| 217 | + # Enable topic extraction and disable NER for clarity |
| 218 | + settings.enable_topic_extraction = True |
| 219 | + settings.enable_ner = False |
| 220 | + |
| 221 | + # Test BERTopic path |
| 222 | + settings.topic_model_source = "BERTopic" |
| 223 | + mock_extract_topics_bertopic.return_value = ["technology"] |
| 224 | + mock_extract_topics_llm.return_value = ["should not be called"] |
| 225 | + topics, _ = await handle_extraction(sample_text) |
| 226 | + mock_extract_topics_bertopic.assert_called_once() |
| 227 | + mock_extract_topics_llm.assert_not_called() |
| 228 | + assert topics == ["technology"] |
| 229 | + mock_extract_topics_bertopic.reset_mock() |
| 230 | + |
| 231 | + # Test LLM path |
| 232 | + settings.topic_model_source = "LLM" |
| 233 | + mock_extract_topics_llm.return_value = ["ai"] |
| 234 | + topics, _ = await handle_extraction(sample_text) |
| 235 | + mock_extract_topics_llm.assert_called_once() |
| 236 | + mock_extract_topics_bertopic.assert_not_called() |
| 237 | + assert topics == ["ai"] |
| 238 | + finally: |
| 239 | + settings.topic_model_source = original_source |
| 240 | + settings.enable_topic_extraction = original_enable_topic_extraction |
| 241 | + settings.enable_ner = original_enable_ner |
0 commit comments