Skip to content

Commit

Permalink
Merge pull request #3485 from chiamp:enum
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 582804999
  • Loading branch information
Flax Authors committed Nov 15, 2023
2 parents e820325 + 01049c7 commit 79d925f
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 1 deletion.
8 changes: 7 additions & 1 deletion flax/linen/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Flax Module summary library."""

import dataclasses
import enum
import io
from abc import ABC, abstractmethod
from types import MappingProxyType
Expand Down Expand Up @@ -709,12 +710,17 @@ def _normalize_structure(obj):
if isinstance(obj, (tuple, list)):
return tuple(map(_normalize_structure, obj))
elif isinstance(obj, Mapping):
return {k: _normalize_structure(v) for k, v in obj.items()}
return {
_normalize_structure(k): _normalize_structure(v) for k, v in obj.items()
}
elif dataclasses.is_dataclass(obj):
return {
f.name: _normalize_structure(getattr(obj, f.name))
for f in dataclasses.fields(obj)
}
elif isinstance(obj, enum.Enum):
# `yaml.safe_dump` does not support Enum key types so extract the underlying value
return obj.value
else:
return obj

Expand Down
21 changes: 21 additions & 0 deletions tests/linen/summary_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import enum
from typing import List

import jax
Expand Down Expand Up @@ -735,6 +736,26 @@ def __call__(self, x):
lines = rep.splitlines()
self.assertIn('Total Parameters: 50', lines[-2])

def test_tabulate_enum(self):
class Net(nn.Module):
@nn.compact
def __call__(self, inputs):
x = inputs['x']
x = nn.Dense(features=2)(x)
return jnp.sum(x)

class InputEnum(str, enum.Enum):
x = 'x'

inputs = {InputEnum.x: jnp.ones((1, 1))}
# test args
lines = Net().tabulate(jax.random.key(0), inputs).split('\n')
self.assertIn('x: \x1b[2mfloat32\x1b[0m[1,1]', lines[5])
# test kwargs
lines = Net().tabulate(jax.random.key(0), inputs=inputs).split('\n')
self.assertIn('inputs:', lines[5])
self.assertIn('x: \x1b[2mfloat32\x1b[0m[1,1]', lines[6])


if __name__ == '__main__':
absltest.main()

0 comments on commit 79d925f

Please sign in to comment.