Skip to content

Commit 429047d

Browse files
committed
change each __type__ in catalog to full qualified name, rather than snake of class name, and remove _class_register altogether
Signed-off-by: dafnapension <[email protected]>
1 parent aac33a6 commit 429047d

File tree

4,717 files changed

+26837
-26859
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

4,717 files changed

+26837
-26859
lines changed

docs/catalog.py

Lines changed: 18 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ def imports_to_syntax_highlighted_html(subtypes: List[str])-> str:
5050
return ""
5151
module_to_class_names = defaultdict(list)
5252
for subtype in subtypes:
53-
subtype_class = Artifact._class_register.get(subtype)
54-
module_to_class_names[subtype_class.__module__].append(subtype_class.__name__)
53+
(module, class_name) = Artifact.get_module_class(subtype)
54+
module_to_class_names[module].append(class_name)
5555

5656
imports_txt = ""
5757
for modu in sorted(module_to_class_names.keys()):
@@ -101,31 +101,6 @@ def custom_walk(top):
101101
yield entry
102102

103103

104-
def all_subtypes_of_artifact(artifact):
105-
if (
106-
artifact is None
107-
or isinstance(artifact, str)
108-
or isinstance(artifact, bool)
109-
or isinstance(artifact, int)
110-
or isinstance(artifact, float)
111-
):
112-
return []
113-
if isinstance(artifact, list):
114-
to_return = []
115-
for art in artifact:
116-
to_return.extend(all_subtypes_of_artifact(art))
117-
return to_return
118-
# artifact is a dict
119-
to_return = []
120-
for key, value in artifact.items():
121-
if isinstance(value, str):
122-
if key == "__type__":
123-
to_return.append(value)
124-
else:
125-
to_return.extend(all_subtypes_of_artifact(value))
126-
return to_return
127-
128-
129104
def get_all_type_elements(nested_dict):
130105
type_elements = set()
131106

@@ -148,19 +123,18 @@ def recursive_search(d):
148123

149124
@lru_cache(maxsize=None)
150125
def artifact_type_to_link(artifact_type):
151-
artifact_class = Artifact._class_register.get(artifact_type)
152-
type_class_name = artifact_class.__name__
153-
artifact_class_id = f"{artifact_class.__module__}.{type_class_name}"
154-
return f'<a class="reference internal" href="../{artifact_class.__module__}.html#{artifact_class_id}" title="{artifact_class_id}"><code class="xref py py-class docutils literal notranslate"><span class="pre">{type_class_name}</span></code></a>'
126+
artifact_module, artifact_class_name = Artifact.get_module_class(artifact_type)
127+
return f'<a class="reference internal" href="../{artifact_module}.html#{artifact_module}.{artifact_class_name}" title="{artifact_module}.{artifact_class_name}"><code class="xref py py-class docutils literal notranslate"><span class="pre">{artifact_class_name}</span></code></a>'
155128

156129

157130
# flake8: noqa: C901
131+
132+
158133
def make_content(artifact, label, all_labels):
159-
artifact_type = artifact["__type__"]
160-
artifact_class = Artifact._class_register.get(artifact_type)
161-
type_class_name = artifact_class.__name__
162-
catalog_id = label.replace("catalog.", "")
134+
artifact_type = artifact["__type__"] #qualified class name
135+
artifact_class = Artifact.get_class_from_artifact_type(artifact_type)
163136

137+
catalog_id = label.replace("catalog.", "")
164138
result = ""
165139

166140
if "__description__" in artifact and artifact["__description__"] is not None:
@@ -203,23 +177,16 @@ def make_content(artifact, label, all_labels):
203177
)
204178

205179
for type_name in type_elements:
206-
# source = f'<span class="nt">__type__</span><span class="p">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">{type_name}</span>'
207-
source = f'<span class="n">__type__{type_name}</span><span class="p">'
208-
target = artifact_type_to_link(type_name)
209-
html_for_dict = html_for_dict.replace(
210-
source,
211-
f'<span class="n" STYLE="font-size:108%">{target}</span><span class="p">'
212-
# '<span class="nt">&quot;type&quot;</span><span class="p">:</span><span class="w"> </span>'
213-
# + target,
214-
)
215-
216-
pattern = r'(<span class="nt">)&quot;(.*?)&quot;(</span>)'
180+
artifact_module, artifact_class_name = Artifact.get_module_class(type_name)
181+
pattern = re.compile(f'<span class="n">__type__(.*?)<span class="n">{artifact_class_name}</span>')
182+
repl = '<span class="n" STYLE="font-size:108%">'+artifact_type_to_link(type_name)+"</span>"
183+
html_for_dict = pattern.sub(repl, html_for_dict)
217184

185+
# pattern = r'(<span class="nt">)&quot;(.*?)&quot;(</span>)'
218186
# Replacement function
219-
html_for_dict = re.sub(pattern, r"\1\2\3", html_for_dict)
187+
# html_for_dict = re.sub(pattern, r"\1\2\3", html_for_dict)
220188

221-
subtypes = all_subtypes_of_artifact(artifact)
222-
subtypes = list(set(subtypes))
189+
subtypes = type_elements
223190
subtypes.remove(artifact_type) # this was already documented
224191
html_for_imports = imports_to_syntax_highlighted_html(subtypes)
225192

@@ -235,13 +202,13 @@ def make_content(artifact, label, all_labels):
235202
result += " " + html_for_element + "\n"
236203

237204
if artifact_class.__doc__:
238-
explanation_str = f"Explanation about `{type_class_name}`"
205+
explanation_str = f"Explanation about `{artifact_class.__name__}`"
239206
result += f"\n{explanation_str}\n"
240207
result += "+" * len(explanation_str) + "\n\n"
241208
result += artifact_class.__doc__ + "\n"
242209

243210
for subtype in subtypes:
244-
subtype_class = Artifact._class_register.get(subtype)
211+
subtype_class = Artifact.get_class_from_artifact_type(subtype)
245212
subtype_class_name = subtype_class.__name__
246213
if subtype_class.__doc__:
247214
explanation_str = f"Explanation about `{subtype_class_name}`"

docs/conf.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,6 @@ def autodoc_skip_member(app, what, name, obj, would_skip, options):
115115
class_name = obj.__qualname__.split(".")[0]
116116
if (
117117
class_name
118-
and Artifact.is_registered_class_name(class_name)
119118
and class_name != name
120119
):
121120
return True

examples/evaluate_rag_using_binary_llm_as_judge.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from unitxt import get_logger
22
from unitxt.api import evaluate, load_dataset
33
from unitxt.blocks import TaskCard
4-
from unitxt.inference import WMLInferenceEngine
4+
from unitxt.inference import WMLInferenceEngineGeneration
55
from unitxt.loaders import LoadFromDictionary
66
from unitxt.templates import TemplatesDict
77

@@ -77,7 +77,7 @@
7777

7878
# Infer using flan t5 xl using wml
7979
model_name = "google/flan-t5-xl"
80-
model = WMLInferenceEngine(model_name=model_name, max_new_tokens=32)
80+
model = WMLInferenceEngineGeneration(model_name=model_name, max_new_tokens=32)
8181
predictions = model(test_dataset)
8282

8383
# Evaluate the generated predictions using the selected metrics

prepare/engines/ibm_wml/llama3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from unitxt.catalog import add_to_catalog
2-
from unitxt.inference import WMLInferenceEngine
2+
from unitxt.inference import WMLInferenceEngineGeneration
33

44
model_list = ["meta-llama/llama-3-3-70b-instruct"]
55

66
for model in model_list:
77
model_label = model.split("/")[1].replace("-", "_").replace(".", ",").lower()
8-
inference_model = WMLInferenceEngine(
8+
inference_model = WMLInferenceEngineGeneration(
99
model_name=model, max_new_tokens=2048, random_seed=42
1010
)
1111
add_to_catalog(inference_model, f"engines.ibm_wml.{model_label}", overwrite=True)

prepare/metrics/custom_f1.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,4 +433,7 @@ class NERWithoutClassReporting(NER):
433433
global_target=global_target,
434434
)
435435

436-
add_to_catalog(metric, "metrics.ner", overwrite=True)
436+
if __name__ == "__main__" or __name__ == "custom_f1":
437+
# because a class is defined in this module, need to not add_to_catalog just for importing that module in order to retrieve the defined class
438+
# and need to prepare for case when this module is run directly from python (__main__) or, for example, from test_preparation (custom_f1)
439+
add_to_catalog(metric, "metrics.ner", overwrite=True)

prepare/metrics/llm_as_judge/conversation_groundedness.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33
from unitxt import add_to_catalog
44
from unitxt.inference import (
55
IbmGenAiInferenceEngine,
6-
IbmGenAiInferenceEngineParams,
6+
IbmGenAiInferenceEngineParamsMixin,
77
)
88
from unitxt.llm_as_judge import LLMAsJudge
99
from unitxt.metrics import (
1010
RandomForestMetricsEnsemble,
1111
)
1212

1313
platform = "ibm_gen_ai"
14-
gen_params = IbmGenAiInferenceEngineParams(max_new_tokens=256)
14+
gen_params = IbmGenAiInferenceEngineParamsMixin(max_new_tokens=256)
1515

1616
config_filepath = "prepare/metrics/llm_as_judge/ensemble_grounded_v1.json"
1717

prepare/metrics/llm_as_judge/conversation_idk.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from unitxt.inference import (
33
GenericInferenceEngine,
44
IbmGenAiInferenceEngine,
5-
IbmGenAiInferenceEngineParams,
5+
IbmGenAiInferenceEngineParamsMixin,
66
)
77
from unitxt.llm_as_judge import LLMAsJudge
88

@@ -13,7 +13,7 @@
1313
"model_name": "llama370binstruct",
1414
"inference_model": IbmGenAiInferenceEngine(
1515
model_name="meta-llama/llama-3-70b-instruct",
16-
parameters=IbmGenAiInferenceEngineParams(max_new_tokens=256),
16+
parameters=IbmGenAiInferenceEngineParamsMixin(max_new_tokens=256),
1717
),
1818
},
1919
"generic_inference_engine": {

prepare/metrics/llm_as_judge/conversation_topicality.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33
from unitxt import add_to_catalog
44
from unitxt.inference import (
55
IbmGenAiInferenceEngine,
6-
IbmGenAiInferenceEngineParams,
6+
IbmGenAiInferenceEngineParamsMixin,
77
)
88
from unitxt.llm_as_judge import LLMAsJudge
99
from unitxt.metrics import (
1010
RandomForestMetricsEnsemble,
1111
)
1212

1313
platform = "ibm_gen_ai"
14-
gen_params = IbmGenAiInferenceEngineParams(max_new_tokens=256)
14+
gen_params = IbmGenAiInferenceEngineParamsMixin(max_new_tokens=256)
1515

1616
config_filepath = "prepare/metrics/llm_as_judge/ensemble_topicality_v1.json"
1717

prepare/metrics/llm_as_judge/pairwise_rating/llama_3_arena_hard_template.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
CrossProviderInferenceEngine,
44
GenericInferenceEngine,
55
IbmGenAiInferenceEngine,
6-
WMLInferenceEngine,
6+
WMLInferenceEngineGeneration,
77
)
88
from unitxt.llm_as_judge import LLMAsJudge
99

@@ -15,7 +15,7 @@
1515
]
1616

1717
inference_engines = [
18-
("ibm_wml", WMLInferenceEngine),
18+
("ibm_wml", WMLInferenceEngineGeneration),
1919
("ibm_genai", IbmGenAiInferenceEngine),
2020
("generic_engine", GenericInferenceEngine),
2121
]

0 commit comments

Comments
 (0)