Skip to content

Commit 86f3e09

Browse files
committed
Got Python bindings working.
1 parent fe9f594 commit 86f3e09

File tree

8 files changed

+347
-100
lines changed

8 files changed

+347
-100
lines changed

bindings/python/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ numpy = "0.23"
1919
ndarray = "0.16"
2020
itertools = "0.12"
2121
rustc-hash = "2.1.1"
22+
compact_str = { version = "0.8.1", features = ["serde"] }
2223

2324
[dependencies.tokenizers]
2425
path = "../../tokenizers"

bindings/python/src/decoders.rs

Lines changed: 57 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::sync::{Arc, RwLock};
33
use crate::pre_tokenizers::from_string;
44
use crate::tokenizer::PyTokenizer;
55
use crate::utils::PyPattern;
6+
use compact_str::ToCompactString;
67
use pyo3::exceptions;
78
use pyo3::prelude::*;
89
use pyo3::types::*;
@@ -91,7 +92,10 @@ impl PyDecoder {
9192
}
9293

9394
impl Decoder for PyDecoder {
94-
fn decode_chain(&self, tokens: Vec<String>) -> tk::Result<Vec<String>> {
95+
fn decode_chain<T: ToCompactString>(
96+
&self,
97+
tokens: Vec<T>,
98+
) -> tk::Result<Vec<impl ToCompactString>> {
9599
self.decoder.decode_chain(tokens)
96100
}
97101
}
@@ -139,7 +143,12 @@ impl PyDecoder {
139143
/// :obj:`str`: The decoded string
140144
#[pyo3(text_signature = "(self, tokens)")]
141145
fn decode(&self, tokens: Vec<String>) -> PyResult<String> {
142-
ToPyResult(self.decoder.decode(tokens)).into()
146+
ToPyResult(
147+
self.decoder
148+
.decode(tokens)
149+
.map(|t| t.to_compact_string().to_string()),
150+
)
151+
.into()
143152
}
144153

145154
fn __repr__(&self) -> PyResult<String> {
@@ -235,12 +244,12 @@ pub struct PyWordPieceDec {}
235244
impl PyWordPieceDec {
236245
#[getter]
237246
fn get_prefix(self_: PyRef<Self>) -> String {
238-
getter!(self_, WordPiece, prefix.clone())
247+
getter!(self_, WordPiece, prefix.clone().to_string())
239248
}
240249

241250
#[setter]
242251
fn set_prefix(self_: PyRef<Self>, prefix: String) {
243-
setter!(self_, WordPiece, prefix, prefix);
252+
setter!(self_, WordPiece, prefix, prefix.to_compact_string());
244253
}
245254

246255
#[getter]
@@ -256,7 +265,10 @@ impl PyWordPieceDec {
256265
#[new]
257266
#[pyo3(signature = (prefix = String::from("##"), cleanup = true), text_signature = "(self, prefix=\"##\", cleanup=True)")]
258267
fn new(prefix: String, cleanup: bool) -> (Self, PyDecoder) {
259-
(PyWordPieceDec {}, WordPiece::new(prefix, cleanup).into())
268+
(
269+
PyWordPieceDec {},
270+
WordPiece::new(prefix.to_compact_string(), cleanup).into(),
271+
)
260272
}
261273
}
262274

@@ -526,22 +538,33 @@ impl CustomDecoder {
526538
}
527539

528540
impl Decoder for CustomDecoder {
529-
fn decode(&self, tokens: Vec<String>) -> tk::Result<String> {
541+
fn decode<T: ToCompactString>(&self, tokens: Vec<T>) -> tk::Result<impl ToCompactString> {
542+
let tokens: Vec<String> = tokens
543+
.into_iter()
544+
.map(|t| t.to_compact_string().to_string())
545+
.collect();
530546
Python::with_gil(|py| {
531547
let decoded = self
532548
.inner
533549
.call_method(py, "decode", (tokens,), None)?
534-
.extract(py)?;
550+
.extract::<String>(py)?;
535551
Ok(decoded)
536552
})
537553
}
538554

539-
fn decode_chain(&self, tokens: Vec<String>) -> tk::Result<Vec<String>> {
555+
fn decode_chain<T: ToCompactString>(
556+
&self,
557+
tokens: Vec<T>,
558+
) -> tk::Result<Vec<impl ToCompactString>> {
559+
let tokens: Vec<String> = tokens
560+
.into_iter()
561+
.map(|t| t.to_compact_string().to_string())
562+
.collect();
540563
Python::with_gil(|py| {
541564
let decoded = self
542565
.inner
543566
.call_method(py, "decode_chain", (tokens,), None)?
544-
.extract(py)?;
567+
.extract::<Vec<String>>(py)?;
545568
Ok(decoded)
546569
})
547570
}
@@ -595,10 +618,21 @@ where
595618
}
596619

597620
impl Decoder for PyDecoderWrapper {
598-
fn decode_chain(&self, tokens: Vec<String>) -> tk::Result<Vec<String>> {
621+
fn decode_chain<T: ToCompactString>(
622+
&self,
623+
tokens: Vec<T>,
624+
) -> tk::Result<Vec<impl ToCompactString>> {
599625
match self {
600-
PyDecoderWrapper::Wrapped(inner) => inner.read().unwrap().decode_chain(tokens),
601-
PyDecoderWrapper::Custom(inner) => inner.read().unwrap().decode_chain(tokens),
626+
PyDecoderWrapper::Wrapped(inner) => inner
627+
.read()
628+
.unwrap()
629+
.decode_chain(tokens)
630+
.map(|v| v.into_iter().map(|t| t.to_compact_string()).collect()),
631+
PyDecoderWrapper::Custom(inner) => inner
632+
.read()
633+
.unwrap()
634+
.decode_chain(tokens)
635+
.map(|v| v.into_iter().map(|t| t.to_compact_string()).collect()),
602636
}
603637
}
604638
}
@@ -663,14 +697,17 @@ impl PyDecodeStream {
663697

664698
#[pyo3(signature = (tokenizer, id), text_signature = "(self, tokenizer, id)")]
665699
fn step(&mut self, tokenizer: &PyTokenizer, id: u32) -> PyResult<Option<String>> {
666-
ToPyResult(tk::tokenizer::step_decode_stream(
667-
&tokenizer.tokenizer,
668-
id,
669-
self.skip_special_tokens,
670-
&mut self.ids,
671-
&mut self.prefix,
672-
&mut self.prefix_index,
673-
))
700+
ToPyResult(
701+
tk::tokenizer::step_decode_stream(
702+
&tokenizer.tokenizer,
703+
id,
704+
self.skip_special_tokens,
705+
&mut self.ids,
706+
&mut self.prefix.to_compact_string(),
707+
&mut self.prefix_index,
708+
)
709+
.map(|o| o.map(|s| s.to_string())),
710+
)
674711
.into()
675712
}
676713
}

bindings/python/src/encoding.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,11 @@ impl PyEncoding {
127127
/// :obj:`List[str]`: The list of tokens
128128
#[getter]
129129
fn get_tokens(&self) -> Vec<String> {
130-
self.encoding.get_tokens().to_vec()
130+
self.encoding
131+
.get_tokens()
132+
.into_iter()
133+
.map(|x| x.to_string())
134+
.collect()
131135
}
132136

133137
/// The generated word indices.

0 commit comments

Comments
 (0)