Skip to content

Commit f37ef1a

Browse files
committed
Added lint checking support to monomorphised kernel impls
1 parent 4b3f2d0 commit f37ef1a

File tree

7 files changed

+291
-149
lines changed

7 files changed

+291
-149
lines changed

rust-cuda-derive/src/kernel/link/config.rs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
use std::path::PathBuf;
1+
use std::{collections::HashMap, path::PathBuf};
2+
3+
use super::super::lints::{parse_ptx_lint_level, LintLevel, PtxLint};
24

35
#[allow(clippy::module_name_repetitions)]
46
pub(super) struct LinkKernelConfig {
@@ -8,6 +10,7 @@ pub(super) struct LinkKernelConfig {
810
pub(super) crate_name: String,
911
pub(super) crate_path: PathBuf,
1012
pub(super) specialisation: String,
13+
pub(super) ptx_lint_levels: HashMap<PtxLint, LintLevel>,
1114
}
1215

1316
impl syn::parse::Parse for LinkKernelConfig {
@@ -37,13 +40,27 @@ impl syn::parse::Parse for LinkKernelConfig {
3740
String::new()
3841
};
3942

43+
let attrs = syn::punctuated::Punctuated::<
44+
syn::MetaList,
45+
syn::token::Comma,
46+
>::parse_separated_nonempty(input)?;
47+
48+
let mut ptx_lint_levels = HashMap::new();
49+
50+
for syn::MetaList { path, nested, .. } in attrs {
51+
parse_ptx_lint_level(&path, &nested, &mut ptx_lint_levels);
52+
}
53+
54+
proc_macro_error::abort_if_dirty();
55+
4056
Ok(Self {
4157
kernel,
4258
kernel_hash,
4359
args,
4460
crate_name: name.value(),
4561
crate_path: PathBuf::from(path.value()),
4662
specialisation,
63+
ptx_lint_levels,
4764
})
4865
}
4966
}

rust-cuda-derive/src/kernel/link/mod.rs

Lines changed: 84 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use std::{
2+
collections::HashMap,
23
env,
34
ffi::CString,
45
fmt::Write as FmtWrite,
@@ -17,7 +18,10 @@ use ptx_builder::{
1718
error::{BuildErrorKind, Error, Result},
1819
};
1920

20-
use super::utils::skip_kernel_compilation;
21+
use super::{
22+
lints::{LintLevel, PtxLint},
23+
utils::skip_kernel_compilation,
24+
};
2125

2226
mod config;
2327
mod error;
@@ -68,12 +72,14 @@ pub fn link_kernel(tokens: TokenStream) -> TokenStream {
6872
crate_name,
6973
crate_path,
7074
specialisation,
75+
ptx_lint_levels,
7176
} = match syn::parse_macro_input::parse(tokens) {
7277
Ok(config) => config,
7378
Err(err) => {
7479
abort_call_site!(
75-
"link_kernel!(KERNEL ARGS NAME PATH SPECIALISATION) expects KERNEL and ARGS \
76-
identifiers, NAME and PATH string literals, and SPECIALISATION tokens: {:?}",
80+
"link_kernel!(KERNEL ARGS NAME PATH SPECIALISATION LINTS,*) expects KERNEL and \
81+
ARGS identifiers, NAME and PATH string literals, SPECIALISATION and LINTS \
82+
tokens: {:?}",
7783
err
7884
)
7985
},
@@ -201,7 +207,7 @@ pub fn link_kernel(tokens: TokenStream) -> TokenStream {
201207
}
202208

203209
let (result, error_log, info_log, version, drop) =
204-
check_kernel_ptx(&kernel_ptx, &specialisation, &kernel_hash);
210+
check_kernel_ptx(&kernel_ptx, &specialisation, &kernel_hash, &ptx_lint_levels);
205211

206212
let ptx_compiler = match &version {
207213
Ok((major, minor)) => format!("PTX compiler v{major}.{minor}"),
@@ -272,10 +278,12 @@ pub fn link_kernel(tokens: TokenStream) -> TokenStream {
272278
}
273279

274280
#[allow(clippy::type_complexity)]
281+
#[allow(clippy::too_many_lines)]
275282
fn check_kernel_ptx(
276283
kernel_ptx: &str,
277284
specialisation: &str,
278285
kernel_hash: &proc_macro2::Ident,
286+
ptx_lint_levels: &HashMap<PtxLint, LintLevel>,
279287
) -> (
280288
Result<(), NvptxError>,
281289
Result<Option<String>, NvptxError>,
@@ -297,7 +305,7 @@ fn check_kernel_ptx(
297305
compiler
298306
};
299307

300-
let result = {
308+
let result = (|| {
301309
let kernel_name = if specialisation.is_empty() {
302310
format!("{kernel_hash}_kernel")
303311
} else {
@@ -306,15 +314,79 @@ fn check_kernel_ptx(
306314
seahash::hash(specialisation.as_bytes())
307315
)
308316
};
309-
310-
let options = vec![
317+
let mut options = vec![
311318
CString::new("--entry").unwrap(),
312319
CString::new(kernel_name).unwrap(),
313-
CString::new("--verbose").unwrap(),
314-
CString::new("--warn-on-double-precision-use").unwrap(),
315-
CString::new("--warn-on-local-memory-usage").unwrap(),
316-
CString::new("--warn-on-spills").unwrap(),
317320
];
321+
322+
if ptx_lint_levels
323+
.values()
324+
.any(|level| *level > LintLevel::Warn)
325+
{
326+
let mut options = options.clone();
327+
328+
if ptx_lint_levels
329+
.get(&PtxLint::Verbose)
330+
.map_or(false, |level| *level > LintLevel::Warn)
331+
{
332+
options.push(CString::new("--verbose").unwrap());
333+
}
334+
if ptx_lint_levels
335+
.get(&PtxLint::DoublePrecisionUse)
336+
.map_or(false, |level| *level > LintLevel::Warn)
337+
{
338+
options.push(CString::new("--warn-on-double-precision-use").unwrap());
339+
}
340+
if ptx_lint_levels
341+
.get(&PtxLint::LocalMemoryUsage)
342+
.map_or(false, |level| *level > LintLevel::Warn)
343+
{
344+
options.push(CString::new("--warn-on-local-memory-usage").unwrap());
345+
}
346+
if ptx_lint_levels
347+
.get(&PtxLint::RegisterSpills)
348+
.map_or(false, |level| *level > LintLevel::Warn)
349+
{
350+
options.push(CString::new("--warn-on-spills").unwrap());
351+
}
352+
options.push(CString::new("--warning-as-error").unwrap());
353+
354+
let options_ptrs = options.iter().map(|o| o.as_ptr()).collect::<Vec<_>>();
355+
356+
NvptxError::try_err_from(unsafe {
357+
ptx_compiler_sys::nvPTXCompilerCompile(
358+
compiler,
359+
options_ptrs.len() as c_int,
360+
options_ptrs.as_ptr().cast(),
361+
)
362+
})?;
363+
};
364+
365+
if ptx_lint_levels
366+
.get(&PtxLint::Verbose)
367+
.map_or(false, |level| *level > LintLevel::Allow)
368+
{
369+
options.push(CString::new("--verbose").unwrap());
370+
}
371+
if ptx_lint_levels
372+
.get(&PtxLint::DoublePrecisionUse)
373+
.map_or(false, |level| *level > LintLevel::Allow)
374+
{
375+
options.push(CString::new("--warn-on-double-precision-use").unwrap());
376+
}
377+
if ptx_lint_levels
378+
.get(&PtxLint::LocalMemoryUsage)
379+
.map_or(false, |level| *level > LintLevel::Allow)
380+
{
381+
options.push(CString::new("--warn-on-local-memory-usage").unwrap());
382+
}
383+
if ptx_lint_levels
384+
.get(&PtxLint::RegisterSpills)
385+
.map_or(false, |level| *level > LintLevel::Allow)
386+
{
387+
options.push(CString::new("--warn-on-spills").unwrap());
388+
}
389+
318390
let options_ptrs = options.iter().map(|o| o.as_ptr()).collect::<Vec<_>>();
319391

320392
NvptxError::try_err_from(unsafe {
@@ -324,7 +396,7 @@ fn check_kernel_ptx(
324396
options_ptrs.as_ptr().cast(),
325397
)
326398
})
327-
};
399+
})();
328400

329401
let error_log = (|| {
330402
let mut error_log_size = 0;

rust-cuda-derive/src/kernel/lints.rs

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
use std::{collections::HashMap, fmt};
2+
3+
use syn::spanned::Spanned;
4+
5+
pub fn parse_ptx_lint_level(
6+
path: &syn::Path,
7+
nested: &syn::punctuated::Punctuated<syn::NestedMeta, syn::token::Comma>,
8+
ptx_lint_levels: &mut HashMap<PtxLint, LintLevel>,
9+
) {
10+
let level = match path.get_ident() {
11+
Some(ident) if ident == "allow" => LintLevel::Allow,
12+
Some(ident) if ident == "warn" => LintLevel::Warn,
13+
Some(ident) if ident == "deny" => LintLevel::Deny,
14+
Some(ident) if ident == "forbid" => LintLevel::Forbid,
15+
_ => {
16+
emit_error!(
17+
path.span(),
18+
"[rust-cuda]: Invalid lint #[kernel(<level>(<lint>))] attribute: unknown lint \
19+
level, must be one of `allow`, `warn`, `deny`, `forbid`.",
20+
);
21+
22+
return;
23+
},
24+
};
25+
26+
for meta in nested {
27+
let syn::NestedMeta::Meta(syn::Meta::Path(path)) = meta else {
28+
emit_error!(
29+
meta.span(),
30+
"[rust-cuda]: Invalid #[kernel({}(<lint>))] attribute.",
31+
level,
32+
);
33+
continue;
34+
};
35+
36+
if path.leading_colon.is_some()
37+
|| path.segments.empty_or_trailing()
38+
|| path.segments.len() != 2
39+
{
40+
emit_error!(
41+
meta.span(),
42+
"[rust-cuda]: Invalid #[kernel({}(<lint>))] attribute: <lint> must be of the form \
43+
`ptx::lint`.",
44+
level,
45+
);
46+
continue;
47+
}
48+
49+
let Some(syn::PathSegment { ident: namespace, arguments: syn::PathArguments::None }) = path.segments.first() else {
50+
emit_error!(
51+
meta.span(),
52+
"[rust-cuda]: Invalid #[kernel({}(<lint>))] attribute: <lint> must be of the form `ptx::lint`.",
53+
level,
54+
);
55+
continue;
56+
};
57+
58+
if namespace != "ptx" {
59+
emit_error!(
60+
meta.span(),
61+
"[rust-cuda]: Invalid #[kernel({}(<lint>))] attribute: <lint> must be of the form \
62+
`ptx::lint`.",
63+
level,
64+
);
65+
continue;
66+
}
67+
68+
let Some(syn::PathSegment { ident: lint, arguments: syn::PathArguments::None }) = path.segments.last() else {
69+
emit_error!(
70+
meta.span(),
71+
"[rust-cuda]: Invalid #[kernel({}(<lint>))] attribute: <lint> must be of the form `ptx::lint`.",
72+
level,
73+
);
74+
continue;
75+
};
76+
77+
let lint = match lint {
78+
l if l == "verbose" => PtxLint::Verbose,
79+
l if l == "double_precision_use" => PtxLint::DoublePrecisionUse,
80+
l if l == "local_memory_usage" => PtxLint::LocalMemoryUsage,
81+
l if l == "register_spills" => PtxLint::RegisterSpills,
82+
_ => {
83+
emit_error!(
84+
meta.span(),
85+
"[rust-cuda]: Unknown PTX kernel lint `ptx::{}`.",
86+
lint,
87+
);
88+
continue;
89+
},
90+
};
91+
92+
match ptx_lint_levels.get(&lint) {
93+
None => (),
94+
Some(LintLevel::Forbid) if level < LintLevel::Forbid => {
95+
emit_error!(
96+
meta.span(),
97+
"[rust-cuda]: {}(ptx::{}) incompatible with previous forbid.",
98+
level,
99+
lint,
100+
);
101+
continue;
102+
},
103+
Some(previous) => {
104+
emit_warning!(
105+
meta.span(),
106+
"[rust-cuda]: {}(ptx::{}) overwrites previous {}.",
107+
level,
108+
lint,
109+
previous,
110+
);
111+
},
112+
}
113+
114+
ptx_lint_levels.insert(lint, level);
115+
}
116+
}
117+
118+
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
119+
pub enum LintLevel {
120+
Allow,
121+
Warn,
122+
Deny,
123+
Forbid,
124+
}
125+
126+
impl fmt::Display for LintLevel {
127+
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
128+
match self {
129+
Self::Allow => fmt.write_str("allow"),
130+
Self::Warn => fmt.write_str("warn"),
131+
Self::Deny => fmt.write_str("deny"),
132+
Self::Forbid => fmt.write_str("forbid"),
133+
}
134+
}
135+
}
136+
137+
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
138+
pub enum PtxLint {
139+
Verbose,
140+
DoublePrecisionUse,
141+
LocalMemoryUsage,
142+
RegisterSpills,
143+
}
144+
145+
impl fmt::Display for PtxLint {
146+
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
147+
match self {
148+
Self::Verbose => fmt.write_str("verbose"),
149+
Self::DoublePrecisionUse => fmt.write_str("double_precision_use"),
150+
Self::LocalMemoryUsage => fmt.write_str("local_memory_usage"),
151+
Self::RegisterSpills => fmt.write_str("register_spills"),
152+
}
153+
}
154+
}

rust-cuda-derive/src/kernel/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@ pub mod link;
22
pub mod specialise;
33
pub mod wrapper;
44

5+
mod lints;
56
mod utils;

rust-cuda-derive/src/kernel/wrapper/generate/cpu_linker_macro/get_ptx_str.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use crate::kernel::utils::skip_kernel_compilation;
55

66
use super::super::super::{DeclGenerics, FuncIdent, FunctionInputs, InputCudaType, KernelConfig};
77

8+
#[allow(clippy::too_many_arguments)]
89
pub(super) fn quote_get_ptx_str(
910
crate_path: &syn::Path,
1011
FuncIdent {
@@ -21,6 +22,7 @@ pub(super) fn quote_get_ptx_str(
2122
inputs: &FunctionInputs,
2223
func_params: &[syn::Ident],
2324
macro_type_ids: &[syn::Ident],
25+
ptx_lint_levels: &TokenStream,
2426
) -> TokenStream {
2527
let crate_name = match proc_macro::tracked_env::var("CARGO_CRATE_NAME") {
2628
Ok(crate_name) => crate_name.to_uppercase(),
@@ -80,7 +82,7 @@ pub(super) fn quote_get_ptx_str(
8082
#crate_path::host::link_kernel!{
8183
#func_ident #func_ident_hash #args #crate_name #crate_manifest_dir #generic_start_token
8284
#($#macro_type_ids),*
83-
#generic_close_token
85+
#generic_close_token #ptx_lint_levels
8486
}
8587

8688
#matching_kernel_assert

0 commit comments

Comments
 (0)