I encountered an OOM error when running inference_treebench.py on a multi-GPU machine.
The script uses multiprocessing with device_map="auto" , but it does not assign specific GPUs to each worker process. As a result, all spawned processes attempt to load the model onto the same GPU (usually GPU 0), causing OOM.