Skip to content

Commit 11cc63c

Browse files
fmt option in adjaceny matrix + propagate copy_xj for Metal (#619)
* fmt in adjaceny matrix * fix
1 parent 6cf3920 commit 11cc63c

File tree

9 files changed

+67
-51
lines changed

9 files changed

+67
-51
lines changed

GNNGraphs/src/convert.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ function to_sparse(A::ADJMAT_T, T = nothing; dir = :out, num_nodes = nothing,
207207
A = sparse(A)
208208
end
209209
if !weighted
210-
A = map(x -> ifelse(x > 0, T(1), T(0)), A)
210+
A = binarize(A, T)
211211
end
212212
return A, num_nodes, num_edges
213213
end

GNNGraphs/src/query.jl

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ end
218218
adjacency_list(g::GNNGraph; dir = :out) = adjacency_list(g, 1:(g.num_nodes); dir)
219219

220220
"""
221-
adjacency_matrix(g::GNNGraph, T=eltype(g); dir=:out, weighted=true)
221+
adjacency_matrix(g::GNNGraph, T=eltype(g); dir=:out, weighted=true, fmt=nothing)
222222
223223
Return the adjacency matrix `A` for the graph `g`.
224224
@@ -227,29 +227,39 @@ If `dir=:in` instead, `A[i,j] > 0` denotes the presence of an edge from node `j`
227227
228228
User may specify the eltype `T` of the returned matrix.
229229
230-
If `weighted=true`, the `A` will contain the edge weights if any, otherwise the elements of `A` will be either 0 or 1.
231-
"""
232-
function Graphs.adjacency_matrix(g::GNNGraph{<:COO_T}, T::DataType = eltype(g); dir = :out,
233-
weighted = true)
234-
A, n, m = to_sparse(g.graph, T; num_nodes = g.num_nodes, weighted)
230+
If `weighted=true`, the matrix `A` will contain the edge weights if any, otherwise the elements of `A`.
231+
If the graph does not contain edge weights, or if `weighted=false`, the adjacency matrix will contain only 0s and 1s.
232+
233+
The argument `fmt` can be used to specify the desired format of the returned matrix. Possible values are:
234+
- `nothing`: return the matrix in the same format as the underlying graph representation.
235+
- `:sparse`: return a sparse matrix (default for COO graphs).
236+
- `:dense`: return a dense matrix (default for adjacency matrix graphs).
237+
"""
238+
function Graphs.adjacency_matrix(g::GNNGraph, T::DataType = eltype(g); dir = :out,
239+
weighted = true, fmt = nothing)
240+
if fmt === nothing
241+
if g.graph isa COO_T
242+
fmt = :sparse
243+
elseif g.graph isa SPARSE_T
244+
fmt = :sparse
245+
else
246+
fmt = :dense
247+
end
248+
end
249+
@assert fmt [:sparse, :dense]
250+
if fmt == :sparse
251+
A, n, m = to_sparse(g.graph, T; num_nodes = g.num_nodes, weighted)
252+
else
253+
A, n, m = to_dense(g.graph, T; num_nodes = g.num_nodes, weighted)
254+
end
235255
@assert size(A) == (n, n)
236256
return dir == :out ? A : A'
237257
end
238258

239-
function Graphs.adjacency_matrix(g::GNNGraph{<:ADJMAT_T}, T::DataType = eltype(g);
240-
dir = :out, weighted = true)
241-
@assert dir [:in, :out]
242-
A = g.graph
243-
if !weighted
244-
A = binarize(A, T)
245-
end
246-
A = T != eltype(A) ? T.(A) : A
247-
return dir == :out ? A : A'
248-
end
249259

250260
function CRC.rrule(::typeof(adjacency_matrix), g::G, T::DataType;
251-
dir = :out, weighted = true) where {G <: GNNGraph{<:ADJMAT_T}}
252-
A = adjacency_matrix(g, T; dir, weighted)
261+
dir=:out, weighted=true, fmt=nothing) where {G <: GNNGraph{<:ADJMAT_T}}
262+
A = adjacency_matrix(g, T; dir, weighted, fmt)
253263
if !weighted
254264
function adjacency_matrix_pullback_noweight(Δ)
255265
return (CRC.NoTangent(), CRC.ZeroTangent(), CRC.NoTangent())
@@ -266,8 +276,8 @@ function CRC.rrule(::typeof(adjacency_matrix), g::G, T::DataType;
266276
end
267277

268278
function CRC.rrule(::typeof(adjacency_matrix), g::G, T::DataType;
269-
dir = :out, weighted = true) where {G <: GNNGraph{<:COO_T}}
270-
A = adjacency_matrix(g, T; dir, weighted)
279+
dir=:out, weighted=true, fmt=nothing) where {G <: GNNGraph{<:COO_T}}
280+
A = adjacency_matrix(g, T; dir, weighted, fmt)
271281
w = get_edge_weight(g)
272282
if !weighted || w === nothing
273283
function adjacency_matrix_pullback_noweight(Δ)

GNNGraphs/test/Project.toml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
33
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
44
GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c"
5-
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
65
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
76
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
87
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
@@ -20,6 +19,3 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2019
TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a"
2120
TestItems = "1c621080-faea-4a02-84b6-bbd5e436b8fe"
2221
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
23-
24-
[compat]
25-
GPUArraysCore = "0.1"

GNNlib/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1616
[weakdeps]
1717
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
1818
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
19+
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
1920

2021
[extensions]
2122
GNNlibAMDGPUExt = "AMDGPU"
2223
GNNlibCUDAExt = "CUDA"
24+
GNNlibMetalExt = "Metal"
2325

2426
[compat]
2527
AMDGPU = "1"
@@ -28,6 +30,7 @@ ChainRulesCore = "1.24"
2830
DataStructures = "0.18"
2931
GNNGraphs = "1.4"
3032
LinearAlgebra = "1"
33+
Metal = "1.0"
3134
MLUtils = "0.4"
3235
NNlib = "0.9"
3336
Random = "1"

GNNlib/ext/GNNlibCUDAExt.jl

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,14 @@ const CUDA_COO_T = Tuple{T, T, V} where {T <: AnyCuArray{<:Integer}, V <: Union{
1515
## avoid the fast path on gpu until we have better cuda support
1616
function GNNlib.propagate(::typeof(copy_xj), g::GNNGraph{<:COO_T}, ::typeof(+),
1717
xi, xj::AnyCuMatrix, e)
18-
A = _adjacency_matrix(g, eltype(xj); weighted = false)
18+
19+
if !g.is_coalesced
20+
# Revisit after
21+
# https://github.com/JuliaGPU/CUDA.jl/issues/1113
22+
A = adjacency_matrix(g, eltype(xj); weighted=false, fmt=:dense)
23+
else
24+
A = adjacency_matrix(g, eltype(xj); weighted=false, fmt=:sparse)
25+
end
1926

2027
return xj * A
2128
end
@@ -47,21 +54,4 @@ end
4754

4855
# Flux.Zygote.@nograd compute_degree
4956

50-
## CUSTOM ADJACENCY_MATRIX IMPLEMENTATION FOR CUDA COO GRAPHS, returning dense matrix when not coalesced, more efficient
51-
52-
function _adjacency_matrix(g::GNNGraph{<:CUDA_COO_T}, T::DataType = eltype(g); dir = :out,
53-
weighted = true)
54-
if !g.is_coalesced
55-
# Revisit after
56-
# https://github.com/JuliaGPU/CUDA.jl/issues/1113
57-
A, n, m = to_dense(g.graph, T; num_nodes = g.num_nodes, weighted) # if not coalesced, construction of sparse matrix is slow
58-
else
59-
A, n, m = to_sparse(g.graph, T; num_nodes = g.num_nodes, weighted, is_coalesced = true)
60-
end
61-
@assert size(A) == (n, n)
62-
return dir == :out ? A : A'
63-
end
64-
65-
@non_differentiable _adjacency_matrix(x...)
66-
6757
end #module

GNNlib/ext/GNNlibMetalExt.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
module GNNlibMetalExt
2+
3+
using Metal
4+
using Random, Statistics, LinearAlgebra
5+
using GNNlib: GNNlib, propagate, copy_xj, e_mul_xj, w_mul_xj
6+
using GNNGraphs: GNNGraph, COO_T, SPARSE_T, adjacency_matrix
7+
using ChainRulesCore: @non_differentiable
8+
9+
const METAL_COO_T = Tuple{T, T, V} where {T <: MtlVector{<:Integer}, V <: Union{Nothing, MtlVector}}
10+
11+
###### PROPAGATE SPECIALIZATIONS ####################
12+
13+
## COPY_XJ
14+
15+
## Metal does not support sparse arrays yet and neither scater.
16+
## Have to use dense adjacency matrix multiplication for now.
17+
function GNNlib.propagate(::typeof(copy_xj), g::GNNGraph{<:METAL_COO_T}, ::typeof(+),
18+
xi, xj::AbstractMatrix, e)
19+
A = adjacency_matrix(g, eltype(xj), weighted=false, fmt=:dense)
20+
return xj * A
21+
end
22+
23+
end #module

GNNlib/test/Project.toml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
44
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
55
GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c"
66
GNNlib = "a6a84749-d869-43f8-aacc-be26a1996e48"
7-
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
87
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
98
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
109
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
@@ -17,6 +16,3 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1716
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1817
TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a"
1918
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
20-
21-
[compat]
22-
GPUArraysCore = "0.1"

GraphNeuralNetworks/docs/src/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ pkg> add GraphNeuralNetworks
3030

3131
Let's give a brief overview of the package by solving a graph regression problem with synthetic data.
3232

33-
Other usage examples can be found in the [examples](https://github.com/JuliaGraphs/GraphNeuralNetworks.jl/tree/master/GraphNeuralNetworks/examples) folder, in the [notebooks](https://github.com/JuliaGraphs/GraphNeuralNetworks.jl/tree/master/GraphNeuralNetworks/notebooks) folder, and in the [tutorials](https://juliagraphs.org/GraphNeuralNetworks.jl/tutorials/) section of the documentation.
33+
Other usage examples can be found in the [examples](https://github.com/JuliaGraphs/GraphNeuralNetworks.jl/tree/master/GraphNeuralNetworks/examples) folder, in the [notebooks](https://github.com/JuliaGraphs/GraphNeuralNetworks.jl/tree/master/GraphNeuralNetworks/notebooks) folder, and in the [tutorials](https://juliagraphs.org/GraphNeuralNetworks.jl/docs/GraphNeuralNetworks.jl/dev/tutorials/gnn_intro/) section of the documentation.
3434

3535
### Data preparation
3636

GraphNeuralNetworks/test/Project.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
55
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
66
GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c"
77
GNNlib = "a6a84749-d869-43f8-aacc-be26a1996e48"
8-
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
98
GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694"
109
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
1110
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
11+
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
1212
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
1313
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1414
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
@@ -17,5 +17,3 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1717
TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a"
1818
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1919

20-
[compat]
21-
GPUArraysCore = "0.1"

0 commit comments

Comments
 (0)