diff --git a/src/unitxt/api.py b/src/unitxt/api.py index 23de331bd4..3e66a2fc60 100644 --- a/src/unitxt/api.py +++ b/src/unitxt/api.py @@ -22,7 +22,7 @@ from .logging_utils import get_logger from .metric_utils import EvaluationResults, _compute, _inference_post_process from .operator import SourceOperator -from .schema import loads_batch +from .schema import SerializeInstancesBeforeDump, loads_batch from .settings_utils import get_constants, get_settings from .standard import DatasetRecipe from .task import Task @@ -204,7 +204,7 @@ def _source_to_dataset( ) if split is not None: stream = {split: stream[split]} - ds_builder._generators = stream + ds_builder._generators = SerializeInstancesBeforeDump()(stream) try: ds_builder.download_and_prepare( @@ -354,10 +354,12 @@ def produce( if not is_list: instance_or_instances = [instance_or_instances] dataset_recipe = _get_recipe_with_cache(dataset_query, **kwargs) - result = dataset_recipe.produce(instance_or_instances) + instances = dataset_recipe.produce(instance_or_instances) + serialize = SerializeInstancesBeforeDump() + instances = [serialize.process_instance(instance) for instance in instances] if not is_list: - return result[0] - return Dataset.from_list(result).with_transform(loads_batch) + return instances[0] + return Dataset.from_list(instances).with_transform(loads_batch) def infer( diff --git a/src/unitxt/dataset.py b/src/unitxt/dataset.py index 94529f42ff..1c207470fc 100644 --- a/src/unitxt/dataset.py +++ b/src/unitxt/dataset.py @@ -47,13 +47,14 @@ from .random_utils import __file__ as _ from .recipe import __file__ as _ from .register import __file__ as _ -from .schema import loads_batch, loads_instance +from .schema import SerializeInstancesBeforeDump, loads_batch, loads_instance from .serializers import __file__ as _ from .settings_utils import get_constants from .span_lableing_operators import __file__ as _ from .split_utils import __file__ as _ from .splitters import __file__ as _ from .standard import __file__ as _ +from .stream import MultiStream from .stream import __file__ as _ from .stream_operators import __file__ as _ from .string_operators import __file__ as _ @@ -92,7 +93,9 @@ def generators(self): logger.info("Loading with huggingface unitxt copy...") dataset = get_dataset_artifact(self.config.name) - self._generators = dataset() + multi_stream: MultiStream = dataset() + + self._generators = SerializeInstancesBeforeDump()(multi_stream) return self._generators diff --git a/src/unitxt/schema.py b/src/unitxt/schema.py index 5926b120cc..21d2b72b5d 100644 --- a/src/unitxt/schema.py +++ b/src/unitxt/schema.py @@ -7,7 +7,7 @@ from .artifact import Artifact from .dict_utils import dict_get from .image_operators import ImageDataString -from .operator import InstanceOperatorValidator +from .operator import InstanceOperator, InstanceOperatorValidator from .settings_utils import get_constants, get_settings from .type_utils import isoftype from .types import Image @@ -68,6 +68,18 @@ def load_chat_source(chat_str): return chat +class SerializeInstancesBeforeDump(InstanceOperator): + def process( + self, instance: Dict[str, Any], stream_name: Optional[str] = None + ) -> Dict[str, Any]: + if settings.task_data_as_text: + instance["task_data"] = json.dumps(instance["task_data"]) + + if not isinstance(instance["source"], str): + instance["source"] = json.dumps(instance["source"]) + return instance + + def loads_batch(batch): if ( "source" in batch @@ -148,14 +160,6 @@ def _get_instance_task_data( task_data["__tools__"] = instance["__tools__"] return task_data - def serialize_instance_fields(self, instance, task_data): - if settings.task_data_as_text: - instance["task_data"] = json.dumps(task_data) - - if not isinstance(instance["source"], str): - instance["source"] = json.dumps(instance["source"]) - return instance - def process( self, instance: Dict[str, Any], stream_name: Optional[str] = None ) -> Dict[str, Any]: @@ -179,7 +183,7 @@ def process( for instance in instance.pop(constants.demos_field) ] - instance = self.serialize_instance_fields(instance, task_data) + instance["task_data"] = task_data if self.remove_unnecessary_fields: keys_to_delete = [] @@ -224,7 +228,8 @@ def validate(self, instance: Dict[str, Any], stream_name: Optional[str] = None): instance, dict ), f"Instance should be a dict, got {type(instance)}" schema = get_schema(stream_name) + assert all( key in instance for key in schema - ), f"Instance should have the following keys: {schema}. Instance is: {instance}" - schema.encode_example(instance) + ), f"Instance should have the following keys: {schema.keys()}. Instance is: {instance.keys()}" + # schema.encode_example(instance)