|  | 
| 3 | 3 | import json | 
| 4 | 4 | import os | 
| 5 | 5 | import re | 
|  | 6 | +import subprocess | 
| 6 | 7 | import sys | 
| 7 | 8 | import sysconfig | 
| 8 | 9 | import warnings | 
|  | 
| 24 | 25 |     separate_inside_and_outside_square_brackets, | 
| 25 | 26 | ) | 
| 26 | 27 | from .settings_utils import get_constants, get_settings | 
|  | 28 | +from .text_utils import snake_to_camel_case | 
| 27 | 29 | from .type_utils import isoftype, issubtype | 
| 28 | 30 | from .utils import ( | 
| 29 | 31 |     artifacts_json_cache, | 
| @@ -227,9 +229,25 @@ def get_module_class_names(artifact_type: dict): | 
| 227 | 229 |     return artifact_type["module"], artifact_type["name"] | 
| 228 | 230 | 
 | 
| 229 | 231 | 
 | 
|  | 232 | +def convert_str_type_to_dict(type: str) -> dict: | 
|  | 233 | +    class_name = snake_to_camel_case(type) | 
|  | 234 | +    return { | 
|  | 235 | +        "module": find_unitxt_module_by_classname(camel_case_class_name=class_name), | 
|  | 236 | +        "name": class_name, | 
|  | 237 | +    } | 
|  | 238 | + | 
|  | 239 | + | 
| 230 | 240 | # type is the dict read from a catelog entry, the value of a key "__type__" | 
| 231 | 241 | def get_class_from_artifact_type(type: dict): | 
| 232 |  | -    module_path, class_name = get_module_class_names(type) | 
|  | 242 | +    if isinstance(type, str): | 
|  | 243 | +        if type in Artifact._class_register: | 
|  | 244 | +            return Artifact._class_register[type] | 
|  | 245 | + | 
|  | 246 | +        class_name = snake_to_camel_case(type) | 
|  | 247 | +        module_path = find_unitxt_module_by_classname(camel_case_class_name=class_name) | 
|  | 248 | +    else: | 
|  | 249 | +        module_path, class_name = get_module_class_names(type) | 
|  | 250 | + | 
| 233 | 251 |     if module_path == "class_register": | 
| 234 | 252 |         if class_name not in Artifact._class_register: | 
| 235 | 253 |             raise ValueError( | 
| @@ -487,12 +505,15 @@ def is_artifact_file(cls, path): | 
| 487 | 505 |     @classmethod | 
| 488 | 506 |     def load(cls, path, artifact_identifier=None, overwrite_args=None): | 
| 489 | 507 |         d = artifacts_json_cache(path) | 
| 490 |  | -        if "__type__" in d and d["__type__"]["name"].endswith("ArtifactLink"): | 
| 491 |  | -            from_dict(d)  # for verifications and warnings | 
| 492 |  | -            catalog, artifact_rep, _ = get_catalog_name_and_args(name=d["to"]) | 
| 493 |  | -            return catalog.get_with_overwrite( | 
| 494 |  | -                artifact_rep, overwrite_args=overwrite_args | 
| 495 |  | -            ) | 
|  | 508 | +        if "__type__" in d: | 
|  | 509 | +            if isinstance(d["__type__"], str): | 
|  | 510 | +                d["__type__"] = convert_str_type_to_dict(d["__type__"]) | 
|  | 511 | +            if d["__type__"]["name"].endswith("ArtifactLink"): | 
|  | 512 | +                from_dict(d)  # for verifications and warnings | 
|  | 513 | +                catalog, artifact_rep, _ = get_catalog_name_and_args(name=d["to"]) | 
|  | 514 | +                return catalog.get_with_overwrite( | 
|  | 515 | +                    artifact_rep, overwrite_args=overwrite_args | 
|  | 516 | +                ) | 
| 496 | 517 | 
 | 
| 497 | 518 |         new_artifact = from_dict(d, overwrite_args=overwrite_args) | 
| 498 | 519 |         new_artifact.__id__ = artifact_identifier | 
| @@ -898,3 +919,22 @@ def get_artifacts_data_classification(artifact: str) -> Optional[List[str]]: | 
| 898 | 919 |         return None | 
| 899 | 920 | 
 | 
| 900 | 921 |     return data_classification.get(artifact) | 
|  | 922 | + | 
|  | 923 | + | 
|  | 924 | +def find_unitxt_module_by_classname(camel_case_class_name: str): | 
|  | 925 | +    """Find a module, a member of src/unitxt, that contains the definition of the class.""" | 
|  | 926 | +    dir = os.path.dirname(__file__)  # dir  src/unitxt | 
|  | 927 | +    try: | 
|  | 928 | +        result = subprocess.run( | 
|  | 929 | +            ["grep", "-lrwE", "^class +" + camel_case_class_name, dir], | 
|  | 930 | +            capture_output=True, | 
|  | 931 | +        ).stdout.decode("ascii") | 
|  | 932 | +        results = result.split("\n") | 
|  | 933 | +        assert len(results) == 2, f"returned: {results}" | 
|  | 934 | +        assert results[-1] == "", f"last result is {results[-1]} rather than ''" | 
|  | 935 | +        to_return = results[0][:-3].replace("/", ".")  # trim the .py and replace | 
|  | 936 | +        return to_return[to_return.rfind("unitxt.") :] | 
|  | 937 | +    except Exception as e: | 
|  | 938 | +        raise ValueError( | 
|  | 939 | +            f"Could not find the unitxt module, under unitxt/src/unitxt, in which class {camel_case_class_name} is defined" | 
|  | 940 | +        ) from e | 
0 commit comments