Skip to content

Commit 8eb62f2

Browse files
jw3126stevengj
authored andcommitted
add pyiterate (#594)
* add pyiterate * add PyIterator * Skip PyIterator tests on old julia versions * fix IteratorSize * fix * fix * compute IteratorSize in PyIterator constructor * fix * fix * Update src/pyiterator.jl Co-Authored-By: jw3126 <[email protected]> * Update src/pyiterator.jl Co-Authored-By: jw3126 <[email protected]> * Update src/pyiterator.jl Co-Authored-By: jw3126 <[email protected]> * Update src/pyiterator.jl Co-Authored-By: jw3126 <[email protected]> * fix * fix * Update pyiterator.jl
1 parent 48d730f commit 8eb62f2

File tree

2 files changed

+107
-2
lines changed

2 files changed

+107
-2
lines changed

src/pyiterator.jl

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#########################################################################
44
# Iterating over Python objects in Julia
55

6+
Base.IteratorSize(::Type{PyObject}) = Base.SizeUnknown()
67
function _start(po::PyObject)
78
sigatomic_begin()
89
try
@@ -29,16 +30,81 @@ end
2930

3031
Base.done(po::PyObject, s) = ispynull(s[1])
3132
else
32-
function Base.iterate(po::PyObject, s=_start(po))
33+
"""
34+
PyIterator{T}(pyobject)
35+
36+
Wrap `pyobject::PyObject` into an iterator, that produces items of type `T`. To be more precise `convert(T, item)` is applied in each iteration. This can be useful to avoid automatic conversion of items into corresponding julia types.
37+
```jldoctest
38+
julia> using PyCall
39+
40+
julia> l = PyObject([PyObject(1), PyObject(2)])
41+
PyObject [1, 2]
42+
43+
julia> piter = PyCall.PyIterator{PyAny}(l)
44+
PyCall.PyIterator{PyAny,Base.HasLength()}(PyObject [1, 2])
45+
46+
julia> collect(piter)
47+
2-element Array{Any,1}:
48+
1
49+
2
50+
51+
julia> piter = PyCall.PyIterator(l)
52+
PyCall.PyIterator{PyObject,Base.HasLength()}(PyObject [1, 2])
53+
54+
julia> collect(piter)
55+
2-element Array{PyObject,1}:
56+
PyObject 1
57+
PyObject 2
58+
```
59+
"""
60+
struct PyIterator{T,S}
61+
o::PyObject
62+
end
63+
64+
function _compute_IteratorSize(o::PyObject)
65+
S = try
66+
length(o)
67+
Base.HasLength
68+
catch err
69+
if !(err isa PyError && pyisinstance(err.val, @pyglobalobjptr :PyExc_TypeError))
70+
rethrow()
71+
end
72+
Base.SizeUnknown
73+
end
74+
end
75+
function PyIterator(o::PyObject)
76+
PyIterator{PyObject}(o)
77+
end
78+
function (::Type{PyIterator{T}})(o::PyObject) where {T}
79+
S = _compute_IteratorSize(o)
80+
PyIterator{T,S}(o)
81+
end
82+
83+
Base.eltype(::Type{<:PyIterator{T}}) where T = T
84+
Base.eltype(::Type{<:PyIterator{PyAny}}) = Any
85+
Base.length(piter::PyIterator) = length(piter.o)
86+
87+
Base.IteratorSize(::Type{<: PyIterator{T,S}}) where {T,S} = S()
88+
89+
_start(piter::PyIterator) = _start(piter.o)
90+
91+
function Base.iterate(piter::PyIterator{T}, s=_start(piter)) where {T}
3392
ispynull(s[1]) && return nothing
3493
sigatomic_begin()
3594
try
3695
nxt = PyObject(@pycheck ccall((@pysym :PyIter_Next), PyPtr, (PyPtr,), s[2]))
37-
return (convert(PyAny, s[1]), (nxt, s[2]))
96+
return (convert(T,s[1]), (nxt, s[2]))
3897
finally
3998
sigatomic_end()
4099
end
41100
end
101+
function Base.iterate(po::PyObject, s=_start(po))
102+
# avoid the constructor that calls length
103+
# since that might be an expensive operation
104+
# even if length is cheap, this adds 10% performance
105+
piter = PyIterator{PyAny, Base.SizeUnknown}(po)
106+
iterate(piter, s)
107+
end
42108
end
43109

44110
# issue #216

test/runtests.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,45 @@ def try_call(f):
687687
pybuiltin("Exception"))
688688
end
689689

690+
@static if VERSION < v"0.7.0-DEV.5126" # julia#25261
691+
# PyIterator not defined in this julia version
692+
else
693+
@testset "PyIterator" begin
694+
arr = [1,2]
695+
o = PyObject(arr)
696+
c_pyany = collect(PyCall.PyIterator{PyAny}(o))
697+
@test c_pyany == arr
698+
@test c_pyany[1] isa Integer
699+
@test c_pyany[2] isa Integer
700+
701+
c_f64 = collect(PyCall.PyIterator{Float64}(o))
702+
@test c_f64 == arr
703+
@test eltype(c_f64) == Float64
704+
705+
i1 = PyObject([1])
706+
i2 = PyObject([2])
707+
l = PyObject([i1,i2])
708+
709+
piter = PyCall.PyIterator(l)
710+
@test length(piter) == 2
711+
@test length(collect(piter)) == 2
712+
r1, r2 = collect(piter)
713+
@test r1.o === i1.o
714+
@test r2.o === i2.o
715+
716+
@test Base.IteratorSize(PyCall.PyIterator(PyObject(1))) == Base.SizeUnknown()
717+
@test Base.IteratorSize(PyCall.PyIterator(PyObject([1]))) == Base.HasLength()
718+
719+
# 594
720+
@test collect(zip(py"iter([1, 2, 3])", 1:3)) ==
721+
[(1, 1), (2, 2), (3, 3)]
722+
@test collect(zip(PyCall.PyIterator{Int}(py"iter([1, 2, 3])"), 1:3)) ==
723+
[(1, 1), (2, 2), (3, 3)]
724+
@test collect(zip(PyCall.PyIterator(py"[1, 2, 3]"o), 1:3)) ==
725+
[(1, 1), (2, 2), (3, 3)]
726+
end
727+
end
728+
690729
@testset "atexit" begin
691730
if VERSION < v"0.7-"
692731
setup = ""

0 commit comments

Comments
 (0)