Skip to content

Commit 4b3f2d0

Browse files
committed
Add PTX lint parsing, no actual support yet
1 parent 3eba1a2 commit 4b3f2d0

File tree

4 files changed

+182
-38
lines changed

4 files changed

+182
-38
lines changed

examples/single-source/src/main.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ pub struct Tuple(u32, i32);
4545

4646
#[rc::common::kernel(use link_kernel! as impl Kernel<KernelArgs, KernelPtx> for Launcher)]
4747
#[kernel(crate = "rc")]
48+
#[kernel(
49+
allow(ptx::double_precision_use),
50+
forbid(ptx::local_memory_usage, ptx::register_spills)
51+
)]
4852
pub fn kernel<'a, T: rc::common::RustToCuda>(
4953
#[kernel(pass = SafeDeviceCopy)] _x: &Dummy,
5054
#[kernel(pass = LendRustToCuda, jit)] _y: &mut ShallowCopy<Wrapper<T>>,

rust-cuda-derive/src/kernel/link/ptx_compiler_sys.rs

Lines changed: 30 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,15 @@ pub type NvptxCompileResult = ::std::os::raw::c_int;
7474

7575
extern "C" {
7676
/// Queries the current major and minor version of PTX Compiler APIs being
77-
/// used
77+
/// used.
7878
///
7979
/// # Parameters
8080
/// - [out] `major`: Major version of the PTX Compiler APIs
8181
/// - [out] `minor`: Minor version of the PTX Compiler APIs
8282
///
83-
/// # Return
84-
/// - [`NvptxCompileResult`]::`NVPTXCOMPILE_SUCCESS`
85-
/// - [`NvptxCompileResult`]::`NVPTXCOMPILE_ERROR_INTERNAL`
83+
/// # Returns
84+
/// - [`NvptxCompileResult::NVPTXCOMPILE_SUCCESS`]
85+
/// - [`NvptxCompileResult::NVPTXCOMPILE_ERROR_INTERNAL`]
8686
///
8787
/// # Note
8888
/// The version of PTX Compiler APIs follows the CUDA Toolkit versioning.
@@ -93,42 +93,38 @@ extern "C" {
9393
minor: *mut ::std::os::raw::c_uint,
9494
) -> NvptxCompileResult;
9595

96-
#[doc = " \\ingroup compilation"]
97-
#[doc = ""]
98-
#[doc = " \\brief Obtains the handle to an instance of the PTX compiler"]
99-
#[doc = " initialized with the given PTX program \\p ptxCode"]
100-
#[doc = ""]
101-
#[doc = " \\param [out] compiler Returns a handle to PTX compiler initialized"]
102-
#[doc = " with the PTX program \\p ptxCode"]
103-
#[doc = " \\param [in] ptxCodeLen Size of the PTX program \\p ptxCode passed as \
104-
string"]
105-
#[doc = " \\param [in] ptxCode The PTX program which is to be compiled passed as \
106-
string."]
107-
#[doc = ""]
108-
#[doc = ""]
109-
#[doc = " \\return"]
110-
#[doc = " - \\link #nvPTXCompileResult NVPTXCOMPILE_SUCCESS \\endlink"]
111-
#[doc = " - \\link #nvPTXCompileResult NVPTXCOMPILE_ERROR_OUT_OF_MEMORY \\endlink"]
112-
#[doc = " - \\link #nvPTXCompileResult NVPTXCOMPILE_ERROR_INTERNAL \\endlink"]
96+
/// Obtains the handle to an instance of the PTX compiler
97+
/// initialized with the given PTX program `ptxCode`.
98+
///
99+
/// # Parameters
100+
/// - [out] `compiler`: Returns a handle to PTX compiler initialized with
101+
/// the PTX program `ptxCode`
102+
/// - [in] `ptxCodeLen`: Size of the PTX program `ptxCode` passed as a
103+
/// string
104+
/// - [in] `ptxCode`: The PTX program which is to be compiled passed as a
105+
/// string
106+
///
107+
/// # Returns
108+
/// - [`NvptxCompileResult::NVPTXCOMPILE_SUCCESS`]
109+
/// - [`NvptxCompileResult::NVPTXCOMPILE_ERROR_OUT_OF_MEMORY`]
110+
/// - [`NvptxCompileResult::NVPTXCOMPILE_ERROR_INTERNAL`]
113111
pub fn nvPTXCompilerCreate(
114112
compiler: *mut NvptxCompilerHandle,
115113
ptxCodeLen: size_t,
116114
ptxCode: *const ::std::os::raw::c_char,
117115
) -> NvptxCompileResult;
118116

119-
#[doc = " \\ingroup compilation"]
120-
#[doc = ""]
121-
#[doc = " \\brief Destroys and cleans the already created PTX compiler"]
122-
#[doc = ""]
123-
#[doc = " \\param [in] compiler A handle to the PTX compiler which is to be \
124-
destroyed"]
125-
#[doc = ""]
126-
#[doc = " \\return"]
127-
#[doc = " - \\link #nvPTXCompileResult NVPTXCOMPILE_SUCCESS \\endlink"]
128-
#[doc = " - \\link #nvPTXCompileResult NVPTXCOMPILE_ERROR_OUT_OF_MEMORY \\endlink"]
129-
#[doc = " - \\link #nvPTXCompileResult NVPTXCOMPILE_ERROR_INTERNAL \\endlink"]
130-
#[doc = " - \\link #nvPTXCompileResult NVPTXCOMPILE_ERROR_INVALID_PROGRAM_HANDLE \\endlink"]
131-
#[doc = ""]
117+
/// Destroys and cleans the already created PTX compiler.
118+
///
119+
/// # Parameters
120+
/// - [in] `compiler`: A handle to the PTX compiler which is to be
121+
/// destroyed.
122+
///
123+
/// # Returns
124+
/// - [`NvptxCompileResult::NVPTXCOMPILE_SUCCESS`]
125+
/// - [`NvptxCompileResult::NVPTXCOMPILE_ERROR_OUT_OF_MEMORY`]
126+
/// - [`NvptxCompileResult::NVPTXCOMPILE_ERROR_INTERNAL`]
127+
/// - [`NvptxCompileResult::NVPTXCOMPILE_ERROR_INVALID_PROGRAM_HANDLE`]
132128
pub fn nvPTXCompilerDestroy(compiler: *mut NvptxCompilerHandle) -> NvptxCompileResult;
133129

134130
#[doc = " \\ingroup compilation"]

rust-cuda-derive/src/kernel/wrapper/mod.rs

Lines changed: 147 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
use std::hash::{Hash, Hasher};
1+
use std::{
2+
collections::HashMap,
3+
fmt,
4+
hash::{Hash, Hasher},
5+
};
26

37
use proc_macro::TokenStream;
48

@@ -41,6 +45,7 @@ pub fn kernel(attr: TokenStream, func: TokenStream) -> TokenStream {
4145
let mut func = parse_kernel_fn(func);
4246

4347
let mut crate_path = None;
48+
let mut lint_levels = HashMap::new();
4449

4550
func.attrs.retain(|attr| {
4651
if attr.path.is_ident("kernel") {
@@ -58,7 +63,7 @@ pub fn kernel(attr: TokenStream, func: TokenStream) -> TokenStream {
5863
syn::parse_quote_spanned! { s.span() => #new_crate_path },
5964
);
6065

61-
return false;
66+
continue;
6267
}
6368

6469
emit_error!(
@@ -73,18 +78,114 @@ pub fn kernel(attr: TokenStream, func: TokenStream) -> TokenStream {
7378
err
7479
),
7580
},
81+
syn::NestedMeta::Meta(syn::Meta::List(syn::MetaList {
82+
path,
83+
nested,
84+
..
85+
})) if path.is_ident("allow") || path.is_ident("warn") || path.is_ident("deny") || path.is_ident("forbid") => {
86+
let level = match path.get_ident() {
87+
Some(ident) if ident == "allow" => LintLevel::Allow,
88+
Some(ident) if ident == "warn" => LintLevel::Warn,
89+
Some(ident) if ident == "deny" => LintLevel::Deny,
90+
Some(ident) if ident == "forbid" => LintLevel::Forbid,
91+
_ => unreachable!(),
92+
};
93+
94+
for meta in nested {
95+
let syn::NestedMeta::Meta(syn::Meta::Path(path)) = meta else {
96+
emit_error!(
97+
meta.span(),
98+
"[rust-cuda]: Invalid #[kernel({}(<lint>))] attribute.",
99+
level,
100+
);
101+
continue;
102+
};
103+
104+
if path.leading_colon.is_some() || path.segments.empty_or_trailing() || path.segments.len() != 2 {
105+
emit_error!(
106+
meta.span(),
107+
"[rust-cuda]: Invalid #[kernel({}(<lint>))] attribute: <lint> must be of the form `ptx::lint`.",
108+
level,
109+
);
110+
continue;
111+
}
112+
113+
let Some(syn::PathSegment { ident: namespace, arguments: syn::PathArguments::None }) = path.segments.first() else {
114+
emit_error!(
115+
meta.span(),
116+
"[rust-cuda]: Invalid #[kernel({}(<lint>))] attribute: <lint> must be of the form `ptx::lint`.",
117+
level,
118+
);
119+
continue;
120+
};
121+
122+
if namespace != "ptx" {
123+
emit_error!(
124+
meta.span(),
125+
"[rust-cuda]: Invalid #[kernel({}(<lint>))] attribute: <lint> must be of the form `ptx::lint`.",
126+
level,
127+
);
128+
continue;
129+
}
130+
131+
let Some(syn::PathSegment { ident: lint, arguments: syn::PathArguments::None }) = path.segments.last() else {
132+
emit_error!(
133+
meta.span(),
134+
"[rust-cuda]: Invalid #[kernel({}(<lint>))] attribute: <lint> must be of the form `ptx::lint`.",
135+
level,
136+
);
137+
continue;
138+
};
139+
140+
let lint = match lint {
141+
l if l == "verbose" => PtxLint::Verbose,
142+
l if l == "double_precision_use" => PtxLint::DoublePrecisionUse,
143+
l if l == "local_memory_usage" => PtxLint::LocalMemoryUsage,
144+
l if l == "register_spills" => PtxLint::RegisterSpills,
145+
_ => {
146+
emit_error!(
147+
meta.span(),
148+
"[rust-cuda]: Unknown PTX kernel lint `ptx::{}`.",
149+
lint,
150+
);
151+
continue;
152+
}
153+
};
154+
155+
match lint_levels.get(&lint) {
156+
None => (),
157+
Some(LintLevel::Forbid) if level < LintLevel::Forbid => {
158+
emit_error!(
159+
meta.span(),
160+
"[rust-cuda]: {}(ptx::{}) incompatible with previous forbid.",
161+
level, lint,
162+
);
163+
continue;
164+
},
165+
Some(previous) => {
166+
emit_warning!(
167+
meta.span(),
168+
"[rust-cuda]: {}(ptx::{}) overwrites previous {}.",
169+
level, lint, previous,
170+
);
171+
}
172+
}
173+
174+
lint_levels.insert(lint, level);
175+
}
176+
},
76177
_ => {
77178
emit_error!(
78179
meta.span(),
79-
"[rust-cuda]: Expected #[kernel(crate = \"<crate-path>\")] function attribute."
180+
"[rust-cuda]: Expected #[kernel(crate = \"<crate-path>\")] or #[kernel(allow/warn/deny/forbid(<lint>))] function attribute."
80181
);
81182
}
82183
}
83184
}
84185
} else {
85186
emit_error!(
86187
attr.span(),
87-
"[rust-cuda]: Expected #[kernel(crate = \"<crate-path>\")] function attribute."
188+
"[rust-cuda]: Expected #[kernel(crate = \"<crate-path>\")] or or #[kernel(allow/warn/deny/forbid(<lint>))] function attribute."
88189
);
89190
}
90191

@@ -96,6 +197,10 @@ pub fn kernel(attr: TokenStream, func: TokenStream) -> TokenStream {
96197

97198
let crate_path = crate_path.unwrap_or_else(|| syn::parse_quote!(::rust_cuda));
98199

200+
let _ = lint_levels.try_insert(PtxLint::DoublePrecisionUse, LintLevel::Warn);
201+
let _ = lint_levels.try_insert(PtxLint::LocalMemoryUsage, LintLevel::Warn);
202+
let _ = lint_levels.try_insert(PtxLint::RegisterSpills, LintLevel::Warn);
203+
99204
let mut generic_kernel_params = func.sig.generics.params.clone();
100205
let mut func_inputs = parse_function_inputs(&func, &mut generic_kernel_params);
101206

@@ -338,6 +443,44 @@ struct FuncIdent<'f> {
338443
func_ident_hash: syn::Ident,
339444
}
340445

446+
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
447+
enum LintLevel {
448+
Allow,
449+
Warn,
450+
Deny,
451+
Forbid,
452+
}
453+
454+
impl fmt::Display for LintLevel {
455+
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
456+
match self {
457+
Self::Allow => fmt.write_str("allow"),
458+
Self::Warn => fmt.write_str("warn"),
459+
Self::Deny => fmt.write_str("deny"),
460+
Self::Forbid => fmt.write_str("forbid"),
461+
}
462+
}
463+
}
464+
465+
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
466+
enum PtxLint {
467+
Verbose,
468+
DoublePrecisionUse,
469+
LocalMemoryUsage,
470+
RegisterSpills,
471+
}
472+
473+
impl fmt::Display for PtxLint {
474+
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
475+
match self {
476+
Self::Verbose => fmt.write_str("verbose"),
477+
Self::DoublePrecisionUse => fmt.write_str("double_precision_use"),
478+
Self::LocalMemoryUsage => fmt.write_str("local_memory_usage"),
479+
Self::RegisterSpills => fmt.write_str("register_spills"),
480+
}
481+
}
482+
}
483+
341484
fn ident_from_pat(pat: &syn::Pat) -> Option<syn::Ident> {
342485
match pat {
343486
syn::Pat::Lit(_)

rust-cuda-derive/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#![feature(non_exhaustive_omitted_patterns_lint)]
66
#![feature(if_let_guard)]
77
#![feature(let_chains)]
8+
#![feature(map_try_insert)]
89
#![doc(html_root_url = "https://juntyr.github.io/rust-cuda/")]
910

1011
extern crate proc_macro;

0 commit comments

Comments
 (0)