Skip to content

Commit

Permalink
feat(runtime): Optimize == for lists (#2247)
Browse files Browse the repository at this point in the history
Co-authored-by: Oscar Spencer <[email protected]>
  • Loading branch information
spotandjake and ospencer authored Feb 17, 2025
1 parent 097ae7d commit 1cba005
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 76 deletions.
6 changes: 6 additions & 0 deletions compiler/test/stdlib/pervasives.test.gr
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,9 @@ record Comparable2 {
}
assert compare({ a: 1, b: true, c: void }, { a: 1, b: true, c: void }) == 0
assert compare({ a: 1, b: true, c: void }, { a: 1, b: false, c: void }) > 0

// Large list equality, regression #2247
let rec make_list = (n, acc) => {
if (n == 0) acc else make_list(n - 1, [n, ...acc])
}
assert make_list(500_000, []) == make_list(500_000, [])
2 changes: 1 addition & 1 deletion compiler/test/suites/basic_functionality.re
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,6 @@ describe("basic functionality", ({test, testSkip}) => {
~config_fn=smallestFileConfig,
"smallest_grain_program",
"",
6494,
6503,
);
});
172 changes: 97 additions & 75 deletions stdlib/runtime/equal.gr
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,19 @@ module Equal

from "runtime/unsafe/memory" include Memory
from "runtime/unsafe/wasmi32" include WasmI32
use WasmI32.{ (==), (!=), (&), (^), (+), (-), (*), (<), remS as (%), (<<) }
use WasmI32.{
(==),
(!=),
(&),
(^),
(+),
(-),
(*),
(<),
remS as (%),
(<<),
(>>),
}
from "runtime/unsafe/wasmi64" include WasmI64
from "runtime/unsafe/wasmf32" include WasmF32
from "runtime/unsafe/tags" include Tags
Expand All @@ -14,6 +26,10 @@ primitive (!) = "@not"
primitive (||) = "@or"
primitive (&&) = "@and"
primitive ignore = "@ignore"
primitive builtinId = "@builtin.id"

@unsafe
let _LIST_ID = WasmI32.fromGrain(builtinId("List"))

@unsafe
let cycleMarker = 0x80000000n
Expand All @@ -23,38 +39,47 @@ let rec heapEqualHelp = (heapTag, xptr, yptr) => {
match (heapTag) {
t when t == Tags._GRAIN_ADT_HEAP_TAG => {
// Check if the same constructor variant
if (WasmI32.load(xptr, 12n) != WasmI32.load(yptr, 12n)) {
false
let mut xVariantTag = WasmI32.load(xptr, 12n)
let mut yVariantTag = WasmI32.load(yptr, 12n)
if (xVariantTag != yVariantTag) {
return false
}

// Handle lists separately to avoid stack overflow
if (WasmI32.load(xptr, 8n) == _LIST_ID) {
if (xVariantTag >> 1n == 1n) return true // End of list

if (!equalHelp(WasmI32.load(xptr, 20n), WasmI32.load(yptr, 20n))) {
return false
}

return equalHelp(WasmI32.load(xptr, 24n), WasmI32.load(yptr, 24n))
} else {
let xarity = WasmI32.load(xptr, 16n)
let yarity = WasmI32.load(yptr, 16n)

// Cycle check
if ((xarity & cycleMarker) == cycleMarker) {
true
} else {
WasmI32.store(xptr, xarity ^ cycleMarker, 16n)
WasmI32.store(yptr, yarity ^ cycleMarker, 16n)

let mut result = true

let bytes = xarity * 4n
for (let mut i = 0n; i < bytes; i += 4n) {
if (
!equalHelp(
WasmI32.load(xptr + i, 20n),
WasmI32.load(yptr + i, 20n)
)
) {
result = false
break
}
}
WasmI32.store(xptr, xarity, 16n)
WasmI32.store(yptr, yarity, 16n)
return true
}

result
WasmI32.store(xptr, xarity ^ cycleMarker, 16n)
WasmI32.store(yptr, yarity ^ cycleMarker, 16n)

let bytes = xarity * 4n
for (let mut i = 0n; i < bytes; i += 4n) {
if (
!equalHelp(WasmI32.load(xptr + i, 20n), WasmI32.load(yptr + i, 20n))
) {
WasmI32.store(xptr, xarity, 16n)
WasmI32.store(yptr, yarity, 16n)
return false
}
}
WasmI32.store(xptr, xarity, 16n)
WasmI32.store(yptr, yarity, 16n)

return true
}
},
t when t == Tags._GRAIN_RECORD_HEAP_TAG => {
Expand All @@ -63,65 +88,64 @@ let rec heapEqualHelp = (heapTag, xptr, yptr) => {

// Cycle check
if ((xlength & cycleMarker) == cycleMarker) {
true
} else {
WasmI32.store(xptr, xlength ^ cycleMarker, 12n)
WasmI32.store(yptr, ylength ^ cycleMarker, 12n)

let mut result = true
return true
}

let bytes = xlength * 4n
for (let mut i = 0n; i < bytes; i += 4n) {
if (
!equalHelp(WasmI32.load(xptr + i, 16n), WasmI32.load(yptr + i, 16n))
) {
result = false
break
}
WasmI32.store(xptr, xlength ^ cycleMarker, 12n)
WasmI32.store(yptr, ylength ^ cycleMarker, 12n)

let bytes = xlength * 4n
for (let mut i = 0n; i < bytes; i += 4n) {
if (
!equalHelp(WasmI32.load(xptr + i, 16n), WasmI32.load(yptr + i, 16n))
) {
WasmI32.store(xptr, xlength, 12n)
WasmI32.store(yptr, ylength, 12n)
return false
}
WasmI32.store(xptr, xlength, 12n)
WasmI32.store(yptr, ylength, 12n)

result
}
WasmI32.store(xptr, xlength, 12n)
WasmI32.store(yptr, ylength, 12n)

return true
},
t when t == Tags._GRAIN_ARRAY_HEAP_TAG => {
let xlength = WasmI32.load(xptr, 4n)
let ylength = WasmI32.load(yptr, 4n)

// Check if the same length
if (xlength != ylength) {
false
} else if ((xlength & cycleMarker) == cycleMarker) {
// Cycle check
true
} else {
WasmI32.store(xptr, xlength ^ cycleMarker, 4n)
WasmI32.store(yptr, ylength ^ cycleMarker, 4n)
return false
}

let mut result = true
let bytes = xlength * 4n
for (let mut i = 0n; i < bytes; i += 4n) {
if (
!equalHelp(WasmI32.load(xptr + i, 8n), WasmI32.load(yptr + i, 8n))
) {
result = false
break
}
}
// Cycle check
if ((xlength & cycleMarker) == cycleMarker) {
return true
}

WasmI32.store(xptr, xlength, 4n)
WasmI32.store(yptr, ylength, 4n)
WasmI32.store(xptr, xlength ^ cycleMarker, 4n)
WasmI32.store(yptr, ylength ^ cycleMarker, 4n)

result
let bytes = xlength * 4n
for (let mut i = 0n; i < bytes; i += 4n) {
if (!equalHelp(WasmI32.load(xptr + i, 8n), WasmI32.load(yptr + i, 8n))) {
WasmI32.store(xptr, xlength, 4n)
WasmI32.store(yptr, ylength, 4n)
return false
}
}

WasmI32.store(xptr, xlength, 4n)
WasmI32.store(yptr, ylength, 4n)

return true
},
t when t == Tags._GRAIN_STRING_HEAP_TAG || t == Tags._GRAIN_BYTES_HEAP_TAG => {
let xlength = WasmI32.load(xptr, 4n)
let ylength = WasmI32.load(yptr, 4n)

// Check if the same length
if (xlength != ylength) {
return if (xlength != ylength) {
false
} else {
Memory.compare(xptr + 8n, yptr + 8n, xlength) == 0n
Expand All @@ -132,44 +156,42 @@ let rec heapEqualHelp = (heapTag, xptr, yptr) => {
let ysize = WasmI32.load(yptr, 4n)

if ((xsize & cycleMarker) == cycleMarker) {
true
return true
} else {
WasmI32.store(xptr, xsize ^ cycleMarker, 4n)
WasmI32.store(yptr, ysize ^ cycleMarker, 4n)

let mut result = true
let bytes = xsize * 4n
for (let mut i = 0n; i < bytes; i += 4n) {
if (
!equalHelp(WasmI32.load(xptr + i, 8n), WasmI32.load(yptr + i, 8n))
) {
result = false
break
WasmI32.store(xptr, xsize, 4n)
WasmI32.store(yptr, ysize, 4n)
return false
}
}

WasmI32.store(xptr, xsize, 4n)
WasmI32.store(yptr, ysize, 4n)

result
return true
}
},
t when t == Tags._GRAIN_UINT32_HEAP_TAG || t == Tags._GRAIN_INT32_HEAP_TAG => {
let xval = WasmI32.load(xptr, 4n)
let yval = WasmI32.load(yptr, 4n)
xval == yval
return xval == yval
},
// Float32 is handled by equalHelp directly
t when t == Tags._GRAIN_UINT64_HEAP_TAG => {
use WasmI64.{ (==) }
let xval = WasmI64.load(xptr, 8n)
let yval = WasmI64.load(yptr, 8n)
xval == yval
},
_ => {
// No other implementation
xptr == yptr
return xval == yval
},
// No other implementation
_ => return xptr == yptr,
}
}
and equalHelp = (x, y) => {
Expand Down

0 comments on commit 1cba005

Please sign in to comment.