Skip to content

Commit 6b55755

Browse files
committed
session: deal with modules with unpickleable objects
1 parent 2fdd31d commit 6b55755

File tree

5 files changed

+125
-29
lines changed

5 files changed

+125
-29
lines changed

dill/_dill.py

Lines changed: 87 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040
#XXX: get types from .objtypes ?
4141
import builtins as __builtin__
4242
from pickle import _Pickler as StockPickler, Unpickler as StockUnpickler
43+
from pickle import BINPUT, DICT, EMPTY_DICT, LONG_BINPUT, MARK, PUT, SETITEM
44+
from struct import pack
4345
from _thread import LockType
4446
from _thread import RLock as RLockType
4547
#from io import IOBase
@@ -234,6 +236,9 @@ def __reduce_ex__(self, protocol):
234236
#: Pickles the entire file (handle and contents), preserving mode and position.
235237
FILE_FMODE = 2
236238

239+
# Exceptions commonly raised by unpicklable objects.
240+
UNPICKLEABLE_ERRORS = (PicklingError, TypeError, NotImplementedError)
241+
237242
### Shorthands (modified from python2.5/lib/pickle.py)
238243
def copy(obj, *args, **kwds):
239244
"""
@@ -349,16 +354,18 @@ class Pickler(StockPickler):
349354
def __init__(self, file, *args, **kwds):
350355
settings = Pickler.settings
351356
_byref = kwds.pop('byref', None)
352-
#_strictio = kwds.pop('strictio', None)
353357
_fmode = kwds.pop('fmode', None)
354358
_recurse = kwds.pop('recurse', None)
359+
#_refonfail = kwds.pop('refonfail', None)
360+
#_strictio = kwds.pop('strictio', None)
355361
StockPickler.__init__(self, file, *args, **kwds)
356362
self._main = _main_module
357363
self._diff_cache = {}
358364
self._byref = settings['byref'] if _byref is None else _byref
359-
self._strictio = False #_strictio
360365
self._fmode = settings['fmode'] if _fmode is None else _fmode
361366
self._recurse = settings['recurse'] if _recurse is None else _recurse
367+
self._refonfail = False #settings['dump_module']['refonfail'] if _refonfail is None else _refonfail
368+
self._strictio = False #_strictio
362369
self._postproc = OrderedDict()
363370
self._file = file # for the logger
364371

@@ -395,7 +402,7 @@ def save_numpy_dtype(pickler, obj):
395402
if NumpyArrayType and ndarraysubclassinstance(obj):
396403
@register(type(obj))
397404
def save_numpy_array(pickler, obj):
398-
logger.trace(pickler, "Nu: (%s, %s)", obj.shape, obj.dtype)
405+
logger.trace(pickler, "Nu: (%s, %s)", obj.shape, obj.dtype, obj=obj)
399406
npdict = getattr(obj, '__dict__', None)
400407
f, args, state = obj.__reduce__()
401408
pickler.save_reduce(_create_array, (f,args,state,npdict), obj=obj)
@@ -407,9 +414,68 @@ def save_numpy_array(pickler, obj):
407414
raise PicklingError(msg)
408415
logger.trace_setup(self)
409416
StockPickler.dump(self, obj)
410-
411417
dump.__doc__ = StockPickler.dump.__doc__
412418

419+
def save(self, obj, save_persistent_id=True, *, name=None):
420+
"""If self._refonfail is True, try to save object by reference if pickling fails."""
421+
if not self._refonfail:
422+
super().save(obj, save_persistent_id)
423+
return
424+
if self.framer.current_frame:
425+
# protocol >= 4
426+
self.framer.commit_frame()
427+
stream = self.framer.current_frame
428+
else:
429+
stream = self._file
430+
position = stream.tell()
431+
memo_size = len(self.memo)
432+
try:
433+
super().save(obj, save_persistent_id)
434+
except UNPICKLEABLE_ERRORS + (AttributeError,) as error_stack:
435+
# AttributeError may happen in save_global() call for child object.
436+
if (type(error_stack) == AttributeError
437+
and "no attribute '__name__'" not in error_stack.args[0]):
438+
raise
439+
# roll back the stream
440+
stream.seek(position)
441+
stream.truncate()
442+
# roll back memo
443+
for _ in range(len(self.memo) - memo_size):
444+
self.memo.popitem() # LIFO order is guaranteed for since 3.7
445+
try:
446+
self.save_global(obj, name)
447+
except (AttributeError, PicklingError) as error:
448+
if getattr(self, '_trace_stack', None) and id(obj) == self._trace_stack[-1]:
449+
# roll back trace state
450+
self._trace_stack.pop()
451+
self._size_stack.pop()
452+
raise error from error_stack
453+
logger.trace(self, "# X: fallback to save_global: <%s object at %#012x>",
454+
type(obj).__name__, id(obj), obj=obj)
455+
456+
def _save_module_dict(self, obj):
457+
"""
458+
Use object name in the module namespace as a last resource to try to
459+
save it by reference when pickling fails.
460+
461+
Modified from Pickler.save_dict() and Pickler._batch_setitems().
462+
"""
463+
if not self._refonfail:
464+
super().save_dict(obj)
465+
return
466+
if self.bin:
467+
self.write(EMPTY_DICT)
468+
else: # proto 0 -- can't use EMPTY_DICT
469+
self.write(MARK + DICT)
470+
self.memoize(obj)
471+
for k, v in obj.items():
472+
self.save(k)
473+
if hasattr(v, '__name__') or hasattr(v, '__qualname__'):
474+
self.save(v)
475+
else:
476+
self.save(v, name=k)
477+
self.write(SETITEM)
478+
413479
class Unpickler(StockUnpickler):
414480
"""python's Unpickler extended to interpreter sessions and more types"""
415481
from .settings import settings
@@ -1173,26 +1239,30 @@ def _repr_dict(obj):
11731239

11741240
@register(dict)
11751241
def save_module_dict(pickler, obj):
1176-
if is_dill(pickler, child=False) and obj == pickler._main.__dict__ and \
1242+
pickler_is_dill = is_dill(pickler, child=False)
1243+
if pickler_is_dill and obj == pickler._main.__dict__ and \
11771244
not (pickler._session and pickler._first_pass):
1178-
logger.trace(pickler, "D1: %s", _repr_dict(obj)) # obj
1245+
logger.trace(pickler, "D1: %s", _repr_dict(obj), obj=obj)
11791246
pickler.write(bytes('c__builtin__\n__main__\n', 'UTF-8'))
11801247
logger.trace(pickler, "# D1")
1181-
elif (not is_dill(pickler, child=False)) and (obj == _main_module.__dict__):
1182-
logger.trace(pickler, "D3: %s", _repr_dict(obj)) # obj
1248+
elif (not pickler_is_dill) and (obj == _main_module.__dict__):
1249+
logger.trace(pickler, "D3: %s", _repr_dict(obj), obj=obj)
11831250
pickler.write(bytes('c__main__\n__dict__\n', 'UTF-8')) #XXX: works in general?
11841251
logger.trace(pickler, "# D3")
11851252
elif '__name__' in obj and obj != _main_module.__dict__ \
11861253
and type(obj['__name__']) is str \
11871254
and obj is getattr(_import_module(obj['__name__'],True), '__dict__', None):
1188-
logger.trace(pickler, "D4: %s", _repr_dict(obj)) # obj
1255+
logger.trace(pickler, "D4: %s", _repr_dict(obj), obj=obj)
11891256
pickler.write(bytes('c%s\n__dict__\n' % obj['__name__'], 'UTF-8'))
11901257
logger.trace(pickler, "# D4")
1258+
elif pickler_is_dill and pickler._session and pickler._first_pass:
1259+
# we only care about session the first pass thru
1260+
pickler._first_pass = False
1261+
logger.trace(pickler, "D5: %s", _repr_dict(obj), obj=obj)
1262+
pickler._save_module_dict(obj)
1263+
logger.trace(pickler, "# D5")
11911264
else:
1192-
logger.trace(pickler, "D2: %s", _repr_dict(obj)) # obj
1193-
if is_dill(pickler, child=False) and pickler._session:
1194-
# we only care about session the first pass thru
1195-
pickler._first_pass = False
1265+
logger.trace(pickler, "D2: %s", _repr_dict(obj), obj=obj)
11961266
StockPickler.save_dict(pickler, obj)
11971267
logger.trace(pickler, "# D2")
11981268
return
@@ -1491,15 +1561,15 @@ def save_cell(pickler, obj):
14911561
if MAPPING_PROXY_TRICK:
14921562
@register(DictProxyType)
14931563
def save_dictproxy(pickler, obj):
1494-
logger.trace(pickler, "Mp: %s", _repr_dict(obj)) # obj
1564+
logger.trace(pickler, "Mp: %s", _repr_dict(obj), obj=obj)
14951565
mapping = obj | _dictproxy_helper_instance
14961566
pickler.save_reduce(DictProxyType, (mapping,), obj=obj)
14971567
logger.trace(pickler, "# Mp")
14981568
return
14991569
else:
15001570
@register(DictProxyType)
15011571
def save_dictproxy(pickler, obj):
1502-
logger.trace(pickler, "Mp: %s", _repr_dict(obj)) # obj
1572+
logger.trace(pickler, "Mp: %s", _repr_dict(obj), obj=obj)
15031573
pickler.save_reduce(DictProxyType, (obj.copy(),), obj=obj)
15041574
logger.trace(pickler, "# Mp")
15051575
return
@@ -1575,7 +1645,7 @@ def save_weakproxy(pickler, obj):
15751645
logger.trace(pickler, "%s: %s", _t, obj)
15761646
except ReferenceError:
15771647
_t = "R3"
1578-
logger.trace(pickler, "%s: %s", _t, sys.exc_info()[1])
1648+
logger.trace(pickler, "%s: %s", _t, sys.exc_info()[1], obj=obj)
15791649
#callable = bool(getattr(refobj, '__call__', None))
15801650
if type(obj) is CallableProxyType: callable = True
15811651
else: callable = False
@@ -1914,7 +1984,7 @@ def pickles(obj,exact=False,safe=False,**kwds):
19141984
"""
19151985
if safe: exceptions = (Exception,) # RuntimeError, ValueError
19161986
else:
1917-
exceptions = (TypeError, AssertionError, NotImplementedError, PicklingError, UnpicklingError)
1987+
exceptions = UNPICKLEABLE_ERRORS + (AssertionError, UnpicklingError)
19181988
try:
19191989
pik = copy(obj, **kwds)
19201990
#FIXME: should check types match first, then check content if "exact"

dill/logger.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -129,18 +129,22 @@ def trace_setup(self, pickler):
129129
if not dill._dill.is_dill(pickler, child=False):
130130
return
131131
if self.isEnabledFor(logging.INFO):
132-
pickler._trace_depth = 1
132+
pickler._trace_stack = []
133133
pickler._size_stack = []
134134
else:
135-
pickler._trace_depth = None
136-
def trace(self, pickler, msg, *args, **kwargs):
137-
if not hasattr(pickler, '_trace_depth'):
135+
pickler._trace_stack = None
136+
def trace(self, pickler, msg, *args, obj=None, **kwargs):
137+
if not hasattr(pickler, '_trace_stack'):
138138
logger.info(msg, *args, **kwargs)
139139
return
140-
if pickler._trace_depth is None:
140+
if pickler._trace_stack is None:
141141
return
142142
extra = kwargs.get('extra', {})
143143
pushed_obj = msg.startswith('#')
144+
if not pushed_obj:
145+
if obj is None:
146+
obj = args[-1]
147+
pickler._trace_stack.append(id(obj))
144148
size = None
145149
try:
146150
# Streams are not required to be tellable.
@@ -159,13 +163,11 @@ def trace(self, pickler, msg, *args, **kwargs):
159163
else:
160164
size -= pickler._size_stack.pop()
161165
extra['size'] = size
162-
if pushed_obj:
163-
pickler._trace_depth -= 1
164-
extra['depth'] = pickler._trace_depth
166+
extra['depth'] = len(pickler._trace_stack)
165167
kwargs['extra'] = extra
166168
self.info(msg, *args, **kwargs)
167-
if not pushed_obj:
168-
pickler._trace_depth += 1
169+
if pushed_obj:
170+
pickler._trace_stack.pop()
169171

170172
class TraceFormatter(logging.Formatter):
171173
"""

dill/session.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,10 @@ def dump_module(
184184
filename = str(TEMPDIR/'session.pkl'),
185185
module: Union[ModuleType, str] = None,
186186
refimported: bool = False,
187+
refonfail: bool = False,
187188
**kwds
188189
) -> None:
189-
"""Pickle the current state of :py:mod:`__main__` or another module to a file.
190+
R"""Pickle the current state of :py:mod:`__main__` or another module to a file.
190191
191192
Save the contents of :py:mod:`__main__` (e.g. from an interactive
192193
interpreter session), an imported module, or a module-type object (e.g.
@@ -202,6 +203,10 @@ def dump_module(
202203
similar but independent from ``dill.settings[`byref`]``, as
203204
``refimported`` refers to virtually all imported objects, while
204205
``byref`` only affects select objects.
206+
refonfail: if `True`, objects that fail to be saved by value will try to
207+
be saved by reference. If it also fails, saving their parent
208+
objects by reference will be attempted recursively. In the worst
209+
case scenario, the module itself may be saved by reference.
205210
**kwds: extra keyword arguments passed to :py:class:`Pickler()`.
206211
207212
Raises:
@@ -232,6 +237,15 @@ def dump_module(
232237
>>> foo.sin = math.sin
233238
>>> dill.dump_module('foo_session.pkl', module=foo, refimported=True)
234239
240+
- Save the state of a module with unpickleable objects:
241+
242+
>>> import dill
243+
>>> import os
244+
>>> os.altsep = '\\'
245+
>>> dill.dump_module('os_session.pkl', module=os)
246+
PicklingError: ...
247+
>>> dill.dump_module('os_session.pkl', module=os, refonfail=True)
248+
235249
- Restore the state of the saved modules:
236250
237251
>>> import dill
@@ -244,6 +258,9 @@ def dump_module(
244258
>>> foo = dill.load_module('foo_session.pkl')
245259
>>> [foo.sin(x) for x in foo.values]
246260
[0.8414709848078965, 0.9092974268256817, 0.1411200080598672]
261+
>>> os = dill.load_module('os_session.pkl')
262+
>>> print(os.altsep.join('path'))
263+
p\a\t\h
247264
248265
*Changed in version 0.3.6:* Function ``dump_session()`` was renamed to
249266
``dump_module()``. Parameters ``main`` and ``byref`` were renamed to
@@ -266,6 +283,8 @@ def dump_module(
266283

267284
from .settings import settings
268285
protocol = settings['protocol']
286+
if refimported is None: refimported = settings['dump_module']['refimported']
287+
if refonfail is None: refonfail = settings['dump_module']['refonfail']
269288
main = module
270289
if main is None:
271290
main = _main_module
@@ -283,6 +302,7 @@ def dump_module(
283302
pickler._main = main #FIXME: dill.settings are disabled
284303
pickler._byref = False # disable pickling by name reference
285304
pickler._recurse = False # disable pickling recursion for globals
305+
pickler._refonfail = refonfail
286306
pickler._session = True # is best indicator of when pickling a session
287307
pickler._first_pass = True
288308
pickler.dump(main)

dill/settings.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919
'fmode' : 0, #HANDLE_FMODE
2020
'recurse' : False,
2121
'ignore' : False,
22+
'dump_module' : {
23+
'refimported': False,
24+
'refonfail' : False,
25+
},
2226
}
2327

2428
del DEFAULT_PROTOCOL

dill/tests/test_session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def test_runtime_module():
197197
runtime = ModuleType(modname)
198198
runtime.x = 42
199199

200-
mod = dill._dill._stash_modules(runtime)
200+
mod = dill.session._stash_modules(runtime)
201201
if mod is not runtime:
202202
print("There are objects to save by referenece that shouldn't be:",
203203
mod.__dill_imported, mod.__dill_imported_as, mod.__dill_imported_top_level,

0 commit comments

Comments
 (0)