diff --git a/mujoco_py/builder.py b/mujoco_py/builder.py index 5280b2c3..dd672076 100644 --- a/mujoco_py/builder.py +++ b/mujoco_py/builder.py @@ -72,9 +72,12 @@ def load_cython_ext(mujoco_path): Builder = MacExtensionBuilder elif sys.platform == 'linux': _ensure_set_env_var("LD_LIBRARY_PATH", lib_path) + if os.getenv('MUJOCO_PY_FORCE_CPU') is None and get_nvidia_lib_dir() is not None: _ensure_set_env_var("LD_LIBRARY_PATH", get_nvidia_lib_dir()) Builder = LinuxGPUExtensionBuilder + elif os.getenv('MUJOCO_PY_FORCE_GPU') is not None: + Builder = LinuxGPUExtensionBuilder else: Builder = LinuxCPUExtensionBuilder elif sys.platform.startswith("win"):