Skip to content

Commit 853e8ea

Browse files
committed
prepare for backward compatibility
Signed-off-by: dafnapension <[email protected]>
1 parent 66ce6a6 commit 853e8ea

File tree

2 files changed

+71
-7
lines changed

2 files changed

+71
-7
lines changed

src/unitxt/artifact.py

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import json
44
import os
55
import re
6+
import subprocess
67
import sys
78
import sysconfig
89
import warnings
@@ -24,6 +25,7 @@
2425
separate_inside_and_outside_square_brackets,
2526
)
2627
from .settings_utils import get_constants, get_settings
28+
from .text_utils import snake_to_camel_case
2729
from .type_utils import isoftype, issubtype
2830
from .utils import (
2931
artifacts_json_cache,
@@ -227,9 +229,29 @@ def get_module_class_names(artifact_type: dict):
227229
return artifact_type["module"], artifact_type["name"]
228230

229231

232+
def convert_str_type_to_dict(type: str) -> dict:
233+
class_name = snake_to_camel_case(type)
234+
module, class_name = find_unitxt_module_and_class_by_classname(
235+
camel_case_class_name=class_name
236+
)
237+
return {
238+
"module": module,
239+
"name": class_name,
240+
}
241+
242+
230243
# type is the dict read from a catelog entry, the value of a key "__type__"
231244
def get_class_from_artifact_type(type: dict):
232-
module_path, class_name = get_module_class_names(type)
245+
if isinstance(type, str):
246+
if type in Artifact._class_register:
247+
return Artifact._class_register[type]
248+
249+
module_path, class_name = find_unitxt_module_and_class_by_classname(
250+
snake_to_camel_case(type)
251+
)
252+
else:
253+
module_path, class_name = get_module_class_names(type)
254+
233255
if module_path == "class_register":
234256
if class_name not in Artifact._class_register:
235257
raise ValueError(
@@ -487,12 +509,15 @@ def is_artifact_file(cls, path):
487509
@classmethod
488510
def load(cls, path, artifact_identifier=None, overwrite_args=None):
489511
d = artifacts_json_cache(path)
490-
if "__type__" in d and d["__type__"]["name"].endswith("ArtifactLink"):
491-
from_dict(d) # for verifications and warnings
492-
catalog, artifact_rep, _ = get_catalog_name_and_args(name=d["to"])
493-
return catalog.get_with_overwrite(
494-
artifact_rep, overwrite_args=overwrite_args
495-
)
512+
if "__type__" in d:
513+
if isinstance(d["__type__"], str):
514+
d["__type__"] = convert_str_type_to_dict(d["__type__"])
515+
if d["__type__"]["name"].endswith("ArtifactLink"):
516+
from_dict(d) # for verifications and warnings
517+
catalog, artifact_rep, _ = get_catalog_name_and_args(name=d["to"])
518+
return catalog.get_with_overwrite(
519+
artifact_rep, overwrite_args=overwrite_args
520+
)
496521

497522
new_artifact = from_dict(d, overwrite_args=overwrite_args)
498523
new_artifact.__id__ = artifact_identifier
@@ -898,3 +923,29 @@ def get_artifacts_data_classification(artifact: str) -> Optional[List[str]]:
898923
return None
899924

900925
return data_classification.get(artifact)
926+
927+
928+
def find_unitxt_module_and_class_by_classname(camel_case_class_name: str):
929+
"""Find a module, a member of src/unitxt, that contains the definition of the class."""
930+
dir = os.path.dirname(__file__) # dir src/unitxt
931+
try:
932+
result = subprocess.run(
933+
["grep", "-irwE", "^class +" + camel_case_class_name, dir],
934+
capture_output=True,
935+
).stdout.decode("ascii")
936+
results = result.split("\n")
937+
assert len(results) == 2, f"returned: {results}"
938+
assert results[-1] == "", f"last result is {results[-1]} rather than ''"
939+
to_return_module = (
940+
results[0].split(":")[0][:-3].replace("/", ".")
941+
) # trim the .py and replace
942+
to_return_class_name = results[0].split(":")[1][
943+
6 : 6 + len(camel_case_class_name)
944+
]
945+
return to_return_module[
946+
to_return_module.rfind("unitxt.") :
947+
], to_return_class_name
948+
except Exception as e:
949+
raise ValueError(
950+
f"Could not find the unitxt module, under unitxt/src/unitxt, in which class {camel_case_class_name} is defined"
951+
) from e

src/unitxt/text_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,19 @@ def camel_to_snake_case(s):
7171
return s.lower()
7272

7373

74+
def snake_to_camel_case(s):
75+
"""Converts a snake_case string s to CamelCase. Assume a class name is in question so result to start with an upper case.
76+
77+
Not always the reciprocal of the above camel_to_snake_case. e.g: camel_to_snake_case(LoadHF) = load_hf,
78+
whereas snake_to_camel_case(load_hf) = LoadHf
79+
"""
80+
s = s.strip()
81+
words = s.split("_")
82+
# Capitalize all words and join them
83+
camel_case_parts = [word.capitalize() for word in words]
84+
return "".join(camel_case_parts)
85+
86+
7487
def to_pretty_string(
7588
value,
7689
indent=0,

0 commit comments

Comments
 (0)