Skip to content

Commit

Permalink
Compile an amplification shader
Browse files Browse the repository at this point in the history
  • Loading branch information
RobDangerous committed Aug 27, 2024
1 parent 332513b commit f190f2b
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 14 deletions.
2 changes: 2 additions & 0 deletions Sources/backends/d3d12.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ static const wchar_t *shader_string(shader_stage stage) {
return L"cs_6_0";
case SHADER_STAGE_RAY_GENERATION:
return L"lib_6_3";
case SHADER_STAGE_AMPLIFICATION:
return L"as_6_5";
case SHADER_STAGE_MESH:
return L"ms_6_5";
default: {
Expand Down
96 changes: 90 additions & 6 deletions Sources/backends/hlsl.c
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,36 @@ static void write_functions(char *hlsl, size_t *offset, shader_stage stage, func
sprintf(&hlsl[*offset], ", %s _%" PRIu64, type_string(f->parameter_types[parameter_index].type), parameter_ids[parameter_index]);
}
}
*offset += sprintf(&hlsl[*offset], ") {\n");
if (f->parameters_size > 0) {
*offset += sprintf(&hlsl[*offset], ", ");
}
*offset += sprintf(&hlsl[*offset], "in uint3 _kong_group_id : SV_GroupID, in uint3 _kong_group_thread_id : SV_GroupThreadID, in uint3 "
"_kong_dispatch_thread_id : SV_DispatchThreadID, in uint _kong_group_index : SV_GroupIndex) {\n");
}
else if (stage == SHADER_STAGE_AMPLIFICATION) {
attribute *threads_attribute = find_attribute(&f->attributes, add_name("threads"));
if (threads_attribute == NULL || threads_attribute->paramters_count != 3) {
debug_context context = {0};
error(context, "Compute function requires a threads attribute with three parameters");
}

*offset += sprintf(&hlsl[*offset], "[numthreads(%i, %i, %i)] %s main(", (int)threads_attribute->parameters[0],
(int)threads_attribute->parameters[1], (int)threads_attribute->parameters[2], type_string(f->return_type.type));
for (uint8_t parameter_index = 0; parameter_index < f->parameters_size; ++parameter_index) {
if (parameter_index == 0) {
*offset +=
sprintf(&hlsl[*offset], "%s _%" PRIu64, type_string(f->parameter_types[parameter_index].type), parameter_ids[parameter_index]);
}
else {
*offset +=
sprintf(&hlsl[*offset], ", %s _%" PRIu64, type_string(f->parameter_types[parameter_index].type), parameter_ids[parameter_index]);
}
}
if (f->parameters_size > 0) {
*offset += sprintf(&hlsl[*offset], ", ");
}
*offset += sprintf(&hlsl[*offset], "in uint3 _kong_group_id : SV_GroupID, in uint3 _kong_group_thread_id : SV_GroupThreadID, in uint3 "
"_kong_dispatch_thread_id : SV_DispatchThreadID, in uint _kong_group_index : SV_GroupIndex) {\n");
}
else if (stage == SHADER_STAGE_MESH) {
attribute *topology_attribute = find_attribute(&f->attributes, add_name("topology"));
Expand All @@ -399,7 +428,11 @@ static void write_functions(char *hlsl, size_t *offset, shader_stage stage, func
sprintf(&hlsl[*offset], ", %s _%" PRIu64, type_string(f->parameter_types[parameter_index].type), parameter_ids[parameter_index]);
}
}
*offset += sprintf(&hlsl[*offset], ") {\n");
if (f->parameters_size > 0) {
*offset += sprintf(&hlsl[*offset], ", ");
}
*offset += sprintf(&hlsl[*offset], "in uint3 _kong_group_id : SV_GroupID, in uint3 _kong_group_thread_id : SV_GroupThreadID, in uint3 "
"_kong_dispatch_thread_id : SV_DispatchThreadID, in uint _kong_group_index : SV_GroupIndex) {\n");
}
else {
debug_context context = {0};
Expand Down Expand Up @@ -563,19 +596,21 @@ static void write_functions(char *hlsl, size_t *offset, shader_stage stage, func
}
else if (o->op_call.func == add_name("group_id")) {
check(o->op_call.parameters_size == 0, context, "group_id can not have a parameter");
*offset += sprintf(&hlsl[*offset], "%s _%" PRIu64 " = SV_GroupID;\n", type_string(o->op_call.var.type.type), o->op_call.var.index);
*offset += sprintf(&hlsl[*offset], "%s _%" PRIu64 " = _kong_group_id;\n", type_string(o->op_call.var.type.type), o->op_call.var.index);
}
else if (o->op_call.func == add_name("group_thread_id")) {
check(o->op_call.parameters_size == 0, context, "group_thread_id can not have a parameter");
*offset += sprintf(&hlsl[*offset], "%s _%" PRIu64 " = SV_GroupThreadID;\n", type_string(o->op_call.var.type.type), o->op_call.var.index);
*offset +=
sprintf(&hlsl[*offset], "%s _%" PRIu64 " = _kong_group_thread_id;\n", type_string(o->op_call.var.type.type), o->op_call.var.index);
}
else if (o->op_call.func == add_name("dispatch_thread_id")) {
check(o->op_call.parameters_size == 0, context, "dispatch_thread_id can not have a parameter");
*offset += sprintf(&hlsl[*offset], "%s _%" PRIu64 " = SV_DispatchThreadID;\n", type_string(o->op_call.var.type.type), o->op_call.var.index);
*offset +=
sprintf(&hlsl[*offset], "%s _%" PRIu64 " = _kong_dispatch_thread_id;\n", type_string(o->op_call.var.type.type), o->op_call.var.index);
}
else if (o->op_call.func == add_name("group_index")) {
check(o->op_call.parameters_size == 0, context, "group_index can not have a parameter");
*offset += sprintf(&hlsl[*offset], "%s _%" PRIu64 " = SV_GroupIndex;\n", type_string(o->op_call.var.type.type), o->op_call.var.index);
*offset += sprintf(&hlsl[*offset], "%s _%" PRIu64 " = _kong_group_index;\n", type_string(o->op_call.var.type.type), o->op_call.var.index);
}
else if (o->op_call.func == add_name("world_ray_direction")) {
check(o->op_call.parameters_size == 0, context, "group_index can not have a parameter");
Expand All @@ -595,6 +630,12 @@ static void write_functions(char *hlsl, size_t *offset, shader_stage stage, func
*offset += sprintf(&hlsl[*offset], "TraceRay(_%" PRIu64 ", RAY_FLAG_NONE, 0xFF, 0, 0, 0, _%" PRIu64 ", _%" PRIu64 ");\n",
o->op_call.parameters[0].index, o->op_call.parameters[1].index, o->op_call.parameters[2].index);
}
else if (o->op_call.func == add_name("dispatch_mesh")) {
check(o->op_call.parameters_size == 4, context, "dispatch_mesh requires four parameters");
*offset +=
sprintf(&hlsl[*offset], "DispatchMesh(_%" PRIu64 ", _%" PRIu64 ", _%" PRIu64 ", _%" PRIu64 ");\n", o->op_call.parameters[0].index,
o->op_call.parameters[1].index, o->op_call.parameters[2].index, o->op_call.parameters[3].index);
}
else {
if (o->op_call.var.type.type == void_id) {
*offset += sprintf(&hlsl[*offset], "%s(", function_string(o->op_call.func));
Expand Down Expand Up @@ -672,6 +713,34 @@ static void hlsl_export_vertex(char *directory, api_kind d3d, function *main) {
write_bytecode(hlsl, directory, filename, var_name, output, output_size);
}

static void hlsl_export_amplification(char *directory, function *main) {
char *hlsl = (char *)calloc(1024 * 1024, 1);
size_t offset = 0;

write_types(hlsl, &offset, SHADER_STAGE_AMPLIFICATION, NO_TYPE, NO_TYPE, main, NULL, 0);

write_globals(hlsl, &offset, main, NULL, 0);

write_functions(hlsl, &offset, SHADER_STAGE_AMPLIFICATION, main, NULL, 0);

char *output = NULL;
size_t output_size = 0;
int result = compile_hlsl_to_d3d12(hlsl, &output, &output_size, SHADER_STAGE_AMPLIFICATION, false);

debug_context context = {0};
check(result == 0, context, "HLSL compilation failed");

char *name = get_name(main->name);

char filename[512];
sprintf(filename, "kong_%s", name);

char var_name[256];
sprintf(var_name, "%s_code", name);

write_bytecode(hlsl, directory, filename, var_name, output, output_size);
}

static void hlsl_export_mesh(char *directory, function *main) {
char *hlsl = (char *)calloc(1024 * 1024, 1);
size_t offset = 0;
Expand Down Expand Up @@ -868,6 +937,9 @@ void hlsl_export(char *directory, api_kind d3d) {
function *vertex_shaders[256];
size_t vertex_shaders_size = 0;

function *amplification_shaders[256];
size_t amplification_shaders_size = 0;

function *mesh_shaders[256];
size_t mesh_shaders_size = 0;

Expand All @@ -878,13 +950,17 @@ void hlsl_export(char *directory, api_kind d3d) {
type *t = get_type(i);
if (!t->built_in && has_attribute(&t->attributes, add_name("pipe"))) {
name_id vertex_shader_name = NO_NAME;
name_id amplification_shader_name = NO_NAME;
name_id mesh_shader_name = NO_NAME;
name_id fragment_shader_name = NO_NAME;

for (size_t j = 0; j < t->members.size; ++j) {
if (t->members.m[j].name == add_name("vertex")) {
vertex_shader_name = t->members.m[j].value.identifier;
}
else if (t->members.m[j].name == add_name("amplification")) {
amplification_shader_name = t->members.m[j].value.identifier;
}
else if (t->members.m[j].name == add_name("mesh")) {
mesh_shader_name = t->members.m[j].value.identifier;
}
Expand All @@ -903,6 +979,10 @@ void hlsl_export(char *directory, api_kind d3d) {
vertex_shaders[vertex_shaders_size] = f;
vertex_shaders_size += 1;
}
if (amplification_shader_name != NO_NAME && f->name == amplification_shader_name) {
amplification_shaders[amplification_shaders_size] = f;
amplification_shaders_size += 1;
}
if (mesh_shader_name != NO_NAME && f->name == mesh_shader_name) {
mesh_shaders[mesh_shaders_size] = f;
mesh_shaders_size += 1;
Expand Down Expand Up @@ -973,6 +1053,10 @@ void hlsl_export(char *directory, api_kind d3d) {
}

if (d3d == API_DIRECT3D12) {
for (size_t i = 0; i < amplification_shaders_size; ++i) {
hlsl_export_amplification(directory, amplification_shaders[i]);
}

for (size_t i = 0; i < mesh_shaders_size; ++i) {
hlsl_export_mesh(directory, mesh_shaders[i]);
}
Expand Down
44 changes: 37 additions & 7 deletions Sources/functions.c
Original file line number Diff line number Diff line change
Expand Up @@ -152,31 +152,61 @@ void functions_init(void) {
{
function_id func = add_function(add_name("trace_ray"));
function *f = get_function(func);

init_type_ref(&f->return_type, add_name("void"));
f->return_type.type = find_type_by_ref(&f->return_type);

f->parameter_names[0] = add_name("scene");
init_type_ref(&f->parameter_types[0], add_name("bvh"));
f->parameter_types[0].type = find_type_by_ref(&f->parameter_types[0]);
f->parameters_size = 1;
f->parameters_size += 1;

f->parameter_names[1] = add_name("ray");
init_type_ref(&f->parameter_types[1], add_name("ray"));
f->parameter_types[1].type = find_type_by_ref(&f->parameter_types[1]);
f->parameters_size = 1;
f->parameters_size += 1;

f->parameter_names[2] = add_name("payload");
init_type_ref(&f->parameter_types[2], add_name("void"));
f->parameter_types[2].type = find_type_by_ref(&f->parameter_types[2]);
f->parameters_size = 1;
f->parameters_size += 1;

f->block = NULL;
}

{
function_id func = add_function(add_name("dispatch_mesh"));
function *f = get_function(func);

init_type_ref(&f->return_type, add_name("void"));
f->return_type.type = find_type_by_ref(&f->return_type);

f->parameter_names[0] = add_name("x");
init_type_ref(&f->parameter_types[0], add_name("uint"));
f->parameter_types[0].type = find_type_by_ref(&f->parameter_types[0]);
f->parameters_size += 1;

f->parameter_names[1] = add_name("y");
init_type_ref(&f->parameter_types[1], add_name("uint"));
f->parameter_types[1].type = find_type_by_ref(&f->parameter_types[1]);
f->parameters_size += 1;

f->parameter_names[2] = add_name("z");
init_type_ref(&f->parameter_types[2], add_name("uint"));
f->parameter_types[2].type = find_type_by_ref(&f->parameter_types[2]);
f->parameters_size += 1;

f->parameter_names[3] = add_name("payload");
init_type_ref(&f->parameter_types[3], add_name("void"));
f->parameter_types[3].type = find_type_by_ref(&f->parameter_types[3]);
f->parameters_size += 1;

f->block = NULL;
}

add_func_int("group_id");
add_func_int("group_thread_id");
add_func_int("dispatch_thread_id");
add_func_uint3("group_id");
add_func_uint3("group_thread_id");
add_func_uint3("dispatch_thread_id");
add_func_int("group_index");

add_func_float3_float_float_float("lerp");
Expand Down
6 changes: 5 additions & 1 deletion Sources/integrations/kinc.c
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,7 @@ void kinc_export(char *directory, api_kind api) {
fprintf(output, "\tkinc_g4_pipeline_init(&%s);\n\n", get_name(t->name));

name_id vertex_shader_name = NO_NAME;
name_id amplification_shader_name = NO_NAME;
name_id mesh_shader_name = NO_NAME;
name_id fragment_shader_name = NO_NAME;

Expand All @@ -480,7 +481,10 @@ void kinc_export(char *directory, api_kind api) {
fprintf(output, "\t%s.vertex_shader = &%s;\n\n", get_name(t->name), get_name(t->members.m[j].value.identifier));
vertex_shader_name = t->members.m[j].value.identifier;
}
if (t->members.m[j].name == add_name("mesh")) {
else if (t->members.m[j].name == add_name("amplification")) {
amplification_shader_name = t->members.m[j].value.identifier;
}
else if (t->members.m[j].name == add_name("mesh")) {
mesh_shader_name = t->members.m[j].value.identifier;
}
else if (t->members.m[j].name == add_name("fragment")) {
Expand Down
1 change: 1 addition & 0 deletions Sources/shader_stage.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

typedef enum shader_stage {
SHADER_STAGE_VERTEX,
SHADER_STAGE_AMPLIFICATION,
SHADER_STAGE_MESH,
SHADER_STAGE_FRAGMENT,
SHADER_STAGE_COMPUTE,
Expand Down

0 comments on commit f190f2b

Please sign in to comment.