1+ using MPI
2+ using Enzyme
3+ using Test
4+
5+ struct Context
6+ x:: Vector{Float64}
7+ end
8+
9+ function halo (context)
10+ x = context. x
11+ np = MPI. Comm_size (MPI. COMM_WORLD)
12+ rank = MPI. Comm_rank (MPI. COMM_WORLD)
13+ requests = Vector {MPI.Request} ()
14+ if rank != 0
15+ buf = @view x[1 : 1 ]
16+ push! (requests, MPI. Isend (x[2 : 2 ], MPI. COMM_WORLD; dest= rank- 1 , tag= 0 ))
17+ push! (requests, MPI. Irecv! (buf, MPI. COMM_WORLD; source= rank- 1 , tag= 0 ))
18+ end
19+ if rank != np- 1
20+ buf = @view x[end : end ]
21+ push! (requests, MPI. Isend (x[end - 1 : end - 1 ], MPI. COMM_WORLD; dest= rank+ 1 , tag= 0 ))
22+ push! (requests, MPI. Irecv! (buf, MPI. COMM_WORLD; source= rank+ 1 , tag= 0 ))
23+ end
24+ for request in requests
25+ MPI. Wait (request)
26+ end
27+ return nothing
28+ end
29+
30+ MPI. Init ()
31+ np = MPI. Comm_size (MPI. COMM_WORLD)
32+ rank = MPI. Comm_rank (MPI. COMM_WORLD)
33+ n = np* 10
34+ n1 = Int (round (rank / np * (n+ np))) - rank
35+ n2 = Int (round ((rank + 1 ) / np * (n+ np))) - rank
36+ nl = rank == 0 ? n1+ 1 : n1
37+ nr = rank == np- 1 ? n2- 1 : n2
38+ nlocal = nr- nl+ 1
39+ context = Context (zeros (nlocal))
40+ fill! (context. x, Float64 (rank))
41+ halo (context)
42+ if rank != 0
43+ @test context. x[1 ] == Float64 (rank- 1 )
44+ end
45+ if rank != np- 1
46+ @test context. x[end ] == Float64 (rank+ 1 )
47+ end
48+
49+ dcontext = Context (zeros (nlocal))
50+ fill! (dcontext. x, Float64 (rank))
51+ autodiff (halo, Duplicated (context, dcontext))
52+ MPI. Barrier (MPI. COMM_WORLD)
53+ if rank != 0
54+ @test dcontext. x[2 ] == Float64 (rank + rank - 1 )
55+ end
56+ if rank != np- 1
57+ @test dcontext. x[end - 1 ] == Float64 (rank + rank + 1 )
58+ end
59+ if ! isinteractive ()
60+ MPI. Finalize ()
61+ end
0 commit comments