From 0d5246a9b80936c54613ffade460c6e3a01c4527 Mon Sep 17 00:00:00 2001 From: furukawa Date: Sun, 7 Nov 2021 11:41:34 +0900 Subject: [PATCH 1/6] init --- data/librispeech.py | 43 +++++++------------------------------------ data/loaders.py | 20 ++++---------------- main.py | 10 +++++----- modules/audio/cpc.py | 2 ++ 4 files changed, 18 insertions(+), 57 deletions(-) diff --git a/data/librispeech.py b/data/librispeech.py index 06a3a92..7b8def3 100644 --- a/data/librispeech.py +++ b/data/librispeech.py @@ -4,6 +4,7 @@ import numpy as np from torch.utils.data import Dataset from collections import defaultdict +from glob import glob def default_loader(path): @@ -37,24 +38,14 @@ def __init__( self.root = root self.opt = opt - self.file_list, self.speaker_dict = flist_reader(flist) + self.file_list = glob(os.path.join(root, '*.npy')) self.loader = loader self.audio_length = audio_length - self.mean = -1456218.7500 - self.std = 135303504.0 - def __getitem__(self, index): - speaker_id, dir_id, sample_id = self.file_list[index] - filename = "{}-{}-{}".format(speaker_id, dir_id, sample_id) - audio, samplerate = self.loader( - os.path.join(self.root, speaker_id, dir_id, "{}.flac".format(filename)) - ) - - assert ( - samplerate == 16000 - ), "Watch out, samplerate is not consistent throughout the dataset!" + filename = self.file_list[index] + audio = torch.from_numpy(np.load(filename)).unsqueeze(0) # discard last part that is not a full 10ms max_length = audio.size(1) // 160 * 160 @@ -66,41 +57,21 @@ def __getitem__(self, index): audio = audio[:, start_idx : start_idx + self.audio_length] # normalize the audio samples - audio = (audio - self.mean) / self.std - return audio, filename, speaker_id, start_idx + return audio, filename, start_idx def __len__(self): return len(self.file_list) - def get_audio_by_speaker(self, speaker_id, batch_size): - batch_size = min(len(self.speaker_dict[speaker_id]), batch_size) - batch = torch.zeros(batch_size, 1, self.audio_length) - for idx in range(batch_size): - batch[idx, 0, :], _, _, _ = self.__getitem__( - self.speaker_dict[speaker_id][idx] - ) - - return batch - def get_full_size_test_item(self, index): """ get audio samples that cover the full length of the input files used for testing the phone classification performance """ - speaker_id, dir_id, sample_id = self.file_list[index] - filename = "{}-{}-{}".format(speaker_id, dir_id, sample_id) - audio, samplerate = self.loader( - os.path.join(self.root, speaker_id, dir_id, "{}.flac".format(filename)) - ) - - assert ( - samplerate == 16000 - ), "Watch out, samplerate is not consistent throughout the dataset!" + filename = self.file_list[index] + audio = torch.from_numpy(np.load(filename)).unsqueeze(0) ## discard last part that is not a full 10ms max_length = audio.size(1) // 160 * 160 audio = audio[:max_length] - audio = (audio - self.mean) / self.std - return audio, filename diff --git a/data/loaders.py b/data/loaders.py index a24da38..f92df7b 100644 --- a/data/loaders.py +++ b/data/loaders.py @@ -9,10 +9,7 @@ def librispeech_loader(opt, num_workers=16): print("Using Train / Val Split") train_dataset = LibriDataset( opt, - os.path.join( - opt.data_input_dir, - "LibriSpeech/train-clean-100", - ), + opt.data_input_dir, os.path.join( opt.data_input_dir, "LibriSpeech100_labels_split/train_val_train.txt" ), @@ -20,10 +17,7 @@ def librispeech_loader(opt, num_workers=16): test_dataset = LibriDataset( opt, - os.path.join( - opt.data_input_dir, - "LibriSpeech/train-clean-100", - ), + opt.data_input_dir, os.path.join( opt.data_input_dir, "LibriSpeech100_labels_split/train_val_val.txt" ), @@ -33,10 +27,7 @@ def librispeech_loader(opt, num_workers=16): print("Using Train+Val / Test Split") train_dataset = LibriDataset( opt, - os.path.join( - opt.data_input_dir, - "LibriSpeech/train-clean-100", - ), + opt.data_input_dir, os.path.join( opt.data_input_dir, "LibriSpeech100_labels_split/train_split.txt" ), @@ -44,10 +35,7 @@ def librispeech_loader(opt, num_workers=16): test_dataset = LibriDataset( opt, - os.path.join( - opt.data_input_dir, - "LibriSpeech/train-clean-100", - ), + opt.data_input_dir, os.path.join( opt.data_input_dir, "LibriSpeech100_labels_split/test_split.txt" ), diff --git a/main.py b/main.py index 51cc927..9fcfc49 100644 --- a/main.py +++ b/main.py @@ -29,7 +29,7 @@ def train(args, model, optimizer, writer): ) total_step = len(train_loader) - print_idx = 100 + print_idx = 10 # at which step to validate training validation_idx = 1000 @@ -40,12 +40,12 @@ def train(args, model, optimizer, writer): global_step = 0 for epoch in range(args.start_epoch, args.start_epoch + args.num_epochs): loss_epoch = 0 - for step, (audio, filename, _, start_idx) in enumerate(train_loader): + for step, (audio, filename, start_idx) in enumerate(train_loader): start_time = time.time() - if step % validation_idx == 0: - validate_speakers(args, train_dataset, model, optimizer, epoch, step, global_step, writer) + # if step % validation_idx == 0: + # validate_speakers(args, train_dataset, model, optimizer, epoch, step, global_step, writer) audio = audio.to(args.device) @@ -132,7 +132,7 @@ def main(_run, _log): args.time = time.ctime() # Device configuration - args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") args.current_epoch = args.start_epoch diff --git a/modules/audio/cpc.py b/modules/audio/cpc.py index 611f4d8..8f2a457 100644 --- a/modules/audio/cpc.py +++ b/modules/audio/cpc.py @@ -58,7 +58,9 @@ def get_latent_representations(self, x): def forward(self, x): + # x: (b, 1, 20480) z, c = self.get_latent_representations(x) + # z: (b, 128, 512) c: (b, 128, 256) loss, accuracy = self.loss.get(x, z, c) return loss, accuracy, z, c From 37df41e6788ab9a85844e4a277b9b9f7933f1509 Mon Sep 17 00:00:00 2001 From: furukawa Date: Sun, 7 Nov 2021 23:25:53 +0900 Subject: [PATCH 2/6] import --- modules/audio/infonce.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/audio/infonce.py b/modules/audio/infonce.py index e889058..35d98f3 100644 --- a/modules/audio/infonce.py +++ b/modules/audio/infonce.py @@ -3,7 +3,7 @@ Calculates the 'Info Noise-Contrastive-Estimation' as explained by Van den Oord et al. (2018), implementation by Bas Veeling & Sindy Lowe """ - +import numpy as np import torch import torch.nn as nn From a5d79d5bf1708e78b266460f89bad0ffaa1a4615 Mon Sep 17 00:00:00 2001 From: furukawa Date: Sun, 7 Nov 2021 23:26:17 +0900 Subject: [PATCH 3/6] add resnet implementation --- modules/audio/resnet.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 modules/audio/resnet.py diff --git a/modules/audio/resnet.py b/modules/audio/resnet.py new file mode 100644 index 0000000..399ede8 --- /dev/null +++ b/modules/audio/resnet.py @@ -0,0 +1,27 @@ +import torch.nn as nn +import torchvision.models as models + + +class ResNetSimCLR(nn.Module): + + def __init__(self, base_model, out_dim): + super(ResNetSimCLR, self).__init__() + self.resnet_dict = {"resnet18": models.resnet18(pretrained=False, num_classes=out_dim), + "resnet50": models.resnet50(pretrained=True)} + + self.backbone = self._get_basemodel(base_model) + self.backbone.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, + bias=False) + num_features = self.backbone.fc.in_features + self.backbone.fc = nn.Linear(num_features, out_dim) + dim_mlp = self.backbone.fc.in_features + + # add mlp projection head + self.backbone.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.backbone.fc) + + def _get_basemodel(self, model_name): + model = self.resnet_dict[model_name] + return model + + def forward(self, x): + return self.backbone(x) \ No newline at end of file From 3a809f632b74b5428a63649073d16803c9c2a445 Mon Sep 17 00:00:00 2001 From: furukawa Date: Sun, 7 Nov 2021 23:29:50 +0900 Subject: [PATCH 4/6] implement triplet loss --- data/librispeech.py | 27 +++++++++++++++++++++------ main.py | 14 ++++++++++---- modules/audio/cpc.py | 18 ++++++++++++------ modules/audio/model.py | 6 +++--- 4 files changed, 46 insertions(+), 19 deletions(-) diff --git a/data/librispeech.py b/data/librispeech.py index 7b8def3..3252089 100644 --- a/data/librispeech.py +++ b/data/librispeech.py @@ -42,26 +42,41 @@ def __init__( self.loader = loader self.audio_length = audio_length + self.mel_length = 320 + self.melspec = torchaudio.transforms.MelSpectrogram(n_fft=1024, win_length=1024, hop_length=256, + f_min=125, f_max=7600, n_mels=80, power=1, normalized=True) def __getitem__(self, index): filename = self.file_list[index] - audio = torch.from_numpy(np.load(filename)).unsqueeze(0) + raw_data = np.load(filename) + audio = torch.from_numpy(raw_data).unsqueeze(0) # discard last part that is not a full 10ms max_length = audio.size(1) // 160 * 160 - start_idx = np.random.choice( - np.arange(160, max_length - self.audio_length - 0, 160) - ) - audio = audio[:, start_idx : start_idx + self.audio_length] + start_idx = np.random.choice(np.arange(160, max_length - self.mel_length * 256 - 0, 160)) + pos_idx = np.random.choice(np.arange(160, max_length - self.mel_length * 256 - 0, 160)) + neg_audio_idx = np.random.choice(np.arange(len(self.file_list))) + neg_audio = torch.from_numpy(np.load(self.file_list[neg_audio_idx])).unsqueeze(0) + neg_max_length = neg_audio.size(1) // 160 * 160 + neg_idx = np.random.choice(np.arange(160, neg_max_length - self.mel_length * 256 - 0, 160)) + + anc_audio = audio[:, start_idx : start_idx + self.mel_length * 256] + pos_audio = audio[:, pos_idx : pos_idx + self.mel_length * 256] + neg_audio = neg_audio[:, neg_idx : neg_idx + self.mel_length * 256] + audio = audio[:, :self.audio_length] + anc_mel, pos_mel, neg_mel = self.get_mel(anc_audio), self.get_mel(pos_audio), self.get_mel(neg_audio) # normalize the audio samples - return audio, filename, start_idx + return audio, filename, start_idx, anc_mel, pos_mel, neg_mel def __len__(self): return len(self.file_list) + def get_mel(self, audio): + return torch.log(self.melspec(audio).clamp(min=1e-10))[:, :, :320] + def get_full_size_test_item(self, index): """ get audio samples that cover the full length of the input files diff --git a/main.py b/main.py index 9fcfc49..a68f23d 100644 --- a/main.py +++ b/main.py @@ -40,7 +40,7 @@ def train(args, model, optimizer, writer): global_step = 0 for epoch in range(args.start_epoch, args.start_epoch + args.num_epochs): loss_epoch = 0 - for step, (audio, filename, start_idx) in enumerate(train_loader): + for step, (audio, filename, start_idx, anc_mel, pos_mel, neg_mel) in enumerate(train_loader): start_time = time.time() @@ -50,10 +50,12 @@ def train(args, model, optimizer, writer): audio = audio.to(args.device) # forward - loss = model(audio) + nce_loss, trpl_loss = model(audio, anc_mel, pos_mel, neg_mel) + nce_loss = nce_loss.mean() + trpl_loss = trpl_loss.mean() # accumulate losses for all GPUs - loss = loss.mean() + loss = nce_loss + trpl_loss # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10) @@ -71,18 +73,22 @@ def train(args, model, optimizer, writer): examples_per_second = args.batch_size / (time.time() - start_time) print( "[Epoch {}/{}] Train step {:04d}/{:04d} \t Examples/s = {:.2f} \t " - "Loss = {:.4f} \t Time/step = {:.4f}".format( + "Loss = {:.4f} \t NCE = {:.4f} \t Triplet = {:.4f} \t Time/step = {:.4f}".format( epoch, args.num_epochs, step, len(train_loader), examples_per_second, loss, + nce_loss, + trpl_loss, time.time() - start_time, ) ) writer.add_scalar("Loss/train_step", loss, global_step) + writer.add_scalar("Loss/nce", nce_loss, global_step) + writer.add_scalar("Loss/triplet", trpl_loss, global_step) loss_epoch += loss global_step += 1 diff --git a/modules/audio/cpc.py b/modules/audio/cpc.py index 8f2a457..990f9f1 100644 --- a/modules/audio/cpc.py +++ b/modules/audio/cpc.py @@ -2,6 +2,7 @@ from .encoder import Encoder from .autoregressor import Autoregressor from .infonce import InfoNCE +from .resnet import ResNetSimCLR class CPC(torch.nn.Module): def __init__( @@ -23,6 +24,8 @@ def __init__( self.autoregressor = Autoregressor(args, input_dim=genc_hidden, hidden_dim=gar_hidden) self.loss = InfoNCE(args, gar_hidden, genc_hidden) + self.static_encoder = ResNetSimCLR('resnet50', 256) + self.triplet_loss = torch.nn.TripletMarginLoss(5.0) def get_latent_size(self, input_size): x = torch.zeros(input_size).to(self.args.device) @@ -51,16 +54,19 @@ def get_latent_representations(self, x): z = z.permute(0, 2, 1) # swap L and C # calculate latent representation from the autoregressor - c = self.autoregressor(z) + # c = self.autoregressor(z) # TODO checked - return z, c + return z - def forward(self, x): + def forward(self, x, anc, pos, neg): # x: (b, 1, 20480) - z, c = self.get_latent_representations(x) + z = self.get_latent_representations(x) + ca, cp, cn = self.static_encoder(anc), self.static_encoder(pos), self.static_encoder(neg) + ca, cp, cn = ca.squeeze(), cp.squeeze(), cn.squeeze() # z: (b, 128, 512) c: (b, 128, 256) - loss, accuracy = self.loss.get(x, z, c) - return loss, accuracy, z, c + loss, accuracy = self.loss.get(x, z, ca.unsqueeze(1).expand(-1, 128, -1)) + trpl_loss = self.triplet_loss(ca, cp, cn) + return loss, trpl_loss, accuracy, z, ca diff --git a/modules/audio/model.py b/modules/audio/model.py index 9e2f70e..6705b0c 100644 --- a/modules/audio/model.py +++ b/modules/audio/model.py @@ -26,8 +26,8 @@ def __init__( gar_hidden, ) - def forward(self, x): + def forward(self, x, anc_mel, pos_mel, neg_mel): """Forward through the network""" - loss, accuracy, _, z = self.model(x) - return loss + loss, trpl_loss, accuracy, _, z = self.model(x, anc_mel, pos_mel, neg_mel) + return loss, trpl_loss From 5c1e1a0ea9d121ae9dbd974dfd13cbab415830bb Mon Sep 17 00:00:00 2001 From: furukawa Date: Fri, 19 Nov 2021 22:36:00 +0900 Subject: [PATCH 5/6] impl tsne --- data/librispeech.py | 64 +++++++++++++++++++++++++++++++-------------- main.py | 9 ++++--- 2 files changed, 49 insertions(+), 24 deletions(-) diff --git a/data/librispeech.py b/data/librispeech.py index 7b8def3..cc25a3c 100644 --- a/data/librispeech.py +++ b/data/librispeech.py @@ -5,40 +5,53 @@ from torch.utils.data import Dataset from collections import defaultdict from glob import glob +import pandas as pd + +csv_input = pd.read_csv(filepath_or_buffer='/groups/1/gcc50521/furukawa/musicnet_metadata.csv', sep=",") +genre_to_id = { + 'Solo Piano': 0, 'String Quartet': 1, 'Accompanied Violin': 2, 'Piano Quartet': 3, 'Accompanied Cello': 4, + 'String Sextet': 5, 'Piano Trio': 6, 'Piano Quintet': 7, 'Wind Quintet': 8, 'Horn Piano Trio': 9, 'Wind Octet': 10, + 'Clarinet-Cello-Piano Trio': 11, 'Pairs Clarinet-Horn-Bassoon': 12, 'Clarinet Quintet': 13, 'Solo Cello': 14, + 'Accompanied Clarinet': 15, 'Solo Violin': 16, 'Violin and Harpsichord': 17, 'Viola Quinte': 18, 'Solo Flute': 19 +} +id_to_genre = {} +for idx, row in csv_input.iterrows(): + genre = row['ensemble'] + song_id = str(row['id']) + id_to_genre[song_id] = genre def default_loader(path): return torchaudio.load(path, normalization=False) -def default_flist_reader(flist): - item_list = [] +def default_flist_reader(root_dir): speaker_dict = defaultdict(list) - index = 0 - with open(flist, "r") as rf: - for line in rf.readlines(): - speaker_id, dir_id, sample_id = line.replace("\n", "").split("-") - item_list.append((speaker_id, dir_id, sample_id)) - speaker_dict[speaker_id].append(index) - index += 1 + item_list = [] + for index, x in enumerate(sorted(glob(os.path.join(root_dir, '*.npy')))): + filename = x.split('/')[-1] + speaker_id = id_to_genre[filename[:4]] + item_list.append(speaker_id) + speaker_dict[speaker_id].append(index) - return item_list, speaker_dict + return speaker_dict, item_list class LibriDataset(Dataset): def __init__( - self, - opt, - root, - flist, - audio_length=20480, - flist_reader=default_flist_reader, - loader=default_loader, + self, + opt, + root, + flist, + audio_length=20480, + flist_reader=default_flist_reader, + loader=default_loader, ): self.root = root self.opt = opt - self.file_list = glob(os.path.join(root, '*.npy')) + self.file_list = sorted(glob(os.path.join(root, '*.npy'))) + self.speaker_dict, self.item_list = flist_reader(root) self.loader = loader self.audio_length = audio_length @@ -46,6 +59,7 @@ def __init__( def __getitem__(self, index): filename = self.file_list[index] audio = torch.from_numpy(np.load(filename)).unsqueeze(0) + speaker_id = self.item_list[index] # discard last part that is not a full 10ms max_length = audio.size(1) // 160 * 160 @@ -54,14 +68,24 @@ def __getitem__(self, index): np.arange(160, max_length - self.audio_length - 0, 160) ) - audio = audio[:, start_idx : start_idx + self.audio_length] + audio = audio[:, start_idx: start_idx + self.audio_length] # normalize the audio samples - return audio, filename, start_idx + return audio, filename, speaker_id, start_idx def __len__(self): return len(self.file_list) + def get_audio_by_speaker(self, speaker_id, batch_size): + batch_size = min(len(self.speaker_dict[speaker_id]), batch_size) + batch = torch.zeros(batch_size, 1, self.audio_length) + for idx in range(batch_size): + batch[idx, 0, :], _, _, _ = self.__getitem__( + self.speaker_dict[speaker_id][idx] + ) + + return batch + def get_full_size_test_item(self, index): """ get audio samples that cover the full length of the input files diff --git a/main.py b/main.py index 9fcfc49..d07d9be 100644 --- a/main.py +++ b/main.py @@ -40,12 +40,12 @@ def train(args, model, optimizer, writer): global_step = 0 for epoch in range(args.start_epoch, args.start_epoch + args.num_epochs): loss_epoch = 0 - for step, (audio, filename, start_idx) in enumerate(train_loader): + for step, (audio, filename, _, start_idx) in enumerate(train_loader): start_time = time.time() - # if step % validation_idx == 0: - # validate_speakers(args, train_dataset, model, optimizer, epoch, step, global_step, writer) + if step % validation_idx == 0: + validate_speakers(args, train_dataset, model, optimizer, epoch, step, global_step, writer) audio = audio.to(args.device) @@ -113,7 +113,8 @@ def train(args, model, optimizer, writer): save_model(args, model, optimizer, best=True) # save current model state - save_model(args, model, optimizer) + if args.current_epoch % 50 == 0: + save_model(args, model, optimizer) args.current_epoch += 1 From 8e10eddc080f1ad7eb5639ae5a895bf06a829e6a Mon Sep 17 00:00:00 2001 From: Kohei Furukawa Date: Thu, 25 Nov 2021 12:31:59 +0900 Subject: [PATCH 6/6] configure logistic regression --- data/librispeech.py | 10 ++++++---- main.py | 12 +++--------- modules/audio/cpc.py | 18 ++++++------------ modules/audio/model.py | 6 +++--- modules/audio/speaker_loss.py | 3 ++- train_classifier.sh | 5 +++-- validation/validate_speakers.py | 4 ++-- 7 files changed, 25 insertions(+), 33 deletions(-) diff --git a/data/librispeech.py b/data/librispeech.py index f3a721d..b381291 100644 --- a/data/librispeech.py +++ b/data/librispeech.py @@ -12,7 +12,8 @@ 'Solo Piano': 0, 'String Quartet': 1, 'Accompanied Violin': 2, 'Piano Quartet': 3, 'Accompanied Cello': 4, 'String Sextet': 5, 'Piano Trio': 6, 'Piano Quintet': 7, 'Wind Quintet': 8, 'Horn Piano Trio': 9, 'Wind Octet': 10, 'Clarinet-Cello-Piano Trio': 11, 'Pairs Clarinet-Horn-Bassoon': 12, 'Clarinet Quintet': 13, 'Solo Cello': 14, - 'Accompanied Clarinet': 15, 'Solo Violin': 16, 'Violin and Harpsichord': 17, 'Viola Quinte': 18, 'Solo Flute': 19 + 'Accompanied Clarinet': 15, 'Solo Violin': 16, 'Violin and Harpsichord': 17, 'Viola Quintet': 18, 'Solo Flute': 19, + 'Wind and Strings Octet': 20 } id_to_genre = {} for idx, row in csv_input.iterrows(): @@ -55,9 +56,6 @@ def __init__( self.loader = loader self.audio_length = audio_length - self.mel_length = 320 - self.melspec = torchaudio.transforms.MelSpectrogram(n_fft=1024, win_length=1024, hop_length=256, - f_min=125, f_max=7600, n_mels=80, power=1, normalized=True) def __getitem__(self, index): filename = self.file_list[index] @@ -67,6 +65,10 @@ def __getitem__(self, index): # discard last part that is not a full 10ms max_length = audio.size(1) // 160 * 160 + start_idx = np.random.choice( + np.arange(160, max_length - self.audio_length - 0, 160) + ) + audio = audio[:, start_idx: start_idx + self.audio_length] # normalize the audio samples diff --git a/main.py b/main.py index 3f77c55..d07d9be 100644 --- a/main.py +++ b/main.py @@ -50,12 +50,10 @@ def train(args, model, optimizer, writer): audio = audio.to(args.device) # forward - nce_loss, trpl_loss = model(audio, anc_mel, pos_mel, neg_mel) - nce_loss = nce_loss.mean() - trpl_loss = trpl_loss.mean() + loss = model(audio) # accumulate losses for all GPUs - loss = nce_loss + trpl_loss + loss = loss.mean() # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10) @@ -73,22 +71,18 @@ def train(args, model, optimizer, writer): examples_per_second = args.batch_size / (time.time() - start_time) print( "[Epoch {}/{}] Train step {:04d}/{:04d} \t Examples/s = {:.2f} \t " - "Loss = {:.4f} \t NCE = {:.4f} \t Triplet = {:.4f} \t Time/step = {:.4f}".format( + "Loss = {:.4f} \t Time/step = {:.4f}".format( epoch, args.num_epochs, step, len(train_loader), examples_per_second, loss, - nce_loss, - trpl_loss, time.time() - start_time, ) ) writer.add_scalar("Loss/train_step", loss, global_step) - writer.add_scalar("Loss/nce", nce_loss, global_step) - writer.add_scalar("Loss/triplet", trpl_loss, global_step) loss_epoch += loss global_step += 1 diff --git a/modules/audio/cpc.py b/modules/audio/cpc.py index 990f9f1..8f2a457 100644 --- a/modules/audio/cpc.py +++ b/modules/audio/cpc.py @@ -2,7 +2,6 @@ from .encoder import Encoder from .autoregressor import Autoregressor from .infonce import InfoNCE -from .resnet import ResNetSimCLR class CPC(torch.nn.Module): def __init__( @@ -24,8 +23,6 @@ def __init__( self.autoregressor = Autoregressor(args, input_dim=genc_hidden, hidden_dim=gar_hidden) self.loss = InfoNCE(args, gar_hidden, genc_hidden) - self.static_encoder = ResNetSimCLR('resnet50', 256) - self.triplet_loss = torch.nn.TripletMarginLoss(5.0) def get_latent_size(self, input_size): x = torch.zeros(input_size).to(self.args.device) @@ -54,19 +51,16 @@ def get_latent_representations(self, x): z = z.permute(0, 2, 1) # swap L and C # calculate latent representation from the autoregressor - # c = self.autoregressor(z) + c = self.autoregressor(z) # TODO checked - return z + return z, c - def forward(self, x, anc, pos, neg): + def forward(self, x): # x: (b, 1, 20480) - z = self.get_latent_representations(x) - ca, cp, cn = self.static_encoder(anc), self.static_encoder(pos), self.static_encoder(neg) - ca, cp, cn = ca.squeeze(), cp.squeeze(), cn.squeeze() + z, c = self.get_latent_representations(x) # z: (b, 128, 512) c: (b, 128, 256) - loss, accuracy = self.loss.get(x, z, ca.unsqueeze(1).expand(-1, 128, -1)) - trpl_loss = self.triplet_loss(ca, cp, cn) - return loss, trpl_loss, accuracy, z, ca + loss, accuracy = self.loss.get(x, z, c) + return loss, accuracy, z, c diff --git a/modules/audio/model.py b/modules/audio/model.py index 6705b0c..9e2f70e 100644 --- a/modules/audio/model.py +++ b/modules/audio/model.py @@ -26,8 +26,8 @@ def __init__( gar_hidden, ) - def forward(self, x, anc_mel, pos_mel, neg_mel): + def forward(self, x): """Forward through the network""" - loss, trpl_loss, accuracy, _, z = self.model(x, anc_mel, pos_mel, neg_mel) - return loss, trpl_loss + loss, accuracy, _, z = self.model(x) + return loss diff --git a/modules/audio/speaker_loss.py b/modules/audio/speaker_loss.py index 18becc3..8e8a005 100644 --- a/modules/audio/speaker_loss.py +++ b/modules/audio/speaker_loss.py @@ -2,6 +2,7 @@ import torch from data import loaders +from data.librispeech import genre_to_id, id_to_genre class Speaker_Loss(nn.Module): def __init__(self, args, hidden_dim, calc_accuracy): @@ -38,7 +39,7 @@ def calc_supervised_speaker_loss(self, c, filename): targets = torch.zeros(len(filename)).long() for idx, _ in enumerate(filename): - targets[idx] = self.speaker_id_dict[filename[idx].split("-")[0]] + targets[idx] = torch.tensor(genre_to_id[id_to_genre[filename[idx].split("/")[-1][:4]]]) targets = targets.to(self.args.device).squeeze() # forward pass diff --git a/train_classifier.sh b/train_classifier.sh index 9182787..85675ae 100755 --- a/train_classifier.sh +++ b/train_classifier.sh @@ -2,8 +2,9 @@ python -m testing.logistic_regression_speaker \ with \ - model_path=./logs/cpc_audio_baseline \ - model_num=299 \ + data_input_dir=/groups/1/gcc50521/furukawa/musicnet_npy_10sec \ + model_path=/groups/1/gcc50521/furukawa/cpc_logs/26 \ + model_num=450 \ fp16=False # python -m testing.logistic_regression_phones \ diff --git a/validation/validate_speakers.py b/validation/validate_speakers.py index 7d601a9..a7801cd 100644 --- a/validation/validate_speakers.py +++ b/validation/validate_speakers.py @@ -25,7 +25,7 @@ def tsne(args, features): def validate_speakers(args, dataset, model, optimizer, epoch, step, global_step, writer): - max_speakers = 10 + max_speakers = 20 batch_size = 40 input_size = (args.batch_size, 1, 20480) @@ -40,7 +40,7 @@ def validate_speakers(args, dataset, model, optimizer, epoch, step, global_step, labels = torch.zeros(max_speakers, batch_size).to(args.device) for idx, speaker_idx in enumerate(dataset.speaker_dict): - if idx == 10: + if idx == 20: break model_in = dataset.get_audio_by_speaker(speaker_idx, batch_size=batch_size)