Skip to content

Commit e5c2953

Browse files
authored
Add debug_printf! and debug_printfln! macros that uses the DebugPrintf extension (#768)
1 parent 28313a2 commit e5c2953

File tree

5 files changed

+469
-0
lines changed

5 files changed

+469
-0
lines changed

crates/spirv-std/macros/src/lib.rs

+263
Original file line numberDiff line numberDiff line change
@@ -350,3 +350,266 @@ fn path_from_ident(ident: Ident) -> syn::Type {
350350
path: syn::Path::from(ident),
351351
})
352352
}
353+
354+
/// Print a formatted string with a newline using the debug printf extension.
355+
///
356+
/// Examples:
357+
///
358+
/// ```rust,ignore
359+
/// debug_printfln!("uv: %v2f", uv);
360+
/// debug_printfln!("pos.x: %f, pos.z: %f, int: %i", pos.x, pos.z, int);
361+
/// ```
362+
///
363+
/// See <https://github.com/KhronosGroup/Vulkan-ValidationLayers/blob/master/docs/debug_printf.md#debug-printf-format-string> for formatting rules.
364+
#[proc_macro]
365+
pub fn debug_printf(input: TokenStream) -> TokenStream {
366+
debug_printf_inner(syn::parse_macro_input!(input as DebugPrintfInput))
367+
}
368+
369+
/// Similar to `debug_printf` but appends a newline to the format string.
370+
#[proc_macro]
371+
pub fn debug_printfln(input: TokenStream) -> TokenStream {
372+
let mut input = syn::parse_macro_input!(input as DebugPrintfInput);
373+
input.format_string.push('\n');
374+
debug_printf_inner(input)
375+
}
376+
377+
struct DebugPrintfInput {
378+
span: proc_macro2::Span,
379+
format_string: String,
380+
variables: Vec<syn::Expr>,
381+
}
382+
383+
impl syn::parse::Parse for DebugPrintfInput {
384+
fn parse(input: syn::parse::ParseStream<'_>) -> syn::parse::Result<Self> {
385+
let span = input.span();
386+
387+
if input.is_empty() {
388+
return Ok(Self {
389+
span,
390+
format_string: Default::default(),
391+
variables: Default::default(),
392+
});
393+
}
394+
395+
let format_string = input.parse::<syn::LitStr>()?;
396+
if !input.is_empty() {
397+
input.parse::<syn::token::Comma>()?;
398+
}
399+
let variables =
400+
syn::punctuated::Punctuated::<syn::Expr, syn::token::Comma>::parse_terminated(input)?;
401+
402+
Ok(Self {
403+
span,
404+
format_string: format_string.value(),
405+
variables: variables.into_iter().collect(),
406+
})
407+
}
408+
}
409+
410+
fn parsing_error(message: &str, span: proc_macro2::Span) -> TokenStream {
411+
syn::Error::new(span, message).to_compile_error().into()
412+
}
413+
414+
enum FormatType {
415+
Scalar {
416+
ty: proc_macro2::TokenStream,
417+
},
418+
Vector {
419+
ty: proc_macro2::TokenStream,
420+
width: usize,
421+
},
422+
}
423+
424+
fn debug_printf_inner(input: DebugPrintfInput) -> TokenStream {
425+
let DebugPrintfInput {
426+
format_string,
427+
variables,
428+
span,
429+
} = input;
430+
431+
fn map_specifier_to_type(
432+
specifier: char,
433+
chars: &mut std::str::Chars<'_>,
434+
) -> Option<proc_macro2::TokenStream> {
435+
let mut peekable = chars.peekable();
436+
437+
Some(match specifier {
438+
'd' | 'i' => quote::quote! { i32 },
439+
'o' | 'x' | 'X' => quote::quote! { u32 },
440+
'a' | 'A' | 'e' | 'E' | 'f' | 'F' | 'g' | 'G' => quote::quote! { f32 },
441+
'u' => {
442+
if matches!(peekable.peek(), Some('l')) {
443+
chars.next();
444+
quote::quote! { u64 }
445+
} else {
446+
quote::quote! { u32 }
447+
}
448+
}
449+
'l' => {
450+
if matches!(peekable.peek(), Some('u' | 'x')) {
451+
chars.next();
452+
quote::quote! { u64 }
453+
} else {
454+
return None;
455+
}
456+
}
457+
_ => return None,
458+
})
459+
}
460+
461+
let mut chars = format_string.chars();
462+
let mut format_arguments = Vec::new();
463+
464+
while let Some(mut ch) = chars.next() {
465+
if ch == '%' {
466+
ch = match chars.next() {
467+
Some('%') => continue,
468+
None => return parsing_error("Unterminated format specifier", span),
469+
Some(ch) => ch,
470+
};
471+
472+
let mut has_precision = false;
473+
474+
while matches!(ch, '0'..='9') {
475+
ch = match chars.next() {
476+
Some(ch) => ch,
477+
None => {
478+
return parsing_error(
479+
"Unterminated format specifier: missing type after precision",
480+
span,
481+
)
482+
}
483+
};
484+
485+
has_precision = true;
486+
}
487+
488+
if has_precision && ch == '.' {
489+
ch = match chars.next() {
490+
Some(ch) => ch,
491+
None => {
492+
return parsing_error(
493+
"Unterminated format specifier: missing type after decimal point",
494+
span,
495+
)
496+
}
497+
};
498+
499+
while matches!(ch, '0'..='9') {
500+
ch = match chars.next() {
501+
Some(ch) => ch,
502+
None => return parsing_error(
503+
"Unterminated format specifier: missing type after fraction precision",
504+
span,
505+
),
506+
};
507+
}
508+
}
509+
510+
if ch == 'v' {
511+
let width = match chars.next() {
512+
Some('2') => 2,
513+
Some('3') => 3,
514+
Some('4') => 4,
515+
Some(ch) => {
516+
return parsing_error(&format!("Invalid width for vector: {}", ch), span)
517+
}
518+
None => return parsing_error("Missing vector dimensions specifier", span),
519+
};
520+
521+
ch = match chars.next() {
522+
Some(ch) => ch,
523+
None => return parsing_error("Missing vector type specifier", span),
524+
};
525+
526+
let ty = match map_specifier_to_type(ch, &mut chars) {
527+
Some(ty) => ty,
528+
_ => {
529+
return parsing_error(
530+
&format!("Unrecognised vector type specifier: '{}'", ch),
531+
span,
532+
)
533+
}
534+
};
535+
536+
format_arguments.push(FormatType::Vector { ty, width });
537+
} else {
538+
let ty = match map_specifier_to_type(ch, &mut chars) {
539+
Some(ty) => ty,
540+
_ => {
541+
return parsing_error(
542+
&format!("Unrecognised format specifier: '{}'", ch),
543+
span,
544+
)
545+
}
546+
};
547+
548+
format_arguments.push(FormatType::Scalar { ty });
549+
}
550+
}
551+
}
552+
553+
if format_arguments.len() != variables.len() {
554+
return syn::Error::new(
555+
span,
556+
&format!(
557+
"{} % arguments were found, but {} variables were given",
558+
format_arguments.len(),
559+
variables.len()
560+
),
561+
)
562+
.to_compile_error()
563+
.into();
564+
}
565+
566+
let mut variable_idents = String::new();
567+
let mut input_registers = Vec::new();
568+
let mut op_loads = Vec::new();
569+
570+
for (i, (variable, format_argument)) in variables.into_iter().zip(format_arguments).enumerate()
571+
{
572+
let ident = quote::format_ident!("_{}", i);
573+
574+
variable_idents.push_str(&format!("%{} ", ident));
575+
576+
let assert_fn = match format_argument {
577+
FormatType::Scalar { ty } => {
578+
quote::quote! { spirv_std::debug_printf_assert_is_type::<#ty> }
579+
}
580+
FormatType::Vector { ty, width } => {
581+
quote::quote! { spirv_std::debug_printf_assert_is_vector::<#ty, _, #width> }
582+
}
583+
};
584+
585+
input_registers.push(quote::quote! {
586+
#ident = in(reg) &#assert_fn(#variable),
587+
});
588+
589+
let op_load = format!("%{ident} = OpLoad _ {{{ident}}}", ident = ident);
590+
591+
op_loads.push(quote::quote! {
592+
#op_load,
593+
});
594+
}
595+
596+
let input_registers = input_registers
597+
.into_iter()
598+
.collect::<proc_macro2::TokenStream>();
599+
let op_loads = op_loads.into_iter().collect::<proc_macro2::TokenStream>();
600+
601+
let op_string = format!("%string = OpString {:?}", format_string);
602+
603+
let output = quote::quote! {
604+
asm!(
605+
"%void = OpTypeVoid",
606+
#op_string,
607+
"%debug_printf = OpExtInstImport \"NonSemantic.DebugPrintf\"",
608+
#op_loads
609+
concat!("%result = OpExtInst %void %debug_printf 1 %string ", #variable_idents),
610+
#input_registers
611+
)
612+
};
613+
614+
output.into()
615+
}

crates/spirv-std/src/lib.rs

+16
Original file line numberDiff line numberDiff line change
@@ -130,3 +130,19 @@ extern "C" fn rust_eh_personality() {}
130130
#[doc(hidden)]
131131
/// [spirv_types]
132132
pub fn workaround_rustdoc_ice_84738() {}
133+
134+
#[doc(hidden)]
135+
pub fn debug_printf_assert_is_type<T>(ty: T) -> T {
136+
ty
137+
}
138+
139+
#[doc(hidden)]
140+
pub fn debug_printf_assert_is_vector<
141+
TY: crate::scalar::Scalar,
142+
V: crate::vector::Vector<TY, SIZE>,
143+
const SIZE: usize,
144+
>(
145+
vec: V,
146+
) -> V {
147+
vec
148+
}

tests/ui/arch/debug_printf.rs

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
// build-pass
2+
// compile-flags: -Ctarget-feature=+ext:SPV_KHR_non_semantic_info
3+
4+
use spirv_std::{
5+
glam::{IVec2, UVec2, Vec2, Vec3, Vec4},
6+
macros::{debug_printf, debug_printfln},
7+
};
8+
9+
fn func(a: f32, b: f32) -> f32 {
10+
a * b + 1.0
11+
}
12+
13+
struct Struct {
14+
a: f32,
15+
}
16+
17+
impl Struct {
18+
fn method(&self, b: f32, c: f32) -> f32 {
19+
self.a * b + c
20+
}
21+
}
22+
23+
#[spirv(fragment)]
24+
pub fn main() {
25+
unsafe {
26+
debug_printf!();
27+
debug_printfln!();
28+
debug_printfln!("Hello World");
29+
debug_printfln!("Hello World",);
30+
debug_printfln!(r#"Hello "World""#);
31+
debug_printfln!(
32+
r#"Hello "World"
33+
"#
34+
);
35+
debug_printfln!("Hello \"World\"\n\n");
36+
debug_printfln!("%%r %%f %%%%f %%%%%u", 77);
37+
}
38+
39+
let vec = Vec2::new(1.52, 25.1);
40+
41+
unsafe {
42+
debug_printfln!("%v2f", vec);
43+
debug_printfln!("%1v2f", { vec * 2.0 });
44+
debug_printfln!("%1.2v2f", vec * 3.0);
45+
debug_printfln!("%% %v2f %%", vec * 4.0);
46+
debug_printfln!("%u %i %f 🐉", 11_u32, -11_i32, 11.0_f32);
47+
debug_printfln!("%f", func(33.0, 44.0));
48+
debug_printfln!("%f", Struct { a: 33.0 }.method(44.0, 55.0));
49+
debug_printfln!("%v3f %v4f", Vec3::new(1.0, 1.0, 1.0), Vec4::splat(5.0));
50+
debug_printfln!("%v2u %v2i", UVec2::new(1, 1), IVec2::splat(-5));
51+
}
52+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// build-fail
2+
// normalize-stderr-test "\S*/crates/spirv-std/src/" -> "$$SPIRV_STD_SRC/"
3+
// compile-flags: -Ctarget-feature=+ext:SPV_KHR_non_semantic_info
4+
5+
use spirv_std::{glam::Vec2, macros::debug_printf};
6+
7+
#[spirv(fragment)]
8+
pub fn main() {
9+
unsafe {
10+
debug_printf!("%1");
11+
debug_printf!("%1.");
12+
debug_printf!("%.");
13+
debug_printf!("%.1");
14+
debug_printf!("%1.1");
15+
debug_printf!("%1.1v");
16+
debug_printf!("%1.1v5");
17+
debug_printf!("%1.1v2");
18+
debug_printf!("%1.1v2r");
19+
debug_printf!("%r", 11_i32);
20+
debug_printf!("%f", 11_u32);
21+
debug_printf!("%u", 11.0_f32);
22+
debug_printf!("%v2f", 11.0);
23+
debug_printf!("%f", Vec2::splat(33.3));
24+
}
25+
}

0 commit comments

Comments
 (0)