Skip to content

Commit 42a82ff

Browse files
Fix convert (#70)
* fix convert * fix the fix * add tests * two more tests * Update src/tarray.jl Co-authored-by: David Widmann <[email protected]> * david's suggestion * Update src/tarray.jl Co-authored-by: David Widmann <[email protected]> Co-authored-by: David Widmann <[email protected]>
1 parent e9db9e0 commit 42a82ff

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

src/tarray.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,15 +88,21 @@ function Base.pop!(S::TArray)
8888
end
8989

9090
function Base.convert(::Type{TArray}, x::Array)
91-
res = TArray{typeof(x[1]),ndims(x)}();
91+
return convert(TArray{eltype(x),ndims(x)}, x)
92+
end
93+
function Base.convert(::Type{TArray{T,N}}, x::Array{T,N}) where {T,N}
94+
res = TArray{T,N}()
9295
n = n_copies()
9396
task_local_storage(res.ref, (n,x))
9497
return res
9598
end
9699

97-
function Base.convert(::Array, x::Type{TArray})
100+
function Base.convert(::Type{Array}, S::TArray)
101+
return convert(Array{eltype(S), ndims(S)}, S)
102+
end
103+
function Base.convert(::Type{Array{T,N}}, S::TArray{T,N}) where {T,N}
98104
n,d = task_local_storage(S.ref)
99-
c = deepcopy(d)
105+
c = convert(Array{T, N}, deepcopy(d))
100106
return c
101107
end
102108

test/tarray.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,12 @@
4646
ta5[i] = i
4747
end
4848
@test Array(ta5) == [1, 2, 3, 4]
49+
@test convert(Array, ta5) == [1, 2, 3, 4]
50+
@test convert(Array{Int, 1}, ta5) == [1, 2, 3, 4]
51+
@test ta5 == convert(TArray, [1, 2, 3, 4])
52+
@test ta5 == convert(TArray{Int, 1}, [1, 2, 3, 4])
53+
@test_throws MethodError convert(TArray{Int, 2}, [1, 2, 3, 4])
54+
@test_throws MethodError convert(Array{Int, 2}, ta5)
4955

5056
@test Array(tzeros(4)) == zeros(4)
5157

@@ -60,7 +66,6 @@
6066
ta7 = TArray{Int, 2}((2, 2))
6167
end
6268

63-
6469
@testset "task copy" begin
6570
function f()
6671
t = TArray(Int, 1)

0 commit comments

Comments
 (0)