Skip to content

Commit fc8afbf

Browse files
authored
Complete Array API for TArray (#82)
* Add many methods for `TArray` Most of the newly added methods are from Tracker.jl's TrackedArray. * remove some unnecessary methods * unit test cases and benchmarks
1 parent 4e8dda9 commit fc8afbf

File tree

6 files changed

+256
-101
lines changed

6 files changed

+256
-101
lines changed

Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,16 @@ version = "0.5.0"
77

88
[deps]
99
Libtask_jll = "3ae2931a-708c-5973-9c38-ccf7496fb450"
10+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
11+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1012

1113
[compat]
1214
Libtask_jll = "0.4"
1315
julia = "1.3"
1416

1517
[extras]
1618
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
19+
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
1720

1821
[targets]
19-
test = ["Test"]
22+
test = ["Test", "BenchmarkTools"]

deps/methods_of_array.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
using LinearAlgebra
2+
using Statistics
3+
4+
using InteractiveUtils
5+
6+
const MOD_METHODS = Dict{Module, Vector{Symbol}}()
7+
8+
methods = methodswith(AbstractArray)
9+
10+
for method in methods
11+
mod = method.module
12+
names = get!(MOD_METHODS, mod, Vector{Symbol}())
13+
push!(names, method.name)
14+
end
15+
16+
for (k, v) in MOD_METHODS
17+
print(k)
18+
print(":\n\t")
19+
show(v)
20+
print("\n\n")
21+
end

src/tarray.jl

Lines changed: 186 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ end
3131

3232
TArray{T}(d::Integer...) where T = TArray(T, d)
3333
TArray{T}(::UndefInitializer, d::Integer...) where T = TArray(T, d)
34+
TArray{T}(::UndefInitializer, dim::NTuple{N,Int}) where {T,N} = TArray(T, dim)
3435
TArray{T,N}(d::Vararg{<:Integer,N}) where {T,N} = TArray(T, d)
3536
TArray{T,N}(::UndefInitializer, d::Vararg{<:Integer,N}) where {T,N} = TArray{T,N}(d)
3637
TArray{T,N}(dim::NTuple{N,Int}) where {T,N} = TArray(T, dim)
@@ -43,106 +44,12 @@ function TArray(T::Type, dim)
4344
res
4445
end
4546

46-
#
47-
# Indexing Interface Implementation
48-
#
49-
50-
function Base.getindex(S::TArray{T, N}, I::Vararg{Int,N}) where {T, N}
51-
t, d = task_local_storage(S.ref)
52-
return d[I...]
53-
end
54-
55-
function Base.setindex!(S::TArray{T, N}, x, I::Vararg{Int,N}) where {T, N}
56-
n, d = task_local_storage(S.ref)
57-
cn = n_copies()
58-
newd = d
59-
if cn > n
60-
# println("[setindex!]: $(S.ref) copying data")
61-
newd = deepcopy(d)
62-
task_local_storage(S.ref, (cn, newd))
63-
end
64-
newd[I...] = x
65-
end
66-
67-
function Base.push!(S::TArray{T}, x) where T
68-
n, d = task_local_storage(S.ref)
69-
cn = n_copies()
70-
newd = d
71-
if cn > n
72-
newd = deepcopy(d)
73-
task_local_storage(S.ref, (cn, newd))
74-
end
75-
push!(newd, x)
76-
end
47+
TArray(x::AbstractArray) = convert(TArray, x)
7748

78-
function Base.pop!(S::TArray)
79-
n, d = task_local_storage(S.ref)
80-
cn = n_copies()
81-
newd = d
82-
if cn > n
83-
newd = deepcopy(d)
84-
task_local_storage(S.ref, (cn, newd))
85-
end
86-
pop!(d)
87-
end
88-
89-
function Base.convert(::Type{TArray}, x::Array)
90-
return convert(TArray{eltype(x),ndims(x)}, x)
91-
end
92-
function Base.convert(::Type{TArray{T,N}}, x::Array{T,N}) where {T,N}
93-
res = TArray{T,N}()
94-
n = n_copies()
95-
task_local_storage(res.ref, (n,x))
96-
return res
97-
end
98-
99-
function Base.convert(::Type{Array}, S::TArray)
100-
return convert(Array{eltype(S), ndims(S)}, S)
101-
end
102-
function Base.convert(::Type{Array{T,N}}, S::TArray{T,N}) where {T,N}
103-
n,d = task_local_storage(S.ref)
104-
c = convert(Array{T, N}, deepcopy(d))
105-
return c
106-
end
107-
108-
function Base.display(S::TArray)
109-
arr = S.orig_task.storage[S.ref][2]
110-
@warn "display(::TArray) prints the originating task's storage, " *
111-
"not the current task's storage. " *
112-
"Please use show(::TArray) to display the current task's version of a TArray."
113-
display(arr)
114-
end
115-
116-
Base.show(io::IO, S::TArray) = Base.show(io::IO, task_local_storage(S.ref)[2])
117-
118-
# Base.get(t::Task, S) = S
119-
# Base.get(t::Task, S::TArray) = (t.storage[S.ref][2])
120-
Base.get(S::TArray) = (current_task().storage[S.ref][2])
121-
122-
##
123-
# Iterator Interface
124-
IteratorSize(::Type{TArray{T, N}}) where {T, N} = HasShape{N}()
125-
IteratorEltype(::Type{TArray}) = HasEltype()
126-
127-
# Implements iterate, eltype, length, and size functions,
128-
# as well as firstindex, lastindex, ndims, and axes
129-
for F in (:iterate, :eltype, :length, :size,
130-
:firstindex, :lastindex, :ndims, :axes)
131-
@eval Base.$F(a::TArray, args...) = $F(get(a), args...)
132-
end
133-
134-
#
135-
# Similarity implementation
136-
#
137-
138-
Base.similar(S::TArray) = tzeros(eltype(S), size(S))
139-
Base.similar(S::TArray, ::Type{T}) where {T} = tzeros(T, size(S))
140-
Base.similar(S::TArray, dims::Dims) = tzeros(eltype(S), dims)
141-
142-
##########
143-
# tzeros #
144-
##########
49+
localize(x) = x
50+
localize(x::AbstractArray) = TArray(x)
14551

52+
# Constructors
14653
"""
14754
tzeros(dims, ...)
14855
@@ -195,3 +102,184 @@ function tfill(val::Real, dim)
195102
task_local_storage(res.ref, (n,d))
196103
return res
197104
end
105+
106+
#
107+
# Conversion between TArray and Array
108+
#
109+
_get(x) = x
110+
function _get(x::TArray)
111+
n, d = task_local_storage(x.ref)
112+
return d
113+
end
114+
115+
function Base.convert(::Type{Array}, x::TArray)
116+
return convert(Array{eltype(x), ndims(x)}, x)
117+
end
118+
function Base.convert(::Type{Array{T,N}}, x::TArray{T,N}) where {T,N}
119+
c = convert(Array{T, N}, deepcopy(_get(x)))
120+
return c
121+
end
122+
123+
function Base.convert(::Type{TArray}, x::AbstractArray)
124+
return convert(TArray{eltype(x),ndims(x)}, x)
125+
end
126+
function Base.convert(::Type{TArray{T,N}}, x::AbstractArray{T,N}) where {T,N}
127+
res = TArray{T,N}()
128+
n = n_copies()
129+
task_local_storage(res.ref, (n,x))
130+
return res
131+
end
132+
133+
#
134+
# Representation
135+
#
136+
function Base.show(io::IO, ::MIME"text/plain", x::TArray)
137+
arr = x.orig_task.storage[x.ref][2]
138+
@warn "Here shows the originating task's storage, " *
139+
"not the current task's storage. " *
140+
"Please explicitly call show(::TArray) to display the current task's version of a TArray."
141+
show(io, MIME("text/plain"), arr)
142+
end
143+
144+
Base.show(io::IO, x::TArray) = Base.show(io::IO, task_local_storage(x.ref)[2])
145+
146+
function Base.summary(io::IO, x::TArray)
147+
print(io, "Task Local Array: ")
148+
summary(io, _get(x))
149+
end
150+
151+
#
152+
# Forward many methods to the underlying array
153+
#
154+
for F in (:size,
155+
:iterate,
156+
:firstindex, :lastindex, :axes)
157+
@eval Base.$F(a::TArray, args...) = $F(_get(a), args...)
158+
end
159+
160+
#
161+
# Similarity implementation
162+
#
163+
164+
Base.similar(x::TArray, ::Type{T}, dims::Dims) where T = TArray(similar(_get(x), T, dims))
165+
166+
for op in [:(==), :]
167+
@eval Base.$op(x::TArray, y::AbstractArray) = Base.$op(_get(x), y)
168+
@eval Base.$op(x::AbstractArray, y::TArray) = Base.$op(x, _get(y))
169+
@eval Base.$op(x::TArray, y::TArray) = Base.$op(_get(x), _get(y))
170+
end
171+
172+
#
173+
# Array Stdlib
174+
#
175+
176+
# Indexing Interface
177+
function Base.getindex(x::TArray{T, N}, I::Vararg{Int,N}) where {T, N}
178+
t, d = task_local_storage(x.ref)
179+
return d[I...]
180+
end
181+
182+
function Base.setindex!(x::TArray{T, N}, e, I::Vararg{Int,N}) where {T, N}
183+
n, d = task_local_storage(x.ref)
184+
cn = n_copies()
185+
newd = d
186+
if cn > n
187+
# println("[setindex!]: $(x.ref) copying data")
188+
newd = deepcopy(d)
189+
task_local_storage(x.ref, (cn, newd))
190+
end
191+
newd[I...] = e
192+
end
193+
194+
function Base.push!(x::TArray{T}, e) where T
195+
n, d = task_local_storage(x.ref)
196+
cn = n_copies()
197+
newd = d
198+
if cn > n
199+
newd = deepcopy(d)
200+
task_local_storage(x.ref, (cn, newd))
201+
end
202+
push!(newd, e)
203+
end
204+
205+
function Base.pop!(x::TArray)
206+
n, d = task_local_storage(x.ref)
207+
cn = n_copies()
208+
newd = d
209+
if cn > n
210+
newd = deepcopy(d)
211+
task_local_storage(x.ref, (cn, newd))
212+
end
213+
pop!(d)
214+
end
215+
216+
# Other methods from stdlib
217+
218+
Base.view(x::TArray, inds...; kwargs...) =
219+
Base.view(_get(x), inds...; kwargs...) |> localize
220+
Base.:-(x::TArray) = (- _get(x)) |> localize
221+
Base.transpose(x::TArray) = transpose(_get(x)) |> localize
222+
Base.adjoint(x::TArray) = adjoint(_get(x)) |> localize
223+
Base.repeat(x::TArray; kw...) = repeat(_get(x); kw...) |> localize
224+
225+
Base.hcat(xs::Union{TArray{T,1}, TArray{T,2}}...) where T =
226+
hcat(_get.(xs)...) |> localize
227+
Base.vcat(xs::Union{TArray{T,1}, TArray{T,2}}...) where T =
228+
vcat(_get.(xs)...) |> localize
229+
Base.cat(xs::Union{TArray{T,1}, TArray{T,2}}...; dims) where T =
230+
cat(_get.(xs)...; dims = dims) |> localize
231+
232+
233+
Base.reshape(x::TArray, dims::Union{Colon,Int}...) = reshape(_get(x), dims) |> localize
234+
Base.reshape(x::TArray, dims::Tuple{Vararg{Union{Int,Colon}}}) =
235+
reshape(_get(x), Base._reshape_uncolon(_get(x), dims)) |> localize
236+
Base.reshape(x::TArray, dims::Tuple{Vararg{Int}}) = reshape(_get(x), dims) |> localize
237+
238+
Base.permutedims(x::TArray, perm) = permutedims(_get(x), perm) |> localize
239+
Base.PermutedDimsArray(x::TArray, perm) = PermutedDimsArray(_get(x), perm) |> localize
240+
Base.reverse(x::TArray; dims) = reverse(_get(x), dims = dims) |> localize
241+
242+
Base.sum(x::TArray; dims = :) = sum(_get(x), dims = dims) |> localize
243+
Base.sum(f::Union{Function,Type},x::TArray) = sum(f.(_get(x))) |> localize
244+
Base.prod(x::TArray; dims=:) = prod(_get(x); dims=dims) |> localize
245+
Base.prod(f::Union{Function, Type}, x::TArray) = prod(f.(_get(x))) |> localize
246+
247+
Base.findfirst(x::TArray, args...) = findfirst(_get(x), args...) |> localize
248+
Base.maximum(x::TArray; dims = :) = maximum(_get(x), dims = dims) |> localize
249+
Base.minimum(x::TArray; dims = :) = minimum(_get(x), dims = dims) |> localize
250+
251+
Base.:/(x::TArray, y::TArray) = _get(x) / _get(y) |> localize
252+
Base.:/(x::AbstractArray, y::TArray) = x / _get(y) |> localize
253+
Base.:/(x::TArray, y::AbstractArray) = _get(x) / y |> localize
254+
Base.:\(x::TArray, y::TArray) = _get(x) \ _get(y) |> localize
255+
Base.:\(x::AbstractArray, y::TArray) = x \ _get(y) |> localize
256+
Base.:\(x::TArray, y::AbstractArray) = _get(x) \ y |> localize
257+
Base.:*(x::TArray, y::TArray) = _get(x) * _get(y) |> localize
258+
Base.:*(x::AbstractArray, y::TArray) = x * _get(y) |> localize
259+
Base.:*(x::TArray, y::AbstractArray) = _get(x) * y |> localize
260+
261+
# broadcast
262+
Base.BroadcastStyle(::Type{TArray{T, N}}) where {T, N} = Broadcast.ArrayStyle{TArray}()
263+
Broadcast.broadcasted(::Broadcast.ArrayStyle{TArray}, f, args...) = f.(_get.(args)...) |> localize
264+
265+
import LinearAlgebra
266+
import LinearAlgebra: \, /, inv, det, logdet, logabsdet, norm
267+
268+
LinearAlgebra.inv(x::TArray) = inv(_get(x)) |> localize
269+
LinearAlgebra.det(x::TArray) = det(_get(x)) |> localize
270+
LinearAlgebra.logdet(x::TArray) = logdet(_get(x)) |> localize
271+
LinearAlgebra.logabsdet(x::TArray) = logabsdet(_get(x)) |> localize
272+
LinearAlgebra.norm(x::TArray, p::Real = 2) =
273+
LinearAlgebra.norm(_get(x), p) |> localize
274+
275+
import LinearAlgebra: dot
276+
dot(x::TArray, ys::TArray) = dot(_get(x), _get(ys)) |> localize
277+
dot(x::AbstractArray, ys::TArray) = dot(x, _get(ys)) |> localize
278+
dot(x::TArray, ys::AbstractArray) = dot(_get(x), ys) |> localize
279+
280+
using Statistics
281+
Statistics.mean(x::TArray; dims = :) = mean(_get(x), dims = dims) |> localize
282+
Statistics.std(x::TArray; kw...) = std(_get(x), kw...) |> localize
283+
284+
# TODO
285+
# * NNlib

test/benchmarks.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
using BenchmarkTools
2+
using Libtask
3+
4+
println("= Benchmarks on Arrays =")
5+
A = rand(100, 100)
6+
x, y = abs.(rand(Int, 2) .% 100)
7+
print("indexing: ")
8+
@btime $A[$x, $y] + $A[$x, $y]
9+
print("set indexing: ")
10+
@btime $A[$x, $y] = 1
11+
print("broadcast: ")
12+
@btime $A .+ $A
13+
14+
println("= Benchmarks on TArrays =")
15+
TA = Libtask.localize(deepcopy(A))
16+
print("indexing: ")
17+
@btime $TA[$x, $y] + $TA[$x, $y]
18+
print("set indexing: ")
19+
@btime $TA[$x, $y] = 1
20+
print("broadcast: ")
21+
@btime $TA .+ $TA

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,7 @@ using Test
44
include("ctask.jl")
55
include("tarray.jl")
66
include("tref.jl")
7+
8+
if get(ENV, "BENCHMARK", nothing) != nothing
9+
include("benchmarks.jl")
10+
end

0 commit comments

Comments
 (0)