Skip to content

Mono #4

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/Monodataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import os
from torch.utils.data import Dataset

class MonoData(Dataset):
def __init__(self,path):
self.path=path
self.file=open(path+'/transcription.txt','r',encoding='UTF-8').read().replace('\t',' ').rstrip().split("\n")

def __len__(self):
return len(self.file)

def __getitem__(self,index):
audio,text=self.file[index].split(' ',1)
audio=self.path+'/audio/'+audio+'.wav'
return {'speech':audio,'text':text}
59 changes: 40 additions & 19 deletions src/configs.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -3,48 +3,69 @@

class config:

data_dir="/home/krishnarajule3/ASR/data/Hindi-English/"
data_loading_script="/home/datasets/code_switch_asr"

use_monolingual=False
monolingual_data_dir="/home/krishnarajule3/ASR/data/Hindi/"

# 'facebook/wav2vec2-large-xlsr-53'
# 'facebook/wav2vec2-base-960h'
model="facebook/wav2vec2-base-960h"
fast_LR=1e-3 #To be used when initial weights are frozen
LR=1e-6
fast_LR=1e-3 #To be used when initial weights are frozen
LR=1e-4
clip_grad_norm=1.0
EPOCHS=0
num_iters_checkpoint=70
prev_checkpoint=""
EPOCHS=1000
num_iters_checkpoint=57660
prev_checkpoint="./wandb/run-20210418_135042-2wfzsbrd/files/facebook/wav2vec2-base-960h_14"

output_directory="./model/"

os.makedirs(output_directory, exist_ok=True)

BATCH_SIZE=5
SHUFFLE=False
eval=False
train=False
BATCH_SIZE=6
SHUFFLE=True
eval=True
train=True
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
max_audio_len=576000
freeze_for_epochs=0
transliterate=False

cur_epoch=0

#Support for additional languages
Telugu=(3072,3199+1)
Tamil=(2944,3071+1)

Oriya=(2816,2943+1)
Gujrati=(2688,2815+1)
Odia=(2816,2943+1)
Gujarati=(2688,2815+1)
Hindi=(2304,2431+1)
Bengali=(2433,2554+1)
Marathi=Hindi

Language=Gujrati #select the language
Language=[Hindi,Gujarati,Telugu,Tamil,Odia] #select the language (can add multiple languages to the list)

#Mono-Language Training

mono=True #to specify training for the monolingual language (to use mono dataset)

mono_train_path=["/home/krishnarajule3/ASR/data/Hindi/train","/home/krishnarajule3/ASR/data/Marathi/train","/home/krishnarajule3/ASR/data/Odia/train",
"/home/krishnarajule3/ASR/data/Gujarati/gu-in-Train","/home/krishnarajule3/ASR/data/Tamil/ta-in-Train","/home/krishnarajule3/ASR/data/Telegu/te-in-Train"
] #path to training folder

mono_test_path=["/home/krishnarajule3/ASR/data/Hindi/test","/home/krishnarajule3/ASR/data/Marathi/test","/home/krishnarajule3/ASR/data/Odia/test",
"/home/krishnarajule3/ASR/data/Gujarati/gu-in-Test","/home/krishnarajule3/ASR/data/Tamil/ta-in-Test","/home/krishnarajule3/ASR/data/Telegu/te-in-Train"
] #path to testing folder

#Code Switched Training (set mono=False, to use code-switched loader.py)

data_dir="/home/krishnarajule3/ASR/data/Hindi-English/"
data_loading_script="/home/datasets/code_switch_asr"

use_monolingual=False
monolingual_data_dir="/home/krishnarajule3/ASR/data/Hindi/"

def get_all_params_dict(config):
params = {}
for k, v in config.__dict__.items():
if not ( callable(v) or (k.startswith('__') and k.endswith('__'))):
params[k]=v
return params




Empty file modified src/model.py
100644 → 100755
Empty file.
12 changes: 8 additions & 4 deletions src/tokenizer.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ class Wav2Vec2Tok(Wav2Vec2Tokenizer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not config.transliterate:
for i in range(config.Language[0], config.Language[1]) :
self._add_tokens(chr(i))
for lang in config.Language:
for i in range(lang[0], lang[1]) :
self._add_tokens(chr(i))
else:
self.en_dict = enchant.Dict("en_US")
for elem in ['̄', '̣', '̐', '́', '़', "'ॉ", '̃', '_', 'ऑ', '^', '…', '°', '̂', '̱', 'ॅ', 'ऍ', ':']:
Expand All @@ -45,7 +46,7 @@ def transliterate(self, text: str)-> str:
def remove_sos(self, texts: List[str]) -> List[str]:
processed_texts = []
for text in texts:
processed_texts.append(text.replace('<s>','').replace('</s>',''))
processed_texts.append(text.replace('<S>','').replace('</S>',''))
return processed_texts

def revert_transliteration(self, texts: List[str])->str:
Expand Down Expand Up @@ -76,7 +77,10 @@ def tokenize(self, text: str, **kwargs) -> List[int]:
"""
if config.transliterate:
text = self.transliterate(text)

else:
for k,v in self.mappings.items():
text = text.replace(k, v)

text = ' '.join(text.split())
text = text.replace(' ', self.word_delimiter_token)
tokens = [self.bos_token_id]
Expand Down
145 changes: 104 additions & 41 deletions src/train2.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
import os
import itertools
import soundfile as sf
import wandb
import random
import argparse

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import os
import itertools
import soundfile as sf

import argparse
from tqdm import tqdm
from configs import config, get_all_params_dict
from model import get_model
from tokenizer import Wav2Vec2Tok
from datasets import load_dataset, load_metric
import wandb
import random
from Monodataset import MonoData


def find_lengths(logits, pad_id: int) -> torch.FloatTensor:
"""
Expand Down Expand Up @@ -134,34 +137,59 @@ def eval_model(model, tokenizer, val_dataloader):
return (epoch_loss / num_valid_batches)

def compute_metric(model, tokenizer, test_dataset):
metric = load_metric('wer')

pbar = tqdm(test_dataset, desc="Computing metric")

show_sample_no = random.randint(1, len(test_dataset)-1)

wer_score=[]
model.eval()
if not isinstance(test_dataset,list):
test_dataset=[test_dataset]

with torch.no_grad():
for i, d in enumerate(pbar):

input_values = tokenizer(d["speech"], return_tensors="pt",
padding='longest').input_values.to(config.device)

logits = model(input_values).logits

predicted_ids = torch.argmax(logits, dim=-1).cpu()
transcriptions = tokenizer.batch_decode(predicted_ids)
transcriptions = tokenizer.revert_transliteration(transcriptions)

reference = d['text'].upper()

if i==show_sample_no or i==0:
print("Sample prediction: ", transcriptions[0])
print("Sample reference: ", reference)

metric.add_batch(predictions=transcriptions,
references=[reference])
for dataset in test_dataset:
metric = load_metric('wer')
pbar = tqdm(dataset, desc="Computing metric")
score=[]
show_sample_no = random.randint(1, len(dataset)-1)
with torch.no_grad():
for i, d in enumerate(pbar):

if not config.mono:
input_values = tokenizer(d["speech"], return_tensors="pt",
padding='longest').input_values.to(config.device)
else:
input_values = tokenizer(sf.read(d["speech"])[0], return_tensors="pt",
padding='longest').input_values.to(config.device)

logits = model(input_values).logits

predicted_ids = torch.argmax(logits, dim=-1).cpu()
transcriptions = tokenizer.batch_decode(predicted_ids)

if config.transliterate:
transcriptions = tokenizer.revert_transliteration(transcriptions)
else:
for k,v in tokenizer.mappings.items():
transcriptions[0]= transcriptions[0].replace(v.strip(),k)

transcriptions[0]=transcriptions[0].replace('<s>','').replace('</s>','')
reference = d['text'].upper()

if i==show_sample_no or i==0:
print("Sample prediction: ", transcriptions[0])
print("Sample reference: ", reference)

#metric.add_batch(predictions=transcriptions,
# references=[reference])

score.append(metric.compute(predictions=transcriptions,references=[reference]))

score=sum(score)/len(score)
print("Evaluation metric: ", score)
wer_score.append(score)

score = metric.compute()
print("Evaluation metric: ", score)
print(wer_score)

score=sum(wer_score)/len(wer_score)
print('Avg Score: ',score)
return score

def collate_fn(batch, tokenizer):
Expand All @@ -175,6 +203,18 @@ def collate_fn(batch, tokenizer):

return (input_values.to(config.device), labels.to(config.device), label_lengths.to(config.device))

def mono_collate_fn(batch, tokenizer):

speech_lis = [sf.read(elem['speech'])[0] for elem in batch]
text_lis = [elem['text'] for elem in batch]

input_values = tokenizer(speech_lis, return_tensors="pt",
padding='longest').input_values

labels, label_lengths = tokenizer.batch_tokenize(text_lis)

return (input_values.to(config.device), labels.to(config.device), label_lengths.to(config.device))

if __name__ =='__main__':
all_params_dict = get_all_params_dict(config)

Expand All @@ -190,25 +230,48 @@ def collate_fn(batch, tokenizer):
if(config.prev_checkpoint!=""):
model=load_checkpoint(model,config.prev_checkpoint)

params = {'batch_size': config.BATCH_SIZE,}
params = {'batch_size': config.BATCH_SIZE,'shuffle': config.SHUFFLE}

print("running on ", config.device)

train_dataset = load_dataset(config.data_loading_script, data_dir=config.data_dir, split="train[2%:]", writer_batch_size=1000)
val_dataset = load_dataset(config.data_loading_script, data_dir=config.data_dir, split="train[:2%]", writer_batch_size=1000)
test_dataset = load_dataset(config.data_loading_script, data_dir=config.data_dir, split="test", writer_batch_size=1000)

if not config.mono:
train_dataset = load_dataset(config.data_loading_script, data_dir=config.data_dir, split="train[2%:]", writer_batch_size=1000)
val_dataset = load_dataset(config.data_loading_script, data_dir=config.data_dir, split="train[:2%]", writer_batch_size=1000)
test_dataset = load_dataset(config.data_loading_script, data_dir=config.data_dir, split="test", writer_batch_size=1000)
else:
train_dataset=[]
test_dataset=[]
for i,j in zip(config.mono_train_path,config.mono_test_path):
train_dataset.append(MonoData(path=i))
test_dataset.append(MonoData(path=j))

if(len(config.mono_train_path)>1):
train_dataset=torch.utils.data.ConcatDataset(train_dataset)
else:
train_dataset=train_dataset[0]
if(len(config.mono_test_path)>1):
val_dataset=torch.utils.data.ConcatDataset(test_dataset)
else:
val_dataset=test_dataset[0]

if config.use_monolingual:
mono_dataset = load_dataset(config.data_loading_script, data_dir=config.monolingual_data_dir, split="train", writer_batch_size=1000)
mono_dataloader = torch.utils.data.DataLoader(dataset=mono_dataset, collate_fn= lambda b: collate_fn(b, tokenizer), **params)
else:
mono_dataloader = None

print(compute_metric(model, tokenizer, test_dataset))

if(config.train):
train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, collate_fn= lambda b: collate_fn(b, tokenizer), **params)
val_dataloader = torch.utils.data.DataLoader(dataset=val_dataset, collate_fn= lambda b: collate_fn(b, tokenizer), **params)
train_model(model, tokenizer, train_dataloader, val_dataloader, test_dataset, mono_dataloader)

if not config.mono:
train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, collate_fn= lambda b: collate_fn(b, tokenizer), **params)
val_dataloader = torch.utils.data.DataLoader(dataset=val_dataset, collate_fn= lambda b: collate_fn(b, tokenizer), **params)
train_model(model, tokenizer, train_dataloader, val_dataloader, test_dataset, mono_dataloader)
else:
train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, collate_fn= lambda b: mono_collate_fn(b, tokenizer), **params)
val_dataloader = torch.utils.data.DataLoader(dataset=val_dataset, collate_fn= lambda b: mono_collate_fn(b, tokenizer), **params)
train_model(model, tokenizer, train_dataloader, val_dataloader, test_dataset)

if(config.eval):
print(compute_metric(model, tokenizer, test_dataset))

Expand Down