Skip to content

Commit

Permalink
Remove buggy linearization pass (#604)
Browse files Browse the repository at this point in the history
There is only one bug in the base linearization pass that needs to be
handled explicitly, otherwise linearization is fully guaranteed for IR.

---------

Co-authored-by: Shuhei Kadowaki <[email protected]>
  • Loading branch information
vtjnash and aviatesk authored Jan 5, 2024
1 parent 1efae18 commit d319168
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 154 deletions.
64 changes: 41 additions & 23 deletions src/interpret.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,6 @@ function lookup_var(frame, slot::SlotNumber)
throw(UndefVarError(frame.framecode.src.slotnames[slot.id]))
end

function lookup_expr(frame, e::Expr)
head = e.head
head === :the_exception && return frame.framedata.last_exception[]
if head === :static_parameter
arg = e.args[1]::Int
if isassigned(frame.framedata.sparams, arg)
return frame.framedata.sparams[arg]
else
syms = sparam_syms(frame.framecode.scope::Method)
throw(UndefVarError(syms[arg]))
end
end
head === :boundscheck && length(e.args) == 0 && return true
error("invalid lookup expr ", e)
end

"""
rhs = @lookup(frame, node)
rhs = @lookup(mod, frame, node)
Expand Down Expand Up @@ -67,6 +51,32 @@ macro lookup(args...)
end
end

function lookup_expr(frame, e::Expr)
head = e.head
head === :the_exception && return frame.framedata.last_exception[]
if head === :static_parameter
arg = e.args[1]::Int
if isassigned(frame.framedata.sparams, arg)
return frame.framedata.sparams[arg]
else
syms = sparam_syms(frame.framecode.scope::Method)
throw(UndefVarError(syms[arg]))
end
end
head === :boundscheck && length(e.args) == 0 && return true
if head === :call
f = @lookup frame e.args[1]
if (@static VERSION < v"1.11.0-DEV.1180" && true) && f === Core.svec
# work around for a linearization bug in Julia (https://github.com/JuliaLang/julia/pull/52497)
return f(Any[@lookup(frame, e.args[i]) for i in 2:length(e.args)]...)
elseif f === Core.tuple
# handling for ccall literal syntax
return f(Any[@lookup(frame, e.args[i]) for i in 2:length(e.args)]...)
end
end
error("invalid lookup expr ", e)
end

# This is used only for new struct/abstract/primitive nodes.
# The most important issue is that in these expressions, :call Exprs can be nested,
# and hence our re-use of the `callargs` field of Frame would introduce
Expand All @@ -91,18 +101,26 @@ function lookup_or_eval(@nospecialize(recurse), frame, @nospecialize(node))
if ex.head === :call
f = ex.args[1]
if f === Core.svec
return Core.svec(ex.args[2:end]...)
popfirst!(ex.args)
return Core.svec(ex.args...)
elseif f === Core.apply_type
return Core.apply_type(ex.args[2:end]...)
elseif f === Core.typeof
return Core.typeof(ex.args[2])
elseif f === Base.getproperty
popfirst!(ex.args)
return Core.apply_type(ex.args...)
elseif f === typeof && length(ex.args) == 2
return typeof(ex.args[2])
elseif f === typeassert && length(ex.args) == 3
return typeassert(ex.args[2], ex.args[3])
elseif f === Base.getproperty && length(ex.args) == 3
return Base.getproperty(ex.args[2], ex.args[3])
elseif f === Core.Compiler.Val && length(ex.args) == 2
return Core.Compiler.Val(ex.args[2])
elseif f === Val && length(ex.args) == 2
return Val(ex.args[2])
else
Base.invokelatest(error, "unknown call f ", f)
Base.invokelatest(error, "unknown call f introduced by ccall lowering ", f)
end
else
error("unknown expr ", ex)
return lookup_expr(frame, ex)
end
elseif isa(node, Int) || isa(node, Number) # Number is slow, requires subtyping
return node
Expand Down
132 changes: 6 additions & 126 deletions src/optimize.jl
Original file line number Diff line number Diff line change
@@ -1,94 +1,5 @@
const calllike = (:call, :foreigncall)

const compiled_calls = Dict{Any,Any}()

function extract_inner_call!(stmt::Expr, idx, once::Bool=false)
(stmt.head === :toplevel || stmt.head === :thunk) && return nothing
once |= stmt.head calllike
for (i, a) in enumerate(stmt.args)
isa(a, Expr) || continue
# Make sure we don't "damage" special syntax that requires literals
if i == 1 && stmt.head === :foreigncall
continue
end
if i == 2 && stmt.head === :call && stmt.args[1] === :cglobal
continue
end
ret = extract_inner_call!(a, idx, once) # doing this first extracts innermost calls
ret !== nothing && return ret
iscalllike = a.head calllike
if once && iscalllike
stmt.args[i] = NewSSAValue(idx)
return a
end
end
return nothing
end

function replace_ssa(stmt::Expr, ssalookup)
return Expr(stmt.head, Any[
if isa(a, SSAValue)
SSAValue(ssalookup[a.id])
elseif isa(a, NewSSAValue)
SSAValue(a.id)
elseif isa(a, Expr)
replace_ssa(a, ssalookup)
else
a
end
for a in stmt.args
]...)
end

function renumber_ssa!(stmts::Vector{Any}, ssalookup)
# When updating jumps, when lines get split into multiple lines
# (see "Un-nest :call expressions" below), we need to jump to the first of them.
# Consequently we use the previous "old-code" offset and add one.
# Fixes #455.
jumplookup(l, idx) = idx > 1 ? l[idx-1] + 1 : idx

for (i, stmt) in enumerate(stmts)
if isa(stmt, GotoNode)
stmts[i] = GotoNode(jumplookup(ssalookup, stmt.label))
elseif isa(stmt, SSAValue)
stmts[i] = SSAValue(ssalookup[stmt.id])
elseif isa(stmt, NewSSAValue)
stmts[i] = SSAValue(stmt.id)
elseif isexpr(stmt, :enter)
stmt.args[end] = jumplookup(ssalookup, stmt.args[1]::Int)
elseif isa(stmt, Expr)
stmts[i] = replace_ssa(stmt, ssalookup)
elseif isa(stmt, GotoIfNot)
cond = stmt.cond
if isa(cond, SSAValue)
cond = SSAValue(ssalookup[cond.id])
end
stmts[i] = GotoIfNot(cond, jumplookup(ssalookup, stmt.dest))
elseif isa(stmt, ReturnNode)
val = stmt.val
if isa(val, SSAValue)
stmts[i] = ReturnNode(SSAValue(ssalookup[val.id]))
end
elseif @static (isdefined(Core.IR, :EnterNode) && true) && isa(stmt, Core.IR.EnterNode)
stmts[i] = Core.IR.EnterNode(jumplookup(ssalookup, stmt.catch_dest))
end
end
return stmts
end

function compute_ssa_mapping_delete_statements!(code::CodeInfo, stmts::Vector{Int})
stmts = unique!(sort!(stmts))
ssalookup = collect(1:length(codelocs(code)))
cnt = 1
for i in 1:length(stmts)
start = stmts[i] + 1
stop = i == length(stmts) ? length(codelocs(code)) : stmts[i+1]
ssalookup[start:stop] .-= cnt
cnt += 1
end
return ssalookup
end

# Pre-frame-construction lookup
function lookup_stmt(stmts, arg)
if isa(arg, SSAValue)
Expand Down Expand Up @@ -179,7 +90,8 @@ function optimize!(code::CodeInfo, scope)

# Replace :llvmcall and :foreigncall with compiled variants. See
# https://github.com/JuliaDebug/JuliaInterpreter.jl/issues/13#issuecomment-464880123
foreigncalls_idx = Int[]
# Insert the foreigncall wrappers at the updated idxs
methodtables = Vector{Union{Compiled,DispatchableMethod}}(undef, length(code.code))
for (idx, stmt) in enumerate(code.code)
# Foregincalls can be rhs of assignments
if isexpr(stmt, :(=))
Expand All @@ -192,47 +104,16 @@ function optimize!(code::CodeInfo, scope)
if (arg1 === :llvmcall || lookup_stmt(code.code, arg1) === Base.llvmcall) && isempty(sparams) && scope isa Method
# Call via `invokelatest` to avoid compiling it until we need it
Base.invokelatest(build_compiled_llvmcall!, stmt, code, idx, evalmod)
push!(foreigncalls_idx, idx)
methodtables[idx] = Compiled()
end
elseif stmt.head === :foreigncall && scope isa Method
# Call via `invokelatest` to avoid compiling it until we need it
Base.invokelatest(build_compiled_foreigncall!, stmt, code, sparams, evalmod)
push!(foreigncalls_idx, idx)
methodtables[idx] = Compiled()
end
end
end

## Un-nest :call expressions (so that there will be only one :call per line)
# This will allow us to re-use args-buffers rather than having to allocate new ones each time.
old_code, old_codelocs = code.code, codelocs(code)
code.code = new_code = eltype(old_code)[]
code.codelocs = new_codelocs = Int32[]
ssainc = fill(1, length(old_code))
for (i, stmt) in enumerate(old_code)
loc = old_codelocs[i]
if isa(stmt, Expr)
inner = extract_inner_call!(stmt, length(new_code)+1)
while inner !== nothing
push!(new_code, inner)
push!(new_codelocs, loc)
ssainc[i] += 1
inner = extract_inner_call!(stmt, length(new_code)+1)
end
end
push!(new_code, stmt)
push!(new_codelocs, loc)
end
# Fix all the SSAValues and GotoNodes
ssalookup = cumsum(ssainc)
renumber_ssa!(new_code, ssalookup)
code.ssavaluetypes = length(new_code)

# Insert the foreigncall wrappers at the updated idxs
methodtables = Vector{Union{Compiled,DispatchableMethod}}(undef, length(code.code))
for idx in foreigncalls_idx
methodtables[ssalookup[idx]] = Compiled()
end

return code, methodtables
end

Expand All @@ -255,7 +136,7 @@ function parametric_type_to_expr(@nospecialize(t::Type))
return t
end

function build_compiled_llvmcall!(stmt::Expr, code, idx, evalmod)
function build_compiled_llvmcall!(stmt::Expr, code::CodeInfo, idx::Int, evalmod::Module)
# Run a mini-interpreter to extract the types
framecode = FrameCode(CompiledCalls, code; optimize=false)
frame = Frame(framecode, prepare_framedata(framecode, []))
Expand Down Expand Up @@ -292,9 +173,8 @@ function build_compiled_llvmcall!(stmt::Expr, code, idx, evalmod)
append!(stmt.args, args)
end


# Handle :llvmcall & :foreigncall (issue #28)
function build_compiled_foreigncall!(stmt::Expr, code, sparams::Vector{Symbol}, evalmod)
function build_compiled_foreigncall!(stmt::Expr, code::CodeInfo, sparams::Vector{Symbol}, evalmod::Module)
TVal = evalmod == Core.Compiler ? Core.Compiler.Val : Val
cfunc, RetType, ArgType = lookup_stmt(code.code, stmt.args[1]), stmt.args[2], stmt.args[3]::SimpleVector

Expand Down
5 changes: 0 additions & 5 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,6 @@ which will cause all calls to be evaluated via the interpreter.
struct Compiled end
Base.similar(::Compiled, sz) = Compiled() # to support similar(stack, 0)

# A type used transiently in renumbering CodeInfo SSAValues (to distinguish a new SSAValue from an old one)
struct NewSSAValue
id::Int
end

# Our own replacements for Core types. We need to do this to ensure we can tell the difference
# between "data" (Core types) and "code" (our types) if we step into Core.Compiler
struct SSAValue
Expand Down
4 changes: 4 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ function scan_ssa_use!(used::BitSet, @nospecialize(stmt))
while iterval !== nothing
useref, state = iterval
val = Core.Compiler.getindex(useref)
if (@static VERSION < v"1.11.0-DEV.1180" && true) && isexpr(val, :call)
# work around for a linearization bug in Julia (https://github.com/JuliaLang/julia/pull/52497)
scan_ssa_use!(used, val)
end
if isa(val, SSAValue)
push!(used, val.id)
end
Expand Down

0 comments on commit d319168

Please sign in to comment.