diff --git a/src/Monodataset.py b/src/Monodataset.py new file mode 100755 index 0000000..f933dd9 --- /dev/null +++ b/src/Monodataset.py @@ -0,0 +1,15 @@ +import os +from torch.utils.data import Dataset + +class MonoData(Dataset): + def __init__(self,path): + self.path=path + self.file=open(path+'/transcription.txt','r',encoding='UTF-8').read().replace('\t',' ').rstrip().split("\n") + + def __len__(self): + return len(self.file) + + def __getitem__(self,index): + audio,text=self.file[index].split(' ',1) + audio=self.path+'/audio/'+audio+'.wav' + return {'speech':audio,'text':text} diff --git a/src/configs.py b/src/configs.py old mode 100644 new mode 100755 index c34489e..4716b4c --- a/src/configs.py +++ b/src/configs.py @@ -3,48 +3,69 @@ class config: - data_dir="/home/krishnarajule3/ASR/data/Hindi-English/" - data_loading_script="/home/datasets/code_switch_asr" - - use_monolingual=False - monolingual_data_dir="/home/krishnarajule3/ASR/data/Hindi/" - +# 'facebook/wav2vec2-large-xlsr-53' +# 'facebook/wav2vec2-base-960h' model="facebook/wav2vec2-base-960h" - fast_LR=1e-3 #To be used when initial weights are frozen - LR=1e-6 + fast_LR=1e-3 #To be used when initial weights are frozen + LR=1e-4 clip_grad_norm=1.0 - EPOCHS=0 - num_iters_checkpoint=70 - prev_checkpoint="" + EPOCHS=1000 + num_iters_checkpoint=57660 + prev_checkpoint="./wandb/run-20210418_135042-2wfzsbrd/files/facebook/wav2vec2-base-960h_14" + output_directory="./model/" os.makedirs(output_directory, exist_ok=True) - BATCH_SIZE=5 - SHUFFLE=False - eval=False - train=False + BATCH_SIZE=6 + SHUFFLE=True + eval=True + train=True device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') max_audio_len=576000 freeze_for_epochs=0 transliterate=False - + cur_epoch=0 + #Support for additional languages Telugu=(3072,3199+1) Tamil=(2944,3071+1) - Oriya=(2816,2943+1) - Gujrati=(2688,2815+1) + Odia=(2816,2943+1) + Gujarati=(2688,2815+1) Hindi=(2304,2431+1) + Bengali=(2433,2554+1) Marathi=Hindi - Language=Gujrati #select the language + Language=[Hindi,Gujarati,Telugu,Tamil,Odia] #select the language (can add multiple languages to the list) + + #Mono-Language Training + + mono=True #to specify training for the monolingual language (to use mono dataset) + + mono_train_path=["/home/krishnarajule3/ASR/data/Hindi/train","/home/krishnarajule3/ASR/data/Marathi/train","/home/krishnarajule3/ASR/data/Odia/train", + "/home/krishnarajule3/ASR/data/Gujarati/gu-in-Train","/home/krishnarajule3/ASR/data/Tamil/ta-in-Train","/home/krishnarajule3/ASR/data/Telegu/te-in-Train" + ] #path to training folder + + mono_test_path=["/home/krishnarajule3/ASR/data/Hindi/test","/home/krishnarajule3/ASR/data/Marathi/test","/home/krishnarajule3/ASR/data/Odia/test", + "/home/krishnarajule3/ASR/data/Gujarati/gu-in-Test","/home/krishnarajule3/ASR/data/Tamil/ta-in-Test","/home/krishnarajule3/ASR/data/Telegu/te-in-Train" + ] #path to testing folder + #Code Switched Training (set mono=False, to use code-switched loader.py) + data_dir="/home/krishnarajule3/ASR/data/Hindi-English/" + data_loading_script="/home/datasets/code_switch_asr" + use_monolingual=False + monolingual_data_dir="/home/krishnarajule3/ASR/data/Hindi/" + def get_all_params_dict(config): params = {} for k, v in config.__dict__.items(): if not ( callable(v) or (k.startswith('__') and k.endswith('__'))): params[k]=v return params + + + + diff --git a/src/model.py b/src/model.py old mode 100644 new mode 100755 diff --git a/src/tokenizer.py b/src/tokenizer.py old mode 100644 new mode 100755 index 52938c5..a544b1c --- a/src/tokenizer.py +++ b/src/tokenizer.py @@ -18,8 +18,9 @@ class Wav2Vec2Tok(Wav2Vec2Tokenizer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if not config.transliterate: - for i in range(config.Language[0], config.Language[1]) : - self._add_tokens(chr(i)) + for lang in config.Language: + for i in range(lang[0], lang[1]) : + self._add_tokens(chr(i)) else: self.en_dict = enchant.Dict("en_US") for elem in ['̄', '̣', '̐', '́', '़', "'ॉ", '̃', '_', 'ऑ', '^', '…', '°', '̂', '̱', 'ॅ', 'ऍ', ':']: @@ -45,7 +46,7 @@ def transliterate(self, text: str)-> str: def remove_sos(self, texts: List[str]) -> List[str]: processed_texts = [] for text in texts: - processed_texts.append(text.replace('','').replace('','')) + processed_texts.append(text.replace('','').replace('','')) return processed_texts def revert_transliteration(self, texts: List[str])->str: @@ -76,7 +77,10 @@ def tokenize(self, text: str, **kwargs) -> List[int]: """ if config.transliterate: text = self.transliterate(text) - + else: + for k,v in self.mappings.items(): + text = text.replace(k, v) + text = ' '.join(text.split()) text = text.replace(' ', self.word_delimiter_token) tokens = [self.bos_token_id] diff --git a/src/train2.py b/src/train2.py old mode 100644 new mode 100755 index f1211ef..101e4d0 --- a/src/train2.py +++ b/src/train2.py @@ -1,19 +1,22 @@ +import os +import itertools +import soundfile as sf +import wandb +import random +import argparse + import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F -import os -import itertools -import soundfile as sf -import argparse from tqdm import tqdm from configs import config, get_all_params_dict from model import get_model from tokenizer import Wav2Vec2Tok from datasets import load_dataset, load_metric -import wandb -import random +from Monodataset import MonoData + def find_lengths(logits, pad_id: int) -> torch.FloatTensor: """ @@ -134,34 +137,59 @@ def eval_model(model, tokenizer, val_dataloader): return (epoch_loss / num_valid_batches) def compute_metric(model, tokenizer, test_dataset): - metric = load_metric('wer') - - pbar = tqdm(test_dataset, desc="Computing metric") - - show_sample_no = random.randint(1, len(test_dataset)-1) + + wer_score=[] + model.eval() + if not isinstance(test_dataset,list): + test_dataset=[test_dataset] + with torch.no_grad(): - for i, d in enumerate(pbar): - - input_values = tokenizer(d["speech"], return_tensors="pt", - padding='longest').input_values.to(config.device) - - logits = model(input_values).logits - - predicted_ids = torch.argmax(logits, dim=-1).cpu() - transcriptions = tokenizer.batch_decode(predicted_ids) - transcriptions = tokenizer.revert_transliteration(transcriptions) - - reference = d['text'].upper() - - if i==show_sample_no or i==0: - print("Sample prediction: ", transcriptions[0]) - print("Sample reference: ", reference) - - metric.add_batch(predictions=transcriptions, - references=[reference]) + for dataset in test_dataset: + metric = load_metric('wer') + pbar = tqdm(dataset, desc="Computing metric") + score=[] + show_sample_no = random.randint(1, len(dataset)-1) + with torch.no_grad(): + for i, d in enumerate(pbar): + + if not config.mono: + input_values = tokenizer(d["speech"], return_tensors="pt", + padding='longest').input_values.to(config.device) + else: + input_values = tokenizer(sf.read(d["speech"])[0], return_tensors="pt", + padding='longest').input_values.to(config.device) + + logits = model(input_values).logits + + predicted_ids = torch.argmax(logits, dim=-1).cpu() + transcriptions = tokenizer.batch_decode(predicted_ids) + + if config.transliterate: + transcriptions = tokenizer.revert_transliteration(transcriptions) + else: + for k,v in tokenizer.mappings.items(): + transcriptions[0]= transcriptions[0].replace(v.strip(),k) + + transcriptions[0]=transcriptions[0].replace('','').replace('','') + reference = d['text'].upper() + + if i==show_sample_no or i==0: + print("Sample prediction: ", transcriptions[0]) + print("Sample reference: ", reference) + + #metric.add_batch(predictions=transcriptions, + # references=[reference]) + + score.append(metric.compute(predictions=transcriptions,references=[reference])) + + score=sum(score)/len(score) + print("Evaluation metric: ", score) + wer_score.append(score) - score = metric.compute() - print("Evaluation metric: ", score) + print(wer_score) + + score=sum(wer_score)/len(wer_score) + print('Avg Score: ',score) return score def collate_fn(batch, tokenizer): @@ -175,6 +203,18 @@ def collate_fn(batch, tokenizer): return (input_values.to(config.device), labels.to(config.device), label_lengths.to(config.device)) +def mono_collate_fn(batch, tokenizer): + + speech_lis = [sf.read(elem['speech'])[0] for elem in batch] + text_lis = [elem['text'] for elem in batch] + + input_values = tokenizer(speech_lis, return_tensors="pt", + padding='longest').input_values + + labels, label_lengths = tokenizer.batch_tokenize(text_lis) + + return (input_values.to(config.device), labels.to(config.device), label_lengths.to(config.device)) + if __name__ =='__main__': all_params_dict = get_all_params_dict(config) @@ -190,25 +230,48 @@ def collate_fn(batch, tokenizer): if(config.prev_checkpoint!=""): model=load_checkpoint(model,config.prev_checkpoint) - params = {'batch_size': config.BATCH_SIZE,} + params = {'batch_size': config.BATCH_SIZE,'shuffle': config.SHUFFLE} print("running on ", config.device) - - train_dataset = load_dataset(config.data_loading_script, data_dir=config.data_dir, split="train[2%:]", writer_batch_size=1000) - val_dataset = load_dataset(config.data_loading_script, data_dir=config.data_dir, split="train[:2%]", writer_batch_size=1000) - test_dataset = load_dataset(config.data_loading_script, data_dir=config.data_dir, split="test", writer_batch_size=1000) + if not config.mono: + train_dataset = load_dataset(config.data_loading_script, data_dir=config.data_dir, split="train[2%:]", writer_batch_size=1000) + val_dataset = load_dataset(config.data_loading_script, data_dir=config.data_dir, split="train[:2%]", writer_batch_size=1000) + test_dataset = load_dataset(config.data_loading_script, data_dir=config.data_dir, split="test", writer_batch_size=1000) + else: + train_dataset=[] + test_dataset=[] + for i,j in zip(config.mono_train_path,config.mono_test_path): + train_dataset.append(MonoData(path=i)) + test_dataset.append(MonoData(path=j)) + + if(len(config.mono_train_path)>1): + train_dataset=torch.utils.data.ConcatDataset(train_dataset) + else: + train_dataset=train_dataset[0] + if(len(config.mono_test_path)>1): + val_dataset=torch.utils.data.ConcatDataset(test_dataset) + else: + val_dataset=test_dataset[0] + if config.use_monolingual: mono_dataset = load_dataset(config.data_loading_script, data_dir=config.monolingual_data_dir, split="train", writer_batch_size=1000) mono_dataloader = torch.utils.data.DataLoader(dataset=mono_dataset, collate_fn= lambda b: collate_fn(b, tokenizer), **params) else: mono_dataloader = None + print(compute_metric(model, tokenizer, test_dataset)) + if(config.train): - train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, collate_fn= lambda b: collate_fn(b, tokenizer), **params) - val_dataloader = torch.utils.data.DataLoader(dataset=val_dataset, collate_fn= lambda b: collate_fn(b, tokenizer), **params) - train_model(model, tokenizer, train_dataloader, val_dataloader, test_dataset, mono_dataloader) - + if not config.mono: + train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, collate_fn= lambda b: collate_fn(b, tokenizer), **params) + val_dataloader = torch.utils.data.DataLoader(dataset=val_dataset, collate_fn= lambda b: collate_fn(b, tokenizer), **params) + train_model(model, tokenizer, train_dataloader, val_dataloader, test_dataset, mono_dataloader) + else: + train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, collate_fn= lambda b: mono_collate_fn(b, tokenizer), **params) + val_dataloader = torch.utils.data.DataLoader(dataset=val_dataset, collate_fn= lambda b: mono_collate_fn(b, tokenizer), **params) + train_model(model, tokenizer, train_dataloader, val_dataloader, test_dataset) + if(config.eval): print(compute_metric(model, tokenizer, test_dataset))