-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgen_util.py
47 lines (42 loc) · 1.84 KB
/
gen_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
from torch.autograd import Variable
import os, sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import similarity, util
def get_events(sv, args, midf, meta_dict):
channel_event_idxs = [[] for _ in range(sv.num_channels)]
for channel in range(sv.num_channels):
if args.synth_data:
origs = [sv.special_events['start'].original] + meta_dict['origs'] + [sv.special_events['end'].original]
else:
origs, _ = sv.mid2orig(args.condition_piece, include_measure_boundaries=args.measure_tokens, channel=channel)
mel_idxs = [sv.orig2e[channel][o].i for o in origs]
for idx in mel_idxs:
channel_event_idxs[channel].append(idx)
return channel_event_idxs
def get_conditions(sv, args, meta_dict):
channel_conditions = [[] for _ in range(sv.num_channels)]
for channel in range(sv.num_channels):
if args.use_metaf:
ssm = meta_dict['measure_sdm']
else:
# TODO See equivalent comment in main.py
pass
channel_conditions[channel] = [ssm]
conditions = [channel_conditions[c] for c in range(sv.num_channels)]
return conditions
def make_data_dict(args, sv, need_conditions):
'''
Returns a tuple.
Both outputs are lists of lists, one sublist for each channel
'''
data = {}
data["data"] = [Variable(torch.FloatTensor(1, 1).zero_().long() + sv.special_events["start"].i, volatile=True)]
if need_conditions:
data["conditions"] = [Variable(torch.LongTensor(1, 1).zero_(), volatile=True) for c in range(sv.num_channels)]
if args.cuda:
for c in range(sv.num_channels):
data["data"][c].data = data["data"][c].data.cuda()
if need_conditions:
data["conditions"][c].data = data["conditions"][c].data.cuda()
return data