Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,8 @@ tests/ops/data/*dup*
tests/tools/tmp_*/
tests/ops/deduplicator/chinese_dedup/
tests/ops/deduplicator/english_dedup/
demos/process_video_on_ray/data/videos/
log.txt
mps_test_results/
spatial-sharing-sys/
test.py
109 changes: 109 additions & 0 deletions data_juicer/core/RayOperatorWrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import ray


@ray.remote(num_gpus=0.0)
class Actor:
def __init__(self, op, rank=None):

self.op = op
self._model_loaded = False # taggle to check if model is loaded
self.rank = rank
self.model = None
self.processor = None

def load_model(self):

if self.op.use_cuda() and not self._model_loaded:

self.model, self.processor = self.op.load_model(rank=self.rank)
self._model_loaded = True

def mapper_cuda(self, data):
if not self._model_loaded:
self.load_model() # ensure model is loaded before processing
# process data
data = self.op.process_single_actor(data, self.model, self.processor)
return data

def mapper_cuda_batched(self, data):
if not self._model_loaded:
self.load_model() # ensure model is loaded before processing
# process data
data = self.op.process_batched_actor(data, self.model, self.processor)
return data

def mapper_cpu(self, data):
# process data
processed_data = self.op.process_single(data)
return processed_data

def filter_cuda_single(self, data):
if not self._model_loaded:
self.load_model()
# Call the Filter operator function
data = self.op.compute_stats_single_actor(data, self.model, self.processor)
keep = self.op.process_single(data)

if keep:
return data
else:
return None

def filter_cuda_batched(self, data):
if not self._model_loaded:
self.load_model()
data = self.op.compute_stats_batched(data, self.model, self.processor)
# transform the map object to a list
keep_mask = list(self.op.process_batched(data))

if not any(keep_mask):
return None

# filter data based on the keep_mask
if isinstance(data, dict):
filtered_data = {
key: [value for value, keep in zip(values, keep_mask) if keep] for key, values in data.items()
}
elif isinstance(data, list):
filtered_data = [item for item, keep in zip(data, keep_mask) if keep]
else:
raise ValueError("Unsupported data type for batch filtering")

return filtered_data

def filter_cpu_single(self, data):
if "text" in data and isinstance(data["text"], list) and len(data["text"]) == 1:
data["text"] = data["text"][0]
if "__dj__stats__" in data and isinstance(data["__dj__stats__"], list) and len(data["__dj__stats__"]) == 1:
data["__dj__stats__"] = data["__dj__stats__"][0]
data = self.op.compute_stats_single(data)
keep = self.op.process_single(data)
if keep:
return data
else:
return None

def filter_cpu_batched(self, data):
data = self.op.compute_stats_batched(data)

keep_mask = list(self.op.process_batched(data))

if not any(keep_mask):
return None

# filter data based on the keep_mask
if isinstance(data, dict):
filtered_data = {}
for key, values in data.items():
if key in ["text", "__dj__stats__"]:
# 对这些字段应用过滤
filtered_data[key] = [value for value, keep in zip(values, keep_mask) if keep]
else:
# 对其他字段保持原样
filtered_data[key] = values
elif isinstance(data, list):
filtered_data = [item for item, keep in zip(data, keep_mask) if keep]
else:
raise ValueError("Unsupported data type for batch filtering")

return filtered_data
7 changes: 7 additions & 0 deletions data_juicer/core/data/dj_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ class DJDataset(ABC):
def process(self, operators, *, exporter=None, checkpointer=None, tracer=None) -> DJDataset: # TODO: add type hint
"""process a list of operators on the dataset."""

@abstractmethod
def process_parallel(self, operators, *, exporter=None, checkpointer=None, tracer=None) -> DJDataset:
"""Implementing op parallel data processing based on Ray Actor"""

@abstractmethod
def schema(self) -> Schema:
"""Get dataset schema.
Expand Down Expand Up @@ -344,6 +348,9 @@ def process(
logger.error("Error occurred when making log summarization")
return dataset

def process_parallel(self, *args, **kwargs):
raise NotImplementedError("The process_parallel method needs to be implemented for the NestedDataset class.")

def update_args(self, args, kargs, is_filter=False):
if args:
args = list(args)
Expand Down
Loading
Loading