diff --git a/yacs/config.py b/yacs/config.py index 54bec35..17400a1 100644 --- a/yacs/config.py +++ b/yacs/config.py @@ -45,7 +45,7 @@ _FILE_TYPES = (io.IOBase,) # CfgNodes can only contain a limited set of valid types -_VALID_TYPES = {tuple, list, str, int, float, bool} +_VALID_TYPES = {tuple, list, str, int, float, bool, type(None)} # py2 allow for str and unicode if _PY2: _VALID_TYPES = _VALID_TYPES.union({unicode}) # noqa: F821 @@ -489,6 +489,12 @@ def _check_and_coerce_cfg_value_type(replacement, original, key, full_key): if replacement_type == original_type: return replacement + # If either of them is None, allow type conversion to one of the valid types + if (replacement_type == type(None) and original_type in _VALID_TYPES) or ( + original_type == type(None) and replacement_type in _VALID_TYPES + ): + return replacement + # Cast replacement from from_type to to_type if the replacement and original # types match from_type and to_type def conditional_cast(from_type, to_type): diff --git a/yacs/tests.py b/yacs/tests.py index 6748214..23bc871 100644 --- a/yacs/tests.py +++ b/yacs/tests.py @@ -3,6 +3,7 @@ import unittest import yacs.config +import yaml from yacs.config import CfgNode as CN try: @@ -196,8 +197,14 @@ def test_nonexistant_key_from_list(self): cfg.merge_from_list(opts) def test_load_cfg_invalid_type(self): - # FOO.BAR.QUUX will have type None, which is not allowed - cfg_string = "FOO:\n BAR:\n QUUX:" + class CustomClass(yaml.YAMLObject): + """A custom class that yaml.safe_load can load.""" + + yaml_loader = yaml.SafeLoader + yaml_tag = u"!CustomClass" + + # FOO.BAR.QUUX will have type CustomClass, which is not allowed + cfg_string = "FOO:\n BAR:\n QUUX: !CustomClass {}" with self.assertRaises(AssertionError): yacs.config.load_cfg(cfg_string)