Skip to content

Commit 4585e7d

Browse files
committed
lua4jvm: Implement VARIABLE_TRACING pass and re-enable upvalue typing
This should make local functions efficient enough for numerical code! While testing said numerical code, it turned out that lua4jvm's support for using ints for numerical code was, especially mixed with doubles, quite buggy. The bugs are now fixed.
1 parent 88033d3 commit 4585e7d

28 files changed

+236
-64
lines changed

lua4jvm/src/main/java/fi/benjami/code4jvm/lua/LuaVm.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,8 @@ public LuaModule compile(String chunk) {
101101
public LuaFunction load(LuaModule module, LuaTable env) {
102102
// Instantiate the module
103103
var type = LuaType.function(
104-
// TODO _ENV mutability tracking
105-
List.of(new UpvalueTemplate(module.env(), LuaType.TABLE)),
104+
// TODO _ENV mutability tracking - we need LuaContext, which is a bit tricky here...
105+
List.of(new UpvalueTemplate(module.env(), LuaType.UNKNOWN, true)),
106106
List.of(),
107107
module.root(),
108108
module.name(),

lua4jvm/src/main/java/fi/benjami/code4jvm/lua/compiler/CompilerPass.java

+7-3
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,17 @@ public enum CompilerPass {
88
IR_GEN,
99

1010
/**
11-
* In this phase, variables are traced to determine their mutability and
12-
* other properties. TODO not yet implemented
11+
* Return tracking, to determine if lua4jvm has to insert empty returns.
12+
*/
13+
RETURN_TRACKING,
14+
15+
/**
16+
* Variable flagging, based on e.g. their mutability.
1317
*/
1418
VARIABLE_TRACING,
1519

1620
/**
17-
* In analysis phase, types that can be statically inferred are inferred
21+
* In analysis pass, types that can be statically inferred are inferred
1822
* to generate better code later.
1923
*/
2024
TYPE_ANALYSIS,

lua4jvm/src/main/java/fi/benjami/code4jvm/lua/compiler/FunctionCompiler.java

-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ public static MethodHandle callTarget(LuaType[] argTypes, LuaFunction function,
6363

6464
// Compile and load the function code, or use something that is already cached
6565
var compiledFunc = function.type().specializations().computeIfAbsent(cacheKey, t -> {
66-
CompilerPass.setCurrent(CompilerPass.TYPE_ANALYSIS);
6766
var ctx = LuaContext.forFunction(function.owner(), function.type(), truncateReturn, argTypes);
6867

6968
CompilerPass.setCurrent(CompilerPass.CODEGEN);

lua4jvm/src/main/java/fi/benjami/code4jvm/lua/compiler/IrCompiler.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,8 @@ public IrNode visitStringConcat(StringConcatContext ctx) {
448448
@Override
449449
public IrNode visitNumberLiteral(NumberLiteralContext ctx) {
450450
var value = Double.valueOf(ctx.Numeral().getText());
451-
return new LuaConstant(value.intValue() == value ? value.intValue() : value);
451+
// Use Math.rint() to handle very large doubles safely
452+
return Math.rint(value) == value ? new LuaConstant(value.intValue()) : new LuaConstant(value);
452453
}
453454

454455
@Override

lua4jvm/src/main/java/fi/benjami/code4jvm/lua/compiler/LuaContext.java

+5
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,13 @@ public static LuaContext forFunction(LuaVm vm, LuaType.Function type, boolean tr
4545
ctx.setFlag(arg, VariableFlag.ASSIGNED); // JVM assigns arguments to these
4646
}
4747

48+
// Do variable flagging BEFORE type analysis, we need that mutability information
49+
type.body().flagVariables(ctx);
50+
4851
// Compute types of local variables and the return type
52+
CompilerPass.setCurrent(CompilerPass.TYPE_ANALYSIS);
4953
type.body().outputType(ctx);
54+
CompilerPass.setCurrent(null);
5055
return ctx;
5156
}
5257

lua4jvm/src/main/java/fi/benjami/code4jvm/lua/compiler/VariableFlag.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ public enum VariableFlag {
1717
/**
1818
* Variable is mutable; that is, it is assigned to at least twice.
1919
*/
20-
MUTABLE(CompilerPass.TYPE_ANALYSIS)
20+
MUTABLE(CompilerPass.VARIABLE_TRACING)
2121
;
2222

2323
/**

lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ffi/JavaFunction.java

+17-4
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,9 @@ public Target matchToArgs(LuaType[] argTypes, String intrinsicId) {
8282
if (target.intrinsicId != null && !target.intrinsicId.equals(intrinsicId)) {
8383
continue; // Intrinsic not allowed by caller
8484
}
85-
if (checkArgs(target, argTypes) == MatchResult.SUCCESS) {
85+
var result = checkArgs(target, argTypes);
86+
if (result == MatchResult.SUCCESS || result == MatchResult.INT_DOUBLE_CAST_NEEDED) {
87+
// Linker calls MethodHandle#cast(...), which casts ints to doubles if needed
8688
return target;
8789
}
8890
}
@@ -99,7 +101,8 @@ private enum MatchResult {
99101
SUCCESS,
100102
TOO_FEW_ARGS,
101103
ARG_TYPE_MISMATCH,
102-
VARARGS_TYPE_MISMATCH
104+
VARARGS_TYPE_MISMATCH,
105+
INT_DOUBLE_CAST_NEEDED
103106
}
104107

105108
private MatchResult checkArgs(Target target, LuaType[] argTypes) {
@@ -113,12 +116,18 @@ private MatchResult checkArgs(Target target, LuaType[] argTypes) {
113116
}
114117

115118
// Check types of arguments
119+
var intDoubleCast = false;
116120
for (var i = 0; i < requiredArgs; i++) {
117121
var arg = target.arguments.get(i);
118122
if (!arg.type.isAssignableFrom(argTypes[i])) {
119123
// Allow nil instead of expected type if nullability is allowed
120-
if (!arg.nullable ||!argTypes[i].equals(LuaType.NIL)) {
121-
return MatchResult.ARG_TYPE_MISMATCH;
124+
if (!arg.nullable || !argTypes[i].equals(LuaType.NIL)) {
125+
if (argTypes[i].equals(LuaType.INTEGER) && arg.type.equals(LuaType.FLOAT)) {
126+
// We'll need to cast ints to doubles using MethodHandle magic
127+
intDoubleCast = true;
128+
} else {
129+
return MatchResult.ARG_TYPE_MISMATCH;
130+
}
122131
}
123132
}
124133
}
@@ -133,6 +142,10 @@ private MatchResult checkArgs(Target target, LuaType[] argTypes) {
133142
}
134143
}
135144

145+
if (intDoubleCast) {
146+
return MatchResult.INT_DOUBLE_CAST_NEEDED;
147+
}
148+
136149
return MatchResult.SUCCESS;
137150
}
138151
}

lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/DebugInfoNode.java

+5
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ public boolean hasReturn() {
3131
return node.hasReturn();
3232
}
3333

34+
@Override
35+
public void flagVariables(LuaContext ctx) {
36+
node.flagVariables(ctx);
37+
}
38+
3439
@Override
3540
public IrNode concreteNode() {
3641
return node.concreteNode();

lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/IrNode.java

+4
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ default boolean hasReturn() {
1414
return false;
1515
}
1616

17+
default void flagVariables(LuaContext ctx) {
18+
// No-op
19+
}
20+
1721
default IrNode concreteNode() {
1822
return this;
1923
}

lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/LuaBlock.java

+7
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,11 @@ public boolean hasReturn() {
3939
return false;
4040
}
4141

42+
@Override
43+
public void flagVariables(LuaContext ctx) {
44+
for (var node : nodes) {
45+
node.flagVariables(ctx);
46+
}
47+
}
48+
4249
}

lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/LuaType.java

+5-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import fi.benjami.code4jvm.Value;
1212
import fi.benjami.code4jvm.lua.compiler.CompiledFunction;
1313
import fi.benjami.code4jvm.lua.compiler.CompiledShape;
14+
import fi.benjami.code4jvm.lua.compiler.CompilerPass;
1415
import fi.benjami.code4jvm.lua.compiler.FunctionCompiler;
1516
import fi.benjami.code4jvm.lua.compiler.ShapeTypes;
1617
import fi.benjami.code4jvm.lua.ir.stmt.ReturnStmt;
@@ -244,7 +245,10 @@ public static Tuple tuple(LuaType... types) {
244245

245246
public static Function function(List<UpvalueTemplate> upvalues, List<LuaLocalVar> args, LuaBlock body,
246247
String moduleName, String name) {
247-
if (!body.hasReturn()) {
248+
CompilerPass.setCurrent(CompilerPass.RETURN_TRACKING);
249+
var hasReturn = body.hasReturn();
250+
CompilerPass.setCurrent(null);
251+
if (!hasReturn) {
248252
// If the function doesn't always return, insert return nil at end
249253
var nodes = new ArrayList<>(body.nodes());
250254
nodes.add(new ReturnStmt(List.of()));

lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/UpvalueTemplate.java

+7-1
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,11 @@ public record UpvalueTemplate(
1515
* {@link LuaFunction#upvalueTypes final types} that are known after
1616
* the function has been instantiated, this may be unknown.
1717
*/
18-
LuaType type
18+
LuaType type,
19+
20+
/**
21+
* Whether or not the upvalue variable is assigned to after its initial
22+
* assignment.
23+
*/
24+
boolean mutable
1925
) {}

lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/expr/ArithmeticExpr.java

+18-5
Original file line numberDiff line numberDiff line change
@@ -32,22 +32,29 @@ public record ArithmeticExpr(
3232
) implements IrNode {
3333

3434
private static final CallTarget MATH_POW = CallTarget.staticMethod(Type.of(Math.class), Type.DOUBLE, "pow", Type.DOUBLE, Type.DOUBLE);
35-
private static final CallTarget MATH_ABS = CallTarget.staticMethod(Type.of(Math.class), Type.DOUBLE, "abs", Type.DOUBLE);
36-
private static final CallTarget FLOOR_DIV = CallTarget.staticMethod(Type.of(ArithmeticExpr.class), Type.DOUBLE, "floorDivide", Type.DOUBLE, Type.DOUBLE);
35+
private static final CallTarget MATH_ABS_INT = CallTarget.staticMethod(Type.of(Math.class), Type.INT, "abs", Type.INT);
36+
private static final CallTarget MATH_ABS_DOUBLE = CallTarget.staticMethod(Type.of(Math.class), Type.DOUBLE, "abs", Type.DOUBLE);
37+
private static final CallTarget FLOOR_DIV_INTS = CallTarget.staticMethod(Type.of(ArithmeticExpr.class), Type.INT, "floorDivide", Type.INT, Type.INT);
38+
private static final CallTarget FLOOR_DIV_DOUBLES = CallTarget.staticMethod(Type.of(ArithmeticExpr.class), Type.DOUBLE, "floorDivide", Type.DOUBLE, Type.DOUBLE);
3739
private static final MethodHandles.Lookup LOOKUP = MethodHandles.lookup();
3840

3941
public enum Kind {
40-
POWER(MATH_POW::call, "power", "__pow"),
42+
POWER((lhs, rhs) -> {
43+
// Math.pow() does not have integer variant
44+
return MATH_POW.call(lhs.cast(Type.DOUBLE), rhs.cast(Type.DOUBLE));
45+
}, "power", "__pow"),
4146
MULTIPLY(Arithmetic::multiply, "multiply", "__mul"),
4247
DIVIDE((lhs, rhs) -> {
4348
// Lua uses float division unless integer division is explicitly request (see below)
4449
return Arithmetic.divide(lhs.cast(Type.DOUBLE), rhs.cast(Type.DOUBLE));
4550
}, "divide", "__div"),
46-
FLOOR_DIVIDE(FLOOR_DIV::call, "floorDivide", "__idiv"),
51+
FLOOR_DIVIDE((lhs, rhs)
52+
-> lhs.type().equals(Type.INT) ? FLOOR_DIV_INTS.call(lhs, rhs) : FLOOR_DIV_DOUBLES.call(lhs, rhs),
53+
"floorDivide", "__idiv"),
4754
MODULO((lhs, rhs) -> (block -> {
4855
// Lua expects modulo to be always positive; Java's remainder can return negative values
4956
var remainder = block.add(Arithmetic.remainder(lhs, rhs));
50-
return block.add(MATH_ABS.call(remainder));
57+
return block.add(remainder.type().equals(Type.INT) ? MATH_ABS_INT.call(remainder) : MATH_ABS_DOUBLE.call(remainder));
5158
}), "modulo", "__mod"),
5259
ADD(Arithmetic::add, "add", "__add"),
5360
SUBTRACT(Arithmetic::subtract, "subtract", "__sub");
@@ -156,6 +163,12 @@ public Value emit(LuaContext ctx, Block block) {
156163
var rhsValue = rhs.emit(ctx, block);
157164
if (outputType(ctx).isNumber()) {
158165
// Both arguments are known to be numbers; emit arithmetic operation directly
166+
// Just make sure that if either side is double, the other side is too
167+
if (lhsValue.type().equals(Type.INT) && rhsValue.type().equals(Type.DOUBLE)) {
168+
lhsValue = lhsValue.cast(Type.DOUBLE);
169+
} else if (rhsValue.type().equals(Type.INT) && lhsValue.type().equals(Type.DOUBLE)) {
170+
rhsValue = rhsValue.cast(Type.DOUBLE);
171+
}
159172
return block.add(kind.directEmitter.apply(lhsValue, rhsValue));
160173
} else {
161174
// Types are unknown compile-time; use invokedynamic

lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/expr/FunctionDeclExpr.java

+3-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import fi.benjami.code4jvm.Value;
88
import fi.benjami.code4jvm.block.Block;
99
import fi.benjami.code4jvm.lua.compiler.LuaContext;
10+
import fi.benjami.code4jvm.lua.compiler.VariableFlag;
1011
import fi.benjami.code4jvm.lua.ir.IrNode;
1112
import fi.benjami.code4jvm.lua.ir.LuaBlock;
1213
import fi.benjami.code4jvm.lua.ir.LuaLocalVar;
@@ -59,7 +60,8 @@ public Value emit(LuaContext ctx, Block block) {
5960
public LuaType.Function outputType(LuaContext ctx) {
6061
// Upvalue template has the variable INSIDE declared function, with type of OUTSIDE variable
6162
var upvalueTemplates = upvalues.stream()
62-
.map(upvalue -> new UpvalueTemplate(upvalue, ctx.variableType(upvalue)))
63+
.map(upvalue -> new UpvalueTemplate(upvalue, ctx.hasFlag(upvalue, VariableFlag.MUTABLE)
64+
? LuaType.UNKNOWN : ctx.variableType(upvalue), ctx.hasFlag(upvalue, VariableFlag.MUTABLE)))
6365
.toList();
6466
return ctx.cached(this, LuaType.function(upvalueTemplates, arguments, body, moduleName, name));
6567
}

lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/expr/NegateExpr.java

+12-4
Original file line numberDiff line numberDiff line change
@@ -18,26 +18,34 @@
1818

1919
public record NegateExpr(IrNode expr) implements IrNode {
2020

21-
private static final MethodHandle NEGATE;
21+
private static final MethodHandle NEGATE_DOUBLE, NEGATE_INT;
2222
private static final DynamicTarget TARGET;
2323

2424
static {
2525
var lookup = MethodHandles.lookup();
2626
try {
27-
NEGATE = lookup.findStatic(NegateExpr.class, "negate", MethodType.methodType(double.class, Object.class, double.class));
27+
NEGATE_DOUBLE = lookup.findStatic(NegateExpr.class, "negate", MethodType.methodType(double.class, Object.class, double.class));
28+
NEGATE_INT = lookup.findStatic(NegateExpr.class, "negate", MethodType.methodType(int.class, Object.class, int.class));
2829
} catch (NoSuchMethodException | IllegalAccessException e) {
2930
throw new AssertionError(e);
3031
}
3132

32-
TARGET = UnaryOp.newTarget(new UnaryOp.Path[] {new UnaryOp.Path(Double.class, NEGATE)}, "__unm",
33-
(val) -> new LuaException("attempted to negate a non-number value"));
33+
TARGET = UnaryOp.newTarget(new UnaryOp.Path[] {
34+
new UnaryOp.Path(Double.class, NEGATE_DOUBLE),
35+
new UnaryOp.Path(Integer.class, NEGATE_INT)
36+
}, "__unm", (val) -> new LuaException("attempted to negate a non-number value"));
3437
}
3538

3639
@SuppressWarnings("unused") // MethodHandle
3740
private static double negate(Object callable, double value) {
3841
return -value;
3942
}
4043

44+
@SuppressWarnings("unused") // MethodHandle
45+
private static int negate(Object callable, int value) {
46+
return -value;
47+
}
48+
4149
@Override
4250
public Value emit(LuaContext ctx, Block block) {
4351
var value = expr.emit(ctx, block);

lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/stmt/IfBlockStmt.java

+7
Original file line numberDiff line numberDiff line change
@@ -64,5 +64,12 @@ public boolean hasReturn() {
6464
// If we don't have fallback, we might not return
6565
return fallback != null ? fallback.hasReturn() : false;
6666
}
67+
68+
@Override
69+
public void flagVariables(LuaContext ctx) {
70+
for (var branch : branches) {
71+
branch.body.flagVariables(ctx);
72+
}
73+
}
6774

6875
}

lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/stmt/IteratorForStmt.java

+12
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import fi.benjami.code4jvm.call.CallTarget;
1212
import fi.benjami.code4jvm.lua.compiler.LoopRef;
1313
import fi.benjami.code4jvm.lua.compiler.LuaContext;
14+
import fi.benjami.code4jvm.lua.compiler.VariableFlag;
1415
import fi.benjami.code4jvm.lua.ir.IrNode;
1516
import fi.benjami.code4jvm.lua.ir.LuaBlock;
1617
import fi.benjami.code4jvm.lua.ir.LuaLocalVar;
@@ -172,4 +173,15 @@ public LuaType outputType(LuaContext ctx) {
172173
public boolean hasReturn() {
173174
return false; // The loop might run for zero iterations
174175
}
176+
177+
@Override
178+
public void flagVariables(LuaContext ctx) {
179+
for (var loopVar : loopVars) {
180+
ctx.setFlag(loopVar, VariableFlag.MUTABLE);
181+
}
182+
for (var it : iterable) {
183+
it.flagVariables(ctx);
184+
}
185+
body.flagVariables(ctx);
186+
}
175187
}

lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/stmt/LoopStmt.java

+6
Original file line numberDiff line numberDiff line change
@@ -61,4 +61,10 @@ public boolean hasReturn() {
6161
// If there is condition before first run, it might not return
6262
return kind != Kind.REPEAT_UNTIL ? false : body.hasReturn();
6363
}
64+
65+
@Override
66+
public void flagVariables(LuaContext ctx) {
67+
condition.flagVariables(ctx);
68+
body.flagVariables(ctx);
69+
}
6470
}

lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/stmt/SetVariablesStmt.java

+7-3
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,6 @@ public LuaType outputType(LuaContext ctx) {
130130
for (var i = 0; i < Math.min(normalSources, targets.size()); i++) {
131131
var target = targets.get(i);
132132
ctx.recordType(target, sources.get(i).outputType(ctx));
133-
markMutable(ctx, target);
134133
}
135134

136135
if (spread) {
@@ -145,19 +144,24 @@ public LuaType outputType(LuaContext ctx) {
145144
// anything else -> first multiValType, rest NIL
146145
var target = targets.get(i);
147146
ctx.recordType(target, LuaType.UNKNOWN);
148-
markMutable(ctx, target);
149147
}
150148
} else {
151149
// If there are leftover targets, set them to nil
152150
for (var i = normalSources; i < targets.size(); i++) {
153151
var target = targets.get(i);
154152
ctx.recordType(target, LuaType.NIL);
155-
markMutable(ctx, target);
156153
}
157154
}
158155
return LuaType.NIL;
159156
}
160157

158+
@Override
159+
public void flagVariables(LuaContext ctx) {
160+
for (var target : targets) {
161+
markMutable(ctx, target);
162+
}
163+
}
164+
161165
private void markMutable(LuaContext ctx, LuaVariable target) {
162166
if (target instanceof LuaLocalVar localVar) {
163167
if (ctx.hasFlag(localVar, VariableFlag.ASSIGNED)) {

lua4jvm/src/main/java/fi/benjami/code4jvm/lua/linker/LuaLinker.java

+1-4
Original file line numberDiff line numberDiff line change
@@ -116,12 +116,9 @@ public static LuaCallTarget linkCall(LuaCallSite meta, Object callable, Object..
116116
specializedTypes = Arrays.copyOf(specializedTypes, function.type().acceptedArgs().size());
117117
Arrays.fill(specializedTypes, compiledTypes.length, specializedTypes.length, LuaType.UNKNOWN);
118118
}
119-
120-
// FIXME upvalue typing is incorrect for mutable upvalues until VARIABLE_TRACING pass is implemented
121-
var useUpvalueTypes = false; // checkTarget
122119

123120
// Truncate multival return if site doesn't want to spread
124-
target = FunctionCompiler.callTarget(specializedTypes, function, useUpvalueTypes,
121+
target = FunctionCompiler.callTarget(specializedTypes, function, checkTarget,
125122
!meta.options.spreadResults());
126123
guard = checkTarget ? TARGET_HAS_CHANGED.bindTo(function)
127124
: PROTOTYPE_HAS_CHANGED.bindTo(function.type());

0 commit comments

Comments
 (0)