Skip to content

Commit 3ec36f2

Browse files
committed
Preserve order of NamedTuple fields in JointDistributionNamed.
This fixes a breakage of `test_can_call_namedtuple_log_prob_with_args_and_kwargs` in python 3.8, when toposort of NamedTuple args returned an order different from input. PiperOrigin-RevId: 311437218
1 parent 3244d86 commit 3ec36f2

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

tensorflow_probability/python/distributions/joint_distribution_named.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,5 +287,7 @@ def _convert_to_dict(x):
287287
if isinstance(x, collections.OrderedDict):
288288
return x
289289
if hasattr(x, '_asdict'):
290-
return x._asdict()
290+
# Wrap with `OrderedDict` to indicate that namedtuples have a well-defined
291+
# order (by default, they convert to just `dict` in Python 3.8+).
292+
return collections.OrderedDict(x._asdict())
291293
return dict(x)

0 commit comments

Comments
 (0)