diff --git a/src/crud/document.py b/src/crud/document.py index 11eb121d4..ed3813409 100644 --- a/src/crud/document.py +++ b/src/crud/document.py @@ -428,6 +428,19 @@ async def query_documents( return docs +def _normalize_content(content: str) -> str: + """Normalize document content for exact-match deduplication. + + Content is compared after trimming surrounding whitespace and lowercasing + + The SQL filter in ``create_documents`` must stay in sync with this: + ``lower(regexp_replace(content, '^\\s+|\\s+$', '', 'g'))``. Postgres' + ``trim()`` only strips spaces, so a regex is used to match Python's + ``str.strip()`` across all whitespace. + """ + return content.strip().lower() + + async def create_documents( db: AsyncSession, documents: list[schemas.DocumentCreate], @@ -440,12 +453,17 @@ async def create_documents( """ Create multiple documents with optional duplicate detection. + The ``deduplicate`` flag additionally enables semantic (cosine-similarity) + dedup via ``is_rejected_duplicate`` for documents that survive the exact + deduplication check. + Args: db: Database session documents: List of document creation schemas workspace_name: Name of the workspace observer: Name of the observing peer observed: Name of the observed peer + deduplicate: Enable semantic duplicate detection Returns: List of DocumentCreate schemas that were actually inserted (excludes @@ -456,8 +474,76 @@ async def create_documents( # Store (document_model, embedding) pairs - IDs aren't available until after commit docs_with_embeddings: list[tuple[models.Document, list[float]]] = [] + # exact-content dedup (independent of `deduplicate`): pre-fetch + # existing live documents whose normalized content matches anything in this + # batch, scoped to (workspace, observer, observed). The SQL normalization must + # mirror _normalize_content. + batch_normalized: set[str] = {_normalize_content(d.content) for d in documents} + existing_by_normalized: dict[str, models.Document] = {} + if batch_normalized: + # The `normalized_content_sql.in_(...)` filter below narrows to the + # (workspace, observer, observed) partition via the single-column indexes, + # then evaluates lower(regexp_replace(...)) per row. + # TODO: add a partial expression index matching + # this filter exactly + # CREATE INDEX ix_documents_normalized_content + # ON documents ( + # workspace_name, + # observer, + # observed, + # (lower(regexp_replace(content, '^\s+|\s+$', '', 'g'))) + # ) + # WHERE deleted_at IS NULL; + normalized_content_sql = func.lower( + func.regexp_replace(models.Document.content, r"^\s+|\s+$", "", "g") + ) + existing_result = await db.execute( + select(models.Document).where( + models.Document.workspace_name == workspace_name, + models.Document.observer == observer, + models.Document.observed == observed, + models.Document.deleted_at.is_(None), + normalized_content_sql.in_(batch_normalized), + ) + ) + for existing_doc in existing_result.scalars(): + # If multiple historical rows share normalized content, reinforcing + # one is sufficient; keep the first. + existing_by_normalized.setdefault( + _normalize_content(existing_doc.content), existing_doc + ) + + # Tracks normalized content already accepted from this batch so exact + # duplicates within a single inference call collapse to one document. + seen_in_batch: set[str] = set() + for doc in documents: try: + normalized_content = _normalize_content(doc.content) + + # Exact-match dedup, always on: + # 1) collapse exact duplicates within this batch (drop silently). + if normalized_content in seen_in_batch: + continue + seen_in_batch.add(normalized_content) + + # 2) drop exact duplicates of an existing live document, recording + # the re-derivation as reinforcement on the existing row. + existing_match = existing_by_normalized.get(normalized_content) + if existing_match is not None: + # Reinforce the existing row. greatest(...) keeps the bump atomic + # server-side (concurrent workers can't lose an increment) while + # still honoring an incoming doc that already carries accumulated + # reinforcement (times_derived > 1, e.g. a future re-ingestion or + # collection-merge path). Mirrors the superior-replacement branch + # in is_rejected_duplicate. + existing_match.times_derived = func.greatest( + models.Document.times_derived + 1, + doc.times_derived, + ) + await db.flush() + continue + # for each document, if deduplicate is True, perform a process # that checks against existing documents and either rejects this document # as a duplicate OR deletes an existing document that is a duplicate. @@ -1038,10 +1124,14 @@ async def is_rejected_duplicate( return False # Don't reject the new document # Existing document has more information, reject the new one but record the - # reinforcement: a semantic duplicate was derived again. Assign a SQL - # expression so the increment is atomic server-side -- concurrent workers - # reinforcing the same document must not lose updates. - existing_doc.times_derived = models.Document.times_derived + 1 + # reinforcement: a semantic duplicate was derived again. greatest(...) keeps + # the increment atomic server-side -- concurrent workers reinforcing the same + # document must not lose updates -- while still honoring an incoming doc that + # already carries accumulated reinforcement (times_derived > 1). + existing_doc.times_derived = func.greatest( + models.Document.times_derived + 1, + doc.times_derived, + ) await db.flush() logger.debug( "[DUPLICATE DETECTION] Rejecting new in favor of existing. new=%r, existing=%r.", diff --git a/tests/crud/test_document.py b/tests/crud/test_document.py index ccde3dac0..1258b5229 100644 --- a/tests/crud/test_document.py +++ b/tests/crud/test_document.py @@ -465,6 +465,338 @@ async def test_duplicate_replacement_carries_count_forward( # Original is soft-deleted; replacement isn't inserted until create_documents runs. assert len(live) == 0 + @pytest.mark.asyncio + async def test_exact_dedup_within_batch_drops_repeat( + self, + db_session: AsyncSession, + sample_data: tuple[models.Workspace, models.Peer], + ): + """Exact (case/whitespace-insensitive) duplicates within a single batch + collapse to one document, even with semantic dedup disabled.""" + test_workspace, test_peer = sample_data + test_peer2, test_session, _ = await self._setup_test_data( + db_session, test_workspace, test_peer + ) + + # Three "exact" matches that differ only by case/surrounding whitespace. + doc_schemas = [ + schemas.DocumentCreate( + content="User likes coffee", + embedding=[0.1] * 1536, + session_name=test_session.name, + metadata=schemas.DocumentMetadata( + message_ids=[1], + message_created_at="2026-01-01T00:00:00Z", + ), + ), + schemas.DocumentCreate( + content="user likes coffee", + embedding=[0.2] * 1536, + session_name=test_session.name, + metadata=schemas.DocumentMetadata( + message_ids=[2], + message_created_at="2026-01-01T00:01:00Z", + ), + ), + schemas.DocumentCreate( + content=" User likes coffee\n", + embedding=[0.3] * 1536, + session_name=test_session.name, + metadata=schemas.DocumentMetadata( + message_ids=[3], + message_created_at="2026-01-01T00:02:00Z", + ), + ), + ] + + accepted = await crud.create_documents( + db_session, + documents=doc_schemas, + workspace_name=test_workspace.name, + observer=test_peer.name, + observed=test_peer2.name, + deduplicate=False, + ) + + assert len(accepted) == 1 + live = ( + ( + await db_session.execute( + select(models.Document).where( + models.Document.workspace_name == test_workspace.name, + models.Document.observer == test_peer.name, + models.Document.observed == test_peer2.name, + models.Document.deleted_at.is_(None), + ) + ) + ) + .scalars() + .all() + ) + assert len(live) == 1 + # Within-batch repeats are dropped silently, no reinforcement. + assert live[0].times_derived == 1 + + @pytest.mark.asyncio + async def test_exact_dedup_against_existing_reinforces( + self, + db_session: AsyncSession, + sample_data: tuple[models.Workspace, models.Peer], + ): + """An exact match of an existing live document is rejected and reinforces + the existing row, even with semantic dedup disabled.""" + test_workspace, test_peer = sample_data + test_peer2, test_session, _ = await self._setup_test_data( + db_session, test_workspace, test_peer + ) + + await crud.create_documents( + db_session, + [ + schemas.DocumentCreate( + content="User likes coffee", + embedding=[0.1] * 1536, + session_name=test_session.name, + times_derived=1, + metadata=schemas.DocumentMetadata( + message_ids=[1], + message_created_at="2026-01-01T00:00:00Z", + ), + ) + ], + workspace_name=test_workspace.name, + observer=test_peer.name, + observed=test_peer2.name, + deduplicate=False, + ) + + # Case/whitespace variant of the existing content -> exact match. + accepted = await crud.create_documents( + db_session, + [ + schemas.DocumentCreate( + content="user likes coffee ", + embedding=[0.9] * 1536, + session_name=test_session.name, + times_derived=1, + metadata=schemas.DocumentMetadata( + message_ids=[2], + message_created_at="2026-01-02T00:00:00Z", + ), + ) + ], + workspace_name=test_workspace.name, + observer=test_peer.name, + observed=test_peer2.name, + deduplicate=False, + ) + + assert len(accepted) == 0 + surviving = ( + ( + await db_session.execute( + select(models.Document).where( + models.Document.workspace_name == test_workspace.name, + models.Document.observer == test_peer.name, + models.Document.observed == test_peer2.name, + models.Document.deleted_at.is_(None), + ) + ) + ) + .scalars() + .all() + ) + assert len(surviving) == 1 + assert surviving[0].content == "User likes coffee" + assert surviving[0].times_derived == 2 + + @pytest.mark.asyncio + async def test_exact_dedup_honors_incoming_times_derived( + self, + db_session: AsyncSession, + sample_data: tuple[models.Workspace, models.Peer], + ): + """Reinforcement folds in an incoming doc that already carries + accumulated reinforcement: the existing row becomes + ``greatest(existing + 1, incoming)``.""" + test_workspace, test_peer = sample_data + test_peer2, test_session, _ = await self._setup_test_data( + db_session, test_workspace, test_peer + ) + + async def _live() -> list[models.Document]: + return list( + ( + await db_session.execute( + select(models.Document).where( + models.Document.workspace_name == test_workspace.name, + models.Document.observer == test_peer.name, + models.Document.observed == test_peer2.name, + models.Document.deleted_at.is_(None), + ) + ) + ) + .scalars() + .all() + ) + + # Existing row already reinforced twice. + await crud.create_documents( + db_session, + [ + schemas.DocumentCreate( + content="User likes coffee", + embedding=[0.1] * 1536, + session_name=test_session.name, + times_derived=2, + metadata=schemas.DocumentMetadata( + message_ids=[1], + message_created_at="2026-01-01T00:00:00Z", + ), + ) + ], + workspace_name=test_workspace.name, + observer=test_peer.name, + observed=test_peer2.name, + deduplicate=False, + ) + + # Incoming exact match claims more accumulated reinforcement (5) than + # existing + 1 (3) -> incoming wins. + accepted = await crud.create_documents( + db_session, + [ + schemas.DocumentCreate( + content="user likes coffee ", + embedding=[0.9] * 1536, + session_name=test_session.name, + times_derived=5, + metadata=schemas.DocumentMetadata( + message_ids=[2], + message_created_at="2026-01-02T00:00:00Z", + ), + ) + ], + workspace_name=test_workspace.name, + observer=test_peer.name, + observed=test_peer2.name, + deduplicate=False, + ) + assert len(accepted) == 0 + live = await _live() + assert len(live) == 1 + assert live[0].times_derived == 5 + + # A normal re-derivation (times_derived defaults to 1) now bumps by one: + # greatest(existing + 1, 1) -> existing + 1. + accepted = await crud.create_documents( + db_session, + [ + schemas.DocumentCreate( + content="USER LIKES COFFEE", + embedding=[0.4] * 1536, + session_name=test_session.name, + metadata=schemas.DocumentMetadata( + message_ids=[3], + message_created_at="2026-01-03T00:00:00Z", + ), + ) + ], + workspace_name=test_workspace.name, + observer=test_peer.name, + observed=test_peer2.name, + deduplicate=False, + ) + assert len(accepted) == 0 + live = await _live() + assert len(live) == 1 + assert live[0].times_derived == 6 + + @pytest.mark.asyncio + async def test_exact_dedup_flushes_before_semantic_replacement( + self, + db_session: AsyncSession, + sample_data: tuple[models.Workspace, models.Peer], + ): + """An exact-match reinforcement in a batch must be visible to a later + semantic replacement of the same existing row when autoflush is off.""" + test_workspace, test_peer = sample_data + test_peer2, test_session, _ = await self._setup_test_data( + db_session, test_workspace, test_peer + ) + + await crud.create_documents( + db_session, + [ + schemas.DocumentCreate( + content="User likes coffee", + embedding=[0.5] * 1536, + session_name=test_session.name, + times_derived=1, + metadata=schemas.DocumentMetadata( + message_ids=[1], + message_created_at="2026-01-01T00:00:00Z", + ), + ) + ], + workspace_name=test_workspace.name, + observer=test_peer.name, + observed=test_peer2.name, + deduplicate=False, + ) + + db_session.autoflush = False + accepted = await crud.create_documents( + db_session, + [ + schemas.DocumentCreate( + content=" user likes coffee ", + embedding=[0.5] * 1536, + session_name=test_session.name, + times_derived=1, + metadata=schemas.DocumentMetadata( + message_ids=[2], + message_created_at="2026-01-02T00:00:00Z", + ), + ), + schemas.DocumentCreate( + content="User likes coffee and tea", + embedding=[0.5] * 1536, + session_name=test_session.name, + times_derived=1, + metadata=schemas.DocumentMetadata( + message_ids=[3], + message_created_at="2026-01-03T00:00:00Z", + ), + ), + ], + workspace_name=test_workspace.name, + observer=test_peer.name, + observed=test_peer2.name, + deduplicate=True, + ) + + assert len(accepted) == 1 + assert accepted[0].content == "User likes coffee and tea" + + surviving = ( + ( + await db_session.execute( + select(models.Document).where( + models.Document.workspace_name == test_workspace.name, + models.Document.observer == test_peer.name, + models.Document.observed == test_peer2.name, + models.Document.deleted_at.is_(None), + ) + ) + ) + .scalars() + .all() + ) + assert len(surviving) == 1 + assert surviving[0].content == "User likes coffee and tea" + assert surviving[0].times_derived == 3 + @pytest.mark.asyncio async def test_delete_document_success( self,