Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 1 addition & 12 deletions src/codegen/codegen.c
Original file line number Diff line number Diff line change
Expand Up @@ -365,18 +365,7 @@ static void codegen_lambda_expr(ParserContext *ctx, ASTNode *node)
}
else
{
char *tstr = NULL;
if (node->lambda.captured_types_info && node->lambda.captured_types_info[i])
{
tstr = type_to_c_string(node->lambda.captured_types_info[i]);
}
else
{
tstr = xstrdup(node->lambda.captured_types[i]);
}

EMIT(ctx, "*(%s*)(&_z_ctx_%d->%s) = ", tstr, lid, node->lambda.captured_vars[i]);
zfree(tstr);
EMIT(ctx, "_z_ctx_%d->%s = ", lid, node->lambda.captured_vars[i]);

ASTNode *var_node = ast_create(NODE_EXPR_VAR);
var_node->var_ref.name = xstrdup(node->lambda.captured_vars[i]);
Expand Down
2 changes: 2 additions & 0 deletions src/codegen/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ void codegen_match_internal(ParserContext *ctx, ASTNode *node, int use_result);
#include "codegen_shared.h"

void emit_var_decl_type(ParserContext *ctx, const char *type_str, const char *var_name);
void emit_named_decl_type(ParserContext *ctx, const char *type_str, const char *lambda_name,
int ptrs);
void emit_auto_type(ParserContext *ctx, ASTNode *init_expr, Token t);
void emit_func_signature(ParserContext *ctx, ASTNode *func, const char *name_override);
int emit_move_invalidation(ParserContext *ctx, ASTNode *node);
Expand Down
38 changes: 10 additions & 28 deletions src/codegen/codegen_decl.c
Original file line number Diff line number Diff line change
Expand Up @@ -373,23 +373,11 @@ static void emit_type_aliases_internal(ParserContext *ctx, ASTNode *node, Visite
EMIT(ctx, "#if %s\n", node->cfg_condition);
}
char *c_type_str = type_to_c_string(node->type_info);
// Quick fix for raw function pointers and arrays in typedefs:
// Since type_to_c_string returns `int (*)(int)`, simple replacement isn't valid
// C. But Zen C doesn't officially support raw function pointer aliases. We'll just
// print it.
if (c_type_str)
{
if (strstr(c_type_str, "(*)"))
{
char *ptr = strstr(c_type_str, "(*)");
int prefix_len = ptr - c_type_str;
EMIT(ctx, "typedef %.*s (*%s)%s;\n", prefix_len, c_type_str,
node->type_alias.alias, ptr + 3);
}
else
{
EMIT(ctx, "typedef %s %s;\n", c_type_str, node->type_alias.alias);
}
EMIT(ctx, "typedef ");
emit_named_decl_type(ctx, c_type_str, node->type_alias.alias, 0);
EMIT(ctx, ";\n");
zfree(c_type_str);
}
else
Expand Down Expand Up @@ -429,17 +417,9 @@ void emit_global_aliases(ParserContext *ctx)
char *c_type_str = type_to_c_string(ta->type_info);
if (c_type_str)
{
if (strstr(c_type_str, "(*)"))
{
char *ptr = strstr(c_type_str, "(*)");
int prefix_len = ptr - c_type_str;
EMIT(ctx, "typedef %.*s (*%s)%s;\n", prefix_len, c_type_str, ta->alias,
ptr + 3);
}
else
{
EMIT(ctx, "typedef %s %s;\n", c_type_str, ta->alias);
}
EMIT(ctx, "typedef ");
emit_named_decl_type(ctx, c_type_str, ta->alias, 0);
EMIT(ctx, ";\n");
zfree(c_type_str);
}
else
Expand Down Expand Up @@ -559,7 +539,8 @@ void emit_lambda_defs(ParserContext *ctx)
{
tstr = xstrdup(node->lambda.captured_types[i]);
}
EMIT(ctx, "%s* %s;\n", tstr, node->lambda.captured_vars[i]);
emit_named_decl_type(ctx, tstr, node->lambda.captured_vars[i], 1);
EMIT(ctx, ";\n");
zfree(tstr);
}
else
Expand All @@ -573,7 +554,8 @@ void emit_lambda_defs(ParserContext *ctx)
{
tstr = xstrdup(node->lambda.captured_types[i]);
}
EMIT(ctx, "%s %s;\n", tstr, node->lambda.captured_vars[i]);
emit_named_decl_type(ctx, tstr, node->lambda.captured_vars[i], 0);
EMIT(ctx, ";\n");
zfree(tstr);

char *tname = node->lambda.captured_types[i];
Expand Down
87 changes: 77 additions & 10 deletions src/codegen/codegen_utils.c
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ int is_enum_type_name(ParserContext *ctx, const char *name)
}

// Helper to emit C declaration (handle arrays, function pointers correctly)
void emit_c_decl(ParserContext *ctx, const char *type_str, const char *name)
void emit_c_decl(ParserContext *ctx, const char *type_str, const char *name, int ptrs)
{
char *bracket = strchr(type_str, '[');
char *generic = strchr(type_str, '<');
Expand All @@ -76,13 +76,27 @@ void emit_c_decl(ParserContext *ctx, const char *type_str, const char *name)
if (end_paren)
{
int prefix_len = end_paren - type_str;
EMIT(ctx, "%.*s%s%s", prefix_len, type_str, name, end_paren);
EMIT(ctx, "%.*s", prefix_len, type_str);

for (int i = 0; i < ptrs; i++)
{
EMIT(ctx, "*");
}

EMIT(ctx, "%s%s", name, end_paren);
}
else
{
// Fallback if malformed (shouldn't happen)
int prefix_len = fn_ptr - type_str + 2;
EMIT(ctx, "%.*s%s%s", prefix_len, type_str, name, fn_ptr + 2);
EMIT(ctx, "%.*s", prefix_len, type_str);

for (int i = 0; i < ptrs; i++)
{
EMIT(ctx, "*");
}

EMIT(ctx, "%s%s", name, fn_ptr + 2);
}
}
else if (generic && (!bracket || generic < bracket))
Expand All @@ -104,7 +118,7 @@ void emit_c_decl(ParserContext *ctx, const char *type_str, const char *name)

if (find_struct_def(ctx, mangled_candidate))
{
EMIT(ctx, "%s %s", mangled_candidate, name);
EMIT(ctx, "%s ", mangled_candidate);
success = 1;
}
}
Expand All @@ -113,33 +127,86 @@ void emit_c_decl(ParserContext *ctx, const char *type_str, const char *name)
if (!success)
{
int base_len = generic - type_str;
EMIT(ctx, "%.*s %s", base_len, type_str, name);
EMIT(ctx, "%.*s ", base_len, type_str);
}
else if (gt[1] == '*')

if (gt && gt[1] == '*')
{
EMIT(ctx, "*");
}

if (bracket)
{
if (ptrs > 0)
{
EMIT(ctx, "(");

for (int i = 0; i < ptrs; i++)
{
EMIT(ctx, "*");
}

EMIT(ctx, "%s)", name);
}
else
{
EMIT(ctx, "%s", name);
}
EMIT(ctx, "%s", bracket);
}
else
{
for (int i = 0; i < ptrs; i++)
{
EMIT(ctx, "*");
}

EMIT(ctx, "%s", name);
}
}
else if (bracket)
{
int base_len = bracket - type_str;
EMIT(ctx, "%.*s %s%s", base_len, type_str, name, bracket);
EMIT(ctx, "%.*s ", base_len, type_str);
if (ptrs > 0)
{
EMIT(ctx, "(");
for (int i = 0; i < ptrs; i++)
{
EMIT(ctx, "*");
}
EMIT(ctx, "%s)", name);
}
else
{
EMIT(ctx, "%s", name);
}
EMIT(ctx, "%s", bracket);
}
else
{
EMIT(ctx, "%s %s", type_str, name);
EMIT(ctx, "%s ", type_str);

for (int i = 0; i < ptrs; i++)
{
EMIT(ctx, "*");
}

EMIT(ctx, "%s", name);
}
}

// Helper to emit variable declarations with array types.
void emit_var_decl_type(ParserContext *ctx, const char *type_str, const char *var_name)
{
emit_c_decl(ctx, type_str, var_name);
emit_c_decl(ctx, type_str, var_name, 0);
}

// Helper to emit named declarations.
void emit_named_decl_type(ParserContext *ctx, const char *type_str, const char *lambda_name,
int ptrs)
{
emit_c_decl(ctx, type_str, lambda_name, ptrs);
}

// Get field type from struct.
Expand Down Expand Up @@ -299,7 +366,7 @@ void emit_func_signature(ParserContext *ctx, ASTNode *func, const char *name_ove
}

// check if array type
emit_c_decl(ctx, type_str, name);
emit_c_decl(ctx, type_str, name, 0);
zfree(type_str);
}
if (func->func.is_varargs)
Expand Down
30 changes: 30 additions & 0 deletions tests/language/functions/test_lambda_capture.zc
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// language/functions: functions: test_lambda_capture
alias fp = fn*(int) -> int;
alias clos = fn(int) -> int;

fn square(x: int) -> int { return x*x; }
fn dec(x: int) -> int { return x+1; }
fn inc(x: int) -> int { return x-1; }

fn compose(f: fp, g: fp, h: fp) -> clos {
return fn(x: int) -> int {
return h(g(f(x)));
}
}


test "lambda captures local fn ptr" {
let f: fp = square;
let call_f: clos = fn(x: int) -> int {
return f(x);
}

let res = call_f(5)
assert(res == 25, "call_f(5) should be 25");
}

test "returned clos captures fn ptr args" {
let composed = compose(square, dec, inc);
let res = composed(10)
assert(res == 100, "composed(10) should be 100")
}
Loading