1
1
use std:: {
2
2
env,
3
3
ffi:: CString ,
4
+ fmt:: Write as FmtWrite ,
4
5
fs,
5
6
io:: { Read , Write } ,
6
- mem:: MaybeUninit ,
7
7
os:: raw:: c_int,
8
8
path:: { Path , PathBuf } ,
9
9
ptr:: addr_of_mut,
@@ -16,15 +16,16 @@ use ptx_builder::{
16
16
builder:: { BuildStatus , Builder , MessageFormat , Profile } ,
17
17
error:: { BuildErrorKind , Error , Result } ,
18
18
} ;
19
- use ptx_compiler:: sys:: size_t;
20
19
21
20
use super :: utils:: skip_kernel_compilation;
22
21
23
22
mod config;
24
23
mod error;
24
+ mod ptx_compiler_sys;
25
25
26
26
use config:: { CheckKernelConfig , LinkKernelConfig } ;
27
27
use error:: emit_ptx_build_error;
28
+ use ptx_compiler_sys:: NvptxError ;
28
29
29
30
pub fn check_kernel ( tokens : TokenStream ) -> TokenStream {
30
31
proc_macro_error:: set_dummy ( quote ! {
@@ -199,110 +200,41 @@ pub fn link_kernel(tokens: TokenStream) -> TokenStream {
199
200
kernel_ptx. replace_range ( type_layout_start..type_layout_end, "" ) ;
200
201
}
201
202
202
- let mut compiler = MaybeUninit :: uninit ( ) ;
203
- let r = unsafe {
204
- ptx_compiler:: sys:: nvPTXCompilerCreate (
205
- compiler. as_mut_ptr ( ) ,
206
- kernel_ptx. len ( ) as size_t ,
207
- kernel_ptx. as_ptr ( ) . cast ( ) ,
208
- )
209
- } ;
210
- emit_call_site_warning ! ( "PTX compiler create result {}" , r) ;
211
- let compiler = unsafe { compiler. assume_init ( ) } ;
212
-
213
- let mut major = 0 ;
214
- let mut minor = 0 ;
215
- let r = unsafe {
216
- ptx_compiler:: sys:: nvPTXCompilerGetVersion ( addr_of_mut ! ( major) , addr_of_mut ! ( minor) )
217
- } ;
218
- emit_call_site_warning ! ( "PTX version result {}" , r) ;
219
- emit_call_site_warning ! ( "PTX compiler version {}.{}" , major, minor) ;
203
+ let ( result, error_log, info_log, version, drop) =
204
+ check_kernel_ptx ( & kernel_ptx, & specialisation, & kernel_hash) ;
220
205
221
- let kernel_name = if specialisation. is_empty ( ) {
222
- format ! ( "{kernel_hash}_kernel" )
223
- } else {
224
- format ! (
225
- "{kernel_hash}_kernel_{:016x}" ,
226
- seahash:: hash( specialisation. as_bytes( ) )
227
- )
228
- } ;
229
-
230
- let options = vec ! [
231
- CString :: new( "--entry" ) . unwrap( ) ,
232
- CString :: new( kernel_name) . unwrap( ) ,
233
- CString :: new( "--verbose" ) . unwrap( ) ,
234
- CString :: new( "--warn-on-double-precision-use" ) . unwrap( ) ,
235
- CString :: new( "--warn-on-local-memory-usage" ) . unwrap( ) ,
236
- CString :: new( "--warn-on-spills" ) . unwrap( ) ,
237
- ] ;
238
- let options_ptrs = options. iter ( ) . map ( |o| o. as_ptr ( ) ) . collect :: < Vec < _ > > ( ) ;
239
-
240
- let r = unsafe {
241
- ptx_compiler:: sys:: nvPTXCompilerCompile (
242
- compiler,
243
- options_ptrs. len ( ) as c_int ,
244
- options_ptrs. as_ptr ( ) . cast ( ) ,
245
- )
206
+ let ptx_compiler = match & version {
207
+ Ok ( ( major, minor) ) => format ! ( "PTX compiler v{major}.{minor}" ) ,
208
+ Err ( _) => String :: from ( "PTX compiler" ) ,
246
209
} ;
247
- emit_call_site_warning ! ( "PTX compile result {}" , r) ;
248
210
249
- let mut info_log_size = 0 ;
250
- let r = unsafe {
251
- ptx_compiler:: sys:: nvPTXCompilerGetInfoLogSize ( compiler, addr_of_mut ! ( info_log_size) )
252
- } ;
253
- emit_call_site_warning ! ( "PTX info log size result {}" , r) ;
254
- #[ allow( clippy:: cast_possible_truncation) ]
255
- let mut info_log: Vec < u8 > = Vec :: with_capacity ( info_log_size as usize ) ;
256
- if info_log_size > 0 {
257
- let r = unsafe {
258
- ptx_compiler:: sys:: nvPTXCompilerGetInfoLog ( compiler, info_log. as_mut_ptr ( ) . cast ( ) )
259
- } ;
260
- emit_call_site_warning ! ( "PTX info log content result {}" , r) ;
261
- #[ allow( clippy:: cast_possible_truncation) ]
262
- unsafe {
263
- info_log. set_len ( info_log_size as usize ) ;
264
- }
265
- }
266
- let info_log = String :: from_utf8_lossy ( & info_log) ;
267
-
268
- let mut error_log_size = 0 ;
269
- let r = unsafe {
270
- ptx_compiler:: sys:: nvPTXCompilerGetErrorLogSize ( compiler, addr_of_mut ! ( error_log_size) )
271
- } ;
272
- emit_call_site_warning ! ( "PTX error log size result {}" , r) ;
273
- #[ allow( clippy:: cast_possible_truncation) ]
274
- let mut error_log: Vec < u8 > = Vec :: with_capacity ( error_log_size as usize ) ;
275
- if error_log_size > 0 {
276
- let r = unsafe {
277
- ptx_compiler:: sys:: nvPTXCompilerGetErrorLog ( compiler, error_log. as_mut_ptr ( ) . cast ( ) )
278
- } ;
279
- emit_call_site_warning ! ( "PTX error log content result {}" , r) ;
280
- #[ allow( clippy:: cast_possible_truncation) ]
281
- unsafe {
282
- error_log. set_len ( error_log_size as usize ) ;
283
- }
211
+ // TODO: allow user to select
212
+ // - warn on double
213
+ // - warn on float
214
+ // - warn on spills
215
+ // - verbose warn
216
+ // - warnings as errors
217
+ // - show PTX source if warning or error
218
+
219
+ let mut errors = String :: new ( ) ;
220
+ if let Err ( err) = drop {
221
+ let _ = errors. write_fmt ( format_args ! ( "Error dropping the {ptx_compiler}: {err}\n " ) ) ;
284
222
}
285
- let error_log = String :: from_utf8_lossy ( & error_log) ;
286
-
287
- // Ensure the compiler is not dropped
288
- let mut compiler = MaybeUninit :: new ( compiler) ;
289
- let r = unsafe { ptx_compiler:: sys:: nvPTXCompilerDestroy ( compiler. as_mut_ptr ( ) ) } ;
290
- emit_call_site_warning ! ( "PTX compiler destroy result {}" , r) ;
291
-
292
- if !info_log. is_empty ( ) {
293
- emit_call_site_warning ! ( "PTX compiler info log:\n {}" , info_log) ;
223
+ if let Err ( err) = version {
224
+ let _ = errors. write_fmt ( format_args ! (
225
+ "Error fetching the version of the {ptx_compiler}: {err}\n "
226
+ ) ) ;
294
227
}
295
- if !error_log . is_empty ( ) {
228
+ if let ( Ok ( Some ( _ ) ) , _ ) | ( _ , Ok ( Some ( _ ) ) ) = ( & info_log , & error_log ) {
296
229
let mut max_lines = kernel_ptx. chars ( ) . filter ( |c| * c == '\n' ) . count ( ) + 1 ;
297
230
let mut indent = 0 ;
298
231
while max_lines > 0 {
299
232
max_lines /= 10 ;
300
233
indent += 1 ;
301
234
}
302
235
303
- abort_call_site ! (
304
- "PTX compiler error log:\n {}\n PTX source:\n {}" ,
305
- error_log,
236
+ emit_call_site_warning ! (
237
+ "PTX source code:\n {}" ,
306
238
kernel_ptx
307
239
. lines( )
308
240
. enumerate( )
@@ -311,10 +243,162 @@ pub fn link_kernel(tokens: TokenStream) -> TokenStream {
311
243
. join( "\n " )
312
244
) ;
313
245
}
246
+ match info_log {
247
+ Ok ( None ) => ( ) ,
248
+ Ok ( Some ( info_log) ) => emit_call_site_warning ! ( "{ptx_compiler} info log:\n {}" , info_log) ,
249
+ Err ( err) => {
250
+ let _ = errors. write_fmt ( format_args ! (
251
+ "Error fetching the info log of the {ptx_compiler}: {err}\n "
252
+ ) ) ;
253
+ } ,
254
+ } ;
255
+ match error_log {
256
+ Ok ( None ) => ( ) ,
257
+ Ok ( Some ( error_log) ) => emit_call_site_error ! ( "{ptx_compiler} error log:\n {}" , error_log) ,
258
+ Err ( err) => {
259
+ let _ = errors. write_fmt ( format_args ! (
260
+ "Error fetching the error log of the {ptx_compiler}: {err}\n "
261
+ ) ) ;
262
+ } ,
263
+ } ;
264
+ if let Err ( err) = result {
265
+ let _ = errors. write_fmt ( format_args ! ( "Error compiling the PTX source code: {err}\n " ) ) ;
266
+ }
267
+ if !errors. is_empty ( ) {
268
+ abort_call_site ! ( "{}" , errors) ;
269
+ }
314
270
315
271
( quote ! { const PTX_STR : & ' static str = #kernel_ptx; #( #type_layouts) * } ) . into ( )
316
272
}
317
273
274
+ #[ allow( clippy:: type_complexity) ]
275
+ fn check_kernel_ptx (
276
+ kernel_ptx : & str ,
277
+ specialisation : & str ,
278
+ kernel_hash : & proc_macro2:: Ident ,
279
+ ) -> (
280
+ Result < ( ) , NvptxError > ,
281
+ Result < Option < String > , NvptxError > ,
282
+ Result < Option < String > , NvptxError > ,
283
+ Result < ( u32 , u32 ) , NvptxError > ,
284
+ Result < ( ) , NvptxError > ,
285
+ ) {
286
+ let compiler = {
287
+ let mut compiler = std:: ptr:: null_mut ( ) ;
288
+ if let Err ( err) = NvptxError :: try_err_from ( unsafe {
289
+ ptx_compiler_sys:: nvPTXCompilerCreate (
290
+ addr_of_mut ! ( compiler) ,
291
+ kernel_ptx. len ( ) as ptx_compiler_sys:: size_t ,
292
+ kernel_ptx. as_ptr ( ) . cast ( ) ,
293
+ )
294
+ } ) {
295
+ abort_call_site ! ( "PTX compiler creation failed: {}" , err) ;
296
+ }
297
+ compiler
298
+ } ;
299
+
300
+ let result = {
301
+ let kernel_name = if specialisation. is_empty ( ) {
302
+ format ! ( "{kernel_hash}_kernel" )
303
+ } else {
304
+ format ! (
305
+ "{kernel_hash}_kernel_{:016x}" ,
306
+ seahash:: hash( specialisation. as_bytes( ) )
307
+ )
308
+ } ;
309
+
310
+ let options = vec ! [
311
+ CString :: new( "--entry" ) . unwrap( ) ,
312
+ CString :: new( kernel_name) . unwrap( ) ,
313
+ CString :: new( "--verbose" ) . unwrap( ) ,
314
+ CString :: new( "--warn-on-double-precision-use" ) . unwrap( ) ,
315
+ CString :: new( "--warn-on-local-memory-usage" ) . unwrap( ) ,
316
+ CString :: new( "--warn-on-spills" ) . unwrap( ) ,
317
+ ] ;
318
+ let options_ptrs = options. iter ( ) . map ( |o| o. as_ptr ( ) ) . collect :: < Vec < _ > > ( ) ;
319
+
320
+ NvptxError :: try_err_from ( unsafe {
321
+ ptx_compiler_sys:: nvPTXCompilerCompile (
322
+ compiler,
323
+ options_ptrs. len ( ) as c_int ,
324
+ options_ptrs. as_ptr ( ) . cast ( ) ,
325
+ )
326
+ } )
327
+ } ;
328
+
329
+ let error_log = ( || {
330
+ let mut error_log_size = 0 ;
331
+
332
+ NvptxError :: try_err_from ( unsafe {
333
+ ptx_compiler_sys:: nvPTXCompilerGetErrorLogSize ( compiler, addr_of_mut ! ( error_log_size) )
334
+ } ) ?;
335
+
336
+ if error_log_size == 0 {
337
+ return Ok ( None ) ;
338
+ }
339
+
340
+ #[ allow( clippy:: cast_possible_truncation) ]
341
+ let mut error_log: Vec < u8 > = Vec :: with_capacity ( error_log_size as usize ) ;
342
+
343
+ NvptxError :: try_err_from ( unsafe {
344
+ ptx_compiler_sys:: nvPTXCompilerGetErrorLog ( compiler, error_log. as_mut_ptr ( ) . cast ( ) )
345
+ } ) ?;
346
+
347
+ #[ allow( clippy:: cast_possible_truncation) ]
348
+ unsafe {
349
+ error_log. set_len ( error_log_size as usize ) ;
350
+ }
351
+
352
+ Ok ( Some ( String :: from_utf8_lossy ( & error_log) . into_owned ( ) ) )
353
+ } ) ( ) ;
354
+
355
+ let info_log = ( || {
356
+ let mut info_log_size = 0 ;
357
+
358
+ NvptxError :: try_err_from ( unsafe {
359
+ ptx_compiler_sys:: nvPTXCompilerGetInfoLogSize ( compiler, addr_of_mut ! ( info_log_size) )
360
+ } ) ?;
361
+
362
+ if info_log_size == 0 {
363
+ return Ok ( None ) ;
364
+ }
365
+
366
+ #[ allow( clippy:: cast_possible_truncation) ]
367
+ let mut info_log: Vec < u8 > = Vec :: with_capacity ( info_log_size as usize ) ;
368
+
369
+ NvptxError :: try_err_from ( unsafe {
370
+ ptx_compiler_sys:: nvPTXCompilerGetInfoLog ( compiler, info_log. as_mut_ptr ( ) . cast ( ) )
371
+ } ) ?;
372
+
373
+ #[ allow( clippy:: cast_possible_truncation) ]
374
+ unsafe {
375
+ info_log. set_len ( info_log_size as usize ) ;
376
+ }
377
+
378
+ Ok ( Some ( String :: from_utf8_lossy ( & info_log) . into_owned ( ) ) )
379
+ } ) ( ) ;
380
+
381
+ let version = ( || {
382
+ let mut major = 0 ;
383
+ let mut minor = 0 ;
384
+
385
+ NvptxError :: try_err_from ( unsafe {
386
+ ptx_compiler_sys:: nvPTXCompilerGetVersion ( addr_of_mut ! ( major) , addr_of_mut ! ( minor) )
387
+ } ) ?;
388
+
389
+ Ok ( ( major, minor) )
390
+ } ) ( ) ;
391
+
392
+ let drop = {
393
+ let mut compiler = compiler;
394
+ NvptxError :: try_err_from ( unsafe {
395
+ ptx_compiler_sys:: nvPTXCompilerDestroy ( addr_of_mut ! ( compiler) )
396
+ } )
397
+ } ;
398
+
399
+ ( result, error_log, info_log, version, drop)
400
+ }
401
+
318
402
fn compile_kernel (
319
403
args : & syn:: Ident ,
320
404
crate_name : & str ,
0 commit comments