Skip to content

Commit

Permalink
Merge pull request #17 from shtopane/back-up-2
Browse files Browse the repository at this point in the history
Enable JAX persistent caching via `EconPizzaConfig`
  • Loading branch information
gboehl authored Oct 11, 2024
2 parents 8bf74aa + c804b0c commit 958d700
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 1 deletion.
2 changes: 1 addition & 1 deletion econpizza/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .solvers.solve_linear_state_space import solve_linear_state_space, find_path_linear_state_space
from .solvers.shooting import find_path_shooting
from .parser import parse, load

from .config import config

# set number of cores for XLA
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={os.cpu_count()}"
Expand Down
41 changes: 41 additions & 0 deletions econpizza/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import os
import jax

class EconPizzaConfig(dict):
def __init__(self, *args, **kwargs):
super(EconPizzaConfig, self).__init__(*args, **kwargs)
self.__dict__ = self
self.enable_jax_persistent_cache = False
self.jax_cache_folder = "__jax_cache__"

self._setup_persistent_cache_map = {
"enable_jax_persistent_cache": self.setup_persistent_cache_jax
}

def __setitem__(self, key, value):
return self.update(key, value)

def update(self, key, value):
"""Updates the attribute, and if it's related to caching, calls the appropriate setup function."""
if hasattr(self, key):
setattr(self, key, value)
if key in self._setup_persistent_cache_map and value:
self._setup_persistent_cache_map[key]()
else:
raise AttributeError(f"'EconPizzaConfig' object has no attribute '{key}'")

def _create_cache_dir(self, folder_name: str):
cwd = os.getcwd()
folder_path = os.path.join(cwd, folder_name)
os.makedirs(folder_path, exist_ok=True)
return folder_path

def setup_persistent_cache_jax(self):
"""Setup JAX persistent cache if enabled."""
if jax.config.jax_compilation_cache_dir is None and not os.path.exists(self.jax_cache_folder):
folder_path_jax = self._create_cache_dir(self.jax_cache_folder)
jax.config.update("jax_compilation_cache_dir", folder_path_jax)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
self.jax_cache_folder = folder_path_jax

config = EconPizzaConfig()
87 changes: 87 additions & 0 deletions econpizza/testing/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""Tests for the config module. Delete any __econpizza__ or __jax_cache__ folders you might have in the current folder before running"""
import pytest
import jax
from unittest.mock import patch
import shutil
import os
import sys
# autopep8: off
sys.path.insert(0, os.path.abspath("."))
import econpizza as ep
from econpizza.config import EconPizzaConfig
# autopep8: on

@pytest.fixture(scope="function", autouse=True)
def ep_config_reset():
ep.config = EconPizzaConfig()

@pytest.fixture(scope="function", autouse=True)
def os_getcwd_create():
test_cache_folder = os.path.abspath("config_working_dir")

if not os.path.exists(test_cache_folder):
os.makedirs(test_cache_folder)

with patch("os.getcwd", return_value=test_cache_folder):
yield

if os.path.exists(test_cache_folder):
shutil.rmtree(test_cache_folder)

def test_config_default_values():
assert ep.config["enable_jax_persistent_cache"] == False
assert ep.config.jax_cache_folder == "__jax_cache__"

def test_config_jax_default_values():
assert jax.config.values["jax_compilation_cache_dir"] is None
assert jax.config.values["jax_persistent_cache_min_entry_size_bytes"] == .0
assert jax.config.values["jax_persistent_cache_min_compile_time_secs"] == 1.0

@patch("os.makedirs")
@patch("jax.config.update")
def test_config_enable_jax_persistent_cache(mock_jax_update, mock_makedirs):
ep.config["enable_jax_persistent_cache"] = True
mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "__jax_cache__"), exist_ok=True)

mock_jax_update.assert_any_call("jax_compilation_cache_dir", os.path.join(os.getcwd(), "__jax_cache__"))
mock_jax_update.assert_any_call("jax_persistent_cache_min_compile_time_secs", 0)

@patch("os.makedirs")
@patch("jax.config.update")
def test_config_set_jax_folder(mock_jax_update, mock_makedirs):
ep.config.jax_cache_folder = "test1"
ep.config["enable_jax_persistent_cache"] = True
mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "test1"), exist_ok=True)
mock_jax_update.assert_any_call("jax_compilation_cache_dir", os.path.join(os.getcwd(), "test1"))

@patch("jax.config.update")
def test_config_jax_folder_set_from_outside(mock_jax_update):
mock_jax_update("jax_compilation_cache_dir", "jax_from_outside")
ep.config["enable_jax_persistent_cache"] = True
mock_jax_update.assert_any_call("jax_compilation_cache_dir", "jax_from_outside")

@patch("os.path.exists")
@patch("os.makedirs")
@patch("jax.config.update")
def test_jax_cache_folder_not_created_second_time(mock_jax_update, mock_makedirs, mock_exists):
# Set to first return False when the folder is not created, then True when the folder is created
mock_exists.side_effect = [False, True]

# When called for the first time, a cache folder should be created(default is __jax_cache__)
ep.config["enable_jax_persistent_cache"] = True
mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "__jax_cache__"), exist_ok=True)
assert mock_jax_update.call_count == 2
# Now reset the mock so that the calls are 0 again.
mock_makedirs.reset_mock()
mock_jax_update.reset_mock()
# The second time we should not create a folder
ep.config["enable_jax_persistent_cache"] = True
mock_makedirs.assert_not_called()
assert mock_jax_update.call_count == 0

def test_config_enable_jax_persistent_cache_called_after_model_load():
_ = ep.load(ep.examples.dsge)

assert os.path.exists(ep.config.jax_cache_folder) == False
ep.config["enable_jax_persistent_cache"] = True
assert os.path.exists(ep.config.jax_cache_folder) == True

0 comments on commit 958d700

Please sign in to comment.