Skip to content

Commit 0f2cbdb

Browse files
committed
Adding MPI test
1 parent 1eeb898 commit 0f2cbdb

File tree

3 files changed

+94
-23
lines changed

3 files changed

+94
-23
lines changed

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48"
99
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
1010
LLVM_jll = "86de99a1-58d6-5da7-8064-bd56ce2e322c"
1111
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
12-
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
12+
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
1313
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1414
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1515
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

test/mpi.jl

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
using MPI
2+
using Enzyme
3+
using Test
4+
5+
struct Context
6+
x::Vector{Float64}
7+
end
8+
9+
function halo(context)
10+
x = context.x
11+
np = MPI.Comm_size(MPI.COMM_WORLD)
12+
rank = MPI.Comm_rank(MPI.COMM_WORLD)
13+
requests = Vector{MPI.Request}()
14+
if rank != 0
15+
buf = @view x[1:1]
16+
push!(requests, MPI.Isend(x[2:2], MPI.COMM_WORLD; dest=rank-1, tag=0))
17+
push!(requests, MPI.Irecv!(buf, MPI.COMM_WORLD; source=rank-1, tag=0))
18+
end
19+
if rank != np-1
20+
buf = @view x[end:end]
21+
push!(requests, MPI.Isend(x[end-1:end-1], MPI.COMM_WORLD; dest=rank+1, tag=0))
22+
push!(requests, MPI.Irecv!(buf, MPI.COMM_WORLD; source=rank+1, tag=0))
23+
end
24+
for request in requests
25+
MPI.Wait(request)
26+
end
27+
return nothing
28+
end
29+
30+
MPI.Init()
31+
np = MPI.Comm_size(MPI.COMM_WORLD)
32+
rank = MPI.Comm_rank(MPI.COMM_WORLD)
33+
n = np*10
34+
n1 = Int(round(rank / np * (n+np))) - rank
35+
n2 = Int(round((rank + 1) / np * (n+np))) - rank
36+
nl = rank == 0 ? n1+1 : n1
37+
nr = rank == np-1 ? n2-1 : n2
38+
nlocal = nr-nl+1
39+
context = Context(zeros(nlocal))
40+
fill!(context.x, Float64(rank))
41+
halo(context)
42+
if rank != 0
43+
@test context.x[1] == Float64(rank-1)
44+
end
45+
if rank != np-1
46+
@test context.x[end] == Float64(rank+1)
47+
end
48+
49+
dcontext = Context(zeros(nlocal))
50+
fill!(dcontext.x, Float64(rank))
51+
autodiff(Reverse, halo, Duplicated(context, dcontext))
52+
MPI.Barrier(MPI.COMM_WORLD)
53+
if rank != 0
54+
@test dcontext.x[2] == Float64(rank + rank - 1)
55+
end
56+
if rank != np-1
57+
@test dcontext.x[end-1] == Float64(rank + rank + 1)
58+
end
59+
if !isinteractive()
60+
MPI.Finalize()
61+
end

test/runtests.jl

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ using ForwardDiff
2020
using Aqua
2121
using Statistics
2222
using LinearAlgebra
23-
using InlineStrings
23+
using MPI
2424

2525
using Enzyme_jll
2626
@info "Testing against" Enzyme_jll.libEnzyme
@@ -236,7 +236,7 @@ make3() = (1.0, 2.0, 3.0)
236236
test_scalar(x->rem(x, 1), 0.7)
237237
test_scalar(x->rem2pi(x,RoundDown), 0.7)
238238
test_scalar(x->fma(x,x+1,x/3), 2.3)
239-
239+
240240
@test autodiff(Forward, sincos, Duplicated(1.0, 1.0))[1][1] cos(1.0)
241241

242242
@test autodiff(Reverse, (x)->log(x), Active(2.0)) == ((0.5,),)
@@ -588,7 +588,7 @@ end
588588

589589
bias = Float32[0.0;;;]
590590
res = Enzyme.autodiff(Reverse, f, Active, Active(x[1]), Const(bias))
591-
591+
592592
@test bias[1][1] 0.0
593593
@test res[1][1] cos(x[1])
594594
end
@@ -931,7 +931,7 @@ end
931931

932932
@inline function myquantile(v::AbstractVector, p::Real; alpha)
933933
n = length(v)
934-
934+
935935
m = 1.0 + p * (1.0 - alpha - 1.0)
936936
aleph = n*p + oftype(p, m)
937937
j = clamp(trunc(Int, aleph), 1, n-1)
@@ -944,7 +944,7 @@ end
944944
a = @inbounds v[j]
945945
b = @inbounds v[j + 1]
946946
end
947-
947+
948948
return a + γ*(b-a)
949949
end
950950

@@ -1166,18 +1166,18 @@ end
11661166
@test 1.0 Enzyme.autodiff(Forward, inactive_gen, Duplicated(1E4, 1.0))[1]
11671167

11681168
function whocallsmorethan30args(R)
1169-
temp = diag(R)
1170-
R_inv = [temp[1] 0. 0. 0. 0. 0.;
1171-
0. temp[2] 0. 0. 0. 0.;
1172-
0. 0. temp[3] 0. 0. 0.;
1173-
0. 0. 0. temp[4] 0. 0.;
1174-
0. 0. 0. 0. temp[5] 0.;
1169+
temp = diag(R)
1170+
R_inv = [temp[1] 0. 0. 0. 0. 0.;
1171+
0. temp[2] 0. 0. 0. 0.;
1172+
0. 0. temp[3] 0. 0. 0.;
1173+
0. 0. 0. temp[4] 0. 0.;
1174+
0. 0. 0. 0. temp[5] 0.;
11751175
]
1176-
1176+
11771177
return sum(R_inv)
11781178
end
1179-
1180-
R = zeros(6,6)
1179+
1180+
R = zeros(6,6)
11811181
dR = zeros(6, 6)
11821182
autodiff(Reverse, whocallsmorethan30args, Active, Duplicated(R, dR))
11831183

@@ -1845,7 +1845,7 @@ end
18451845
end
18461846
# TODO: Add test for NoShadowException
18471847
end
1848-
1848+
18491849
function indirectfltret(a)::DataType
18501850
a[] *= 2
18511851
return Float64
@@ -2313,7 +2313,7 @@ end
23132313
Enzyme.API.runtimeActivity!(false)
23142314
@test res[1] 0.2
23152315
# broken as the return of an apply generic is {primal, primal}
2316-
# but since the return is abstractfloat doing the
2316+
# but since the return is abstractfloat doing the
23172317
@static if VERSION v"1.9-" && !(VERSION v"1.10-" )
23182318
@test_broken res[2] 1.0
23192319
else
@@ -2383,6 +2383,16 @@ end
23832383
)
23842384
@test ad_eta[1] 0.0
23852385
end
2386+
@testset "MPI" begin
2387+
testdir = @__DIR__
2388+
# Test parsing
2389+
include("mpi.jl")
2390+
mpiexec() do cmd
2391+
run(`$cmd -n 2 $(Base.julia_cmd()) --project=$testdir $testdir/mpi.jl`)
2392+
end
2393+
@test true
2394+
end
2395+
23862396

23872397
@testset "Tape Width" begin
23882398
struct Roo
@@ -2452,10 +2462,10 @@ end
24522462
Duplicated(inters, dinters),
24532463
)
24542464

2455-
@test dinters[1].k 0.1
2456-
@test dinters[1].t0 1.0
2457-
@test dinters[2].k 0.3
2458-
@test dinters[2].t0 2.0
2465+
@test dinters[1].k 0.1
2466+
@test dinters[1].t0 1.0
2467+
@test dinters[2].k 0.3
2468+
@test dinters[2].t0 2.0
24592469
end
24602470

24612471
@testset "Statistics" begin
@@ -2524,7 +2534,7 @@ end
25242534
y = A \ b
25252535
@test dA (-z * transpose(y))
25262536
@test db z
2527-
2537+
25282538
db = zero(b)
25292539

25302540
forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Const{typeof(A)}, Duplicated{typeof(b)})
@@ -2540,7 +2550,7 @@ end
25402550

25412551
y = A \ b
25422552
@test db z
2543-
2553+
25442554
dA = zero(A)
25452555

25462556
forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Duplicated{typeof(A)}, Const{typeof(b)})

0 commit comments

Comments
 (0)