Skip to content

Commit 25b3622

Browse files
committed
add exps for ELMo on RACE
1 parent 7106032 commit 25b3622

File tree

4 files changed

+405
-5
lines changed

4 files changed

+405
-5
lines changed

README.md

+4-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,10 @@ This dir includes the ELMo model. Here we present experiments on the RACE datase
4141

4242
#### Run on RACE
4343

44-
TO DO!
44+
| epochs | batch size | max input length | Accuracy (%) on dev | Accuracy (%) on test |Accuracy (%) on middle test |Accuracy (%) on high test | device |
45+
| :--------: |:------:|:------:|:------:|:------:|:------:|:------:|:------:|
46+
| 3 | 8 | 512/32 | 39.39 | 38.57 | 38.23 | 39.02 | 1 GTX 1080 |
47+
| 25 | 8 | 512/32 | \ | \ | \ | \ | 1 GTX 1080 |
4548

4649
#### Run on SQuAD
4750

bilm-tf/bilm/data.py

100644100755
+9-4
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ class Batcher(object):
194194
'''
195195
Batch sentences of tokenized text into character id matrices.
196196
'''
197-
def __init__(self, lm_vocab_file: str, max_token_length: int):
197+
def __init__(self, lm_vocab_file: str, max_token_length: int, max_sentence_length: int = 0):
198198
'''
199199
lm_vocab_file = the language model vocabulary file (one line per
200200
token)
@@ -204,6 +204,7 @@ def __init__(self, lm_vocab_file: str, max_token_length: int):
204204
lm_vocab_file, max_token_length
205205
)
206206
self._max_token_length = max_token_length
207+
self._max_sentence_length = max_sentence_length
207208

208209
def batch_sentences(self, sentences: List[List[str]]):
209210
'''
@@ -213,16 +214,20 @@ def batch_sentences(self, sentences: List[List[str]]):
213214
'''
214215
n_sentences = len(sentences)
215216
max_length = max(len(sentence) for sentence in sentences) + 2
217+
if self._max_sentence_length > 0:
218+
max_length = min(self._max_sentence_length, max_length)
216219

217220
X_char_ids = np.zeros(
218221
(n_sentences, max_length, self._max_token_length),
219-
dtype=np.int64
222+
dtype=np.int32
220223
)
221224

222225
for k, sent in enumerate(sentences):
223-
length = len(sent) + 2
226+
length = len(sent)
227+
length = min(length, max_length - 2)
224228
char_ids_without_mask = self._lm_vocab.encode_chars(
225-
sent, split=False)
229+
sent[:length], split=False)
230+
length += 2
226231
# add one so that 0 is the mask value
227232
X_char_ids[k, :length, :] = char_ids_without_mask + 1
228233

0 commit comments

Comments
 (0)