diff --git a/Project.toml b/Project.toml index 20f1d0a..b5f1d4f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Elemental" uuid = "902c3f28-d1ec-5e7e-8399-a24c3845ee38" -version = "0.6.0" +version = "0.7.0" [deps] Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" @@ -8,18 +8,10 @@ DistributedArrays = "aaf54ef3-cdf8-58ed-94cc-d582ad619b94" Elemental_jll = "c2e960f2-a21d-557e-aa36-859d46eed7e8" Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" [compat] +MPI = "0.20" DistributedArrays = "0.5, 0.6" Elemental_jll = "0.87" -julia = "1.3" - -[extras] -MPIClusterManagers = "e7922434-ae4b-11e9-05c5-9780451d2c66" -Primes = "27ebfcd6-29c5-5fa9-bf4b-fb8fc14df3ae" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -TSVD = "9449cd9e-2762-5aa3-a617-5413e99d722e" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[targets] -test = ["Test", "MPIClusterManagers", "Primes", "TSVD", "Random"] +julia = "1.6" diff --git a/src/Elemental.jl b/src/Elemental.jl index 5b78ada..e09a7d9 100644 --- a/src/Elemental.jl +++ b/src/Elemental.jl @@ -1,5 +1,6 @@ module Elemental +import MPI using Distributed using DistributedArrays using LinearAlgebra @@ -41,7 +42,6 @@ function __init__() end include("core/types.jl") -include("mpi.jl") include("core/matrix.jl") include("core/grid.jl") include("core/sparsematrix.jl") diff --git a/src/core/distmatrix.jl b/src/core/distmatrix.jl index 92b3879..91b0682 100644 --- a/src/core/distmatrix.jl +++ b/src/core/distmatrix.jl @@ -37,11 +37,11 @@ for (elty, ext) in ((:ElInt, :i), # end function comm(A::DistMatrix{$elty}) - cm = Ref{ElComm}() + cm = Ref{MPI.API.MPI_Comm}() ElError(ccall(($(string("ElDistMatrixDistComm_", ext)), libEl), Cuint, - (Ptr{Cvoid}, Ref{ElComm}), + (Ptr{Cvoid}, Ref{MPI.API.MPI_Comm}), A.obj, cm)) - return cm[] + return MPI.Comm(cm[]) end function get(A::DistMatrix{$elty}, i::Integer, j::Integer) diff --git a/src/core/distmultivec.jl b/src/core/distmultivec.jl index a97e2c7..51d8c71 100644 --- a/src/core/distmultivec.jl +++ b/src/core/distmultivec.jl @@ -16,10 +16,10 @@ for (elty, ext) in ((:ElInt, :i), return nothing end - function DistMultiVec(::Type{$elty}, cm::ElComm = MPI.CommWorld[]) + function DistMultiVec(::Type{$elty}, cm::MPI.Comm = MPI.COMM_WORLD) obj = Ref{Ptr{Cvoid}}(C_NULL) ElError(ccall(($(string("ElDistMultiVecCreate_", ext)), libEl), Cuint, - (Ref{Ptr{Cvoid}}, ElComm), + (Ref{Ptr{Cvoid}}, MPI.API.MPI_Comm), obj, cm)) A = DistMultiVec{$elty}(obj[]) finalizer(destroy, A) @@ -27,11 +27,11 @@ for (elty, ext) in ((:ElInt, :i), end function comm(A::DistMultiVec{$elty}) - cm = Ref{ElComm}() + cm = Ref{MPI.API.MPI_Comm}() ElError(ccall(($(string("ElDistMultiVecComm_", ext)), libEl), Cuint, - (Ptr{Cvoid}, Ref{ElComm}), + (Ptr{Cvoid}, Ref{MPI.API.MPI_Comm}), A.obj, cm)) - return cm[] + return MPI.Comm(cm[]) end function get(x::DistMultiVec{$elty}, i::Integer = size(x, 1), j::Integer = 1) @@ -117,7 +117,7 @@ end getindex(x::DistMultiVec, i, j) = get(x, i, j) -function similar(::DistMultiVec, ::Type{T}, sz::Dims, cm::ElComm = MPI.CommWorld[]) where {T} +function similar(::DistMultiVec, ::Type{T}, sz::Dims, cm::MPI.Comm = MPI.COMM_WORLD) where {T} A = DistMultiVec(T, cm) resize!(A, sz...) return A diff --git a/src/core/distsparsematrix.jl b/src/core/distsparsematrix.jl index 131b210..32f469d 100644 --- a/src/core/distsparsematrix.jl +++ b/src/core/distsparsematrix.jl @@ -16,10 +16,10 @@ for (elty, ext) in ((:ElInt, :i), return nothing end - function DistSparseMatrix(::Type{$elty}, comm::ElComm = MPI.CommWorld[]) + function DistSparseMatrix(::Type{$elty}, comm::MPI.COMM_WORLD = MPI.COMM_WORLD) obj = Ref{Ptr{Cvoid}}(C_NULL) ElError(ccall(($(string("ElDistSparseMatrixCreate_", ext)), libEl), Cuint, - (Ref{Ptr{Cvoid}}, ElComm), + (Ref{Ptr{Cvoid}}, MPI.API.MPI_Comm), obj, comm)) A = DistSparseMatrix{$elty}(obj[]) finalizer(destroy, A) @@ -27,9 +27,9 @@ for (elty, ext) in ((:ElInt, :i), end function comm(A::DistSparseMatrix{$elty}) - cm = Ref{ElComm}() + cm = Ref{MPI.API.MPI_Comm}() ElError(ccall(($(string("ElDistSparseMatrixComm_", ext)), libEl), Cuint, - (Ptr{Cvoid}, Ref{ElComm}), + (Ptr{Cvoid}, Ref{MPI.API.MPI_Comm}), A.obj, cm)) return cm[] end @@ -112,7 +112,7 @@ for (elty, ext) in ((:ElInt, :i), end # The other constructors don't have a version with dimensions. Should they, or should this one go? -function DistSparseMatrix(::Type{T}, m::Integer, n::Integer, comm::ElComm = MPI.CommWorld[]) where {T} +function DistSparseMatrix(::Type{T}, m::Integer, n::Integer, comm::MPI.Comm = MPI.COMM_WORLD) where {T} A = DistSparseMatrix(T, comm) resize!(A, m, n) return A diff --git a/src/core/types.jl b/src/core/types.jl index dd3aec9..ad35a19 100644 --- a/src/core/types.jl +++ b/src/core/types.jl @@ -6,14 +6,6 @@ function ElIntType() end const ElInt = ElIntType() -function ElCommType() - sameSizeAsInt = Cint[0] - ElError(ccall((:ElMPICommSameSizeAsInteger, libEl), Cuint, (Ptr{Cint},), - sameSizeAsInt)) - return sameSizeAsInt[1] == 1 ? Cint : Ptr{Cvoid} -end -const ElComm = ElCommType() - function ElGroupType() sameSizeAsInt = Cint[0] ElError(ccall((:ElMPIGroupSameSizeAsInteger, libEl), Cuint, (Ptr{Cint},), diff --git a/src/mpi.jl b/src/mpi.jl deleted file mode 100644 index f1adf43..0000000 --- a/src/mpi.jl +++ /dev/null @@ -1,115 +0,0 @@ -module MPI - -using Libdl -using Elemental: ElComm, ElElementType, ElInt, libEl - -const MPIImpl = Ref{Symbol}() -const CommWorld = Ref{Any}() - -function __init__() - # FixMe! The symbol could probably also be missing for other implementations - # - # NOTE! I'm using RTLD_GLOBAL here to avoid the OPEN-MPI error described in - # https://www.open-mpi.org/faq/?category=troubleshooting#missing-symbols - if Libdl.dlsym_e(Libdl.dlopen(libEl, Libdl.RTLD_GLOBAL), :MPI_Get_library_version) == C_NULL - global MPIImpl[] = :MPICH2 - else - versionBuffer = Vector{UInt8}(undef, 2800) - len = Cint[0] - err = ccall((:MPI_Get_library_version, libEl), Cint, (Ptr{UInt8}, Ptr{Cint}), versionBuffer, len) - versionString = String(versionBuffer[1:len[1]-1]) - if occursin(r"Open MPI", versionString) - global MPIImpl[] = :OpenMPI - elseif occursin(r"MPICH", versionString) - global MPIImpl[] = :MPICH3 - else - error("don't know which MPI implemetation you are using here") - end - end - - CommWorld[] = CommWorldValue() -end - -# Get MPIWorldComm -function CommWorldValue() - r = Ref{ElComm}(0) - ccall((:ElMPICommWorld, libEl), Cuint, (Ref{ElComm},), r) - return r[] -end - -function MPIType(t::DataType) - if MPIImpl[] == :OpenMPI - if t == Float64 - return Libdl.dlsym_e(Libdl.dlopen(libEl), :ompi_mpi_double) - elseif t == Cint - return Libdl.dlsym_e(Libdl.dlopen(libEl), :ompi_mpi_int) - elseif t == Clong - return Libdl.dlsym_e(Libdl.dlopen(libEl), :ompi_mpi_long_int) - else - error("data type not defined yet") - end - elseif MPIImpl[] == :MPICH2 || MPIImpl[] == :MPICH3 - if t == Float64 - return Cint(0x4c00080b) - elseif t == Cint - return Cint(0x4c000405) - elseif t == Clong - return Cint(0x4c000807) - else - error("data type not defined yet") - end - else - error("MPI implementation not covered yet") - end -end - -function MPIOp(f::Function) - if MPIImpl[] == :OpenMPI - if f == (+) - return Libdl.dlsym_e(Libdl.dlopen(libEl), :ompi_mpi_op_sum) - else - error("operation not defined yet") - end - elseif MPIImpl[] == :MPICH2 || MPIImpl[] == :MPICH3 - if f == (+) - return Cint(0x58000003) - else - error("operation not defined yet") - end - else - error("MPI implementaion no covered yet") - end -end - -function commRank(comm::ElComm) - n = Ref{Cint}() - err = ccall((:MPI_Comm_rank, libEl), Cint, (ElComm, Ref{Cint}), comm, n) - if err != 0 - error("error value was $err") - end - return n[] -end - -function commSize(comm::ElComm) - n = Ref{Cint}() - err = ccall((:MPI_Comm_size, libEl), Cint, (ElComm, Ref{Cint}), comm, n) - if err != 0 - error("error value was $err") - end - return n[] -end - -# FixMe! Should be restricted to support element types -function allreduce(sendbuf::Ref{T}, recvbuf::Ref{T}, count::Integer, op::Function, comm::ElComm = CommWorld) where {T} - err = ccall((:MPI_Allreduce, libEl), Cint, - (Ref{T}, Ref{T}, Cint, ElComm, ElComm, ElComm), - sendbuf, recvbuf, count, MPIType(T), MPIOp(op), comm) - if err != 0 - error("error value was $err") - end - return recvbuf -end - -allreduce(value::T, op::Function, comm::ElComm = CommWorld) where {T} = ElInt(allreduce(Ref(value), Ref{T}(), 1, op, comm)[]) - -end # module diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 0000000..aa40c58 --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,7 @@ +[deps] +MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" +MPIClusterManagers = "e7922434-ae4b-11e9-05c5-9780451d2c66" +Primes = "27ebfcd6-29c5-5fa9-bf4b-fb8fc14df3ae" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +TSVD = "9449cd9e-2762-5aa3-a617-5413e99d722e" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"