@@ -3,6 +3,7 @@ use std::sync::{Arc, RwLock};
3
3
use crate :: pre_tokenizers:: from_string;
4
4
use crate :: tokenizer:: PyTokenizer ;
5
5
use crate :: utils:: PyPattern ;
6
+ use compact_str:: ToCompactString ;
6
7
use pyo3:: exceptions;
7
8
use pyo3:: prelude:: * ;
8
9
use pyo3:: types:: * ;
@@ -91,7 +92,10 @@ impl PyDecoder {
91
92
}
92
93
93
94
impl Decoder for PyDecoder {
94
- fn decode_chain ( & self , tokens : Vec < String > ) -> tk:: Result < Vec < String > > {
95
+ fn decode_chain < T : ToCompactString > (
96
+ & self ,
97
+ tokens : Vec < T > ,
98
+ ) -> tk:: Result < Vec < impl ToCompactString > > {
95
99
self . decoder . decode_chain ( tokens)
96
100
}
97
101
}
@@ -139,7 +143,12 @@ impl PyDecoder {
139
143
/// :obj:`str`: The decoded string
140
144
#[ pyo3( text_signature = "(self, tokens)" ) ]
141
145
fn decode ( & self , tokens : Vec < String > ) -> PyResult < String > {
142
- ToPyResult ( self . decoder . decode ( tokens) ) . into ( )
146
+ ToPyResult (
147
+ self . decoder
148
+ . decode ( tokens)
149
+ . map ( |t| t. to_compact_string ( ) . to_string ( ) ) ,
150
+ )
151
+ . into ( )
143
152
}
144
153
145
154
fn __repr__ ( & self ) -> PyResult < String > {
@@ -235,12 +244,12 @@ pub struct PyWordPieceDec {}
235
244
impl PyWordPieceDec {
236
245
#[ getter]
237
246
fn get_prefix ( self_ : PyRef < Self > ) -> String {
238
- getter ! ( self_, WordPiece , prefix. clone( ) )
247
+ getter ! ( self_, WordPiece , prefix. clone( ) . to_string ( ) )
239
248
}
240
249
241
250
#[ setter]
242
251
fn set_prefix ( self_ : PyRef < Self > , prefix : String ) {
243
- setter ! ( self_, WordPiece , prefix, prefix) ;
252
+ setter ! ( self_, WordPiece , prefix, prefix. to_compact_string ( ) ) ;
244
253
}
245
254
246
255
#[ getter]
@@ -256,7 +265,10 @@ impl PyWordPieceDec {
256
265
#[ new]
257
266
#[ pyo3( signature = ( prefix = String :: from( "##" ) , cleanup = true ) , text_signature = "(self, prefix=\" ##\" , cleanup=True)" ) ]
258
267
fn new ( prefix : String , cleanup : bool ) -> ( Self , PyDecoder ) {
259
- ( PyWordPieceDec { } , WordPiece :: new ( prefix, cleanup) . into ( ) )
268
+ (
269
+ PyWordPieceDec { } ,
270
+ WordPiece :: new ( prefix. to_compact_string ( ) , cleanup) . into ( ) ,
271
+ )
260
272
}
261
273
}
262
274
@@ -526,22 +538,33 @@ impl CustomDecoder {
526
538
}
527
539
528
540
impl Decoder for CustomDecoder {
529
- fn decode ( & self , tokens : Vec < String > ) -> tk:: Result < String > {
541
+ fn decode < T : ToCompactString > ( & self , tokens : Vec < T > ) -> tk:: Result < impl ToCompactString > {
542
+ let tokens: Vec < String > = tokens
543
+ . into_iter ( )
544
+ . map ( |t| t. to_compact_string ( ) . to_string ( ) )
545
+ . collect ( ) ;
530
546
Python :: with_gil ( |py| {
531
547
let decoded = self
532
548
. inner
533
549
. call_method ( py, "decode" , ( tokens, ) , None ) ?
534
- . extract ( py) ?;
550
+ . extract :: < String > ( py) ?;
535
551
Ok ( decoded)
536
552
} )
537
553
}
538
554
539
- fn decode_chain ( & self , tokens : Vec < String > ) -> tk:: Result < Vec < String > > {
555
+ fn decode_chain < T : ToCompactString > (
556
+ & self ,
557
+ tokens : Vec < T > ,
558
+ ) -> tk:: Result < Vec < impl ToCompactString > > {
559
+ let tokens: Vec < String > = tokens
560
+ . into_iter ( )
561
+ . map ( |t| t. to_compact_string ( ) . to_string ( ) )
562
+ . collect ( ) ;
540
563
Python :: with_gil ( |py| {
541
564
let decoded = self
542
565
. inner
543
566
. call_method ( py, "decode_chain" , ( tokens, ) , None ) ?
544
- . extract ( py) ?;
567
+ . extract :: < Vec < String > > ( py) ?;
545
568
Ok ( decoded)
546
569
} )
547
570
}
@@ -595,10 +618,21 @@ where
595
618
}
596
619
597
620
impl Decoder for PyDecoderWrapper {
598
- fn decode_chain ( & self , tokens : Vec < String > ) -> tk:: Result < Vec < String > > {
621
+ fn decode_chain < T : ToCompactString > (
622
+ & self ,
623
+ tokens : Vec < T > ,
624
+ ) -> tk:: Result < Vec < impl ToCompactString > > {
599
625
match self {
600
- PyDecoderWrapper :: Wrapped ( inner) => inner. read ( ) . unwrap ( ) . decode_chain ( tokens) ,
601
- PyDecoderWrapper :: Custom ( inner) => inner. read ( ) . unwrap ( ) . decode_chain ( tokens) ,
626
+ PyDecoderWrapper :: Wrapped ( inner) => inner
627
+ . read ( )
628
+ . unwrap ( )
629
+ . decode_chain ( tokens)
630
+ . map ( |v| v. into_iter ( ) . map ( |t| t. to_compact_string ( ) ) . collect ( ) ) ,
631
+ PyDecoderWrapper :: Custom ( inner) => inner
632
+ . read ( )
633
+ . unwrap ( )
634
+ . decode_chain ( tokens)
635
+ . map ( |v| v. into_iter ( ) . map ( |t| t. to_compact_string ( ) ) . collect ( ) ) ,
602
636
}
603
637
}
604
638
}
@@ -663,14 +697,17 @@ impl PyDecodeStream {
663
697
664
698
#[ pyo3( signature = ( tokenizer, id) , text_signature = "(self, tokenizer, id)" ) ]
665
699
fn step ( & mut self , tokenizer : & PyTokenizer , id : u32 ) -> PyResult < Option < String > > {
666
- ToPyResult ( tk:: tokenizer:: step_decode_stream (
667
- & tokenizer. tokenizer ,
668
- id,
669
- self . skip_special_tokens ,
670
- & mut self . ids ,
671
- & mut self . prefix ,
672
- & mut self . prefix_index ,
673
- ) )
700
+ ToPyResult (
701
+ tk:: tokenizer:: step_decode_stream (
702
+ & tokenizer. tokenizer ,
703
+ id,
704
+ self . skip_special_tokens ,
705
+ & mut self . ids ,
706
+ & mut self . prefix . to_compact_string ( ) ,
707
+ & mut self . prefix_index ,
708
+ )
709
+ . map ( |o| o. map ( |s| s. to_string ( ) ) ) ,
710
+ )
674
711
. into ( )
675
712
}
676
713
}
0 commit comments