Skip to content

Commit a681d67

Browse files
committed
Try switching from rustacuda to cust
1 parent 2a124b6 commit a681d67

File tree

26 files changed

+230
-222
lines changed

26 files changed

+230
-222
lines changed

.github/workflows/rustdoc.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@ jobs:
3737
--enable-index-page \
3838
--extern-html-root-url const_type_layout=https://docs.rs/const-type-layout/0.3.2/ \
3939
--extern-html-root-url final=https://docs.rs/final/0.1.1/ \
40-
--extern-html-root-url rustacuda=https://docs.rs/rustacuda/0.1.3/ \
41-
--extern-html-root-url rustacuda_core=https://docs.rs/rustacuda_core/0.1.2/ \
42-
--extern-html-root-url rustacuda_derive=https://docs.rs/rustacuda_derive/0.1.2/ \
40+
--extern-html-root-url cust=https://docs.rs/cust/0.3.2/ \
41+
--extern-html-root-url cust_core=https://docs.rs/cust_core/0.1/ \
42+
--extern-html-root-url cust_derive=https://docs.rs/cust_derive/0.2/ \
4343
-Zunstable-options \
4444
" cargo doc \
4545
--all-features \

Cargo.toml

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,26 +19,22 @@ rust-version = "1.81" # nightly
1919

2020
[features]
2121
default = []
22-
derive = ["dep:rustacuda_derive", "dep:rust-cuda-derive"]
22+
derive = ["dep:cust_derive", "dep:rust-cuda-derive"]
2323
device = []
2424
final = ["dep:final"]
25-
host = ["dep:rustacuda", "dep:regex", "dep:oneshot", "dep:safer_owning_ref"]
25+
host = ["dep:cust", "dep:regex", "dep:oneshot", "dep:safer_owning_ref"]
2626
kernel = ["dep:rust-cuda-kernel"]
2727

2828
[dependencies]
29-
rustacuda_core = { git = "https://github.com/juntyr/RustaCUDA", rev = "c6ea7cc" }
30-
31-
rustacuda = { git = "https://github.com/juntyr/RustaCUDA", rev = "c6ea7cc", optional = true }
32-
rustacuda_derive = { git = "https://github.com/juntyr/RustaCUDA", rev = "c6ea7cc", optional = true }
33-
34-
regex = { version = "1.10", optional = true }
35-
36-
const-type-layout = { version = "0.3.2", features = ["derive"] }
37-
38-
safer_owning_ref = { version = "0.5", optional = true }
39-
oneshot = { version = "0.1", optional = true, features = ["std", "async"] }
40-
41-
final = { version = "0.1.1", optional = true }
42-
43-
rust-cuda-derive = { path = "rust-cuda-derive", optional = true }
44-
rust-cuda-kernel = { path = "rust-cuda-kernel", optional = true }
29+
const-type-layout = { version = "0.3.2", default-features = false, features = ["derive"] }
30+
# FIXME: cust fails to compile without the `bytemuck` feature
31+
cust = { version = "0.3.2", default-features = false, features = ["bytemuck"], optional = true }
32+
cust_core = { version = "0.1", default-features = false }
33+
cust_derive = { version = "0.2", default-features = false, optional = true }
34+
final = { version = "0.1.1", default-features = false, optional = true }
35+
oneshot = { version = "0.1", default-features = false, features = ["std", "async"], optional = true }
36+
regex = { version = "1.10", default-features = false, optional = true }
37+
safer_owning_ref = { version = "0.5", default-features = false, optional = true }
38+
39+
rust-cuda-derive = { path = "rust-cuda-derive", default-features = false, optional = true }
40+
rust-cuda-kernel = { path = "rust-cuda-kernel", default-features = false, optional = true }

examples/print/src/main.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,38 +2,38 @@
22

33
use print::{kernel, link, Action};
44

5-
fn main() -> rust_cuda::deps::rustacuda::error::CudaResult<()> {
5+
fn main() -> rust_cuda::deps::cust::error::CudaResult<()> {
66
// Link the non-generic CUDA kernel
77
struct KernelPtx;
88
link! { impl kernel for KernelPtx }
99

1010
// Initialize the CUDA API
11-
rust_cuda::deps::rustacuda::init(rust_cuda::deps::rustacuda::CudaFlags::empty())?;
11+
rust_cuda::deps::cust::init(rust_cuda::deps::cust::CudaFlags::empty())?;
1212

1313
// Get the first CUDA GPU device
14-
let device = rust_cuda::deps::rustacuda::device::Device::get_device(0)?;
14+
let device = rust_cuda::deps::cust::device::Device::get_device(0)?;
1515

1616
// Create a CUDA context associated to this device
1717
let _context = rust_cuda::host::CudaDropWrapper::from(
18-
rust_cuda::deps::rustacuda::context::Context::create_and_push(
19-
rust_cuda::deps::rustacuda::context::ContextFlags::MAP_HOST
20-
| rust_cuda::deps::rustacuda::context::ContextFlags::SCHED_AUTO,
18+
rust_cuda::deps::cust::context::Context::create_and_push(
19+
rust_cuda::deps::cust::context::ContextFlags::MAP_HOST
20+
| rust_cuda::deps::cust::context::ContextFlags::SCHED_AUTO,
2121
device,
2222
)?,
2323
);
2424

2525
// Create a new CUDA stream to submit kernels to
2626
let mut stream =
27-
rust_cuda::host::CudaDropWrapper::from(rust_cuda::deps::rustacuda::stream::Stream::new(
28-
rust_cuda::deps::rustacuda::stream::StreamFlags::NON_BLOCKING,
27+
rust_cuda::host::CudaDropWrapper::from(rust_cuda::deps::cust::stream::Stream::new(
28+
rust_cuda::deps::cust::stream::StreamFlags::NON_BLOCKING,
2929
None,
3030
)?);
3131

3232
// Create a new instance of the CUDA kernel and prepare the launch config
3333
let mut kernel = rust_cuda::kernel::TypedPtxKernel::<kernel>::new::<KernelPtx>(None);
3434
let config = rust_cuda::kernel::LaunchConfig {
35-
grid: rust_cuda::deps::rustacuda::function::GridSize::x(1),
36-
block: rust_cuda::deps::rustacuda::function::BlockSize::x(4),
35+
grid: rust_cuda::deps::cust::function::GridSize::x(1),
36+
block: rust_cuda::deps::cust::function::BlockSize::x(4),
3737
ptx_jit: false,
3838
};
3939

rust-cuda-derive/src/rust_to_cuda/impl.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ pub fn rust_to_cuda_trait(
8484
unsafe fn borrow<CudaAllocType: #crate_path::alloc::CudaAlloc>(
8585
&self,
8686
alloc: CudaAllocType,
87-
) -> #crate_path::deps::rustacuda::error::CudaResult<(
87+
) -> #crate_path::deps::cust::error::CudaResult<(
8888
#crate_path::utils::ffi::DeviceAccessible<Self::CudaRepresentation>,
8989
#crate_path::alloc::CombinedCudaAlloc<Self::CudaAllocation, CudaAllocType>
9090
)> {
@@ -107,7 +107,7 @@ pub fn rust_to_cuda_trait(
107107
alloc: #crate_path::alloc::CombinedCudaAlloc<
108108
Self::CudaAllocation, CudaAllocType
109109
>,
110-
) -> #crate_path::deps::rustacuda::error::CudaResult<CudaAllocType> {
110+
) -> #crate_path::deps::cust::error::CudaResult<CudaAllocType> {
111111
let (alloc_front, alloc_tail) = alloc.split();
112112

113113
#(#r2c_field_destructors)*
@@ -192,7 +192,7 @@ pub fn rust_to_cuda_async_trait(
192192
&self,
193193
alloc: CudaAllocType,
194194
stream: #crate_path::host::Stream<'stream>,
195-
) -> #crate_path::deps::rustacuda::error::CudaResult<(
195+
) -> #crate_path::deps::cust::error::CudaResult<(
196196
#crate_path::utils::r#async::Async<
197197
'_, 'stream,
198198
#crate_path::utils::ffi::DeviceAccessible<Self::CudaRepresentation>,
@@ -220,7 +220,7 @@ pub fn rust_to_cuda_async_trait(
220220
Self::CudaAllocationAsync, CudaAllocType
221221
>,
222222
stream: #crate_path::host::Stream<'stream>,
223-
) -> #crate_path::deps::rustacuda::error::CudaResult<(
223+
) -> #crate_path::deps::cust::error::CudaResult<(
224224
#crate_path::utils::r#async::Async<
225225
'a, 'stream,
226226
#crate_path::deps::owning_ref::BoxRefMut<'a, CudaRestoreOwner, Self>,

src/deps.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ pub extern crate const_type_layout;
77
pub extern crate owning_ref;
88

99
#[cfg(feature = "host")]
10-
pub extern crate rustacuda;
10+
pub extern crate cust;
1111

12-
pub extern crate rustacuda_core;
12+
pub extern crate cust_core;

src/host/mod.rs

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@ use std::{
55
};
66

77
use const_type_layout::TypeGraphLayout;
8-
use rustacuda::{
8+
use cust::{
99
context::Context,
1010
error::CudaError,
1111
event::Event,
1212
memory::{CopyDestination, DeviceBox, DeviceBuffer, LockedBox, LockedBuffer},
1313
module::Module,
1414
};
15+
use cust_core::DeviceCopy;
1516

1617
use crate::{
1718
safety::PortableBitSemantics,
@@ -30,12 +31,12 @@ type InvariantLifetime<'brand> = PhantomData<fn(&'brand ()) -> &'brand ()>;
3031
#[derive(Copy, Clone)]
3132
#[repr(transparent)]
3233
pub struct Stream<'stream> {
33-
stream: &'stream rustacuda::stream::Stream,
34+
stream: &'stream cust::stream::Stream,
3435
_brand: InvariantLifetime<'stream>,
3536
}
3637

3738
impl<'stream> Deref for Stream<'stream> {
38-
type Target = rustacuda::stream::Stream;
39+
type Target = cust::stream::Stream;
3940

4041
fn deref(&self) -> &Self::Target {
4142
self.stream
@@ -65,7 +66,7 @@ impl<'stream> Stream<'stream> {
6566
/// }
6667
/// ```
6768
pub fn with<O>(
68-
stream: &mut rustacuda::stream::Stream,
69+
stream: &mut cust::stream::Stream,
6970
inner: impl for<'new_stream> FnOnce(Stream<'new_stream>) -> O,
7071
) -> O {
7172
inner(Stream {
@@ -77,7 +78,7 @@ impl<'stream> Stream<'stream> {
7778

7879
pub trait CudaDroppable: Sized {
7980
#[expect(clippy::missing_errors_doc)]
80-
fn drop(val: Self) -> Result<(), (rustacuda::error::CudaError, Self)>;
81+
fn drop(val: Self) -> Result<(), (cust::error::CudaError, Self)>;
8182
}
8283

8384
#[repr(transparent)]
@@ -112,25 +113,27 @@ impl<C: CudaDroppable> DerefMut for CudaDropWrapper<C> {
112113
}
113114
}
114115

115-
impl<T> CudaDroppable for DeviceBox<T> {
116+
impl<T: DeviceCopy> CudaDroppable for DeviceBox<T> {
116117
fn drop(val: Self) -> Result<(), (CudaError, Self)> {
117118
Self::drop(val)
118119
}
119120
}
120121

121-
impl<T: rustacuda_core::DeviceCopy> CudaDroppable for DeviceBuffer<T> {
122+
impl<T: cust_core::DeviceCopy> CudaDroppable for DeviceBuffer<T> {
122123
fn drop(val: Self) -> Result<(), (CudaError, Self)> {
123124
Self::drop(val)
124125
}
125126
}
126127

127-
impl<T> CudaDroppable for LockedBox<T> {
128+
impl<T: DeviceCopy> CudaDroppable for LockedBox<T> {
128129
fn drop(val: Self) -> Result<(), (CudaError, Self)> {
129-
Self::drop(val)
130+
// FIXME: cust's LockedBox no longer has a fallible drop
131+
std::mem::drop(val);
132+
Ok(())
130133
}
131134
}
132135

133-
impl<T: rustacuda_core::DeviceCopy> CudaDroppable for LockedBuffer<T> {
136+
impl<T: cust_core::DeviceCopy> CudaDroppable for LockedBuffer<T> {
134137
fn drop(val: Self) -> Result<(), (CudaError, Self)> {
135138
Self::drop(val)
136139
}
@@ -147,7 +150,7 @@ macro_rules! impl_sealed_drop_value {
147150
}
148151

149152
impl_sealed_drop_value!(Module);
150-
impl_sealed_drop_value!(rustacuda::stream::Stream);
153+
impl_sealed_drop_value!(cust::stream::Stream);
151154
impl_sealed_drop_value!(Context);
152155
impl_sealed_drop_value!(Event);
153156

@@ -207,7 +210,7 @@ impl<'a, T: PortableBitSemantics + TypeGraphLayout> HostAndDeviceMutRef<'a, T> {
207210
'a: 'b,
208211
{
209212
DeviceMutRef {
210-
pointer: DeviceMutPointer(self.device_box.as_device_ptr().as_raw_mut().cast()),
213+
pointer: DeviceMutPointer(self.device_box.as_device_ptr().as_mut_ptr().cast()),
211214
reference: PhantomData,
212215
}
213216
}
@@ -322,10 +325,10 @@ impl<'a, T: PortableBitSemantics + TypeGraphLayout> HostAndDeviceConstRef<'a, T>
322325
where
323326
'a: 'b,
324327
{
325-
let mut hack = ManuallyDrop::new(unsafe { std::ptr::read(self.device_box) });
328+
let hack = ManuallyDrop::new(unsafe { std::ptr::read(self.device_box) });
326329

327330
DeviceConstRef {
328-
pointer: DeviceConstPointer(hack.as_device_ptr().as_raw().cast()),
331+
pointer: DeviceConstPointer(hack.as_device_ptr().as_ptr().cast()),
329332
reference: PhantomData,
330333
}
331334
}
@@ -390,7 +393,7 @@ impl<'a, T: PortableBitSemantics + TypeGraphLayout> HostAndDeviceOwned<'a, T> {
390393
#[must_use]
391394
pub(crate) fn for_device(self) -> DeviceOwnedRef<'a, T> {
392395
DeviceOwnedRef {
393-
pointer: DeviceOwnedPointer(self.device_box.as_device_ptr().as_raw_mut().cast()),
396+
pointer: DeviceOwnedPointer(self.device_box.as_device_ptr().as_mut_ptr().cast()),
394397
marker: PhantomData::<T>,
395398
reference: PhantomData::<&'a mut ()>,
396399
}

src/kernel/mod.rs

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use core::str;
12
#[cfg(feature = "host")]
23
use std::{
34
ffi::{CStr, CString},
@@ -6,8 +7,9 @@ use std::{
67
ptr::NonNull,
78
};
89

10+
use cust::module::{ModuleJitOption, OptLevel};
911
#[cfg(feature = "host")]
10-
use rustacuda::{
12+
use cust::{
1113
error::{CudaError, CudaResult},
1214
function::Function,
1315
module::Module,
@@ -42,12 +44,7 @@ mod sealed {
4244

4345
#[cfg(all(feature = "host", not(doc)))]
4446
#[doc(hidden)]
45-
pub trait WithNewAsync<
46-
'stream,
47-
P: ?Sized + CudaKernelParameter,
48-
O,
49-
E: From<rustacuda::error::CudaError>,
50-
>
47+
pub trait WithNewAsync<'stream, P: ?Sized + CudaKernelParameter, O, E: From<cust::error::CudaError>>
5148
{
5249
fn with<'b>(self, param: P::AsyncHostType<'stream, 'b>) -> Result<O, E>
5350
where
@@ -59,7 +56,7 @@ impl<
5956
'stream,
6057
P: ?Sized + CudaKernelParameter,
6158
O,
62-
E: From<rustacuda::error::CudaError>,
59+
E: From<cust::error::CudaError>,
6360
F: for<'b> FnOnce(P::AsyncHostType<'stream, 'b>) -> Result<O, E>,
6461
> WithNewAsync<'stream, P, O, E> for F
6562
{
@@ -109,7 +106,7 @@ pub trait CudaKernelParameter: sealed::Sealed {
109106

110107
#[cfg(feature = "host")]
111108
#[expect(clippy::missing_errors_doc)] // FIXME
112-
fn with_new_async<'stream, 'b, O, E: From<rustacuda::error::CudaError>>(
109+
fn with_new_async<'stream, 'b, O, E: From<cust::error::CudaError>>(
113110
param: Self::SyncHostType,
114111
stream: crate::host::Stream<'stream>,
115112
#[cfg(not(doc))] inner: impl WithNewAsync<'stream, Self, O, E>,
@@ -139,7 +136,7 @@ pub trait CudaKernelParameter: sealed::Sealed {
139136

140137
#[doc(hidden)]
141138
#[cfg(feature = "host")]
142-
fn async_to_ffi<'stream, 'b, E: From<rustacuda::error::CudaError>>(
139+
fn async_to_ffi<'stream, 'b, E: From<cust::error::CudaError>>(
143140
param: Self::AsyncHostType<'stream, 'b>,
144141
token: sealed::Token,
145142
) -> Result<Self::FfiType<'stream, 'b>, E>
@@ -286,8 +283,8 @@ impl<'stream, 'kernel, Kernel> Launcher<'stream, 'kernel, Kernel> {
286283
#[cfg(feature = "host")]
287284
#[derive(Clone, Debug, PartialEq, Eq)]
288285
pub struct LaunchConfig {
289-
pub grid: rustacuda::function::GridSize,
290-
pub block: rustacuda::function::BlockSize,
286+
pub grid: cust::function::GridSize,
287+
pub block: cust::function::BlockSize,
291288
pub ptx_jit: bool,
292289
}
293290

@@ -305,9 +302,15 @@ impl RawPtxKernel {
305302
/// Returns a [`CudaError`] if `ptx` is not a valid PTX source, or it does
306303
/// not contain an entry point named `entry_point`.
307304
pub fn new(ptx: &CStr, entry_point: &CStr) -> CudaResult<Self> {
308-
let module: Box<Module> = Box::new(Module::load_from_string(ptx)?);
309-
310-
let function = unsafe { &*std::ptr::from_ref(module.as_ref()) }.get_function(entry_point);
305+
let module: Box<Module> = Box::new(Module::from_ptx_cstr(
306+
ptx,
307+
&[ModuleJitOption::OptLevel(OptLevel::O4)],
308+
)?);
309+
310+
// FIXME: cust's Module::get_function takes a str and turns it back into
311+
// a CString immediately
312+
let function = unsafe { &*std::ptr::from_ref(module.as_ref()) }
313+
.get_function(unsafe { str::from_utf8_unchecked(entry_point.to_bytes()) });
311314

312315
let function = match function {
313316
Ok(function) => function,

0 commit comments

Comments
 (0)