|
4 | 4 | import os |
5 | 5 | import pkgutil |
6 | 6 | import re |
| 7 | +import types |
7 | 8 | import warnings |
8 | 9 | from abc import abstractmethod |
9 | 10 | from typing import Any, Dict, List, Optional, Tuple, Union, final |
|
27 | 28 | from .utils import ( |
28 | 29 | artifacts_json_cache, |
29 | 30 | json_dump, |
| 31 | + json_load, |
| 32 | + load_json, |
30 | 33 | save_to_file, |
31 | 34 | shallow_copy, |
32 | 35 | ) |
@@ -119,16 +122,32 @@ def reset(self): |
119 | 122 | self.catalogs = [] |
120 | 123 |
|
121 | 124 |
|
| 125 | +def maybe_recover_function_operator(func): |
| 126 | + sig = inspect.signature(func) |
| 127 | + param_names = tuple(sorted(sig.parameters)) |
| 128 | + if param_names == ("stream", "stream_name") or param_names == ( |
| 129 | + "instance", |
| 130 | + "stream_name", |
| 131 | + ): |
| 132 | + from .operators import FunctionOperator |
| 133 | + |
| 134 | + return FunctionOperator(function=func) |
| 135 | + return func |
| 136 | + |
| 137 | + |
122 | 138 | def maybe_recover_artifacts_structure(obj): |
| 139 | + if isinstance(obj, types.FunctionType): |
| 140 | + obj = maybe_recover_function_operator(obj) |
| 141 | + |
123 | 142 | if Artifact.is_possible_identifier(obj): |
124 | 143 | return verbosed_fetch_artifact(obj) |
125 | 144 | if isinstance(obj, dict): |
126 | 145 | for key, value in obj.items(): |
127 | | - obj[key] = maybe_recover_artifact(value) |
| 146 | + obj[key] = maybe_recover_artifacts_structure(value) |
128 | 147 | return obj |
129 | 148 | if isinstance(obj, list): |
130 | 149 | for i in range(len(obj)): |
131 | | - obj[i] = maybe_recover_artifact(obj[i]) |
| 150 | + obj[i] = maybe_recover_artifacts_structure(obj[i]) |
132 | 151 | return obj |
133 | 152 | return obj |
134 | 153 |
|
@@ -237,8 +256,7 @@ def __init_subclass__(cls, **kwargs): |
237 | 256 | def is_artifact_file(cls, path): |
238 | 257 | if not os.path.exists(path) or not os.path.isfile(path): |
239 | 258 | return False |
240 | | - with open(path) as f: |
241 | | - d = json.load(f) |
| 259 | + d = load_json(path) |
242 | 260 | return cls.is_artifact_dict(d) |
243 | 261 |
|
244 | 262 | @classmethod |
@@ -384,14 +402,15 @@ def serialize(self): |
384 | 402 | return self.to_json() |
385 | 403 |
|
386 | 404 | def save(self, path): |
387 | | - original_args = Artifact.from_dict(self.to_dict()).get_repr_dict() |
| 405 | + data = self.to_dict() |
| 406 | + original_args = Artifact.from_dict(data).get_repr_dict() |
388 | 407 | current_args = self.get_repr_dict() |
389 | 408 | diffs = dict_diff_string(original_args, current_args) |
390 | 409 | if diffs: |
391 | 410 | raise UnitxtError( |
392 | 411 | f"Cannot save catalog artifacts that have changed since initialization. Detected differences in the following fields:\n{diffs}" |
393 | 412 | ) |
394 | | - save_to_file(path, self.to_json()) |
| 413 | + save_to_file(path, json_dump(data)) |
395 | 414 |
|
396 | 415 | def verify_instance( |
397 | 416 | self, instance: Dict[str, Any], name: Optional[str] = None |
@@ -581,7 +600,7 @@ def fetch_artifact( |
581 | 600 |
|
582 | 601 | # If Json string, first load into dictionary |
583 | 602 | if isinstance(artifact_rep, str): |
584 | | - artifact_rep = json.loads(artifact_rep) |
| 603 | + artifact_rep = json_load(artifact_rep) |
585 | 604 | # Load from dictionary (fails if not valid dictionary) |
586 | 605 | return Artifact.from_dict(artifact_rep), None |
587 | 606 |
|
@@ -657,7 +676,7 @@ def get_artifacts_data_classification(artifact: str) -> Optional[List[str]]: |
657 | 676 | ) |
658 | 677 |
|
659 | 678 | try: |
660 | | - data_classification = json.loads(data_classification) |
| 679 | + data_classification = json_load(data_classification) |
661 | 680 | except json.decoder.JSONDecodeError as e: |
662 | 681 | raise RuntimeError(error_msg) from e |
663 | 682 |
|
|
0 commit comments