Skip to content

Commit 9aa289c

Browse files
committed
Adding MPI test
1 parent 5607194 commit 9aa289c

File tree

3 files changed

+118
-44
lines changed

3 files changed

+118
-44
lines changed

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
77
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
88
LLVM_jll = "86de99a1-58d6-5da7-8064-bd56ce2e322c"
99
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
10+
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
1011
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1112
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1213
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

test/mpi.jl

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
using MPI
2+
using Enzyme
3+
using Test
4+
5+
struct Context
6+
x::Vector{Float64}
7+
end
8+
9+
function halo(context)
10+
x = context.x
11+
np = MPI.Comm_size(MPI.COMM_WORLD)
12+
rank = MPI.Comm_rank(MPI.COMM_WORLD)
13+
requests = Vector{MPI.Request}()
14+
if rank != 0
15+
buf = @view x[1:1]
16+
push!(requests, MPI.Isend(x[2:2], MPI.COMM_WORLD; dest=rank-1, tag=0))
17+
push!(requests, MPI.Irecv!(buf, MPI.COMM_WORLD; source=rank-1, tag=0))
18+
end
19+
if rank != np-1
20+
buf = @view x[end:end]
21+
push!(requests, MPI.Isend(x[end-1:end-1], MPI.COMM_WORLD; dest=rank+1, tag=0))
22+
push!(requests, MPI.Irecv!(buf, MPI.COMM_WORLD; source=rank+1, tag=0))
23+
end
24+
for request in requests
25+
MPI.Wait(request)
26+
end
27+
return nothing
28+
end
29+
30+
MPI.Init()
31+
np = MPI.Comm_size(MPI.COMM_WORLD)
32+
rank = MPI.Comm_rank(MPI.COMM_WORLD)
33+
n = np*10
34+
n1 = Int(round(rank / np * (n+np))) - rank
35+
n2 = Int(round((rank + 1) / np * (n+np))) - rank
36+
nl = rank == 0 ? n1+1 : n1
37+
nr = rank == np-1 ? n2-1 : n2
38+
nlocal = nr-nl+1
39+
context = Context(zeros(nlocal))
40+
fill!(context.x, Float64(rank))
41+
halo(context)
42+
if rank != 0
43+
@test context.x[1] == Float64(rank-1)
44+
end
45+
if rank != np-1
46+
@test context.x[end] == Float64(rank+1)
47+
end
48+
49+
dcontext = Context(zeros(nlocal))
50+
fill!(dcontext.x, Float64(rank))
51+
autodiff(halo, Duplicated(context, dcontext))
52+
MPI.Barrier(MPI.COMM_WORLD)
53+
if rank != 0
54+
@test dcontext.x[2] == Float64(rank + rank - 1)
55+
end
56+
if rank != np-1
57+
@test dcontext.x[end-1] == Float64(rank + rank + 1)
58+
end
59+
if !isinteractive()
60+
MPI.Finalize()
61+
end

0 commit comments

Comments
 (0)