Replies: 3 comments
-
Hey I simplified your example and could not find any errors: import jax
from pathlib import Path
from flax.training import checkpoints
async_manager = checkpoints.AsyncManager()
for i in range(3):
bigarray = jax.random.normal(jax.random.PRNGKey(i), (30000, 10000))
checkpoints.save_checkpoint(
Path(),
bigarray,
step=i,
prefix=f"checkpoint_",
overwrite=True,
async_manager=async_manager,
) Can you try to pin-point the error you are observing? |
Beta Was this translation helpful? Give feedback.
-
Yeah, this example works. My point is the wall time keep the same with or without |
Beta Was this translation helpful? Give feedback.
-
I ran very basic bechmark to find the impact of using Codeimport jax
from pathlib import Path
from flax.training import checkpoints
from timeit import default_timer as timer
from tempfile import TemporaryDirectory
bigarray = jax.random.normal(jax.random.PRNGKey(0), (30000, 10000)).block_until_ready()
async_manager = checkpoints.AsyncManager()
with TemporaryDirectory() as tmpdir:
i = 0
start = timer()
checkpoints.save_checkpoint(
Path(f"{tmpdir}/checkpoint_{i}"),
bigarray,
step=i,
async_manager=async_manager, # comment to run without async_manager
)
end = timer()
print(f"checkpoint{i} saved in {end - start} seconds") Using
|
Beta Was this translation helpful? Give feedback.
-
I tried a few ways to try the
async_manager
feature, however all of them seems not work and my task still block as usual, here is my minimal exampleMay I do it wrong? As I can not find any example relevant except the simple comment in the source code.
Beta Was this translation helpful? Give feedback.
All reactions