Skip to content

Commit 3eba1a2

Browse files
committed
Added error handling to the compile-time PTX checking
1 parent bacf064 commit 3eba1a2

File tree

4 files changed

+485
-96
lines changed

4 files changed

+485
-96
lines changed

rust-cuda-derive/Cargo.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@ serde_json = "1.0"
2222
cargo_metadata = { version = "0.17", features = ["builder"] }
2323
strip-ansi-escapes = "0.2"
2424
colored = "2.0"
25-
25+
thiserror = "1.0"
2626
seahash = "4.1"
2727
ptx-builder = { git = "https://github.com/juntyr/rust-ptx-builder", rev = "1f1f49d" }
28-
ptx_compiler = "0.1"
28+
29+
[build-dependencies]
30+
find_cuda_helper = "0.2"

rust-cuda-derive/build.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
fn main() {
2+
find_cuda_helper::include_cuda();
3+
24
println!("cargo:rustc-link-lib=nvptxcompiler_static");
35
}

rust-cuda-derive/src/kernel/link/mod.rs

Lines changed: 178 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
use std::{
22
env,
33
ffi::CString,
4+
fmt::Write as FmtWrite,
45
fs,
56
io::{Read, Write},
6-
mem::MaybeUninit,
77
os::raw::c_int,
88
path::{Path, PathBuf},
99
ptr::addr_of_mut,
@@ -16,15 +16,16 @@ use ptx_builder::{
1616
builder::{BuildStatus, Builder, MessageFormat, Profile},
1717
error::{BuildErrorKind, Error, Result},
1818
};
19-
use ptx_compiler::sys::size_t;
2019

2120
use super::utils::skip_kernel_compilation;
2221

2322
mod config;
2423
mod error;
24+
mod ptx_compiler_sys;
2525

2626
use config::{CheckKernelConfig, LinkKernelConfig};
2727
use error::emit_ptx_build_error;
28+
use ptx_compiler_sys::NvptxError;
2829

2930
pub fn check_kernel(tokens: TokenStream) -> TokenStream {
3031
proc_macro_error::set_dummy(quote! {
@@ -199,110 +200,41 @@ pub fn link_kernel(tokens: TokenStream) -> TokenStream {
199200
kernel_ptx.replace_range(type_layout_start..type_layout_end, "");
200201
}
201202

202-
let mut compiler = MaybeUninit::uninit();
203-
let r = unsafe {
204-
ptx_compiler::sys::nvPTXCompilerCreate(
205-
compiler.as_mut_ptr(),
206-
kernel_ptx.len() as size_t,
207-
kernel_ptx.as_ptr().cast(),
208-
)
209-
};
210-
emit_call_site_warning!("PTX compiler create result {}", r);
211-
let compiler = unsafe { compiler.assume_init() };
212-
213-
let mut major = 0;
214-
let mut minor = 0;
215-
let r = unsafe {
216-
ptx_compiler::sys::nvPTXCompilerGetVersion(addr_of_mut!(major), addr_of_mut!(minor))
217-
};
218-
emit_call_site_warning!("PTX version result {}", r);
219-
emit_call_site_warning!("PTX compiler version {}.{}", major, minor);
203+
let (result, error_log, info_log, version, drop) =
204+
check_kernel_ptx(&kernel_ptx, &specialisation, &kernel_hash);
220205

221-
let kernel_name = if specialisation.is_empty() {
222-
format!("{kernel_hash}_kernel")
223-
} else {
224-
format!(
225-
"{kernel_hash}_kernel_{:016x}",
226-
seahash::hash(specialisation.as_bytes())
227-
)
228-
};
229-
230-
let options = vec![
231-
CString::new("--entry").unwrap(),
232-
CString::new(kernel_name).unwrap(),
233-
CString::new("--verbose").unwrap(),
234-
CString::new("--warn-on-double-precision-use").unwrap(),
235-
CString::new("--warn-on-local-memory-usage").unwrap(),
236-
CString::new("--warn-on-spills").unwrap(),
237-
];
238-
let options_ptrs = options.iter().map(|o| o.as_ptr()).collect::<Vec<_>>();
239-
240-
let r = unsafe {
241-
ptx_compiler::sys::nvPTXCompilerCompile(
242-
compiler,
243-
options_ptrs.len() as c_int,
244-
options_ptrs.as_ptr().cast(),
245-
)
206+
let ptx_compiler = match &version {
207+
Ok((major, minor)) => format!("PTX compiler v{major}.{minor}"),
208+
Err(_) => String::from("PTX compiler"),
246209
};
247-
emit_call_site_warning!("PTX compile result {}", r);
248210

249-
let mut info_log_size = 0;
250-
let r = unsafe {
251-
ptx_compiler::sys::nvPTXCompilerGetInfoLogSize(compiler, addr_of_mut!(info_log_size))
252-
};
253-
emit_call_site_warning!("PTX info log size result {}", r);
254-
#[allow(clippy::cast_possible_truncation)]
255-
let mut info_log: Vec<u8> = Vec::with_capacity(info_log_size as usize);
256-
if info_log_size > 0 {
257-
let r = unsafe {
258-
ptx_compiler::sys::nvPTXCompilerGetInfoLog(compiler, info_log.as_mut_ptr().cast())
259-
};
260-
emit_call_site_warning!("PTX info log content result {}", r);
261-
#[allow(clippy::cast_possible_truncation)]
262-
unsafe {
263-
info_log.set_len(info_log_size as usize);
264-
}
265-
}
266-
let info_log = String::from_utf8_lossy(&info_log);
267-
268-
let mut error_log_size = 0;
269-
let r = unsafe {
270-
ptx_compiler::sys::nvPTXCompilerGetErrorLogSize(compiler, addr_of_mut!(error_log_size))
271-
};
272-
emit_call_site_warning!("PTX error log size result {}", r);
273-
#[allow(clippy::cast_possible_truncation)]
274-
let mut error_log: Vec<u8> = Vec::with_capacity(error_log_size as usize);
275-
if error_log_size > 0 {
276-
let r = unsafe {
277-
ptx_compiler::sys::nvPTXCompilerGetErrorLog(compiler, error_log.as_mut_ptr().cast())
278-
};
279-
emit_call_site_warning!("PTX error log content result {}", r);
280-
#[allow(clippy::cast_possible_truncation)]
281-
unsafe {
282-
error_log.set_len(error_log_size as usize);
283-
}
211+
// TODO: allow user to select
212+
// - warn on double
213+
// - warn on float
214+
// - warn on spills
215+
// - verbose warn
216+
// - warnings as errors
217+
// - show PTX source if warning or error
218+
219+
let mut errors = String::new();
220+
if let Err(err) = drop {
221+
let _ = errors.write_fmt(format_args!("Error dropping the {ptx_compiler}: {err}\n"));
284222
}
285-
let error_log = String::from_utf8_lossy(&error_log);
286-
287-
// Ensure the compiler is not dropped
288-
let mut compiler = MaybeUninit::new(compiler);
289-
let r = unsafe { ptx_compiler::sys::nvPTXCompilerDestroy(compiler.as_mut_ptr()) };
290-
emit_call_site_warning!("PTX compiler destroy result {}", r);
291-
292-
if !info_log.is_empty() {
293-
emit_call_site_warning!("PTX compiler info log:\n{}", info_log);
223+
if let Err(err) = version {
224+
let _ = errors.write_fmt(format_args!(
225+
"Error fetching the version of the {ptx_compiler}: {err}\n"
226+
));
294227
}
295-
if !error_log.is_empty() {
228+
if let (Ok(Some(_)), _) | (_, Ok(Some(_))) = (&info_log, &error_log) {
296229
let mut max_lines = kernel_ptx.chars().filter(|c| *c == '\n').count() + 1;
297230
let mut indent = 0;
298231
while max_lines > 0 {
299232
max_lines /= 10;
300233
indent += 1;
301234
}
302235

303-
abort_call_site!(
304-
"PTX compiler error log:\n{}\nPTX source:\n{}",
305-
error_log,
236+
emit_call_site_warning!(
237+
"PTX source code:\n{}",
306238
kernel_ptx
307239
.lines()
308240
.enumerate()
@@ -311,10 +243,162 @@ pub fn link_kernel(tokens: TokenStream) -> TokenStream {
311243
.join("\n")
312244
);
313245
}
246+
match info_log {
247+
Ok(None) => (),
248+
Ok(Some(info_log)) => emit_call_site_warning!("{ptx_compiler} info log:\n{}", info_log),
249+
Err(err) => {
250+
let _ = errors.write_fmt(format_args!(
251+
"Error fetching the info log of the {ptx_compiler}: {err}\n"
252+
));
253+
},
254+
};
255+
match error_log {
256+
Ok(None) => (),
257+
Ok(Some(error_log)) => emit_call_site_error!("{ptx_compiler} error log:\n{}", error_log),
258+
Err(err) => {
259+
let _ = errors.write_fmt(format_args!(
260+
"Error fetching the error log of the {ptx_compiler}: {err}\n"
261+
));
262+
},
263+
};
264+
if let Err(err) = result {
265+
let _ = errors.write_fmt(format_args!("Error compiling the PTX source code: {err}\n"));
266+
}
267+
if !errors.is_empty() {
268+
abort_call_site!("{}", errors);
269+
}
314270

315271
(quote! { const PTX_STR: &'static str = #kernel_ptx; #(#type_layouts)* }).into()
316272
}
317273

274+
#[allow(clippy::type_complexity)]
275+
fn check_kernel_ptx(
276+
kernel_ptx: &str,
277+
specialisation: &str,
278+
kernel_hash: &proc_macro2::Ident,
279+
) -> (
280+
Result<(), NvptxError>,
281+
Result<Option<String>, NvptxError>,
282+
Result<Option<String>, NvptxError>,
283+
Result<(u32, u32), NvptxError>,
284+
Result<(), NvptxError>,
285+
) {
286+
let compiler = {
287+
let mut compiler = std::ptr::null_mut();
288+
if let Err(err) = NvptxError::try_err_from(unsafe {
289+
ptx_compiler_sys::nvPTXCompilerCreate(
290+
addr_of_mut!(compiler),
291+
kernel_ptx.len() as ptx_compiler_sys::size_t,
292+
kernel_ptx.as_ptr().cast(),
293+
)
294+
}) {
295+
abort_call_site!("PTX compiler creation failed: {}", err);
296+
}
297+
compiler
298+
};
299+
300+
let result = {
301+
let kernel_name = if specialisation.is_empty() {
302+
format!("{kernel_hash}_kernel")
303+
} else {
304+
format!(
305+
"{kernel_hash}_kernel_{:016x}",
306+
seahash::hash(specialisation.as_bytes())
307+
)
308+
};
309+
310+
let options = vec![
311+
CString::new("--entry").unwrap(),
312+
CString::new(kernel_name).unwrap(),
313+
CString::new("--verbose").unwrap(),
314+
CString::new("--warn-on-double-precision-use").unwrap(),
315+
CString::new("--warn-on-local-memory-usage").unwrap(),
316+
CString::new("--warn-on-spills").unwrap(),
317+
];
318+
let options_ptrs = options.iter().map(|o| o.as_ptr()).collect::<Vec<_>>();
319+
320+
NvptxError::try_err_from(unsafe {
321+
ptx_compiler_sys::nvPTXCompilerCompile(
322+
compiler,
323+
options_ptrs.len() as c_int,
324+
options_ptrs.as_ptr().cast(),
325+
)
326+
})
327+
};
328+
329+
let error_log = (|| {
330+
let mut error_log_size = 0;
331+
332+
NvptxError::try_err_from(unsafe {
333+
ptx_compiler_sys::nvPTXCompilerGetErrorLogSize(compiler, addr_of_mut!(error_log_size))
334+
})?;
335+
336+
if error_log_size == 0 {
337+
return Ok(None);
338+
}
339+
340+
#[allow(clippy::cast_possible_truncation)]
341+
let mut error_log: Vec<u8> = Vec::with_capacity(error_log_size as usize);
342+
343+
NvptxError::try_err_from(unsafe {
344+
ptx_compiler_sys::nvPTXCompilerGetErrorLog(compiler, error_log.as_mut_ptr().cast())
345+
})?;
346+
347+
#[allow(clippy::cast_possible_truncation)]
348+
unsafe {
349+
error_log.set_len(error_log_size as usize);
350+
}
351+
352+
Ok(Some(String::from_utf8_lossy(&error_log).into_owned()))
353+
})();
354+
355+
let info_log = (|| {
356+
let mut info_log_size = 0;
357+
358+
NvptxError::try_err_from(unsafe {
359+
ptx_compiler_sys::nvPTXCompilerGetInfoLogSize(compiler, addr_of_mut!(info_log_size))
360+
})?;
361+
362+
if info_log_size == 0 {
363+
return Ok(None);
364+
}
365+
366+
#[allow(clippy::cast_possible_truncation)]
367+
let mut info_log: Vec<u8> = Vec::with_capacity(info_log_size as usize);
368+
369+
NvptxError::try_err_from(unsafe {
370+
ptx_compiler_sys::nvPTXCompilerGetInfoLog(compiler, info_log.as_mut_ptr().cast())
371+
})?;
372+
373+
#[allow(clippy::cast_possible_truncation)]
374+
unsafe {
375+
info_log.set_len(info_log_size as usize);
376+
}
377+
378+
Ok(Some(String::from_utf8_lossy(&info_log).into_owned()))
379+
})();
380+
381+
let version = (|| {
382+
let mut major = 0;
383+
let mut minor = 0;
384+
385+
NvptxError::try_err_from(unsafe {
386+
ptx_compiler_sys::nvPTXCompilerGetVersion(addr_of_mut!(major), addr_of_mut!(minor))
387+
})?;
388+
389+
Ok((major, minor))
390+
})();
391+
392+
let drop = {
393+
let mut compiler = compiler;
394+
NvptxError::try_err_from(unsafe {
395+
ptx_compiler_sys::nvPTXCompilerDestroy(addr_of_mut!(compiler))
396+
})
397+
};
398+
399+
(result, error_log, info_log, version, drop)
400+
}
401+
318402
fn compile_kernel(
319403
args: &syn::Ident,
320404
crate_name: &str,

0 commit comments

Comments
 (0)