@@ -32,31 +32,38 @@ use error::emit_ptx_build_error;
32
32
use ptx_compiler_sys:: NvptxError ;
33
33
34
34
pub fn check_kernel ( tokens : TokenStream ) -> TokenStream {
35
- proc_macro_error:: set_dummy ( quote ! {
36
- "ERROR in this PTX compilation"
37
- } ) ;
35
+ proc_macro_error:: set_dummy ( quote ! { :: core:: result:: Result :: Err ( ( ) ) } ) ;
38
36
39
37
let CheckKernelConfig {
38
+ kernel_hash,
40
39
args,
41
40
crate_name,
42
41
crate_path,
43
42
} = match syn:: parse_macro_input:: parse ( tokens) {
44
43
Ok ( config) => config,
45
44
Err ( err) => {
46
45
abort_call_site ! (
47
- "check_kernel!(ARGS NAME PATH) expects ARGS identifier, NAME and PATH string \
48
- literals: {:?}",
46
+ "check_kernel!(HASH ARGS NAME PATH) expects HASH and ARGS identifiers, annd NAME \
47
+ and PATH string literals: {:?}",
49
48
err
50
49
)
51
50
} ,
52
51
} ;
53
52
54
53
let kernel_ptx = compile_kernel ( & args, & crate_name, & crate_path, Specialisation :: Check ) ;
55
54
56
- match kernel_ptx {
57
- Some ( kernel_ptx) => quote ! ( #kernel_ptx) . into ( ) ,
58
- None => quote ! ( "ERROR in this PTX compilation" ) . into ( ) ,
59
- }
55
+ let Some ( kernel_ptx) = kernel_ptx else {
56
+ return quote ! ( :: core:: result:: Result :: Err ( ( ) ) ) . into ( )
57
+ } ;
58
+
59
+ check_kernel_ptx_and_report (
60
+ & kernel_ptx,
61
+ Specialisation :: Check ,
62
+ & kernel_hash,
63
+ & HashMap :: new ( ) ,
64
+ ) ;
65
+
66
+ quote ! ( :: core:: result:: Result :: Ok ( ( ) ) ) . into ( )
60
67
}
61
68
62
69
#[ allow( clippy:: module_name_repetitions, clippy:: too_many_lines) ]
@@ -77,9 +84,9 @@ pub fn link_kernel(tokens: TokenStream) -> TokenStream {
77
84
Ok ( config) => config,
78
85
Err ( err) => {
79
86
abort_call_site ! (
80
- "link_kernel!(KERNEL ARGS NAME PATH SPECIALISATION LINTS,*) expects KERNEL and \
81
- ARGS identifiers, NAME and PATH string literals, SPECIALISATION and LINTS \
82
- tokens: {:?}",
87
+ "link_kernel!(KERNEL HASH ARGS NAME PATH SPECIALISATION LINTS,*) expects KERNEL, \
88
+ HASH, and ARGS identifiers, NAME and PATH string literals, and SPECIALISATION \
89
+ and LINTS tokens: {:?}",
83
90
err
84
91
)
85
92
} ,
@@ -206,88 +213,162 @@ pub fn link_kernel(tokens: TokenStream) -> TokenStream {
206
213
kernel_ptx. replace_range ( type_layout_start..type_layout_end, "" ) ;
207
214
}
208
215
209
- let ( result, error_log, info_log, version, drop) =
210
- check_kernel_ptx ( & kernel_ptx, & specialisation, & kernel_hash, & ptx_lint_levels) ;
216
+ check_kernel_ptx_and_report (
217
+ & kernel_ptx,
218
+ Specialisation :: Link ( & specialisation) ,
219
+ & kernel_hash,
220
+ & ptx_lint_levels,
221
+ ) ;
222
+
223
+ ( quote ! { const PTX_STR : & ' static str = #kernel_ptx; #( #type_layouts) * } ) . into ( )
224
+ }
225
+
226
+ #[ allow( clippy:: too_many_lines) ]
227
+ fn check_kernel_ptx_and_report (
228
+ kernel_ptx : & str ,
229
+ specialisation : Specialisation ,
230
+ kernel_hash : & proc_macro2:: Ident ,
231
+ ptx_lint_levels : & HashMap < PtxLint , LintLevel > ,
232
+ ) {
233
+ let ( result, error_log, info_log, binary, version, drop) =
234
+ check_kernel_ptx ( kernel_ptx, specialisation, kernel_hash, ptx_lint_levels) ;
211
235
212
236
let ptx_compiler = match & version {
213
237
Ok ( ( major, minor) ) => format ! ( "PTX compiler v{major}.{minor}" ) ,
214
238
Err ( _) => String :: from ( "PTX compiler" ) ,
215
239
} ;
216
240
217
- // TODO: allow user to select
218
- // - warn on double
219
- // - warn on float
220
- // - warn on spills
221
- // - verbose warn
222
- // - warnings as errors
223
- // - show PTX source if warning or error
224
-
225
241
let mut errors = String :: new ( ) ;
242
+
226
243
if let Err ( err) = drop {
227
244
let _ = errors. write_fmt ( format_args ! ( "Error dropping the {ptx_compiler}: {err}\n " ) ) ;
228
245
}
246
+
229
247
if let Err ( err) = version {
230
248
let _ = errors. write_fmt ( format_args ! (
231
249
"Error fetching the version of the {ptx_compiler}: {err}\n "
232
250
) ) ;
233
251
}
234
- if let ( Ok ( Some ( _) ) , _) | ( _, Ok ( Some ( _) ) ) = ( & info_log, & error_log) {
252
+
253
+ let ptx_source_code = {
235
254
let mut max_lines = kernel_ptx. chars ( ) . filter ( |c| * c == '\n' ) . count ( ) + 1 ;
236
255
let mut indent = 0 ;
237
256
while max_lines > 0 {
238
257
max_lines /= 10 ;
239
258
indent += 1 ;
240
259
}
241
260
242
- emit_call_site_warning ! (
261
+ format ! (
243
262
"PTX source code:\n {}" ,
244
263
kernel_ptx
245
264
. lines( )
246
265
. enumerate( )
247
266
. map( |( i, l) | format!( "{:indent$}| {l}" , i + 1 ) )
248
267
. collect:: <Vec <_>>( )
249
268
. join( "\n " )
250
- ) ;
269
+ )
270
+ } ;
271
+
272
+ match binary {
273
+ Ok ( None ) => ( ) ,
274
+ Ok ( Some ( binary) ) => {
275
+ if ptx_lint_levels
276
+ . get ( & PtxLint :: DumpBinary )
277
+ . map_or ( false , |level| * level > LintLevel :: Allow )
278
+ {
279
+ const HEX : [ char ; 16 ] = [
280
+ '0' , '1' , '2' , '3' , '4' , '5' , '6' , '7' , '8' , '9' , 'a' , 'b' , 'c' , 'd' , 'e' , 'f' ,
281
+ ] ;
282
+
283
+ let mut binary_hex = String :: with_capacity ( binary. len ( ) * 2 ) ;
284
+ for byte in binary {
285
+ binary_hex. push ( HEX [ usize:: from ( byte >> 4 ) ] ) ;
286
+ binary_hex. push ( HEX [ usize:: from ( byte & 0x0F ) ] ) ;
287
+ }
288
+
289
+ if ptx_lint_levels
290
+ . get ( & PtxLint :: DumpBinary )
291
+ . map_or ( false , |level| * level > LintLevel :: Warn )
292
+ {
293
+ emit_call_site_error ! (
294
+ "{} compiled binary:\n {}\n \n {}" ,
295
+ ptx_compiler,
296
+ binary_hex,
297
+ ptx_source_code
298
+ ) ;
299
+ } else {
300
+ emit_call_site_warning ! (
301
+ "{} compiled binary:\n {}\n \n {}" ,
302
+ ptx_compiler,
303
+ binary_hex,
304
+ ptx_source_code
305
+ ) ;
306
+ }
307
+ }
308
+ } ,
309
+ Err ( err) => {
310
+ let _ = errors. write_fmt ( format_args ! (
311
+ "Error fetching the compiled binary from {ptx_compiler}: {err}\n "
312
+ ) ) ;
313
+ } ,
251
314
}
315
+
252
316
match info_log {
253
317
Ok ( None ) => ( ) ,
254
- Ok ( Some ( info_log) ) => emit_call_site_warning ! ( "{ptx_compiler} info log:\n {}" , info_log) ,
318
+ Ok ( Some ( info_log) ) => emit_call_site_warning ! (
319
+ "{} info log:\n {}\n {}" ,
320
+ ptx_compiler,
321
+ info_log,
322
+ ptx_source_code
323
+ ) ,
255
324
Err ( err) => {
256
325
let _ = errors. write_fmt ( format_args ! (
257
326
"Error fetching the info log of the {ptx_compiler}: {err}\n "
258
327
) ) ;
259
328
} ,
260
329
} ;
261
- match error_log {
262
- Ok ( None ) => ( ) ,
263
- Ok ( Some ( error_log) ) => emit_call_site_error ! ( "{ptx_compiler} error log:\n {}" , error_log) ,
330
+
331
+ let error_log = match error_log {
332
+ Ok ( None ) => String :: new ( ) ,
333
+ Ok ( Some ( error_log) ) => {
334
+ format ! ( "{ptx_compiler} error log:\n {error_log}\n {ptx_source_code}" )
335
+ } ,
264
336
Err ( err) => {
265
337
let _ = errors. write_fmt ( format_args ! (
266
338
"Error fetching the error log of the {ptx_compiler}: {err}\n "
267
339
) ) ;
340
+ String :: new ( )
268
341
} ,
269
342
} ;
343
+
270
344
if let Err ( err) = result {
271
345
let _ = errors. write_fmt ( format_args ! ( "Error compiling the PTX source code: {err}\n " ) ) ;
272
346
}
273
- if !errors. is_empty ( ) {
274
- abort_call_site ! ( "{}" , errors) ;
275
- }
276
347
277
- ( quote ! { const PTX_STR : & ' static str = #kernel_ptx; #( #type_layouts) * } ) . into ( )
348
+ if !error_log. is_empty ( ) || !errors. is_empty ( ) {
349
+ abort_call_site ! (
350
+ "{error_log}{}{errors}" ,
351
+ if !error_log. is_empty( ) && !errors. is_empty( ) {
352
+ "\n \n "
353
+ } else {
354
+ ""
355
+ }
356
+ ) ;
357
+ }
278
358
}
279
359
280
360
#[ allow( clippy:: type_complexity) ]
281
361
#[ allow( clippy:: too_many_lines) ]
282
362
fn check_kernel_ptx (
283
363
kernel_ptx : & str ,
284
- specialisation : & str ,
364
+ specialisation : Specialisation ,
285
365
kernel_hash : & proc_macro2:: Ident ,
286
366
ptx_lint_levels : & HashMap < PtxLint , LintLevel > ,
287
367
) -> (
288
368
Result < ( ) , NvptxError > ,
289
369
Result < Option < String > , NvptxError > ,
290
370
Result < Option < String > , NvptxError > ,
371
+ Result < Option < Vec < u8 > > , NvptxError > ,
291
372
Result < ( u32 , u32 ) , NvptxError > ,
292
373
Result < ( ) , NvptxError > ,
293
374
) {
@@ -306,14 +387,15 @@ fn check_kernel_ptx(
306
387
} ;
307
388
308
389
let result = ( || {
309
- let kernel_name = if specialisation. is_empty ( ) {
310
- format ! ( "{kernel_hash}_kernel" )
311
- } else {
312
- format ! (
390
+ let kernel_name = match specialisation {
391
+ Specialisation :: Check => format ! ( "{kernel_hash}_chECK" ) ,
392
+ Specialisation :: Link ( "" ) => format ! ( "{kernel_hash}_kernel" ) ,
393
+ Specialisation :: Link ( specialisation ) => format ! (
313
394
"{kernel_hash}_kernel_{:016x}" ,
314
395
seahash:: hash( specialisation. as_bytes( ) )
315
- )
396
+ ) ,
316
397
} ;
398
+
317
399
let mut options = vec ! [
318
400
CString :: new( "--entry" ) . unwrap( ) ,
319
401
CString :: new( kernel_name) . unwrap( ) ,
@@ -450,6 +532,39 @@ fn check_kernel_ptx(
450
532
Ok ( Some ( String :: from_utf8_lossy ( & info_log) . into_owned ( ) ) )
451
533
} ) ( ) ;
452
534
535
+ let binary = ( || {
536
+ if result. is_err ( ) {
537
+ return Ok ( None ) ;
538
+ }
539
+
540
+ let mut binary_size = 0 ;
541
+
542
+ NvptxError :: try_err_from ( unsafe {
543
+ ptx_compiler_sys:: nvPTXCompilerGetCompiledProgramSize (
544
+ compiler,
545
+ addr_of_mut ! ( binary_size) ,
546
+ )
547
+ } ) ?;
548
+
549
+ if binary_size == 0 {
550
+ return Ok ( None ) ;
551
+ }
552
+
553
+ #[ allow( clippy:: cast_possible_truncation) ]
554
+ let mut binary: Vec < u8 > = Vec :: with_capacity ( binary_size as usize ) ;
555
+
556
+ NvptxError :: try_err_from ( unsafe {
557
+ ptx_compiler_sys:: nvPTXCompilerGetCompiledProgram ( compiler, binary. as_mut_ptr ( ) . cast ( ) )
558
+ } ) ?;
559
+
560
+ #[ allow( clippy:: cast_possible_truncation) ]
561
+ unsafe {
562
+ binary. set_len ( binary_size as usize ) ;
563
+ }
564
+
565
+ Ok ( Some ( binary) )
566
+ } ) ( ) ;
567
+
453
568
let version = ( || {
454
569
let mut major = 0 ;
455
570
let mut minor = 0 ;
@@ -468,7 +583,7 @@ fn check_kernel_ptx(
468
583
} )
469
584
} ;
470
585
471
- ( result, error_log, info_log, version, drop)
586
+ ( result, error_log, info_log, binary , version, drop)
472
587
}
473
588
474
589
fn compile_kernel (
0 commit comments