Skip to content

Commit 8eed6ee

Browse files
authored
Add tests for ArrayValue __array__ copy semantics (#710)
* Add numpy-backed ArrayValue tests * Guard ArrayValue numpy tests on devdeps * Use Setup module for ArrayValue numpy tests
1 parent 52db81d commit 8eed6ee

File tree

3 files changed

+43
-8
lines changed

3 files changed

+43
-8
lines changed

CondaPkg.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@ version = ">=3.10,<4"
1414

1515
[dev.deps]
1616
matplotlib = ""
17+
numpy = ""
1718
pyside6 = ""
1819
python = "<3.14"

src/JlWrap/array.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -355,22 +355,20 @@ class ArrayValue(AnyValue):
355355
@property
356356
def __array_interface__(self):
357357
return self._jl_callmethod($(pyjl_methodnum(pyjlarray_array_interface)))
358-
def __array__(self, dtype=None):
358+
def __array__(self, dtype=None, copy=None):
359+
import numpy
359360
# convert to an array-like object
360361
arr = self
361362
if not (hasattr(arr, "__array_interface__") or hasattr(arr, "__array_struct__")):
363+
if copy is False:
364+
raise ValueError("copy=False is not supported when collecting ArrayValue data")
362365
# the first attempt collects into an Array
363366
arr = self._jl_callmethod($(pyjl_methodnum(pyjlarray_array__array)))
364367
if not (hasattr(arr, "__array_interface__") or hasattr(arr, "__array_struct__")):
365368
# the second attempt collects into a PyObjectArray
366369
arr = self._jl_callmethod($(pyjl_methodnum(pyjlarray_array__pyobjectarray)))
367370
# convert to a numpy array if numpy is available
368-
try:
369-
import numpy
370-
arr = numpy.array(arr, dtype=dtype)
371-
except ImportError:
372-
pass
373-
return arr
371+
return numpy.array(arr, dtype=dtype, copy=copy)
374372
def to_numpy(self, dtype=None, copy=True, order="K"):
375373
import numpy
376374
return numpy.array(self, dtype=dtype, copy=copy, order=order)

test/JlWrap.jl

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@
215215
end
216216
end
217217

218-
@testitem "array" begin
218+
@testitem "array" setup=[Setup] begin
219219
@testset "type" begin
220220
@test pyis(pytype(pyjl(fill(nothing))), PythonCall.pyjlarraytype)
221221
@test pyis(pytype(pyjl([1 2; 3 4])), PythonCall.pyjlarraytype)
@@ -313,6 +313,42 @@ end
313313
@test pyjlvalue(x) == [0 2; 3 4]
314314
@test pyjlvalue(y) == [1 2; 3 4]
315315
end
316+
@testset "__array__" begin
317+
if Setup.devdeps
318+
np = pyimport("numpy")
319+
320+
numeric = pyjl(Float64[1, 2, 3])
321+
numeric_array = numeric.__array__()
322+
@test pyisinstance(numeric_array, np.ndarray)
323+
@test pyconvert(Vector{Float64}, numeric_array) == [1.0, 2.0, 3.0]
324+
325+
numeric_no_copy = numeric.__array__(copy=false)
326+
numeric_data = pyjlvalue(numeric)
327+
numeric_data[1] = 42.0
328+
@test pyconvert(Vector{Float64}, numeric_no_copy) == [42.0, 2.0, 3.0]
329+
330+
string_array = pyjl(["a", "b"])
331+
string_result = string_array.__array__()
332+
@test pyisinstance(string_result, np.ndarray)
333+
@test pyconvert(Vector{String}, pybuiltins.list(string_result)) == ["a", "b"]
334+
335+
err = try
336+
string_array.__array__(copy=false)
337+
nothing
338+
catch err
339+
err
340+
end
341+
@test err !== nothing
342+
@test err isa PythonCall.PyException
343+
@test pyis(err._t, pybuiltins.ValueError)
344+
@test occursin(
345+
"copy=False is not supported when collecting ArrayValue data",
346+
sprint(showerror, err),
347+
)
348+
else
349+
@test_skip Setup.devdeps
350+
end
351+
end
316352
@testset "array_interface" begin
317353
x = pyjl(Float32[1 2 3; 4 5 6]).__array_interface__
318354
@test pyisinstance(x, pybuiltins.dict)

0 commit comments

Comments
 (0)