Skip to content

fix: replace hardcoded .cuda() with dynamic device selection#553

Open
Mr-Neutr0n wants to merge 1 commit intozai-org:mainfrom
Mr-Neutr0n:fix/hardcoded-cuda-device
Open

fix: replace hardcoded .cuda() with dynamic device selection#553
Mr-Neutr0n wants to merge 1 commit intozai-org:mainfrom
Mr-Neutr0n:fix/hardcoded-cuda-device

Conversation

@Mr-Neutr0n
Copy link
Copy Markdown

Bug

The demo scripts basic_demo/cli_demo_sat.py and basic_demo/web_demo.py hardcode .cuda() calls and device='cuda', which causes failures on non-CUDA systems (e.g., Apple Silicon with MPS, or CPU-only machines).

Fix

  • Added a _get_device() helper that detects the best available device (CUDA > MPS > CPU).
  • Replaced hardcoded device='cuda' with device=_get_device() in model loading.
  • Replaced model.cuda() with model.to(device) in quantization blocks.
  • Broadened the device guard from torch.cuda.is_available() to device != 'cpu' so MPS users also benefit from GPU acceleration when quantization is enabled.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant