|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +import decimal |
3 | 4 | from decimal import Decimal |
| 5 | +from math import ceil, floor, isinf |
4 | 6 | from sys import float_info |
5 | 7 | from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast |
6 | 8 |
|
@@ -227,16 +229,62 @@ def generate_constrained_number( |
227 | 229 |
|
228 | 230 | :returns: A value of type T. |
229 | 231 | """ |
230 | | - if minimum is None or maximum is None: |
231 | | - return multiple_of if multiple_of is not None else method(random=random) |
232 | 232 | if multiple_of is None: |
233 | 233 | return method(random=random, minimum=minimum, maximum=maximum) |
234 | | - if multiple_of >= minimum: |
| 234 | + |
| 235 | + def passes_all_constraints(value: T) -> bool: |
| 236 | + return ( |
| 237 | + (minimum is None or value >= minimum) |
| 238 | + and (maximum is None or value <= maximum) |
| 239 | + and (multiple_of is None or passes_pydantic_multiple_validator(value, multiple_of)) |
| 240 | + ) |
| 241 | + |
| 242 | + # If the arguments are Decimals, they might have precision that is greater than the current decimal context. If |
| 243 | + # so, recreate them under the current context to ensure they have the appropriate precision. This is important |
| 244 | + # because otherwise, x * 1 == x may not always hold, which can cause the algorithm below to fail in unintuitive |
| 245 | + # ways. |
| 246 | + if isinstance(minimum, Decimal): |
| 247 | + minimum = decimal.getcontext().create_decimal(minimum) |
| 248 | + if isinstance(maximum, Decimal): |
| 249 | + maximum = decimal.getcontext().create_decimal(maximum) |
| 250 | + if isinstance(multiple_of, Decimal): |
| 251 | + multiple_of = decimal.getcontext().create_decimal(multiple_of) |
| 252 | + |
| 253 | + max_attempts = 10 |
| 254 | + for _ in range(max_attempts): |
| 255 | + # We attempt to generate a random number and find the nearest valid multiple, but a naive approach of rounding |
| 256 | + # to the nearest multiple may push the number out of range. To handle edge cases, we find both the nearest |
| 257 | + # multiple in both the negative and positive directions (floor and ceil), and we pick one that fits within |
| 258 | + # range. We should be guaranteed to find a number other than in the case where the range (minimum, maximum) is |
| 259 | + # narrow and does not contain any multiple of multiple_of. |
| 260 | + random_value = method(random=random, minimum=minimum, maximum=maximum) |
| 261 | + quotient = random_value / multiple_of |
| 262 | + if isinf(quotient): |
| 263 | + continue |
| 264 | + lower = floor(quotient) * multiple_of |
| 265 | + upper = ceil(quotient) * multiple_of |
| 266 | + |
| 267 | + # If both the lower and upper candidates are out of bounds, then there are no valid multiples that fit within |
| 268 | + # the specified range. |
| 269 | + if minimum is not None and maximum is not None and lower < minimum and upper > maximum: |
| 270 | + msg = f"no multiple of {multiple_of} exists between {minimum} and {maximum}" |
| 271 | + raise ParameterException(msg) |
| 272 | + |
| 273 | + for candidate in [lower, upper]: |
| 274 | + if not passes_all_constraints(candidate): |
| 275 | + continue |
| 276 | + return candidate |
| 277 | + |
| 278 | + # Try last-ditch attempt at using the multiple_of, 0, or -multiple_of as the value |
| 279 | + if passes_all_constraints(multiple_of): |
235 | 280 | return multiple_of |
236 | | - result = minimum |
237 | | - while not passes_pydantic_multiple_validator(result, multiple_of): |
238 | | - result = round(method(random=random, minimum=minimum, maximum=maximum) / multiple_of) * multiple_of |
239 | | - return result |
| 281 | + if passes_all_constraints(-multiple_of): |
| 282 | + return -multiple_of |
| 283 | + if passes_all_constraints(multiple_of * 0): |
| 284 | + return multiple_of * 0 |
| 285 | + |
| 286 | + msg = f"could not find solution in {max_attempts} attempts" |
| 287 | + raise ValueError(msg) |
240 | 288 |
|
241 | 289 |
|
242 | 290 | def handle_constrained_int( |
|
0 commit comments