Skip to content

Commit 50afba7

Browse files
Add attempt to work around the safetensors mmap issue. (#8928)
1 parent 6b8062f commit 50afba7

File tree

2 files changed

+6
-1
lines changed

2 files changed

+6
-1
lines changed

comfy/cli_args.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ class PerformanceFeature(enum.Enum):
144144
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops")
145145

146146
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
147+
parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.")
147148

148149
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
149150
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")

comfy/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from comfy.cli_args import args
3232

3333
MMAP_TORCH_FILES = args.mmap_torch_files
34+
DISABLE_MMAP = args.disable_mmap
3435

3536
ALWAYS_SAFE_LOAD = False
3637
if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
@@ -58,7 +59,10 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
5859
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
5960
sd = {}
6061
for k in f.keys():
61-
sd[k] = f.get_tensor(k)
62+
tensor = f.get_tensor(k)
63+
if DISABLE_MMAP: # TODO: Not sure if this is the best way to bypass the mmap issues
64+
tensor = tensor.to(device=device, copy=True)
65+
sd[k] = tensor
6266
if return_metadata:
6367
metadata = f.metadata()
6468
except Exception as e:

0 commit comments

Comments
 (0)