Skip to content

Commit 98550c7

Browse files
committed
Add tons of tests, various fixes
1 parent fba3690 commit 98550c7

File tree

4 files changed

+404
-72
lines changed

4 files changed

+404
-72
lines changed

lib/DaggerGraphs/src/adjlist.jl

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -35,26 +35,27 @@ function Base.iterate(adjlist::SimpleAdjListStorage{T}) where T
3535
return nothing
3636
end
3737
edge_idx = something(findfirst(x->!isempty(x), adjlist.fadjlist[idx]))
38-
state = (idx, edge_idx)
38+
state = (T(idx), T(edge_idx))
3939
return Base.iterate(adjlist, state)
4040
end
4141
function Base.iterate(adjlist::SimpleAdjListStorage{T}, state::Tuple{T,T}) where T
4242
src, dst = state
4343
if src > length(adjlist.fadjlist)
4444
return nothing
4545
elseif dst > length(adjlist.fadjlist[src])
46-
src += 1
47-
dst = 1
46+
src += one(T)
47+
dst = one(T)
4848
head = src
4949
src = findfirst(x->!isempty(x), @view(adjlist.fadjlist[head:end]))
5050
if src === nothing
5151
return nothing
5252
end
53+
src = T(src)
5354
# Shift by offset from @view
54-
src += head - 1
55+
src += head - one(T)
5556
end
5657
value = (src, adjlist.fadjlist[src][dst])
57-
dst += 1
58+
dst += one(T)
5859
return (Edge(value), (src, dst))
5960
end
6061
function Graphs.add_edge!(adjlist::SimpleAdjListStorage{T,D}, edge) where {T,D}
@@ -82,18 +83,10 @@ function Graphs.add_edge!(adjlist::SimpleAdjListStorage{T,D}, edge) where {T,D}
8283
# Directed graphs have only forward edges
8384
push!(adjlist.fadjlist[src], dst)
8485
push!(adjlist.badjlist[dst], src)
85-
# Consider this behavior of using Graphs.jl
86-
# g = SimpleDiGraph(3,0); add_edge!(g,1,2)
87-
# println( "$(g.fadjlist) | $(g.badjlist)")
88-
# [[2], Int64[], Int64[]] | [Int64[], [1], Int64[]]
8986
else
9087
# Undirected graphs have both forward and backward edges
9188
push!(adjlist.fadjlist[src], dst)
9289
push!(adjlist.fadjlist[dst], src)
93-
# Consider this behavior of using Graphs.jl
94-
# g = SimpleGraph(3,0); add_edge!(g,1,2)
95-
# println(g.fadjlist)
96-
# [[2], [1], Int64[]]
9790
end
9891

9992
return true
@@ -175,37 +168,59 @@ Graphs.add_edge!(adj::AdjList{T}, src::Integer, dst::Integer) where T =
175168
Graphs.add_edge!(adj::AdjList, edge) = add_edge!(adj.data, edge)
176169
add_edges!(adj::AdjList, edges; all::Bool=true) = add_edges!(adj.data, edges; all)
177170
Graphs.edges(adj::AdjList) = copy(adj.data)
178-
function Graphs.inneighbors(adj::AdjList, v::Integer)
171+
function Graphs.inneighbors(adj::AdjList{T,D}, v::Integer) where {T,D}
179172
neighbors = Int[]
180173
for edge in adj.data
181174
src, dst = Tuple(edge)
182175
if dst == v
183176
push!(neighbors, src)
177+
elseif !D && src == v
178+
push!(neighbors, dst)
184179
end
185180
end
181+
sort!(neighbors)
182+
unique!(neighbors)
186183
return neighbors
187184
end
188-
function Graphs.inneighbors(adj::AdjList{T,D,SimpleAdjListStorage{T}}, v::Integer) where {T,D}
185+
function Graphs.inneighbors(adj::AdjList{T,D,SimpleAdjListStorage{T,D}}, v::Integer) where {T,D}
189186
if D
190-
return copy(adj.data.badjlist[v])
187+
return length(adj.data.badjlist) >= v ? copy(adj.data.badjlist[v]) : T[]
191188
else
192-
return invoke(inneighbors, Tuple{AdjList, Integer}, adj, v)
189+
if length(adj.data.fadjlist) >= v
190+
neighbors = copy(adj.data.fadjlist[v])
191+
sort!(neighbors)
192+
unique!(neighbors)
193+
return neighbors
194+
else
195+
return T[]
196+
end
193197
end
194198
end
195-
function Graphs.outneighbors(adj::AdjList, v::Integer)
199+
function Graphs.outneighbors(adj::AdjList{T,D}, v::Integer) where {T,D}
196200
neighbors = Int[]
197201
for edge in adj.data
198202
src, dst = Tuple(edge)
199203
if src == v
200204
push!(neighbors, dst)
205+
elseif !D && dst == v
206+
push!(neighbors, src)
201207
end
202208
end
209+
sort!(neighbors)
210+
unique!(neighbors)
203211
return neighbors
204212
end
205213
function Graphs.outneighbors(adj::AdjList{T,SimpleAdjListStorage{T,D}}, v::Integer) where {T,D}
206214
if D
207-
return copy(adj.data.fadjlist[v])
215+
return length(adj.data.fadjlist) >= v ? copy(adj.data.fadjlist[v]) : T[]
208216
else
209-
return invoke(outneighbors, Tuple{AdjList, Integer}, adj, v)
217+
if length(adj.data.fadjlist) >= v
218+
neighbors = copy(adj.data.fadjlist[v])
219+
sort!(neighbors)
220+
unique!(neighbors)
221+
return neighbors
222+
else
223+
return T[]
224+
end
210225
end
211226
end

lib/DaggerGraphs/src/dgraph.jl

Lines changed: 39 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,9 @@ mutable struct DGraph{T<:Integer, D} <: Graphs.AbstractGraph{T}
8686
return new{T,D}(Dagger.tochunk(state), Ref(false))
8787
end
8888
end
89-
DGraph(x; kwargs...) = DGraph{Int}(x; kwargs...)
9089
DGraph(; kwargs...) = DGraph{Int}(; kwargs...)
90+
DGraph(x::T; kwargs...) where {T<:Integer} = DGraph{T}(x; kwargs...)
91+
DGraph(x::AbstractGraph{T}; kwargs...) where {T<:Integer} = DGraph{T}(x; kwargs...)
9192
DGraph{T}(n::S; kwargs...) where {T<:Integer,S<:Integer} =
9293
DGraph{T}(T(n); kwargs...)
9394
function DGraph{T}(n::T; freeze::Bool=false, kwargs...) where {T<:Integer}
@@ -96,13 +97,13 @@ function DGraph{T}(n::T; freeze::Bool=false, kwargs...) where {T<:Integer}
9697
freeze && freeze!(g)
9798
return g
9899
end
99-
function DGraph{T}(sg::AbstractGraph{T}; directed::Bool=is_directed(sg), freeze::Bool=false, kwargs...) where {T<:Integer}
100-
g = DGraph{T}(nv(sg); directed, kwargs...)
100+
function DGraph{T}(sg::AbstractGraph{U}; directed::Bool=is_directed(sg), freeze::Bool=false, kwargs...) where {T<:Integer, U<:Integer}
101+
g = DGraph{T}(T(nv(sg)); directed, kwargs...)
101102
foreach(edges(sg)) do edge
102-
103-
add_edge!(g, edge)
103+
edge_conv = Edge(T(src(edge)), T(dst(edge)))
104+
add_edge!(g, edge_conv)
104105
if !is_directed(sg) && directed
105-
add_edge!(g, dst(edge), src(edge))
106+
add_edge!(g, dst(edge_conv), src(edge_conv))
106107
end
107108
end
108109
freeze && freeze!(g)
@@ -221,7 +222,7 @@ function freeze!(g::DGraphState)
221222
return false
222223
end
223224
g.frozen[] = true
224-
for part in nparts(g)
225+
for part in 1:nparts(g)
225226
if Dagger.istask(g.parts[part])
226227
g.parts[part] = fetch(g.parts[part]; raw=true)
227228
end
@@ -516,6 +517,9 @@ function Graphs.has_edge(g::DGraphState{T,D}, src::Integer, dst::Integer) where
516517
end
517518
end
518519
Graphs.is_directed(::DGraph{T,D}) where {T,D} = D
520+
Graphs.is_directed(::Type{<:DGraph{T,D}}) where {T,D} = D
521+
Graphs.is_directed(::DGraphState{T,D}) where {T,D} = D
522+
Graphs.is_directed(::Type{<:DGraphState{T,D}}) where {T,D} = D
519523
Graphs.vertices(g::DGraph{T}) where T = Base.OneTo{T}(nv(g))
520524
Graphs.edges(g::DGraph) = DGraphEdgeIter(g)
521525

@@ -663,30 +667,30 @@ function add_partition!(g::DGraphState{T,D}, part_data::Ref, back_data::Ref,
663667
push!(g.bg_adjs_e_meta, back_edge_meta_data[])
664668
return length(g.parts)
665669
end
666-
function Graphs.add_edge!(g::DGraph, src::Integer, dst::Integer)
670+
function Graphs.add_edge!(g::DGraph{T}, src::Integer, dst::Integer) where T
667671
check_not_frozen(g)
668-
return with_state(g, add_edge!, src, dst)
672+
return with_state(g, add_edge!, T(src), T(dst))
669673
end
670-
function Graphs.add_edge!(g::DGraph, edge::Edge)
674+
function Graphs.add_edge!(g::DGraph{T}, edge::Edge) where T
671675
check_not_frozen(g)
672-
return add_edge!(g, src(edge), dst(edge))
676+
return add_edge!(g, T(src(edge)), T(dst(edge)))
673677
end
674678
function Graphs.add_edge!(g::DGraphState{T,D}, src::Integer, dst::Integer) where {T,D}
675679
check_not_frozen(g)
676680

677-
src_part_idx = findfirst(span->src in span, g.parts_nv)
681+
src_part_idx = T(findfirst(span->src in span, g.parts_nv))
678682
@assert src_part_idx !== nothing "Source vertex $src does not exist"
679683

680-
dst_part_idx = findfirst(span->dst in span, g.parts_nv)
684+
dst_part_idx = T(findfirst(span->dst in span, g.parts_nv))
681685
@assert dst_part_idx !== nothing "Destination vertex $dst does not exist"
682686

683687
if src_part_idx == dst_part_idx
684688
# Edge exists within a single partition
685689
part = g.parts[src_part_idx]
686-
src_shift = src - (g.parts_nv[src_part_idx].start - 1)
687-
dst_shift = dst - (g.parts_nv[dst_part_idx].start - 1)
690+
src_shift = src - (g.parts_nv[src_part_idx].start - one(T))
691+
dst_shift = dst - (g.parts_nv[dst_part_idx].start - one(T))
688692
if exec_fast(add_edge!, part, src_shift, dst_shift)
689-
g.parts_ne[src_part_idx] += 1
693+
g.parts_ne[src_part_idx] += one(T)
690694
else
691695
return false
692696
end
@@ -701,13 +705,13 @@ function Graphs.add_edge!(g::DGraphState{T,D}, src::Integer, dst::Integer) where
701705
end
702706
if D
703707
# TODO: This will cause imbalance for many outgoing edges from a few vertices
704-
g.bg_adjs_ne_src[src_part_idx] += 1
708+
g.bg_adjs_ne_src[src_part_idx] += one(T)
705709
else
706710
owner_part_idx = edge_owner(src, dst, src_part_idx, dst_part_idx)
707-
g.bg_adjs_ne_src[owner_part_idx] += 1
711+
g.bg_adjs_ne_src[owner_part_idx] += one(T)
708712
end
709-
g.bg_adjs_ne[src_part_idx] += 1
710-
g.bg_adjs_ne[dst_part_idx] += 1
713+
g.bg_adjs_ne[src_part_idx] += one(T)
714+
g.bg_adjs_ne[dst_part_idx] += one(T)
711715
end
712716

713717
return true
@@ -771,23 +775,24 @@ function add_edges!(g::Graphs.AbstractSimpleGraph, shift, edges; all::Bool=true)
771775
end
772776

773777
"""
774-
edge_owner(src::Integer, dst::Integer, src_part_idx::Integer, dst_part_idx::Integer)
778+
edge_owner(src::T, dst::T, src_part_idx::T, dst_part_idx::T) where {T<:Integer}
775779
776780
Determine which partition owns the edge `(src, dst)`.
777-
FIXME: I do not like it. Both partitions should own the edge. (i.e. there should be data redundancy for the backgorund graph)
781+
FIXME: I do not like it. Both partitions should own the edge. (i.e. there should be data redundancy for the background graph)
778782
"""
779-
edge_owner(src::Int, dst::Int, src_part_idx::Int, dst_part_idx::Int) =
783+
edge_owner(src::T, dst::T, src_part_idx::T, dst_part_idx::T) where {T<:Integer} =
780784
iseven(hash(Base.unsafe_trunc(UInt, src+dst))) ? src_part_idx : dst_part_idx
781785

782-
Graphs.inneighbors(g::DGraph, v::Integer) = with_state(g, inneighbors, v)
783-
function Graphs.inneighbors(g::DGraphState{T}, v::Integer) where T
786+
Graphs.inneighbors(g::DGraph{T}, v::Integer) where T =
787+
with_state(g, inneighbors, T(v))
788+
function Graphs.inneighbors(g::DGraphState{T}, v::T) where T
784789
part_idx = findfirst(span->v in span, g.parts_nv)
785790
if part_idx === nothing
786791
throw(BoundsError(g, v))
787792
end
788793

789794
neighbors = T[]
790-
shift = g.parts_nv[part_idx].start - 1
795+
shift = g.parts_nv[part_idx].start - one(T)
791796

792797
# Check against local edges
793798
v_shift = v - shift
@@ -797,17 +802,18 @@ function Graphs.inneighbors(g::DGraphState{T}, v::Integer) where T
797802
# Check against background edges
798803
append!(neighbors, exec_fast(inneighbors, g.bg_adjs[part_idx], v))
799804

800-
return neighbors
805+
return sort!(neighbors)
801806
end
802-
Graphs.outneighbors(g::DGraph, v::Integer) = with_state(g, outneighbors, v)
803-
function Graphs.outneighbors(g::DGraphState{T}, v::Integer) where T
807+
Graphs.outneighbors(g::DGraph{T}, v::Integer) where T =
808+
with_state(g, outneighbors, T(v))
809+
function Graphs.outneighbors(g::DGraphState{T}, v::T) where T
804810
part_idx = findfirst(span->v in span, g.parts_nv)
805811
if part_idx === nothing
806812
throw(BoundsError(g, v))
807813
end
808814

809815
neighbors = T[]
810-
shift = g.parts_nv[part_idx].start - 1
816+
shift = g.parts_nv[part_idx].start - one(T)
811817

812818
# Check against local edges
813819
v_shift = v - shift
@@ -817,7 +823,7 @@ function Graphs.outneighbors(g::DGraphState{T}, v::Integer) where T
817823
# Check against background edges
818824
append!(neighbors, exec_fast(outneighbors, g.bg_adjs[part_idx], v))
819825

820-
return neighbors
826+
return sort!(neighbors)
821827
end
822828

823829
"""
@@ -878,8 +884,8 @@ partition_edges(g::DGraph, part::Integer) =
878884
879885
Get the edges of the partition `part` of the graph state `g`.
880886
"""
881-
function partition_edges(g::DGraphState, part::Integer)
882-
shift = g.parts_nv[part].start - 1
887+
function partition_edges(g::DGraphState{T}, part::Integer) where T
888+
shift = g.parts_nv[part].start - one(T)
883889
part_edges = map(edge->Edge(src(edge)+shift, dst(edge)+shift), exec_fast(edges, g.parts[part]))
884890
back_edges = exec_fast(edges, g.bg_adjs[part])
885891
return part_edges, back_edges

lib/DaggerGraphs/src/edgeiter.jl

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,45 +4,49 @@ struct DGraphEdgeIter{T,M} <: Graphs.AbstractEdgeIter
44
end
55
DGraphEdgeIter(g::DGraph{T}; metadata::Bool=false, meta_f=nothing) where T =
66
DGraphEdgeIter{T,metadata}(fetch(g.state), meta_f)
7-
struct DGraphEdgeIterState
7+
struct DGraphEdgeIterState{T}
88
adj::Bool
99
part::Int
1010
idx::Int
1111
cache
1212
cache_meta
13+
seen::Union{Set{Edge{T}},Nothing}
1314
end
1415
Base.length(iter::DGraphEdgeIter) = ne(iter.graph)
1516
Base.eltype(iter::DGraphEdgeIter{T,false}) where T = Edge{T}
1617
Base.eltype(iter::DGraphEdgeIter{T,true}) where T = Tuple{Edge{T},Any}
17-
function Base.iterate(iter::DGraphEdgeIter)
18+
function Base.iterate(iter::DGraphEdgeIter{T}) where T
1819
g = iter.graph
1920
if nv(g) == 0
2021
return nothing
2122
elseif sum(g.parts_ne; init=0) > 0
2223
# Start with partitions
23-
return iterate(iter, DGraphEdgeIterState(false, 1, 1, nothing, nothing))
24+
seen = is_directed(g) ? nothing : Set{Edge{T}}()
25+
return iterate(iter, DGraphEdgeIterState{T}(false, 1, 1, nothing, nothing, seen))
2426
elseif sum(g.bg_adjs_ne_src; init=0) > 0
2527
# Start with background AdjLists
26-
return iterate(iter, DGraphEdgeIterState(true, 1, 1, nothing, nothing))
28+
seen = is_directed(g) ? nothing : Set{Edge{T}}()
29+
return iterate(iter, DGraphEdgeIterState{T}(true, 1, 1, nothing, nothing, seen))
2730
else
2831
return nothing
2932
end
3033
end
31-
function Base.iterate(iter::DGraphEdgeIter{T,M}, state::DGraphEdgeIterState) where {T,M}
34+
function Base.iterate(iter::DGraphEdgeIter{T,M}, state::DGraphEdgeIterState{T}) where {T,M}
3235
g = iter.graph
3336
adj = state.adj
3437
part = state.part
3538
idx = state.idx
3639
cache = state.cache
3740
cache_meta = state.cache_meta
41+
seen = state.seen
3842

3943
edge_metadata_for(meta, edges) = map(edge->meta[edge[1],edge[2]], edges)
4044

4145
@label start
4246
if !adj
4347
if part > length(g.parts)
4448
# Restart with background AdjLists
45-
return iterate(iter, DGraphEdgeIterState(true, 1, 1, nothing, nothing))
49+
return iterate(iter, DGraphEdgeIterState{T}(true, 1, 1, nothing, nothing, seen))
4650
end
4751
if cache === nothing
4852
cache = map(Tuple, fetch(Dagger.@spawn edges(g.parts[part])))
@@ -105,6 +109,18 @@ function Base.iterate(iter::DGraphEdgeIter{T,M}, state::DGraphEdgeIterState) whe
105109
@goto start
106110
end
107111

112+
# Restart if this edge has already been seen (undirected case)
113+
if seen !== nothing
114+
if value in seen
115+
@goto start
116+
end
117+
value_rev = Edge(dst(value), src(value))
118+
if value_rev in seen
119+
@goto start
120+
end
121+
push!(seen, value)
122+
end
123+
108124
return (M ? (value, value_meta) : value,
109-
DGraphEdgeIterState(adj, part, idx, cache, cache_meta))
125+
DGraphEdgeIterState{T}(adj, part, idx, cache, cache_meta, seen))
110126
end

0 commit comments

Comments
 (0)