Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement threaded broadcast array type to make time integration multithreaded #722

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions src/general/semidiscretization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,13 @@ function semidiscretize(semi, tspan; reset_threads=true, data_type=nothing)

if isnothing(data_type)
# Use CPU vectors and the optimized CPU code
u0_ode = Vector{ELTYPE}(undef, sum(sizes_u))
v0_ode = Vector{ELTYPE}(undef, sum(sizes_v))
u0_ode_ = Vector{ELTYPE}(undef, sum(sizes_u))
v0_ode_ = Vector{ELTYPE}(undef, sum(sizes_v))
u0_ode = ThreadedBroadcastArray(u0_ode_)
v0_ode = ThreadedBroadcastArray(v0_ode_)

# u0_ode = Vector{ELTYPE}(undef, sum(sizes_u))
# v0_ode = Vector{ELTYPE}(undef, sum(sizes_v))
else
# Use the specified data type, e.g., `CuArray` or `ROCArray`
u0_ode = data_type{ELTYPE}(undef, sum(sizes_u))
Expand Down Expand Up @@ -395,6 +400,10 @@ end
return PtrArray(pointer(view(array, range)), size)
end

@inline function wrap_array(array::ThreadedBroadcastArray, range, size)
return ThreadedBroadcastArray(wrap_array(parent(array), range, size))
end

@inline function wrap_array(array, range, size)
# For non-`Array`s (typically GPU arrays), just reshape. Calling the `PtrArray` code
# above for a `CuArray` yields another `CuArray` (instead of a `PtrArray`)
Expand Down
74 changes: 74 additions & 0 deletions src/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,77 @@ function compute_git_hash()
return "UnknownVersion"
end
end

# This data type wraps regular arrays and redefines broadcasting and common operations
# like `fill!` and `copyto!` to use multithreading with `@threaded`.
# See https://github.com/trixi-framework/TrixiParticles.jl/pull/722 for more details
# and benchmarks.
struct ThreadedBroadcastArray{T, N, A <: AbstractArray{T, N}} <: AbstractArray{T, N}
array::A

function ThreadedBroadcastArray(array::AbstractArray{T, N}) where {T, N}
new{T, N, typeof(array)}(array)
end
end

Base.parent(A::ThreadedBroadcastArray) = A.array
Base.pointer(A::ThreadedBroadcastArray) = pointer(parent(A))
Base.size(A::ThreadedBroadcastArray) = size(parent(A))

function Base.similar(A::ThreadedBroadcastArray, ::Type{T}) where {T}
return ThreadedBroadcastArray(similar(A.array, T))
end

Base.@propagate_inbounds function Base.getindex(A::ThreadedBroadcastArray, i...)
return getindex(A.array, i...)
end

Base.@propagate_inbounds function Base.setindex!(A::ThreadedBroadcastArray, x...)
setindex!(A.array, x...)
return A
end

function Base.fill!(A::ThreadedBroadcastArray{T}, x) where {T}
xT = x isa T ? x : convert(T, x)::T
@threaded A.array for i in eachindex(A.array)
@inbounds A.array[i] = xT
end

return A
end

function Base.copyto!(dest::ThreadedBroadcastArray, src::AbstractArray)
if eachindex(dest) == eachindex(src)
# Shared-iterator implementation
@threaded dest.array for I in eachindex(dest)
@inbounds dest.array[I] = src[I]
end
else
# Dual-iterator implementation
@threaded dest.array for (Idest, Isrc) in zip(eachindex(dest), eachindex(src))
@inbounds dest.array[Idest] = src[Isrc]
end
end

return dest
end

function Broadcast.BroadcastStyle(::Type{ThreadedBroadcastArray{T, N, A}}) where {T, N, A}
return Broadcast.ArrayStyle{ThreadedBroadcastArray}()
end

function Broadcast.copyto!(dest::ThreadedBroadcastArray,
bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{ThreadedBroadcastArray}})
# Check bounds
axes(dest.array) == axes(bc) || Broadcast.throwdm(axes(dest.array), axes(bc))

@threaded dest.array for i in eachindex(dest.array)
@inbounds dest.array[i] = bc[i]
end
return dest
end

function Base.similar(::Broadcast.Broadcasted{Broadcast.ArrayStyle{ThreadedBroadcastArray}},
::Type{T}, dims) where {T}
return ThreadedBroadcastArray(similar(Array{T}, dims))
end
2 changes: 1 addition & 1 deletion test/schemes/solid/total_lagrangian_sph/rhs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@

v_ode = ode.u0.x[1]
if isnothing(data_type)
u_ode = vec(u)
u_ode = TrixiParticles.ThreadedBroadcastArray(vec(u))
else
u_ode = data_type(vec(u))
end
Expand Down
Loading