diff --git a/dill/_dill.py b/dill/_dill.py index 33fabb74..2024e814 100644 --- a/dill/_dill.py +++ b/dill/_dill.py @@ -66,14 +66,17 @@ TypeType = type # 'new-style' classes #XXX: unregistered XRangeType = range from types import MappingProxyType as DictProxyType -from pickle import DEFAULT_PROTOCOL, HIGHEST_PROTOCOL, PickleError, PicklingError, UnpicklingError +from pickle import DEFAULT_PROTOCOL, HIGHEST_PROTOCOL, PickleError, \ + PicklingError, UnpicklingError import __main__ as _main_module import marshal import gc # import zlib +import abc import dataclasses from weakref import ReferenceType, ProxyType, CallableProxyType from collections import OrderedDict +from enum import Enum, EnumMeta from functools import partial from operator import itemgetter, attrgetter GENERATOR_FAIL = False @@ -1057,7 +1060,6 @@ def _locate_function(obj, pickler=None): found = _import_module(module_name + '.' + obj.__name__, safe=True) return found is obj - def _setitems(dest, source): for k, v in source.items(): dest[k] = v @@ -1669,6 +1671,93 @@ def save_module(pickler, obj): logger.trace(pickler, "# M2") return +# The following function is based on '_extract_class_dict' from 'cloudpickle' +# Copyright (c) 2012, Regents of the University of California. +# Copyright (c) 2009 `PiCloud, Inc. `_. +# License: https://github.com/cloudpipe/cloudpickle/blob/master/LICENSE +def _get_typedict_type(cls, clsdict, postproc_list): + """Retrieve a copy of the dict of a class without the inherited methods""" + if len(cls.__bases__) == 1: + inherited_dict = cls.__bases__[0].__dict__ + else: + inherited_dict = {} + for base in reversed(cls.__bases__): + inherited_dict.update(base.__dict__) + to_remove = [] + for name, value in dict.items(clsdict): + try: + base_value = inherited_dict[name] + if value is base_value: + to_remove.append(name) + except KeyError: + pass + for name in to_remove: + dict.pop(clsdict, name) + + if issubclass(type(cls), type): + clsdict.pop('__dict__', None) + clsdict.pop('__weakref__', None) + # clsdict.pop('__prepare__', None) + return clsdict + +def _get_typedict_abc(obj, _dict, attrs, postproc_list): + if hasattr(abc, '_get_dump'): + (registry, _, _, _) = abc._get_dump(obj) + register = obj.register + postproc_list.extend((register, (reg(),)) for reg in registry) + elif hasattr(obj, '_abc_registry'): + registry = obj._abc_registry + register = obj.register + postproc_list.extend((register, (reg,)) for reg in registry) + else: + raise PicklingError("Cannot find registry of ABC %s", obj) + + if '_abc_registry' in _dict: + del _dict['_abc_registry'] + del _dict['_abc_cache'] + del _dict['_abc_negative_cache'] + # del _dict['_abc_negative_cache_version'] + else: + del _dict['_abc_impl'] + return _dict, attrs + +CORE_CLASSES = {int, float, type(None), str, dict, tuple, set, list, frozenset} + +def _get_typedict_enum(obj, _dict, attrs, postproc_list): + base = None + + metacls = type(obj) + original_dict = {} + for name, enum_value in obj.__members__.items(): + value = enum_value.value + if base is None: + import copyreg + base = type(value) + reducer = copyreg.dispatch_table.get(base, None) + + if base is tuple: + init_value = (value,) + elif base in CORE_CLASSES: + init_value = value + else: + init_value = reducer(value) if reducer else value.__reduce__() + if init_value[0] is not base or len(init_value) != 2: + raise PickleError('Cannot pickle Enum class, reduction too complex') + init_value = init_value[1] + original_dict[name] = init_value + del _dict[name] + + _dict.pop('_member_names_', None) + _dict.pop('_member_map_', None) + _dict.pop('_value2member_map_', None) + _dict.pop('_generate_next_value_', None) + + if attrs is not None: + attrs.update(_dict) + _dict = attrs + + return original_dict, _dict + @register(TypeType) def save_type(pickler, obj, postproc_list=None): if obj in _typemap: @@ -1680,15 +1769,22 @@ def save_type(pickler, obj, postproc_list=None): elif obj.__bases__ == (tuple,) and all([hasattr(obj, attr) for attr in ('_fields','_asdict','_make','_replace')]): # special case: namedtuples logger.trace(pickler, "T6: %s", obj) + + obj_name = getattr(obj, '__qualname__', getattr(obj, '__name__', None)) + if obj.__name__ != obj_name: + if postproc_list is None: + postproc_list = [] + postproc_list.append((setattr, (obj, '__qualname__', obj_name))) + if not obj._field_defaults: - pickler.save_reduce(_create_namedtuple, (obj.__name__, obj._fields, obj.__module__), obj=obj) + _save_with_postproc(pickler, (_create_namedtuple, (obj.__name__, obj._fields, obj.__module__)), obj=obj, postproc_list=postproc_list) else: defaults = [obj._field_defaults[field] for field in obj._fields if field in obj._field_defaults] - pickler.save_reduce(_create_namedtuple, (obj.__name__, obj._fields, obj.__module__, defaults), obj=obj) + _save_with_postproc(pickler, (_create_namedtuple, (obj.__name__, obj._fields, obj.__module__, defaults)), obj=obj, postproc_list=postproc_list) logger.trace(pickler, "# T6") return - # special cases: NoneType, NotImplementedType, EllipsisType + # special cases: NoneType, NotImplementedType, EllipsisType, EnumMeta elif obj is type(None): logger.trace(pickler, "T7: %s", obj) #XXX: pickler.save_reduce(type, (None,), obj=obj) @@ -1702,35 +1798,74 @@ def save_type(pickler, obj, postproc_list=None): logger.trace(pickler, "T7: %s", obj) pickler.save_reduce(type, (Ellipsis,), obj=obj) logger.trace(pickler, "# T7") + elif obj is EnumMeta: + logger.trace(pickler, "T7: %s", obj) + pickler.write(GLOBAL + b'enum\nEnumMeta\n') + logger.trace(pickler, "# T7") else: - obj_name = getattr(obj, '__qualname__', getattr(obj, '__name__', None)) _byref = getattr(pickler, '_byref', None) obj_recursive = id(obj) in getattr(pickler, '_postproc', ()) incorrectly_named = not _locate_function(obj, pickler) if not _byref and not obj_recursive and incorrectly_named: # not a function, but the name was held over + if postproc_list is None: + postproc_list = [] + # thanks to Tom Stepleton pointing out pickler._session unneeded logger.trace(pickler, "T2: %s", obj) - _dict = obj.__dict__.copy() # convert dictproxy to dict - #print (_dict) - #print ("%s\n%s" % (type(obj), obj.__name__)) - #print ("%s\n%s" % (obj.__bases__, obj.__dict__)) + _dict = _get_typedict_type(obj, obj.__dict__.copy(), postproc_list) # copy dict proxy to a dict + attrs = None + slots = _dict.get('__slots__', ()) - if type(slots) == str: slots = (slots,) # __slots__ accepts a single string + if type(slots) == str: + # __slots__ accepts a single string + slots = (slots,) for name in slots: - del _dict[name] - _dict.pop('__dict__', None) - _dict.pop('__weakref__', None) - _dict.pop('__prepare__', None) - if obj_name != obj.__name__: - if postproc_list is None: - postproc_list = [] - postproc_list.append((setattr, (obj, '__qualname__', obj_name))) - _save_with_postproc(pickler, (_create_type, ( - type(obj), obj.__name__, obj.__bases__, _dict - )), obj=obj, postproc_list=postproc_list) + _dict.pop(name, None) + + if isinstance(obj, abc.ABCMeta): + logger.trace(pickler, "ABC: %s", obj) + _dict, attrs = _get_typedict_abc(obj, _dict, attrs, postproc_list) + logger.trace(pickler, "# ABC") + + if isinstance(obj, EnumMeta): + logger.trace(pickler, "E: %s", obj) + _dict, attrs = _get_typedict_enum(obj, _dict, attrs, postproc_list) + logger.trace(pickler, "# E") + + qualname = getattr(obj, '__qualname__', None) + if attrs is not None: + if qualname is not None: + attrs['__qualname__'] = qualname + for k, v in attrs.items(): + postproc_list.append((setattr, (obj, k, v))) + # TODO: Consider using the state argument to save_reduce? + elif qualname is not None: + postproc_list.append((setattr, (obj, '__qualname__', qualname))) + + if False: # not hasattr(obj, '__orig_bases__'): + _save_with_postproc(pickler, (_create_type, ( + type(obj), obj.__name__, obj.__bases__, _dict + )), obj=obj, postproc_list=postproc_list) + else: + # This case will always work, but might be overkill. + from types import new_class + _metadict = { + 'metaclass': type(obj) + } + + if _dict: + _dict_update = PartialType(_setitems, source=_dict) + else: + _dict_update = None + + bases = getattr(obj, '__orig_bases__', obj.__bases__) + _save_with_postproc(pickler, (new_class, ( + obj.__name__, bases, _metadict, _dict_update + )), obj=obj, postproc_list=postproc_list) logger.trace(pickler, "# T2") else: + obj_name = getattr(obj, '__qualname__', getattr(obj, '__name__', None)) logger.trace(pickler, "T4: %s", obj) if incorrectly_named: warnings.warn( @@ -1753,14 +1888,17 @@ def save_type(pickler, obj, postproc_list=None): return @register(property) +@register(abc.abstractproperty) def save_property(pickler, obj): logger.trace(pickler, "Pr: %s", obj) - pickler.save_reduce(property, (obj.fget, obj.fset, obj.fdel, obj.__doc__), + pickler.save_reduce(type(obj), (obj.fget, obj.fset, obj.fdel, obj.__doc__), obj=obj) logger.trace(pickler, "# Pr") @register(staticmethod) @register(classmethod) +@register(abc.abstractstaticmethod) +@register(abc.abstractclassmethod) def save_classmethod(pickler, obj): logger.trace(pickler, "Cm: %s", obj) orig_func = obj.__func__ diff --git a/dill/tests/test_abc.py b/dill/tests/test_abc.py new file mode 100644 index 00000000..a90c73e4 --- /dev/null +++ b/dill/tests/test_abc.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python +""" +test dill's ability to pickle abstract base class objects +""" +import dill +import abc +from abc import ABC + +from types import FunctionType + +dill.settings['recurse'] = True + +class OneTwoThree(ABC): + @abc.abstractmethod + def foo(self): + """A method""" + pass + + @property + @abc.abstractmethod + def bar(self): + """Property getter""" + pass + + @bar.setter + @abc.abstractmethod + def bar(self, value): + """Property setter""" + pass + + @classmethod + @abc.abstractmethod + def cfoo(cls): + """Class method""" + pass + + @staticmethod + @abc.abstractmethod + def sfoo(): + """Static method""" + pass + +class EasyAsAbc(OneTwoThree): + def __init__(self): + self._bar = None + + def foo(self): + return "Instance Method FOO" + + @property + def bar(self): + return self._bar + + @bar.setter + def bar(self, value): + self._bar = value + + @classmethod + def cfoo(cls): + return "Class Method CFOO" + + @staticmethod + def sfoo(): + return "Static Method SFOO" + +def test_abc_non_local(): + assert dill.copy(OneTwoThree) is not OneTwoThree + assert dill.copy(EasyAsAbc) is not EasyAsAbc + assert dill.copy(OneTwoThree, byref=True) is OneTwoThree + assert dill.copy(EasyAsAbc, byref=True) is EasyAsAbc + instance = EasyAsAbc() + # Set a property that StockPickle can't preserve + instance.bar = lambda x: x**2 + depickled = dill.copy(instance) + assert type(depickled) is not type(instance) + assert type(depickled.bar) is FunctionType + assert depickled.bar(3) == 9 + assert depickled.sfoo() == "Static Method SFOO" + assert depickled.cfoo() == "Class Method CFOO" + assert depickled.foo() == "Instance Method FOO" + +def test_abc_local(): + """ + Test using locally scoped ABC class + """ + class LocalABC(ABC): + @abc.abstractmethod + def foo(self): + pass + + def baz(self): + return repr(self) + + labc = dill.copy(LocalABC) + assert labc is not LocalABC + assert type(labc) is type(LocalABC) + # TODO should work like it does for non local classes + # + # .LocalABC'> + + class Real(labc): + def foo(self): + return "True!" + + def baz(self): + return "My " + super(Real, self).baz() + + real = Real() + assert real.foo() == "True!" + + try: + labc() + except TypeError as e: + # Expected error + pass + else: + print('Failed to raise type error') + assert False + + labc2, pik = dill.copy((labc, Real())) + assert 'Real' == type(pik).__name__ + assert '.Real' in type(pik).__qualname__ + assert type(pik) is not Real + assert labc2 is not LocalABC + assert labc2 is not labc + assert isinstance(pik, labc2) + assert not isinstance(pik, labc) + assert not isinstance(pik, LocalABC) + assert pik.baz() == "My " + repr(pik) + +def test_meta_local_no_cache(): + """ + Test calling metaclass and cache registration + """ + LocalMetaABC = abc.ABCMeta('LocalMetaABC', (), {}) + + class ClassyClass: + pass + + class KlassyClass: + pass + + LocalMetaABC.register(ClassyClass) + + assert not issubclass(KlassyClass, LocalMetaABC) + assert issubclass(ClassyClass, LocalMetaABC) + + res = dill.dumps((LocalMetaABC, ClassyClass, KlassyClass)) + + lmabc, cc, kc = dill.loads(res) + assert type(lmabc) == type(LocalMetaABC) + assert not issubclass(kc, lmabc) + assert issubclass(cc, lmabc) + +if __name__ == '__main__': + test_abc_non_local() + test_abc_local() + test_meta_local_no_cache() diff --git a/dill/tests/test_classdef.py b/dill/tests/test_classdef.py index 05338637..7ff0cde1 100644 --- a/dill/tests/test_classdef.py +++ b/dill/tests/test_classdef.py @@ -54,6 +54,14 @@ def ok(self): nc = _newclass2() m = _mclass() +if sys.hexversion < 0x03090000: + import typing + class customIntList(typing.List[int]): + pass +else: + class customIntList(list[int]): + pass + # test pickles for class instances def test_class_instances(): assert dill.pickles(o) @@ -111,7 +119,7 @@ def test_namedtuple(): assert tuple(Badi) == tuple(dill.loads(dill.dumps(Badi))) class A: - class B(namedtuple("B", ["one", "two"])): + class B(namedtuple("C", ["one", "two"])): '''docstring''' B.__module__ = 'testing' @@ -123,6 +131,15 @@ class B(namedtuple("B", ["one", "two"])): assert dill.copy(A.B).__doc__ == 'docstring' assert dill.copy(A.B).__module__ == 'testing' + from typing import NamedTuple + + def A(): + class B(NamedTuple): + x: int + return B + + assert type(dill.copy(A()(8))).__qualname__ == type(A()(8)).__qualname__ + def test_dtype(): try: import numpy as np @@ -210,6 +227,9 @@ def test_slots(): assert dill.pickles(Y.y) assert dill.copy(y).y == value +def test_origbases(): + assert dill.copy(customIntList).__orig_bases__ == customIntList.__orig_bases__ + def test_attr(): import attr @attr.s @@ -249,4 +269,5 @@ def __new__(cls): test_array_subclass() test_method_decorator() test_slots() + test_origbases() test_metaclass() diff --git a/dill/tests/test_enum.py b/dill/tests/test_enum.py new file mode 100644 index 00000000..adbe83d6 --- /dev/null +++ b/dill/tests/test_enum.py @@ -0,0 +1,161 @@ +import abc +import enum +from enum import Enum, IntEnum, EnumMeta, Flag, IntFlag, unique, auto + +import dill +import sys + +dill.settings['recurse'] = True + +""" +Test cases copied from https://raw.githubusercontent.com/python/cpython/3.10/Lib/test/test_enum.py + +Copyright 1991-1995 by Stichting Mathematisch Centrum, Amsterdam, The Netherlands. + +All Rights Reserved +Permission to use, copy, modify, and distribute this software and its +documentation for any purpose and without fee is hereby granted, provided that +the above copyright notice appear in all copies and that both that copyright +notice and this permission notice appear in supporting documentation, and that +the names of Stichting Mathematisch Centrum or CWI or Corporation for National +Research Initiatives or CNRI not be used in advertising or publicity pertaining +to distribution of the software without specific, written prior permission. + +While CWI is the initial source for this software, a modified version is made +available by the Corporation for National Research Initiatives (CNRI) at the +Internet address http://www.python.org. + +STICHTING MATHEMATISCH CENTRUM AND CNRI DISCLAIM ALL WARRANTIES WITH REGARD TO +THIS SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, +IN NO EVENT SHALL STICHTING MATHEMATISCH CENTRUM OR CNRI BE LIABLE FOR ANY +SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING +FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE +OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR +PERFORMANCE OF THIS SOFTWARE. +""" + +def test_enums(): + + class Stooges(Enum): + LARRY = 1 + CURLY = 2 + MOE = 3 + + class IntStooges(int, Enum): + LARRY = 1 + CURLY = 2 + MOE = 3 + + class FloatStooges(float, Enum): + LARRY = 1.39 + CURLY = 2.72 + MOE = 3.142596 + + class FlagStooges(Flag): + LARRY = 1 + CURLY = 2 + MOE = 3 + + # https://stackoverflow.com/a/56135108 + class ABCEnumMeta(abc.ABCMeta, EnumMeta): + def __new__(mcls, *args, **kw): + abstract_enum_cls = super().__new__(mcls, *args, **kw) + # Only check abstractions if members were defined. + if abstract_enum_cls._member_map_: + try: # Handle existence of undefined abstract methods. + absmethods = list(abstract_enum_cls.__abstractmethods__) + if absmethods: + missing = ', '.join(repr(method) for method in absmethods) + plural = 's' if len(absmethods) > 1 else '' + raise TypeError( + ("cannot instantiate abstract class %r" + " with abstract method%s %s") % (abstract_enum_cls.__name__, plural, missing)) + except AttributeError: + pass + return abstract_enum_cls + + l = locals() + exec("""class StrEnum(str, abc.ABC, Enum, metaclass=ABCEnumMeta): + 'accepts only string values' + def invisible(self): + return "did you see me?" """, None, l) + StrEnum = l['StrEnum'] + + class Name(StrEnum): + BDFL = 'Guido van Rossum' + FLUFL = 'Barry Warsaw' + + assert 'invisible' in dir(dill.copy(Name).BDFL) + assert 'invisible' in dir(dill.copy(Name.BDFL)) + assert dill.copy(Name.BDFL) is not Name.BDFL + + Question = Enum('Question', 'who what when where why', module=__name__) + Answer = Enum('Answer', 'him this then there because') + Theory = Enum('Theory', 'rule law supposition', qualname='spanish_inquisition') + + class Fruit(Enum): + TOMATO = 1 + BANANA = 2 + CHERRY = 3 + + assert dill.copy(Fruit).TOMATO.value == 1 and dill.copy(Fruit).TOMATO != 1 \ + and dill.copy(Fruit).TOMATO is not Fruit.TOMATO + + from datetime import date + class Holiday(date, Enum): + NEW_YEAR = 2013, 1, 1 + IDES_OF_MARCH = 2013, 3, 15 + + assert hasattr(dill.copy(Holiday), 'NEW_YEAR') + + class HolidayTuple(tuple, Enum): + NEW_YEAR = 2013, 1, 1 + IDES_OF_MARCH = 2013, 3, 15 + + assert isinstance(dill.copy(HolidayTuple).NEW_YEAR, tuple) + + class SuperEnum(IntEnum): + def __new__(cls, value, description=""): + obj = int.__new__(cls, value) + obj._value_ = value + obj.description = description + return obj + + class SubEnum(SuperEnum): + sample = 5 + + if sys.hexversion >= 0x030a0000: + assert 'description' in dir(dill.copy(SubEnum.sample)) + assert 'description' in dir(dill.copy(SubEnum).sample) + + class WeekDay(IntEnum): + SUNDAY = 1 + MONDAY = 2 + TUESDAY = TEUSDAY = 3 + WEDNESDAY = 4 + THURSDAY = 5 + FRIDAY = 6 + SATURDAY = 7 + + WeekDay_ = dill.copy(WeekDay) + assert WeekDay_.TUESDAY is WeekDay_.TEUSDAY + + class AutoNumber(IntEnum): + def __new__(cls): + value = len(cls.__members__) + 1 + obj = int.__new__(cls, value) + obj._value_ = value + return obj + + class Color(AutoNumber): + red = () + green = () + blue = () + + # TODO: This doesn't work yet + # Color_ = dill.copy(Color) + # assert list(Color_) == [Color_.red, Color_.green, Color_.blue] + # assert list(map(int, Color_)) == [1, 2, 3] + +if __name__ == '__main__': + test_enums()