Skip to content

Commit d9425b2

Browse files
committed
change each __type__ in catalog to full qualified name rather than snake. read and write to a registry that just maps qualified-class-name to itself. toward removal of registry
Signed-off-by: dafnapension <[email protected]>
1 parent 1d0596f commit d9425b2

File tree

4,673 files changed

+26597
-26562
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

4,673 files changed

+26597
-26562
lines changed

docs/catalog.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ def imports_to_syntax_highlighted_html(subtypes: List[str])-> str:
5050
return ""
5151
module_to_class_names = defaultdict(list)
5252
for subtype in subtypes:
53-
subtype_class = Artifact._class_register.get(subtype)
54-
module_to_class_names[subtype_class.__module__].append(subtype_class.__name__)
53+
(module, class_name) = Artifact.get_module_class(subtype)
54+
module_to_class_names[module].append(class_name)
5555

5656
imports_txt = ""
5757
for modu in sorted(module_to_class_names.keys()):
@@ -148,19 +148,18 @@ def recursive_search(d):
148148

149149
@lru_cache(maxsize=None)
150150
def artifact_type_to_link(artifact_type):
151-
artifact_class = Artifact._class_register.get(artifact_type)
152-
type_class_name = artifact_class.__name__
153-
artifact_class_id = f"{artifact_class.__module__}.{type_class_name}"
154-
return f'<a class="reference internal" href="../{artifact_class.__module__}.html#{artifact_class_id}" title="{artifact_class_id}"><code class="xref py py-class docutils literal notranslate"><span class="pre">{type_class_name}</span></code></a>'
151+
artifact_module, artifact_class_name = Artifact.get_module_class(artifact_type)
152+
return f'<a class="reference internal" href="../{artifact_module}.html#{artifact_module}.{artifact_class_name}" title="{artifact_module}.{artifact_class_name}"><code class="xref py py-class docutils literal notranslate"><span class="pre">{artifact_class_name}</span></code></a>'
155153

156154

157155
# flake8: noqa: C901
156+
157+
158158
def make_content(artifact, label, all_labels):
159-
artifact_type = artifact["__type__"]
160-
artifact_class = Artifact._class_register.get(artifact_type)
161-
type_class_name = artifact_class.__name__
162-
catalog_id = label.replace("catalog.", "")
159+
artifact_type = artifact["__type__"] #qualified class name
160+
artifact_class = Artifact.get_class_from_artifact_type(artifact_type)
163161

162+
catalog_id = label.replace("catalog.", "")
164163
result = ""
165164

166165
if "__description__" in artifact and artifact["__description__"] is not None:
@@ -235,7 +234,7 @@ def make_content(artifact, label, all_labels):
235234
result += " " + html_for_element + "\n"
236235

237236
if artifact_class.__doc__:
238-
explanation_str = f"Explanation about `{type_class_name}`"
237+
explanation_str = f"Explanation about `{artifact_class.__name__}`"
239238
result += f"\n{explanation_str}\n"
240239
result += "+" * len(explanation_str) + "\n\n"
241240
result += artifact_class.__doc__ + "\n"

src/unitxt/artifact.py

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import re
77
import warnings
88
from abc import abstractmethod
9+
from importlib import import_module
910
from typing import Any, Dict, List, Optional, Tuple, Union, final
1011

1112
from .dataclass import (
@@ -22,7 +23,7 @@
2223
separate_inside_and_outside_square_brackets,
2324
)
2425
from .settings_utils import get_constants, get_settings
25-
from .text_utils import camel_to_snake_case, is_camel_case
26+
from .text_utils import is_camel_case
2627
from .type_utils import isoftype, issubtype
2728
from .utils import (
2829
artifacts_json_cache,
@@ -143,7 +144,7 @@ def get_closest_artifact_type(type):
143144

144145
class UnrecognizedArtifactTypeError(ValueError):
145146
def __init__(self, type) -> None:
146-
maybe_class = "".join(word.capitalize() for word in type.split("_"))
147+
maybe_class = type.split(".")[-1]
147148
message = f"'{type}' is not a recognized artifact 'type'. Make sure a the class defined this type (Probably called '{maybe_class}' or similar) is defined and/or imported anywhere in the code executed."
148149
closest_artifact_type = get_closest_artifact_type(type)
149150
if closest_artifact_type is not None:
@@ -200,12 +201,21 @@ def verify_artifact_dict(cls, d):
200201
)
201202
if "__type__" not in d:
202203
raise MissingArtifactTypeError(d)
203-
if not cls.is_registered_type(d["__type__"]):
204-
raise UnrecognizedArtifactTypeError(d["__type__"])
204+
# if not cls.is_registered_type(d["__type__"]):
205+
# raise UnrecognizedArtifactTypeError(d["__type__"])
205206

206207
@classmethod
207208
def get_artifact_type(cls):
208-
return camel_to_snake_case(cls.__name__)
209+
if hasattr(cls, "__qualname__") and "." in cls.__qualname__:
210+
return cls.__module__+"/"+cls.__qualname__
211+
return cls.__module__+"."+cls.__name__
212+
213+
@classmethod
214+
def get_module_class(cls, artifact_type:str):
215+
if "/" in artifact_type:
216+
return artifact_type.split("/")
217+
return artifact_type.rsplit(".", 1)
218+
209219

210220
@classmethod
211221
def register_class(cls, artifact_class):
@@ -216,18 +226,16 @@ def register_class(cls, artifact_class):
216226
artifact_class.__name__
217227
), f"Artifact class name must be legal camel case, got '{artifact_class.__name__}'"
218228

219-
snake_case_key = camel_to_snake_case(artifact_class.__name__)
220-
221-
if cls.is_registered_type(snake_case_key):
229+
if cls.is_registered_type(cls.get_artifact_type()):
222230
assert (
223-
str(cls._class_register[snake_case_key]) == str(artifact_class)
224-
), f"Artifact class name must be unique, '{snake_case_key}' already exists for {cls._class_register[snake_case_key]}. Cannot be overridden by {artifact_class}."
231+
str(cls._class_register[cls.get_artifact_type()]) == cls.get_artifact_type()
232+
), f"Artifact class name must be unique, '{cls.get_artifact_type()}' is already registered as {cls._class_register[cls.get_artifact_type()]}. Cannot be overridden by {artifact_class}."
225233

226-
return snake_case_key
234+
return cls.get_artifact_type()
227235

228-
cls._class_register[snake_case_key] = artifact_class
236+
cls._class_register[cls.get_artifact_type()] = cls.get_artifact_type() # for now, still maintain the registry from qualified to qualified
229237

230-
return snake_case_key
238+
return cls.get_artifact_type()
231239

232240
def __init_subclass__(cls, **kwargs):
233241
super().__init_subclass__(**kwargs)
@@ -247,12 +255,25 @@ def is_registered_type(cls, type: str):
247255

248256
@classmethod
249257
def is_registered_class_name(cls, class_name: str):
250-
snake_case_key = camel_to_snake_case(class_name)
251-
return cls.is_registered_type(snake_case_key)
258+
for k in cls._class_register:
259+
_, artifact_class_name = cls.get_module_class(k)
260+
if artifact_class_name == class_name:
261+
return True
262+
return False
252263

253264
@classmethod
254-
def is_registered_class(cls, clz: object):
255-
return clz in set(cls._class_register.values())
265+
def get_class_from_artifact_type(cls, type:str):
266+
module_path, class_name = cls.get_module_class(type)
267+
module = import_module(module_path)
268+
if "." not in class_name:
269+
return getattr(module, class_name)
270+
class_name_components = class_name.split(".")
271+
klass = getattr(module, class_name_components[0])
272+
for i in range (1, len(class_name_components)):
273+
klass = getattr(klass, class_name_components[i])
274+
return klass
275+
276+
256277

257278
@classmethod
258279
def _recursive_load(cls, obj):
@@ -267,9 +288,12 @@ def _recursive_load(cls, obj):
267288
pass
268289
if cls.is_artifact_dict(obj):
269290
cls.verify_artifact_dict(obj)
270-
artifact_class = cls._class_register[obj.pop("__type__")]
271-
obj = artifact_class.process_data_after_load(obj)
272-
return artifact_class(**obj)
291+
try:
292+
artifact_class = cls.get_class_from_artifact_type(obj.pop("__type__"))
293+
obj = artifact_class.process_data_after_load(obj)
294+
return artifact_class(**obj)
295+
except (ImportError, AttributeError) as e:
296+
raise ImportError(obj) from e
273297

274298
return obj
275299

@@ -283,7 +307,7 @@ def from_dict(cls, d, overwrite_args=None):
283307
@classmethod
284308
def load(cls, path, artifact_identifier=None, overwrite_args=None):
285309
d = artifacts_json_cache(path)
286-
if "__type__" in d and d["__type__"] == "artifact_link":
310+
if "__type__" in d and d["__type__"].endswith("ArtifactLink"):
287311
cls.from_dict(d) # for verifications and warnings
288312
catalog, artifact_rep, _ = get_catalog_name_and_args(name=d["to"])
289313
return catalog.get_with_overwrite(
@@ -347,7 +371,7 @@ def __post_init__(self):
347371

348372
def _to_raw_dict(self):
349373
return {
350-
"__type__": self.__type__,
374+
"__type__": self.__class__.get_artifact_type(),
351375
**self.process_data_before_dump(self._init_dict),
352376
}
353377

src/unitxt/catalog.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ def add_to_catalog(
128128
catalog_path: Optional[str] = None,
129129
verbose=True,
130130
):
131+
# print(artifact.__class__.__name__)
132+
# print(artifact.__module__)
131133
reset_artifacts_json_cache()
132134
if catalog is None:
133135
if catalog_path is None:
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
{
2-
"__type__": "artifact_link",
2+
"__type__": "unitxt.artifact.ArtifactLink",
33
"to": "augmentors.text.whitespace_prefix_suffix",
44
"__deprecated_msg__": "Artifact 'augmentors.augment_whitespace_prefix_and_suffix_task_input' is deprecated. Artifact 'augmentors.text.whitespace_prefix_suffix' will be instantiated instead. In future uses, please reference artifact 'augmentors.text.whitespace_prefix_suffix' directly."
55
}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
{
2-
"__type__": "artifact_link",
2+
"__type__": "unitxt.artifact.ArtifactLink",
33
"to": "augmentors.text.whitespace_prefix_suffix",
44
"__deprecated_msg__": "Artifact 'augmentors.augment_whitespace_task_input' is deprecated. Artifact 'augmentors.text.whitespace_prefix_suffix' will be instantiated instead. In future uses, please reference artifact 'augmentors.text.whitespace_prefix_suffix' directly."
55
}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
{
2-
"__type__": "gray_scale"
2+
"__type__": "unitxt.image_operators.GrayScale"
33
}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
{
2-
"__type__": "grid_lines"
2+
"__type__": "unitxt.image_operators.GridLines"
33
}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
{
2-
"__type__": "oldify"
2+
"__type__": "unitxt.image_operators.Oldify"
33
}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
{
2-
"__type__": "pixel_noise"
2+
"__type__": "unitxt.image_operators.PixelNoise"
33
}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
{
2-
"__type__": "null_augmentor"
2+
"__type__": "unitxt.augmentors.NullAugmentor"
33
}

0 commit comments

Comments
 (0)