Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions data_juicer/ops/deduplicator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from .document_deduplicator import DocumentDeduplicator
from .document_minhash_deduplicator import DocumentMinhashDeduplicator
from .document_minhash_deduplicator import (DocumentMinhashDeduplicator,
DocumentMinhashDeduplicatorWithUid)
from .document_simhash_deduplicator import DocumentSimhashDeduplicator
from .image_deduplicator import ImageDeduplicator
from .ray_basic_deduplicator import RayBasicDeduplicator
from .ray_bts_minhash_deduplicator import RayBTSMinhashDeduplicator
from .ray_bts_minhash_deduplicator import (RayBTSMinhashDeduplicator,
RayBTSMinhashDeduplicatorWithUid)
from .ray_document_deduplicator import RayDocumentDeduplicator
from .ray_image_deduplicator import RayImageDeduplicator
from .ray_video_deduplicator import RayVideoDeduplicator
Expand All @@ -12,6 +14,7 @@
__all__ = [
'DocumentDeduplicator',
'DocumentMinhashDeduplicator',
'DocumentMinhashDeduplicatorWithUid',
'DocumentSimhashDeduplicator',
'ImageDeduplicator',
'RayBasicDeduplicator',
Expand All @@ -20,5 +23,6 @@
'RayVideoDeduplicator',
'RayImageDeduplicator',
'RayBTSMinhashDeduplicator',
'RayBTSMinhashDeduplicatorWithUid',
'VideoDeduplicator',
]
90 changes: 90 additions & 0 deletions data_juicer/ops/deduplicator/document_minhash_deduplicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,3 +335,93 @@ def _filter_minhash_dup_helper(sample, index):
logger.info(f'Keep {len(dataset)} samples after MinHash dedup.')

return dataset, dup_pairs


@OPERATORS.register_module(f'{OP_NAME}_with_uid')
class DocumentMinhashDeduplicatorWithUid(DocumentMinhashDeduplicator):
"""
A Deduplicator that performs document-level deduplication using MinHashLSH.

Unlike `DocumentMinhashDeduplicator`, this class requires the dataset to include an additional column named
'__dj__uid' of type int, with unique values for each sample. This column is essential for supporting
incremental deduplication scenarios.

For example, consider a scenario where you have an already deduplicated dataset A and a new dataset B that
you wish to add. If you want to perform joint deduplication on both A and B while prioritizing the retention of
data from A, you can ensure that all '__dj__uid' values in B are greater than those in A. Then, by applying
this deduplicator to the combined dataset, duplicates will be resolved in favor of the entries from A.
"""

def process(self, dataset, show_num=0):
"""
For doc-level, dataset --> dataset.

:param dataset: input dataset
:param show_num: number of traced samples used when tracer is
open.
:return: deduplicated dataset and the sampled duplicate pairs.
"""
# no need to deduplicate because too few samples
if len(dataset) <= 1:
return dataset, {}

minhashes = dataset[HashKeys.minhash]
# remove bytes minhash column otherwise unexpected error would occur
# when exporting the processed dataset
dataset = dataset.remove_columns([HashKeys.minhash])
uids = dataset[HashKeys.uid]
uid2idx = {uid: idx for idx, uid in enumerate(uids)}

# make clusters -- construct the minhash lookup tables of seg to ids
logger.info(f'Start clustering for {len(dataset)} samples...')
batch_size = 10000
for i in tqdm(range(0, len(minhashes), batch_size),
dynamic_ncols=True,
desc='Iterating MinHashes of samples...'):
batch = minhashes[i:i + batch_size]
batch_uid = uids[i:i + batch_size]
for uid, hs in zip(batch_uid, batch):
for h, hashtable in zip(hs, self.hash_tables):
hashtable[h].add(uid)

# using UnionFind set to union samples within the same clusters
union_find = UnionFind()
for table in tqdm(self.hash_tables,
dynamic_ncols=True,
desc='Clustering'):
for cluster in table.values():
if len(cluster) <= 1:
continue
uid = min(cluster)
for x in cluster:
union_find.union(x, uid)
logger.info(f'There are {len(set(union_find.parent.values()))} '
f'clusters that includes multiple near-duplicate samples.')

# record the duplicate sample pairs
dup_pairs = {}
if show_num > 0:
for i in range(len(dataset)):
uid = uids[i]
cluster_uid = union_find.find(uid)
cluster_idx = uid2idx[cluster_uid]
if cluster_uid != uid and cluster_idx not in dup_pairs:
dup_pairs[cluster_idx] = [
dataset[cluster_idx],
dataset[i],
]
if len(dup_pairs) >= show_num:
break

# filtering -- only keep those samples whose parent index is itself,
# including:
# 1. samples that form a cluster by themselves
# 2. the first sample in a cluster that includes multiple samples
def _filter_minhash_dup_helper(sample):
uid = sample[HashKeys.uid]
return union_find.find(uid) == uid

dataset = dataset.filter(_filter_minhash_dup_helper)
logger.info(f'Keep {len(dataset)} samples after MinHash dedup.')

return dataset, dup_pairs
51 changes: 51 additions & 0 deletions data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,3 +568,54 @@ def minhash_with_uid(table: pa.Table) -> pa.Table:
zero_copy_batch=True,
)
return result


@OPERATORS.register_module(f'{OP_NAME}_with_uid')
class RayBTSMinhashDeduplicatorWithUid(RayBTSMinhashDeduplicator):
"""
A MinhashLSH deduplicator based on RAY.

Unlike `RayBTSMinhashDeduplicator`, this class requires the input dataset to contain an additional column
named '__dj__uid' of type int, where each value is unique across samples. This column serves two main purposes:

1. **Reduced I/O Overhead**: Compared to RayBTSMinhashDeduplicator, this class does not persist intermediate
results, thereby reducing disk read and write operations.

2. **Support for Incremental Deduplication**: The '__dj__uid' column enables the deduplicator to perform
incremental deduplication. This is particularly useful in scenarios where you already have a deduplicated dataset
(e.g., dataset A) and want to add a new dataset (e.g., dataset B) while ensuring that duplicates are resolved
in favor of the original data.

For example, consider a scenario where you have an already deduplicated dataset A and a new dataset B that
you wish to add. If you want to perform joint deduplication on both A and B while prioritizing the retention
of data from A, you can ensure that all '__dj__uid' values in B are greater than those in A. Then, by applying
this deduplicator to the combined dataset, duplicates will be resolved in favor of the entries from A.
"""

def run(self, dataset, **kwargs):
# Ignore additional parameters like exporter, tracer, etc.
start_time = time.time()

def minhash_with_uid(table: pa.Table) -> pa.Table:
uid_list = table[HashKeys.uid].to_pylist()
self.calc_minhash(table[self.text_key], uid_list)
return table

dataset.map_batches(
minhash_with_uid,
batch_format='pyarrow',
zero_copy_batch=True,
).materialize()
end_time = time.time()
logger.info(f'MinHash time = {end_time - start_time}')

start_time = time.time()
self.merge()
end_time = time.time()
logger.info(f'merge time = {end_time - start_time}')
result = dataset.map_batches(
self.filter_with_union_find,
batch_format='pyarrow',
zero_copy_batch=True,
)
return result
20 changes: 18 additions & 2 deletions tests/ops/deduplicator/test_document_minhash_deduplicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@

from data_juicer.core.data import NestedDataset as Dataset

from data_juicer.ops.deduplicator.document_minhash_deduplicator import \
DocumentMinhashDeduplicator
from data_juicer.ops.deduplicator.document_minhash_deduplicator import (
DocumentMinhashDeduplicator,
DocumentMinhashDeduplicatorWithUid,
)
from data_juicer.utils.constant import HashKeys
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase


Expand Down Expand Up @@ -822,6 +825,12 @@ def test_english_deduplication(self):
op = DocumentMinhashDeduplicator(ignore_pattern=r'\p{P}')
self._run_minhash_dedup(dataset, tgt_list, op)

for i, ds in enumerate(ds_list):
ds[HashKeys.uid] = i
dataset = Dataset.from_list(ds_list)
op = DocumentMinhashDeduplicatorWithUid(ignore_pattern=r'\p{P}')
self._run_minhash_dedup(dataset, tgt_list, op)

def test_chinese_deduplication(self):
ds_list = [
{
Expand Down Expand Up @@ -958,6 +967,13 @@ def test_chinese_deduplication(self):
ignore_pattern=r'\p{P}')
self._run_minhash_dedup(dataset, tgt_list, op)

for i, ds in enumerate(ds_list):
ds[HashKeys.uid] = i
dataset = Dataset.from_list(ds_list)
op = DocumentMinhashDeduplicatorWithUid(tokenization='character',
ignore_pattern=r'\p{P}')
self._run_minhash_dedup(dataset, tgt_list, op)


if __name__ == '__main__':
unittest.main()
26 changes: 22 additions & 4 deletions tests/ops/deduplicator/test_ray_bts_minhash_deduplicator.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import unittest
import os
import shutil

from data_juicer.core.data import NestedDataset as Dataset

from data_juicer.ops.deduplicator.ray_bts_minhash_deduplicator import \
RayBTSMinhashDeduplicator
from data_juicer.ops.deduplicator.ray_bts_minhash_deduplicator import (
RayBTSMinhashDeduplicator,
RayBTSMinhashDeduplicatorWithUid,
)
from data_juicer.utils.constant import HashKeys
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, TEST_TAG


Expand Down Expand Up @@ -817,11 +822,17 @@ def test_english_deduplication(self):
},
]
dataset = self.generate_dataset(ds_list)
import os
cur_dir = os.path.dirname(os.path.abspath(__file__))
work_dir = os.path.join(cur_dir, 'english_dedup')
op = RayBTSMinhashDeduplicator(ignore_pattern=r'\p{P}', work_dir=work_dir)
self._run_minhash_dedup(dataset, tgt_list, op)
shutil.rmtree(work_dir)

for i, ds in enumerate(ds_list):
ds[HashKeys.uid] = i
dataset = self.generate_dataset(ds_list)
op = RayBTSMinhashDeduplicatorWithUid(ignore_pattern=r'\p{P}')
self._run_minhash_dedup(dataset, tgt_list, op)

@TEST_TAG("ray")
def test_chinese_deduplication(self):
Expand Down Expand Up @@ -956,13 +967,20 @@ def test_chinese_deduplication(self):
},
]
dataset = self.generate_dataset(ds_list)
import os
cur_dir = os.path.dirname(os.path.abspath(__file__))
work_dir = os.path.join(cur_dir, 'chinese_dedup')
op = RayBTSMinhashDeduplicator(tokenization='character',
ignore_pattern=r'\p{P}',
work_dir=work_dir)
self._run_minhash_dedup(dataset, tgt_list, op)
shutil.rmtree(work_dir)

for i, ds in enumerate(ds_list):
ds[HashKeys.uid] = i
dataset = self.generate_dataset(ds_list)
op = RayBTSMinhashDeduplicatorWithUid(tokenization='character',
ignore_pattern=r'\p{P}')
self._run_minhash_dedup(dataset, tgt_list, op)


if __name__ == '__main__':
Expand Down
Loading