Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cuda_core/cuda/core/experimental/_event.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ from typing import TYPE_CHECKING, Optional
from cuda.core.experimental._context import Context
from cuda.core.experimental._utils.cuda_utils import (
CUDAError,
_check_multiprocessing_start_method,
driver,
)
if TYPE_CHECKING:
Expand Down Expand Up @@ -300,6 +301,7 @@ cdef class IPCEventDescriptor:


def _reduce_event(event):
_check_multiprocessing_start_method()
return event.from_ipc_descriptor, (event.get_ipc_descriptor(),)

multiprocessing.reduction.register(Event, _reduce_event)
3 changes: 3 additions & 0 deletions cuda_core/cuda/core/experimental/_memory/_ipc.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ from cuda.bindings cimport cydriver
from cuda.core.experimental._memory._buffer cimport Buffer
from cuda.core.experimental._stream cimport default_stream
from cuda.core.experimental._utils.cuda_utils cimport HANDLE_RETURN
from cuda.core.experimental._utils.cuda_utils import _check_multiprocessing_start_method

import multiprocessing
import os
Expand Down Expand Up @@ -129,6 +130,7 @@ cdef class IPCAllocationHandle:


def _reduce_allocation_handle(alloc_handle):
_check_multiprocessing_start_method()
df = multiprocessing.reduction.DupFd(alloc_handle.handle)
return _reconstruct_allocation_handle, (type(alloc_handle), df, alloc_handle.uuid)

Expand All @@ -141,6 +143,7 @@ multiprocessing.reduction.register(IPCAllocationHandle, _reduce_allocation_handl


def _deep_reduce_device_memory_resource(mr):
_check_multiprocessing_start_method()
from .._device import Device
device = Device(mr.device_id)
alloc_handle = mr.get_allocation_handle()
Expand Down
39 changes: 39 additions & 0 deletions cuda_core/cuda/core/experimental/_utils/cuda_utils.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import functools
from functools import partial
import importlib.metadata
import multiprocessing
import platform
import warnings
from collections import namedtuple
from collections.abc import Sequence
from contextlib import ExitStack
Expand Down Expand Up @@ -283,3 +286,39 @@ class Transaction:
"""
# pop_all() empties this stack so no callbacks are triggered on exit.
self._stack.pop_all()


# Track whether we've already warned about fork method
_fork_warning_emitted = False


def _check_multiprocessing_start_method():
"""Check if multiprocessing start method is 'fork' and warn if so."""
global _fork_warning_emitted
if _fork_warning_emitted:
return
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of caching it, would it be better to always call multiprocessing.get_start_method() and check it? I worry that it is a global state that we don't own and it could be changed at arbitrary point in time by the user or any package.

Copy link
Contributor Author

@Andy-Jost Andy-Jost Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the official multiprocessing library docs, the description of multiprocessing.set_start_method() explains that:

  • The start method is a global setting that can only be chosen a single time in a given Python process.
  • Calling set_start_method() again without forcing will raise a RuntimeError indicating that the context has already been set.
  • Recent CPython implementations mention an internal/undocumented way to override this with a force argument, but this is not part of the public, documented API and should be avoided in normal code.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, Andy. Good to know Python by default checks this. However, I do see the force argument being documented, so it is part of the public API. And it does allow overwriting:

>>> import multiprocessing
>>> multiprocessing.set_start_method("spawn")
>>> multiprocessing.set_start_method("spawn")
Traceback (most recent call last):
  File "<python-input-2>", line 1, in <module>
    multiprocessing.set_start_method("spawn")
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^
  File "/local/home/leof/miniforge3/envs/py313_cu130/lib/python3.13/multiprocessing/context.py", line 247, in set_start_method
    raise RuntimeError('context has already been set')
RuntimeError: context has already been set
>>> multiprocessing.set_start_method("fork", force=True)
>>> multiprocessing.set_start_method("spawn", force=True)
>>> multiprocessing.set_start_method("forkserver", force=True)
>>> 

If this is not on the hot path, I think not caching the result and always checking is safer.

Alternatively, we could mention in the warning message that setting set_start_method(..., force=True) is dangerous because we cannot help capture issues.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One more thing, if we already warned the next warn() call is a no-op so adding the _fork_warning_emitted guard is redundant 😆

import warnings

def f():
    warnings.warn("oh no", UserWarning, stacklevel=3)

f()  # only this call would generate a warning 
f()
f()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see, you want to limit the warning to only once per process lifetime, regardless of which object raises the warning from. NVM.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On minor detail -- since set_start_method can only be set once, I don't think we need to check it multiple times. In other words, we can set _fork_warning_emitted = True unconditionally rather than only when a warning is emitted so we don't check again. And maybe renaming the flag to _fork_warning_checked for clarity. Unless I'm missing some case where it might be set exceptionally late and it does need to be checked multiple times.

Copy link
Contributor Author

@Andy-Jost Andy-Jost Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comments. I've made adjustments to the name and conditionality as Mike suggested. I left this as a one-time check because there does not seem to be a strong consensus, but I don't feel strongly about that and wouldn't mind changing it if more discussion goes in that direction.

Incidentally, a separate warning is issued when fork is called in multithreaded programs (like those using CUDA), so if a user somehow sidesteps this warning by setting the start method multiple times, they will still get another nasty message indicating their program is invalid.


# Common warning message parts
common_message = (
"CUDA does not support. Forked subprocesses exhibit undefined behavior, "
"including failure to initialize CUDA contexts and devices. Set the start method "
"to 'spawn' before creating processes that use CUDA. "
"Use: multiprocessing.set_start_method('spawn')"
)

try:
start_method = multiprocessing.get_start_method()
if start_method == "fork":
message = f"multiprocessing start method is 'fork', which {common_message}"
warnings.warn(message, UserWarning, stacklevel=3)
_fork_warning_emitted = True
except RuntimeError:
# get_start_method() can raise RuntimeError if start method hasn't been set
# In this case, default is 'fork' on Linux, so we should warn
if platform.system() == "Linux":
message = (
f"multiprocessing start method is not set and defaults to 'fork' on Linux, "
f"which {common_message}"
)
warnings.warn(message, UserWarning, stacklevel=3)
_fork_warning_emitted = True
158 changes: 158 additions & 0 deletions cuda_core/tests/test_multiprocessing_warning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""
Test that warnings are emitted when multiprocessing start method is 'fork'
and IPC objects are serialized.

These tests use mocking to simulate the 'fork' start method without actually
using fork, avoiding the need for subprocess isolation.
"""

import warnings
from unittest.mock import patch

from cuda.core.experimental import DeviceMemoryResource, DeviceMemoryResourceOptions, EventOptions
from cuda.core.experimental._event import _reduce_event
from cuda.core.experimental._memory._ipc import (
_deep_reduce_device_memory_resource,
_reduce_allocation_handle,
)


def test_warn_on_fork_method_device_memory_resource(ipc_device):
"""Test that warning is emitted when DeviceMemoryResource is pickled with fork method."""
device = ipc_device
device.set_current()
options = DeviceMemoryResourceOptions(max_size=2097152, ipc_enabled=True)
mr = DeviceMemoryResource(device, options=options)

with patch("multiprocessing.get_start_method", return_value="fork"), warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")

# Reset the warning flag to allow testing
from cuda.core.experimental._utils import cuda_utils

cuda_utils._fork_warning_emitted = False

# Trigger the reduction function directly
_deep_reduce_device_memory_resource(mr)

# Check that warning was emitted
assert len(w) == 1, f"Expected 1 warning, got {len(w)}: {[str(warning.message) for warning in w]}"
warning = w[0]
assert warning.category is UserWarning
assert "fork" in str(warning.message).lower()
assert "spawn" in str(warning.message).lower()
assert "undefined behavior" in str(warning.message).lower()

mr.close()


def test_warn_on_fork_method_allocation_handle(ipc_device):
"""Test that warning is emitted when IPCAllocationHandle is pickled with fork method."""
device = ipc_device
device.set_current()
options = DeviceMemoryResourceOptions(max_size=2097152, ipc_enabled=True)
mr = DeviceMemoryResource(device, options=options)
alloc_handle = mr.get_allocation_handle()

with patch("multiprocessing.get_start_method", return_value="fork"), warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")

# Reset the warning flag to allow testing
from cuda.core.experimental._utils import cuda_utils

cuda_utils._fork_warning_emitted = False

# Trigger the reduction function directly
_reduce_allocation_handle(alloc_handle)

# Check that warning was emitted
assert len(w) == 1
warning = w[0]
assert warning.category is UserWarning
assert "fork" in str(warning.message).lower()

mr.close()


def test_warn_on_fork_method_event(mempool_device):
"""Test that warning is emitted when Event is pickled with fork method."""
device = mempool_device
device.set_current()
stream = device.create_stream()
ipc_event_options = EventOptions(ipc_enabled=True)
event = stream.record(options=ipc_event_options)

with patch("multiprocessing.get_start_method", return_value="fork"), warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")

# Reset the warning flag to allow testing
from cuda.core.experimental._utils import cuda_utils

cuda_utils._fork_warning_emitted = False

# Trigger the reduction function directly
_reduce_event(event)

# Check that warning was emitted
assert len(w) == 1
warning = w[0]
assert warning.category is UserWarning
assert "fork" in str(warning.message).lower()

event.close()


def test_no_warning_with_spawn_method(ipc_device):
"""Test that no warning is emitted when start method is 'spawn'."""
device = ipc_device
device.set_current()
options = DeviceMemoryResourceOptions(max_size=2097152, ipc_enabled=True)
mr = DeviceMemoryResource(device, options=options)

with patch("multiprocessing.get_start_method", return_value="spawn"), warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")

# Reset the warning flag to allow testing
from cuda.core.experimental._utils import cuda_utils

cuda_utils._fork_warning_emitted = False

# Trigger the reduction function directly
_deep_reduce_device_memory_resource(mr)

# Check that no fork-related warning was emitted
fork_warnings = [warning for warning in w if "fork" in str(warning.message).lower()]
assert len(fork_warnings) == 0, f"Unexpected warning: {fork_warnings[0].message if fork_warnings else None}"

mr.close()


def test_warning_emitted_only_once(ipc_device):
"""Test that warning is only emitted once even when multiple objects are pickled."""
device = ipc_device
device.set_current()
options = DeviceMemoryResourceOptions(max_size=2097152, ipc_enabled=True)
mr1 = DeviceMemoryResource(device, options=options)
mr2 = DeviceMemoryResource(device, options=options)

with patch("multiprocessing.get_start_method", return_value="fork"), warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")

# Reset the warning flag to allow testing
from cuda.core.experimental._utils import cuda_utils

cuda_utils._fork_warning_emitted = False

# Trigger reduction multiple times
_deep_reduce_device_memory_resource(mr1)
_deep_reduce_device_memory_resource(mr2)

# Check that warning was emitted only once
fork_warnings = [warning for warning in w if "fork" in str(warning.message).lower()]
assert len(fork_warnings) == 1, f"Expected 1 warning, got {len(fork_warnings)}"

mr1.close()
mr2.close()
Loading