Skip to content

Commit 137df10

Browse files
committed
remove _class_register altogether
Signed-off-by: dafnapension <[email protected]>
1 parent 1eb4241 commit 137df10

File tree

4 files changed

+14
-89
lines changed

4 files changed

+14
-89
lines changed

docs/conf.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,6 @@ def autodoc_skip_member(app, what, name, obj, would_skip, options):
116116
class_name = obj.__qualname__.split(".")[0]
117117
if (
118118
class_name
119-
and Artifact.is_registered_class_name(class_name)
120119
and class_name != name
121120
):
122121
return True

src/unitxt/artifact.py

Lines changed: 12 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import difflib
21
import inspect
32
import json
43
import os
@@ -23,7 +22,6 @@
2322
separate_inside_and_outside_square_brackets,
2423
)
2524
from .settings_utils import get_constants, get_settings
26-
from .text_utils import is_camel_case
2725
from .type_utils import isoftype, issubtype
2826
from .utils import (
2927
artifacts_json_cache,
@@ -134,21 +132,11 @@ def maybe_recover_artifacts_structure(obj):
134132
return obj
135133

136134

137-
def get_closest_artifact_type(type):
138-
artifact_type_options = list(Artifact._class_register.keys())
139-
matches = difflib.get_close_matches(type, artifact_type_options)
140-
if matches:
141-
return matches[0] # Return the closest match
142-
return None
143-
144135

145136
class UnrecognizedArtifactTypeError(ValueError):
146137
def __init__(self, type) -> None:
147138
maybe_class = type.split(".")[-1]
148139
message = f"'{type}' is not a recognized artifact 'type'. Make sure a the class defined this type (Probably called '{maybe_class}' or similar) is defined and/or imported anywhere in the code executed."
149-
closest_artifact_type = get_closest_artifact_type(type)
150-
if closest_artifact_type is not None:
151-
message += f"\n\nDid you mean '{closest_artifact_type}'?"
152140
super().__init__(message)
153141

154142

@@ -161,7 +149,7 @@ def __init__(self, dic) -> None:
161149

162150

163151
class Artifact(Dataclass):
164-
_class_register = {}
152+
# _class_register = {}
165153

166154
__type__: str = Field(default=None, final=True, init=False)
167155
__title__: str = NonPositionalField(
@@ -252,29 +240,9 @@ def get_module_class(cls, artifact_type:str):
252240
return artifact_type.rsplit(".", 1)
253241

254242

255-
@classmethod
256-
def register_class(cls, artifact_class):
257-
assert issubclass(
258-
artifact_class, Artifact
259-
), f"Artifact class must be a subclass of Artifact, got '{artifact_class}'"
260-
assert is_camel_case(
261-
artifact_class.__name__
262-
), f"Artifact class name must be legal camel case, got '{artifact_class.__name__}'"
263-
264-
if cls.is_registered_type(cls.get_artifact_type()):
265-
assert (
266-
str(cls._class_register[cls.get_artifact_type()]) == cls.get_artifact_type()
267-
), f"Artifact class name must be unique, '{cls.get_artifact_type()}' is already registered as {cls._class_register[cls.get_artifact_type()]}. Cannot be overridden by {artifact_class}."
268-
269-
return cls.get_artifact_type()
270-
271-
cls._class_register[cls.get_artifact_type()] = cls.get_artifact_type() # for now, still maintain the registry from qualified to qualified
272-
273-
return cls.get_artifact_type()
274243

275244
def __init_subclass__(cls, **kwargs):
276245
super().__init_subclass__(**kwargs)
277-
cls.register_class(cls)
278246

279247
@classmethod
280248
def is_artifact_file(cls, path):
@@ -284,18 +252,6 @@ def is_artifact_file(cls, path):
284252
d = json.load(f)
285253
return cls.is_artifact_dict(d)
286254

287-
@classmethod
288-
def is_registered_type(cls, type: str):
289-
return type in cls._class_register
290-
291-
@classmethod
292-
def is_registered_class_name(cls, class_name: str):
293-
for k in cls._class_register:
294-
_, artifact_class_name = cls.get_module_class(k)
295-
if artifact_class_name == class_name:
296-
return True
297-
return False
298-
299255
@classmethod
300256
def get_class_from_artifact_type(cls, type:str):
301257
module_path, class_name = cls.get_module_class(type)
@@ -313,23 +269,17 @@ def get_class_from_artifact_type(cls, type:str):
313269
@classmethod
314270
def _recursive_load(cls, obj):
315271
if isinstance(obj, dict):
316-
new_d = {}
317-
for key, value in obj.items():
318-
new_d[key] = cls._recursive_load(value)
319-
obj = new_d
272+
obj = {key: cls._recursive_load(value) for key, value in obj.items()}
273+
if cls.is_artifact_dict(obj):
274+
try:
275+
artifact_type = obj.pop("__type__")
276+
artifact_class = cls.get_class_from_artifact_type(artifact_type)
277+
obj = artifact_class.process_data_after_load(obj)
278+
return artifact_class(**obj)
279+
except (ImportError, AttributeError) as e:
280+
raise UnrecognizedArtifactTypeError(artifact_type) from e
320281
elif isinstance(obj, list):
321-
obj = [cls._recursive_load(value) for value in obj]
322-
else:
323-
pass
324-
if cls.is_artifact_dict(obj):
325-
cls.verify_artifact_dict(obj)
326-
try:
327-
artifact_type = obj.pop("__type__")
328-
artifact_class = cls.get_class_from_artifact_type(artifact_type)
329-
obj = artifact_class.process_data_after_load(obj)
330-
return artifact_class(**obj)
331-
except (ImportError, AttributeError) as e:
332-
raise UnrecognizedArtifactTypeError(artifact_type) from e
282+
return [cls._recursive_load(value) for value in obj]
333283

334284
return obj
335285

@@ -389,7 +339,7 @@ def verify_data_classification_policy(self):
389339

390340
@final
391341
def __post_init__(self):
392-
self.__type__ = self.register_class(self.__class__)
342+
self.__type__ = self.__class__.get_artifact_type()
393343

394344
for field in fields(self):
395345
if issubtype(
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
{
2-
"__type__": "hf_system_format",
2+
"__type__": "unitxt.formats.HFSystemFormat",
33
"model_name": "ibm-granite/granite-3.1-2b-instruct"
44
}

src/unitxt/register.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
1-
import importlib
2-
import inspect
31
import os
42
from pathlib import Path
53

6-
from .artifact import Artifact, Catalogs
4+
from .artifact import Catalogs
75
from .catalog import EnvironmentLocalCatalog, GithubCatalog, LocalCatalog
86
from .error_utils import Documentation, UnitxtError, UnitxtWarning
97
from .settings_utils import get_constants, get_settings
@@ -89,27 +87,6 @@ def _reset_env_local_catalogs():
8987
_register_catalog(EnvironmentLocalCatalog(location=path))
9088

9189

92-
def _register_all_artifacts():
93-
dir = os.path.dirname(__file__)
94-
file_name = os.path.basename(__file__)
95-
96-
for file in os.listdir(dir):
97-
if (
98-
file.endswith(".py")
99-
and file not in constants.non_registered_files
100-
and file != file_name
101-
):
102-
module_name = file.replace(".py", "")
103-
104-
module = importlib.import_module("." + module_name, __package__)
105-
106-
for _name, obj in inspect.getmembers(module):
107-
# Make sure the object is a class
108-
if inspect.isclass(obj):
109-
# Make sure the class is a subclass of Artifact (but not Artifact itself)
110-
if issubclass(obj, Artifact) and obj is not Artifact:
111-
Artifact.register_class(obj)
112-
11390

11491
class ProjectArtifactRegisterer(metaclass=Singleton):
11592
def __init__(self):
@@ -118,7 +95,6 @@ def __init__(self):
11895

11996
if not self._registered:
12097
_register_all_catalogs()
121-
_register_all_artifacts()
12298
self._registered = True
12399

124100

0 commit comments

Comments
 (0)