diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5f10688..a15983b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,12 +1,12 @@ --- repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.2.0 + rev: v5.0.0 hooks: - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/pre-commit/mirrors-autopep8 - rev: v1.6.0 + rev: 'v2.0.4' hooks: - id: autopep8 diff --git a/econpizza/__init__.py b/econpizza/__init__.py index a334be3..024a6b5 100644 --- a/econpizza/__init__.py +++ b/econpizza/__init__.py @@ -17,9 +17,7 @@ from .parser import parse, load -# set number of cores for XLA -os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={os.cpu_count()}" - +# setting precision is necessary jax.config.update("jax_enable_x64", True) # create local alias @@ -57,13 +55,14 @@ def get_distributions(self, trajectory, init_dist=None, shock=None, pars=None): a dictionary of the distributions """ + # get transformers (experimental) + transform_back = self['options'].get('transform_back') or (lambda x: x) + dist0 = jnp.array(init_dist) if init_dist is not None else jnp.array( self['steady_state'].get('distributions')) - if self.get('exp_all'): - pars = jnp.log(jnp.array(list(self['pars'].values())) if pars is None else pars) - trajectory = jnp.log(trajectory) - else: - pars = jnp.array(list(self['pars'].values())) if pars is None else pars + pars = transform_back( + jnp.array(list(self['pars'].values())) if pars is None else pars) + trajectory = transform_back(trajectory) shocks = self.get("shocks") or () dist_names = list(self['distributions'].keys()) decisions_inputs = self['decisions']['inputs'] @@ -81,13 +80,17 @@ def get_distributions(self, trajectory, init_dist=None, shock=None, pars=None): # get functions and execute backwards_sweep = self['context']['backwards_sweep'] forwards_sweep = self['context']['forwards_sweep'] - wf_storage, decisions_output_storage = backwards_sweep(x, x0, shock_series.T, pars, return_wf=True) + wf_storage, decisions_output_storage = backwards_sweep( + x, x0, shock_series.T, pars, return_wf=True) dists_storage = forwards_sweep(decisions_output_storage, dist0) # store this - rdict = {oput: wf_storage[i] for i, oput in enumerate(decisions_inputs)} - rdict.update({oput: decisions_output_storage[i] for i, oput in enumerate(decisions_outputs)}) - rdict.update({oput: dists_storage[i] for i, oput in enumerate(dist_names)}) + rdict = {oput: wf_storage[i] + for i, oput in enumerate(decisions_inputs)} + rdict.update( + {oput: decisions_output_storage[i] for i, oput in enumerate(decisions_outputs)}) + rdict.update({oput: dists_storage[i] + for i, oput in enumerate(dist_names)}) return rdict diff --git a/econpizza/__version__.py b/econpizza/__version__.py index 4671d4d..224fbbe 100644 --- a/econpizza/__version__.py +++ b/econpizza/__version__.py @@ -1,3 +1,3 @@ # -*- coding: utf-8 -*- -__version__ = '0.6.4' +__version__ = '0.6.5' diff --git a/econpizza/parser/__init__.py b/econpizza/parser/__init__.py index 196f711..7c0e3d5 100644 --- a/econpizza/parser/__init__.py +++ b/econpizza/parser/__init__.py @@ -14,7 +14,6 @@ import importlib.util as iu from copy import deepcopy, copy from inspect import getmembers, isfunction -from jax.experimental.host_callback import id_print as jax_print from .compile_model_functions import * from .checks import * from ..utilities import grids, dists, interp @@ -189,6 +188,16 @@ def _define_function(func_str, context): return tmpf.name +def wrap_with_transform(func, transform): + if transform: + def outfunc(XLag, X, XPrime, XSS, pars, *args, **kwargs): + Xandpar = (transform(y) for y in (XLag, X, XPrime, XSS, pars)) + return func(*Xandpar, *args, **kwargs) + return outfunc + else: + return func + + def _get_pre_stst_mapping(init_vals, fixed_values, evars, par_names): """Define the mapping from init & fixed vals to model variables & parameters """ @@ -296,7 +305,8 @@ def load( # compile globals & definitions _ = _define_subdict_if_absent(model, "globals") - _, model['context'] = _eval_strs(model['globals'], context=model['context']) + _, model['context'] = _eval_strs( + model['globals'], context=model['context']) defs = model.get("definitions") defs = '' if defs is None else defs defs = '\n'.join(defs) if isinstance(defs, list) else defs @@ -320,7 +330,8 @@ def load( _ = _define_subdict_if_absent(model, "func_strings") _ = _define_subdict_if_absent(model, "steady_state") pars = _define_subdict_if_absent(model, "parameters") - par_names = model["par_names"] = [*pars] if isinstance(pars, dict) else pars + par_names = model["par_names"] = [ + *pars] if isinstance(pars, dict) else pars if 'lambda' in evars + par_names: raise NameError( "Variables or parameters must not use the name of python's build-in functions \"lambda\".") @@ -328,6 +339,11 @@ def load( raise TypeError( f'parameters must be a list and not {type(par_names)}.') + # get and evaluate options + _ = _define_subdict_if_absent(model, "options") + model['options'], _ = _eval_strs( + model['options'], context=model['context']) + # get function strings for decisions and distributions, if they exist if model.get('decisions'): decisions_outputs = model['decisions']['outputs'] @@ -336,10 +352,8 @@ def load( evars, par_names, shocks, decisions_inputs, decisions_outputs, model['decisions']['calls']) _define_function(model['func_strings'] ['func_backw'], model['context']) - if model.get('exp_all'): - model['context']['func_backw'] = lambda xl,xc,xp,XSS,WFPrime,shocks,pars: model['context']['func_backw_raw'](jnp.exp(xl), jnp.exp(xc), jnp.exp(xp), jnp.exp(XSS), WFPrime, shocks, jnp.exp(pars)) - else: - model['context']['func_backw'] = model['context']['func_backw_raw'] + model['context']['func_backw'] = wrap_with_transform( + model['context']['func_backw_raw'], model['options'].get('transform_to')) else: decisions_outputs = [] decisions_inputs = [] @@ -357,10 +371,8 @@ def load( # writing to tempfiles helps to get nice debug traces if the model does not work _define_function(model['func_strings']['func_eqns'], model['context']) - if model.get('exp_all'): - model['context']['func_eqns'] = lambda xl,xc,xp,XSS,shocks,pars,distributions,decisions_outputs: model['context']['func_eqns_raw'](jnp.exp(xl), jnp.exp(xc), jnp.exp(xp), jnp.exp(XSS), shocks, jnp.exp(pars), distributions, decisions_outputs) - else: - model['context']['func_eqns'] = model['context']['func_eqns_raw'] + model['context']['func_eqns'] = wrap_with_transform( + model['context']['func_eqns_raw'], model['options'].get('transform_to')) # compile fixed and initial values stst_inputs = compile_stst_inputs(model) # try if function works on initvals diff --git a/econpizza/parser/build_generic_functions.py b/econpizza/parser/build_generic_functions.py index ad61eb9..9839213 100644 --- a/econpizza/parser/build_generic_functions.py +++ b/econpizza/parser/build_generic_functions.py @@ -132,7 +132,8 @@ def get_stst_derivatives(model, nvars, pars, stst, x_stst, zshocks, horizon, ver combined_sweep = model['context']['combined_sweep'] distSS = jnp.array(model['steady_state']['distributions']) - decisions_outputSS = (jnp.array(d)[..., None] for d in list(model['steady_state']['decisions'].values())) + decisions_outputSS = (jnp.array(d)[..., None] for d in list( + model['steady_state']['decisions'].values())) # basis for steady state jacobian construction basis = jnp.zeros((nvars*(horizon-1), nvars)) @@ -148,7 +149,7 @@ def get_stst_derivatives(model, nvars, pars, stst, x_stst, zshocks, horizon, ver # get steady state jacobians for direct effects x on f jacrev_func_eqns = jax.jacrev(func_eqns, argnums=(0, 1, 2)) f2X = jacrev_func_eqns(stst[:, None], stst[:, None], stst[:, None], - stst, zshocks[:, 0], pars, distSS[..., None], decisions_outputSS) + stst, pars, zshocks[:, 0], distSS[..., None], decisions_outputSS) if verbose: duration = time.time() - st @@ -174,7 +175,8 @@ def get_stacked_func_het_agents(func_backw, func_forw, func_eqns, stst, wfSS, ho forwards_sweep, horizon=horizon, func_forw=partial_forw) final_step_local = jax.tree_util.Partial( final_step, stst=stst, horizon=horizon, nshpe=nshpe, func_eqns=func_eqns) - combined_sweep_local = jax.tree_util.Partial(combined_sweep, forwards_sweep=forwards_sweep_local, final_step=final_step_local) + combined_sweep_local = jax.tree_util.Partial( + combined_sweep, forwards_sweep=forwards_sweep_local, final_step=final_step_local) stacked_func_local = jax.tree_util.Partial( stacked_func_het_agents, backwards_sweep=backwards_sweep_local, combined_sweep=combined_sweep_local) diff --git a/econpizza/parser/checks.py b/econpizza/parser/checks.py index 9e1b5d1..6b92ad3 100644 --- a/econpizza/parser/checks.py +++ b/econpizza/parser/checks.py @@ -71,7 +71,8 @@ def check_initial_values(model, shocks, init_guesses, fixed_values, init_wf, pre mess = '' if model.get('decisions'): # make a test backward and forward run - _, decisions_output_init = model['context']['func_backw'](init_vals, init_vals, init_vals, init_vals, init_wf, jnp.zeros(len(shocks)), par) + _, decisions_output_init = model['context']['func_backw']( + init_vals, init_vals, init_vals, init_vals, par, init_wf, jnp.zeros(len(shocks))) dists_init, _ = model['context']['func_forw_stst']( decisions_output_init, 1e-8, 1) @@ -91,8 +92,8 @@ def check_initial_values(model, shocks, init_guesses, fixed_values, init_wf, pre # final test of main function init_vals = init_vals[..., None] - test = model['context']['func_eqns'](init_vals, init_vals, init_vals, init_vals, jnp.zeros( - len(shocks)), par, jnp.array(dists_init)[..., None], (doi[...,None] for doi in decisions_output_init)) + test = model['context']['func_eqns'](init_vals, init_vals, init_vals, init_vals, par, jnp.zeros( + len(shocks)), jnp.array(dists_init)[..., None], (doi[..., None] for doi in decisions_output_init)) if mess: pass diff --git a/econpizza/parser/compile_model_functions.py b/econpizza/parser/compile_model_functions.py index e834432..f7e0750 100644 --- a/econpizza/parser/compile_model_functions.py +++ b/econpizza/parser/compile_model_functions.py @@ -25,7 +25,7 @@ def compile_backw_func_str(evars, par, shocks, inputs, outputs, calls): if isinstance(calls, str): calls = calls.splitlines() - func_str = f"""def func_backw_raw(XLag, X, XPrime, XSS, WFPrime, shocks, pars): + func_str = f"""def func_backw_raw(XLag, X, XPrime, XSS, pars, WFPrime, shocks): {compile_func_basics_str(evars, par, shocks)} \n ({"".join(v + ", " for v in inputs)}) = WFPrime \n %s @@ -49,8 +49,10 @@ def get_forw_funcs(model): dist = distributions[dist_name] # *_generic should be depreciated at some point - implemented_endo = ('exogenous', 'exogenous_rouwenhorst', 'exogenous_generic', 'exogenous_custom') - implemented_exo = ('endogenous', 'endogenous_log', 'endogenous_generic', 'endogenous_custom') + implemented_endo = ('exogenous', 'exogenous_rouwenhorst', + 'exogenous_generic', 'exogenous_custom') + implemented_exo = ('endogenous', 'endogenous_log', + 'endogenous_generic', 'endogenous_custom') exog = [v for v in dist if dist[v]['type'] in implemented_endo] endo = [v for v in dist if dist[v]['type'] in implemented_exo] other = [dist[v]['type'] for v in dist if dist[v] @@ -68,13 +70,15 @@ def get_forw_funcs(model): # for each object, check if it is provided in decisions_outputs try: - transition = model['decisions']['outputs'].index(dist[exog[0]]['transition_name']) + transition = model['decisions']['outputs'].index( + dist[exog[0]]['transition_name']) except ValueError: transition = model['context'][dist[exog[0]]['transition_name']] grids = [] for i in endo: try: - grids.append(model['decisions']['outputs'].index(dist[i]['grid_name'])) + grids.append(model['decisions'] + ['outputs'].index(dist[i]['grid_name'])) except ValueError: grids.append(model['context'][dist[i]['grid_name']]) indices = [model['decisions']['outputs'].index(i) for i in endo] @@ -108,7 +112,7 @@ def compile_eqn_func_str(evars, eqns, par, eqns_aux, shocks, distributions, deci eqns_stack = "\n ".join(eqns) # compile the final function string - func_str = f"""def func_eqns_raw(XLag, X, XPrime, XSS, shocks, pars, distributions=[], decisions_outputs=[]): + func_str = f"""def func_eqns_raw(XLag, X, XPrime, XSS, pars, shocks, distributions=[], decisions_outputs=[]): {compile_func_basics_str(evars, par, shocks)} \n ({"".join(d+', ' for d in distributions)}) = distributions \n ({"".join(d+', ' for d in decisions_outputs)}) = decisions_outputs diff --git a/econpizza/parser/het_agent_base_funcs.py b/econpizza/parser/het_agent_base_funcs.py index d0f3771..aacd614 100644 --- a/econpizza/parser/het_agent_base_funcs.py +++ b/econpizza/parser/het_agent_base_funcs.py @@ -16,7 +16,7 @@ def _backwards_stst_cond(carry): def _backwards_stst_body(carry): (x, par), (wf, _), (_, cnt), (func, tol, maxit) = carry - return (x, par), func(x, x, x, x, wf, pars=par), (wf, cnt + 1), (func, tol, maxit) + return (x, par), func(x, x, x, x, pars=par, WFPrime=wf), (wf, cnt + 1), (func, tol, maxit) def backwards_sweep_stst(x, par, carry): @@ -29,7 +29,7 @@ def _backwards_step(carry, i): wf, X, shocks, func_backw, stst, pars = carry wf, decisions_output = func_backw( - X[:, i], X[:, i+1], X[:, i+2], WFPrime=wf, shocks=shocks[:, i], pars=pars) + X[:, i], X[:, i+1], X[:, i+2], pars=pars, WFPrime=wf, shocks=shocks[:, i]) return (wf, X, shocks, func_backw, stst, pars), (wf, decisions_output) @@ -40,7 +40,8 @@ def backwards_sweep(x: Array, x0: Array, shocks: Array, pars: Array, stst: Array _, (wf_storage, decisions_output_storage) = jax.lax.scan( _backwards_step, (wfSS, X, shocks, func_backw, stst, pars), jnp.arange(horizon-1), reverse=True) - decisions_output_storage = [jnp.moveaxis(dos, 0, -1) for dos in decisions_output_storage] + decisions_output_storage = [jnp.moveaxis( + dos, 0, -1) for dos in decisions_output_storage] wf_storage = jnp.moveaxis(wf_storage, 0, -1) if return_wf: @@ -51,7 +52,8 @@ def backwards_sweep(x: Array, x0: Array, shocks: Array, pars: Array, stst: Array def _forwards_step(carry, i): dist_old, decisions_output_storage, func_forw = carry - dist = func_forw(dist_old, [dos[..., i] for dos in decisions_output_storage]) + dist = func_forw(dist_old, [dos[..., i] + for dos in decisions_output_storage]) return (dist, decisions_output_storage, func_forw), dist_old @@ -69,7 +71,7 @@ def final_step(x: Array, dists_storage: Array, decisions_output_storage: Array, X = jnp.hstack((x0, x, stst)).reshape(horizon+1, -1).T out = func_eqns(X[:, :-2].reshape(nshpe), X[:, 1:-1].reshape(nshpe), X[:, 2:].reshape( - nshpe), stst, shocks, pars, dists_storage, decisions_output_storage) + nshpe), stst, pars, shocks, dists_storage, decisions_output_storage) return out diff --git a/econpizza/solvers/shooting.py b/econpizza/solvers/shooting.py index 07a0d7f..9a41d7b 100644 --- a/econpizza/solvers/shooting.py +++ b/econpizza/solvers/shooting.py @@ -101,8 +101,7 @@ def solve_current(pars, shock, XLag, XLastGuess, XPrime): """Solves for one period. """ - # partial_func = jax.tree_util.Partial(func, XLag=XLag, XPrime=XPrime, XSS=stst, shocks=shock, pars=pars) - def partial_func(x): return func(XLag, x, XPrime, stst, shock, pars) + def partial_func(x): return func(XLag, x, XPrime, stst, pars, shock) jav = val_and_jacfwd(partial_func) partial_jav = jax.tree_util.Partial(jav) res = newton_jax_jit(partial_jav, XLastGuess, verbose=False) diff --git a/econpizza/solvers/stacking.py b/econpizza/solvers/stacking.py index 44ad46a..3a68c56 100644 --- a/econpizza/solvers/stacking.py +++ b/econpizza/solvers/stacking.py @@ -72,22 +72,19 @@ def find_path_stacking( # only skip jacobian calculation if it exists skip_jacobian = skip_jacobian if self['cache'].get( 'jac_factorized') else False + # get transformers (experimental) + transform_forw = self['options'].get('transform_to') or (lambda x: x) + transform_back = self['options'].get('transform_back') or (lambda x: x) # get variables nvars = len(self["var_names"]) - if self.get('exp_all'): - stst = jnp.log(d2jnp(self["stst"])) - pars = jnp.log(d2jnp(pars if pars is not None else self["pars"])) - else: - stst = d2jnp(self["stst"]) - pars = d2jnp(pars if pars is not None else self["pars"]) + stst = transform_back(d2jnp(self["stst"])) + pars = transform_back(d2jnp(pars if pars is not None else self["pars"])) shocks = self.get("shocks") or () # get initial guess - if self.get('exp_all'): - x0 = jnp.log(jnp.array(list(init_state))) if init_state is not None else stst - else: - x0 = jnp.array(list(init_state)) if init_state is not None else stst + x0 = transform_back(jnp.array(list(init_state)) + ) if init_state is not None else stst init_dist = init_dist if init_dist is not None else self['steady_state'].get( 'distributions') dist0 = jnp.array(init_dist if init_dist is not None else jnp.nan) @@ -110,7 +107,7 @@ def find_path_stacking( func_eqns = self['context']["func_eqns"] jav_func_eqns = val_and_jacrev(func_eqns, (0, 1, 2)) jav_func_eqns_partial = jax.tree_util.Partial( - jav_func_eqns, XSS=stst, pars=pars, distributions=[], decisions_outputs=[]) + jav_func_eqns, pars=pars, XSS=stst, distributions=[], decisions_outputs=[]) self['context']['jav_func'] = jav_func_eqns_partial # mark as compiled write_cache(self, horizon, pars, stst) @@ -161,7 +158,4 @@ def find_path_stacking( elif verbose: print(mess) - if self.get('exp_all'): - return jnp.exp(x_out), (flag, f) - else: - return x_out, (flag, f) + return transform_forw(x_out), (flag, f) diff --git a/econpizza/solvers/steady_state.py b/econpizza/solvers/steady_state.py index fa51f0f..5e74769 100644 --- a/econpizza/solvers/steady_state.py +++ b/econpizza/solvers/steady_state.py @@ -99,30 +99,26 @@ def solve_stst(self, tol=1e-8, maxit=15, tol_backwards=None, maxit_backwards=200 func_backw = self['context'].get('func_backw') func_forw_stst = self['context'].get('func_forw_stst') func_pre_stst = self['context']['func_pre_stst'] + # get transformers (experimental) + transform_forw = self['options'].get('transform_to') or (lambda x: x) + transform_back = self['options'].get('transform_back') or (lambda x: x) # get initial values for heterogenous agents decisions_output_init = self['context']['init_run'].get( 'decisions_output') # get the actual steady state function - if self.get('exp_all'): - func_stst = get_func_stst(func_backw, func_forw_stst, func_eqns, shocks, wf_init, decisions_output_init, fixed_values=jnp.log(d2jnp(fixed_vals)), pre_stst_mapping=pre_stst_mapping, tol_backw=tol_backwards, maxit_backw=maxit_backwards, tol_forw=tol_forwards, maxit_forw=maxit_forwards) - else: - func_stst = get_func_stst(func_backw, func_forw_stst, func_eqns, shocks, wf_init, decisions_output_init, fixed_values=d2jnp(fixed_vals), pre_stst_mapping=pre_stst_mapping, tol_backw=tol_backwards, maxit_backw=maxit_backwards, tol_forw=tol_forwards, maxit_forw=maxit_forwards) + func_stst = get_func_stst(func_backw, func_forw_stst, func_eqns, shocks, wf_init, decisions_output_init, fixed_values=transform_back(d2jnp( + fixed_vals)), pre_stst_mapping=pre_stst_mapping, tol_backw=tol_backwards, maxit_backw=maxit_backwards, tol_forw=tol_forwards, maxit_forw=maxit_forwards) # store jitted stst function that returns jacobian and func. value self["context"]['func_stst'] = func_stst if not self['steady_state'].get('skip'): # actual root finding - if self.get('exp_all'): - res = newton_jax(func_stst, jnp.log(d2jnp(init_vals)), maxit, tol, solver=solver, verbose=verbose, **newton_kwargs) - else: - res = newton_jax(func_stst, d2jnp(init_vals), maxit, tol, solver=solver, verbose=verbose, **newton_kwargs) + res = newton_jax(func_stst, transform_back( + d2jnp(init_vals)), maxit, tol, solver=solver, verbose=verbose, **newton_kwargs) else: - if self.get('exp_all'): - f, jac, aux = func_stst(jnp.log(d2jnp(init_vals))) - else: - f, jac, aux = func_stst(d2jnp(init_vals)) + f, jac, aux = func_stst(transform_back(d2jnp(init_vals))) res = {'x': d2jnp(init_vals), 'fun': f, 'jac': jac, @@ -132,22 +128,18 @@ def solve_stst(self, tol=1e-8, maxit=15, tol_backwards=None, maxit_backwards=200 } # exchange those values that are identified via stst_equations - if self.get('exp_all'): - stst_vals, par_vals = func_pre_stst(res['x'], jnp.log(d2jnp(fixed_vals)), pre_stst_mapping) - else: - stst_vals, par_vals = func_pre_stst(res['x'], d2jnp(fixed_vals), pre_stst_mapping) - res['initial_values'] = {'guesses': init_vals, 'fixed': fixed_vals, 'value_functions': wf_init, 'decisions': decisions_output_init} + stst_vals, par_vals = func_pre_stst( + res['x'], transform_back(d2jnp(fixed_vals)), pre_stst_mapping) + res['initial_values'] = {'guesses': init_vals, 'fixed': fixed_vals, + 'value_functions': wf_init, 'decisions': decisions_output_init} # store results self['steady_state']['root_finding_result'] = res - if self.get('exp_all'): - self['steady_state']['found_values'] = dict(zip(init_vals.keys(),jnp.exp(res['x']))) - self['stst'] = self['steady_state']['all_values'] = dict(zip(evars, jnp.exp(stst_vals))) - self['pars'] = dict(zip(par_names, jnp.exp(par_vals))) - else: - self['steady_state']['found_values'] = dict(zip(init_vals.keys(),res['x'])) - self['stst'] = self['steady_state']['all_values'] = dict(zip(evars, stst_vals)) - self['pars'] = dict(zip(par_names, par_vals)) + self['steady_state']['found_values'] = dict( + zip(init_vals.keys(), transform_forw(res['x']))) + self['stst'] = self['steady_state']['all_values'] = dict( + zip(evars, transform_forw(stst_vals))) + self['pars'] = dict(zip(par_names, transform_forw(par_vals))) # calculate dist objects and compile message mess = _get_stst_dist_objs(self, res, maxit_backwards, diff --git a/econpizza/testing/test_config.py b/econpizza/testing/test_config.py deleted file mode 100644 index d53496b..0000000 --- a/econpizza/testing/test_config.py +++ /dev/null @@ -1,87 +0,0 @@ -"""Tests for the config module. Delete any __econpizza__ or __jax_cache__ folders you might have in the current folder before running""" -import pytest -import jax -from unittest.mock import patch -import shutil -import os -import sys -# autopep8: off -sys.path.insert(0, os.path.abspath(".")) -import econpizza as ep -from econpizza.config import EconPizzaConfig -# autopep8: on - -@pytest.fixture(scope="function", autouse=True) -def ep_config_reset(): - ep.config = EconPizzaConfig() - -@pytest.fixture(scope="function", autouse=True) -def os_getcwd_create(): - test_cache_folder = os.path.abspath("config_working_dir") - - if not os.path.exists(test_cache_folder): - os.makedirs(test_cache_folder) - - with patch("os.getcwd", return_value=test_cache_folder): - yield - - if os.path.exists(test_cache_folder): - shutil.rmtree(test_cache_folder) - -def test_config_default_values(): - assert ep.config["enable_jax_persistent_cache"] == False - assert ep.config.jax_cache_folder == "__jax_cache__" - -def test_config_jax_default_values(): - assert jax.config.values["jax_compilation_cache_dir"] is None - assert jax.config.values["jax_persistent_cache_min_entry_size_bytes"] == .0 - assert jax.config.values["jax_persistent_cache_min_compile_time_secs"] == 1.0 - -@patch("os.makedirs") -@patch("jax.config.update") -def test_config_enable_jax_persistent_cache(mock_jax_update, mock_makedirs): - ep.config["enable_jax_persistent_cache"] = True - mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "__jax_cache__"), exist_ok=True) - - mock_jax_update.assert_any_call("jax_compilation_cache_dir", os.path.join(os.getcwd(), "__jax_cache__")) - mock_jax_update.assert_any_call("jax_persistent_cache_min_compile_time_secs", 0) - -@patch("os.makedirs") -@patch("jax.config.update") -def test_config_set_jax_folder(mock_jax_update, mock_makedirs): - ep.config.jax_cache_folder = "test1" - ep.config["enable_jax_persistent_cache"] = True - mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "test1"), exist_ok=True) - mock_jax_update.assert_any_call("jax_compilation_cache_dir", os.path.join(os.getcwd(), "test1")) - -@patch("jax.config.update") -def test_config_jax_folder_set_from_outside(mock_jax_update): - mock_jax_update("jax_compilation_cache_dir", "jax_from_outside") - ep.config["enable_jax_persistent_cache"] = True - mock_jax_update.assert_any_call("jax_compilation_cache_dir", "jax_from_outside") - -@patch("os.path.exists") -@patch("os.makedirs") -@patch("jax.config.update") -def test_jax_cache_folder_not_created_second_time(mock_jax_update, mock_makedirs, mock_exists): - # Set to first return False when the folder is not created, then True when the folder is created - mock_exists.side_effect = [False, True] - - # When called for the first time, a cache folder should be created(default is __jax_cache__) - ep.config["enable_jax_persistent_cache"] = True - mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "__jax_cache__"), exist_ok=True) - assert mock_jax_update.call_count == 2 - # Now reset the mock so that the calls are 0 again. - mock_makedirs.reset_mock() - mock_jax_update.reset_mock() - # The second time we should not create a folder - ep.config["enable_jax_persistent_cache"] = True - mock_makedirs.assert_not_called() - assert mock_jax_update.call_count == 0 - -def test_config_enable_jax_persistent_cache_called_after_model_load(): - _ = ep.load(ep.examples.dsge) - - assert os.path.exists(ep.config.jax_cache_folder) == False - ep.config["enable_jax_persistent_cache"] = True - assert os.path.exists(ep.config.jax_cache_folder) == True diff --git a/econpizza/testing/test_rest_config.py b/econpizza/testing/test_rest_config.py new file mode 100644 index 0000000..8d51388 --- /dev/null +++ b/econpizza/testing/test_rest_config.py @@ -0,0 +1,103 @@ +"""Tests for the config module. Delete any __econpizza__ or __jax_cache__ folders you might have in the current folder before running""" +import pytest +import jax +from unittest.mock import patch +import shutil +import os +import sys +# autopep8: off +sys.path.insert(0, os.path.abspath(".")) +import econpizza as ep +from econpizza.config import EconPizzaConfig +# autopep8: on + + +@pytest.fixture(scope="function", autouse=True) +def ep_config_reset(): + ep.config = EconPizzaConfig() + + +@pytest.fixture(scope="function", autouse=True) +def os_getcwd_create(): + test_cache_folder = os.path.abspath("config_working_dir") + + if not os.path.exists(test_cache_folder): + os.makedirs(test_cache_folder) + + with patch("os.getcwd", return_value=test_cache_folder): + yield + + if os.path.exists(test_cache_folder): + shutil.rmtree(test_cache_folder) + + +def test_config_default_values(): + assert ep.config["enable_jax_persistent_cache"] == False + assert ep.config.jax_cache_folder == "__jax_cache__" + + +def test_config_jax_default_values(): + assert jax.config.values["jax_compilation_cache_dir"] is None + assert jax.config.values["jax_persistent_cache_min_entry_size_bytes"] == .0 + assert jax.config.values["jax_persistent_cache_min_compile_time_secs"] == 1.0 + + +@patch("os.makedirs") +@patch("jax.config.update") +def test_config_enable_jax_persistent_cache(mock_jax_update, mock_makedirs): + ep.config["enable_jax_persistent_cache"] = True + mock_makedirs.assert_any_call(os.path.join( + os.getcwd(), "__jax_cache__"), exist_ok=True) + + mock_jax_update.assert_any_call( + "jax_compilation_cache_dir", os.path.join(os.getcwd(), "__jax_cache__")) + mock_jax_update.assert_any_call( + "jax_persistent_cache_min_compile_time_secs", 0) + + +@patch("os.makedirs") +@patch("jax.config.update") +def test_config_set_jax_folder(mock_jax_update, mock_makedirs): + ep.config.jax_cache_folder = "test1" + ep.config["enable_jax_persistent_cache"] = True + mock_makedirs.assert_any_call(os.path.join( + os.getcwd(), "test1"), exist_ok=True) + mock_jax_update.assert_any_call( + "jax_compilation_cache_dir", os.path.join(os.getcwd(), "test1")) + + +@patch("jax.config.update") +def test_config_jax_folder_set_from_outside(mock_jax_update): + mock_jax_update("jax_compilation_cache_dir", "jax_from_outside") + ep.config["enable_jax_persistent_cache"] = True + mock_jax_update.assert_any_call( + "jax_compilation_cache_dir", "jax_from_outside") + + +@patch("os.path.exists") +@patch("os.makedirs") +@patch("jax.config.update") +def test_jax_cache_folder_not_created_second_time(mock_jax_update, mock_makedirs, mock_exists): + # Set to first return False when the folder is not created, then True when the folder is created + mock_exists.side_effect = [False, True] + + # When called for the first time, a cache folder should be created(default is __jax_cache__) + ep.config["enable_jax_persistent_cache"] = True + mock_makedirs.assert_any_call(os.path.join( + os.getcwd(), "__jax_cache__"), exist_ok=True) + assert mock_jax_update.call_count == 2 + # Now reset the mock so that the calls are 0 again. + mock_makedirs.reset_mock() + mock_jax_update.reset_mock() + # The second time we should not create a folder + ep.config["enable_jax_persistent_cache"] = True + mock_makedirs.assert_not_called() + assert mock_jax_update.call_count == 0 + + +def test_config_enable_jax_persistent_cache_called_after_model_load(): + _ = ep.load(ep.examples.dsge) + + assert os.path.exists(ep.config.jax_cache_folder) == False + ep.config["enable_jax_persistent_cache"] = True + assert os.path.exists(ep.config.jax_cache_folder) == True