From 076a35f2063ce2944c0e942240d7b8170f242064 Mon Sep 17 00:00:00 2001 From: Jaskaran Singh Date: Wed, 24 Mar 2021 18:33:50 +0530 Subject: [PATCH 1/7] mono loader added --- src/configs.py | 4 +++- src/train2.py | 62 +++++++++++++++++++++++++++++++++++++++++--------- 2 files changed, 54 insertions(+), 12 deletions(-) diff --git a/src/configs.py b/src/configs.py index c34489e..b5db660 100644 --- a/src/configs.py +++ b/src/configs.py @@ -40,7 +40,9 @@ class config: Language=Gujrati #select the language - + mono=True + mono_train_path="./" + mono_test_path="./" def get_all_params_dict(config): params = {} diff --git a/src/train2.py b/src/train2.py index f1211ef..c659e96 100644 --- a/src/train2.py +++ b/src/train2.py @@ -15,6 +15,32 @@ import wandb import random + +class MonoData(Dataset): + def __init__(self,path): + self.path=path + self.file=open(path+'/transcription.txt','r',encoding='UTF-8').read().split("\n") + + def __len__(self): + return len(self.file) + + def __getitem__(self,index): + audio,text=self.file[index].split(' ',1) + audio=self.path+'/Audios/'+audio + return audio,text + +def mono_collate_fn(batch, tokenizer): + + speech_lis = [sf.read(elem[0])[0] for elem in batch] + text_lis = [elem[1] for elem in batch] + + input_values = tokenizer(speech_lis, return_tensors="pt", + padding='longest').input_values + + labels, label_lengths = tokenizer.batch_tokenize(text_lis) + + return (input_values.to(config.device), labels.to(config.device), label_lengths.to(config.device)) + def find_lengths(logits, pad_id: int) -> torch.FloatTensor: """ Function to find lengths of output sequences @@ -141,10 +167,14 @@ def compute_metric(model, tokenizer, test_dataset): show_sample_no = random.randint(1, len(test_dataset)-1) with torch.no_grad(): for i, d in enumerate(pbar): - - input_values = tokenizer(d["speech"], return_tensors="pt", + + if not config.mono: + input_values = tokenizer(d["speech"], return_tensors="pt", padding='longest').input_values.to(config.device) - + else: + input_values = tokenizer(sf.read(d["speech"])[0], return_tensors="pt", + padding='longest').input_values.to(config.device) + logits = model(input_values).logits predicted_ids = torch.argmax(logits, dim=-1).cpu() @@ -193,11 +223,16 @@ def collate_fn(batch, tokenizer): params = {'batch_size': config.BATCH_SIZE,} print("running on ", config.device) - - train_dataset = load_dataset(config.data_loading_script, data_dir=config.data_dir, split="train[2%:]", writer_batch_size=1000) - val_dataset = load_dataset(config.data_loading_script, data_dir=config.data_dir, split="train[:2%]", writer_batch_size=1000) - test_dataset = load_dataset(config.data_loading_script, data_dir=config.data_dir, split="test", writer_batch_size=1000) + if not config.mono: + train_dataset = load_dataset(config.data_loading_script, data_dir=config.data_dir, split="train[2%:]", writer_batch_size=1000) + val_dataset = load_dataset(config.data_loading_script, data_dir=config.data_dir, split="train[:2%]", writer_batch_size=1000) + test_dataset = load_dataset(config.data_loading_script, data_dir=config.data_dir, split="test", writer_batch_size=1000) + else: + train_dataset=MonoData(path=config.mono_train_path) + test_dataset=MonoData(path=config.mono_test_path) + val_dataset=test_dataset + if config.use_monolingual: mono_dataset = load_dataset(config.data_loading_script, data_dir=config.monolingual_data_dir, split="train", writer_batch_size=1000) mono_dataloader = torch.utils.data.DataLoader(dataset=mono_dataset, collate_fn= lambda b: collate_fn(b, tokenizer), **params) @@ -205,10 +240,15 @@ def collate_fn(batch, tokenizer): mono_dataloader = None if(config.train): - train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, collate_fn= lambda b: collate_fn(b, tokenizer), **params) - val_dataloader = torch.utils.data.DataLoader(dataset=val_dataset, collate_fn= lambda b: collate_fn(b, tokenizer), **params) - train_model(model, tokenizer, train_dataloader, val_dataloader, test_dataset, mono_dataloader) - + if not config.mono: + train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, collate_fn= lambda b: collate_fn(b, tokenizer), **params) + val_dataloader = torch.utils.data.DataLoader(dataset=val_dataset, collate_fn= lambda b: collate_fn(b, tokenizer), **params) + train_model(model, tokenizer, train_dataloader, val_dataloader, test_dataset, mono_dataloader) + else: + train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, collate_fn= lambda b: mono_collate_fn(b, tokenizer), **params) + val_dataloader = torch.utils.data.DataLoader(dataset=val_dataset, collate_fn= lambda b: mono_collate_fn(b, tokenizer), **params) + train_model(model, tokenizer, train_dataloader, val_dataloader, test_dataset, mono_dataloader) + if(config.eval): print(compute_metric(model, tokenizer, test_dataset)) From 7c37bc2c334b9d3db5327170651ad30222e3d664 Mon Sep 17 00:00:00 2001 From: Jaskaran Singh Date: Wed, 24 Mar 2021 13:43:56 +0000 Subject: [PATCH 2/7] fix --- src/configs.py | 19 ++++++++++--------- src/train2.py | 14 ++++++++------ 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/src/configs.py b/src/configs.py index b5db660..20841f0 100644 --- a/src/configs.py +++ b/src/configs.py @@ -11,19 +11,19 @@ class config: model="facebook/wav2vec2-base-960h" fast_LR=1e-3 #To be used when initial weights are frozen - LR=1e-6 + LR=1e-5 clip_grad_norm=1.0 - EPOCHS=0 - num_iters_checkpoint=70 + EPOCHS=100 + num_iters_checkpoint=70000 prev_checkpoint="" output_directory="./model/" os.makedirs(output_directory, exist_ok=True) - BATCH_SIZE=5 + BATCH_SIZE=1 SHUFFLE=False - eval=False - train=False + eval=True + train=True device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') max_audio_len=576000 freeze_for_epochs=0 @@ -38,11 +38,11 @@ class config: Hindi=(2304,2431+1) Marathi=Hindi - Language=Gujrati #select the language + Language=Marathi #select the language mono=True - mono_train_path="./" - mono_test_path="./" + mono_train_path="/home/krishnarajule3/ASR/data/Marathi/train" + mono_test_path="/home/krishnarajule3/ASR/data/Marathi/test" def get_all_params_dict(config): params = {} @@ -50,3 +50,4 @@ def get_all_params_dict(config): if not ( callable(v) or (k.startswith('__') and k.endswith('__'))): params[k]=v return params + diff --git a/src/train2.py b/src/train2.py index c659e96..f7b9089 100644 --- a/src/train2.py +++ b/src/train2.py @@ -16,17 +16,17 @@ import random -class MonoData(Dataset): +class MonoData(torch.utils.data.Dataset): def __init__(self,path): self.path=path - self.file=open(path+'/transcription.txt','r',encoding='UTF-8').read().split("\n") + self.file=open(path+'/transcription.txt','r',encoding='UTF-8').read().rstrip().split("\n") def __len__(self): return len(self.file) def __getitem__(self,index): audio,text=self.file[index].split(' ',1) - audio=self.path+'/Audios/'+audio + audio=self.path+'/audio/'+audio+'.wav' return audio,text def mono_collate_fn(batch, tokenizer): @@ -172,7 +172,7 @@ def compute_metric(model, tokenizer, test_dataset): input_values = tokenizer(d["speech"], return_tensors="pt", padding='longest').input_values.to(config.device) else: - input_values = tokenizer(sf.read(d["speech"])[0], return_tensors="pt", + input_values = tokenizer(sf.read(d[0])[0], return_tensors="pt", padding='longest').input_values.to(config.device) logits = model(input_values).logits @@ -180,8 +180,10 @@ def compute_metric(model, tokenizer, test_dataset): predicted_ids = torch.argmax(logits, dim=-1).cpu() transcriptions = tokenizer.batch_decode(predicted_ids) transcriptions = tokenizer.revert_transliteration(transcriptions) - - reference = d['text'].upper() + if not config.mono: + reference = d['text'].upper() + else: + reference= d[1].upper() if i==show_sample_no or i==0: print("Sample prediction: ", transcriptions[0]) From e2cb63f107cb8323a5fb928622fc92ab5d69f924 Mon Sep 17 00:00:00 2001 From: Jaskaran Singh Date: Sun, 28 Mar 2021 19:06:09 +0000 Subject: [PATCH 3/7] fix --- src/configs.py | 27 +++++++++++++++------------ src/model.py | 0 src/tokenizer.py | 2 +- src/train2.py | 7 +++++-- 4 files changed, 21 insertions(+), 15 deletions(-) mode change 100644 => 100755 src/configs.py mode change 100644 => 100755 src/model.py mode change 100644 => 100755 src/tokenizer.py mode change 100644 => 100755 src/train2.py diff --git a/src/configs.py b/src/configs.py old mode 100644 new mode 100755 index 20841f0..8e996c8 --- a/src/configs.py +++ b/src/configs.py @@ -12,37 +12,38 @@ class config: model="facebook/wav2vec2-base-960h" fast_LR=1e-3 #To be used when initial weights are frozen LR=1e-5 - clip_grad_norm=1.0 - EPOCHS=100 - num_iters_checkpoint=70000 - prev_checkpoint="" + clip_grad_norm=3.0 + EPOCHS=1000 + num_iters_checkpoint=50000 + prev_checkpoint="/home/jaskaransingh101010/indic-asr/src/wandb/run-20210324_174000-ymcpyuqp/files/facebook/wav2vec2-base-960h_99/" output_directory="./model/" os.makedirs(output_directory, exist_ok=True) - BATCH_SIZE=1 - SHUFFLE=False + BATCH_SIZE=2 + SHUFFLE=True eval=True train=True device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') max_audio_len=576000 freeze_for_epochs=0 transliterate=False - + cur_epoch=0 + #Support for additional languages Telugu=(3072,3199+1) Tamil=(2944,3071+1) - Oriya=(2816,2943+1) - Gujrati=(2688,2815+1) + Odia=(2816,2943+1) + Gujarati=(2688,2815+1) Hindi=(2304,2431+1) Marathi=Hindi - Language=Marathi #select the language + Language=Hindi #select the language mono=True - mono_train_path="/home/krishnarajule3/ASR/data/Marathi/train" - mono_test_path="/home/krishnarajule3/ASR/data/Marathi/test" + mono_train_path="/home/krishnarajule3/ASR/data/Hindi/train" + mono_test_path="/home/krishnarajule3/ASR/data/Hindi/test" def get_all_params_dict(config): params = {} @@ -51,3 +52,5 @@ def get_all_params_dict(config): params[k]=v return params + + diff --git a/src/model.py b/src/model.py old mode 100644 new mode 100755 diff --git a/src/tokenizer.py b/src/tokenizer.py old mode 100644 new mode 100755 index 52938c5..bb5fd50 --- a/src/tokenizer.py +++ b/src/tokenizer.py @@ -45,7 +45,7 @@ def transliterate(self, text: str)-> str: def remove_sos(self, texts: List[str]) -> List[str]: processed_texts = [] for text in texts: - processed_texts.append(text.replace('','').replace('','')) + processed_texts.append(text.replace('','').replace('','')) return processed_texts def revert_transliteration(self, texts: List[str])->str: diff --git a/src/train2.py b/src/train2.py old mode 100644 new mode 100755 index f7b9089..4a36e47 --- a/src/train2.py +++ b/src/train2.py @@ -163,7 +163,7 @@ def compute_metric(model, tokenizer, test_dataset): metric = load_metric('wer') pbar = tqdm(test_dataset, desc="Computing metric") - + score=[] show_sample_no = random.randint(1, len(test_dataset)-1) with torch.no_grad(): for i, d in enumerate(pbar): @@ -192,7 +192,8 @@ def compute_metric(model, tokenizer, test_dataset): metric.add_batch(predictions=transcriptions, references=[reference]) - score = metric.compute() + score.append(metric.compute()) + score=sum(score)/len(score) print("Evaluation metric: ", score) return score @@ -241,6 +242,8 @@ def collate_fn(batch, tokenizer): else: mono_dataloader = None + print(compute_metric(model, tokenizer, test_dataset)) + if(config.train): if not config.mono: train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, collate_fn= lambda b: collate_fn(b, tokenizer), **params) From ef6021309813cd8c138377960e2a330e40ace34c Mon Sep 17 00:00:00 2001 From: Jaskaran Singh Date: Tue, 6 Apr 2021 16:00:48 +0530 Subject: [PATCH 4/7] mono added --- src/Monodataset.py | 14 ++++++++++++++ src/configs.py | 27 +++++++++++++++++---------- src/tokenizer.py | 10 +++++++--- src/train2.py | 38 +++++++++++++++----------------------- 4 files changed, 53 insertions(+), 36 deletions(-) create mode 100644 src/Monodataset.py diff --git a/src/Monodataset.py b/src/Monodataset.py new file mode 100644 index 0000000..e041100 --- /dev/null +++ b/src/Monodataset.py @@ -0,0 +1,14 @@ +import os + +class MonoData(Dataset): + def __init__(self,path): + self.path=path + self.file=open(path+'/transcription.txt','r',encoding='UTF-8').read().split("\n") + + def __len__(self): + return len(self.file) + + def __getitem__(self,index): + audio,text=self.file[index].split(' ',1) + audio=self.path+'/Audios/'+audio + return audio,text diff --git a/src/configs.py b/src/configs.py index b5db660..47ebab5 100644 --- a/src/configs.py +++ b/src/configs.py @@ -3,14 +3,10 @@ class config: - data_dir="/home/krishnarajule3/ASR/data/Hindi-English/" - data_loading_script="/home/datasets/code_switch_asr" - - use_monolingual=False - monolingual_data_dir="/home/krishnarajule3/ASR/data/Hindi/" + model="facebook/wav2vec2-base-960h" - fast_LR=1e-3 #To be used when initial weights are frozen + fast_LR=1e-3 #To be used when initial weights are frozen LR=1e-6 clip_grad_norm=1.0 EPOCHS=0 @@ -36,14 +32,25 @@ class config: Oriya=(2816,2943+1) Gujrati=(2688,2815+1) Hindi=(2304,2431+1) + Bengali=(2433,2554+1) Marathi=Hindi - Language=Gujrati #select the language + Language=[Gujrati] #select the language (can add multiple languages to the list) + + #Mono-Language Training - mono=True - mono_train_path="./" - mono_test_path="./" + mono=True #to specify training for the monolingual language (to use mono dataset) + mono_train_path="./" #path to training folder + mono_test_path="./" #path to testing folder + + #Code Switched Training (set mono=False, to use code-switched loader.py) + + data_dir="/home/krishnarajule3/ASR/data/Hindi-English/" + data_loading_script="/home/datasets/code_switch_asr" + use_monolingual=False + monolingual_data_dir="/home/krishnarajule3/ASR/data/Hindi/" + def get_all_params_dict(config): params = {} for k, v in config.__dict__.items(): diff --git a/src/tokenizer.py b/src/tokenizer.py index 52938c5..78b2d60 100644 --- a/src/tokenizer.py +++ b/src/tokenizer.py @@ -18,8 +18,9 @@ class Wav2Vec2Tok(Wav2Vec2Tokenizer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if not config.transliterate: - for i in range(config.Language[0], config.Language[1]) : - self._add_tokens(chr(i)) + for lang in config.Language: + for i in range(lang[0], lang[1]) : + self._add_tokens(chr(i)) else: self.en_dict = enchant.Dict("en_US") for elem in ['̄', '̣', '̐', '́', '़', "'ॉ", '̃', '_', 'ऑ', '^', '…', '°', '̂', '̱', 'ॅ', 'ऍ', ':']: @@ -76,7 +77,10 @@ def tokenize(self, text: str, **kwargs) -> List[int]: """ if config.transliterate: text = self.transliterate(text) - + else: + for k,v in self.mappings.items(): + text = text.replace(k, v) + text = ' '.join(text.split()) text = text.replace(' ', self.word_delimiter_token) tokens = [self.bos_token_id] diff --git a/src/train2.py b/src/train2.py index c659e96..44c377a 100644 --- a/src/train2.py +++ b/src/train2.py @@ -1,33 +1,21 @@ +import os +import itertools +import soundfile as sf +import wandb +import random +import argparse + import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F -import os -import itertools -import soundfile as sf -import argparse from tqdm import tqdm from configs import config, get_all_params_dict from model import get_model from tokenizer import Wav2Vec2Tok from datasets import load_dataset, load_metric -import wandb -import random - - -class MonoData(Dataset): - def __init__(self,path): - self.path=path - self.file=open(path+'/transcription.txt','r',encoding='UTF-8').read().split("\n") - - def __len__(self): - return len(self.file) - - def __getitem__(self,index): - audio,text=self.file[index].split(' ',1) - audio=self.path+'/Audios/'+audio - return audio,text +from Monodataset import MonoData def mono_collate_fn(batch, tokenizer): @@ -179,8 +167,12 @@ def compute_metric(model, tokenizer, test_dataset): predicted_ids = torch.argmax(logits, dim=-1).cpu() transcriptions = tokenizer.batch_decode(predicted_ids) - transcriptions = tokenizer.revert_transliteration(transcriptions) - + if config.transliterate: + transcriptions = tokenizer.revert_transliteration(transcriptions) + else: + for k,v in self.mappings.items(): + text = text.replace(v.strip(),k) + reference = d['text'].upper() if i==show_sample_no or i==0: @@ -247,7 +239,7 @@ def collate_fn(batch, tokenizer): else: train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, collate_fn= lambda b: mono_collate_fn(b, tokenizer), **params) val_dataloader = torch.utils.data.DataLoader(dataset=val_dataset, collate_fn= lambda b: mono_collate_fn(b, tokenizer), **params) - train_model(model, tokenizer, train_dataloader, val_dataloader, test_dataset, mono_dataloader) + train_model(model, tokenizer, train_dataloader, val_dataloader, test_dataset) if(config.eval): print(compute_metric(model, tokenizer, test_dataset)) From f04c466da318866cbce22ded1372b3853dbdc62a Mon Sep 17 00:00:00 2001 From: Jaskaran Singh Date: Fri, 9 Apr 2021 14:54:48 +0530 Subject: [PATCH 5/7] multilingual asr support --- src/configs.py | 10 ++--- src/train2.py | 104 +++++++++++++++++++++++++++++-------------------- 2 files changed, 67 insertions(+), 47 deletions(-) diff --git a/src/configs.py b/src/configs.py index b58ed48..c1dbb40 100644 --- a/src/configs.py +++ b/src/configs.py @@ -3,7 +3,7 @@ class config: - +# 'facebook/wav2vec2-large-xlsr-53' model="facebook/wav2vec2-base-960h" fast_LR=1e-3 #To be used when initial weights are frozen @@ -37,13 +37,13 @@ class config: Bengali=(2433,2554+1) Marathi=Hindi - Language=[Gujrati] #select the language (can add multiple languages to the list) + Language=[Hindi,Gujrati,Telugu,Tamil,Odia] #select the language (can add multiple languages to the list) #Mono-Language Training - mono=True #to specify training for the monolingual language (to use mono dataset) - mono_train_path="./" #path to training folder - mono_test_path="./" #path to testing folder + mono=True #to specify training for the monolingual language (to use mono dataset) + mono_train_path=["./",] #path to training folder + mono_test_path=["./",] #path to testing folder #Code Switched Training (set mono=False, to use code-switched loader.py) diff --git a/src/train2.py b/src/train2.py index 250a2b4..462bb4d 100644 --- a/src/train2.py +++ b/src/train2.py @@ -149,44 +149,55 @@ def eval_model(model, tokenizer, val_dataloader): return (epoch_loss / num_valid_batches) def compute_metric(model, tokenizer, test_dataset): - metric = load_metric('wer') - - pbar = tqdm(test_dataset, desc="Computing metric") - score=[] - show_sample_no = random.randint(1, len(test_dataset)-1) + + wer_score=[] + model.eval() + if not isinstance(test_dataset,list): + test_dataset=[test_dataset] + with torch.no_grad(): - for i, d in enumerate(pbar): - - if not config.mono: - input_values = tokenizer(d["speech"], return_tensors="pt", - padding='longest').input_values.to(config.device) - else: - input_values = tokenizer(sf.read(d[0])[0], return_tensors="pt", - padding='longest').input_values.to(config.device) - - logits = model(input_values).logits - - predicted_ids = torch.argmax(logits, dim=-1).cpu() - transcriptions = tokenizer.batch_decode(predicted_ids) - - if config.transliterate: - transcriptions = tokenizer.revert_transliteration(transcriptions) - else: - for k,v in self.mappings.items(): - text = text.replace(v.strip(),k) - - reference = d['text'].upper() - - if i==show_sample_no or i==0: - print("Sample prediction: ", transcriptions[0]) - print("Sample reference: ", reference) - - metric.add_batch(predictions=transcriptions, - references=[reference]) + for dataset in test_dataset: + metric = load_metric('wer') + pbar = tqdm(dataset, desc="Computing metric") + score=[] + show_sample_no = random.randint(1, len(dataset)-1) + with torch.no_grad(): + for i, d in enumerate(pbar): + + if not config.mono: + input_values = tokenizer(d["speech"], return_tensors="pt", + padding='longest').input_values.to(config.device) + else: + input_values = tokenizer(sf.read(d[0])[0], return_tensors="pt", + padding='longest').input_values.to(config.device) + + logits = model(input_values).logits + + predicted_ids = torch.argmax(logits, dim=-1).cpu() + transcriptions = tokenizer.batch_decode(predicted_ids) + + if config.transliterate: + transcriptions = tokenizer.revert_transliteration(transcriptions) + else: + for k,v in self.mappings.items(): + text = text.replace(v.strip(),k) + + reference = d['text'].upper() + + if i==show_sample_no or i==0: + print("Sample prediction: ", transcriptions[0]) + print("Sample reference: ", reference) + + metric.add_batch(predictions=transcriptions, + references=[reference]) + + score=metric.compute() + print("Evaluation metric: ", score) + wer_score.append(score) + + print(wer_score) - score.append(metric.compute()) - score=sum(score)/len(score) - print("Evaluation metric: ", score) + score=sum(wer_score)/len(wer_score) return score def collate_fn(batch, tokenizer): @@ -215,7 +226,7 @@ def collate_fn(batch, tokenizer): if(config.prev_checkpoint!=""): model=load_checkpoint(model,config.prev_checkpoint) - params = {'batch_size': config.BATCH_SIZE,} + params = {'batch_size': config.BATCH_SIZE,'shuffle': config.SHUFFLE} print("running on ", config.device) @@ -224,9 +235,20 @@ def collate_fn(batch, tokenizer): val_dataset = load_dataset(config.data_loading_script, data_dir=config.data_dir, split="train[:2%]", writer_batch_size=1000) test_dataset = load_dataset(config.data_loading_script, data_dir=config.data_dir, split="test", writer_batch_size=1000) else: - train_dataset=MonoData(path=config.mono_train_path) - test_dataset=MonoData(path=config.mono_test_path) - val_dataset=test_dataset + train_dataset=[] + test_dataset=[] + for i,j in zip(config.mono_train_path,config.mono_test_path): + train_dataset.append(MonoData(path=i)) + test_dataset.append(MonoData(path=j)) + + if(len(config.mono_train_path)>1): + train_dataset=ConcatDataset(train_dataset) + else: + train_dataset=train_dataset[0] + if(len(config.mono_test_path)>1): + val_dataset=ConcatDataset(test_dataset) + else: + val_dataset=test_dataset[0] if config.use_monolingual: mono_dataset = load_dataset(config.data_loading_script, data_dir=config.monolingual_data_dir, split="train", writer_batch_size=1000) @@ -234,8 +256,6 @@ def collate_fn(batch, tokenizer): else: mono_dataloader = None - print(compute_metric(model, tokenizer, test_dataset)) - if(config.train): if not config.mono: train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, collate_fn= lambda b: collate_fn(b, tokenizer), **params) From 30cc841891112aa5de861c8d1f939794839c72b3 Mon Sep 17 00:00:00 2001 From: Jaskaran Singh Date: Fri, 9 Apr 2021 12:59:04 +0000 Subject: [PATCH 6/7] fix after running --- src/Monodataset.py | 7 ++++--- src/configs.py | 22 ++++++++++++++-------- src/train2.py | 45 +++++++++++++++++++++++++-------------------- 3 files changed, 43 insertions(+), 31 deletions(-) mode change 100644 => 100755 src/Monodataset.py mode change 100644 => 100755 src/configs.py mode change 100644 => 100755 src/train2.py diff --git a/src/Monodataset.py b/src/Monodataset.py old mode 100644 new mode 100755 index e041100..f933dd9 --- a/src/Monodataset.py +++ b/src/Monodataset.py @@ -1,14 +1,15 @@ import os +from torch.utils.data import Dataset class MonoData(Dataset): def __init__(self,path): self.path=path - self.file=open(path+'/transcription.txt','r',encoding='UTF-8').read().split("\n") + self.file=open(path+'/transcription.txt','r',encoding='UTF-8').read().replace('\t',' ').rstrip().split("\n") def __len__(self): return len(self.file) def __getitem__(self,index): audio,text=self.file[index].split(' ',1) - audio=self.path+'/Audios/'+audio - return audio,text + audio=self.path+'/audio/'+audio+'.wav' + return {'speech':audio,'text':text} diff --git a/src/configs.py b/src/configs.py old mode 100644 new mode 100755 index c1dbb40..03a2dde --- a/src/configs.py +++ b/src/configs.py @@ -4,20 +4,20 @@ class config: # 'facebook/wav2vec2-large-xlsr-53' - +# 'facebook/wav2vec2-base-960h' model="facebook/wav2vec2-base-960h" fast_LR=1e-3 #To be used when initial weights are frozen - LR=1e-6 + LR=1e-5 clip_grad_norm=1.0 - EPOCHS=0 - num_iters_checkpoint=70 + EPOCHS=1000 + num_iters_checkpoint=56000 prev_checkpoint="" output_directory="./model/" os.makedirs(output_directory, exist_ok=True) - BATCH_SIZE=2 + BATCH_SIZE=8 SHUFFLE=True eval=True train=True @@ -37,13 +37,19 @@ class config: Bengali=(2433,2554+1) Marathi=Hindi - Language=[Hindi,Gujrati,Telugu,Tamil,Odia] #select the language (can add multiple languages to the list) + Language=[Hindi,Gujarati,Telugu,Tamil,Odia] #select the language (can add multiple languages to the list) #Mono-Language Training mono=True #to specify training for the monolingual language (to use mono dataset) - mono_train_path=["./",] #path to training folder - mono_test_path=["./",] #path to testing folder + + mono_train_path=["/home/krishnarajule3/ASR/data/Hindi/train","/home/krishnarajule3/ASR/data/Marathi/train","/home/krishnarajule3/ASR/data/Odia/train", + "/home/krishnarajule3/ASR/data/Gujarati/gu-in-Train","/home/krishnarajule3/ASR/data/Tamil/ta-in-Train","/home/krishnarajule3/ASR/data/Telegu/te-in-Train" + ] #path to training folder + + mono_test_path=["/home/krishnarajule3/ASR/data/Hindi/test","/home/krishnarajule3/ASR/data/Marathi/test","/home/krishnarajule3/ASR/data/Odia/test", + "/home/krishnarajule3/ASR/data/Gujarati/gu-in-Test","/home/krishnarajule3/ASR/data/Tamil/ta-in-Test","/home/krishnarajule3/ASR/data/Telegu/te-in-Train" + ] #path to testing folder #Code Switched Training (set mono=False, to use code-switched loader.py) diff --git a/src/train2.py b/src/train2.py old mode 100644 new mode 100755 index 462bb4d..a936f7b --- a/src/train2.py +++ b/src/train2.py @@ -18,18 +18,6 @@ from Monodataset import MonoData -def mono_collate_fn(batch, tokenizer): - - speech_lis = [sf.read(elem[0])[0] for elem in batch] - text_lis = [elem[1] for elem in batch] - - input_values = tokenizer(speech_lis, return_tensors="pt", - padding='longest').input_values - - labels, label_lengths = tokenizer.batch_tokenize(text_lis) - - return (input_values.to(config.device), labels.to(config.device), label_lengths.to(config.device)) - def find_lengths(logits, pad_id: int) -> torch.FloatTensor: """ Function to find lengths of output sequences @@ -168,7 +156,7 @@ def compute_metric(model, tokenizer, test_dataset): input_values = tokenizer(d["speech"], return_tensors="pt", padding='longest').input_values.to(config.device) else: - input_values = tokenizer(sf.read(d[0])[0], return_tensors="pt", + input_values = tokenizer(sf.read(d["speech"])[0], return_tensors="pt", padding='longest').input_values.to(config.device) logits = model(input_values).logits @@ -179,19 +167,22 @@ def compute_metric(model, tokenizer, test_dataset): if config.transliterate: transcriptions = tokenizer.revert_transliteration(transcriptions) else: - for k,v in self.mappings.items(): - text = text.replace(v.strip(),k) + for k,v in tokenizer.mappings.items(): + transcriptions[0]= transcriptions[0].replace(v.strip(),k) + transcriptions[0]=transcriptions[0].replace('','').replace('','') reference = d['text'].upper() if i==show_sample_no or i==0: print("Sample prediction: ", transcriptions[0]) print("Sample reference: ", reference) - metric.add_batch(predictions=transcriptions, - references=[reference]) + #metric.add_batch(predictions=transcriptions, + # references=[reference]) + + score.append(metric.compute(predictions=transcriptions,references=[reference])) - score=metric.compute() + score=sum(score)/len(score) print("Evaluation metric: ", score) wer_score.append(score) @@ -211,6 +202,18 @@ def collate_fn(batch, tokenizer): return (input_values.to(config.device), labels.to(config.device), label_lengths.to(config.device)) +def mono_collate_fn(batch, tokenizer): + + speech_lis = [sf.read(elem['speech'])[0] for elem in batch] + text_lis = [elem['text'] for elem in batch] + + input_values = tokenizer(speech_lis, return_tensors="pt", + padding='longest').input_values + + labels, label_lengths = tokenizer.batch_tokenize(text_lis) + + return (input_values.to(config.device), labels.to(config.device), label_lengths.to(config.device)) + if __name__ =='__main__': all_params_dict = get_all_params_dict(config) @@ -242,11 +245,11 @@ def collate_fn(batch, tokenizer): test_dataset.append(MonoData(path=j)) if(len(config.mono_train_path)>1): - train_dataset=ConcatDataset(train_dataset) + train_dataset=torch.utils.data.ConcatDataset(train_dataset) else: train_dataset=train_dataset[0] if(len(config.mono_test_path)>1): - val_dataset=ConcatDataset(test_dataset) + val_dataset=torch.utils.data.ConcatDataset(test_dataset) else: val_dataset=test_dataset[0] @@ -256,6 +259,8 @@ def collate_fn(batch, tokenizer): else: mono_dataloader = None + #print(compute_metric(model, tokenizer, test_dataset)) + if(config.train): if not config.mono: train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, collate_fn= lambda b: collate_fn(b, tokenizer), **params) From 492c28c2c05f98a3cc451d3f2e79b2af7c9ad2f9 Mon Sep 17 00:00:00 2001 From: Jaskaran Singh Date: Tue, 20 Apr 2021 19:11:17 +0000 Subject: [PATCH 7/7] final --- src/configs.py | 9 +++++---- src/train2.py | 5 +++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/configs.py b/src/configs.py index 03a2dde..4716b4c 100755 --- a/src/configs.py +++ b/src/configs.py @@ -7,17 +7,17 @@ class config: # 'facebook/wav2vec2-base-960h' model="facebook/wav2vec2-base-960h" fast_LR=1e-3 #To be used when initial weights are frozen - LR=1e-5 + LR=1e-4 clip_grad_norm=1.0 EPOCHS=1000 - num_iters_checkpoint=56000 - prev_checkpoint="" + num_iters_checkpoint=57660 + prev_checkpoint="./wandb/run-20210418_135042-2wfzsbrd/files/facebook/wav2vec2-base-960h_14" output_directory="./model/" os.makedirs(output_directory, exist_ok=True) - BATCH_SIZE=8 + BATCH_SIZE=6 SHUFFLE=True eval=True train=True @@ -68,3 +68,4 @@ def get_all_params_dict(config): + diff --git a/src/train2.py b/src/train2.py index a936f7b..101e4d0 100755 --- a/src/train2.py +++ b/src/train2.py @@ -170,7 +170,7 @@ def compute_metric(model, tokenizer, test_dataset): for k,v in tokenizer.mappings.items(): transcriptions[0]= transcriptions[0].replace(v.strip(),k) - transcriptions[0]=transcriptions[0].replace('','').replace('','') + transcriptions[0]=transcriptions[0].replace('','').replace('','') reference = d['text'].upper() if i==show_sample_no or i==0: @@ -189,6 +189,7 @@ def compute_metric(model, tokenizer, test_dataset): print(wer_score) score=sum(wer_score)/len(wer_score) + print('Avg Score: ',score) return score def collate_fn(batch, tokenizer): @@ -259,7 +260,7 @@ def mono_collate_fn(batch, tokenizer): else: mono_dataloader = None - #print(compute_metric(model, tokenizer, test_dataset)) + print(compute_metric(model, tokenizer, test_dataset)) if(config.train): if not config.mono: