@@ -13,80 +13,60 @@ fn _byte_pair_merge(
1313 piece : & [ u8 ] ,
1414) -> Vec < ( usize , Rank ) > {
1515 // This is a vector of (start, rank).
16- // The rank is of the byte pair starting at position start.
17- // The rank of the last item in the vector is not a valid value.
18- let mut parts: Vec < ( usize , Rank ) > = ( 0 ..piece. len ( ) + 1 ) . map ( |i| ( i, Rank :: MAX ) ) . collect ( ) ;
16+ // The rank is of the pair starting at position start.
17+ let mut parts = Vec :: with_capacity ( piece. len ( ) + 1 ) ;
18+
19+ // Note that we hash bytes when indexing into `ranks`, not token pairs. As long as we train BPE
20+ // the way we currently do, this is equivalent. An easy way to break this would be to decouple
21+ // merge priority from token index or to prevent specific token merges.
22+ let mut min_rank: ( Rank , usize ) = ( Rank :: MAX , usize:: MAX ) ;
23+ for i in 0 ..piece. len ( ) - 1 {
24+ let rank = * ranks. get ( & piece[ i..i + 2 ] ) . unwrap_or ( & Rank :: MAX ) ;
25+ if rank < min_rank. 0 {
26+ min_rank = ( rank, i) ;
27+ }
28+ parts. push ( ( i, rank) ) ;
29+ }
30+ parts. push ( ( piece. len ( ) - 1 , Rank :: MAX ) ) ;
31+ parts. push ( ( piece. len ( ) , Rank :: MAX ) ) ;
1932
2033 let get_rank = {
2134 #[ inline( always) ]
22- |parts : & Vec < ( usize , Rank ) > , start_idx : usize , skip : usize | {
23- if ( start_idx + skip + 2 ) < parts. len ( ) {
24- ranks
25- . get ( & piece[ parts[ start_idx] . 0 ..parts[ start_idx + skip + 2 ] . 0 ] )
26- . copied ( )
35+ |parts : & Vec < ( usize , Rank ) > , i : usize | {
36+ if ( i + 3 ) < parts. len ( ) {
37+ // Similar to `piece[i..i + 2]` above. The +3 is because we haven't yet deleted
38+ // parts[i + 1], see comment in the main loop.
39+ * ranks
40+ . get ( & piece[ parts[ i] . 0 ..parts[ i + 3 ] . 0 ] )
41+ . unwrap_or ( & Rank :: MAX )
2742 } else {
28- None
43+ Rank :: MAX
2944 }
3045 }
3146 } ;
3247
33- // We look up the ranks once in the beginning and iteratively update
34- // them during each merge, which reduces the number of rank lookups.
35- for i in 0 ..parts. len ( ) - 2 {
36- match get_rank ( & parts, i, 0 ) {
37- Some ( rank) => {
38- // Rank::MAX is a sentinel value and cannot be a valid rank
39- debug_assert ! ( rank != Rank :: MAX ) ;
40- parts[ i] . 1 = rank;
41- }
42- None => {
43- continue ;
44- }
45- } ;
46- }
47-
4848 // If you have n parts and m merges, this does O(mn) work.
4949 // We could do something with a heap and do O(m log n) work.
50- // It is important to consider that n is often small (<100), and as such
51- // the cache-locality benefits outweigh the algorithmic complexity downsides
52- // of the `parts` vector data structure above.
53-
54- // Note that we hash bytes, not token pairs. As long as we train BPE the way we
55- // currently do, this is equivalent. An easy way to break this would be to decouple
56- // merge priority from token index or to prevent specific token merges.
57- loop {
58- if parts. len ( ) == 1 {
59- break ;
50+ // n is often very small so considerations like cache-locality outweigh the algorithmic
51+ // complexity downsides of the `parts` vector.
52+ while min_rank. 0 != Rank :: MAX {
53+ let i = min_rank. 1 ;
54+ // Update parts[i] and parts[i - 1] before removing parts[i + 1], since
55+ // `parts.remove(i + 1)` will thrash the cache.
56+ if i > 0 {
57+ parts[ i - 1 ] . 1 = get_rank ( & parts, i - 1 ) ;
6058 }
59+ parts[ i] . 1 = get_rank ( & parts, i) ;
60+ parts. remove ( i + 1 ) ;
6161
62- // Rank::MAX is a sentinel rank value allowing us to
63- // take the min more quickly
64- let mut min_rank: ( Rank , usize ) = ( Rank :: MAX , 0 ) ;
62+ min_rank = ( Rank :: MAX , usize:: MAX ) ;
6563 for ( i, & ( _, rank) ) in parts[ ..parts. len ( ) - 1 ] . iter ( ) . enumerate ( ) {
6664 if rank < min_rank. 0 {
6765 min_rank = ( rank, i) ;
6866 }
6967 }
70-
71- if min_rank. 0 != Rank :: MAX {
72- let i = min_rank. 1 ;
73-
74- // NOTE: We are about to remove parts[i + 1]. We do not do it
75- // yet because there are cache-locality benefits to updating
76- // parts[i] and parts[i-1] before removing, which could thrash
77- // the cache. Thus, we update the rank calculation by skipping over
78- // parts[i + 1], by invoking `get_rank!` with `skip = 1`.
79- parts[ i] . 1 = get_rank ( & parts, i, 1 ) . unwrap_or ( Rank :: MAX ) ;
80- if i > 0 {
81- parts[ i - 1 ] . 1 = get_rank ( & parts, i - 1 , 1 ) . unwrap_or ( Rank :: MAX ) ;
82- }
83-
84- parts. remove ( i + 1 ) ;
85- } else {
86- break ;
87- }
8868 }
89-
69+
9070 parts
9171}
9272
0 commit comments