diff --git a/pynamodb/attributes.py b/pynamodb/attributes.py index 957e148f..80034f76 100644 --- a/pynamodb/attributes.py +++ b/pynamodb/attributes.py @@ -6,12 +6,14 @@ import collections.abc import json import time +import typing import warnings from base64 import b64encode, b64decode from copy import deepcopy from datetime import datetime from datetime import timedelta from datetime import timezone +from decimal import Decimal from inspect import getfullargspec from inspect import getmembers from typing import Any, Callable, Dict, Generic, List, Mapping, Optional, TypeVar, Type, Union, Set, overload, Iterable @@ -178,7 +180,12 @@ def serialize(self, value: Any) -> Any: see `DynamoDB.Client.get_item API reference `_. """ - return value + if value is None or isinstance(value, str): + return value + raise TypeError( + f"UnicodeAttribute expected str for '{self.attr_name}', " + f"got {type(value).__name__}" + ) def deserialize(self, value: Any) -> Any: """ @@ -741,6 +748,13 @@ def serialize(self, value): """ Encode numbers as JSON """ + if isinstance(value, bool): + raise TypeError("Boolean values are not allowed for NumberAttribute") + + if not isinstance(value, (int, float, Decimal)): + raise TypeError( + f"Expected int, float, or Decimal for NumberAttribute, got {type(value).__name__}" + ) return json.dumps(value) def deserialize(self, value): diff --git a/tests/test_attributes.py b/tests/test_attributes.py index 862a492c..0d4a69a0 100644 --- a/tests/test_attributes.py +++ b/tests/test_attributes.py @@ -342,6 +342,24 @@ def test_number_deserialize(self): assert attr.deserialize('3.141') == 3.141 assert attr.deserialize('12345678909876543211234234324234') == 12345678909876543211234234324234 + @pytest.mark.parametrize( + "value", + [ + "42", # string + True, # bool + None, # None + [1, 2, 3], # list + {"a": 1}, # dict + {1, 2, 3}, # set + object(), # arbitrary object + ], + ) + def test_serialize_invalid_types(self, value): + attr = NumberAttribute() + with pytest.raises(TypeError) as exc_info: + attr.serialize(value) + assert "Expected int, float, or Decimal" in str(exc_info.value) + def test_number_set_deserialize(self): """ NumberSetAttribute.deserialize @@ -399,6 +417,12 @@ def test_unicode_deserialize(self): assert attr.deserialize('') == '' assert attr.deserialize(None) is None + @pytest.mark.parametrize("value", [123, 1.23, True, False, [], {}, object()]) + def test_unicode_serialize_invalid(self, value): + attr = UnicodeAttribute() + with pytest.raises(TypeError): + attr.serialize(value) + def test_unicode_set_serialize(self): """ UnicodeSetAttribute.serialize diff --git a/tests/test_model.py b/tests/test_model.py index da54303b..11b21751 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -727,12 +727,12 @@ def test_delete_doesnt_do_validation_on_null_attributes(self): """ with patch(PATCH_METHOD) as req: req.return_value = {} - CarModel('foo').delete() + CarModel(1234).delete() with patch(PATCH_METHOD) as req: req.return_value = {} with CarModel.batch_write() as batch: - car = CarModel('foo') + car = CarModel(1234) batch.delete(car) @patch('time.time') @@ -1046,7 +1046,7 @@ def fake_dynamodb(*args, **kwargs): def test_count_no_hash_key(self): with pytest.raises(ValueError): - UserModel.count(filter_condition=(UserModel.zip_code <= '94117')) + UserModel.count(filter_condition=(UserModel.zip_code <= 94117)) def test_index_count(self): """