diff --git a/src/codegen/codegen.c b/src/codegen/codegen.c index 57928672..6c0a85ff 100644 --- a/src/codegen/codegen.c +++ b/src/codegen/codegen.c @@ -365,18 +365,7 @@ static void codegen_lambda_expr(ParserContext *ctx, ASTNode *node) } else { - char *tstr = NULL; - if (node->lambda.captured_types_info && node->lambda.captured_types_info[i]) - { - tstr = type_to_c_string(node->lambda.captured_types_info[i]); - } - else - { - tstr = xstrdup(node->lambda.captured_types[i]); - } - - EMIT(ctx, "*(%s*)(&_z_ctx_%d->%s) = ", tstr, lid, node->lambda.captured_vars[i]); - zfree(tstr); + EMIT(ctx, "_z_ctx_%d->%s = ", lid, node->lambda.captured_vars[i]); ASTNode *var_node = ast_create(NODE_EXPR_VAR); var_node->var_ref.name = xstrdup(node->lambda.captured_vars[i]); diff --git a/src/codegen/codegen.h b/src/codegen/codegen.h index 9cdd959a..23e1e596 100644 --- a/src/codegen/codegen.h +++ b/src/codegen/codegen.h @@ -66,6 +66,8 @@ void codegen_match_internal(ParserContext *ctx, ASTNode *node, int use_result); #include "codegen_shared.h" void emit_var_decl_type(ParserContext *ctx, const char *type_str, const char *var_name); +void emit_named_decl_type(ParserContext *ctx, const char *type_str, const char *lambda_name, + int ptrs); void emit_auto_type(ParserContext *ctx, ASTNode *init_expr, Token t); void emit_func_signature(ParserContext *ctx, ASTNode *func, const char *name_override); int emit_move_invalidation(ParserContext *ctx, ASTNode *node); diff --git a/src/codegen/codegen_decl.c b/src/codegen/codegen_decl.c index 989feb5a..a905ac89 100644 --- a/src/codegen/codegen_decl.c +++ b/src/codegen/codegen_decl.c @@ -373,23 +373,11 @@ static void emit_type_aliases_internal(ParserContext *ctx, ASTNode *node, Visite EMIT(ctx, "#if %s\n", node->cfg_condition); } char *c_type_str = type_to_c_string(node->type_info); - // Quick fix for raw function pointers and arrays in typedefs: - // Since type_to_c_string returns `int (*)(int)`, simple replacement isn't valid - // C. But Zen C doesn't officially support raw function pointer aliases. We'll just - // print it. if (c_type_str) { - if (strstr(c_type_str, "(*)")) - { - char *ptr = strstr(c_type_str, "(*)"); - int prefix_len = ptr - c_type_str; - EMIT(ctx, "typedef %.*s (*%s)%s;\n", prefix_len, c_type_str, - node->type_alias.alias, ptr + 3); - } - else - { - EMIT(ctx, "typedef %s %s;\n", c_type_str, node->type_alias.alias); - } + EMIT(ctx, "typedef "); + emit_named_decl_type(ctx, c_type_str, node->type_alias.alias, 0); + EMIT(ctx, ";\n"); zfree(c_type_str); } else @@ -429,17 +417,9 @@ void emit_global_aliases(ParserContext *ctx) char *c_type_str = type_to_c_string(ta->type_info); if (c_type_str) { - if (strstr(c_type_str, "(*)")) - { - char *ptr = strstr(c_type_str, "(*)"); - int prefix_len = ptr - c_type_str; - EMIT(ctx, "typedef %.*s (*%s)%s;\n", prefix_len, c_type_str, ta->alias, - ptr + 3); - } - else - { - EMIT(ctx, "typedef %s %s;\n", c_type_str, ta->alias); - } + EMIT(ctx, "typedef "); + emit_named_decl_type(ctx, c_type_str, ta->alias, 0); + EMIT(ctx, ";\n"); zfree(c_type_str); } else @@ -559,7 +539,8 @@ void emit_lambda_defs(ParserContext *ctx) { tstr = xstrdup(node->lambda.captured_types[i]); } - EMIT(ctx, "%s* %s;\n", tstr, node->lambda.captured_vars[i]); + emit_named_decl_type(ctx, tstr, node->lambda.captured_vars[i], 1); + EMIT(ctx, ";\n"); zfree(tstr); } else @@ -573,7 +554,8 @@ void emit_lambda_defs(ParserContext *ctx) { tstr = xstrdup(node->lambda.captured_types[i]); } - EMIT(ctx, "%s %s;\n", tstr, node->lambda.captured_vars[i]); + emit_named_decl_type(ctx, tstr, node->lambda.captured_vars[i], 0); + EMIT(ctx, ";\n"); zfree(tstr); char *tname = node->lambda.captured_types[i]; diff --git a/src/codegen/codegen_utils.c b/src/codegen/codegen_utils.c index 892c7cdb..c9fdf105 100644 --- a/src/codegen/codegen_utils.c +++ b/src/codegen/codegen_utils.c @@ -64,7 +64,7 @@ int is_enum_type_name(ParserContext *ctx, const char *name) } // Helper to emit C declaration (handle arrays, function pointers correctly) -void emit_c_decl(ParserContext *ctx, const char *type_str, const char *name) +void emit_c_decl(ParserContext *ctx, const char *type_str, const char *name, int ptrs) { char *bracket = strchr(type_str, '['); char *generic = strchr(type_str, '<'); @@ -76,13 +76,27 @@ void emit_c_decl(ParserContext *ctx, const char *type_str, const char *name) if (end_paren) { int prefix_len = end_paren - type_str; - EMIT(ctx, "%.*s%s%s", prefix_len, type_str, name, end_paren); + EMIT(ctx, "%.*s", prefix_len, type_str); + + for (int i = 0; i < ptrs; i++) + { + EMIT(ctx, "*"); + } + + EMIT(ctx, "%s%s", name, end_paren); } else { // Fallback if malformed (shouldn't happen) int prefix_len = fn_ptr - type_str + 2; - EMIT(ctx, "%.*s%s%s", prefix_len, type_str, name, fn_ptr + 2); + EMIT(ctx, "%.*s", prefix_len, type_str); + + for (int i = 0; i < ptrs; i++) + { + EMIT(ctx, "*"); + } + + EMIT(ctx, "%s%s", name, fn_ptr + 2); } } else if (generic && (!bracket || generic < bracket)) @@ -104,7 +118,7 @@ void emit_c_decl(ParserContext *ctx, const char *type_str, const char *name) if (find_struct_def(ctx, mangled_candidate)) { - EMIT(ctx, "%s %s", mangled_candidate, name); + EMIT(ctx, "%s ", mangled_candidate); success = 1; } } @@ -113,33 +127,86 @@ void emit_c_decl(ParserContext *ctx, const char *type_str, const char *name) if (!success) { int base_len = generic - type_str; - EMIT(ctx, "%.*s %s", base_len, type_str, name); + EMIT(ctx, "%.*s ", base_len, type_str); } - else if (gt[1] == '*') + + if (gt && gt[1] == '*') { EMIT(ctx, "*"); } if (bracket) { + if (ptrs > 0) + { + EMIT(ctx, "("); + + for (int i = 0; i < ptrs; i++) + { + EMIT(ctx, "*"); + } + + EMIT(ctx, "%s)", name); + } + else + { + EMIT(ctx, "%s", name); + } EMIT(ctx, "%s", bracket); } + else + { + for (int i = 0; i < ptrs; i++) + { + EMIT(ctx, "*"); + } + + EMIT(ctx, "%s", name); + } } else if (bracket) { int base_len = bracket - type_str; - EMIT(ctx, "%.*s %s%s", base_len, type_str, name, bracket); + EMIT(ctx, "%.*s ", base_len, type_str); + if (ptrs > 0) + { + EMIT(ctx, "("); + for (int i = 0; i < ptrs; i++) + { + EMIT(ctx, "*"); + } + EMIT(ctx, "%s)", name); + } + else + { + EMIT(ctx, "%s", name); + } + EMIT(ctx, "%s", bracket); } else { - EMIT(ctx, "%s %s", type_str, name); + EMIT(ctx, "%s ", type_str); + + for (int i = 0; i < ptrs; i++) + { + EMIT(ctx, "*"); + } + + EMIT(ctx, "%s", name); } } // Helper to emit variable declarations with array types. void emit_var_decl_type(ParserContext *ctx, const char *type_str, const char *var_name) { - emit_c_decl(ctx, type_str, var_name); + emit_c_decl(ctx, type_str, var_name, 0); +} + +// Helper to emit named declarations. +void emit_named_decl_type(ParserContext *ctx, const char *type_str, const char *lambda_name, + int ptrs) +{ + emit_c_decl(ctx, type_str, lambda_name, ptrs); } // Get field type from struct. @@ -299,7 +366,7 @@ void emit_func_signature(ParserContext *ctx, ASTNode *func, const char *name_ove } // check if array type - emit_c_decl(ctx, type_str, name); + emit_c_decl(ctx, type_str, name, 0); zfree(type_str); } if (func->func.is_varargs) diff --git a/tests/language/functions/test_lambda_capture.zc b/tests/language/functions/test_lambda_capture.zc new file mode 100644 index 00000000..c0bfc852 --- /dev/null +++ b/tests/language/functions/test_lambda_capture.zc @@ -0,0 +1,30 @@ +// language/functions: functions: test_lambda_capture +alias fp = fn*(int) -> int; +alias clos = fn(int) -> int; + +fn square(x: int) -> int { return x*x; } +fn dec(x: int) -> int { return x+1; } +fn inc(x: int) -> int { return x-1; } + +fn compose(f: fp, g: fp, h: fp) -> clos { + return fn(x: int) -> int { + return h(g(f(x))); + } +} + + +test "lambda captures local fn ptr" { + let f: fp = square; + let call_f: clos = fn(x: int) -> int { + return f(x); + } + + let res = call_f(5) + assert(res == 25, "call_f(5) should be 25"); +} + +test "returned clos captures fn ptr args" { + let composed = compose(square, dec, inc); + let res = composed(10) + assert(res == 100, "composed(10) should be 100") +} \ No newline at end of file