Skip to content

Commit 80e9934

Browse files
elronbandeldafnapension
authored andcommitted
Allow using python functions instead of operators (e.g in pre-processing pipeline)
Signed-off-by: elronbandel <[email protected]>
1 parent 01951fd commit 80e9934

File tree

16 files changed

+373
-46
lines changed

16 files changed

+373
-46
lines changed

docs/catalog.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22
import os
33
import re
4+
import types
45
from collections import defaultdict
56
from functools import lru_cache
67
from pathlib import Path
@@ -110,6 +111,7 @@ def all_subtypes_of_artifact(artifact):
110111
or isinstance(artifact, bool)
111112
or isinstance(artifact, int)
112113
or isinstance(artifact, float)
114+
or isinstance(artifact, types.FunctionType)
113115
):
114116
return []
115117
if isinstance(artifact, list):

docs/docs/adding_dataset.rst

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
To use this tutorial, you need to :ref:`install Unitxt <install_unitxt>`.
66

77
=================
8-
Datasets
8+
Datasets
99
=================
1010

1111
This guide will assist you in adding or using your new dataset in Unitxt.
@@ -105,6 +105,27 @@ Most data can be normalized to the task schema using built-in operators, ensurin
105105

106106
For custom operators, refer to the :ref:`Operators Tutorial <adding_operator>`.
107107

108+
.. tip::
109+
110+
If you cannot find operators fit to your needs simply use instance function operator:
111+
112+
.. code-block:: python
113+
114+
def my_function(instance, stream_name=None):
115+
instance["x"] += 42
116+
return instance
117+
118+
Or stream function operator:
119+
120+
.. code-block:: python
121+
122+
def my_other_function(stream, stream_name=None):
123+
for instance in stream:
124+
instance["x"] += 42
125+
yield instance
126+
127+
Both functions can be plugged in every place in unitxt requires operators, e.g pre-processing pipeline.
128+
108129
The Template
109130
----------------
110131

docs/docs/adding_operator.rst

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,33 @@
55
To use this tutorial, you need to :ref:`install unitxt <install_unitxt>`.
66

77
=====================================
8-
Operators
8+
Operators
99
=====================================
1010

1111
Operators are specialized functions designed to process data.
1212

13+
.. tip::
14+
15+
If you cannot find operators fit to your needs simply use instance function operator:
16+
17+
.. code-block:: python
18+
19+
def my_function(instance, stream_name=None):
20+
instance["x"] += 42
21+
return instance
22+
23+
Or stream function operator:
24+
25+
.. code-block:: python
26+
27+
def my_other_function(stream, stream_name=None):
28+
for instance in stream:
29+
instance["x"] += 42
30+
yield instance
31+
32+
Both functions can be plugged in every place in unitxt requires operators, e.g pre-processing pipeline.
33+
34+
1335
They are used in the TaskCard for preparing data for specific tasks and by Post Processors
1436
to process the textual output of the model to the expect input of the metrics.
1537

@@ -18,11 +40,11 @@ There are several types of operators.
1840
1. Field Operators - Operators that modify individual fields of the instances in the input streams. Example of such operators are operators that
1941
cast field values, uppercase string fields, or translate text between languages.
2042

21-
2. Instance Operators - Operators that modify individual instances in the input streams. For example, operators that add or remove fields.
43+
1. Instance Operators - Operators that modify individual instances in the input streams. For example, operators that add or remove fields.
2244

23-
3. Stream Operators - Operators that perform operations on full streams. For example, operators that remove instances based on some condition.
45+
2. Stream Operators - Operators that perform operations on full streams. For example, operators that remove instances based on some condition.
2446

25-
4. MultiStream Operators - Operator that perform operations on multiple streams. For example, operators that repartition the instances between train and test splits.
47+
3. MultiStream Operators - Operator that perform operations on multiple streams. For example, operators that repartition the instances between train and test splits.
2648

2749
Unitxt comes with a large collection of built in operators - that were design to cover most common requirements of dataset processing.
2850

prepare/cards/xlam_function_calling.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,26 @@
33
from unitxt.loaders import LoadHF
44
from unitxt.operators import (
55
Copy,
6-
ExecuteExpression,
76
Move,
87
Set,
98
)
109
from unitxt.splitters import RenameSplits
1110
from unitxt.struct_data_operators import LoadJson
1211
from unitxt.test_utils.card import test_card
1312

13+
14+
def extract_required_parameters(instance, stream_name=None):
15+
result = []
16+
for tool in instance["tools"]:
17+
required_params = []
18+
for param_name, param_info in tool["parameters"]["properties"].items():
19+
if "optional" not in param_info["type"]:
20+
required_params.append(param_name)
21+
result.append(required_params)
22+
instance["required"] = result
23+
return instance
24+
25+
1426
card = TaskCard(
1527
loader=LoadHF(
1628
path="Salesforce/xlam-function-calling-60k",
@@ -30,10 +42,7 @@
3042
to_field="tools/*/parameters/properties",
3143
set_every_value=True,
3244
),
33-
ExecuteExpression(
34-
to_field="required",
35-
expression="[[p for p, c in tool['parameters']['properties'].items() if 'optional' not in c['type']] for tool in tools]",
36-
),
45+
extract_required_parameters,
3746
Copy(
3847
field="required",
3948
to_field="tools/*/parameters/required",

src/unitxt/api.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from .settings_utils import get_constants, get_settings
2828
from .standard import DatasetRecipe
2929
from .task import Task
30+
from .utils import json_dump
3031

3132
logger = get_logger()
3233
constants = get_constants()
@@ -180,6 +181,20 @@ class MyClass:
180181
return obj_str
181182

182183

184+
def _remove_id_keys(obj):
185+
if isinstance(obj, dict):
186+
return {k: _remove_id_keys(v) for k, v in obj.items() if k != "__id__"}
187+
if isinstance(obj, list):
188+
return [_remove_id_keys(item) for item in obj]
189+
return obj
190+
191+
192+
def _artifact_string_repr(artifact):
193+
artifact_dict = to_dict(artifact, object_to_str_without_addresses)
194+
artifact_dict_without_ids = _remove_id_keys(artifact_dict)
195+
return json_dump(artifact_dict_without_ids)
196+
197+
183198
def _source_to_dataset(
184199
source: SourceOperator,
185200
split=None,
@@ -189,9 +204,7 @@ def _source_to_dataset(
189204
from .dataset import Dataset as UnitxtDataset
190205

191206
# Generate a unique signature for the source
192-
source_signature = json.dumps(
193-
to_dict(source, object_to_str_without_addresses), sort_keys=True
194-
)
207+
source_signature = _artifact_string_repr(source)
195208
config_name = "recipe-" + short_hex_hash(source_signature)
196209
# Obtain data stream from the source
197210
stream = source()

src/unitxt/artifact.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
import pkgutil
66
import re
7+
import types
78
import warnings
89
from abc import abstractmethod
910
from typing import Any, Dict, List, Optional, Tuple, Union, final
@@ -27,6 +28,8 @@
2728
from .utils import (
2829
artifacts_json_cache,
2930
json_dump,
31+
json_load,
32+
load_json,
3033
save_to_file,
3134
shallow_copy,
3235
)
@@ -119,16 +122,32 @@ def reset(self):
119122
self.catalogs = []
120123

121124

125+
def maybe_recover_function_operator(func):
126+
sig = inspect.signature(func)
127+
param_names = tuple(sorted(sig.parameters))
128+
if param_names == ("stream", "stream_name") or param_names == (
129+
"instance",
130+
"stream_name",
131+
):
132+
from .operators import FunctionOperator
133+
134+
return FunctionOperator(function=func)
135+
return func
136+
137+
122138
def maybe_recover_artifacts_structure(obj):
139+
if isinstance(obj, types.FunctionType):
140+
obj = maybe_recover_function_operator(obj)
141+
123142
if Artifact.is_possible_identifier(obj):
124143
return verbosed_fetch_artifact(obj)
125144
if isinstance(obj, dict):
126145
for key, value in obj.items():
127-
obj[key] = maybe_recover_artifact(value)
146+
obj[key] = maybe_recover_artifacts_structure(value)
128147
return obj
129148
if isinstance(obj, list):
130149
for i in range(len(obj)):
131-
obj[i] = maybe_recover_artifact(obj[i])
150+
obj[i] = maybe_recover_artifacts_structure(obj[i])
132151
return obj
133152
return obj
134153

@@ -237,8 +256,7 @@ def __init_subclass__(cls, **kwargs):
237256
def is_artifact_file(cls, path):
238257
if not os.path.exists(path) or not os.path.isfile(path):
239258
return False
240-
with open(path) as f:
241-
d = json.load(f)
259+
d = load_json(path)
242260
return cls.is_artifact_dict(d)
243261

244262
@classmethod
@@ -384,14 +402,15 @@ def serialize(self):
384402
return self.to_json()
385403

386404
def save(self, path):
387-
original_args = Artifact.from_dict(self.to_dict()).get_repr_dict()
405+
data = self.to_dict()
406+
original_args = Artifact.from_dict(data).get_repr_dict()
388407
current_args = self.get_repr_dict()
389408
diffs = dict_diff_string(original_args, current_args)
390409
if diffs:
391410
raise UnitxtError(
392411
f"Cannot save catalog artifacts that have changed since initialization. Detected differences in the following fields:\n{diffs}"
393412
)
394-
save_to_file(path, self.to_json())
413+
save_to_file(path, json_dump(data))
395414

396415
def verify_instance(
397416
self, instance: Dict[str, Any], name: Optional[str] = None
@@ -581,7 +600,7 @@ def fetch_artifact(
581600

582601
# If Json string, first load into dictionary
583602
if isinstance(artifact_rep, str):
584-
artifact_rep = json.loads(artifact_rep)
603+
artifact_rep = json_load(artifact_rep)
585604
# Load from dictionary (fails if not valid dictionary)
586605
return Artifact.from_dict(artifact_rep), None
587606

@@ -657,7 +676,7 @@ def get_artifacts_data_classification(artifact: str) -> Optional[List[str]]:
657676
)
658677

659678
try:
660-
data_classification = json.loads(data_classification)
679+
data_classification = json_load(data_classification)
661680
except json.decoder.JSONDecodeError as e:
662681
raise RuntimeError(error_msg) from e
663682

src/unitxt/catalog.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .logging_utils import get_logger
1919
from .settings_utils import get_constants
2020
from .text_utils import print_dict
21+
from .utils import json_load
2122
from .version import version
2223

2324
logger = get_logger()
@@ -228,7 +229,7 @@ def _get_tags_from_file(file_path):
228229
result = Counter()
229230

230231
with open(file_path) as f:
231-
data = json.load(f)
232+
data = json_load(f)
232233
if "__tags__" in data and isinstance(data["__tags__"], dict):
233234
tags = data["__tags__"]
234235
for key, value in tags.items():

src/unitxt/catalog/cards/xlam_function_calling_60k.json

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,8 @@
6161
"set_every_value": true
6262
},
6363
{
64-
"__type__": "execute_expression",
65-
"to_field": "required",
66-
"expression": "[[p for p, c in tool['parameters']['properties'].items() if 'optional' not in c['type']] for tool in tools]"
64+
"__function__": "extract_required_parameters",
65+
"source": "def extract_required_parameters(instance, stream_name=None):\n result = []\n for tool in instance[\"tools\"]:\n required_params = []\n for param_name, param_info in tool['parameters']['properties'].items():\n if 'optional' not in param_info['type']:\n required_params.append(param_name)\n result.append(required_params)\n instance[\"required\"] = result\n return instance\n"
6766
},
6867
{
6968
"__type__": "copy",

src/unitxt/inference.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
from .operators import ArtifactFetcherMixin
5252
from .settings_utils import get_constants, get_settings
5353
from .type_utils import isoftype
54-
from .utils import retry_connection_with_exponential_backoff
54+
from .utils import json_load, retry_connection_with_exponential_backoff
5555

5656
constants = get_constants()
5757
settings = get_settings()
@@ -391,7 +391,7 @@ def to_tools(self, instance):
391391
if task_data is None:
392392
return None
393393
if isinstance(task_data, str):
394-
task_data = json.loads(task_data)
394+
task_data = json_load(task_data)
395395
if "__tools__" in task_data:
396396
return task_data["__tools__"]
397397
return None
@@ -814,7 +814,7 @@ def _infer_fn(
814814
if "task_data" in instance and "__tools__" in instance["task_data"]:
815815
task_data = instance["task_data"]
816816
if isinstance(task_data, str):
817-
task_data = json.loads(task_data)
817+
task_data = json_load(task_data)
818818
tools.append(task_data["__tools__"])
819819
else:
820820
tools.append(None)
@@ -2768,7 +2768,7 @@ class WMLInferenceEngineChat(WMLInferenceEngineBase, WMLChatParamsMixin):
27682768
def _extract_queries(instance: Dict[str, Any]) -> Tuple[Optional[str], List]:
27692769
task_data = instance["task_data"]
27702770
if isinstance(task_data, str):
2771-
task_data = json.loads(task_data)
2771+
task_data = json_load(task_data)
27722772
question = task_data.get("question")
27732773

27742774
images = [None]
@@ -2888,7 +2888,7 @@ def to_tools(
28882888
return {"tools": None, "tool_choice": None}
28892889

28902890
if isinstance(task_data, str):
2891-
task_data = json.loads(task_data)
2891+
task_data = json_load(task_data)
28922892
if "__tools__" in task_data:
28932893
tools: List[Dict[str, str]] = task_data["__tools__"]
28942894
tool_choice: Optional[Dict[str, str]] = task_data.get("__tool_choice__")
@@ -3186,7 +3186,7 @@ def _infer(
31863186
task_data = instance["task_data"]
31873187

31883188
if isinstance(task_data, str):
3189-
task_data = json.loads(task_data)
3189+
task_data = json_load(task_data)
31903190

31913191
for option in task_data["options"]:
31923192
requests.append(
@@ -3866,7 +3866,7 @@ def _infer(
38663866
return_meta_data: bool = False,
38673867
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
38683868
task_data = [
3869-
json.loads(instance["task_data"]) if "task_data" in instance else {}
3869+
json_load(instance["task_data"]) if "task_data" in instance else {}
38703870
for instance in dataset
38713871
]
38723872
predictions = [td[self.prediction_field] for td in task_data]

src/unitxt/llm_as_judge_from_template.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,15 @@
1212
from .settings_utils import get_settings
1313
from .system_prompts import EmptySystemPrompt, SystemPrompt
1414
from .templates import Template
15+
from .utils import json_load
1516

1617
settings = get_settings()
1718

1819

1920
def get_task_data_dict(task_data):
20-
import json
21-
2221
# seems like the task data sometimes comes as a string, not a dict
2322
# this fixes it
24-
return json.loads(task_data) if isinstance(task_data, str) else task_data
23+
return json_load(task_data) if isinstance(task_data, str) else task_data
2524

2625

2726
class LLMAsJudgeBase(BulkInstanceMetric, ArtifactFetcherMixin):

0 commit comments

Comments
 (0)