Skip to content

Commit

Permalink
Polished docstrings.
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxBlesch committed Jun 3, 2024
1 parent 2683232 commit d9c5e19
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 58 deletions.
80 changes: 39 additions & 41 deletions src/upper_envelope/fues_jax/fues_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def fues_jax(
endog_grid: jnp.ndarray,
policy: jnp.ndarray,
value: jnp.ndarray,
expected_value_zero_savings: float,
expected_value_zero_savings: jnp.ndarray | float,
value_function: Callable,
value_function_args: Optional[Tuple] = (),
value_function_kwargs: Optional[Dict] = {},
Expand Down Expand Up @@ -70,28 +70,36 @@ def fues_jax(
subsequent periods t + 1, t + 2, ..., T under the optimal consumption policy.
Args:
endog_grid (np.ndarray): 1d array of shape (n_grid_wealth + 1,)
endog_grid (jnp.ndarray): 1d array of shape (n_grid_wealth + 1,)
containing the current state- and choice-specific endogenous grid.
policy (np.ndarray): 1d array of shape (n_grid_wealth + 1,)
policy (jnp.ndarray): 1d array of shape (n_grid_wealth + 1,)
containing the current state- and choice-specific policy function.
value (np.ndarray): 1d array of shape (n_grid_wealth + 1,)
value (jnp.ndarray): 1d array of shape (n_grid_wealth + 1,)
containing the current state- and choice-specific value function.
expected_value_zero_savings (float): The agent's expected value given that she
saves zero.
utility_function (callable): The utility function. The first argument is
assumed to be consumption.
utility_kwargs (dict): The keyword arguments to be passed to the utility
expected_value_zero_savings (jnp.ndarray | float): The agent's expected value
given that she saves zero.
value_function (callable): The value function for calculating the value if
nothing is saved.
value_function_args (Tuple): The positional arguments to be passed to the value
function.
value_function_kwargs (dict): The keyword arguments to be passed to the value
function.
n_constrained_points_to_add (int): Number of constrained points to add to the
left of the first grid point if there is an area with credit-constrain.
n_final_wealth_grid (int): Size of final function grid. Determines number of
iterations for the scan in the fues_jax.
jump_thresh (float): Jump detection threshold.
n_points_to_scan (int): Number of points to scan for suboptimal points.
Returns:
tuple:
- endog_grid_refined (np.ndarray): 1d array of shape (1.1 * n_grid_wealth,)
containing the refined state- and choice-specific endogenous grid.
- policy_refined_with_nans (np.ndarray): 1d array of shape (1.1 * n_grid_wealth)
containing refined state- and choice-specificconsumption policy.
- value_refined_with_nans (np.ndarray): 1d array of shape (1.1 * n_grid_wealth)
containing refined state- and choice-specific value function.
- endog_grid_refined (jnp.ndarray): 1d array of shape (n_final_wealth_grid,)
containing the refined endogenous wealth grid.
- policy_refined (jnp.ndarray): 1d array of shape (n_final_wealth_grid,)
containing refined consumption policy.
- value_refined (jnp.ndarray): 1d array of shape (n_final_wealth_grid,)
containing refined value function.
"""
# Set default of n_constrained_points_to_add to 10% of the grid size
Expand Down Expand Up @@ -167,40 +175,30 @@ def fues_jax_unconstrained(
"""Remove suboptimal points from the endogenous grid, policy, and value function.
Args:
endog_grid (np.ndarray): 1d array containing the unrefined endogenous wealth
grid of shape (n_grid_wealth + 1,).
value (np.ndarray): 1d array containing the unrefined value correspondence
of shape (n_grid_wealth + 1,).
policy (np.ndarray): 1d array containing the unrefined policy correspondence
of shape (n_grid_wealth + 1,).
expected_value_zero_savings (float): The agent's expected value given that she
saves zero.
n_final_wealth_grid (int): Size of final grid. Determines number of
endog_grid (jnp.ndarray): 1d array of shape (n_grid_wealth + 1,)
containing the current state- and choice-specific endogenous grid.
policy (jnp.ndarray): 1d array of shape (n_grid_wealth + 1,)
containing the current state- and choice-specific policy function.
value (jnp.ndarray): 1d array of shape (n_grid_wealth + 1,)
containing the current state- and choice-specific value function.
expected_value_zero_savings (jnp.ndarray | float): The agent's expected value
given that she saves zero.
n_final_wealth_grid (int): Size of final function grid. Determines number of
iterations for the scan in the fues_jax.
jump_thresh (float): Jump detection threshold.
n_points_to_scan (int): Number of points to scan for suboptimal points.
Returns:
tuple:
- endog_grid_refined (np.ndarray): 1d array containing the refined endogenous
wealth grid of shape (n_grid_clean,), which maps only to the optimal points
in the value function.
- value_refined (np.ndarray): 1d array containing the refined value function
of shape (n_grid_clean,). Overlapping segments have been removed and only
the optimal points are kept.
- policy_refined (np.ndarray): 1d array containing the refined policy function
of shape (n_grid_clean,). Overlapping segments have been removed and only
the optimal points are kept.
- endog_grid_refined (jnp.ndarray): 1d array of shape (n_final_wealth_grid,)
containing the refined endogenous wealth grid.
- policy_refined (jnp.ndarray): 1d array of shape (n_final_wealth_grid,)
containing refined consumption policy.
- value_refined (jnp.ndarray): 1d array of shape (n_final_wealth_grid,)
containing refined value function.
"""
# Comment by Akshay: Determine locations where endogenous grid points are
# equal to the lower bound. Not relevant for us.
# mask = endog_grid <= lower_bound_wealth
# if jnp.any(mask):
# max_value_lower_bound = jnp.nanmax(value[mask])
# mask &= value < max_value_lower_bound
# value[mask] = jnp.nan

# Set default value of final grid size to 1.2 times current if not defined
n_final_wealth_grid = (
int(1.2 * (len(policy))) if n_final_wealth_grid is None else n_final_wealth_grid
Expand Down
18 changes: 9 additions & 9 deletions tests/test_fues_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,16 +130,16 @@ def value_func(consumption, choice, params):
policy_refined,
value_refined,
) = upenv.fues_jax(
endog_grid=policy_egm[0, 1:],
policy=policy_egm[1, 1:],
value=value_egm[1, 1:],
endog_grid=jnp.asarray(policy_egm[0, 1:]),
policy=jnp.asarray(policy_egm[1, 1:]),
value=jnp.asarray(value_egm[1, 1:]),
expected_value_zero_savings=value_egm[1, 0],
value_function=value_func,
value_function_kwargs=value_function_kwargs,
)

wealth_max_to_test = np.max(endog_grid_refined[~np.isnan(endog_grid_refined)]) + 100
wealth_grid_to_test = jnp.linspace(
wealth_grid_to_test = np.linspace(
endog_grid_refined[1], wealth_max_to_test, 1000, dtype=float
)

Expand Down Expand Up @@ -257,9 +257,9 @@ def value_func(consumption, choice, params):
policy_fues,
value_fues,
) = upenv.fues_jax(
endog_grid=policy_egm[0, 1:],
policy=policy_egm[1, 1:],
value=value_egm[1, 1:],
endog_grid=jnp.asarray(policy_egm[0, 1:]),
policy=jnp.asarray(policy_egm[1, 1:]),
value=jnp.asarray(value_egm[1, 1:]),
expected_value_zero_savings=value_egm[1, 0],
value_function=value_func,
value_function_args=(state_choice_vec["choice"], params),
Expand Down Expand Up @@ -296,8 +296,8 @@ def test_back_and_forward_scan_wrapper_direction_flag():
endog_grid_to_scan_from=1.2,
policy_to_scan_from=0.7,
endog_grid=1,
value=np.arange(2, 5),
policy=np.arange(1, 4),
value=jnp.arange(2, 5),
policy=jnp.arange(1, 4),
idx_to_scan_from=2,
n_points_to_scan=3,
is_scan_needed=False,
Expand Down
17 changes: 9 additions & 8 deletions tests/utils/interpolation.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from typing import Tuple

import jax.numpy as jnp
import numpy as np


def interpolate_policy_and_value_on_wealth_grid(
wealth_beginning_of_period: np.ndarray,
endog_wealth_grid: np.ndarray,
policy_grid: np.ndarray,
value_function_grid: np.ndarray,
wealth_beginning_of_period: np.ndarray | jnp.ndarray,
endog_wealth_grid: np.ndarray | jnp.ndarray,
policy_grid: np.ndarray | jnp.ndarray,
value_function_grid: np.ndarray | jnp.ndarray,
):
"""Interpolate policy and value functions on the wealth grid.
Expand Down Expand Up @@ -62,10 +63,10 @@ def interpolate_policy_and_value_on_wealth_grid(


def interpolate_single_policy_and_value_on_wealth_grid(
wealth_beginning_of_period: np.ndarray,
endog_wealth_grid: np.ndarray,
policy_grid: np.ndarray,
value_function_grid: np.ndarray,
wealth_beginning_of_period: np.ndarray | jnp.ndarray,
endog_wealth_grid: np.ndarray | jnp.ndarray,
policy_grid: np.ndarray | jnp.ndarray,
value_function_grid: np.ndarray | jnp.ndarray,
):
"""Interpolate policy and value functions on the wealth grid.
Expand Down

0 comments on commit d9c5e19

Please sign in to comment.