31
31
32
32
TArray {T} (d:: Integer... ) where T = TArray (T, d)
33
33
TArray {T} (:: UndefInitializer , d:: Integer... ) where T = TArray (T, d)
34
+ TArray {T} (:: UndefInitializer , dim:: NTuple{N,Int} ) where {T,N} = TArray (T, dim)
34
35
TArray {T,N} (d:: Vararg{<:Integer,N} ) where {T,N} = TArray (T, d)
35
36
TArray {T,N} (:: UndefInitializer , d:: Vararg{<:Integer,N} ) where {T,N} = TArray {T,N} (d)
36
37
TArray {T,N} (dim:: NTuple{N,Int} ) where {T,N} = TArray (T, dim)
@@ -43,106 +44,12 @@ function TArray(T::Type, dim)
43
44
res
44
45
end
45
46
46
- #
47
- # Indexing Interface Implementation
48
- #
49
-
50
- function Base. getindex (S:: TArray{T, N} , I:: Vararg{Int,N} ) where {T, N}
51
- t, d = task_local_storage (S. ref)
52
- return d[I... ]
53
- end
54
-
55
- function Base. setindex! (S:: TArray{T, N} , x, I:: Vararg{Int,N} ) where {T, N}
56
- n, d = task_local_storage (S. ref)
57
- cn = n_copies ()
58
- newd = d
59
- if cn > n
60
- # println("[setindex!]: $(S.ref) copying data")
61
- newd = deepcopy (d)
62
- task_local_storage (S. ref, (cn, newd))
63
- end
64
- newd[I... ] = x
65
- end
66
-
67
- function Base. push! (S:: TArray{T} , x) where T
68
- n, d = task_local_storage (S. ref)
69
- cn = n_copies ()
70
- newd = d
71
- if cn > n
72
- newd = deepcopy (d)
73
- task_local_storage (S. ref, (cn, newd))
74
- end
75
- push! (newd, x)
76
- end
47
+ TArray (x:: AbstractArray ) = convert (TArray, x)
77
48
78
- function Base. pop! (S:: TArray )
79
- n, d = task_local_storage (S. ref)
80
- cn = n_copies ()
81
- newd = d
82
- if cn > n
83
- newd = deepcopy (d)
84
- task_local_storage (S. ref, (cn, newd))
85
- end
86
- pop! (d)
87
- end
88
-
89
- function Base. convert (:: Type{TArray} , x:: Array )
90
- return convert (TArray{eltype (x),ndims (x)}, x)
91
- end
92
- function Base. convert (:: Type{TArray{T,N}} , x:: Array{T,N} ) where {T,N}
93
- res = TArray {T,N} ()
94
- n = n_copies ()
95
- task_local_storage (res. ref, (n,x))
96
- return res
97
- end
98
-
99
- function Base. convert (:: Type{Array} , S:: TArray )
100
- return convert (Array{eltype (S), ndims (S)}, S)
101
- end
102
- function Base. convert (:: Type{Array{T,N}} , S:: TArray{T,N} ) where {T,N}
103
- n,d = task_local_storage (S. ref)
104
- c = convert (Array{T, N}, deepcopy (d))
105
- return c
106
- end
107
-
108
- function Base. display (S:: TArray )
109
- arr = S. orig_task. storage[S. ref][2 ]
110
- @warn " display(::TArray) prints the originating task's storage, " *
111
- " not the current task's storage. " *
112
- " Please use show(::TArray) to display the current task's version of a TArray."
113
- display (arr)
114
- end
115
-
116
- Base. show (io:: IO , S:: TArray ) = Base. show (io:: IO , task_local_storage (S. ref)[2 ])
117
-
118
- # Base.get(t::Task, S) = S
119
- # Base.get(t::Task, S::TArray) = (t.storage[S.ref][2])
120
- Base. get (S:: TArray ) = (current_task (). storage[S. ref][2 ])
121
-
122
- # #
123
- # Iterator Interface
124
- IteratorSize (:: Type{TArray{T, N}} ) where {T, N} = HasShape {N} ()
125
- IteratorEltype (:: Type{TArray} ) = HasEltype ()
126
-
127
- # Implements iterate, eltype, length, and size functions,
128
- # as well as firstindex, lastindex, ndims, and axes
129
- for F in (:iterate , :eltype , :length , :size ,
130
- :firstindex , :lastindex , :ndims , :axes )
131
- @eval Base.$ F (a:: TArray , args... ) = $ F (get (a), args... )
132
- end
133
-
134
- #
135
- # Similarity implementation
136
- #
137
-
138
- Base. similar (S:: TArray ) = tzeros (eltype (S), size (S))
139
- Base. similar (S:: TArray , :: Type{T} ) where {T} = tzeros (T, size (S))
140
- Base. similar (S:: TArray , dims:: Dims ) = tzeros (eltype (S), dims)
141
-
142
- # #########
143
- # tzeros #
144
- # #########
49
+ localize (x) = x
50
+ localize (x:: AbstractArray ) = TArray (x)
145
51
52
+ # Constructors
146
53
"""
147
54
tzeros(dims, ...)
148
55
@@ -195,3 +102,184 @@ function tfill(val::Real, dim)
195
102
task_local_storage (res. ref, (n,d))
196
103
return res
197
104
end
105
+
106
+ #
107
+ # Conversion between TArray and Array
108
+ #
109
+ _get (x) = x
110
+ function _get (x:: TArray )
111
+ n, d = task_local_storage (x. ref)
112
+ return d
113
+ end
114
+
115
+ function Base. convert (:: Type{Array} , x:: TArray )
116
+ return convert (Array{eltype (x), ndims (x)}, x)
117
+ end
118
+ function Base. convert (:: Type{Array{T,N}} , x:: TArray{T,N} ) where {T,N}
119
+ c = convert (Array{T, N}, deepcopy (_get (x)))
120
+ return c
121
+ end
122
+
123
+ function Base. convert (:: Type{TArray} , x:: AbstractArray )
124
+ return convert (TArray{eltype (x),ndims (x)}, x)
125
+ end
126
+ function Base. convert (:: Type{TArray{T,N}} , x:: AbstractArray{T,N} ) where {T,N}
127
+ res = TArray {T,N} ()
128
+ n = n_copies ()
129
+ task_local_storage (res. ref, (n,x))
130
+ return res
131
+ end
132
+
133
+ #
134
+ # Representation
135
+ #
136
+ function Base. show (io:: IO , :: MIME"text/plain" , x:: TArray )
137
+ arr = x. orig_task. storage[x. ref][2 ]
138
+ @warn " Here shows the originating task's storage, " *
139
+ " not the current task's storage. " *
140
+ " Please explicitly call show(::TArray) to display the current task's version of a TArray."
141
+ show (io, MIME (" text/plain" ), arr)
142
+ end
143
+
144
+ Base. show (io:: IO , x:: TArray ) = Base. show (io:: IO , task_local_storage (x. ref)[2 ])
145
+
146
+ function Base. summary (io:: IO , x:: TArray )
147
+ print (io, " Task Local Array: " )
148
+ summary (io, _get (x))
149
+ end
150
+
151
+ #
152
+ # Forward many methods to the underlying array
153
+ #
154
+ for F in (:size ,
155
+ :iterate ,
156
+ :firstindex , :lastindex , :axes )
157
+ @eval Base.$ F (a:: TArray , args... ) = $ F (_get (a), args... )
158
+ end
159
+
160
+ #
161
+ # Similarity implementation
162
+ #
163
+
164
+ Base. similar (x:: TArray , :: Type{T} , dims:: Dims ) where T = TArray (similar (_get (x), T, dims))
165
+
166
+ for op in [:(== ), :≈ ]
167
+ @eval Base.$ op (x:: TArray , y:: AbstractArray ) = Base.$ op (_get (x), y)
168
+ @eval Base.$ op (x:: AbstractArray , y:: TArray ) = Base.$ op (x, _get (y))
169
+ @eval Base.$ op (x:: TArray , y:: TArray ) = Base.$ op (_get (x), _get (y))
170
+ end
171
+
172
+ #
173
+ # Array Stdlib
174
+ #
175
+
176
+ # Indexing Interface
177
+ function Base. getindex (x:: TArray{T, N} , I:: Vararg{Int,N} ) where {T, N}
178
+ t, d = task_local_storage (x. ref)
179
+ return d[I... ]
180
+ end
181
+
182
+ function Base. setindex! (x:: TArray{T, N} , e, I:: Vararg{Int,N} ) where {T, N}
183
+ n, d = task_local_storage (x. ref)
184
+ cn = n_copies ()
185
+ newd = d
186
+ if cn > n
187
+ # println("[setindex!]: $(x.ref) copying data")
188
+ newd = deepcopy (d)
189
+ task_local_storage (x. ref, (cn, newd))
190
+ end
191
+ newd[I... ] = e
192
+ end
193
+
194
+ function Base. push! (x:: TArray{T} , e) where T
195
+ n, d = task_local_storage (x. ref)
196
+ cn = n_copies ()
197
+ newd = d
198
+ if cn > n
199
+ newd = deepcopy (d)
200
+ task_local_storage (x. ref, (cn, newd))
201
+ end
202
+ push! (newd, e)
203
+ end
204
+
205
+ function Base. pop! (x:: TArray )
206
+ n, d = task_local_storage (x. ref)
207
+ cn = n_copies ()
208
+ newd = d
209
+ if cn > n
210
+ newd = deepcopy (d)
211
+ task_local_storage (x. ref, (cn, newd))
212
+ end
213
+ pop! (d)
214
+ end
215
+
216
+ # Other methods from stdlib
217
+
218
+ Base. view (x:: TArray , inds... ; kwargs... ) =
219
+ Base. view (_get (x), inds... ; kwargs... ) |> localize
220
+ Base.:- (x:: TArray ) = (- _get (x)) |> localize
221
+ Base. transpose (x:: TArray ) = transpose (_get (x)) |> localize
222
+ Base. adjoint (x:: TArray ) = adjoint (_get (x)) |> localize
223
+ Base. repeat (x:: TArray ; kw... ) = repeat (_get (x); kw... ) |> localize
224
+
225
+ Base. hcat (xs:: Union{TArray{T,1}, TArray{T,2}} ...) where T =
226
+ hcat (_get .(xs)... ) |> localize
227
+ Base. vcat (xs:: Union{TArray{T,1}, TArray{T,2}} ...) where T =
228
+ vcat (_get .(xs)... ) |> localize
229
+ Base. cat (xs:: Union{TArray{T,1}, TArray{T,2}} ...; dims) where T =
230
+ cat (_get .(xs)... ; dims = dims) |> localize
231
+
232
+
233
+ Base. reshape (x:: TArray , dims:: Union{Colon,Int} ...) = reshape (_get (x), dims) |> localize
234
+ Base. reshape (x:: TArray , dims:: Tuple{Vararg{Union{Int,Colon}}} ) =
235
+ reshape (_get (x), Base. _reshape_uncolon (_get (x), dims)) |> localize
236
+ Base. reshape (x:: TArray , dims:: Tuple{Vararg{Int}} ) = reshape (_get (x), dims) |> localize
237
+
238
+ Base. permutedims (x:: TArray , perm) = permutedims (_get (x), perm) |> localize
239
+ Base. PermutedDimsArray (x:: TArray , perm) = PermutedDimsArray (_get (x), perm) |> localize
240
+ Base. reverse (x:: TArray ; dims) = reverse (_get (x), dims = dims) |> localize
241
+
242
+ Base. sum (x:: TArray ; dims = :) = sum (_get (x), dims = dims) |> localize
243
+ Base. sum (f:: Union{Function,Type} ,x:: TArray ) = sum (f .(_get (x))) |> localize
244
+ Base. prod (x:: TArray ; dims= :) = prod (_get (x); dims= dims) |> localize
245
+ Base. prod (f:: Union{Function, Type} , x:: TArray ) = prod (f .(_get (x))) |> localize
246
+
247
+ Base. findfirst (x:: TArray , args... ) = findfirst (_get (x), args... ) |> localize
248
+ Base. maximum (x:: TArray ; dims = :) = maximum (_get (x), dims = dims) |> localize
249
+ Base. minimum (x:: TArray ; dims = :) = minimum (_get (x), dims = dims) |> localize
250
+
251
+ Base.:/ (x:: TArray , y:: TArray ) = _get (x) / _get (y) |> localize
252
+ Base.:/ (x:: AbstractArray , y:: TArray ) = x / _get (y) |> localize
253
+ Base.:/ (x:: TArray , y:: AbstractArray ) = _get (x) / y |> localize
254
+ Base.:\ (x:: TArray , y:: TArray ) = _get (x) \ _get (y) |> localize
255
+ Base.:\ (x:: AbstractArray , y:: TArray ) = x \ _get (y) |> localize
256
+ Base.:\ (x:: TArray , y:: AbstractArray ) = _get (x) \ y |> localize
257
+ Base.:* (x:: TArray , y:: TArray ) = _get (x) * _get (y) |> localize
258
+ Base.:* (x:: AbstractArray , y:: TArray ) = x * _get (y) |> localize
259
+ Base.:* (x:: TArray , y:: AbstractArray ) = _get (x) * y |> localize
260
+
261
+ # broadcast
262
+ Base. BroadcastStyle (:: Type{TArray{T, N}} ) where {T, N} = Broadcast. ArrayStyle {TArray} ()
263
+ Broadcast. broadcasted (:: Broadcast.ArrayStyle{TArray} , f, args... ) = f .(_get .(args)... ) |> localize
264
+
265
+ import LinearAlgebra
266
+ import LinearAlgebra: \ , / , inv, det, logdet, logabsdet, norm
267
+
268
+ LinearAlgebra. inv (x:: TArray ) = inv (_get (x)) |> localize
269
+ LinearAlgebra. det (x:: TArray ) = det (_get (x)) |> localize
270
+ LinearAlgebra. logdet (x:: TArray ) = logdet (_get (x)) |> localize
271
+ LinearAlgebra. logabsdet (x:: TArray ) = logabsdet (_get (x)) |> localize
272
+ LinearAlgebra. norm (x:: TArray , p:: Real = 2 ) =
273
+ LinearAlgebra. norm (_get (x), p) |> localize
274
+
275
+ import LinearAlgebra: dot
276
+ dot (x:: TArray , ys:: TArray ) = dot (_get (x), _get (ys)) |> localize
277
+ dot (x:: AbstractArray , ys:: TArray ) = dot (x, _get (ys)) |> localize
278
+ dot (x:: TArray , ys:: AbstractArray ) = dot (_get (x), ys) |> localize
279
+
280
+ using Statistics
281
+ Statistics. mean (x:: TArray ; dims = :) = mean (_get (x), dims = dims) |> localize
282
+ Statistics. std (x:: TArray ; kw... ) = std (_get (x), kw... ) |> localize
283
+
284
+ # TODO
285
+ # * NNlib
0 commit comments