Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cuda_core/cuda/core/experimental/_event.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ from typing import TYPE_CHECKING, Optional
from cuda.core.experimental._context import Context
from cuda.core.experimental._utils.cuda_utils import (
CUDAError,
check_multiprocessing_start_method,
driver,
)
if TYPE_CHECKING:
Expand Down Expand Up @@ -300,6 +301,7 @@ cdef class IPCEventDescriptor:


def _reduce_event(event):
check_multiprocessing_start_method()
return event.from_ipc_descriptor, (event.get_ipc_descriptor(),)

multiprocessing.reduction.register(Event, _reduce_event)
3 changes: 3 additions & 0 deletions cuda_core/cuda/core/experimental/_memory/_ipc.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ from cuda.bindings cimport cydriver
from cuda.core.experimental._memory._buffer cimport Buffer
from cuda.core.experimental._stream cimport default_stream
from cuda.core.experimental._utils.cuda_utils cimport HANDLE_RETURN
from cuda.core.experimental._utils.cuda_utils import check_multiprocessing_start_method

import multiprocessing
import os
Expand Down Expand Up @@ -129,6 +130,7 @@ cdef class IPCAllocationHandle:


def _reduce_allocation_handle(alloc_handle):
check_multiprocessing_start_method()
df = multiprocessing.reduction.DupFd(alloc_handle.handle)
return _reconstruct_allocation_handle, (type(alloc_handle), df, alloc_handle.uuid)

Expand All @@ -141,6 +143,7 @@ multiprocessing.reduction.register(IPCAllocationHandle, _reduce_allocation_handl


def _deep_reduce_device_memory_resource(mr):
check_multiprocessing_start_method()
from .._device import Device
device = Device(mr.device_id)
alloc_handle = mr.get_allocation_handle()
Expand Down
48 changes: 48 additions & 0 deletions cuda_core/cuda/core/experimental/_utils/cuda_utils.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import functools
from functools import partial
import importlib.metadata
import multiprocessing
import platform
import warnings
from collections import namedtuple
from collections.abc import Sequence
from contextlib import ExitStack
Expand Down Expand Up @@ -283,3 +286,48 @@ class Transaction:
"""
# pop_all() empties this stack so no callbacks are triggered on exit.
self._stack.pop_all()


# Track whether we've already warned about fork method
_fork_warning_checked = False


def reset_fork_warning():
"""Reset the fork warning check flag for testing purposes.

This function is intended for use in tests to allow multiple test runs
to check the warning behavior.
"""
global _fork_warning_checked
_fork_warning_checked = False


def check_multiprocessing_start_method():
"""Check if multiprocessing start method is 'fork' and warn if so."""
global _fork_warning_checked
if _fork_warning_checked:
return
_fork_warning_checked = True

# Common warning message parts
common_message = (
"CUDA does not support. Forked subprocesses exhibit undefined behavior, "
"including failure to initialize CUDA contexts and devices. Set the start method "
"to 'spawn' before creating processes that use CUDA. "
"Use: multiprocessing.set_start_method('spawn')"
)

try:
start_method = multiprocessing.get_start_method()
if start_method == "fork":
message = f"multiprocessing start method is 'fork', which {common_message}"
warnings.warn(message, UserWarning, stacklevel=3)
except RuntimeError:
# get_start_method() can raise RuntimeError if start method hasn't been set
# In this case, default is 'fork' on Linux, so we should warn
if platform.system() == "Linux":
message = (
f"multiprocessing start method is not set and defaults to 'fork' on Linux, "
f"which {common_message}"
)
warnings.warn(message, UserWarning, stacklevel=3)
149 changes: 149 additions & 0 deletions cuda_core/tests/test_multiprocessing_warning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""
Test that warnings are emitted when multiprocessing start method is 'fork'
and IPC objects are serialized.

These tests use mocking to simulate the 'fork' start method without actually
using fork, avoiding the need for subprocess isolation.
"""

import warnings
from unittest.mock import patch

from cuda.core.experimental import DeviceMemoryResource, DeviceMemoryResourceOptions, EventOptions
from cuda.core.experimental._event import _reduce_event
from cuda.core.experimental._memory._ipc import (
_deep_reduce_device_memory_resource,
_reduce_allocation_handle,
)
from cuda.core.experimental._utils.cuda_utils import reset_fork_warning


def test_warn_on_fork_method_device_memory_resource(ipc_device):
"""Test that warning is emitted when DeviceMemoryResource is pickled with fork method."""
device = ipc_device
device.set_current()
options = DeviceMemoryResourceOptions(max_size=2097152, ipc_enabled=True)
mr = DeviceMemoryResource(device, options=options)

with patch("multiprocessing.get_start_method", return_value="fork"), warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")

# Reset the warning flag to allow testing
reset_fork_warning()

# Trigger the reduction function directly
_deep_reduce_device_memory_resource(mr)

# Check that warning was emitted
assert len(w) == 1, f"Expected 1 warning, got {len(w)}: {[str(warning.message) for warning in w]}"
warning = w[0]
assert warning.category is UserWarning
assert "fork" in str(warning.message).lower()
assert "spawn" in str(warning.message).lower()
assert "undefined behavior" in str(warning.message).lower()

mr.close()


def test_warn_on_fork_method_allocation_handle(ipc_device):
"""Test that warning is emitted when IPCAllocationHandle is pickled with fork method."""
device = ipc_device
device.set_current()
options = DeviceMemoryResourceOptions(max_size=2097152, ipc_enabled=True)
mr = DeviceMemoryResource(device, options=options)
alloc_handle = mr.get_allocation_handle()

with patch("multiprocessing.get_start_method", return_value="fork"), warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")

# Reset the warning flag to allow testing
reset_fork_warning()

# Trigger the reduction function directly
_reduce_allocation_handle(alloc_handle)

# Check that warning was emitted
assert len(w) == 1
warning = w[0]
assert warning.category is UserWarning
assert "fork" in str(warning.message).lower()

mr.close()


def test_warn_on_fork_method_event(mempool_device):
"""Test that warning is emitted when Event is pickled with fork method."""
device = mempool_device
device.set_current()
stream = device.create_stream()
ipc_event_options = EventOptions(ipc_enabled=True)
event = stream.record(options=ipc_event_options)

with patch("multiprocessing.get_start_method", return_value="fork"), warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")

# Reset the warning flag to allow testing
reset_fork_warning()

# Trigger the reduction function directly
_reduce_event(event)

# Check that warning was emitted
assert len(w) == 1
warning = w[0]
assert warning.category is UserWarning
assert "fork" in str(warning.message).lower()

event.close()


def test_no_warning_with_spawn_method(ipc_device):
"""Test that no warning is emitted when start method is 'spawn'."""
device = ipc_device
device.set_current()
options = DeviceMemoryResourceOptions(max_size=2097152, ipc_enabled=True)
mr = DeviceMemoryResource(device, options=options)

with patch("multiprocessing.get_start_method", return_value="spawn"), warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")

# Reset the warning flag to allow testing
reset_fork_warning()

# Trigger the reduction function directly
_deep_reduce_device_memory_resource(mr)

# Check that no fork-related warning was emitted
fork_warnings = [warning for warning in w if "fork" in str(warning.message).lower()]
assert len(fork_warnings) == 0, f"Unexpected warning: {fork_warnings[0].message if fork_warnings else None}"

mr.close()


def test_warning_emitted_only_once(ipc_device):
"""Test that warning is only emitted once even when multiple objects are pickled."""
device = ipc_device
device.set_current()
options = DeviceMemoryResourceOptions(max_size=2097152, ipc_enabled=True)
mr1 = DeviceMemoryResource(device, options=options)
mr2 = DeviceMemoryResource(device, options=options)

with patch("multiprocessing.get_start_method", return_value="fork"), warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")

# Reset the warning flag to allow testing
reset_fork_warning()

# Trigger reduction multiple times
_deep_reduce_device_memory_resource(mr1)
_deep_reduce_device_memory_resource(mr2)

# Check that warning was emitted only once
fork_warnings = [warning for warning in w if "fork" in str(warning.message).lower()]
assert len(fork_warnings) == 1, f"Expected 1 warning, got {len(fork_warnings)}"

mr1.close()
mr2.close()
Loading