Skip to content

Commit 9fb8ff5

Browse files
committed
merge mempool classes
1 parent 4ca1b47 commit 9fb8ff5

File tree

5 files changed

+81
-210
lines changed

5 files changed

+81
-210
lines changed

cuda_core/cuda/core/experimental/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from cuda.core.experimental._event import EventOptions
88
from cuda.core.experimental._launcher import LaunchConfig, launch
99
from cuda.core.experimental._linker import Linker, LinkerOptions
10-
from cuda.core.experimental._memory import Mempool
10+
from cuda.core.experimental._memory import AsyncMempool
1111
from cuda.core.experimental._module import ObjectCode
1212
from cuda.core.experimental._program import Program, ProgramOptions
1313
from cuda.core.experimental._stream import Stream, StreamOptions

cuda_core/cuda/core/experimental/_device.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import Union
77

88
from cuda.core.experimental._context import Context, ContextOptions
9-
from cuda.core.experimental._memory import Buffer, MemoryResource, _DefaultAsyncMempool, _SynchronousMemoryResource
9+
from cuda.core.experimental._memory import AsyncMempool, Buffer, MemoryResource, _SynchronousMemoryResource
1010
from cuda.core.experimental._stream import Stream, StreamOptions, default_stream
1111
from cuda.core.experimental._utils import ComputeCapability, CUDAError, driver, handle_return, precondition, runtime
1212

@@ -962,7 +962,7 @@ def __new__(cls, device_id=None):
962962
runtime.cudaDeviceGetAttribute(runtime.cudaDeviceAttr.cudaDevAttrMemoryPoolsSupported, 0)
963963
)
964964
) == 1:
965-
dev._mr = _DefaultAsyncMempool(dev_id)
965+
dev._mr = AsyncMempool._from_device(dev_id)
966966
else:
967967
dev._mr = _SynchronousMemoryResource(dev_id)
968968

0 commit comments

Comments
 (0)