-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcorpus.py
94 lines (82 loc) · 3.25 KB
/
corpus.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os, random
import torch
import pickle
import util, vocab
class Corpus(object):
def __init__(self):
self.vocab = None
self.vocab_fname = ""
self.my_fname = ""
self.trains = [[]]
self.valids = [[]]
self.tests = [[]]
def eventize(self, path, args):
''' returns a list of lists, where each list is a single channel. '''
assert os.path.exists(path)
nevents = 0
maxlen = 0
melodies = [[] for _ in range(self.vocab.num_channels)]
for f in util.getmidfiles(path):
for c in range(self.vocab.num_channels):
if args.use_metaf:
basename = os.path.basename(f)
meta_dicts = util.get_meta_dicts(path, args)
if basename not in meta_dicts:
print "Skipping", basename
continue
meta_dict = meta_dicts[basename]
else:
meta_dict = {'f': f}
if args.synth_data:
melody = [('start', 'start')] + [(str(n),d) for n, d in meta_dict['origs']] + [('end', 'end')]
else:
melody, _ = self.vocab.mid2orig(f, include_measure_boundaries=args.measure_tokens, channel=c)
if len(melody) < 8 or len(melody) > 400:
print "Skipping", f
continue
melodies[c].append(
(
[self.vocab.orig2e[c][orig].i for orig in melody],
meta_dict
)
)
for c in range(self.vocab.num_channels):
melodies[c].sort(key=lambda x: -len(x[0]))
return melodies
def save(self):
info_dict = {
"trains": self.trains,
"valids": self.valids,
"tests": self.tests,
"vocab_fname": self.vocab_fname,
}
with open(self.my_fname, "w") as f: pickle.dump(info_dict, f)
self.vocab.save(self.vocab_fname)
@classmethod
def load(clss, filename):
with open(filename, "r") as f:
info_dict = pickle.load(f)
corpus = clss()
print "Load", corpus
corpus.trains = info_dict["trains"]
corpus.valids = info_dict["valids"]
corpus.tests = info_dict["tests"]
corpus.vocab_fname = info_dict["vocab_fname"]
corpus.vocab = vocab.PitchDurationVocab.load(corpus.vocab_fname)
return corpus
@classmethod
def load_from_corpus(clss, vocab, vocab_fname, corpus_fname, args):
if os.path.isfile(corpus_fname):
print "Loading existing Corpus", corpus_fname
return clss.load(corpus_fname)
print "Creating new Corpus", corpus_fname
corpus = clss()
corpus.vocab = vocab
corpus.vocab_fname = vocab_fname
corpus.my_fname = corpus_fname
corpus.trains = corpus.eventize(os.path.join(args.path, 'train/'), args)
corpus.valids = corpus.eventize(os.path.join(args.path, 'valid/'), args)
corpus.tests = corpus.eventize(os.path.join(args.path, 'test/'), args)
print "Saving new Corpus", corpus_fname
corpus.save()
return corpus