diff --git a/src/reverse_mode.jl b/src/reverse_mode.jl index 72987f5..f58125b 100644 --- a/src/reverse_mode.jl +++ b/src/reverse_mode.jl @@ -518,10 +518,9 @@ function _forward_eval( @inbounds ix1 = children_arr[idx1] @inbounds ix2 = children_arr[idx2] @assert f.sizes.ndims[ix2] == 0 "Broadcasted ^ requires scalar exponent" - exponent = _scalar_load( - f.forward_storage, - f.sizes.storage_offset[ix2]+1, - ) + # If it is a constant, we can just read it from the `const_values` and avoid a GPU->CPU communication + @assert f.nodes[ix2].type == NODE_VALUE + exponent = f.const_values[f.nodes[ix2].index] out = _view_linear(f.forward_storage, f.sizes, k) inp = _view_linear(f.forward_storage, f.sizes, ix1) partials = _view_linear(f.partials_storage, f.sizes, ix1) diff --git a/src/sizes.jl b/src/sizes.jl index bf5469d..d168638 100644 --- a/src/sizes.jl +++ b/src/sizes.jl @@ -70,16 +70,6 @@ function _setindex!(x, value, sizes::Sizes, k::Int, j) return x[sizes.storage_offset[k]+j] = value end -""" - _scalar_load(storage, idx) -> Float64 - -Read a single Float64 from `storage` at linear index `idx`. The default -implementation just calls `getindex`; this is a hook for storage backends -(such as `CuVector`) that disallow scalar indexing and need to dispatch to a -1-element transfer instead. -""" -_scalar_load(storage::AbstractVector, idx::Int) = @inbounds storage[idx] - function _view_scalar(storage::AbstractVector, sizes::Sizes, k::Int) pos = _scalar_pos(sizes, k) return view(storage, reshape(pos:pos, ()))