@@ -29,6 +29,17 @@ def _get_temp_filename():
29
29
def _get_tmp_file_location ():
30
30
return _util ._make_temp_directory (prefix = 'gl_pickle_' )
31
31
32
+
33
+ def _get_class_from_name (module_name , class_name ):
34
+ import importlib
35
+
36
+ # load the module, will raise ImportError if module cannot be loaded
37
+ m = importlib .import_module (module_name )
38
+
39
+ # get the class, will raise AttributeError if class cannot be found
40
+ c = getattr (m , class_name )
41
+ return c
42
+
32
43
def _is_gl_pickle_extensible (obj ):
33
44
"""
34
45
Check if an object has an external serialization prototol. We do so by
@@ -188,14 +199,13 @@ def __init__(self, filename, protocol = -1, min_bytes_to_save = 0):
188
199
"""
189
200
# Zipfile
190
201
# --------
191
- # Version 1 : GLC 1.2.1
202
+ # Version None : GLC 1.2.1
192
203
#
193
204
# Directory:
194
205
# ----------
195
206
# Version 1: GLC 1.4: 1
196
- # Version 2: SFrame 1.8.2+ (new gl_pickle extensibility mechanism)
197
207
198
- VERSION = "2 .0"
208
+ VERSION = "1 .0"
199
209
self .archive_filename = None
200
210
self .gl_temp_storage_path = _get_tmp_file_location ()
201
211
self .gl_object_memo = set ()
@@ -305,11 +315,11 @@ def __gl_pickle_save__(self, filename):
305
315
with open(filename, 'w') as f:
306
316
f.write(self.member)
307
317
308
- @staticmethod
309
- def __gl_pickle_load__(filename):
318
+ @classmethod
319
+ def __gl_pickle_load__(cls, filename):
310
320
with open(filename, 'r') as f:
311
321
member = f.read().split()
312
- return SampleClass (member)
322
+ return cls (member)
313
323
314
324
WARNING: Version 1.0 and before of GLPickle only supported the
315
325
following extended objects.
@@ -500,9 +510,9 @@ def __init__(self, filename):
500
510
self .version = open (version_filename ).read ().strip ()
501
511
except :
502
512
raise IOError ("Corrupted archive: Corrupted version file." )
503
- if self .version not in ["1.0" , "2.0" ]:
513
+ if self .version not in ["1.0" ]:
504
514
raise Exception (
505
- "Corrupted archive: Version string must be in [ 1.0, 2.0] " )
515
+ "Corrupted archive: Version string must be 1.0. " )
506
516
self .pickle_filename = pickle_filename
507
517
self .gl_temp_storage_path = _os .path .abspath (filename )
508
518
@@ -543,9 +553,14 @@ def persistent_load(self, pid):
543
553
else :
544
554
abs_path = _os .path .join (self .gl_temp_storage_path , filename )
545
555
if self .version in ["1.0" , None ]:
546
- obj = _get_gl_object_from_persistent_id (type_tag , abs_path )
547
- elif self .version == "2.0" :
548
- obj = type_tag (abs_path )
556
+ if type_tag in ["SFrame" , "SGraph" , "SArray" , "Model" ]:
557
+ obj = _get_gl_object_from_persistent_id (type_tag ,
558
+ abs_path )
559
+ else :
560
+ module_name , class_name = type_tag
561
+ type_class = _get_class_from_name (module_name ,
562
+ class_name )
563
+ obj = type_class .__gl_pickle_load__ (abs_path )
549
564
else :
550
565
raise Exception (
551
566
"Unknown version %s: Expected version in [1.0, 2.0]" \
0 commit comments