Skip to content

Commit e1a908a

Browse files
committed
prepare for backward compatibility
Signed-off-by: dafnapension <[email protected]>
1 parent 7ca1705 commit e1a908a

File tree

2 files changed

+99
-7
lines changed

2 files changed

+99
-7
lines changed

src/unitxt/artifact.py

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import json
44
import os
55
import re
6+
import subprocess
67
import sys
78
import sysconfig
89
import warnings
@@ -24,6 +25,7 @@
2425
separate_inside_and_outside_square_brackets,
2526
)
2627
from .settings_utils import get_constants, get_settings
28+
from .text_utils import snake_to_camel_case
2729
from .type_utils import isoftype, issubtype
2830
from .utils import (
2931
artifacts_json_cache,
@@ -227,9 +229,25 @@ def get_module_class_names(artifact_type: dict):
227229
return artifact_type["module"], artifact_type["name"]
228230

229231

232+
def convert_str_type_to_dict(type: str) -> dict:
233+
class_name = snake_to_camel_case(type)
234+
return {
235+
"module": find_unitxt_module_by_classname(camel_case_class_name=class_name),
236+
"name": class_name,
237+
}
238+
239+
230240
# type is the dict read from a catelog entry, the value of a key "__type__"
231241
def get_class_from_artifact_type(type: dict):
232-
module_path, class_name = get_module_class_names(type)
242+
if isinstance(type, str):
243+
if type in Artifact._class_register:
244+
return Artifact._class_register[type]
245+
246+
class_name = snake_to_camel_case(type)
247+
module_path = find_unitxt_module_by_classname(camel_case_class_name=class_name)
248+
else:
249+
module_path, class_name = get_module_class_names(type)
250+
233251
if module_path == "class_register":
234252
if class_name not in Artifact._class_register:
235253
raise ValueError(
@@ -487,12 +505,15 @@ def is_artifact_file(cls, path):
487505
@classmethod
488506
def load(cls, path, artifact_identifier=None, overwrite_args=None):
489507
d = artifacts_json_cache(path)
490-
if "__type__" in d and d["__type__"]["name"].endswith("ArtifactLink"):
491-
from_dict(d) # for verifications and warnings
492-
catalog, artifact_rep, _ = get_catalog_name_and_args(name=d["to"])
493-
return catalog.get_with_overwrite(
494-
artifact_rep, overwrite_args=overwrite_args
495-
)
508+
if "__type__" in d:
509+
if isinstance(d["__type__"], str):
510+
d["__type__"] = convert_str_type_to_dict(d["__type__"])
511+
if d["__type__"]["name"].endswith("ArtifactLink"):
512+
from_dict(d) # for verifications and warnings
513+
catalog, artifact_rep, _ = get_catalog_name_and_args(name=d["to"])
514+
return catalog.get_with_overwrite(
515+
artifact_rep, overwrite_args=overwrite_args
516+
)
496517

497518
new_artifact = from_dict(d, overwrite_args=overwrite_args)
498519
new_artifact.__id__ = artifact_identifier
@@ -898,3 +919,22 @@ def get_artifacts_data_classification(artifact: str) -> Optional[List[str]]:
898919
return None
899920

900921
return data_classification.get(artifact)
922+
923+
924+
def find_unitxt_module_by_classname(camel_case_class_name: str):
925+
"""Find a module, a member of src/unitxt, that contains the definition of the class."""
926+
dir = os.path.dirname(__file__) # dir src/unitxt
927+
try:
928+
result = subprocess.run(
929+
["grep", "-lrwE", "^class +" + camel_case_class_name, dir],
930+
capture_output=True,
931+
).stdout.decode("ascii")
932+
results = result.split("\n")
933+
assert len(results) == 2, f"returned: {results}"
934+
assert results[-1] == "", f"last result is {results[-1]} rather than ''"
935+
to_return = results[0][:-3].replace("/", ".") # trim the .py and replace
936+
return to_return[to_return.rfind("unitxt.") :]
937+
except Exception as e:
938+
raise ValueError(
939+
f"Could not find the unitxt module, under unitxt/src/unitxt, in which class {camel_case_class_name} is defined"
940+
) from e

src/unitxt/text_utils.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,58 @@ def camel_to_snake_case(s):
7171
return s.lower()
7272

7373

74+
def snake_to_camel_case(s):
75+
"""Converts a snake_case string s to CamelCase. Assume a class name is in question so result to start with an upper case."""
76+
s = s.strip()
77+
words = s.split("_")
78+
# Capitalize all words and join them
79+
camel_case_parts = [word.capitalize() for word in words]
80+
final_fixes = {
81+
"LoadHf": "LoadHF",
82+
"LoadCsv": "LoadCSV",
83+
"LoadFromHfSpace": "LoadFromHFSpace",
84+
"LoadIob": "LoadIOB",
85+
"AddId": "AddID",
86+
"Anls": "ANLS",
87+
"AzureOpenAiInferenceEngine": "AzureOpenAIInferenceEngine",
88+
"ChatApiFormat": "ChatAPIFormat",
89+
"ExactMatchMm": "ExactMatchMM",
90+
"FaithfulnessHhem": "FaithfulnessHHEM",
91+
"FinQaEval": "FinQAEval",
92+
"FixedGroupPdrParaphraseAccuracy": "FixedGroupPDRParaphraseAccuracy",
93+
"FixedGroupPdrParaphraseStringContainment": "FixedGroupPDRParaphraseStringContainment",
94+
"GetSql": "GetSQL",
95+
"HfPipelineBasedInferenceEngine": "HFPipelineBasedInferenceEngine",
96+
"HfSystemFormat": "HFSystemFormat",
97+
"Kpa": "KPA",
98+
"LlmAsJudge": "LLMAsJudge",
99+
"LlmJudgeDirect": "LLMJudgeDirect",
100+
"LlmJudgePairwise": "LLMJudgePairwise",
101+
"Map": "MAP",
102+
"MapHtmlTableToJson": "MapHTMLTableToJSON",
103+
"MapTableListsToStdTableJson": "MapTableListsToStdTableJSON",
104+
"Mrr": "MRR",
105+
"Ndcg": "NDCG",
106+
"Ner": "NER",
107+
"ParseCsv": "ParseCSV",
108+
"RitsInferenceEngine": "RITSInferenceEngine",
109+
"SerializeTableAsDfLoader": "SerializeTableAsDFLoader",
110+
"SerializeTableAsHtml": "SerializeTableAsHTML",
111+
"SqlDatabaseAsSchemaSerializer": "SQLDatabaseAsSchemaSerializer",
112+
"SqlExecutionAccuracy": "SQLExecutionAccuracy",
113+
"SqlNonExecutionAccuracy": "SQLNonExecutionAccuracy",
114+
"TaskBasedLlMasJudge": "TaskBasedLLMasJudge",
115+
"ToRgb": "ToRGB",
116+
"TurlColumnTypeAnnotationLoader": "TURLColumnTypeAnnotationLoader",
117+
"WmlInferenceEngine": "WMLInferenceEngine",
118+
"WmlInferenceEngineGeneration": "WMLInferenceEngineGeneration",
119+
}
120+
to_return = "".join(camel_case_parts)
121+
if to_return in final_fixes:
122+
return final_fixes[to_return]
123+
return to_return
124+
125+
74126
def to_pretty_string(
75127
value,
76128
indent=0,

0 commit comments

Comments
 (0)