Skip to content

Commit 4fce3d0

Browse files
committed
Add PyString wrapper for Python strings
1 parent be16342 commit 4fce3d0

File tree

13 files changed

+152
-67
lines changed

13 files changed

+152
-67
lines changed

docs/src/conversion-to-julia.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ From Python, the arguments to a Julia function will be converted according to th
3737
| `None` | `Missing` |
3838
| `bytes` | `Vector{UInt8}`, `Vector{Int8}`, `String` |
3939
| `str` | `String`, `Symbol`, `Char`, `Vector{UInt8}`, `Vector{Int8}` |
40+
| `str` | `PyString` |
4041
| `range` | `UnitRange` |
4142
| `collections.abc.Mapping` | `Dict` |
4243
| `collections.abc.Iterable` | `Vector`, `Set`, `Tuple`, `NamedTuple`, `Pair` |

docs/src/pythoncall-reference.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ PyList
184184
PySet
185185
PyDict
186186
PyIterable
187+
PyString
187188
PyArray
188189
PyIO
189190
PyTable

docs/src/pythoncall.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,9 @@ Python: [3, 4, 5, None, 1, 2]
227227

228228
There are wrappers for other container types, such as [`PyDict`](@ref) and [`PySet`](@ref).
229229

230+
`PyString` is a zero-copy wrapper around a Python `str`, exposing it as a Julia
231+
`AbstractString` backed by the UTF-8 pointer cached by Python.
232+
230233
The wrapper [`PyArray`](@ref) provides a Julia array view of any Python array, i.e. anything
231234
satisfying either the buffer protocol or the numpy array interface. This includes things
232235
like `bytes`, `bytearray`, `array.array` and `numpy.ndarray`:

src/API/exports.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ export PyDict
117117
export PyIO
118118
export PyIterable
119119
export PyList
120+
export PyString
120121
export PyPandasDataFrame
121122
export PySet
122123
export PyTable

src/API/types.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,27 @@ struct PyIterable{T}
169169
PyIterable{T}(x) where {T} = new{T}(Py(x))
170170
end
171171

172+
"""
173+
PyString(x)
174+
175+
Wraps the Python `str` `x` as an `AbstractString` without copying.
176+
177+
The UTF-8 data is stored as a pointer and byte length obtained from
178+
`PyUnicode_AsUTF8AndSize`, and remains valid as long as the underlying Python
179+
object is alive.
180+
"""
181+
struct PyString <: AbstractString
182+
py::Py
183+
ptr::Ptr{UInt8}
184+
nbytes::Int
185+
function PyString(x)
186+
py = Py(x)
187+
PythonCall.Core.pyisstr(py) || throw(ArgumentError("PyString expects a Python `str`"))
188+
ptr, n = PythonCall.Core.pystr_utf8_pointer(py)
189+
new(py, ptr, n)
190+
end
191+
end
192+
172193
"""
173194
PyList{T=Py}([x])
174195

src/C/pointers.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ const CAPI_FUNC_SIGS = Dict{Symbol,Pair{Tuple,Type}}(
139139
:PyComplex_AsCComplex => (PyPtr,) => Py_complex,
140140
# STR
141141
:PyUnicode_DecodeUTF8 => (Ptr{Cchar}, Py_ssize_t, Ptr{Cchar}) => PyPtr,
142+
:PyUnicode_AsUTF8AndSize => (PyPtr, Ptr{Py_ssize_t}) => Ptr{Cchar},
142143
:PyUnicode_AsUTF8String => (PyPtr,) => PyPtr,
143144
:PyUnicode_InternInPlace => (Ptr{PyPtr},) => Cvoid,
144145
# BYTES

src/Convert/pyconvert.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ function pyconvert_rule_fast(::Type{T}, x::Py) where {T}
321321
pyisfloat(x) && return pyconvert_return(T(pyfloat_asdouble(x)))
322322
elseif (T == ComplexF64)
323323
pyiscomplex(x) && return pyconvert_return(T(pycomplex_ascomplex(x)))
324-
elseif (T == String) | (T == Char) | (T == Symbol)
324+
elseif (T == String) | (T == Char) | (T == Symbol) | (T == PyString)
325325
pyisstr(x) && return pyconvert_rule_str(T, x)
326326
elseif (T == Vector{UInt8}) | (T == Base.CodeUnits{UInt8,String})
327327
pyisbytes(x) && return pyconvert_rule_bytes(T, x)
@@ -473,6 +473,7 @@ function init_pyconvert()
473473
pyconvert_add_rule("builtins:float", Nothing, pyconvert_rule_float, priority)
474474
pyconvert_add_rule("builtins:float", Missing, pyconvert_rule_float, priority)
475475
pyconvert_add_rule("numbers:Complex", Number, pyconvert_rule_complex, priority)
476+
pyconvert_add_rule("builtins:str", PyString, pyconvert_rule_str, priority)
476477
pyconvert_add_rule("numbers:Integral", Number, pyconvert_rule_int, priority)
477478
pyconvert_add_rule("builtins:str", Symbol, pyconvert_rule_str, priority)
478479
pyconvert_add_rule("builtins:str", Char, pyconvert_rule_str, priority)

src/Convert/rules.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ pyconvert_rule_str(::Type{Char}, x::Py) = begin
4848
pyconvert_unconverted()
4949
end
5050
end
51+
pyconvert_rule_str(::Type{PyString}, x::Py) = pyconvert_return(PyString(x))
5152

5253
### bytes
5354

src/Core/builtins.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,13 @@ pystr_asUTF8vector(x::Py) =
573573
pystr_asstring(x::Py) =
574574
(b = pystr_asUTF8bytes(x); ans = pybytes_asUTF8string(b); pydel!(b); ans)
575575

576+
function pystr_utf8_pointer(x::Py)
577+
n = Ref{C.Py_ssize_t}()
578+
p = C.PyUnicode_AsUTF8AndSize(x, n)
579+
p == C_NULL && pythrow()
580+
Ptr{UInt8}(p), Int(n[])
581+
end
582+
576583
function pystr_intern!(x::Py)
577584
ptr = Ref(getptr(x))
578585
C.PyUnicode_InternInPlace(ptr)

src/Utils/Utils.jl

Lines changed: 74 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -180,109 +180,79 @@ size_to_fstrides(elsz::Integer, sz::Tuple{Vararg{Integer}}) =
180180
size_to_cstrides(elsz::Integer, sz::Tuple{Vararg{Integer}}) =
181181
isempty(sz) ? () : (size_to_cstrides(elsz * sz[end], sz[1:end-1])..., elsz)
182182

183-
struct StaticString{T,N} <: AbstractString
184-
codeunits::NTuple{N,T}
185-
StaticString{T,N}(codeunits::NTuple{N,T}) where {T,N} = new{T,N}(codeunits)
186-
end
187-
188-
function Base.String(x::StaticString{T,N}) where {T,N}
189-
ts = x.codeunits
190-
n = N
191-
while n > 0 && iszero(ts[n])
192-
n -= 1
183+
function utf8_allzeros(codeunit::F, n::Int, i::Int) where {F}
184+
@inbounds for j in i:n
185+
iszero(codeunit(j)) || return false
193186
end
194-
cs = T[ts[i] for i = 1:n]
195-
transcode(String, cs)
187+
return true
196188
end
197189

198-
function Base.convert(::Type{StaticString{T,N}}, x::AbstractString) where {T,N}
199-
ts = transcode(T, convert(String, x))
200-
n = length(ts)
201-
n > N && throw(InexactError(:convert, StaticString{T,N}, x))
202-
n > 0 && iszero(ts[n]) && throw(InexactError(:convert, StaticString{T,N}, x))
203-
z = zero(T)
204-
cs = ntuple(i -> i > n ? z : @inbounds(ts[i]), N)
205-
StaticString{T,N}(cs)
206-
end
207-
208-
StaticString{T,N}(x::AbstractString) where {T,N} = convert(StaticString{T,N}, x)
209-
210-
Base.ncodeunits(x::StaticString{T,N}) where {T,N} = N
211-
212-
Base.codeunit(x::StaticString, i::Integer) = x.codeunits[i]
213-
214-
Base.codeunit(x::StaticString{T}) where {T} = T
215-
216-
function Base.isvalid(x::StaticString{UInt8,N}, i::Int) where {N}
217-
if i < 1 || i > N
190+
function utf8_isvalid(codeunit::F, n::Int, i::Int; zeroterminated::Bool = false) where {F}
191+
if i < 1 || i > n
218192
return false
219193
end
220-
cs = x.codeunits
221-
c = @inbounds cs[i]
222-
if all(iszero, (cs[j] for j = i:N))
223-
return false
224-
elseif (c & 0x80) == 0x00
194+
zeroterminated && utf8_allzeros(codeunit, n, i) && return false
195+
c = @inbounds codeunit(i)
196+
if (c & 0x80) == 0x00
225197
return true
226198
elseif (c & 0x40) == 0x00
227199
return false
228200
elseif (c & 0x20) == 0x00
229-
return @inbounds (i N - 1) && ((cs[i+1] & 0xC0) == 0x80)
201+
return @inbounds (i n - 1) && ((codeunit(i + 1) & 0xC0) == 0x80)
230202
elseif (c & 0x10) == 0x00
231-
return @inbounds (i N - 2) &&
232-
((cs[i+1] & 0xC0) == 0x80) &&
233-
((cs[i+2] & 0xC0) == 0x80)
203+
return @inbounds (i n - 2) &&
204+
((codeunit(i + 1) & 0xC0) == 0x80) &&
205+
((codeunit(i + 2) & 0xC0) == 0x80)
234206
elseif (c & 0x08) == 0x00
235-
return @inbounds (i N - 3) &&
236-
((cs[i+1] & 0xC0) == 0x80) &&
237-
((cs[i+2] & 0xC0) == 0x80) &&
238-
((cs[i+3] & 0xC0) == 0x80)
207+
return @inbounds (i n - 3) &&
208+
((codeunit(i + 1) & 0xC0) == 0x80) &&
209+
((codeunit(i + 2) & 0xC0) == 0x80) &&
210+
((codeunit(i + 3) & 0xC0) == 0x80)
239211
else
240212
return false
241213
end
242-
return false
243214
end
244215

245-
function Base.iterate(x::StaticString{UInt8,N}, i::Int = 1) where {N}
246-
i > N && return
247-
cs = x.codeunits
248-
c = @inbounds cs[i]
249-
if all(iszero, (cs[j] for j = i:N))
216+
function utf8_iterate(x, codeunit::F, n::Int, i::Int = 1; zeroterminated::Bool = false) where {F}
217+
i > n && return
218+
c = @inbounds codeunit(i)
219+
if zeroterminated && utf8_allzeros(codeunit, n, i)
250220
return
251221
elseif (c & 0x80) == 0x00
252222
return (reinterpret(Char, UInt32(c) << 24), i + 1)
253223
elseif (c & 0x40) == 0x00
254224
nothing
255225
elseif (c & 0x20) == 0x00
256-
if @inbounds (i N - 1) && ((cs[i+1] & 0xC0) == 0x80)
226+
if @inbounds (i n - 1) && ((codeunit(i + 1) & 0xC0) == 0x80)
257227
return (
258-
reinterpret(Char, (UInt32(cs[i]) << 24) | (UInt32(cs[i+1]) << 16)),
228+
reinterpret(Char, (UInt32(codeunit(i)) << 24) | (UInt32(codeunit(i + 1)) << 16)),
259229
i + 2,
260230
)
261231
end
262232
elseif (c & 0x10) == 0x00
263-
if @inbounds (i N - 2) && ((cs[i+1] & 0xC0) == 0x80) && ((cs[i+2] & 0xC0) == 0x80)
233+
if @inbounds (i n - 2) && ((codeunit(i + 1) & 0xC0) == 0x80) && ((codeunit(i + 2) & 0xC0) == 0x80)
264234
return (
265235
reinterpret(
266236
Char,
267-
(UInt32(cs[i]) << 24) |
268-
(UInt32(cs[i+1]) << 16) |
269-
(UInt32(cs[i+2]) << 8),
237+
(UInt32(codeunit(i)) << 24) |
238+
(UInt32(codeunit(i + 1)) << 16) |
239+
(UInt32(codeunit(i + 2)) << 8),
270240
),
271241
i + 3,
272242
)
273243
end
274244
elseif (c & 0x08) == 0x00
275-
if @inbounds (i N - 3) &&
276-
((cs[i+1] & 0xC0) == 0x80) &&
277-
((cs[i+2] & 0xC0) == 0x80) &&
278-
((cs[i+3] & 0xC0) == 0x80)
245+
if @inbounds (i n - 3) &&
246+
((codeunit(i + 1) & 0xC0) == 0x80) &&
247+
((codeunit(i + 2) & 0xC0) == 0x80) &&
248+
((codeunit(i + 3) & 0xC0) == 0x80)
279249
return (
280250
reinterpret(
281251
Char,
282-
(UInt32(cs[i]) << 24) |
283-
(UInt32(cs[i+1]) << 16) |
284-
(UInt32(cs[i+2]) << 8) |
285-
UInt32(cs[i+3]),
252+
(UInt32(codeunit(i)) << 24) |
253+
(UInt32(codeunit(i + 1)) << 16) |
254+
(UInt32(codeunit(i + 2)) << 8) |
255+
UInt32(codeunit(i + 3)),
286256
),
287257
i + 4,
288258
)
@@ -291,6 +261,45 @@ function Base.iterate(x::StaticString{UInt8,N}, i::Int = 1) where {N}
291261
throw(StringIndexError(x, i))
292262
end
293263

264+
struct StaticString{T,N} <: AbstractString
265+
codeunits::NTuple{N,T}
266+
StaticString{T,N}(codeunits::NTuple{N,T}) where {T,N} = new{T,N}(codeunits)
267+
end
268+
269+
function Base.String(x::StaticString{T,N}) where {T,N}
270+
ts = x.codeunits
271+
n = N
272+
while n > 0 && iszero(ts[n])
273+
n -= 1
274+
end
275+
cs = T[ts[i] for i = 1:n]
276+
transcode(String, cs)
277+
end
278+
279+
function Base.convert(::Type{StaticString{T,N}}, x::AbstractString) where {T,N}
280+
ts = transcode(T, convert(String, x))
281+
n = length(ts)
282+
n > N && throw(InexactError(:convert, StaticString{T,N}, x))
283+
n > 0 && iszero(ts[n]) && throw(InexactError(:convert, StaticString{T,N}, x))
284+
z = zero(T)
285+
cs = ntuple(i -> i > n ? z : @inbounds(ts[i]), N)
286+
StaticString{T,N}(cs)
287+
end
288+
289+
StaticString{T,N}(x::AbstractString) where {T,N} = convert(StaticString{T,N}, x)
290+
291+
Base.ncodeunits(x::StaticString{T,N}) where {T,N} = N
292+
293+
Base.codeunit(x::StaticString, i::Integer) = x.codeunits[i]
294+
295+
Base.codeunit(x::StaticString{T}) where {T} = T
296+
297+
Base.isvalid(x::StaticString{UInt8,N}, i::Int) where {N} =
298+
utf8_isvalid(j -> @inbounds(x.codeunits[j]), N, i; zeroterminated = true)
299+
300+
Base.iterate(x::StaticString{UInt8,N}, i::Int = 1) where {N} =
301+
utf8_iterate(x, j -> @inbounds(x.codeunits[j]), N, i; zeroterminated = true)
302+
294303
function Base.isvalid(x::StaticString{UInt32,N}, i::Int) where {N}
295304
i < 1 && return false
296305
cs = x.codeunits

0 commit comments

Comments
 (0)