Skip to content

Commit

Permalink
new_feature: Added a working model for Image Segmentation, can be use…
Browse files Browse the repository at this point in the history
…d using python train.py --classes num_classes
  • Loading branch information
ishan121028 committed Oct 2, 2023
1 parent 17925c3 commit 216acfc
Show file tree
Hide file tree
Showing 4 changed files with 373 additions and 0 deletions.
39 changes: 39 additions & 0 deletions Segmentation/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import os
from PIL import Image
import torch
from torch.utils.data import Dataset
import cv2
import torchvision.transforms as transforms

class CustomDataset(Dataset):
def __init__(self, image_paths: list, mask_paths: list, transform=None):
self.image_paths = image_paths
self.mask_paths = mask_paths
self.transform = transform

def __len__(self):
return len(self.image_paths)

def __getitem__(self, index):
# Load image and mask based on index
image_path = self.image_paths[index]
mask_path = self.mask_paths[index]

print(image_path)

image = Image.open(image_path).convert('RGB')
mask = Image.open(mask_path).convert('L')

img = cv2.imread(image_path)

if self.transform is not None:
image = self.transform(image)

mask_transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
mask = mask_transform(mask)
mask = mask.to(torch.long)

# You may need to further preprocess the mask if required
# Example: Convert mask to tensor and perform class mapping

return image, mask
103 changes: 103 additions & 0 deletions Segmentation/models/UNet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import torch
from torch import nn
from torch.nn import functional as F

class Model(nn.Module):
def __init__(self, n_class):
super().__init__()

# Encoder
# In the encoder, convolutional layers with the Conv2d function are used to extract features from the input image.
# Each block in the encoder consists of two convolutional layers followed by a max-pooling layer, with the exception of the last block which does not include a max-pooling layer.
# -------
# input: 572x572x3
self.e11 = nn.Conv2d(3, 64, kernel_size=3, padding=1) # output: 570x570x64
self.e12 = nn.Conv2d(64, 64, kernel_size=3, padding=1) # output: 568x568x64
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) # output: 284x284x64

# input: 284x284x64
self.e21 = nn.Conv2d(64, 128, kernel_size=3, padding=1) # output: 282x282x128
self.e22 = nn.Conv2d(128, 128, kernel_size=3, padding=1) # output: 280x280x128
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) # output: 140x140x128

# input: 140x140x128
self.e31 = nn.Conv2d(128, 256, kernel_size=3, padding=1) # output: 138x138x256
self.e32 = nn.Conv2d(256, 256, kernel_size=3, padding=1) # output: 136x136x256
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) # output: 68x68x256

# input: 68x68x256
self.e41 = nn.Conv2d(256, 512, kernel_size=3, padding=1) # output: 66x66x512
self.e42 = nn.Conv2d(512, 512, kernel_size=3, padding=1) # output: 64x64x512
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) # output: 32x32x512

# input: 32x32x512
self.e51 = nn.Conv2d(512, 1024, kernel_size=3, padding=1) # output: 30x30x1024
self.e52 = nn.Conv2d(1024, 1024, kernel_size=3, padding=1) # output: 28x28x1024


# Decoder
self.upconv1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
self.d11 = nn.Conv2d(1024, 512, kernel_size=3, padding=1)
self.d12 = nn.Conv2d(512, 512, kernel_size=3, padding=1)

self.upconv2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.d21 = nn.Conv2d(512, 256, kernel_size=3, padding=1)
self.d22 = nn.Conv2d(256, 256, kernel_size=3, padding=1)

self.upconv3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.d31 = nn.Conv2d(256, 128, kernel_size=3, padding=1)
self.d32 = nn.Conv2d(128, 128, kernel_size=3, padding=1)

self.upconv4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.d41 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
self.d42 = nn.Conv2d(64, 64, kernel_size=3, padding=1)

# Output layer
self.outconv = nn.Conv2d(64, n_class, kernel_size=1)

def forward(self, x):
# Encoder
xe11 = F.relu(self.e11(x))
xe12 = F.relu(self.e12(xe11))
xp1 = self.pool1(xe12)

xe21 = F.relu(self.e21(xp1))
xe22 = F.relu(self.e22(xe21))
xp2 = self.pool2(xe22)

xe31 = F.relu(self.e31(xp2))
xe32 = F.relu(self.e32(xe31))
xp3 = self.pool3(xe32)

xe41 = F.relu(self.e41(xp3))
xe42 = F.relu(self.e42(xe41))
xp4 = self.pool4(xe42)

xe51 = F.relu(self.e51(xp4))
xe52 = F.relu(self.e52(xe51))

# Decoder
xu1 = self.upconv1(xe52)
xu11 = torch.cat([xu1, xe42], dim=1)
xd11 = F.relu(self.d11(xu11))
xd12 = F.relu(self.d12(xd11))

xu2 = self.upconv2(xd12)
xu22 = torch.cat([xu2, xe32], dim=1)
xd21 = F.relu(self.d21(xu22))
xd22 = F.relu(self.d22(xd21))

xu3 = self.upconv3(xd22)
xu33 = torch.cat([xu3, xe22], dim=1)
xd31 = F.relu(self.d31(xu33))
xd32 = F.relu(self.d32(xd31))

xu4 = self.upconv4(xd32)
xu44 = torch.cat([xu4, xe12], dim=1)
xd41 = F.relu(self.d41(xu44))
xd42 = F.relu(self.d42(xd41))

# Output layer
out = self.outconv(xd42)

return out
102 changes: 102 additions & 0 deletions Segmentation/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import argparse
import logging
import os
import torch
import torch.optim as optim
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader
from utils import parse_folder, get_data_loader, train_unet
from models.UNet import Model
import glob
from utils import generate_model_summary

def main():
parser = argparse.ArgumentParser(description='Train an Image Segmentation Model')
parser.add_argument('--logging', type=bool, default=True, help='Enable or disable logging')
parser.add_argument('--epochs', type=int, default=10, help='Number of training epochs')
parser.add_argument('--data_dir', type=str, default='data', help='Path to the dataset directory')
parser.add_argument('--loss_function', type=str, default='CrossEntropy', choices=['CrossEntropy', 'MSELoss'])
parser.add_argument('--optimizer', type=str, default='Adam', choices=['Adam', 'SGD'])
parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate')
parser.add_argument('--batch_size', type=int, default=1, help='Batch size')
parser.add_argument('--logging_directory', type=str, default='logs', help='Directory for logging')
parser.add_argument('--checkpoint_directory', type=str, default='checkpoints', help='Directory for saving checkpoints')
parser.add_argument('--classes', type=int, default='2', help='No. of classes you want to segment your model into.')
args = parser.parse_args()

# Create the logging directory
if args.logging and not os.path.exists(args.logging_directory):
os.makedirs(args.logging_directory)

# Initialize logger
log_file = os.path.join(args.logging_directory, 'training.log')
logging.basicConfig(filename=log_file, level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Parse the dataset folder
parse_bool = parse_folder(args.data_dir)

if not parse_bool:
print("Dataset was not parsed correctlty, re-check the arrangement!")
return None

transform = None

# Create data loaders

train_path = os.path.join(args.data_dir, "train")
test_path = os.path.join(args.data_dir, "test")

# Construct full paths for train and test images and masks using glob
train_image_paths = glob.glob(os.path.join(train_path, "images", "*.jpg"))
train_mask_paths = glob.glob(os.path.join(train_path, "masks", "*.png"))

test_image_paths = glob.glob(os.path.join(test_path, "images", "*.jpg"))
test_mask_paths = glob.glob(os.path.join(test_path, "masks", "*.png"))

# Create data loaders
train_data_loader = get_data_loader(image_paths=train_image_paths,
mask_paths=train_mask_paths,
batch_size=args.batch_size,
shuffle=True, transform=transform)

test_data_loader = get_data_loader(image_paths=test_image_paths,
mask_paths=test_mask_paths,
batch_size=args.batch_size,
shuffle=True, transform=transform)
# Initialize U-Net model
model = Model(n_class=20) # You may need to adjust classes

# Define loss function and optimize
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)

if args.loss_function == "CrossEntropy":
criterion = nn.CrossEntropyLoss()
elif args.loss_function == "MSELoss":
criterion = nn.MSELoss()
else:
print("Choose a suitable criterion from the choices.")
return None

if args.optimizer == "Adam":
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
elif args.optimizer == "SGD":
optimizer = optim.SGD(model.parameters(), lr=args.learning_rate)
else:
print("Choose a suitable optimizer from the choices.")
return None
generate_model_summary(model=model, input_size=(3, 512, 512))
# Train the model
train_unet(
model=model,
train_data_loader=train_data_loader,
test_data_loader=test_data_loader,
num_epochs=args.epochs,
learning_rate=args.learning_rate,
checkpoint_dir=args.checkpoint_directory,
logger=logging
)


if __name__ == "__main__":
main()
129 changes: 129 additions & 0 deletions Segmentation/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import os
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torch.utils.data import DataLoader
from torchvision import transforms
from dataset import CustomDataset
import torchsummary

def get_data_loader(image_paths:list, mask_paths:list, batch_size:int, shuffle:bool=True, transform=None) -> DataLoader:
"""
Create and return a data loader for a custom dataset.
Args:
data_dir (str): Path to the dataset directory.
batch_size (int): Batch size for the data loader.
shuffle (bool): Whether to shuffle the data (default is True).
Returns:
DataLoader: PyTorch data loader.
"""
# Define data transformations (adjust as needed)
if transform is None:
transform = transforms.Compose([
transforms.Resize((224, 224)), # Resize images to a fixed size
transforms.ToTensor(), # Convert images to PyTorch tensors
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) # Normalize with ImageNet stats
])

# Create a custom dataset
dataset = CustomDataset(image_paths=image_paths, mask_paths=mask_paths, transform=transform)

# Create a data loader
data_loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle
)

return data_loader

def parse_folder(dataset_path):
try:
if os.path.exists(dataset_path):
# Store paths to train, test, and eval folders if they exist
train_path = os.path.join(dataset_path, "train")
test_path = os.path.join(dataset_path, "test")
eval_path = os.path.join(dataset_path, "eval")

if os.path.exists(train_path) and os.path.exists(test_path) and os.path.exists(eval_path):

print(f"Train path: {train_path}")
print(f"Test path: {test_path}")
print(f"Eval path: {eval_path}")

root_dir_list = os.listdir(dataset_path)

for dir in root_dir_list:
masks_path = os.path.join(dataset_path, dir, "masks")
images_path = os.path.join(dataset_path, dir, "images")

if os.path.exists(masks_path) and os.path.exists(images_path):
pass
else:
return False

return True

else:
print(f"The '{dataset_path}' folder does not exist in the current directory.")
return False
except Exception as e:
print("An error occurred:", str(e))
return False


def train_unet(model, train_data_loader, test_data_loader, num_epochs, learning_rate, checkpoint_dir, logger=None):
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

best_loss = float('inf')

for epoch in range(num_epochs):
model.train()
train_loss = 0.0

for inputs, targets in tqdm(train_data_loader, desc=f'Epoch {epoch + 1}/{num_epochs}', leave=False):
optimizer.zero_grad()
outputs = model(inputs)
targets = targets.squeeze(1)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
train_loss += loss.item()

average_train_loss = train_loss / len(train_data_loader)

if logger:
logger.info(f'Epoch {epoch + 1}/{num_epochs}, Train Loss: {average_train_loss:.4f}')

# Validation
model.eval()
val_loss = 0.0

with torch.no_grad():
for inputs, targets in tqdm(test_data_loader, desc=f'Validation', leave=False):
outputs = model(inputs)
targets = targets.squeeze(1)
loss = criterion(outputs, targets)
val_loss += loss.item()

average_val_loss = val_loss / len(test_data_loader)

if logger:
logger.info(f'Epoch {epoch + 1}/{num_epochs}, Validation Loss: {average_val_loss:.4f}')

# Save model checkpoint if validation loss improves
if average_val_loss < best_loss:
best_loss = average_val_loss
checkpoint_path = f'{checkpoint_dir}/unet_model_epoch_{epoch + 1}.pth'
torch.save(model.state_dict(), checkpoint_path)
if logger:
logger.info(f'Saved checkpoint to {checkpoint_path}')

print('Finished Training')

def generate_model_summary(model, input_size):
torchsummary.summary(model, input_size=input_size)

0 comments on commit 216acfc

Please sign in to comment.