diff --git a/tencentpretrain/utils/dataset.py b/tencentpretrain/utils/dataset.py index ce42ac3..f17ec06 100755 --- a/tencentpretrain/utils/dataset.py +++ b/tencentpretrain/utils/dataset.py @@ -229,6 +229,7 @@ class MlmDataset(Dataset): def __init__(self, args, vocab, tokenizer): super(MlmDataset, self).__init__(args, vocab, tokenizer) self.full_sentences = args.full_sentences + self.json_format_corpus = args.json_format_corpus def worker(self, proc_id, start, end): print("Worker %d is building dataset ... " % proc_id) @@ -243,7 +244,13 @@ def worker(self, proc_id, start, end): pos += 1 while True: line = f.readline() + if self.json_format_corpus: + data = json.loads(line) + line = data.get("text", "") + data.get("content", "") + pos += 1 + if len(line) < 5: + continue document = [self.vocab.get(CLS_TOKEN)] + self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(line)) + [self.vocab.get(SEP_TOKEN)]