diff --git a/src/Monodataset.py b/src/Monodataset.py
new file mode 100755
index 0000000..f933dd9
--- /dev/null
+++ b/src/Monodataset.py
@@ -0,0 +1,15 @@
+import os
+from torch.utils.data import Dataset
+
+class MonoData(Dataset):
+ def __init__(self,path):
+ self.path=path
+ self.file=open(path+'/transcription.txt','r',encoding='UTF-8').read().replace('\t',' ').rstrip().split("\n")
+
+ def __len__(self):
+ return len(self.file)
+
+ def __getitem__(self,index):
+ audio,text=self.file[index].split(' ',1)
+ audio=self.path+'/audio/'+audio+'.wav'
+ return {'speech':audio,'text':text}
diff --git a/src/configs.py b/src/configs.py
old mode 100644
new mode 100755
index c34489e..4716b4c
--- a/src/configs.py
+++ b/src/configs.py
@@ -3,48 +3,69 @@
class config:
- data_dir="/home/krishnarajule3/ASR/data/Hindi-English/"
- data_loading_script="/home/datasets/code_switch_asr"
-
- use_monolingual=False
- monolingual_data_dir="/home/krishnarajule3/ASR/data/Hindi/"
-
+# 'facebook/wav2vec2-large-xlsr-53'
+# 'facebook/wav2vec2-base-960h'
model="facebook/wav2vec2-base-960h"
- fast_LR=1e-3 #To be used when initial weights are frozen
- LR=1e-6
+ fast_LR=1e-3 #To be used when initial weights are frozen
+ LR=1e-4
clip_grad_norm=1.0
- EPOCHS=0
- num_iters_checkpoint=70
- prev_checkpoint=""
+ EPOCHS=1000
+ num_iters_checkpoint=57660
+ prev_checkpoint="./wandb/run-20210418_135042-2wfzsbrd/files/facebook/wav2vec2-base-960h_14"
+
output_directory="./model/"
os.makedirs(output_directory, exist_ok=True)
- BATCH_SIZE=5
- SHUFFLE=False
- eval=False
- train=False
+ BATCH_SIZE=6
+ SHUFFLE=True
+ eval=True
+ train=True
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
max_audio_len=576000
freeze_for_epochs=0
transliterate=False
-
+ cur_epoch=0
+
#Support for additional languages
Telugu=(3072,3199+1)
Tamil=(2944,3071+1)
- Oriya=(2816,2943+1)
- Gujrati=(2688,2815+1)
+ Odia=(2816,2943+1)
+ Gujarati=(2688,2815+1)
Hindi=(2304,2431+1)
+ Bengali=(2433,2554+1)
Marathi=Hindi
- Language=Gujrati #select the language
+ Language=[Hindi,Gujarati,Telugu,Tamil,Odia] #select the language (can add multiple languages to the list)
+
+ #Mono-Language Training
+
+ mono=True #to specify training for the monolingual language (to use mono dataset)
+
+ mono_train_path=["/home/krishnarajule3/ASR/data/Hindi/train","/home/krishnarajule3/ASR/data/Marathi/train","/home/krishnarajule3/ASR/data/Odia/train",
+ "/home/krishnarajule3/ASR/data/Gujarati/gu-in-Train","/home/krishnarajule3/ASR/data/Tamil/ta-in-Train","/home/krishnarajule3/ASR/data/Telegu/te-in-Train"
+ ] #path to training folder
+
+ mono_test_path=["/home/krishnarajule3/ASR/data/Hindi/test","/home/krishnarajule3/ASR/data/Marathi/test","/home/krishnarajule3/ASR/data/Odia/test",
+ "/home/krishnarajule3/ASR/data/Gujarati/gu-in-Test","/home/krishnarajule3/ASR/data/Tamil/ta-in-Test","/home/krishnarajule3/ASR/data/Telegu/te-in-Train"
+ ] #path to testing folder
+ #Code Switched Training (set mono=False, to use code-switched loader.py)
+ data_dir="/home/krishnarajule3/ASR/data/Hindi-English/"
+ data_loading_script="/home/datasets/code_switch_asr"
+ use_monolingual=False
+ monolingual_data_dir="/home/krishnarajule3/ASR/data/Hindi/"
+
def get_all_params_dict(config):
params = {}
for k, v in config.__dict__.items():
if not ( callable(v) or (k.startswith('__') and k.endswith('__'))):
params[k]=v
return params
+
+
+
+
diff --git a/src/model.py b/src/model.py
old mode 100644
new mode 100755
diff --git a/src/tokenizer.py b/src/tokenizer.py
old mode 100644
new mode 100755
index 52938c5..a544b1c
--- a/src/tokenizer.py
+++ b/src/tokenizer.py
@@ -18,8 +18,9 @@ class Wav2Vec2Tok(Wav2Vec2Tokenizer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not config.transliterate:
- for i in range(config.Language[0], config.Language[1]) :
- self._add_tokens(chr(i))
+ for lang in config.Language:
+ for i in range(lang[0], lang[1]) :
+ self._add_tokens(chr(i))
else:
self.en_dict = enchant.Dict("en_US")
for elem in ['̄', '̣', '̐', '́', '़', "'ॉ", '̃', '_', 'ऑ', '^', '…', '°', '̂', '̱', 'ॅ', 'ऍ', ':']:
@@ -45,7 +46,7 @@ def transliterate(self, text: str)-> str:
def remove_sos(self, texts: List[str]) -> List[str]:
processed_texts = []
for text in texts:
- processed_texts.append(text.replace('','').replace('',''))
+ processed_texts.append(text.replace('','').replace('',''))
return processed_texts
def revert_transliteration(self, texts: List[str])->str:
@@ -76,7 +77,10 @@ def tokenize(self, text: str, **kwargs) -> List[int]:
"""
if config.transliterate:
text = self.transliterate(text)
-
+ else:
+ for k,v in self.mappings.items():
+ text = text.replace(k, v)
+
text = ' '.join(text.split())
text = text.replace(' ', self.word_delimiter_token)
tokens = [self.bos_token_id]
diff --git a/src/train2.py b/src/train2.py
old mode 100644
new mode 100755
index f1211ef..101e4d0
--- a/src/train2.py
+++ b/src/train2.py
@@ -1,19 +1,22 @@
+import os
+import itertools
+import soundfile as sf
+import wandb
+import random
+import argparse
+
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
-import os
-import itertools
-import soundfile as sf
-import argparse
from tqdm import tqdm
from configs import config, get_all_params_dict
from model import get_model
from tokenizer import Wav2Vec2Tok
from datasets import load_dataset, load_metric
-import wandb
-import random
+from Monodataset import MonoData
+
def find_lengths(logits, pad_id: int) -> torch.FloatTensor:
"""
@@ -134,34 +137,59 @@ def eval_model(model, tokenizer, val_dataloader):
return (epoch_loss / num_valid_batches)
def compute_metric(model, tokenizer, test_dataset):
- metric = load_metric('wer')
-
- pbar = tqdm(test_dataset, desc="Computing metric")
-
- show_sample_no = random.randint(1, len(test_dataset)-1)
+
+ wer_score=[]
+ model.eval()
+ if not isinstance(test_dataset,list):
+ test_dataset=[test_dataset]
+
with torch.no_grad():
- for i, d in enumerate(pbar):
-
- input_values = tokenizer(d["speech"], return_tensors="pt",
- padding='longest').input_values.to(config.device)
-
- logits = model(input_values).logits
-
- predicted_ids = torch.argmax(logits, dim=-1).cpu()
- transcriptions = tokenizer.batch_decode(predicted_ids)
- transcriptions = tokenizer.revert_transliteration(transcriptions)
-
- reference = d['text'].upper()
-
- if i==show_sample_no or i==0:
- print("Sample prediction: ", transcriptions[0])
- print("Sample reference: ", reference)
-
- metric.add_batch(predictions=transcriptions,
- references=[reference])
+ for dataset in test_dataset:
+ metric = load_metric('wer')
+ pbar = tqdm(dataset, desc="Computing metric")
+ score=[]
+ show_sample_no = random.randint(1, len(dataset)-1)
+ with torch.no_grad():
+ for i, d in enumerate(pbar):
+
+ if not config.mono:
+ input_values = tokenizer(d["speech"], return_tensors="pt",
+ padding='longest').input_values.to(config.device)
+ else:
+ input_values = tokenizer(sf.read(d["speech"])[0], return_tensors="pt",
+ padding='longest').input_values.to(config.device)
+
+ logits = model(input_values).logits
+
+ predicted_ids = torch.argmax(logits, dim=-1).cpu()
+ transcriptions = tokenizer.batch_decode(predicted_ids)
+
+ if config.transliterate:
+ transcriptions = tokenizer.revert_transliteration(transcriptions)
+ else:
+ for k,v in tokenizer.mappings.items():
+ transcriptions[0]= transcriptions[0].replace(v.strip(),k)
+
+ transcriptions[0]=transcriptions[0].replace('','').replace('','')
+ reference = d['text'].upper()
+
+ if i==show_sample_no or i==0:
+ print("Sample prediction: ", transcriptions[0])
+ print("Sample reference: ", reference)
+
+ #metric.add_batch(predictions=transcriptions,
+ # references=[reference])
+
+ score.append(metric.compute(predictions=transcriptions,references=[reference]))
+
+ score=sum(score)/len(score)
+ print("Evaluation metric: ", score)
+ wer_score.append(score)
- score = metric.compute()
- print("Evaluation metric: ", score)
+ print(wer_score)
+
+ score=sum(wer_score)/len(wer_score)
+ print('Avg Score: ',score)
return score
def collate_fn(batch, tokenizer):
@@ -175,6 +203,18 @@ def collate_fn(batch, tokenizer):
return (input_values.to(config.device), labels.to(config.device), label_lengths.to(config.device))
+def mono_collate_fn(batch, tokenizer):
+
+ speech_lis = [sf.read(elem['speech'])[0] for elem in batch]
+ text_lis = [elem['text'] for elem in batch]
+
+ input_values = tokenizer(speech_lis, return_tensors="pt",
+ padding='longest').input_values
+
+ labels, label_lengths = tokenizer.batch_tokenize(text_lis)
+
+ return (input_values.to(config.device), labels.to(config.device), label_lengths.to(config.device))
+
if __name__ =='__main__':
all_params_dict = get_all_params_dict(config)
@@ -190,25 +230,48 @@ def collate_fn(batch, tokenizer):
if(config.prev_checkpoint!=""):
model=load_checkpoint(model,config.prev_checkpoint)
- params = {'batch_size': config.BATCH_SIZE,}
+ params = {'batch_size': config.BATCH_SIZE,'shuffle': config.SHUFFLE}
print("running on ", config.device)
-
- train_dataset = load_dataset(config.data_loading_script, data_dir=config.data_dir, split="train[2%:]", writer_batch_size=1000)
- val_dataset = load_dataset(config.data_loading_script, data_dir=config.data_dir, split="train[:2%]", writer_batch_size=1000)
- test_dataset = load_dataset(config.data_loading_script, data_dir=config.data_dir, split="test", writer_batch_size=1000)
+ if not config.mono:
+ train_dataset = load_dataset(config.data_loading_script, data_dir=config.data_dir, split="train[2%:]", writer_batch_size=1000)
+ val_dataset = load_dataset(config.data_loading_script, data_dir=config.data_dir, split="train[:2%]", writer_batch_size=1000)
+ test_dataset = load_dataset(config.data_loading_script, data_dir=config.data_dir, split="test", writer_batch_size=1000)
+ else:
+ train_dataset=[]
+ test_dataset=[]
+ for i,j in zip(config.mono_train_path,config.mono_test_path):
+ train_dataset.append(MonoData(path=i))
+ test_dataset.append(MonoData(path=j))
+
+ if(len(config.mono_train_path)>1):
+ train_dataset=torch.utils.data.ConcatDataset(train_dataset)
+ else:
+ train_dataset=train_dataset[0]
+ if(len(config.mono_test_path)>1):
+ val_dataset=torch.utils.data.ConcatDataset(test_dataset)
+ else:
+ val_dataset=test_dataset[0]
+
if config.use_monolingual:
mono_dataset = load_dataset(config.data_loading_script, data_dir=config.monolingual_data_dir, split="train", writer_batch_size=1000)
mono_dataloader = torch.utils.data.DataLoader(dataset=mono_dataset, collate_fn= lambda b: collate_fn(b, tokenizer), **params)
else:
mono_dataloader = None
+ print(compute_metric(model, tokenizer, test_dataset))
+
if(config.train):
- train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, collate_fn= lambda b: collate_fn(b, tokenizer), **params)
- val_dataloader = torch.utils.data.DataLoader(dataset=val_dataset, collate_fn= lambda b: collate_fn(b, tokenizer), **params)
- train_model(model, tokenizer, train_dataloader, val_dataloader, test_dataset, mono_dataloader)
-
+ if not config.mono:
+ train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, collate_fn= lambda b: collate_fn(b, tokenizer), **params)
+ val_dataloader = torch.utils.data.DataLoader(dataset=val_dataset, collate_fn= lambda b: collate_fn(b, tokenizer), **params)
+ train_model(model, tokenizer, train_dataloader, val_dataloader, test_dataset, mono_dataloader)
+ else:
+ train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, collate_fn= lambda b: mono_collate_fn(b, tokenizer), **params)
+ val_dataloader = torch.utils.data.DataLoader(dataset=val_dataset, collate_fn= lambda b: mono_collate_fn(b, tokenizer), **params)
+ train_model(model, tokenizer, train_dataloader, val_dataloader, test_dataset)
+
if(config.eval):
print(compute_metric(model, tokenizer, test_dataset))