Skip to content
Open
2 changes: 2 additions & 0 deletions docs/catalog.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import os
import re
import types
from collections import defaultdict
from functools import lru_cache
from pathlib import Path
Expand Down Expand Up @@ -110,6 +111,7 @@ def all_subtypes_of_artifact(artifact):
or isinstance(artifact, bool)
or isinstance(artifact, int)
or isinstance(artifact, float)
or isinstance(artifact, types.FunctionType)
):
return []
if isinstance(artifact, list):
Expand Down
23 changes: 22 additions & 1 deletion docs/docs/adding_dataset.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
To use this tutorial, you need to :ref:`install Unitxt <install_unitxt>`.

=================
Datasets
Datasets
=================

This guide will assist you in adding or using your new dataset in Unitxt.
Expand Down Expand Up @@ -105,6 +105,27 @@ Most data can be normalized to the task schema using built-in operators, ensurin

For custom operators, refer to the :ref:`Operators Tutorial <adding_operator>`.

.. tip::

If you cannot find operators fit to your needs simply use instance function operator:

.. code-block:: python

def my_function(instance, stream_name=None):
instance["x"] += 42
return instance

Or stream function operator:

.. code-block:: python

def my_other_function(stream, stream_name=None):
for instance in stream:
instance["x"] += 42
yield instance

Both functions can be plugged in every place in unitxt requires operators, e.g pre-processing pipeline.

The Template
----------------

Expand Down
24 changes: 23 additions & 1 deletion docs/docs/adding_operator.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,33 @@
To use this tutorial, you need to :ref:`install unitxt <install_unitxt>`.

=====================================
Operators
Operators
=====================================

Operators are specialized functions designed to process data.

.. tip::

If you cannot find operators fit to your needs simply use instance function operator:

.. code-block:: python

def my_function(instance, stream_name=None):
instance["x"] += 42
return instance

Or stream function operator:

.. code-block:: python

def my_other_function(stream, stream_name=None):
for instance in stream:
instance["x"] += 42
yield instance

Both functions can be plugged in every place in unitxt requires operators, e.g pre-processing pipeline.


They are used in the TaskCard for preparing data for specific tasks and by Post Processors
to process the textual output of the model to the expect input of the metrics.

Expand Down
20 changes: 16 additions & 4 deletions src/unitxt/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from .settings_utils import get_constants, get_settings
from .standard import DatasetRecipe
from .task import Task
from .utils import lru_cache_decorator
from .utils import json_dump, lru_cache_decorator

logger = get_logger()
constants = get_constants()
Expand Down Expand Up @@ -180,6 +180,20 @@ class MyClass:
return obj_str


def _remove_id_keys(obj):
if isinstance(obj, dict):
return {k: _remove_id_keys(v) for k, v in obj.items() if k != "__id__"}
if isinstance(obj, list):
return [_remove_id_keys(item) for item in obj]
return obj


def _artifact_string_repr(artifact):
artifact_dict = to_dict(artifact, object_to_str_without_addresses)
artifact_dict_without_ids = _remove_id_keys(artifact_dict)
return json_dump(artifact_dict_without_ids)


def _source_to_dataset(
source: SourceOperator,
split=None,
Expand All @@ -189,9 +203,7 @@ def _source_to_dataset(
from .dataset import Dataset as UnitxtDataset

# Generate a unique signature for the source
source_signature = json.dumps(
to_dict(source, object_to_str_without_addresses), sort_keys=True
)
source_signature = _artifact_string_repr(source)
config_name = "recipe-" + short_hex_hash(source_signature)
# Obtain data stream from the source
stream = source()
Expand Down
35 changes: 27 additions & 8 deletions src/unitxt/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import pkgutil
import re
import types
import warnings
from abc import abstractmethod
from typing import Any, Dict, List, Optional, Tuple, Union, final
Expand All @@ -27,6 +28,8 @@
from .utils import (
artifacts_json_cache,
json_dump,
json_load,
load_json,
save_to_file,
shallow_copy,
)
Expand Down Expand Up @@ -119,16 +122,32 @@ def reset(self):
self.catalogs = []


def maybe_recover_function_operator(func):
sig = inspect.signature(func)
param_names = tuple(sorted(sig.parameters))
if param_names == ("stream", "stream_name") or param_names == (
"instance",
"stream_name",
):
from .operators import FunctionOperator

return FunctionOperator(function=func)
return func


def maybe_recover_artifacts_structure(obj):
if isinstance(obj, types.FunctionType):
obj = maybe_recover_function_operator(obj)

if Artifact.is_possible_identifier(obj):
return verbosed_fetch_artifact(obj)
if isinstance(obj, dict):
for key, value in obj.items():
obj[key] = maybe_recover_artifact(value)
obj[key] = maybe_recover_artifacts_structure(value)
return obj
if isinstance(obj, list):
for i in range(len(obj)):
obj[i] = maybe_recover_artifact(obj[i])
obj[i] = maybe_recover_artifacts_structure(obj[i])
return obj
return obj

Expand Down Expand Up @@ -237,8 +256,7 @@ def __init_subclass__(cls, **kwargs):
def is_artifact_file(cls, path):
if not os.path.exists(path) or not os.path.isfile(path):
return False
with open(path) as f:
d = json.load(f)
d = load_json(path)
return cls.is_artifact_dict(d)

@classmethod
Expand Down Expand Up @@ -384,14 +402,15 @@ def serialize(self):
return self.to_json()

def save(self, path):
original_args = Artifact.from_dict(self.to_dict()).get_repr_dict()
data = self.to_dict()
original_args = Artifact.from_dict(data).get_repr_dict()
current_args = self.get_repr_dict()
diffs = dict_diff_string(original_args, current_args)
if diffs:
raise UnitxtError(
f"Cannot save catalog artifacts that have changed since initialization. Detected differences in the following fields:\n{diffs}"
)
save_to_file(path, self.to_json())
save_to_file(path, json_dump(data))

def verify_instance(
self, instance: Dict[str, Any], name: Optional[str] = None
Expand Down Expand Up @@ -581,7 +600,7 @@ def fetch_artifact(

# If Json string, first load into dictionary
if isinstance(artifact_rep, str):
artifact_rep = json.loads(artifact_rep)
artifact_rep = json_load(artifact_rep)
# Load from dictionary (fails if not valid dictionary)
return Artifact.from_dict(artifact_rep), None

Expand Down Expand Up @@ -657,7 +676,7 @@ def get_artifacts_data_classification(artifact: str) -> Optional[List[str]]:
)

try:
data_classification = json.loads(data_classification)
data_classification = json_load(data_classification)
except json.decoder.JSONDecodeError as e:
raise RuntimeError(error_msg) from e

Expand Down
3 changes: 2 additions & 1 deletion src/unitxt/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .logging_utils import get_logger
from .settings_utils import get_constants
from .text_utils import print_dict
from .utils import json_load
from .version import version

logger = get_logger()
Expand Down Expand Up @@ -228,7 +229,7 @@ def _get_tags_from_file(file_path):
result = Counter()

with open(file_path) as f:
data = json.load(f)
data = json_load(f)
if "__tags__" in data and isinstance(data["__tags__"], dict):
tags = data["__tags__"]
for key, value in tags.items():
Expand Down
5 changes: 5 additions & 0 deletions src/unitxt/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import dataclasses
import functools
import inspect
import json
from abc import ABCMeta
from inspect import Parameter, Signature
from typing import Any, Dict, List, Optional, final
Expand Down Expand Up @@ -321,6 +322,10 @@ def to_dict(obj, func=copy.deepcopy, _visited=None):
# Get object ID to track visited objects
obj_id = id(obj)

if isinstance(obj, (int, float, bool)):
# normalize constants like re.DOTALL
obj = json.loads(json.dumps(obj))

# If we've seen this object before, return a placeholder to avoid infinite recursion
if obj_id in _visited:
return func(obj)
Expand Down
12 changes: 6 additions & 6 deletions src/unitxt/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from .operators import ArtifactFetcherMixin
from .settings_utils import get_constants, get_settings
from .type_utils import isoftype
from .utils import retry_connection_with_exponential_backoff
from .utils import json_load, retry_connection_with_exponential_backoff

constants = get_constants()
settings = get_settings()
Expand Down Expand Up @@ -403,7 +403,7 @@ def to_tools(self, instance):
if task_data is None:
return None
if isinstance(task_data, str):
task_data = json.loads(task_data)
task_data = json_load(task_data)
if "__tools__" in task_data:
return task_data["__tools__"]
return None
Expand Down Expand Up @@ -2562,7 +2562,7 @@ class WMLInferenceEngineChat(WMLInferenceEngineBase, WMLChatParamsMixin):
def _extract_queries(instance: Dict[str, Any]) -> Tuple[Optional[str], List]:
task_data = instance["task_data"]
if isinstance(task_data, str):
task_data = json.loads(task_data)
task_data = json_load(task_data)
question = task_data.get("question")

images = [None]
Expand Down Expand Up @@ -2682,7 +2682,7 @@ def to_tools(
return {"tools": None, "tool_choice": None}

if isinstance(task_data, str):
task_data = json.loads(task_data)
task_data = json_load(task_data)
if "__tools__" in task_data:
tools: List[Dict[str, str]] = task_data["__tools__"]
tool_choice: Optional[Dict[str, str]] = task_data.get("__tool_choice__")
Expand Down Expand Up @@ -2980,7 +2980,7 @@ def _infer(
task_data = instance["task_data"]

if isinstance(task_data, str):
task_data = json.loads(task_data)
task_data = json_load(task_data)

for option in task_data["options"]:
requests.append(
Expand Down Expand Up @@ -3691,7 +3691,7 @@ def _infer(
return_meta_data: bool = False,
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
task_data = [
json.loads(instance["task_data"]) if "task_data" in instance else {}
json_load(instance["task_data"]) if "task_data" in instance else {}
for instance in dataset
]
predictions = (
Expand Down
5 changes: 2 additions & 3 deletions src/unitxt/llm_as_judge_from_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,15 @@
from .settings_utils import get_settings
from .system_prompts import EmptySystemPrompt, SystemPrompt
from .templates import Template
from .utils import json_load

settings = get_settings()


def get_task_data_dict(task_data):
import json

# seems like the task data sometimes comes as a string, not a dict
# this fixes it
return json.loads(task_data) if isinstance(task_data, str) else task_data
return json_load(task_data) if isinstance(task_data, str) else task_data


class LLMAsJudgeBase(BulkInstanceMetric, ArtifactFetcherMixin):
Expand Down
Loading