Skip to content

Commit 25735d0

Browse files
committed
Implement RustToCuda for Arc<T>
1 parent 697dcf5 commit 25735d0

File tree

2 files changed

+172
-0
lines changed

2 files changed

+172
-0
lines changed

src/lend/impls/arc.rs

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
use core::sync::atomic::AtomicUsize;
2+
#[cfg(feature = "host")]
3+
use std::mem::ManuallyDrop;
4+
5+
use const_type_layout::{TypeGraphLayout, TypeLayout};
6+
7+
#[cfg(feature = "host")]
8+
use rustacuda::{error::CudaResult, memory::DeviceBox, memory::LockedBox};
9+
10+
use crate::{
11+
deps::alloc::sync::Arc,
12+
lend::{CudaAsRust, RustToCuda, RustToCudaAsync},
13+
safety::PortableBitSemantics,
14+
utils::ffi::DeviceOwnedPointer,
15+
};
16+
17+
#[cfg(any(feature = "host", feature = "device"))]
18+
use crate::utils::ffi::DeviceAccessible;
19+
20+
#[cfg(feature = "host")]
21+
use crate::{
22+
alloc::{CombinedCudaAlloc, CudaAlloc},
23+
host::CudaDropWrapper,
24+
utils::adapter::DeviceCopyWithPortableBitSemantics,
25+
utils::r#async::Async,
26+
utils::r#async::CompletionFnMut,
27+
utils::r#async::NoCompletion,
28+
};
29+
30+
#[doc(hidden)]
31+
#[repr(transparent)]
32+
#[derive(TypeLayout)]
33+
#[allow(clippy::module_name_repetitions)]
34+
pub struct ArcCudaRepresentation<T: PortableBitSemantics + TypeGraphLayout>(
35+
DeviceOwnedPointer<_ArcInner<T>>,
36+
);
37+
38+
// must be kept in sync (hehe)
39+
#[doc(hidden)]
40+
#[derive(TypeLayout)]
41+
#[repr(C)]
42+
pub struct _ArcInner<T: ?Sized> {
43+
strong: AtomicUsize,
44+
weak: AtomicUsize,
45+
data: T,
46+
}
47+
48+
unsafe impl<T: PortableBitSemantics + TypeGraphLayout> RustToCuda for Arc<T> {
49+
#[cfg(all(feature = "host", not(doc)))]
50+
type CudaAllocation =
51+
CudaDropWrapper<DeviceBox<DeviceCopyWithPortableBitSemantics<_ArcInner<T>>>>;
52+
#[cfg(any(not(feature = "host"), doc))]
53+
type CudaAllocation = crate::alloc::SomeCudaAlloc;
54+
type CudaRepresentation = ArcCudaRepresentation<T>;
55+
56+
#[cfg(feature = "host")]
57+
#[allow(clippy::type_complexity)]
58+
unsafe fn borrow<A: CudaAlloc>(
59+
&self,
60+
alloc: A,
61+
) -> CudaResult<(
62+
DeviceAccessible<Self::CudaRepresentation>,
63+
CombinedCudaAlloc<Self::CudaAllocation, A>,
64+
)> {
65+
let inner = ManuallyDrop::new(_ArcInner {
66+
strong: AtomicUsize::new(1),
67+
weak: AtomicUsize::new(1),
68+
data: std::ptr::read(&**self),
69+
});
70+
71+
let mut device_box = CudaDropWrapper::from(DeviceBox::new(
72+
DeviceCopyWithPortableBitSemantics::from_ref(&*inner),
73+
)?);
74+
75+
Ok((
76+
DeviceAccessible::from(ArcCudaRepresentation(DeviceOwnedPointer(
77+
device_box.as_device_ptr().as_raw_mut().cast(),
78+
))),
79+
CombinedCudaAlloc::new(device_box, alloc),
80+
))
81+
}
82+
83+
#[cfg(feature = "host")]
84+
unsafe fn restore<A: CudaAlloc>(
85+
&mut self,
86+
alloc: CombinedCudaAlloc<Self::CudaAllocation, A>,
87+
) -> CudaResult<A> {
88+
let (_alloc_front, alloc_tail) = alloc.split();
89+
Ok(alloc_tail)
90+
}
91+
}
92+
93+
unsafe impl<T: PortableBitSemantics + TypeGraphLayout> RustToCudaAsync for Arc<T> {
94+
#[cfg(all(feature = "host", not(doc)))]
95+
type CudaAllocationAsync = CombinedCudaAlloc<
96+
CudaDropWrapper<LockedBox<DeviceCopyWithPortableBitSemantics<ManuallyDrop<_ArcInner<T>>>>>,
97+
CudaDropWrapper<DeviceBox<DeviceCopyWithPortableBitSemantics<ManuallyDrop<_ArcInner<T>>>>>,
98+
>;
99+
#[cfg(any(not(feature = "host"), doc))]
100+
type CudaAllocationAsync = crate::alloc::SomeCudaAlloc;
101+
102+
#[cfg(feature = "host")]
103+
unsafe fn borrow_async<'stream, A: CudaAlloc>(
104+
&self,
105+
alloc: A,
106+
stream: crate::host::Stream<'stream>,
107+
) -> rustacuda::error::CudaResult<(
108+
Async<'_, 'stream, DeviceAccessible<Self::CudaRepresentation>>,
109+
CombinedCudaAlloc<Self::CudaAllocationAsync, A>,
110+
)> {
111+
use rustacuda::memory::AsyncCopyDestination;
112+
113+
let locked_box = unsafe {
114+
let inner = ManuallyDrop::new(_ArcInner {
115+
strong: AtomicUsize::new(1),
116+
weak: AtomicUsize::new(1),
117+
data: std::ptr::read(&**self),
118+
});
119+
120+
let mut uninit = CudaDropWrapper::from(LockedBox::<
121+
DeviceCopyWithPortableBitSemantics<ManuallyDrop<_ArcInner<T>>>,
122+
>::uninitialized()?);
123+
std::ptr::copy_nonoverlapping(
124+
std::ptr::from_ref(DeviceCopyWithPortableBitSemantics::from_ref(&inner)),
125+
uninit.as_mut_ptr(),
126+
1,
127+
);
128+
129+
uninit
130+
};
131+
132+
let mut device_box = CudaDropWrapper::from(DeviceBox::<
133+
DeviceCopyWithPortableBitSemantics<ManuallyDrop<_ArcInner<T>>>,
134+
>::uninitialized()?);
135+
device_box.async_copy_from(&*locked_box, &stream)?;
136+
137+
Ok((
138+
Async::pending(
139+
DeviceAccessible::from(ArcCudaRepresentation(DeviceOwnedPointer(
140+
device_box.as_device_ptr().as_raw_mut().cast(),
141+
))),
142+
stream,
143+
NoCompletion,
144+
)?,
145+
CombinedCudaAlloc::new(CombinedCudaAlloc::new(locked_box, device_box), alloc),
146+
))
147+
}
148+
149+
#[cfg(feature = "host")]
150+
unsafe fn restore_async<'a, 'stream, A: CudaAlloc, O>(
151+
this: owning_ref::BoxRefMut<'a, O, Self>,
152+
alloc: CombinedCudaAlloc<Self::CudaAllocationAsync, A>,
153+
stream: crate::host::Stream<'stream>,
154+
) -> CudaResult<(
155+
Async<'a, 'stream, owning_ref::BoxRefMut<'a, O, Self>, CompletionFnMut<'a, Self>>,
156+
A,
157+
)> {
158+
let (_alloc_front, alloc_tail) = alloc.split();
159+
let r#async = Async::ready(this, stream);
160+
Ok((r#async, alloc_tail))
161+
}
162+
}
163+
164+
unsafe impl<T: PortableBitSemantics + TypeGraphLayout> CudaAsRust for ArcCudaRepresentation<T> {
165+
type RustRepresentation = Arc<T>;
166+
167+
#[cfg(feature = "device")]
168+
unsafe fn as_rust(this: &DeviceAccessible<Self>) -> Self::RustRepresentation {
169+
crate::deps::alloc::sync::Arc::from_raw(core::ptr::addr_of!((*(this.0 .0)).data))
170+
}
171+
}

src/lend/impls/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
mod arc;
12
mod r#box;
23
mod boxed_slice;
34
#[cfg(feature = "final")]

0 commit comments

Comments
 (0)