From d9c5e19d80eadaabf0aac7def131880616011553 Mon Sep 17 00:00:00 2001 From: MaxBlesch Date: Mon, 3 Jun 2024 10:07:09 +0200 Subject: [PATCH] Polished docstrings. --- src/upper_envelope/fues_jax/fues_jax.py | 80 ++++++++++++------------- tests/test_fues_jax.py | 18 +++--- tests/utils/interpolation.py | 17 +++--- 3 files changed, 57 insertions(+), 58 deletions(-) diff --git a/src/upper_envelope/fues_jax/fues_jax.py b/src/upper_envelope/fues_jax/fues_jax.py index 4391667..bbfea30 100644 --- a/src/upper_envelope/fues_jax/fues_jax.py +++ b/src/upper_envelope/fues_jax/fues_jax.py @@ -37,7 +37,7 @@ def fues_jax( endog_grid: jnp.ndarray, policy: jnp.ndarray, value: jnp.ndarray, - expected_value_zero_savings: float, + expected_value_zero_savings: jnp.ndarray | float, value_function: Callable, value_function_args: Optional[Tuple] = (), value_function_kwargs: Optional[Dict] = {}, @@ -70,28 +70,36 @@ def fues_jax( subsequent periods t + 1, t + 2, ..., T under the optimal consumption policy. Args: - endog_grid (np.ndarray): 1d array of shape (n_grid_wealth + 1,) + endog_grid (jnp.ndarray): 1d array of shape (n_grid_wealth + 1,) containing the current state- and choice-specific endogenous grid. - policy (np.ndarray): 1d array of shape (n_grid_wealth + 1,) + policy (jnp.ndarray): 1d array of shape (n_grid_wealth + 1,) containing the current state- and choice-specific policy function. - value (np.ndarray): 1d array of shape (n_grid_wealth + 1,) + value (jnp.ndarray): 1d array of shape (n_grid_wealth + 1,) containing the current state- and choice-specific value function. - expected_value_zero_savings (float): The agent's expected value given that she - saves zero. - utility_function (callable): The utility function. The first argument is - assumed to be consumption. - utility_kwargs (dict): The keyword arguments to be passed to the utility + expected_value_zero_savings (jnp.ndarray | float): The agent's expected value + given that she saves zero. + value_function (callable): The value function for calculating the value if + nothing is saved. + value_function_args (Tuple): The positional arguments to be passed to the value function. + value_function_kwargs (dict): The keyword arguments to be passed to the value + function. + n_constrained_points_to_add (int): Number of constrained points to add to the + left of the first grid point if there is an area with credit-constrain. + n_final_wealth_grid (int): Size of final function grid. Determines number of + iterations for the scan in the fues_jax. + jump_thresh (float): Jump detection threshold. + n_points_to_scan (int): Number of points to scan for suboptimal points. Returns: tuple: - - endog_grid_refined (np.ndarray): 1d array of shape (1.1 * n_grid_wealth,) - containing the refined state- and choice-specific endogenous grid. - - policy_refined_with_nans (np.ndarray): 1d array of shape (1.1 * n_grid_wealth) - containing refined state- and choice-specificconsumption policy. - - value_refined_with_nans (np.ndarray): 1d array of shape (1.1 * n_grid_wealth) - containing refined state- and choice-specific value function. + - endog_grid_refined (jnp.ndarray): 1d array of shape (n_final_wealth_grid,) + containing the refined endogenous wealth grid. + - policy_refined (jnp.ndarray): 1d array of shape (n_final_wealth_grid,) + containing refined consumption policy. + - value_refined (jnp.ndarray): 1d array of shape (n_final_wealth_grid,) + containing refined value function. """ # Set default of n_constrained_points_to_add to 10% of the grid size @@ -167,40 +175,30 @@ def fues_jax_unconstrained( """Remove suboptimal points from the endogenous grid, policy, and value function. Args: - endog_grid (np.ndarray): 1d array containing the unrefined endogenous wealth - grid of shape (n_grid_wealth + 1,). - value (np.ndarray): 1d array containing the unrefined value correspondence - of shape (n_grid_wealth + 1,). - policy (np.ndarray): 1d array containing the unrefined policy correspondence - of shape (n_grid_wealth + 1,). - expected_value_zero_savings (float): The agent's expected value given that she - saves zero. - n_final_wealth_grid (int): Size of final grid. Determines number of + endog_grid (jnp.ndarray): 1d array of shape (n_grid_wealth + 1,) + containing the current state- and choice-specific endogenous grid. + policy (jnp.ndarray): 1d array of shape (n_grid_wealth + 1,) + containing the current state- and choice-specific policy function. + value (jnp.ndarray): 1d array of shape (n_grid_wealth + 1,) + containing the current state- and choice-specific value function. + expected_value_zero_savings (jnp.ndarray | float): The agent's expected value + given that she saves zero. + n_final_wealth_grid (int): Size of final function grid. Determines number of iterations for the scan in the fues_jax. jump_thresh (float): Jump detection threshold. + n_points_to_scan (int): Number of points to scan for suboptimal points. Returns: tuple: - - endog_grid_refined (np.ndarray): 1d array containing the refined endogenous - wealth grid of shape (n_grid_clean,), which maps only to the optimal points - in the value function. - - value_refined (np.ndarray): 1d array containing the refined value function - of shape (n_grid_clean,). Overlapping segments have been removed and only - the optimal points are kept. - - policy_refined (np.ndarray): 1d array containing the refined policy function - of shape (n_grid_clean,). Overlapping segments have been removed and only - the optimal points are kept. + - endog_grid_refined (jnp.ndarray): 1d array of shape (n_final_wealth_grid,) + containing the refined endogenous wealth grid. + - policy_refined (jnp.ndarray): 1d array of shape (n_final_wealth_grid,) + containing refined consumption policy. + - value_refined (jnp.ndarray): 1d array of shape (n_final_wealth_grid,) + containing refined value function. """ - # Comment by Akshay: Determine locations where endogenous grid points are - # equal to the lower bound. Not relevant for us. - # mask = endog_grid <= lower_bound_wealth - # if jnp.any(mask): - # max_value_lower_bound = jnp.nanmax(value[mask]) - # mask &= value < max_value_lower_bound - # value[mask] = jnp.nan - # Set default value of final grid size to 1.2 times current if not defined n_final_wealth_grid = ( int(1.2 * (len(policy))) if n_final_wealth_grid is None else n_final_wealth_grid diff --git a/tests/test_fues_jax.py b/tests/test_fues_jax.py index 78c5e45..764e454 100644 --- a/tests/test_fues_jax.py +++ b/tests/test_fues_jax.py @@ -130,16 +130,16 @@ def value_func(consumption, choice, params): policy_refined, value_refined, ) = upenv.fues_jax( - endog_grid=policy_egm[0, 1:], - policy=policy_egm[1, 1:], - value=value_egm[1, 1:], + endog_grid=jnp.asarray(policy_egm[0, 1:]), + policy=jnp.asarray(policy_egm[1, 1:]), + value=jnp.asarray(value_egm[1, 1:]), expected_value_zero_savings=value_egm[1, 0], value_function=value_func, value_function_kwargs=value_function_kwargs, ) wealth_max_to_test = np.max(endog_grid_refined[~np.isnan(endog_grid_refined)]) + 100 - wealth_grid_to_test = jnp.linspace( + wealth_grid_to_test = np.linspace( endog_grid_refined[1], wealth_max_to_test, 1000, dtype=float ) @@ -257,9 +257,9 @@ def value_func(consumption, choice, params): policy_fues, value_fues, ) = upenv.fues_jax( - endog_grid=policy_egm[0, 1:], - policy=policy_egm[1, 1:], - value=value_egm[1, 1:], + endog_grid=jnp.asarray(policy_egm[0, 1:]), + policy=jnp.asarray(policy_egm[1, 1:]), + value=jnp.asarray(value_egm[1, 1:]), expected_value_zero_savings=value_egm[1, 0], value_function=value_func, value_function_args=(state_choice_vec["choice"], params), @@ -296,8 +296,8 @@ def test_back_and_forward_scan_wrapper_direction_flag(): endog_grid_to_scan_from=1.2, policy_to_scan_from=0.7, endog_grid=1, - value=np.arange(2, 5), - policy=np.arange(1, 4), + value=jnp.arange(2, 5), + policy=jnp.arange(1, 4), idx_to_scan_from=2, n_points_to_scan=3, is_scan_needed=False, diff --git a/tests/utils/interpolation.py b/tests/utils/interpolation.py index 0638e89..faf03a8 100644 --- a/tests/utils/interpolation.py +++ b/tests/utils/interpolation.py @@ -1,13 +1,14 @@ from typing import Tuple +import jax.numpy as jnp import numpy as np def interpolate_policy_and_value_on_wealth_grid( - wealth_beginning_of_period: np.ndarray, - endog_wealth_grid: np.ndarray, - policy_grid: np.ndarray, - value_function_grid: np.ndarray, + wealth_beginning_of_period: np.ndarray | jnp.ndarray, + endog_wealth_grid: np.ndarray | jnp.ndarray, + policy_grid: np.ndarray | jnp.ndarray, + value_function_grid: np.ndarray | jnp.ndarray, ): """Interpolate policy and value functions on the wealth grid. @@ -62,10 +63,10 @@ def interpolate_policy_and_value_on_wealth_grid( def interpolate_single_policy_and_value_on_wealth_grid( - wealth_beginning_of_period: np.ndarray, - endog_wealth_grid: np.ndarray, - policy_grid: np.ndarray, - value_function_grid: np.ndarray, + wealth_beginning_of_period: np.ndarray | jnp.ndarray, + endog_wealth_grid: np.ndarray | jnp.ndarray, + policy_grid: np.ndarray | jnp.ndarray, + value_function_grid: np.ndarray | jnp.ndarray, ): """Interpolate policy and value functions on the wealth grid.