diff --git a/Sources/backends/hlsl.c b/Sources/backends/hlsl.c index 1056718..864346d 100644 --- a/Sources/backends/hlsl.c +++ b/Sources/backends/hlsl.c @@ -145,6 +145,17 @@ static void write_types(char *hlsl, size_t *offset, shader_stage stage, type_id } } } + else if (stage == SHADER_STAGE_MESH && types[i] == output) { + for (size_t j = 0; j < t->members.size; ++j) { + if (j == 0) { + *offset += sprintf(&hlsl[*offset], "\t%s %s : SV_POSITION;\n", type_string(t->members.m[j].type.type), get_name(t->members.m[j].name)); + } + else { + *offset += + sprintf(&hlsl[*offset], "\t%s %s : TEXCOORD%zu;\n", type_string(t->members.m[j].type.type), get_name(t->members.m[j].name), j - 1); + } + } + } else if (stage == SHADER_STAGE_FRAGMENT && types[i] == input) { for (size_t j = 0; j < t->members.size; ++j) { if (j == 0) { @@ -416,6 +427,21 @@ static void write_functions(char *hlsl, size_t *offset, shader_stage stage, func error(context, "Mesh function requires a threads attribute with three parameters"); } + attribute *tris_attribute = find_attribute(&f->attributes, add_name("tris")); + if (tris_attribute == NULL || tris_attribute->paramters_count != 1) { + debug_context context = {0}; + error(context, "Mesh function requires a tris attribute with one parameter"); + } + + attribute *vertices_attribute = find_attribute(&f->attributes, add_name("vertices")); + if (vertices_attribute == NULL || vertices_attribute->paramters_count != 2) { + debug_context context = {0}; + error(context, "Mesh function requires a vertices attribute with two parameters"); + } + + type_id vertex_type = (type_id)vertices_attribute->parameters[1]; + char *vertex_name = get_name(get_type(vertex_type)->name); + *offset += sprintf(&hlsl[*offset], "[outputtopology(\"triangle\")][numthreads(%i, %i, %i)] %s main(", (int)threads_attribute->parameters[0], (int)threads_attribute->parameters[1], (int)threads_attribute->parameters[2], type_string(f->return_type.type)); for (uint8_t parameter_index = 0; parameter_index < f->parameters_size; ++parameter_index) { @@ -431,8 +457,12 @@ static void write_functions(char *hlsl, size_t *offset, shader_stage stage, func if (f->parameters_size > 0) { *offset += sprintf(&hlsl[*offset], ", "); } - *offset += sprintf(&hlsl[*offset], "in uint3 _kong_group_id : SV_GroupID, in uint3 _kong_group_thread_id : SV_GroupThreadID, in uint3 " - "_kong_dispatch_thread_id : SV_DispatchThreadID, in uint _kong_group_index : SV_GroupIndex) {\n"); + *offset += + sprintf(&hlsl[*offset], + "out indices uint3 _kong_mesh_tris[%i], out vertices %s _kong_mesh_vertices[%i], in uint3 _kong_group_id : SV_GroupID, in uint3 " + "_kong_group_thread_id : SV_GroupThreadID, in uint3 " + "_kong_dispatch_thread_id : SV_DispatchThreadID, in uint _kong_group_index : SV_GroupIndex) {\n", + (int)tris_attribute->parameters[0], vertex_name, (int)vertices_attribute->parameters[0]); } else { debug_context context = {0}; @@ -636,6 +666,21 @@ static void write_functions(char *hlsl, size_t *offset, shader_stage stage, func sprintf(&hlsl[*offset], "DispatchMesh(_%" PRIu64 ", _%" PRIu64 ", _%" PRIu64 ", _%" PRIu64 ");\n", o->op_call.parameters[0].index, o->op_call.parameters[1].index, o->op_call.parameters[2].index, o->op_call.parameters[3].index); } + else if (o->op_call.func == add_name("set_mesh_output_counts")) { + check(o->op_call.parameters_size == 2, context, "set_mesh_output_counts requires two parameters"); + *offset += sprintf(&hlsl[*offset], "SetMeshOutputCounts(_%" PRIu64 ", _%" PRIu64 ");\n", o->op_call.parameters[0].index, + o->op_call.parameters[1].index); + } + else if (o->op_call.func == add_name("set_mesh_triangle")) { + check(o->op_call.parameters_size == 2, context, "set_mesh_triangle requires two parameters"); + *offset += sprintf(&hlsl[*offset], "_kong_mesh_tris[_%" PRIu64 "] = _%" PRIu64 ";\n", o->op_call.parameters[0].index, + o->op_call.parameters[1].index); + } + else if (o->op_call.func == add_name("set_mesh_vertex")) { + check(o->op_call.parameters_size == 2, context, "set_mesh_vertex requires two parameters"); + *offset += sprintf(&hlsl[*offset], "_kong_mesh_vertices[_%" PRIu64 "] = _%" PRIu64 ";\n", o->op_call.parameters[0].index, + o->op_call.parameters[1].index); + } else { if (o->op_call.var.type.type == void_id) { *offset += sprintf(&hlsl[*offset], "%s(", function_string(o->op_call.func)); @@ -745,7 +790,14 @@ static void hlsl_export_mesh(char *directory, function *main) { char *hlsl = (char *)calloc(1024 * 1024, 1); size_t offset = 0; - write_types(hlsl, &offset, SHADER_STAGE_MESH, NO_TYPE, NO_TYPE, main, NULL, 0); + attribute *vertices_attribute = find_attribute(&main->attributes, add_name("vertices")); + if (vertices_attribute == NULL || vertices_attribute->paramters_count != 2) { + debug_context context = {0}; + error(context, "Mesh function requires a vertices attribute with two parameters"); + } + type_id vertex_output = (type_id)vertices_attribute->parameters[1]; + + write_types(hlsl, &offset, SHADER_STAGE_MESH, NO_TYPE, vertex_output, main, NULL, 0); write_globals(hlsl, &offset, main, NULL, 0); diff --git a/Sources/functions.c b/Sources/functions.c index 794bc80..ef75a55 100644 --- a/Sources/functions.c +++ b/Sources/functions.c @@ -82,6 +82,24 @@ static void add_func_float3_float3(char *name) { f->block = NULL; } +static void add_func_void_uint_uint(char *name) { + function_id func = add_function(add_name(name)); + function *f = get_function(func); + + init_type_ref(&f->return_type, add_name("void")); + f->return_type.type = find_type_by_ref(&f->return_type); + + f->parameter_names[0] = add_name("a"); + f->parameter_names[1] = add_name("b"); + for (int i = 0; i < 2; ++i) { + init_type_ref(&f->parameter_types[0], add_name("uint")); + f->parameter_types[0].type = find_type_by_ref(&f->parameter_types[0]); + } + f->parameters_size = 2; + + f->block = NULL; +} + void functions_init(void) { function *new_functions = realloc(functions, functions_size * sizeof(function)); debug_context context = {0}; @@ -121,7 +139,13 @@ void functions_init(void) { f->parameter_names[0] = add_name("x"); init_type_ref(&f->parameter_types[0], add_name("float")); f->parameter_types[0].type = find_type_by_ref(&f->parameter_types[0]); - f->parameters_size = 1; + + f->parameter_names[1] = add_name("y"); + init_type_ref(&f->parameter_types[1], add_name("float")); + f->parameter_types[1].type = find_type_by_ref(&f->parameter_types[1]); + + f->parameters_size = 2; + f->block = NULL; } @@ -130,10 +154,21 @@ void functions_init(void) { function *f = get_function(float3_constructor_id); init_type_ref(&f->return_type, add_name("float3")); f->return_type.type = find_type_by_ref(&f->return_type); + f->parameter_names[0] = add_name("x"); init_type_ref(&f->parameter_types[0], add_name("float")); f->parameter_types[0].type = find_type_by_ref(&f->parameter_types[0]); - f->parameters_size = 1; + + f->parameter_names[1] = add_name("y"); + init_type_ref(&f->parameter_types[1], add_name("float")); + f->parameter_types[1].type = find_type_by_ref(&f->parameter_types[1]); + + f->parameter_names[2] = add_name("z"); + init_type_ref(&f->parameter_types[2], add_name("float")); + f->parameter_types[2].type = find_type_by_ref(&f->parameter_types[2]); + + f->parameters_size = 3; + f->block = NULL; } @@ -142,10 +177,163 @@ void functions_init(void) { function *f = get_function(float4_constructor_id); init_type_ref(&f->return_type, add_name("float4")); f->return_type.type = find_type_by_ref(&f->return_type); + f->parameter_names[0] = add_name("x"); init_type_ref(&f->parameter_types[0], add_name("float")); f->parameter_types[0].type = find_type_by_ref(&f->parameter_types[0]); - f->parameters_size = 1; + + f->parameter_names[1] = add_name("y"); + init_type_ref(&f->parameter_types[1], add_name("float")); + f->parameter_types[1].type = find_type_by_ref(&f->parameter_types[1]); + + f->parameter_names[2] = add_name("z"); + init_type_ref(&f->parameter_types[2], add_name("float")); + f->parameter_types[2].type = find_type_by_ref(&f->parameter_types[2]); + + f->parameter_names[3] = add_name("w"); + init_type_ref(&f->parameter_types[3], add_name("float")); + f->parameter_types[3].type = find_type_by_ref(&f->parameter_types[3]); + + f->parameters_size = 4; + + f->block = NULL; + } + + { + float2_constructor_id = add_function(add_name("int2")); + function *f = get_function(float2_constructor_id); + init_type_ref(&f->return_type, add_name("int2")); + f->return_type.type = find_type_by_ref(&f->return_type); + + f->parameter_names[0] = add_name("x"); + init_type_ref(&f->parameter_types[0], add_name("int")); + f->parameter_types[0].type = find_type_by_ref(&f->parameter_types[0]); + + f->parameter_names[1] = add_name("y"); + init_type_ref(&f->parameter_types[1], add_name("int")); + f->parameter_types[1].type = find_type_by_ref(&f->parameter_types[1]); + + f->parameters_size = 2; + + f->block = NULL; + } + + { + float3_constructor_id = add_function(add_name("int3")); + function *f = get_function(float3_constructor_id); + init_type_ref(&f->return_type, add_name("int3")); + f->return_type.type = find_type_by_ref(&f->return_type); + + f->parameter_names[0] = add_name("x"); + init_type_ref(&f->parameter_types[0], add_name("int")); + f->parameter_types[0].type = find_type_by_ref(&f->parameter_types[0]); + + f->parameter_names[1] = add_name("y"); + init_type_ref(&f->parameter_types[1], add_name("int")); + f->parameter_types[1].type = find_type_by_ref(&f->parameter_types[1]); + + f->parameter_names[2] = add_name("z"); + init_type_ref(&f->parameter_types[2], add_name("int")); + f->parameter_types[2].type = find_type_by_ref(&f->parameter_types[2]); + + f->parameters_size = 3; + + f->block = NULL; + } + + { + float4_constructor_id = add_function(add_name("int4")); + function *f = get_function(float4_constructor_id); + init_type_ref(&f->return_type, add_name("int4")); + f->return_type.type = find_type_by_ref(&f->return_type); + + f->parameter_names[0] = add_name("x"); + init_type_ref(&f->parameter_types[0], add_name("int")); + f->parameter_types[0].type = find_type_by_ref(&f->parameter_types[0]); + + f->parameter_names[1] = add_name("y"); + init_type_ref(&f->parameter_types[1], add_name("int")); + f->parameter_types[1].type = find_type_by_ref(&f->parameter_types[1]); + + f->parameter_names[2] = add_name("z"); + init_type_ref(&f->parameter_types[2], add_name("int")); + f->parameter_types[2].type = find_type_by_ref(&f->parameter_types[2]); + + f->parameter_names[3] = add_name("w"); + init_type_ref(&f->parameter_types[3], add_name("int")); + f->parameter_types[3].type = find_type_by_ref(&f->parameter_types[3]); + + f->parameters_size = 4; + + f->block = NULL; + } + + { + float2_constructor_id = add_function(add_name("uint2")); + function *f = get_function(float2_constructor_id); + init_type_ref(&f->return_type, add_name("uint2")); + f->return_type.type = find_type_by_ref(&f->return_type); + + f->parameter_names[0] = add_name("x"); + init_type_ref(&f->parameter_types[0], add_name("uint")); + f->parameter_types[0].type = find_type_by_ref(&f->parameter_types[0]); + + f->parameter_names[1] = add_name("y"); + init_type_ref(&f->parameter_types[1], add_name("uint")); + f->parameter_types[1].type = find_type_by_ref(&f->parameter_types[1]); + + f->parameters_size = 2; + + f->block = NULL; + } + + { + float3_constructor_id = add_function(add_name("uint3")); + function *f = get_function(float3_constructor_id); + init_type_ref(&f->return_type, add_name("uint3")); + f->return_type.type = find_type_by_ref(&f->return_type); + + f->parameter_names[0] = add_name("x"); + init_type_ref(&f->parameter_types[0], add_name("uint")); + f->parameter_types[0].type = find_type_by_ref(&f->parameter_types[0]); + + f->parameter_names[1] = add_name("y"); + init_type_ref(&f->parameter_types[1], add_name("uint")); + f->parameter_types[1].type = find_type_by_ref(&f->parameter_types[1]); + + f->parameter_names[2] = add_name("z"); + init_type_ref(&f->parameter_types[2], add_name("uint")); + f->parameter_types[2].type = find_type_by_ref(&f->parameter_types[2]); + + f->parameters_size = 3; + + f->block = NULL; + } + + { + float4_constructor_id = add_function(add_name("uint4")); + function *f = get_function(float4_constructor_id); + init_type_ref(&f->return_type, add_name("uint4")); + f->return_type.type = find_type_by_ref(&f->return_type); + + f->parameter_names[0] = add_name("x"); + init_type_ref(&f->parameter_types[0], add_name("uint")); + f->parameter_types[0].type = find_type_by_ref(&f->parameter_types[0]); + + f->parameter_names[1] = add_name("y"); + init_type_ref(&f->parameter_types[1], add_name("uint")); + f->parameter_types[1].type = find_type_by_ref(&f->parameter_types[1]); + + f->parameter_names[2] = add_name("z"); + init_type_ref(&f->parameter_types[2], add_name("uint")); + f->parameter_types[2].type = find_type_by_ref(&f->parameter_types[2]); + + f->parameter_names[3] = add_name("w"); + init_type_ref(&f->parameter_types[3], add_name("uint")); + f->parameter_types[3].type = find_type_by_ref(&f->parameter_types[3]); + + f->parameters_size = 4; + f->block = NULL; } @@ -215,6 +403,48 @@ void functions_init(void) { add_func_float_float("saturate"); add_func_uint3("ray_index"); add_func_float3("ray_dimensions"); + + add_func_void_uint_uint("set_mesh_output_counts"); + + { + function_id func = add_function(add_name("set_mesh_triangle")); + function *f = get_function(func); + init_type_ref(&f->return_type, add_name("void")); + f->return_type.type = find_type_by_ref(&f->return_type); + f->return_type.array_size = 1; + + f->parameter_names[0] = add_name("x"); + init_type_ref(&f->parameter_types[0], add_name("uint")); + f->parameter_types[0].type = find_type_by_ref(&f->parameter_types[0]); + + f->parameter_names[1] = add_name("y"); + init_type_ref(&f->parameter_types[1], add_name("uint3")); + f->parameter_types[1].type = find_type_by_ref(&f->parameter_types[1]); + + f->parameters_size = 2; + + f->block = NULL; + } + + { + function_id func = add_function(add_name("set_mesh_vertex")); + function *f = get_function(func); + init_type_ref(&f->return_type, add_name("void")); + f->return_type.type = find_type_by_ref(&f->return_type); + f->return_type.array_size = 1; + + f->parameter_names[0] = add_name("x"); + init_type_ref(&f->parameter_types[0], add_name("uint")); + f->parameter_types[0].type = find_type_by_ref(&f->parameter_types[0]); + + f->parameter_names[1] = add_name("y"); + init_type_ref(&f->parameter_types[1], add_name("void")); + f->parameter_types[1].type = find_type_by_ref(&f->parameter_types[1]); + + f->parameters_size = 2; + + f->block = NULL; + } } static void grow_if_needed(uint64_t size) { diff --git a/Sources/parser.c b/Sources/parser.c index 64f72ae..1919692 100644 --- a/Sources/parser.c +++ b/Sources/parser.c @@ -153,6 +153,12 @@ static double attribute_parameter_to_number(name_id attribute_name, name_id para if (attribute_name == add_name("topology") && parameter_name == add_name("triangle")) { return 0; } + + type_id type = find_type_by_name(parameter_name); + if (type != NO_TYPE) { + return (double)type; + } + debug_context context = {0}; error(context, "Unknown attribute parameter %s", get_name(parameter_name)); return 0; @@ -875,7 +881,7 @@ static expression *parse_call(state_t *state, name_id func_name) { advance_state(state); - bool dynamic = square && current(state).kind == TOKEN_NUMBER; + bool dynamic = square && current(state).kind != TOKEN_NUMBER; expression *right = parse_member(state, square); diff --git a/tests/in/test.kong b/tests/in/test.kong index 9802638..5e2bb00 100644 --- a/tests/in/test.kong +++ b/tests/in/test.kong @@ -101,19 +101,28 @@ struct RayPipe { closest = closesthit; }*/ -struct VertexIn { - position: float3; -} - struct FragmentIn { position: float4; } -//fun amplify(): void {} +struct payloadStruct { + myArbitraryData: uint; +} + +#[threads(32, 1, 1)] +fun amplify(): void { + var p: payloadStruct; + p.myArbitraryData = group_id().z; + dispatch_mesh(1, 1, 1, p); +} -#[topology(triangle), threads(32, 1, 1)] +#[topology(triangle), tris(128), vertices(128, FragmentIn), threads(32, 1, 1)] fun meshy(): void { - + set_mesh_output_counts(10, 11); + set_mesh_triangle(group_thread_id().x, uint3(0, 0, 0)); + var vertex_out: FragmentIn; + vertex_out.position = float4(1.0, 0.0, 0.0, 1.0); + set_mesh_vertex(group_thread_id().x, vertex_out); } fun pixel(input: FragmentIn): float4 { @@ -131,7 +140,7 @@ fun pixel(input: FragmentIn): float4 { #[pipe] struct Pipe { - //prim = amplify; + amplification = amplify; mesh = meshy; fragment = pixel; }