Skip to content

Commit 154bb78

Browse files
author
srikris
committed
Modified the code to make sure imports work.
1 parent a314262 commit 154bb78

File tree

1 file changed

+26
-11
lines changed

1 file changed

+26
-11
lines changed

oss_src/unity/python/sframe/_gl_pickle.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,17 @@ def _get_temp_filename():
2929
def _get_tmp_file_location():
3030
return _util._make_temp_directory(prefix='gl_pickle_')
3131

32+
33+
def _get_class_from_name(module_name, class_name):
34+
import importlib
35+
36+
# load the module, will raise ImportError if module cannot be loaded
37+
m = importlib.import_module(module_name)
38+
39+
# get the class, will raise AttributeError if class cannot be found
40+
c = getattr(m, class_name)
41+
return c
42+
3243
def _is_gl_pickle_extensible(obj):
3344
"""
3445
Check if an object has an external serialization prototol. We do so by
@@ -188,14 +199,13 @@ def __init__(self, filename, protocol = -1, min_bytes_to_save = 0):
188199
"""
189200
# Zipfile
190201
# --------
191-
# Version 1: GLC 1.2.1
202+
# Version None: GLC 1.2.1
192203
#
193204
# Directory:
194205
# ----------
195206
# Version 1: GLC 1.4: 1
196-
# Version 2: SFrame 1.8.2+ (new gl_pickle extensibility mechanism)
197207

198-
VERSION = "2.0"
208+
VERSION = "1.0"
199209
self.archive_filename = None
200210
self.gl_temp_storage_path = _get_tmp_file_location()
201211
self.gl_object_memo = set()
@@ -305,11 +315,11 @@ def __gl_pickle_save__(self, filename):
305315
with open(filename, 'w') as f:
306316
f.write(self.member)
307317
308-
@staticmethod
309-
def __gl_pickle_load__(filename):
318+
@classmethod
319+
def __gl_pickle_load__(cls, filename):
310320
with open(filename, 'r') as f:
311321
member = f.read().split()
312-
return SampleClass(member)
322+
return cls(member)
313323
314324
WARNING: Version 1.0 and before of GLPickle only supported the
315325
following extended objects.
@@ -500,9 +510,9 @@ def __init__(self, filename):
500510
self.version = open(version_filename).read().strip()
501511
except:
502512
raise IOError("Corrupted archive: Corrupted version file.")
503-
if self.version not in ["1.0", "2.0"]:
513+
if self.version not in ["1.0"]:
504514
raise Exception(
505-
"Corrupted archive: Version string must be in [1.0, 2.0]")
515+
"Corrupted archive: Version string must be 1.0.")
506516
self.pickle_filename = pickle_filename
507517
self.gl_temp_storage_path = _os.path.abspath(filename)
508518

@@ -543,9 +553,14 @@ def persistent_load(self, pid):
543553
else:
544554
abs_path = _os.path.join(self.gl_temp_storage_path, filename)
545555
if self.version in ["1.0", None]:
546-
obj = _get_gl_object_from_persistent_id(type_tag, abs_path)
547-
elif self.version == "2.0":
548-
obj = type_tag(abs_path)
556+
if type_tag in ["SFrame", "SGraph", "SArray", "Model"]:
557+
obj = _get_gl_object_from_persistent_id(type_tag,
558+
abs_path)
559+
else:
560+
module_name, class_name = type_tag
561+
type_class = _get_class_from_name(module_name,
562+
class_name)
563+
obj = type_class.__gl_pickle_load__(abs_path)
549564
else:
550565
raise Exception(
551566
"Unknown version %s: Expected version in [1.0, 2.0]" \

0 commit comments

Comments
 (0)