Skip to content

Commit f415a1b

Browse files
committed
make Let bind variables
1 parent 4d4dd59 commit f415a1b

35 files changed

+314
-334
lines changed

include/shady/grammar.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ inline static bool is_nominal(const Node* node) {
8585
NodeTag tag = node->tag;
8686
if (node->tag == PrimOp_TAG && has_primop_got_side_effects(node->payload.prim_op.op))
8787
return true;
88-
return tag == Function_TAG || tag == BasicBlock_TAG || tag == Constant_TAG || tag == Param_TAG || tag == GlobalVariable_TAG || tag == NominalType_TAG || tag == Case_TAG;
88+
return tag == Function_TAG || tag == BasicBlock_TAG || tag == Constant_TAG || tag == Param_TAG || tag == Variablez_TAG || tag == GlobalVariable_TAG || tag == NominalType_TAG || tag == Case_TAG;
8989
}
9090

9191
inline static bool is_function(const Node* node) { return node->tag == Function_TAG; }

include/shady/grammar.json

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,10 @@
8888
"name": "param",
8989
"generate-enum": false
9090
},
91+
{
92+
"name": "variable",
93+
"generate-enum": false
94+
},
9195
{
9296
"name": "instruction"
9397
},
@@ -254,7 +258,7 @@
254258
{
255259
"name": "Variablez",
256260
"snake_name": "varz",
257-
"class": ["value"],
261+
"class": ["value", "variable"],
258262
"constructor": "custom",
259263
"ops": [
260264
{ "name": "name", "class": "string" },
@@ -448,6 +452,7 @@
448452
"constructor": "custom",
449453
"ops": [
450454
{ "name": "instruction", "class": "instruction" },
455+
{ "name": "variables", "class": "variable", "list": true },
451456
{ "name": "tail", "class": "case" }
452457
]
453458
},
@@ -458,7 +463,8 @@
458463
"front-end-only": true,
459464
"ops": [
460465
{ "name": "instruction", "class": "instruction" },
461-
{ "name": "variables", "class": "param", "list": true }
466+
{ "name": "variables", "class": "param", "list": true },
467+
{ "name": "types", "class": "type", "list": true }
462468
]
463469
},
464470
{
@@ -655,8 +661,7 @@
655661
"constructor": "custom",
656662
"class": "case",
657663
"description": [
658-
"An unnamed abstraction that lives inside a function, and can be used as part of various control-flow constructs",
659-
"Most notably, the tails of standard `let` nodes"
664+
"An unnamed abstraction that lives inside a function, and can be used as part of various control-flow constructs"
660665
],
661666
"ops": [
662667
{ "name": "params", "class": "param", "list": true },

include/shady/ir.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,9 @@ const Node* quote_helper(IrArena*, Nodes values);
225225
const Node* prim_op_helper(IrArena*, Op, Nodes, Nodes);
226226

227227
// terminators
228-
const Node* let(IrArena*, const Node* instruction, const Node* tail);
229-
const Node* let_mut(IrArena*, const Node* instruction, Nodes variables);
228+
const Node* var(IrArena* arena, const char* name, const Node* instruction, size_t i);
229+
const Node* let(IrArena*, const Node* instruction, Nodes vars, const Node* tail);
230+
const Node* let_mut(IrArena*, const Node* instruction, Nodes variables, Nodes types);
230231
const Node* jump_helper(IrArena* a, const Node* dst, Nodes args);
231232

232233
// decl ctors

src/frontends/llvm/l2s.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ static TodoBB prepare_bb(Parser* p, FnParseCtx* fn_ctx, LLVMBasicBlockRef bb) {
7878
while (instr) {
7979
switch (LLVMGetInstructionOpcode(instr)) {
8080
case LLVMPHI: {
81-
const Node* nparam = var(a, qualified_type_helper(convert_type(p, LLVMTypeOf(instr)), false), "phi");
81+
const Node* nparam = param(a, qualified_type_helper(convert_type(p, LLVMTypeOf(instr)), false), "phi");
8282
insert_dict(LLVMValueRef, const Node*, p->map, instr, nparam);
8383
append_list(LLVMValueRef, phis, instr);
8484
params = append_nodes(a, params, nparam);
@@ -137,9 +137,9 @@ const Node* convert_function(Parser* p, LLVMValueRef fn) {
137137
for (LLVMValueRef oparam = LLVMGetFirstParam(fn); oparam; oparam = LLVMGetNextParam(oparam)) {
138138
LLVMTypeRef ot = LLVMTypeOf(oparam);
139139
const Type* t = convert_type(p, ot);
140-
const Node* param = var(a, qualified_type_helper(t, false), LLVMGetValueName(oparam));
141-
insert_dict(LLVMValueRef, const Node*, p->map, oparam, param);
142-
params = append_nodes(a, params, param);
140+
const Node* nparam = param(a, qualified_type_helper(t, false), LLVMGetValueName(oparam));
141+
insert_dict(LLVMValueRef, const Node*, p->map, oparam, nparam);
142+
params = append_nodes(a, params, nparam);
143143
if (oparam == LLVMGetLastParam(fn))
144144
break;
145145
}

src/frontends/llvm/l2s_postprocess.c

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,12 @@ static const Node* wrap_in_controls(Context* ctx, Controls* controls, const Node
3939
const Node* token = controls->tokens.nodes[i];
4040
const Node* dst = controls->destinations.nodes[i];
4141
Nodes o_dst_params = get_abstraction_params(dst);
42-
LARRAY(const Node*, new_control_params, o_dst_params.count);
43-
for (size_t j = 0; j < o_dst_params.count; j++)
44-
new_control_params[j] = param(a, o_dst_params.nodes[j]->payload.param.type, unique_name(a, "v"));
45-
Nodes nparams = nodes(a, o_dst_params.count, new_control_params);
46-
body = let(a, control(a, (Control) {
47-
.yield_types = get_param_types(a, o_dst_params),
48-
.inside = case_(a, singleton(token), body)
49-
}), case_(a, nparams, jump_helper(a, rewrite_node(&ctx->rewriter, dst), nparams)));
42+
BodyBuilder* bb = begin_body(a);
43+
Nodes results = bind_instruction(bb, control(a, (Control) {
44+
.yield_types = get_param_types(a, o_dst_params),
45+
.inside = case_(a, singleton(token), body)
46+
}));
47+
body = finish_body(bb, jump_helper(a, rewrite_node(&ctx->rewriter, dst), results));
5048
}
5149
return body;
5250
}

src/frontends/slim/parser.c

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ static void expect_parameters(ctxparams, Nodes* parameters, Nodes* default_value
368368
const char* id = accept_identifier(ctx);
369369
expect(id);
370370

371-
const Node* node = var(arena, qtype, id);
371+
const Node* node = param(arena, qtype, id);
372372
append_list(Node*, params, node);
373373

374374
if (default_values) {
@@ -617,11 +617,11 @@ static const Node* accept_control_flow_instruction(ctxparams, Node* fn) {
617617
expect(accept_token(ctx, lpar_tok));
618618
String str = accept_identifier(ctx);
619619
expect(str);
620-
const Node* param = var(arena, join_point_type(arena, (JoinPointType) {
620+
const Node* jp = param(arena, join_point_type(arena, (JoinPointType) {
621621
.yield_types = yield_types,
622622
}), str);
623623
expect(accept_token(ctx, rpar_tok));
624-
const Node* body = case_(arena, singleton(param), expect_body(ctx, fn, NULL));
624+
const Node* body = case_(arena, singleton(jp), expect_body(ctx, fn, NULL));
625625
return control(arena, (Control) {
626626
.inside = body,
627627
.yield_types = yield_types
@@ -726,6 +726,15 @@ static const Node* expect_jump(ctxparams) {
726726
});
727727
}
728728

729+
/// for convenience, parse variables as parameters
730+
static Nodes params2vars(IrArena* arena, const Node* instruction, Nodes params) {
731+
LARRAY(const Node*, vars, params.count);
732+
for (size_t i = 0; i < params.count; i++) {
733+
vars[i] = var(arena, params.nodes[i]->payload.param.name, instruction, i);
734+
}
735+
return nodes(arena, params.count, vars);
736+
}
737+
729738
static const Node* accept_terminator(ctxparams, Node* fn) {
730739
TokenTag tag = curr_token(tokenizer).tag;
731740
switch (tag) {
@@ -738,7 +747,7 @@ static const Node* accept_terminator(ctxparams, Node* fn) {
738747
case let_tok: {
739748
const Node* lam = accept_case(ctx, fn);
740749
expect(lam);
741-
return let(arena, instruction, lam);
750+
return let(arena, instruction, params2vars(arena, instruction, get_abstraction_params(lam)), get_abstraction_body(lam));
742751
}
743752
default: SHADY_UNREACHABLE;
744753
}

src/frontends/spirv/s2s.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -864,7 +864,7 @@ size_t parse_spv_instruction_at(SpvParser* parser, size_t instruction_offset) {
864864
parser->defs[result].type = Value;
865865
String param_name = get_name(parser, result);
866866
param_name = param_name ? param_name : format_string_arena(parser->arena->arena, "param%d", parser->fun_arg_i);
867-
parser->defs[result].node = var(parser->arena, qualified_type_helper(get_def_type(parser, result_t), parser->is_entry_pt), param_name);
867+
parser->defs[result].node = param(parser->arena, qualified_type_helper(get_def_type(parser, result_t), parser->is_entry_pt), param_name);
868868
break;
869869
}
870870
case SpvOpLabel: {
@@ -923,7 +923,7 @@ size_t parse_spv_instruction_at(SpvParser* parser, size_t instruction_offset) {
923923
parser->defs[result].type = Value;
924924
String phi_name = get_name(parser, result);
925925
phi_name = phi_name ? phi_name : unique_name(parser->arena, "phi");
926-
parser->defs[result].node = var(parser->arena, qualified_type_helper(get_def_type(parser, result_t), false), phi_name);
926+
parser->defs[result].node = param(parser->arena, qualified_type_helper(get_def_type(parser, result_t), false), phi_name);
927927
assert(size % 2 == 1);
928928
int num_callsites = (size - 3) / 2;
929929
for (size_t i = 0; i < num_callsites; i++) {

src/shady/analysis/free_variables.c

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,24 +23,28 @@ typedef struct {
2323
static void search_op_for_free_variables(Context* visitor, NodeClass class, String op_name, const Node* node) {
2424
assert(node);
2525
switch (node->tag) {
26-
case Param_TAG: {
27-
Nodes params = get_abstraction_params(visitor->current_scope->node->node);
28-
for (size_t i = 0; i < params.count; i++) {
29-
if (params.nodes[i] == node)
30-
return;
26+
case Let_TAG: {
27+
Nodes variables = node->payload.let.variables;
28+
for (size_t j = 0; j < variables.count; j++) {
29+
const Node* var = variables.nodes[j];
30+
bool r = insert_set_get_result(const Node*, visitor->current_scope->bound_set, var);
31+
assert(r);
3132
}
32-
insert_set_get_result(const Node*, visitor->current_scope->free_set, node);
3333
break;
3434
}
35-
case Variablez_TAG: {
36-
insert_set_get_result(const Node*, visitor->current_scope->free_set, node);
37-
break;
35+
case Variablez_TAG:
36+
case Param_TAG: {
37+
const Node** found = find_key_dict(const Node*, visitor->current_scope->bound_set, node);
38+
if (!found)
39+
insert_set_get_result(const Node*, visitor->current_scope->free_set, node);
40+
return;
3841
}
3942
case Function_TAG:
4043
case Case_TAG:
4144
case BasicBlock_TAG: assert(false);
42-
default: visit_node_operands(&visitor->visitor, IGNORE_ABSTRACTIONS_MASK, node); break;
45+
default: break;
4346
}
47+
visit_node_operands(&visitor->visitor, IGNORE_ABSTRACTIONS_MASK | NcVariable, node);
4448
}
4549

4650
static CFNodeVariables* create_node_variables(CFNode* cfnode) {
@@ -57,11 +61,7 @@ static CFNodeVariables* visit_domtree(Context* ctx, CFNode* cfnode, int depth, C
5761
ctx = &new_context;
5862

5963
ctx->current_scope = create_node_variables(cfnode);
60-
if (parent) {
61-
ctx->current_scope->bound_set = clone_dict(parent->bound_set);
62-
} else {
63-
ctx->current_scope->bound_set = new_set(const Node*, (HashFn) hash_node, (CmpFn) compare_node);
64-
}
64+
ctx->current_scope->bound_set = new_set(const Node*, (HashFn) hash_node, (CmpFn) compare_node);
6565
insert_dict(CFNode*, CFNodeVariables*, ctx->map, cfnode, ctx->current_scope);
6666
const Node* abs = cfnode->node;
6767

@@ -85,15 +85,21 @@ static CFNodeVariables* visit_domtree(Context* ctx, CFNode* cfnode, int depth, C
8585
size_t j = 0;
8686
const Node* free_var;
8787
while (dict_iter(child_variables->free_set, &j, &free_var, NULL)) {
88-
for (size_t k = 0; k < params.count; k++) {
89-
if (params.nodes[k] == free_var)
90-
goto next;
91-
}
92-
insert_set_get_result(const Node*, ctx->current_scope->free_set, free_var);
88+
const Node** found = find_key_dict(const Node*, ctx->current_scope->bound_set, free_var);
89+
if (!found)
90+
insert_set_get_result(const Node*, ctx->current_scope->free_set, free_var);
9391
next:;
9492
}
9593
}
9694

95+
if (parent) {
96+
size_t j = 0;
97+
const Node* bound;
98+
while (dict_iter(parent->bound_set, &j, &bound, NULL)) {
99+
insert_set_get_result(const Node*, ctx->current_scope->bound_set, bound);
100+
}
101+
}
102+
97103
/*String abs_name = get_abstraction_name_unsafe(abs);
98104
for (int i = 0; i < depth; i++)
99105
debugvv_print(" ");

src/shady/body_builder.c

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ BodyBuilder* begin_body(IrArena* a) {
2323
return bb;
2424
}
2525

26-
Node* var(IrArena* arena, const char* name, const Node* instruction, size_t i);
26+
const Node* var(IrArena* arena, const char* name, const Node* instruction, size_t i);
2727

2828
static Nodes create_output_variables(IrArena* a, const Node* value, size_t outputs_count, const Node** output_types, String const output_names[]) {
2929
Nodes types;
@@ -93,7 +93,7 @@ Nodes bind_instruction_explicit_result_types(BodyBuilder* bb, const Node* instru
9393

9494
Nodes create_mutable_variables(BodyBuilder* bb, const Node* instruction, Nodes provided_types, String const output_names[]) {
9595
Nodes mutable_vars = create_output_variables(bb->arena, instruction, provided_types.count, provided_types.nodes, output_names);
96-
const Node* let_mut_instr = let_mut(bb->arena, instruction, mutable_vars);
96+
const Node* let_mut_instr = let_mut(bb->arena, instruction, mutable_vars, provided_types);
9797
return bind_internal(bb, let_mut_instr, 0, NULL, NULL);
9898
}
9999

@@ -109,12 +109,19 @@ void bind_variables(BodyBuilder* bb, Nodes vars, Nodes values) {
109109
append_list(StackEntry, bb->stack, entry);
110110
}
111111

112+
void bind_variables2(BodyBuilder* bb, Nodes vars, const Node* instr) {
113+
StackEntry entry = {
114+
.instr = instr,
115+
.vars = vars,
116+
};
117+
append_list(StackEntry, bb->stack, entry);
118+
}
119+
112120
const Node* finish_body(BodyBuilder* bb, const Node* terminator) {
113121
size_t stack_size = entries_count_list(bb->stack);
114122
for (size_t i = stack_size - 1; i < stack_size; i--) {
115123
StackEntry entry = read_list(StackEntry, bb->stack)[i];
116-
const Node* lam = case_(bb->arena, entry.vars, terminator);
117-
terminator = let(bb->arena, entry.instr, lam);
124+
terminator = let(bb->arena, entry.instr, entry.vars, case_(bb->arena, empty(bb->arena), terminator));
118125
}
119126

120127
destroy_list(bb->stack);

src/shady/compile.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ CompilationResult run_compiler_passes(CompilerConfig* config, Module** pmod) {
6767

6868
RUN_PASS(specialize_execution_model)
6969

70-
RUN_PASS(opt_stack)
70+
//RUN_PASS(opt_stack)
7171

7272
RUN_PASS(lower_tailcalls)
7373
RUN_PASS(lower_switch_btree)

0 commit comments

Comments
 (0)