@@ -38,12 +38,15 @@ def __init__(self, tokens: str):
38
38
id = int (info [0 ])
39
39
else :
40
40
token , id = info [0 ], int (info [1 ])
41
+ assert token not in self .token2id , token
41
42
self .token2id [token ] = id
42
43
43
- self .blank_id = self .token2id ["<blk>" ]
44
- self .sos_id = self .token2id ["<sos>" ]
45
- self .eos_id = self .token2id ["<eos>" ]
46
- self .oov_id = self .token2id ["<unk>" ]
44
+ # Refer to https://github.com/rhasspy/piper/blob/master/TRAINING.md
45
+ self .pad_id = self .token2id ["_" ] # padding
46
+ self .sos_id = self .token2id ["^" ] # beginning of an utterance (bos)
47
+ self .eos_id = self .token2id ["$" ] # end of an utterance (eos)
48
+ self .space_id = self .token2id [" " ] # word separator (whitespace)
49
+
47
50
self .vocab_size = len (self .token2id )
48
51
49
52
def texts_to_token_ids (
@@ -80,13 +83,11 @@ def texts_to_token_ids(
80
83
81
84
token_ids = []
82
85
for t in tokens :
83
- if t in self .token2id :
84
- token_ids .append (self .token2id [t ])
85
- else :
86
- token_ids .append (self .oov_id )
86
+ assert t in self .token2id , t
87
+ token_ids .append (self .token2id [t ])
87
88
88
89
if intersperse_blank :
89
- token_ids = intersperse (token_ids , self .blank_id )
90
+ token_ids = intersperse (token_ids , self .pad_id )
90
91
if add_sos :
91
92
token_ids = [self .sos_id ] + token_ids
92
93
if add_eos :
@@ -122,13 +123,11 @@ def tokens_to_token_ids(
122
123
for tokens in tokens_list :
123
124
token_ids = []
124
125
for t in tokens :
125
- if t in self .token2id :
126
- token_ids .append (self .token2id [t ])
127
- else :
128
- token_ids .append (self .oov_id )
126
+ assert t in self .token2id , t
127
+ token_ids .append (self .token2id [t ])
129
128
130
129
if intersperse_blank :
131
- token_ids = intersperse (token_ids , self .blank_id )
130
+ token_ids = intersperse (token_ids , self .pad_id )
132
131
if add_sos :
133
132
token_ids = [self .sos_id ] + token_ids
134
133
if add_eos :
0 commit comments