Skip to content

Commit 322cf1b

Browse files
authored
Python 3 Metaclasses [Support ABC and Enums - Part 1] (#577)
* Python 3 Metaclasses * Add more tests for Python 3 Metaclasses
1 parent 0e3e7b5 commit 322cf1b

File tree

2 files changed

+128
-20
lines changed

2 files changed

+128
-20
lines changed

dill/_dill.py

Lines changed: 88 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -65,15 +65,17 @@
6565
SliceType = slice
6666
TypeType = type # 'new-style' classes #XXX: unregistered
6767
XRangeType = range
68-
from types import MappingProxyType as DictProxyType
68+
from types import MappingProxyType as DictProxyType, new_class
6969
from pickle import DEFAULT_PROTOCOL, HIGHEST_PROTOCOL, PickleError, PicklingError, UnpicklingError
7070
import __main__ as _main_module
7171
import marshal
7272
import gc
7373
# import zlib
74+
import abc
7475
import dataclasses
7576
from weakref import ReferenceType, ProxyType, CallableProxyType
7677
from collections import OrderedDict
78+
from enum import Enum, EnumMeta
7779
from functools import partial
7880
from operator import itemgetter, attrgetter
7981
GENERATOR_FAIL = False
@@ -1669,6 +1671,35 @@ def save_module(pickler, obj):
16691671
logger.trace(pickler, "# M2")
16701672
return
16711673

1674+
# The following function is based on '_extract_class_dict' from 'cloudpickle'
1675+
# Copyright (c) 2012, Regents of the University of California.
1676+
# Copyright (c) 2009 `PiCloud, Inc. <http://www.picloud.com>`_.
1677+
# License: https://github.com/cloudpipe/cloudpickle/blob/master/LICENSE
1678+
def _get_typedict_type(cls, clsdict, attrs, postproc_list):
1679+
"""Retrieve a copy of the dict of a class without the inherited methods"""
1680+
if len(cls.__bases__) == 1:
1681+
inherited_dict = cls.__bases__[0].__dict__
1682+
else:
1683+
inherited_dict = {}
1684+
for base in reversed(cls.__bases__):
1685+
inherited_dict.update(base.__dict__)
1686+
to_remove = []
1687+
for name, value in dict.items(clsdict):
1688+
try:
1689+
base_value = inherited_dict[name]
1690+
if value is base_value:
1691+
to_remove.append(name)
1692+
except KeyError:
1693+
pass
1694+
for name in to_remove:
1695+
dict.pop(clsdict, name)
1696+
1697+
if issubclass(type(cls), type):
1698+
clsdict.pop('__dict__', None)
1699+
clsdict.pop('__weakref__', None)
1700+
# clsdict.pop('__prepare__', None)
1701+
return clsdict, attrs
1702+
16721703
@register(TypeType)
16731704
def save_type(pickler, obj, postproc_list=None):
16741705
if obj in _typemap:
@@ -1680,15 +1711,22 @@ def save_type(pickler, obj, postproc_list=None):
16801711
elif obj.__bases__ == (tuple,) and all([hasattr(obj, attr) for attr in ('_fields','_asdict','_make','_replace')]):
16811712
# special case: namedtuples
16821713
logger.trace(pickler, "T6: %s", obj)
1714+
1715+
obj_name = getattr(obj, '__qualname__', getattr(obj, '__name__', None))
1716+
if obj.__name__ != obj_name:
1717+
if postproc_list is None:
1718+
postproc_list = []
1719+
postproc_list.append((setattr, (obj, '__qualname__', obj_name)))
1720+
16831721
if not obj._field_defaults:
1684-
pickler.save_reduce(_create_namedtuple, (obj.__name__, obj._fields, obj.__module__), obj=obj)
1722+
_save_with_postproc(pickler, (_create_namedtuple, (obj.__name__, obj._fields, obj.__module__)), obj=obj, postproc_list=postproc_list)
16851723
else:
16861724
defaults = [obj._field_defaults[field] for field in obj._fields if field in obj._field_defaults]
1687-
pickler.save_reduce(_create_namedtuple, (obj.__name__, obj._fields, obj.__module__, defaults), obj=obj)
1725+
_save_with_postproc(pickler, (_create_namedtuple, (obj.__name__, obj._fields, obj.__module__, defaults)), obj=obj, postproc_list=postproc_list)
16881726
logger.trace(pickler, "# T6")
16891727
return
16901728

1691-
# special cases: NoneType, NotImplementedType, EllipsisType
1729+
# special cases: NoneType, NotImplementedType, EllipsisType, EnumMeta
16921730
elif obj is type(None):
16931731
logger.trace(pickler, "T7: %s", obj)
16941732
#XXX: pickler.save_reduce(type, (None,), obj=obj)
@@ -1702,35 +1740,63 @@ def save_type(pickler, obj, postproc_list=None):
17021740
logger.trace(pickler, "T7: %s", obj)
17031741
pickler.save_reduce(type, (Ellipsis,), obj=obj)
17041742
logger.trace(pickler, "# T7")
1743+
elif obj is EnumMeta:
1744+
logger.trace(pickler, "T7: %s", obj)
1745+
pickler.write(GLOBAL + b'enum\nEnumMeta\n')
1746+
logger.trace(pickler, "# T7")
17051747

17061748
else:
1707-
obj_name = getattr(obj, '__qualname__', getattr(obj, '__name__', None))
17081749
_byref = getattr(pickler, '_byref', None)
17091750
obj_recursive = id(obj) in getattr(pickler, '_postproc', ())
17101751
incorrectly_named = not _locate_function(obj, pickler)
17111752
if not _byref and not obj_recursive and incorrectly_named: # not a function, but the name was held over
1753+
if postproc_list is None:
1754+
postproc_list = []
1755+
17121756
# thanks to Tom Stepleton pointing out pickler._session unneeded
17131757
logger.trace(pickler, "T2: %s", obj)
1714-
_dict = obj.__dict__.copy() # convert dictproxy to dict
1758+
_dict, attrs = _get_typedict_type(obj, obj.__dict__.copy(), None, postproc_list) # copy dict proxy to a dict
1759+
17151760
#print (_dict)
17161761
#print ("%s\n%s" % (type(obj), obj.__name__))
17171762
#print ("%s\n%s" % (obj.__bases__, obj.__dict__))
17181763
slots = _dict.get('__slots__', ())
1719-
if type(slots) == str: slots = (slots,) # __slots__ accepts a single string
1764+
if type(slots) == str:
1765+
# __slots__ accepts a single string
1766+
slots = (slots,)
1767+
17201768
for name in slots:
1721-
del _dict[name]
1722-
_dict.pop('__dict__', None)
1723-
_dict.pop('__weakref__', None)
1724-
_dict.pop('__prepare__', None)
1725-
if obj_name != obj.__name__:
1726-
if postproc_list is None:
1727-
postproc_list = []
1728-
postproc_list.append((setattr, (obj, '__qualname__', obj_name)))
1729-
_save_with_postproc(pickler, (_create_type, (
1730-
type(obj), obj.__name__, obj.__bases__, _dict
1731-
)), obj=obj, postproc_list=postproc_list)
1769+
_dict.pop(name, None)
1770+
1771+
qualname = getattr(obj, '__qualname__', None)
1772+
if attrs is not None:
1773+
for k, v in attrs.items():
1774+
postproc_list.append((setattr, (obj, k, v)))
1775+
# TODO: Consider using the state argument to save_reduce?
1776+
if qualname is not None:
1777+
postproc_list.append((setattr, (obj, '__qualname__', qualname)))
1778+
1779+
if not hasattr(obj, '__orig_bases__'):
1780+
_save_with_postproc(pickler, (_create_type, (
1781+
type(obj), obj.__name__, obj.__bases__, _dict
1782+
)), obj=obj, postproc_list=postproc_list)
1783+
else:
1784+
# This case will always work, but might be overkill.
1785+
_metadict = {
1786+
'metaclass': type(obj)
1787+
}
1788+
1789+
if _dict:
1790+
_dict_update = PartialType(_setitems, source=_dict)
1791+
else:
1792+
_dict_update = None
1793+
1794+
_save_with_postproc(pickler, (new_class, (
1795+
obj.__name__, obj.__orig_bases__, _metadict, _dict_update
1796+
)), obj=obj, postproc_list=postproc_list)
17321797
logger.trace(pickler, "# T2")
17331798
else:
1799+
obj_name = getattr(obj, '__qualname__', getattr(obj, '__name__', None))
17341800
logger.trace(pickler, "T4: %s", obj)
17351801
if incorrectly_named:
17361802
warnings.warn(
@@ -1753,14 +1819,17 @@ def save_type(pickler, obj, postproc_list=None):
17531819
return
17541820

17551821
@register(property)
1822+
@register(abc.abstractproperty)
17561823
def save_property(pickler, obj):
17571824
logger.trace(pickler, "Pr: %s", obj)
1758-
pickler.save_reduce(property, (obj.fget, obj.fset, obj.fdel, obj.__doc__),
1825+
pickler.save_reduce(type(obj), (obj.fget, obj.fset, obj.fdel, obj.__doc__),
17591826
obj=obj)
17601827
logger.trace(pickler, "# Pr")
17611828

17621829
@register(staticmethod)
17631830
@register(classmethod)
1831+
@register(abc.abstractstaticmethod)
1832+
@register(abc.abstractclassmethod)
17641833
def save_classmethod(pickler, obj):
17651834
logger.trace(pickler, "Cm: %s", obj)
17661835
orig_func = obj.__func__

dill/tests/test_classdef.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# - https://github.com/uqfoundation/dill/blob/master/LICENSE
88

99
import dill
10+
from enum import EnumMeta
1011
import sys
1112
dill.settings['recurse'] = True
1213

@@ -54,6 +55,14 @@ def ok(self):
5455
nc = _newclass2()
5556
m = _mclass()
5657

58+
if sys.hexversion < 0x03090000:
59+
import typing
60+
class customIntList(typing.List[int]):
61+
pass
62+
else:
63+
class customIntList(list[int]):
64+
pass
65+
5766
# test pickles for class instances
5867
def test_class_instances():
5968
assert dill.pickles(o)
@@ -89,6 +98,7 @@ def test_specialtypes():
8998
assert dill.pickles(type(None))
9099
assert dill.pickles(type(NotImplemented))
91100
assert dill.pickles(type(Ellipsis))
101+
assert dill.pickles(type(EnumMeta))
92102

93103
from collections import namedtuple
94104
Z = namedtuple("Z", ['a','b'])
@@ -99,19 +109,23 @@ def test_specialtypes():
99109
Xi = X(0,1)
100110
Bad = namedtuple("FakeName", ['a','b'])
101111
Badi = Bad(0,1)
112+
Defaults = namedtuple('Defaults', ['x', 'y'], defaults=[1])
113+
Defaultsi = Defaults(2)
102114

103115
# test namedtuple
104116
def test_namedtuple():
105117
assert Z is dill.loads(dill.dumps(Z))
106118
assert Zi == dill.loads(dill.dumps(Zi))
107119
assert X is dill.loads(dill.dumps(X))
108120
assert Xi == dill.loads(dill.dumps(Xi))
121+
assert Defaults is dill.loads(dill.dumps(Defaults))
122+
assert Defaultsi == dill.loads(dill.dumps(Defaultsi))
109123
assert Bad is not dill.loads(dill.dumps(Bad))
110124
assert Bad._fields == dill.loads(dill.dumps(Bad))._fields
111125
assert tuple(Badi) == tuple(dill.loads(dill.dumps(Badi)))
112126

113127
class A:
114-
class B(namedtuple("B", ["one", "two"])):
128+
class B(namedtuple("C", ["one", "two"])):
115129
'''docstring'''
116130
B.__module__ = 'testing'
117131

@@ -123,6 +137,15 @@ class B(namedtuple("B", ["one", "two"])):
123137
assert dill.copy(A.B).__doc__ == 'docstring'
124138
assert dill.copy(A.B).__module__ == 'testing'
125139

140+
from typing import NamedTuple
141+
142+
def A():
143+
class B(NamedTuple):
144+
x: int
145+
return B
146+
147+
assert type(dill.copy(A()(8))).__qualname__ == type(A()(8)).__qualname__
148+
126149
def test_dtype():
127150
try:
128151
import numpy as np
@@ -204,11 +227,20 @@ def __init__(self, y):
204227
value = 123
205228
y = Y(value)
206229

230+
class Y2(object):
231+
__slots__ = 'y'
232+
def __init__(self, y):
233+
self.y = y
234+
207235
def test_slots():
208236
assert dill.pickles(Y)
209237
assert dill.pickles(y)
210238
assert dill.pickles(Y.y)
211239
assert dill.copy(y).y == value
240+
assert dill.copy(Y2(value)).y == value
241+
242+
def test_origbases():
243+
assert dill.copy(customIntList).__orig_bases__ == customIntList.__orig_bases__
212244

213245
def test_attr():
214246
import attr
@@ -238,6 +270,11 @@ def __new__(cls):
238270

239271
assert dill.copy(subclass_with_new())
240272

273+
def test_enummeta():
274+
from http import HTTPStatus
275+
import enum
276+
assert dill.copy(HTTPStatus.OK) is HTTPStatus.OK
277+
assert dill.copy(enum.EnumMeta) is enum.EnumMeta
241278

242279
if __name__ == '__main__':
243280
test_class_instances()
@@ -249,4 +286,6 @@ def __new__(cls):
249286
test_array_subclass()
250287
test_method_decorator()
251288
test_slots()
289+
test_origbases()
252290
test_metaclass()
291+
test_enummeta()

0 commit comments

Comments
 (0)