forked from joshcarp/llm.go
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtokenizer.go
143 lines (134 loc) · 2.93 KB
/
tokenizer.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
package llmgo
import (
"encoding/binary"
"errors"
)
type Tokenizer struct {
vocabSize uint32
tokenTable []string
trie trie
init bool
}
func newTokenizer(vocab []string) Tokenizer {
tokenizer := Tokenizer{
vocabSize: uint32(len(vocab)),
tokenTable: vocab,
trie: newTrie(vocab),
init: true,
}
return tokenizer
}
func NewTokenizer(filename string) (Tokenizer, error) {
f, err := Open(filename)
if err != nil {
return Tokenizer{}, err
}
defer f.Close()
header := make([]uint32, 256)
if err := binary.Read(f, binary.LittleEndian, header); err != nil {
return Tokenizer{}, err
}
if header[0] != 20240328 || header[1] != 1 {
return Tokenizer{}, errors.New("incorrect header for tokenizer")
}
tok := Tokenizer{
vocabSize: header[2],
tokenTable: make([]string, header[2]),
init: true,
trie: newTrie(nil),
}
var length byte
for i := range tok.tokenTable {
if err := binary.Read(f, binary.LittleEndian, &length); err != nil {
return tok, err
}
if length <= 0 {
return tok, errors.New("tokenizer failure")
}
tokenBytes := make([]byte, length)
if err := binary.Read(f, binary.LittleEndian, tokenBytes); err != nil {
return tok, err
}
tok.tokenTable[i] = string(tokenBytes)
tok.trie.Insert(tokenBytes, int32(i))
}
return tok, nil
}
func (t Tokenizer) Decode(tokens []int32) (string, error) {
s := ""
for _, token := range tokens {
if token >= int32(len(t.tokenTable)) {
return "", errors.New("not valid token")
}
if token != GPT2_EOT {
s += t.tokenTable[token]
}
}
return s, nil
}
func (t Tokenizer) Encode(text string) ([]int32, error) {
_, tokens := t.trie.Tokenize([]byte(text))
return tokens, nil
}
type trie struct {
children map[byte]*trie
data int32
end bool
key byte
}
func newTrie(data []string) trie {
t := trie{
children: map[byte]*trie{},
end: false,
}
for i, word := range data {
t.Insert([]byte(word), int32(i))
}
return t
}
func (t *trie) Insert(word []byte, data int32) error {
cur := t
if len(word) == 0 {
return errors.New("zero length word not supported")
}
var index byte
for i := 0; i < len(word); i++ {
index = word[i] // 00: 0
if cur.children[index] == nil {
cur.children[index] = &trie{
children: map[byte]*trie{},
}
}
cur = cur.children[index]
}
cur.end = true
cur.data = data
cur.key = index
return nil
}
func (t *trie) Tokenize(input []byte) ([][]byte, []int32) {
var cur = t
var token = GPT2_EOT
endIdx, next := 1, 0
split, tokens := make([][]byte, 0), make([]int32, 0)
for len(input) != 0 {
switch {
case next == len(input), cur.children[input[next]] == nil:
split = append(split, input[:endIdx])
tokens = append(tokens, token)
input = input[endIdx:]
token = GPT2_EOT
cur = t
next = 0
endIdx = 1
default:
cur = cur.children[input[next]]
next += 1
if cur.end {
endIdx = next
token = cur.data
}
}
}
return split, tokens
}