diff --git a/docs/catalog.py b/docs/catalog.py index 0d06d5d54b..657f212a2c 100644 --- a/docs/catalog.py +++ b/docs/catalog.py @@ -1,6 +1,7 @@ import json import os import re +import types from collections import defaultdict from functools import lru_cache from pathlib import Path @@ -110,6 +111,7 @@ def all_subtypes_of_artifact(artifact): or isinstance(artifact, bool) or isinstance(artifact, int) or isinstance(artifact, float) + or isinstance(artifact, types.FunctionType) ): return [] if isinstance(artifact, list): diff --git a/docs/docs/adding_dataset.rst b/docs/docs/adding_dataset.rst index 0f6af4ef02..7094668241 100644 --- a/docs/docs/adding_dataset.rst +++ b/docs/docs/adding_dataset.rst @@ -5,7 +5,7 @@ To use this tutorial, you need to :ref:`install Unitxt `. ================= -Datasets +Datasets ================= This guide will assist you in adding or using your new dataset in Unitxt. @@ -105,6 +105,27 @@ Most data can be normalized to the task schema using built-in operators, ensurin For custom operators, refer to the :ref:`Operators Tutorial `. +.. tip:: + + If you cannot find operators fit to your needs simply use instance function operator: + + .. code-block:: python + + def my_function(instance, stream_name=None): + instance["x"] += 42 + return instance + + Or stream function operator: + + .. code-block:: python + + def my_other_function(stream, stream_name=None): + for instance in stream: + instance["x"] += 42 + yield instance + + Both functions can be plugged in every place in unitxt requires operators, e.g pre-processing pipeline. + The Template ---------------- diff --git a/docs/docs/adding_operator.rst b/docs/docs/adding_operator.rst index ecb3c2d54b..c258ad9139 100644 --- a/docs/docs/adding_operator.rst +++ b/docs/docs/adding_operator.rst @@ -5,11 +5,33 @@ To use this tutorial, you need to :ref:`install unitxt `. ===================================== -Operators +Operators ===================================== Operators are specialized functions designed to process data. +.. tip:: + + If you cannot find operators fit to your needs simply use instance function operator: + + .. code-block:: python + + def my_function(instance, stream_name=None): + instance["x"] += 42 + return instance + + Or stream function operator: + + .. code-block:: python + + def my_other_function(stream, stream_name=None): + for instance in stream: + instance["x"] += 42 + yield instance + + Both functions can be plugged in every place in unitxt requires operators, e.g pre-processing pipeline. + + They are used in the TaskCard for preparing data for specific tasks and by Post Processors to process the textual output of the model to the expect input of the metrics. diff --git a/src/unitxt/api.py b/src/unitxt/api.py index 23de331bd4..0bf36b98f2 100644 --- a/src/unitxt/api.py +++ b/src/unitxt/api.py @@ -26,7 +26,7 @@ from .settings_utils import get_constants, get_settings from .standard import DatasetRecipe from .task import Task -from .utils import lru_cache_decorator +from .utils import json_dump, lru_cache_decorator logger = get_logger() constants = get_constants() @@ -180,6 +180,20 @@ class MyClass: return obj_str +def _remove_id_keys(obj): + if isinstance(obj, dict): + return {k: _remove_id_keys(v) for k, v in obj.items() if k != "__id__"} + if isinstance(obj, list): + return [_remove_id_keys(item) for item in obj] + return obj + + +def _artifact_string_repr(artifact): + artifact_dict = to_dict(artifact, object_to_str_without_addresses) + artifact_dict_without_ids = _remove_id_keys(artifact_dict) + return json_dump(artifact_dict_without_ids) + + def _source_to_dataset( source: SourceOperator, split=None, @@ -189,9 +203,7 @@ def _source_to_dataset( from .dataset import Dataset as UnitxtDataset # Generate a unique signature for the source - source_signature = json.dumps( - to_dict(source, object_to_str_without_addresses), sort_keys=True - ) + source_signature = _artifact_string_repr(source) config_name = "recipe-" + short_hex_hash(source_signature) # Obtain data stream from the source stream = source() diff --git a/src/unitxt/artifact.py b/src/unitxt/artifact.py index e1ccae320e..b2bd693ec8 100644 --- a/src/unitxt/artifact.py +++ b/src/unitxt/artifact.py @@ -4,6 +4,7 @@ import os import pkgutil import re +import types import warnings from abc import abstractmethod from typing import Any, Dict, List, Optional, Tuple, Union, final @@ -27,6 +28,8 @@ from .utils import ( artifacts_json_cache, json_dump, + json_load, + load_json, save_to_file, shallow_copy, ) @@ -119,16 +122,32 @@ def reset(self): self.catalogs = [] +def maybe_recover_function_operator(func): + sig = inspect.signature(func) + param_names = tuple(sorted(sig.parameters)) + if param_names == ("stream", "stream_name") or param_names == ( + "instance", + "stream_name", + ): + from .operators import FunctionOperator + + return FunctionOperator(function=func) + return func + + def maybe_recover_artifacts_structure(obj): + if isinstance(obj, types.FunctionType): + obj = maybe_recover_function_operator(obj) + if Artifact.is_possible_identifier(obj): return verbosed_fetch_artifact(obj) if isinstance(obj, dict): for key, value in obj.items(): - obj[key] = maybe_recover_artifact(value) + obj[key] = maybe_recover_artifacts_structure(value) return obj if isinstance(obj, list): for i in range(len(obj)): - obj[i] = maybe_recover_artifact(obj[i]) + obj[i] = maybe_recover_artifacts_structure(obj[i]) return obj return obj @@ -237,8 +256,7 @@ def __init_subclass__(cls, **kwargs): def is_artifact_file(cls, path): if not os.path.exists(path) or not os.path.isfile(path): return False - with open(path) as f: - d = json.load(f) + d = load_json(path) return cls.is_artifact_dict(d) @classmethod @@ -384,14 +402,15 @@ def serialize(self): return self.to_json() def save(self, path): - original_args = Artifact.from_dict(self.to_dict()).get_repr_dict() + data = self.to_dict() + original_args = Artifact.from_dict(data).get_repr_dict() current_args = self.get_repr_dict() diffs = dict_diff_string(original_args, current_args) if diffs: raise UnitxtError( f"Cannot save catalog artifacts that have changed since initialization. Detected differences in the following fields:\n{diffs}" ) - save_to_file(path, self.to_json()) + save_to_file(path, json_dump(data)) def verify_instance( self, instance: Dict[str, Any], name: Optional[str] = None @@ -581,7 +600,7 @@ def fetch_artifact( # If Json string, first load into dictionary if isinstance(artifact_rep, str): - artifact_rep = json.loads(artifact_rep) + artifact_rep = json_load(artifact_rep) # Load from dictionary (fails if not valid dictionary) return Artifact.from_dict(artifact_rep), None @@ -657,7 +676,7 @@ def get_artifacts_data_classification(artifact: str) -> Optional[List[str]]: ) try: - data_classification = json.loads(data_classification) + data_classification = json_load(data_classification) except json.decoder.JSONDecodeError as e: raise RuntimeError(error_msg) from e diff --git a/src/unitxt/catalog.py b/src/unitxt/catalog.py index 3221c3ee0d..e1a6a114fb 100644 --- a/src/unitxt/catalog.py +++ b/src/unitxt/catalog.py @@ -18,6 +18,7 @@ from .logging_utils import get_logger from .settings_utils import get_constants from .text_utils import print_dict +from .utils import json_load from .version import version logger = get_logger() @@ -228,7 +229,7 @@ def _get_tags_from_file(file_path): result = Counter() with open(file_path) as f: - data = json.load(f) + data = json_load(f) if "__tags__" in data and isinstance(data["__tags__"], dict): tags = data["__tags__"] for key, value in tags.items(): diff --git a/src/unitxt/dataclass.py b/src/unitxt/dataclass.py index afb92bcf0e..49332a9ed4 100644 --- a/src/unitxt/dataclass.py +++ b/src/unitxt/dataclass.py @@ -2,6 +2,7 @@ import dataclasses import functools import inspect +import json from abc import ABCMeta from inspect import Parameter, Signature from typing import Any, Dict, List, Optional, final @@ -321,6 +322,10 @@ def to_dict(obj, func=copy.deepcopy, _visited=None): # Get object ID to track visited objects obj_id = id(obj) + if isinstance(obj, (int, float, bool)): + # normalize constants like re.DOTALL + obj = json.loads(json.dumps(obj)) + # If we've seen this object before, return a placeholder to avoid infinite recursion if obj_id in _visited: return func(obj) diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index 87488d1da7..809cbc7604 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -51,7 +51,7 @@ from .operators import ArtifactFetcherMixin from .settings_utils import get_constants, get_settings from .type_utils import isoftype -from .utils import retry_connection_with_exponential_backoff +from .utils import json_load, retry_connection_with_exponential_backoff constants = get_constants() settings = get_settings() @@ -403,7 +403,7 @@ def to_tools(self, instance): if task_data is None: return None if isinstance(task_data, str): - task_data = json.loads(task_data) + task_data = json_load(task_data) if "__tools__" in task_data: return task_data["__tools__"] return None @@ -2562,7 +2562,7 @@ class WMLInferenceEngineChat(WMLInferenceEngineBase, WMLChatParamsMixin): def _extract_queries(instance: Dict[str, Any]) -> Tuple[Optional[str], List]: task_data = instance["task_data"] if isinstance(task_data, str): - task_data = json.loads(task_data) + task_data = json_load(task_data) question = task_data.get("question") images = [None] @@ -2682,7 +2682,7 @@ def to_tools( return {"tools": None, "tool_choice": None} if isinstance(task_data, str): - task_data = json.loads(task_data) + task_data = json_load(task_data) if "__tools__" in task_data: tools: List[Dict[str, str]] = task_data["__tools__"] tool_choice: Optional[Dict[str, str]] = task_data.get("__tool_choice__") @@ -2980,7 +2980,7 @@ def _infer( task_data = instance["task_data"] if isinstance(task_data, str): - task_data = json.loads(task_data) + task_data = json_load(task_data) for option in task_data["options"]: requests.append( @@ -3691,7 +3691,7 @@ def _infer( return_meta_data: bool = False, ) -> Union[List[str], List[TextGenerationInferenceOutput]]: task_data = [ - json.loads(instance["task_data"]) if "task_data" in instance else {} + json_load(instance["task_data"]) if "task_data" in instance else {} for instance in dataset ] predictions = ( diff --git a/src/unitxt/llm_as_judge_from_template.py b/src/unitxt/llm_as_judge_from_template.py index df2d5abab8..fcb924ce02 100644 --- a/src/unitxt/llm_as_judge_from_template.py +++ b/src/unitxt/llm_as_judge_from_template.py @@ -12,16 +12,15 @@ from .settings_utils import get_settings from .system_prompts import EmptySystemPrompt, SystemPrompt from .templates import Template +from .utils import json_load settings = get_settings() def get_task_data_dict(task_data): - import json - # seems like the task data sometimes comes as a string, not a dict # this fixes it - return json.loads(task_data) if isinstance(task_data, str) else task_data + return json_load(task_data) if isinstance(task_data, str) else task_data class LLMAsJudgeBase(BulkInstanceMetric, ArtifactFetcherMixin): diff --git a/src/unitxt/operators.py b/src/unitxt/operators.py index 410b940df1..7647cd02fa 100644 --- a/src/unitxt/operators.py +++ b/src/unitxt/operators.py @@ -3,6 +3,28 @@ Operators: Building Blocks of Unitxt Processing Pipelines ============================================================== +.. tip:: + + If you cannot find operators fit to your needs simply use instance function operator: + + .. code-block:: python + + def my_function(instance, stream_name=None): + instance["x"] += 42 + return instance + + Or stream function operator: + + .. code-block:: python + + def my_other_function(stream, stream_name=None): + for instance in stream: + instance["x"] += 42 + yield instance + + Both functions can be plugged in every place in unitxt requires operators, e.g pre-processing pipeline. + + Within the Unitxt framework, operators serve as the foundational elements used to assemble processing pipelines. Each operator is designed to perform specific manipulations on dictionary structures within a stream. These operators are callable entities that receive a MultiStream as input. @@ -39,6 +61,7 @@ ------------------------ """ +import inspect import operator import re import uuid @@ -2714,3 +2737,40 @@ def process( ), f"field '{self.main_field}' must reside in instance in order to verify its jsonschema correctness. got {instance}" self.recursive_trace_for_type_fields(instance[self.main_field]) return instance + + +class FunctionOperator(StreamOperator): + function: Callable + + def verify(self): + super().verify() + + if not callable(self.function): + raise ValueError("Function must be callable.") + sig = inspect.signature(self.function) + param_names = set(sig.parameters) + + if "stream_name" not in param_names: + raise TypeError( + "The provided function must have a 'stream_name' parameter." + ) + + if "stream" not in param_names and "instance" not in param_names: + raise TypeError( + "The provided function must have a 'stream' parameter or 'instance' parameter." + ) + + if len(param_names) != 2: + raise TypeError("The provided function must have only 2 parameters") + + if "stream" in param_names: + self._mode = "stream" + if "instance" in param_names: + self._mode = "instance" + + def process(self, stream: Stream, stream_name: Optional[str] = None): + if self._mode == "stream": + yield from self.function(stream, stream_name) + if self._mode == "instance": + for instance in stream: + yield self.function(instance, stream_name) diff --git a/src/unitxt/test_utils/artifact.py b/src/unitxt/test_utils/artifact.py index 79b46ba04e..5294f79c99 100644 --- a/src/unitxt/test_utils/artifact.py +++ b/src/unitxt/test_utils/artifact.py @@ -1,7 +1,7 @@ -import json import tempfile from .. import add_to_catalog, register_local_catalog +from ..api import _artifact_string_repr from ..artifact import fetch_artifact from ..logging_utils import get_logger from ..text_utils import print_dict @@ -19,11 +19,14 @@ def test_artfifact_saving_and_loading(artifact, tester=None): loaded_artifact, _ = fetch_artifact(TEMP_NAME) if tester is not None: with tester.subTest(artifact=artifact, loaded_artifact=loaded_artifact): - tester.assertDictEqual(loaded_artifact.to_dict(), artifact.to_dict()) + tester.assertEqual( + _artifact_string_repr(loaded_artifact), + _artifact_string_repr(artifact), + ) else: - if not json.dumps( - loaded_artifact.to_dict(), sort_keys=True, ensure_ascii=False - ) == json.dumps(artifact.to_dict(), sort_keys=True): + if not _artifact_string_repr(loaded_artifact) == _artifact_string_repr( + artifact + ): logger.info("Artifact loaded is not equal to artifact stored") print_dict(loaded_artifact.to_dict()) print_dict(artifact.to_dict()) diff --git a/src/unitxt/test_utils/card.py b/src/unitxt/test_utils/card.py index a02c811a12..c6ef53693e 100644 --- a/src/unitxt/test_utils/card.py +++ b/src/unitxt/test_utils/card.py @@ -4,6 +4,7 @@ import tempfile from .. import add_to_catalog, register_local_catalog +from ..api import _artifact_string_repr from ..artifact import fetch_artifact from ..collections import Collection from ..logging_utils import get_logger @@ -41,8 +42,8 @@ def test_loading_from_catalog(card): ) register_local_catalog(tmp_dir) card_, _ = fetch_artifact(TEMP_NAME) - assert json.dumps(card_.to_dict(), sort_keys=True) == json.dumps( - card.to_dict(), sort_keys=True + assert _artifact_string_repr(card_) == _artifact_string_repr( + card ), "Card loaded is not equal to card stored" diff --git a/src/unitxt/text_utils.py b/src/unitxt/text_utils.py index c54d3fbd72..5644cf2ca1 100644 --- a/src/unitxt/text_utils.py +++ b/src/unitxt/text_utils.py @@ -1,5 +1,6 @@ import re import shutil +import types from typing import List, Tuple import pandas as pd @@ -295,6 +296,36 @@ def construct_dict_as_python_lines(d, indent_delta=4) -> List[str]: return [f'"{d}"'] if d is None or isinstance(d, (int, float, bool)): return [f"{d}"] + + if isinstance(d, types.FunctionType): + from .utils import get_function_source + + try: + source = get_function_source(d) + source_lines = source.splitlines() + + # Find the base indentation of the function definition + base_indent = len(source_lines[0]) - len(source_lines[0].lstrip()) + + # Remove only the base indentation from each line + result_lines = [] + for line in source_lines: + # Preserve empty lines + if line.strip() == "": + result_lines.append("") + else: + # Remove base indent while preserving internal indentation + if line.startswith(" " * base_indent): + result_lines.append(line[base_indent:]) + else: + result_lines.append(line.lstrip()) + + return result_lines + + except (OSError, TypeError): + # If source is not available + return [f""] + raise RuntimeError(f"unrecognized value to print as python: {d}") diff --git a/src/unitxt/utils.py b/src/unitxt/utils.py index 2bfd7b1522..70fb31a0fb 100644 --- a/src/unitxt/utils.py +++ b/src/unitxt/utils.py @@ -1,11 +1,13 @@ import copy import functools import importlib.util +import inspect import json import os import random import re import time +import types from collections import OrderedDict from contextvars import ContextVar from functools import wraps @@ -221,7 +223,7 @@ def flatten_dict( def load_json(path): with open(path) as f: try: - return json.load(f) + return json.load(f, object_hook=decode_function) except json.decoder.JSONDecodeError as e: with open(path) as f: file_content = "\n".join(f.readlines()) @@ -236,8 +238,56 @@ def save_to_file(path, data): f.write("\n") -def json_dump(data): - return json.dumps(data, indent=4, ensure_ascii=False) +def encode_function(obj): + # Allow only plain (module-level) functions + if isinstance(obj, types.FunctionType): + try: + return {"__function__": obj.__name__, "source": get_function_source(obj)} + except Exception as e: + raise TypeError(f"Failed to serialize function {obj.__name__}") from e + elif isinstance(obj, types.MethodType): + raise TypeError( + f"Method {obj.__func__.__name__} of class {obj.__self__.__class__.__name__} is not JSON serializable" + ) + raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable") + + +def json_dump(data, sort_keys=False): + return json.dumps( + data, indent=4, default=encode_function, ensure_ascii=False, sort_keys=sort_keys + ) + + +def get_function_source(func): + if hasattr(func, "__exec_source__"): + return func.__exec_source__ + return inspect.getsource(func) + + +def decode_function(obj): + # Detect our special function marker + if "__function__" in obj and "source" in obj: + namespace = {} + func_name = obj["__function__"] + try: + exec(obj["source"], namespace) + func = namespace.get(func_name) + func.__exec_source__ = obj["source"] + if not callable(func): + raise ValueError( + f"Source did not define a callable named {func_name!r}" + ) + return func + except Exception as e: + raise ValueError( + f"Failed to load function {func_name!r} from source:\n{obj['source']}" + ) from e + + return obj + + +def json_load(s): + return json.loads(s, object_hook=decode_function) def is_package_installed(package_name): diff --git a/tests/library/test_function_operators.py b/tests/library/test_function_operators.py index f101727e05..522ced9322 100644 --- a/tests/library/test_function_operators.py +++ b/tests/library/test_function_operators.py @@ -1,13 +1,110 @@ import json +import os +import tempfile +import types +from unitxt.artifact import Artifact from unitxt.operator import SequentialOperator -from unitxt.operators import Apply, CopyFields +from unitxt.operators import Apply, CopyFields, FunctionOperator from unitxt.test_utils.operators import check_operator from tests.utils import UnitxtTestCase +def process_stream(stream, stream_name=None): + for instance in stream: + instance["x"] += 1 + yield instance + + +def process_instance(instance, stream_name=None): + instance["x"] += 1 + return instance + + +def wrong_function(instance): + ... + + class TestFunctionOperators(UnitxtTestCase): + def test_saving_and_loading_operator_holding_function_operator(self): + with tempfile.TemporaryDirectory() as temp_dir: + artifact_path = os.path.join(temp_dir, "temp_func.json") + SequentialOperator(steps=[process_stream]).save(artifact_path) + + loaded = Artifact.load(artifact_path) + self.assertIsInstance(loaded, SequentialOperator) + if isinstance(loaded, SequentialOperator): + self.assertIsInstance(loaded.steps[0], FunctionOperator) + if isinstance(loaded.steps[0], FunctionOperator): + self.assertIsInstance(loaded.steps[0].function, types.FunctionType) + + def test_saving_and_loading_function_operator(self): + with tempfile.TemporaryDirectory() as temp_dir: + artifact_path = os.path.join(temp_dir, "temp_func.json") + FunctionOperator(function=process_stream).save(artifact_path) + + loaded = Artifact.load(artifact_path) + self.assertIsInstance(loaded, FunctionOperator) + if isinstance(loaded, FunctionOperator): + self.assertIsInstance(loaded.function, types.FunctionType) + + def test_saving_and_loading_operator_with_regular_function(self): + with tempfile.TemporaryDirectory() as temp_dir: + artifact_path = os.path.join(temp_dir, "temp_func.json") + SequentialOperator(steps=[wrong_function]).save(artifact_path) + + loaded = Artifact.load(artifact_path) + self.assertIsInstance(loaded, SequentialOperator) + if isinstance(loaded, SequentialOperator): + self.assertIsInstance(loaded.steps[0], types.FunctionType) + + def test_stream_function_operators(self): + operator = FunctionOperator(function=process_stream) + + inputs = [ + {"x": 1, "b": "2"}, + {"x": 2, "b": "3"}, + ] + + targets = [ + {"x": 2, "b": "2"}, + {"x": 3, "b": "3"}, + ] + + check_operator( + operator=operator, + inputs=inputs, + targets=targets, + tester=self, + ) + + def test_instance_function_operators(self): + operator = FunctionOperator(function=process_instance) + + inputs = [ + {"x": 1, "b": "2"}, + {"x": 2, "b": "3"}, + ] + + targets = [ + {"x": 2, "b": "2"}, + {"x": 3, "b": "3"}, + ] + + check_operator( + operator=operator, + inputs=inputs, + targets=targets, + tester=self, + ) + + def test_function_operator_with_wrong_function(self): + with self.assertRaises(ValueError): + FunctionOperator(function=[]) + with self.assertRaises(TypeError): + FunctionOperator(function=wrong_function) + def test_apply_function_operator(self): operator = Apply("a", function=str.upper, to_field="b")