1
- use std:: hash:: { Hash , Hasher } ;
1
+ use std:: {
2
+ collections:: HashMap ,
3
+ fmt,
4
+ hash:: { Hash , Hasher } ,
5
+ } ;
2
6
3
7
use proc_macro:: TokenStream ;
4
8
@@ -41,6 +45,7 @@ pub fn kernel(attr: TokenStream, func: TokenStream) -> TokenStream {
41
45
let mut func = parse_kernel_fn ( func) ;
42
46
43
47
let mut crate_path = None ;
48
+ let mut lint_levels = HashMap :: new ( ) ;
44
49
45
50
func. attrs . retain ( |attr| {
46
51
if attr. path . is_ident ( "kernel" ) {
@@ -58,7 +63,7 @@ pub fn kernel(attr: TokenStream, func: TokenStream) -> TokenStream {
58
63
syn:: parse_quote_spanned! { s. span( ) => #new_crate_path } ,
59
64
) ;
60
65
61
- return false ;
66
+ continue ;
62
67
}
63
68
64
69
emit_error ! (
@@ -73,18 +78,114 @@ pub fn kernel(attr: TokenStream, func: TokenStream) -> TokenStream {
73
78
err
74
79
) ,
75
80
} ,
81
+ syn:: NestedMeta :: Meta ( syn:: Meta :: List ( syn:: MetaList {
82
+ path,
83
+ nested,
84
+ ..
85
+ } ) ) if path. is_ident ( "allow" ) || path. is_ident ( "warn" ) || path. is_ident ( "deny" ) || path. is_ident ( "forbid" ) => {
86
+ let level = match path. get_ident ( ) {
87
+ Some ( ident) if ident == "allow" => LintLevel :: Allow ,
88
+ Some ( ident) if ident == "warn" => LintLevel :: Warn ,
89
+ Some ( ident) if ident == "deny" => LintLevel :: Deny ,
90
+ Some ( ident) if ident == "forbid" => LintLevel :: Forbid ,
91
+ _ => unreachable ! ( ) ,
92
+ } ;
93
+
94
+ for meta in nested {
95
+ let syn:: NestedMeta :: Meta ( syn:: Meta :: Path ( path) ) = meta else {
96
+ emit_error ! (
97
+ meta. span( ) ,
98
+ "[rust-cuda]: Invalid #[kernel({}(<lint>))] attribute." ,
99
+ level,
100
+ ) ;
101
+ continue ;
102
+ } ;
103
+
104
+ if path. leading_colon . is_some ( ) || path. segments . empty_or_trailing ( ) || path. segments . len ( ) != 2 {
105
+ emit_error ! (
106
+ meta. span( ) ,
107
+ "[rust-cuda]: Invalid #[kernel({}(<lint>))] attribute: <lint> must be of the form `ptx::lint`." ,
108
+ level,
109
+ ) ;
110
+ continue ;
111
+ }
112
+
113
+ let Some ( syn:: PathSegment { ident : namespace, arguments : syn:: PathArguments :: None } ) = path. segments . first ( ) else {
114
+ emit_error ! (
115
+ meta. span( ) ,
116
+ "[rust-cuda]: Invalid #[kernel({}(<lint>))] attribute: <lint> must be of the form `ptx::lint`." ,
117
+ level,
118
+ ) ;
119
+ continue ;
120
+ } ;
121
+
122
+ if namespace != "ptx" {
123
+ emit_error ! (
124
+ meta. span( ) ,
125
+ "[rust-cuda]: Invalid #[kernel({}(<lint>))] attribute: <lint> must be of the form `ptx::lint`." ,
126
+ level,
127
+ ) ;
128
+ continue ;
129
+ }
130
+
131
+ let Some ( syn:: PathSegment { ident : lint, arguments : syn:: PathArguments :: None } ) = path. segments . last ( ) else {
132
+ emit_error ! (
133
+ meta. span( ) ,
134
+ "[rust-cuda]: Invalid #[kernel({}(<lint>))] attribute: <lint> must be of the form `ptx::lint`." ,
135
+ level,
136
+ ) ;
137
+ continue ;
138
+ } ;
139
+
140
+ let lint = match lint {
141
+ l if l == "verbose" => PtxLint :: Verbose ,
142
+ l if l == "double_precision_use" => PtxLint :: DoublePrecisionUse ,
143
+ l if l == "local_memory_usage" => PtxLint :: LocalMemoryUsage ,
144
+ l if l == "register_spills" => PtxLint :: RegisterSpills ,
145
+ _ => {
146
+ emit_error ! (
147
+ meta. span( ) ,
148
+ "[rust-cuda]: Unknown PTX kernel lint `ptx::{}`." ,
149
+ lint,
150
+ ) ;
151
+ continue ;
152
+ }
153
+ } ;
154
+
155
+ match lint_levels. get ( & lint) {
156
+ None => ( ) ,
157
+ Some ( LintLevel :: Forbid ) if level < LintLevel :: Forbid => {
158
+ emit_error ! (
159
+ meta. span( ) ,
160
+ "[rust-cuda]: {}(ptx::{}) incompatible with previous forbid." ,
161
+ level, lint,
162
+ ) ;
163
+ continue ;
164
+ } ,
165
+ Some ( previous) => {
166
+ emit_warning ! (
167
+ meta. span( ) ,
168
+ "[rust-cuda]: {}(ptx::{}) overwrites previous {}." ,
169
+ level, lint, previous,
170
+ ) ;
171
+ }
172
+ }
173
+
174
+ lint_levels. insert ( lint, level) ;
175
+ }
176
+ } ,
76
177
_ => {
77
178
emit_error ! (
78
179
meta. span( ) ,
79
- "[rust-cuda]: Expected #[kernel(crate = \" <crate-path>\" )] function attribute."
180
+ "[rust-cuda]: Expected #[kernel(crate = \" <crate-path>\" )] or #[kernel(allow/warn/deny/forbid(<lint>))] function attribute."
80
181
) ;
81
182
}
82
183
}
83
184
}
84
185
} else {
85
186
emit_error ! (
86
187
attr. span( ) ,
87
- "[rust-cuda]: Expected #[kernel(crate = \" <crate-path>\" )] function attribute."
188
+ "[rust-cuda]: Expected #[kernel(crate = \" <crate-path>\" )] or or #[kernel(allow/warn/deny/forbid(<lint>))] function attribute."
88
189
) ;
89
190
}
90
191
@@ -96,6 +197,10 @@ pub fn kernel(attr: TokenStream, func: TokenStream) -> TokenStream {
96
197
97
198
let crate_path = crate_path. unwrap_or_else ( || syn:: parse_quote!( :: rust_cuda) ) ;
98
199
200
+ let _ = lint_levels. try_insert ( PtxLint :: DoublePrecisionUse , LintLevel :: Warn ) ;
201
+ let _ = lint_levels. try_insert ( PtxLint :: LocalMemoryUsage , LintLevel :: Warn ) ;
202
+ let _ = lint_levels. try_insert ( PtxLint :: RegisterSpills , LintLevel :: Warn ) ;
203
+
99
204
let mut generic_kernel_params = func. sig . generics . params . clone ( ) ;
100
205
let mut func_inputs = parse_function_inputs ( & func, & mut generic_kernel_params) ;
101
206
@@ -338,6 +443,44 @@ struct FuncIdent<'f> {
338
443
func_ident_hash : syn:: Ident ,
339
444
}
340
445
446
+ #[ derive( Clone , Copy , PartialEq , Eq , PartialOrd , Ord , Hash , Debug ) ]
447
+ enum LintLevel {
448
+ Allow ,
449
+ Warn ,
450
+ Deny ,
451
+ Forbid ,
452
+ }
453
+
454
+ impl fmt:: Display for LintLevel {
455
+ fn fmt ( & self , fmt : & mut fmt:: Formatter ) -> fmt:: Result {
456
+ match self {
457
+ Self :: Allow => fmt. write_str ( "allow" ) ,
458
+ Self :: Warn => fmt. write_str ( "warn" ) ,
459
+ Self :: Deny => fmt. write_str ( "deny" ) ,
460
+ Self :: Forbid => fmt. write_str ( "forbid" ) ,
461
+ }
462
+ }
463
+ }
464
+
465
+ #[ derive( Clone , Copy , PartialEq , Eq , PartialOrd , Ord , Hash , Debug ) ]
466
+ enum PtxLint {
467
+ Verbose ,
468
+ DoublePrecisionUse ,
469
+ LocalMemoryUsage ,
470
+ RegisterSpills ,
471
+ }
472
+
473
+ impl fmt:: Display for PtxLint {
474
+ fn fmt ( & self , fmt : & mut fmt:: Formatter ) -> fmt:: Result {
475
+ match self {
476
+ Self :: Verbose => fmt. write_str ( "verbose" ) ,
477
+ Self :: DoublePrecisionUse => fmt. write_str ( "double_precision_use" ) ,
478
+ Self :: LocalMemoryUsage => fmt. write_str ( "local_memory_usage" ) ,
479
+ Self :: RegisterSpills => fmt. write_str ( "register_spills" ) ,
480
+ }
481
+ }
482
+ }
483
+
341
484
fn ident_from_pat ( pat : & syn:: Pat ) -> Option < syn:: Ident > {
342
485
match pat {
343
486
syn:: Pat :: Lit ( _)
0 commit comments