Skip to content

Commit

Permalink
Deep argument flattening (#182)
Browse files Browse the repository at this point in the history
* deep argument flattening (pass tuples and reals unboxed when possible, delete unused arguments, and perform uncurring)
* fixes to recursive function optimisation
* fix-floating
  • Loading branch information
melsman authored Nov 11, 2024
1 parent ec6a14d commit edf4b2b
Show file tree
Hide file tree
Showing 25 changed files with 1,005 additions and 1,731 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:

steps:

- uses: actions/checkout@v2
- uses: actions/checkout@v3

- name: Setup environment
run: |
Expand Down
10 changes: 5 additions & 5 deletions basis/Array2.sml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ structure Array2 :> ARRAY2 = struct

type 'a array = 'a array

(* 26 bits are reserved for the length in the tag field; maxLen = 2^26 *)
val maxLen = 16777216 (* =4096*4096=128 MB *)
(* quite a few bits are available for the length in the tag field! *)
val maxLen = Initial.wordtable_maxlen

type 'a region = {base : 'a array,
row : int,
Expand All @@ -30,7 +30,7 @@ fun update2 (a : 'a array, cols:int, r:int, c:int, v:'a) : unit =
(* The primitive word_table2d0 is in OptLambda compiled into calls to
word_table0 and consecutive updates to store the sizes of each
dimension in slot 0 and 1. Similarly for word_table2d0_init, which
is in OptLambda compiled into calls to word_table_init and
is in OptLambda compiles into calls to word_table_init and
consecutive updates to store the sizes of each dimension in slot 0
and 1.
Expand All @@ -45,8 +45,8 @@ fun table2d0_init (n:int,v:'a,r:int,c:int) : 'a array = prim ("word_table2d0_ini
fun update0 (a : 'a array, i : int, x : 'a) : unit = prim ("word_update0", (a, i, x))

fun check (nr,nc) : int =
if nr < 0 orelse nc < 0 orelse nc > maxLen orelse nr > maxLen then raise Size
else let val n = nr*nc
if nr < 0 orelse nc < 0 then raise Size
else let val n = nr*nc handle Overflow => raise Size
in if n > maxLen then raise Size
else n
end
Expand Down
1 change: 1 addition & 0 deletions src/Common/KitBarry.sml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ structure K =
let structure KC = KitCompiler(ExecutionBarry)
val () = Flags.turn_off "cross_module_opt"
val () = Flags.turn_off "unbox_reals"
val () = Flags.turn_off "unbox_function_arguments"
val () = Flags.turn_off "eliminate_polymorphic_equality"
val () = Flags.turn_off "uncurrying"
val () = List.app Flags.block_entry
Expand Down
10 changes: 7 additions & 3 deletions src/Compiler/Backend/RegAlloc.sml
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,11 @@ struct
adjList:=nil; alias:=NONE; color:=SOME key; lrs:=no_call; uses:=0))
precolored

val Kfpr = List.length RI.f64_phregs
val () = if List.length RI.f64_phregs <> 14 then
die "RegAlloc: f64_phregs error"
else ()
val f64_phregs = List.take (RI.f64_phregs, 12)
val Kfpr = List.length f64_phregs
val Kgpr = List.length RI.caller_save_phregs

fun K (k:kind) : int =
Expand Down Expand Up @@ -834,7 +838,7 @@ struct
S.empty

val f64_phregset =
S.fromList (map key RI.f64_phregs)
S.fromList (map key f64_phregs)

val callee_save_ccall_phregset =
S.fromList (map key RI.callee_save_ccall_phregs)
Expand Down Expand Up @@ -1154,7 +1158,7 @@ struct
structure H = Polyhash
in
val phregKeyToLv : key -> lvar =
let val regs = RI.f64_phregs @ RI.all_regs
let val regs = f64_phregs @ RI.all_regs
val m = H.mkTable (Word.toIntX, op =) (50,Fail "RegAlloc.phregTable")
val () = app (fn lv => H.insert m (key lv, lv)) regs
in fn k => case H.peek m k of
Expand Down
6 changes: 3 additions & 3 deletions src/Compiler/Backend/X64/CodeGenUtilX64.sml
Original file line number Diff line number Diff line change
Expand Up @@ -1966,7 +1966,7 @@ struct
else if I.is_xmm x andalso I.is_xmm y then I.movsd(R x,R y)::C
else die "copy_f64: expecting xmm registers"

fun bin_f64_op s finst (x,y0,d,size_ff:int,C) =
fun bin_f64_op s finst (x,y0,d,size_ff:int,C) = (* d := x op y *) (* e.g.: d := x; sub y d *)
let val (x, x_C) = resolve_arg_aty(x,tmp_freg0,size_ff)
val (y, y_C) = resolve_arg_aty(y0,tmp_freg1,size_ff)
val (d, C') = resolve_aty_def(d,tmp_freg0,size_ff, C)
Expand All @@ -1977,12 +1977,12 @@ struct
if y = d then
if x = d then
finst(R d, R d) :: C'
else
else (* x <> d && x <> y && d <> f1 *)
copy_f64(y, tmp_freg1,
copy_f64(x, d,
finst(R tmp_freg1, R d) ::
C'))
else
else (* y <> d *)
copy_f64(x, d,
finst(R y, R d) ::
C') ))
Expand Down
6 changes: 5 additions & 1 deletion src/Compiler/Backend/X64/CodeGenX64.sml
Original file line number Diff line number Diff line change
Expand Up @@ -1643,6 +1643,10 @@ struct
comment ("END OF STATIC DATA AREA",nil)))))

fun init_x64_code () = [I.dot_text]

fun x64_optimise x =
if true then I.optimise x
else x
in
fun CG {main_lab:label,
code=ss_prg: (StoreTypeCO,offset,AtySS) LinePrg,
Expand All @@ -1660,7 +1664,7 @@ struct
val x64_prg = {top_decls = foldr (fn (func,acc) => CG_top_decl func :: acc) [] ss_prg,
init_code = init_x64_code(),
static_data = static_data main_lab}
val x64_prg = I.optimise x64_prg
val x64_prg = x64_optimise x64_prg
val _ = chat "]\n"
in
x64_prg
Expand Down
12 changes: 0 additions & 12 deletions src/Compiler/Backend/X64/INSTS_X64.sml
Original file line number Diff line number Diff line change
Expand Up @@ -114,18 +114,6 @@ signature INSTS_X64 =
| sqrtsd of ea * ea
| cvtsi2sdq of ea * ea

| fstpq of ea (* store float and pop float stack *)
| fldq of ea (* push float onto the float stack *)
| fldz (* push 0.0 onto the float stack *)
| faddp (* add st(0) to st(1) and pop *)
| fsubp (* subtract st(0) from st(1) and pop *)
| fmulp (* multiply st(0) to st(1) and pop *)
| fdivp (* divide st(1) with st(0) and pop *)
| fcompp (* compare st(0) and st(1) and pop twice *)
| fabs (* st(0) = abs(st(0)) *)
| fchs (* st(0) = neg(st(0)) *)
| fnstsw (* store float status word *)

| jmp of ea (* jump instructions *)
| jl of lab
| jg of lab
Expand Down
46 changes: 6 additions & 40 deletions src/Compiler/Backend/X64/InstsX64.sml
Original file line number Diff line number Diff line change
Expand Up @@ -141,19 +141,6 @@ structure InstsX64: INSTS_X64 =
| xorps of ea * ea
| sqrtsd of ea * ea
| cvtsi2sdq of ea * ea

| fstpq of ea (* store float and pop float stack *)
| fldq of ea (* push float onto the float stack *)
| fldz (* push 0.0 onto the float stack *)
| faddp (* add st(0) to st(1) and pop *)
| fsubp (* subtract st(0) from st(1) and pop *)
| fmulp (* multiply st(0) to st(1) and pop *)
| fdivp (* divide st(1) with st(0) and pop *)
| fcompp (* compare st(0) and st(1) and pop twice *)
| fabs (* st(0) = abs(st(0)) *)
| fchs (* st(0) = neg(st(0)) *)
| fnstsw (* store float status word *)

| jmp of ea (* jump instructions *)
| jl of lab
| jg of lab
Expand Down Expand Up @@ -535,19 +522,6 @@ structure InstsX64: INSTS_X64 =
| xorps a => emit_bin("xorps", a)
| sqrtsd a => emit_bin("sqrtsd", a)
| cvtsi2sdq a => emit_bin("cvtsi2sdq", a)

| fstpq ea => emit_unary("fstpq", ea)
| fldq ea => emit_unary("fldq", ea)
| fldz => emit_nullary "fldz"
| faddp => emit_nullary "faddp"
| fsubp => emit_nullary "fsubp"
| fmulp => emit_nullary "fmulp"
| fdivp => emit_nullary "fdivp"
| fcompp=> emit_nullary "fcompp"
| fabs => emit_nullary "fabs"
| fchs => emit_nullary "fchs"
| fnstsw => emit_nullary "fnstsw"

| jmp (L l) => emit_jump("jmp", l)
| jmp ea => (emit "\tjmp *"; emit(pr_ea K ea); emit_nl())
| jl l => emit_jump("jl", l)
Expand Down Expand Up @@ -677,7 +651,6 @@ structure InstsX64: INSTS_X64 =
v8,v9,v10,v11,v12,v13,v14,v15)
| _ => die "RI.all_fregs mismatch"


val f64_phregset = Lvarset.lvarsetof f64_phregs

val map_lvs_to_reg =
Expand Down Expand Up @@ -771,10 +744,10 @@ structure InstsX64: INSTS_X64 =
fun doubleOfQuadReg r =
case r of
rax => eax | rbx => ebx | rcx => ecx | rdx => edx
| rsi => esi | rdi => edi | rbp => ebp | rsp => esp
| r8 => r8d | r9 => r9d | r10 => r10d | r11 => r11d
| r12 => r12d | r13 => r13d | r14 => r14d | r15 => r15d
| _ => die ("doubleOfQuadReg: " ^ pr_reg r ^ " is not a quad register")
| rsi => esi | rdi => edi | rbp => ebp | rsp => esp
| r8 => r8d | r9 => r9d | r10 => r10d | r11 => r11d
| r12 => r12d | r13 => r13d | r14 => r14d | r15 => r15d
| _ => die ("doubleOfQuadReg: " ^ pr_reg r ^ " is not a quad register")

(* Helper functions *)

Expand Down Expand Up @@ -862,8 +835,10 @@ structure InstsX64: INSTS_X64 =
| xorps (ea1,ea2) => xorps (Em ea1,Em ea2)
| sqrtsd (ea1,ea2) => sqrtsd (Em ea1,Em ea2)
| cvtsi2sdq (ea1,ea2) => cvtsi2sdq (Em ea1,Em ea2)
(*
| fstpq ea => fstpq (Em ea)
| fldq ea => fldq (Em ea)
*)
| jmp ea => jmp (Em ea)
| jl l => jl (Lm l)
| jg l => jg (Lm l)
Expand Down Expand Up @@ -893,15 +868,6 @@ structure InstsX64: INSTS_X64 =
| dot_globl (l,ty) => dot_globl (Lm l,ty)
| dot_size (l, i) => dot_size (Lm l, i)
| lab l => lab (Lm l)
| fldz => i
| faddp => i
| fsubp => i
| fmulp => i
| fdivp => i
| fcompp => i
| fabs => i
| fchs => i
| fnstsw => i
| ret => i
| leave => i
| dot_align n => i
Expand Down
Loading

0 comments on commit edf4b2b

Please sign in to comment.