Skip to content

Commit

Permalink
Merge pull request cupy#8774 from kmaehashi/fix-softlink-init-race
Browse files Browse the repository at this point in the history
Fix race during SoftLink initialization
  • Loading branch information
asi1024 authored and chainer-ci committed Dec 3, 2024
1 parent 40d48cd commit 2d187e0
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 9 deletions.
7 changes: 4 additions & 3 deletions cupy_backends/cuda/api/_driver_extern.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,10 @@ cdef inline void initialize() except *:
global _L
if _L is not None:
return
_initialize()
_L = _initialize()


cdef void _initialize() except *:
global _L
cdef SoftLink _initialize():
_L = _get_softlink()

cdef str version = '_v2' if CUPY_CUDA_VERSION != 0 else ''
Expand Down Expand Up @@ -139,6 +138,8 @@ cdef void _initialize() except *:
global cuStreamGetCtx
cuStreamGetCtx = <F_cuStreamGetCtx>_L.get('StreamGetCtx')

return _L


cdef SoftLink _get_softlink():
cdef str prefix = 'cu'
Expand Down
7 changes: 4 additions & 3 deletions cupy_backends/cuda/api/_runtime_softlink.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,17 @@ cdef inline void initialize() except *:
global _L
if _L is not None:
return
_initialize()
_L = _initialize()


cdef void _initialize() except *:
global _L
cdef SoftLink _initialize():
_L = _get_softlink()

global DYN_cudaRuntimeGetVersion
DYN_cudaRuntimeGetVersion = <F_cudaRuntimeGetVersion>_L.get('RuntimeGetVersion') # noqa

return _L


cdef SoftLink _get_softlink():
cdef int runtime_version
Expand Down
7 changes: 4 additions & 3 deletions cupy_backends/cuda/libs/_cnvrtc.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,9 @@ cdef inline void initialize() except *:
global _L
if _L is not None:
return
_initialize()
_L = _initialize()

cdef void _initialize() except *:
global _L
cdef SoftLink _initialize():
_L = _get_softlink()

global nvrtcGetErrorString
Expand Down Expand Up @@ -110,6 +109,8 @@ cdef void _initialize() except *:
global nvrtcGetNVVM
nvrtcGetNVVM = <F_nvrtcGetNVVM>_L.get('GetNVVM')

return _L


cdef SoftLink _get_softlink():
cdef int runtime_version
Expand Down

0 comments on commit 2d187e0

Please sign in to comment.