Skip to content
158 changes: 141 additions & 17 deletions src/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,22 +106,22 @@ end
```
"""
macro with_pool(pool_name, expr)
_generate_pool_code(pool_name, expr, true)
_generate_pool_code(pool_name, expr, true; source=__source__)
end

macro with_pool(expr)
pool_name = gensym(:pool)
_generate_pool_code(pool_name, expr, true)
_generate_pool_code(pool_name, expr, true; source=__source__)
end

# Backend-specific variants: @with_pool :cuda pool begin ... end
macro with_pool(backend::QuoteNode, pool_name, expr)
_generate_pool_code_with_backend(backend.value, pool_name, expr, true)
_generate_pool_code_with_backend(backend.value, pool_name, expr, true; source=__source__)
end

macro with_pool(backend::QuoteNode, expr)
pool_name = gensym(:pool)
_generate_pool_code_with_backend(backend.value, pool_name, expr, true)
_generate_pool_code_with_backend(backend.value, pool_name, expr, true; source=__source__)
end

"""
Expand Down Expand Up @@ -153,22 +153,22 @@ end
```
"""
macro maybe_with_pool(pool_name, expr)
_generate_pool_code(pool_name, expr, false)
_generate_pool_code(pool_name, expr, false; source=__source__)
end

macro maybe_with_pool(expr)
pool_name = gensym(:pool)
_generate_pool_code(pool_name, expr, false)
_generate_pool_code(pool_name, expr, false; source=__source__)
end

# Backend-specific variants: @maybe_with_pool :cuda pool begin ... end
macro maybe_with_pool(backend::QuoteNode, pool_name, expr)
_generate_pool_code_with_backend(backend.value, pool_name, expr, false)
_generate_pool_code_with_backend(backend.value, pool_name, expr, false; source=__source__)
end

macro maybe_with_pool(backend::QuoteNode, expr)
pool_name = gensym(:pool)
_generate_pool_code_with_backend(backend.value, pool_name, expr, false)
_generate_pool_code_with_backend(backend.value, pool_name, expr, false; source=__source__)
end

# ==============================================================================
Expand All @@ -189,17 +189,131 @@ function _disabled_pool_expr(backend::Symbol)
end
end

# ==============================================================================
# Internal: Source Location Helpers
# ==============================================================================

"""
_find_first_lnn_index(args) -> Union{Int, Nothing}

Find the index of the first LineNumberNode in the leading prefix of `args`.

Scans sequentially, skipping `Expr(:meta, ...)` nodes (inserted by `@inline`,
`@inbounds`, etc.). Returns `nothing` as soon as a non-meta, non-LNN expression
is encountered—this prevents matching LNNs deeper in the AST.

# Example AST prefix patterns
- `[Expr(:meta,:inline), LNN, ...]` → returns 2
- `[LNN, ...]` → returns 1
- `[Expr(:meta,:inline), Expr(:call,...), LNN, ...]` → returns nothing (stopped at call)
"""
function _find_first_lnn_index(args)
for (i, arg) in enumerate(args)
if arg isa LineNumberNode
return i
elseif arg isa Expr && arg.head === :meta
continue
else
return nothing
end
end
return nothing
end

"""
_ensure_body_has_toplevel_lnn(body, source)

Ensure body has a LineNumberNode pointing to user source at the top level.
- Scans first few args to handle Expr(:meta, ...) from @inline etc.
- If first LNN points to user file (same as source.file), preserve it
- If first LNN points elsewhere (e.g., macros.jl), replace with source LNN
- If no LNN exists, prepend source LNN
- If source.file === :none (REPL/eval), don't clobber valid file LNNs

Returns a new Expr to avoid mutating the original AST.
"""
function _ensure_body_has_toplevel_lnn(body, source::Union{LineNumberNode,Nothing})
source === nothing && return body
# Don't clobber valid file info with :none from REPL/eval
source.file === :none && return body
source_lnn = LineNumberNode(source.line, source.file)

if body isa Expr && body.head === :block && !isempty(body.args)
lnn_idx = _find_first_lnn_index(body.args)
if lnn_idx !== nothing
existing_lnn = body.args[lnn_idx]
# Check if LNN already points to user file
if existing_lnn.file == source.file
return body # User file LNN already present
else
# Replace macros.jl LNN with source LNN
new_args = copy(body.args)
new_args[lnn_idx] = source_lnn
return Expr(:block, new_args...)
end
end
# No LNN found, prepend source LNN
return Expr(:block, source_lnn, body.args...)
elseif body isa Expr && body.head === :block
# Empty block
return Expr(:block, source_lnn)
else
# Non-block body
return Expr(:block, source_lnn, body)
end
end

"""
_fix_try_body_lnn!(expr, source)

Fix LineNumberNodes inside try blocks to point to user source.
Julia's stack trace uses the LAST LNN before error location for line numbers.
By replacing the first LNN in try body with source LNN, we ensure correct
line numbers in stack traces.

Scans first few args to handle Expr(:meta, ...) from @inline etc.
If source.file === :none (REPL/eval), don't clobber valid file LNNs.
Modifies expr in-place and returns it.
"""
function _fix_try_body_lnn!(expr, source::Union{LineNumberNode,Nothing})
source === nothing && return expr
# Don't clobber valid file info with :none from REPL/eval
source.file === :none && return expr
source_lnn = LineNumberNode(source.line, source.file)

if expr isa Expr
if expr.head === :try && length(expr.args) >= 1
try_body = expr.args[1]
if try_body isa Expr && try_body.head === :block && !isempty(try_body.args)
lnn_idx = _find_first_lnn_index(try_body.args)
if lnn_idx !== nothing
existing_lnn = try_body.args[lnn_idx]
if existing_lnn.file != source.file
# Replace macros.jl LNN with source LNN
try_body.args[lnn_idx] = source_lnn
end
end
end
end
# Recurse into all args
for arg in expr.args
_fix_try_body_lnn!(arg, source)
end
end
return expr
end

# ==============================================================================
# Internal: Code Generation
# ==============================================================================

function _generate_pool_code(pool_name, expr, force_enable)
function _generate_pool_code(pool_name, expr, force_enable; source::Union{LineNumberNode,Nothing}=nothing)
# Compile-time check: if pooling disabled, use DisabledPool to preserve backend context
if !USE_POOLING
disabled_pool = _disabled_pool_expr(:cpu)
if Meta.isexpr(expr, [:function, :(=)]) && _is_function_def(expr)
# Function definition: inject local pool = DisabledPool at start of body
return _generate_function_pool_code(pool_name, expr, force_enable, true, :cpu)
return _generate_function_pool_code(pool_name, expr, force_enable, true, :cpu; source)
else
return quote
local $(esc(pool_name)) = $disabled_pool
Expand All @@ -210,7 +324,7 @@ function _generate_pool_code(pool_name, expr, force_enable)

# Check if function definition
if Meta.isexpr(expr, [:function, :(=)]) && _is_function_def(expr)
return _generate_function_pool_code(pool_name, expr, force_enable, false)
return _generate_function_pool_code(pool_name, expr, force_enable, false; source)
end

# Block logic
Expand Down Expand Up @@ -304,12 +418,12 @@ Uses `_get_pool_for_backend(Val{backend}())` for zero-overhead dispatch.

Includes type-specific checkpoint/rewind optimization (same as regular @with_pool).
"""
function _generate_pool_code_with_backend(backend::Symbol, pool_name, expr, force_enable::Bool)
function _generate_pool_code_with_backend(backend::Symbol, pool_name, expr, force_enable::Bool; source::Union{LineNumberNode,Nothing}=nothing)
# Compile-time check: if pooling disabled, use DisabledPool to preserve backend context
if !USE_POOLING
disabled_pool = _disabled_pool_expr(backend)
if Meta.isexpr(expr, [:function, :(=)]) && _is_function_def(expr)
return _generate_function_pool_code_with_backend(backend, pool_name, expr, true)
return _generate_function_pool_code_with_backend(backend, pool_name, expr, true; source)
else
return quote
local $(esc(pool_name)) = $disabled_pool
Expand All @@ -323,7 +437,7 @@ function _generate_pool_code_with_backend(backend::Symbol, pool_name, expr, forc
disabled_pool = _disabled_pool_expr(backend)
# Check if function definition
if Meta.isexpr(expr, [:function, :(=)]) && _is_function_def(expr)
return _generate_function_pool_code_with_backend(backend, pool_name, expr, false)
return _generate_function_pool_code_with_backend(backend, pool_name, expr, false; source)
end

# Block logic with runtime check
Expand Down Expand Up @@ -378,7 +492,7 @@ function _generate_pool_code_with_backend(backend::Symbol, pool_name, expr, forc

# Check if function definition
if Meta.isexpr(expr, [:function, :(=)]) && _is_function_def(expr)
return _generate_function_pool_code_with_backend(backend, pool_name, expr, false)
return _generate_function_pool_code_with_backend(backend, pool_name, expr, false; source)
end

# Block logic: Extract types from acquire! calls for optimized checkpoint/rewind
Expand Down Expand Up @@ -444,7 +558,7 @@ end
Generate function code for a specific backend (e.g., :cuda).
Wraps the function body with pool getter, checkpoint, try-finally, rewind.
"""
function _generate_function_pool_code_with_backend(backend::Symbol, pool_name, func_def, disable_pooling::Bool)
function _generate_function_pool_code_with_backend(backend::Symbol, pool_name, func_def, disable_pooling::Bool; source::Union{LineNumberNode,Nothing}=nothing)
def_head = func_def.head
call_expr = func_def.args[1]
body = func_def.args[2]
Expand All @@ -455,6 +569,8 @@ function _generate_function_pool_code_with_backend(backend::Symbol, pool_name, f
local $(esc(pool_name)) = $disabled_pool
$(esc(body))
end
# Ensure new_body has source location for proper stack traces
new_body = _ensure_body_has_toplevel_lnn(new_body, source)
return Expr(def_head, esc(call_expr), new_body)
end

Expand Down Expand Up @@ -508,10 +624,13 @@ function _generate_function_pool_code_with_backend(backend::Symbol, pool_name, f
end
end

# Ensure new_body has source location for proper stack traces
new_body = _ensure_body_has_toplevel_lnn(new_body, source)
_fix_try_body_lnn!(new_body, source) # Fix try block LNNs for accurate stack traces
return Expr(def_head, esc(call_expr), new_body)
end

function _generate_function_pool_code(pool_name, func_def, force_enable, disable_pooling, backend::Symbol=:cpu)
function _generate_function_pool_code(pool_name, func_def, force_enable, disable_pooling, backend::Symbol=:cpu; source::Union{LineNumberNode,Nothing}=nothing)
def_head = func_def.head
call_expr = func_def.args[1]
body = func_def.args[2]
Expand All @@ -522,6 +641,8 @@ function _generate_function_pool_code(pool_name, func_def, force_enable, disable
local $(esc(pool_name)) = $disabled_pool
$(esc(body))
end
# Ensure new_body has source location for proper stack traces
new_body = _ensure_body_has_toplevel_lnn(new_body, source)
return Expr(def_head, esc(call_expr), new_body)
end

Expand Down Expand Up @@ -591,6 +712,9 @@ function _generate_function_pool_code(pool_name, func_def, force_enable, disable
end
end

# Ensure new_body has source location for proper stack traces
new_body = _ensure_body_has_toplevel_lnn(new_body, source)
_fix_try_body_lnn!(new_body, source) # Fix try block LNNs for accurate stack traces
return Expr(def_head, esc(call_expr), new_body)
end

Expand Down
Loading