Skip to content

Commit a8bc0d4

Browse files
committed
Improve kernel checking + added cubin dump lint
1 parent f37ef1a commit a8bc0d4

File tree

6 files changed

+278
-185
lines changed

6 files changed

+278
-185
lines changed

rust-cuda-derive/src/kernel/link/config.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,18 +67,21 @@ impl syn::parse::Parse for LinkKernelConfig {
6767

6868
#[allow(clippy::module_name_repetitions)]
6969
pub(super) struct CheckKernelConfig {
70+
pub(super) kernel_hash: syn::Ident,
7071
pub(super) args: syn::Ident,
7172
pub(super) crate_name: String,
7273
pub(super) crate_path: PathBuf,
7374
}
7475

7576
impl syn::parse::Parse for CheckKernelConfig {
7677
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
78+
let kernel_hash: syn::Ident = input.parse()?;
7779
let args: syn::Ident = input.parse()?;
7880
let name: syn::LitStr = input.parse()?;
7981
let path: syn::LitStr = input.parse()?;
8082

8183
Ok(Self {
84+
kernel_hash,
8285
args,
8386
crate_name: name.value(),
8487
crate_path: PathBuf::from(path.value()),

rust-cuda-derive/src/kernel/link/mod.rs

Lines changed: 155 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -32,31 +32,38 @@ use error::emit_ptx_build_error;
3232
use ptx_compiler_sys::NvptxError;
3333

3434
pub fn check_kernel(tokens: TokenStream) -> TokenStream {
35-
proc_macro_error::set_dummy(quote! {
36-
"ERROR in this PTX compilation"
37-
});
35+
proc_macro_error::set_dummy(quote! {::core::result::Result::Err(())});
3836

3937
let CheckKernelConfig {
38+
kernel_hash,
4039
args,
4140
crate_name,
4241
crate_path,
4342
} = match syn::parse_macro_input::parse(tokens) {
4443
Ok(config) => config,
4544
Err(err) => {
4645
abort_call_site!(
47-
"check_kernel!(ARGS NAME PATH) expects ARGS identifier, NAME and PATH string \
48-
literals: {:?}",
46+
"check_kernel!(HASH ARGS NAME PATH) expects HASH and ARGS identifiers, annd NAME \
47+
and PATH string literals: {:?}",
4948
err
5049
)
5150
},
5251
};
5352

5453
let kernel_ptx = compile_kernel(&args, &crate_name, &crate_path, Specialisation::Check);
5554

56-
match kernel_ptx {
57-
Some(kernel_ptx) => quote!(#kernel_ptx).into(),
58-
None => quote!("ERROR in this PTX compilation").into(),
59-
}
55+
let Some(kernel_ptx) = kernel_ptx else {
56+
return quote!(::core::result::Result::Err(())).into()
57+
};
58+
59+
check_kernel_ptx_and_report(
60+
&kernel_ptx,
61+
Specialisation::Check,
62+
&kernel_hash,
63+
&HashMap::new(),
64+
);
65+
66+
quote!(::core::result::Result::Ok(())).into()
6067
}
6168

6269
#[allow(clippy::module_name_repetitions, clippy::too_many_lines)]
@@ -77,9 +84,9 @@ pub fn link_kernel(tokens: TokenStream) -> TokenStream {
7784
Ok(config) => config,
7885
Err(err) => {
7986
abort_call_site!(
80-
"link_kernel!(KERNEL ARGS NAME PATH SPECIALISATION LINTS,*) expects KERNEL and \
81-
ARGS identifiers, NAME and PATH string literals, SPECIALISATION and LINTS \
82-
tokens: {:?}",
87+
"link_kernel!(KERNEL HASH ARGS NAME PATH SPECIALISATION LINTS,*) expects KERNEL, \
88+
HASH, and ARGS identifiers, NAME and PATH string literals, and SPECIALISATION \
89+
and LINTS tokens: {:?}",
8390
err
8491
)
8592
},
@@ -206,88 +213,162 @@ pub fn link_kernel(tokens: TokenStream) -> TokenStream {
206213
kernel_ptx.replace_range(type_layout_start..type_layout_end, "");
207214
}
208215

209-
let (result, error_log, info_log, version, drop) =
210-
check_kernel_ptx(&kernel_ptx, &specialisation, &kernel_hash, &ptx_lint_levels);
216+
check_kernel_ptx_and_report(
217+
&kernel_ptx,
218+
Specialisation::Link(&specialisation),
219+
&kernel_hash,
220+
&ptx_lint_levels,
221+
);
222+
223+
(quote! { const PTX_STR: &'static str = #kernel_ptx; #(#type_layouts)* }).into()
224+
}
225+
226+
#[allow(clippy::too_many_lines)]
227+
fn check_kernel_ptx_and_report(
228+
kernel_ptx: &str,
229+
specialisation: Specialisation,
230+
kernel_hash: &proc_macro2::Ident,
231+
ptx_lint_levels: &HashMap<PtxLint, LintLevel>,
232+
) {
233+
let (result, error_log, info_log, binary, version, drop) =
234+
check_kernel_ptx(kernel_ptx, specialisation, kernel_hash, ptx_lint_levels);
211235

212236
let ptx_compiler = match &version {
213237
Ok((major, minor)) => format!("PTX compiler v{major}.{minor}"),
214238
Err(_) => String::from("PTX compiler"),
215239
};
216240

217-
// TODO: allow user to select
218-
// - warn on double
219-
// - warn on float
220-
// - warn on spills
221-
// - verbose warn
222-
// - warnings as errors
223-
// - show PTX source if warning or error
224-
225241
let mut errors = String::new();
242+
226243
if let Err(err) = drop {
227244
let _ = errors.write_fmt(format_args!("Error dropping the {ptx_compiler}: {err}\n"));
228245
}
246+
229247
if let Err(err) = version {
230248
let _ = errors.write_fmt(format_args!(
231249
"Error fetching the version of the {ptx_compiler}: {err}\n"
232250
));
233251
}
234-
if let (Ok(Some(_)), _) | (_, Ok(Some(_))) = (&info_log, &error_log) {
252+
253+
let ptx_source_code = {
235254
let mut max_lines = kernel_ptx.chars().filter(|c| *c == '\n').count() + 1;
236255
let mut indent = 0;
237256
while max_lines > 0 {
238257
max_lines /= 10;
239258
indent += 1;
240259
}
241260

242-
emit_call_site_warning!(
261+
format!(
243262
"PTX source code:\n{}",
244263
kernel_ptx
245264
.lines()
246265
.enumerate()
247266
.map(|(i, l)| format!("{:indent$}| {l}", i + 1))
248267
.collect::<Vec<_>>()
249268
.join("\n")
250-
);
269+
)
270+
};
271+
272+
match binary {
273+
Ok(None) => (),
274+
Ok(Some(binary)) => {
275+
if ptx_lint_levels
276+
.get(&PtxLint::DumpBinary)
277+
.map_or(false, |level| *level > LintLevel::Allow)
278+
{
279+
const HEX: [char; 16] = [
280+
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f',
281+
];
282+
283+
let mut binary_hex = String::with_capacity(binary.len() * 2);
284+
for byte in binary {
285+
binary_hex.push(HEX[usize::from(byte >> 4)]);
286+
binary_hex.push(HEX[usize::from(byte & 0x0F)]);
287+
}
288+
289+
if ptx_lint_levels
290+
.get(&PtxLint::DumpBinary)
291+
.map_or(false, |level| *level > LintLevel::Warn)
292+
{
293+
emit_call_site_error!(
294+
"{} compiled binary:\n{}\n\n{}",
295+
ptx_compiler,
296+
binary_hex,
297+
ptx_source_code
298+
);
299+
} else {
300+
emit_call_site_warning!(
301+
"{} compiled binary:\n{}\n\n{}",
302+
ptx_compiler,
303+
binary_hex,
304+
ptx_source_code
305+
);
306+
}
307+
}
308+
},
309+
Err(err) => {
310+
let _ = errors.write_fmt(format_args!(
311+
"Error fetching the compiled binary from {ptx_compiler}: {err}\n"
312+
));
313+
},
251314
}
315+
252316
match info_log {
253317
Ok(None) => (),
254-
Ok(Some(info_log)) => emit_call_site_warning!("{ptx_compiler} info log:\n{}", info_log),
318+
Ok(Some(info_log)) => emit_call_site_warning!(
319+
"{} info log:\n{}\n{}",
320+
ptx_compiler,
321+
info_log,
322+
ptx_source_code
323+
),
255324
Err(err) => {
256325
let _ = errors.write_fmt(format_args!(
257326
"Error fetching the info log of the {ptx_compiler}: {err}\n"
258327
));
259328
},
260329
};
261-
match error_log {
262-
Ok(None) => (),
263-
Ok(Some(error_log)) => emit_call_site_error!("{ptx_compiler} error log:\n{}", error_log),
330+
331+
let error_log = match error_log {
332+
Ok(None) => String::new(),
333+
Ok(Some(error_log)) => {
334+
format!("{ptx_compiler} error log:\n{error_log}\n{ptx_source_code}")
335+
},
264336
Err(err) => {
265337
let _ = errors.write_fmt(format_args!(
266338
"Error fetching the error log of the {ptx_compiler}: {err}\n"
267339
));
340+
String::new()
268341
},
269342
};
343+
270344
if let Err(err) = result {
271345
let _ = errors.write_fmt(format_args!("Error compiling the PTX source code: {err}\n"));
272346
}
273-
if !errors.is_empty() {
274-
abort_call_site!("{}", errors);
275-
}
276347

277-
(quote! { const PTX_STR: &'static str = #kernel_ptx; #(#type_layouts)* }).into()
348+
if !error_log.is_empty() || !errors.is_empty() {
349+
abort_call_site!(
350+
"{error_log}{}{errors}",
351+
if !error_log.is_empty() && !errors.is_empty() {
352+
"\n\n"
353+
} else {
354+
""
355+
}
356+
);
357+
}
278358
}
279359

280360
#[allow(clippy::type_complexity)]
281361
#[allow(clippy::too_many_lines)]
282362
fn check_kernel_ptx(
283363
kernel_ptx: &str,
284-
specialisation: &str,
364+
specialisation: Specialisation,
285365
kernel_hash: &proc_macro2::Ident,
286366
ptx_lint_levels: &HashMap<PtxLint, LintLevel>,
287367
) -> (
288368
Result<(), NvptxError>,
289369
Result<Option<String>, NvptxError>,
290370
Result<Option<String>, NvptxError>,
371+
Result<Option<Vec<u8>>, NvptxError>,
291372
Result<(u32, u32), NvptxError>,
292373
Result<(), NvptxError>,
293374
) {
@@ -306,14 +387,15 @@ fn check_kernel_ptx(
306387
};
307388

308389
let result = (|| {
309-
let kernel_name = if specialisation.is_empty() {
310-
format!("{kernel_hash}_kernel")
311-
} else {
312-
format!(
390+
let kernel_name = match specialisation {
391+
Specialisation::Check => format!("{kernel_hash}_chECK"),
392+
Specialisation::Link("") => format!("{kernel_hash}_kernel"),
393+
Specialisation::Link(specialisation) => format!(
313394
"{kernel_hash}_kernel_{:016x}",
314395
seahash::hash(specialisation.as_bytes())
315-
)
396+
),
316397
};
398+
317399
let mut options = vec![
318400
CString::new("--entry").unwrap(),
319401
CString::new(kernel_name).unwrap(),
@@ -450,6 +532,39 @@ fn check_kernel_ptx(
450532
Ok(Some(String::from_utf8_lossy(&info_log).into_owned()))
451533
})();
452534

535+
let binary = (|| {
536+
if result.is_err() {
537+
return Ok(None);
538+
}
539+
540+
let mut binary_size = 0;
541+
542+
NvptxError::try_err_from(unsafe {
543+
ptx_compiler_sys::nvPTXCompilerGetCompiledProgramSize(
544+
compiler,
545+
addr_of_mut!(binary_size),
546+
)
547+
})?;
548+
549+
if binary_size == 0 {
550+
return Ok(None);
551+
}
552+
553+
#[allow(clippy::cast_possible_truncation)]
554+
let mut binary: Vec<u8> = Vec::with_capacity(binary_size as usize);
555+
556+
NvptxError::try_err_from(unsafe {
557+
ptx_compiler_sys::nvPTXCompilerGetCompiledProgram(compiler, binary.as_mut_ptr().cast())
558+
})?;
559+
560+
#[allow(clippy::cast_possible_truncation)]
561+
unsafe {
562+
binary.set_len(binary_size as usize);
563+
}
564+
565+
Ok(Some(binary))
566+
})();
567+
453568
let version = (|| {
454569
let mut major = 0;
455570
let mut minor = 0;
@@ -468,7 +583,7 @@ fn check_kernel_ptx(
468583
})
469584
};
470585

471-
(result, error_log, info_log, version, drop)
586+
(result, error_log, info_log, binary, version, drop)
472587
}
473588

474589
fn compile_kernel(

0 commit comments

Comments
 (0)