flax.metrics.tensorboard will take most GPU memory #2379
-
Hi, I found an issue that when I use import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = 'false' at the first lines of code. I did not find this issue when I use Update: the minimal codes that can reproduce the issue: from flax.metrics.tensorboard import SummaryWriter
writer = SummaryWriter("tmp")
writer.scalar("key", 1.0, 100)
import ipdb; ipdb.set_trace() # GPU memory will be almost occupied at this point
writer.flush() My installed packages:
Do you know the reason? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
This should not happen, but it is hard to debug without any code. Could you provide us with a minimal example in a Colab where the problem occurs? Generally you should follow the best practices from our examples: use |
Beta Was this translation helpful? Give feedback.
This should not happen, but it is hard to debug without any code. Could you provide us with a minimal example in a Colab where the problem occurs? Generally you should follow the best practices from our examples: use
jax.device_get
to transfer the metrics from the device to the host and useflush()
to flush the summary writer.