-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
new_feature: Added a working model for Image Segmentation, can be use…
…d using python train.py --classes num_classes
- Loading branch information
1 parent
17925c3
commit 216acfc
Showing
4 changed files
with
373 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import os | ||
from PIL import Image | ||
import torch | ||
from torch.utils.data import Dataset | ||
import cv2 | ||
import torchvision.transforms as transforms | ||
|
||
class CustomDataset(Dataset): | ||
def __init__(self, image_paths: list, mask_paths: list, transform=None): | ||
self.image_paths = image_paths | ||
self.mask_paths = mask_paths | ||
self.transform = transform | ||
|
||
def __len__(self): | ||
return len(self.image_paths) | ||
|
||
def __getitem__(self, index): | ||
# Load image and mask based on index | ||
image_path = self.image_paths[index] | ||
mask_path = self.mask_paths[index] | ||
|
||
print(image_path) | ||
|
||
image = Image.open(image_path).convert('RGB') | ||
mask = Image.open(mask_path).convert('L') | ||
|
||
img = cv2.imread(image_path) | ||
|
||
if self.transform is not None: | ||
image = self.transform(image) | ||
|
||
mask_transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()]) | ||
mask = mask_transform(mask) | ||
mask = mask.to(torch.long) | ||
|
||
# You may need to further preprocess the mask if required | ||
# Example: Convert mask to tensor and perform class mapping | ||
|
||
return image, mask |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import torch | ||
from torch import nn | ||
from torch.nn import functional as F | ||
|
||
class Model(nn.Module): | ||
def __init__(self, n_class): | ||
super().__init__() | ||
|
||
# Encoder | ||
# In the encoder, convolutional layers with the Conv2d function are used to extract features from the input image. | ||
# Each block in the encoder consists of two convolutional layers followed by a max-pooling layer, with the exception of the last block which does not include a max-pooling layer. | ||
# ------- | ||
# input: 572x572x3 | ||
self.e11 = nn.Conv2d(3, 64, kernel_size=3, padding=1) # output: 570x570x64 | ||
self.e12 = nn.Conv2d(64, 64, kernel_size=3, padding=1) # output: 568x568x64 | ||
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) # output: 284x284x64 | ||
|
||
# input: 284x284x64 | ||
self.e21 = nn.Conv2d(64, 128, kernel_size=3, padding=1) # output: 282x282x128 | ||
self.e22 = nn.Conv2d(128, 128, kernel_size=3, padding=1) # output: 280x280x128 | ||
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) # output: 140x140x128 | ||
|
||
# input: 140x140x128 | ||
self.e31 = nn.Conv2d(128, 256, kernel_size=3, padding=1) # output: 138x138x256 | ||
self.e32 = nn.Conv2d(256, 256, kernel_size=3, padding=1) # output: 136x136x256 | ||
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) # output: 68x68x256 | ||
|
||
# input: 68x68x256 | ||
self.e41 = nn.Conv2d(256, 512, kernel_size=3, padding=1) # output: 66x66x512 | ||
self.e42 = nn.Conv2d(512, 512, kernel_size=3, padding=1) # output: 64x64x512 | ||
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) # output: 32x32x512 | ||
|
||
# input: 32x32x512 | ||
self.e51 = nn.Conv2d(512, 1024, kernel_size=3, padding=1) # output: 30x30x1024 | ||
self.e52 = nn.Conv2d(1024, 1024, kernel_size=3, padding=1) # output: 28x28x1024 | ||
|
||
|
||
# Decoder | ||
self.upconv1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2) | ||
self.d11 = nn.Conv2d(1024, 512, kernel_size=3, padding=1) | ||
self.d12 = nn.Conv2d(512, 512, kernel_size=3, padding=1) | ||
|
||
self.upconv2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2) | ||
self.d21 = nn.Conv2d(512, 256, kernel_size=3, padding=1) | ||
self.d22 = nn.Conv2d(256, 256, kernel_size=3, padding=1) | ||
|
||
self.upconv3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2) | ||
self.d31 = nn.Conv2d(256, 128, kernel_size=3, padding=1) | ||
self.d32 = nn.Conv2d(128, 128, kernel_size=3, padding=1) | ||
|
||
self.upconv4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2) | ||
self.d41 = nn.Conv2d(128, 64, kernel_size=3, padding=1) | ||
self.d42 = nn.Conv2d(64, 64, kernel_size=3, padding=1) | ||
|
||
# Output layer | ||
self.outconv = nn.Conv2d(64, n_class, kernel_size=1) | ||
|
||
def forward(self, x): | ||
# Encoder | ||
xe11 = F.relu(self.e11(x)) | ||
xe12 = F.relu(self.e12(xe11)) | ||
xp1 = self.pool1(xe12) | ||
|
||
xe21 = F.relu(self.e21(xp1)) | ||
xe22 = F.relu(self.e22(xe21)) | ||
xp2 = self.pool2(xe22) | ||
|
||
xe31 = F.relu(self.e31(xp2)) | ||
xe32 = F.relu(self.e32(xe31)) | ||
xp3 = self.pool3(xe32) | ||
|
||
xe41 = F.relu(self.e41(xp3)) | ||
xe42 = F.relu(self.e42(xe41)) | ||
xp4 = self.pool4(xe42) | ||
|
||
xe51 = F.relu(self.e51(xp4)) | ||
xe52 = F.relu(self.e52(xe51)) | ||
|
||
# Decoder | ||
xu1 = self.upconv1(xe52) | ||
xu11 = torch.cat([xu1, xe42], dim=1) | ||
xd11 = F.relu(self.d11(xu11)) | ||
xd12 = F.relu(self.d12(xd11)) | ||
|
||
xu2 = self.upconv2(xd12) | ||
xu22 = torch.cat([xu2, xe32], dim=1) | ||
xd21 = F.relu(self.d21(xu22)) | ||
xd22 = F.relu(self.d22(xd21)) | ||
|
||
xu3 = self.upconv3(xd22) | ||
xu33 = torch.cat([xu3, xe22], dim=1) | ||
xd31 = F.relu(self.d31(xu33)) | ||
xd32 = F.relu(self.d32(xd31)) | ||
|
||
xu4 = self.upconv4(xd32) | ||
xu44 = torch.cat([xu4, xe12], dim=1) | ||
xd41 = F.relu(self.d41(xu44)) | ||
xd42 = F.relu(self.d42(xd41)) | ||
|
||
# Output layer | ||
out = self.outconv(xd42) | ||
|
||
return out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import argparse | ||
import logging | ||
import os | ||
import torch | ||
import torch.optim as optim | ||
import torch.nn as nn | ||
from torchvision import transforms | ||
from torch.utils.data import DataLoader | ||
from utils import parse_folder, get_data_loader, train_unet | ||
from models.UNet import Model | ||
import glob | ||
from utils import generate_model_summary | ||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description='Train an Image Segmentation Model') | ||
parser.add_argument('--logging', type=bool, default=True, help='Enable or disable logging') | ||
parser.add_argument('--epochs', type=int, default=10, help='Number of training epochs') | ||
parser.add_argument('--data_dir', type=str, default='data', help='Path to the dataset directory') | ||
parser.add_argument('--loss_function', type=str, default='CrossEntropy', choices=['CrossEntropy', 'MSELoss']) | ||
parser.add_argument('--optimizer', type=str, default='Adam', choices=['Adam', 'SGD']) | ||
parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate') | ||
parser.add_argument('--batch_size', type=int, default=1, help='Batch size') | ||
parser.add_argument('--logging_directory', type=str, default='logs', help='Directory for logging') | ||
parser.add_argument('--checkpoint_directory', type=str, default='checkpoints', help='Directory for saving checkpoints') | ||
parser.add_argument('--classes', type=int, default='2', help='No. of classes you want to segment your model into.') | ||
args = parser.parse_args() | ||
|
||
# Create the logging directory | ||
if args.logging and not os.path.exists(args.logging_directory): | ||
os.makedirs(args.logging_directory) | ||
|
||
# Initialize logger | ||
log_file = os.path.join(args.logging_directory, 'training.log') | ||
logging.basicConfig(filename=log_file, level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | ||
|
||
# Parse the dataset folder | ||
parse_bool = parse_folder(args.data_dir) | ||
|
||
if not parse_bool: | ||
print("Dataset was not parsed correctlty, re-check the arrangement!") | ||
return None | ||
|
||
transform = None | ||
|
||
# Create data loaders | ||
|
||
train_path = os.path.join(args.data_dir, "train") | ||
test_path = os.path.join(args.data_dir, "test") | ||
|
||
# Construct full paths for train and test images and masks using glob | ||
train_image_paths = glob.glob(os.path.join(train_path, "images", "*.jpg")) | ||
train_mask_paths = glob.glob(os.path.join(train_path, "masks", "*.png")) | ||
|
||
test_image_paths = glob.glob(os.path.join(test_path, "images", "*.jpg")) | ||
test_mask_paths = glob.glob(os.path.join(test_path, "masks", "*.png")) | ||
|
||
# Create data loaders | ||
train_data_loader = get_data_loader(image_paths=train_image_paths, | ||
mask_paths=train_mask_paths, | ||
batch_size=args.batch_size, | ||
shuffle=True, transform=transform) | ||
|
||
test_data_loader = get_data_loader(image_paths=test_image_paths, | ||
mask_paths=test_mask_paths, | ||
batch_size=args.batch_size, | ||
shuffle=True, transform=transform) | ||
# Initialize U-Net model | ||
model = Model(n_class=20) # You may need to adjust classes | ||
|
||
# Define loss function and optimize | ||
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate) | ||
|
||
if args.loss_function == "CrossEntropy": | ||
criterion = nn.CrossEntropyLoss() | ||
elif args.loss_function == "MSELoss": | ||
criterion = nn.MSELoss() | ||
else: | ||
print("Choose a suitable criterion from the choices.") | ||
return None | ||
|
||
if args.optimizer == "Adam": | ||
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate) | ||
elif args.optimizer == "SGD": | ||
optimizer = optim.SGD(model.parameters(), lr=args.learning_rate) | ||
else: | ||
print("Choose a suitable optimizer from the choices.") | ||
return None | ||
generate_model_summary(model=model, input_size=(3, 512, 512)) | ||
# Train the model | ||
train_unet( | ||
model=model, | ||
train_data_loader=train_data_loader, | ||
test_data_loader=test_data_loader, | ||
num_epochs=args.epochs, | ||
learning_rate=args.learning_rate, | ||
checkpoint_dir=args.checkpoint_directory, | ||
logger=logging | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
import os | ||
import torch | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
from tqdm import tqdm | ||
from torch.utils.data import DataLoader | ||
from torchvision import transforms | ||
from dataset import CustomDataset | ||
import torchsummary | ||
|
||
def get_data_loader(image_paths:list, mask_paths:list, batch_size:int, shuffle:bool=True, transform=None) -> DataLoader: | ||
""" | ||
Create and return a data loader for a custom dataset. | ||
Args: | ||
data_dir (str): Path to the dataset directory. | ||
batch_size (int): Batch size for the data loader. | ||
shuffle (bool): Whether to shuffle the data (default is True). | ||
Returns: | ||
DataLoader: PyTorch data loader. | ||
""" | ||
# Define data transformations (adjust as needed) | ||
if transform is None: | ||
transform = transforms.Compose([ | ||
transforms.Resize((224, 224)), # Resize images to a fixed size | ||
transforms.ToTensor(), # Convert images to PyTorch tensors | ||
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) # Normalize with ImageNet stats | ||
]) | ||
|
||
# Create a custom dataset | ||
dataset = CustomDataset(image_paths=image_paths, mask_paths=mask_paths, transform=transform) | ||
|
||
# Create a data loader | ||
data_loader = DataLoader( | ||
dataset, | ||
batch_size=batch_size, | ||
shuffle=shuffle | ||
) | ||
|
||
return data_loader | ||
|
||
def parse_folder(dataset_path): | ||
try: | ||
if os.path.exists(dataset_path): | ||
# Store paths to train, test, and eval folders if they exist | ||
train_path = os.path.join(dataset_path, "train") | ||
test_path = os.path.join(dataset_path, "test") | ||
eval_path = os.path.join(dataset_path, "eval") | ||
|
||
if os.path.exists(train_path) and os.path.exists(test_path) and os.path.exists(eval_path): | ||
|
||
print(f"Train path: {train_path}") | ||
print(f"Test path: {test_path}") | ||
print(f"Eval path: {eval_path}") | ||
|
||
root_dir_list = os.listdir(dataset_path) | ||
|
||
for dir in root_dir_list: | ||
masks_path = os.path.join(dataset_path, dir, "masks") | ||
images_path = os.path.join(dataset_path, dir, "images") | ||
|
||
if os.path.exists(masks_path) and os.path.exists(images_path): | ||
pass | ||
else: | ||
return False | ||
|
||
return True | ||
|
||
else: | ||
print(f"The '{dataset_path}' folder does not exist in the current directory.") | ||
return False | ||
except Exception as e: | ||
print("An error occurred:", str(e)) | ||
return False | ||
|
||
|
||
def train_unet(model, train_data_loader, test_data_loader, num_epochs, learning_rate, checkpoint_dir, logger=None): | ||
criterion = nn.CrossEntropyLoss() | ||
optimizer = optim.Adam(model.parameters(), lr=learning_rate) | ||
|
||
best_loss = float('inf') | ||
|
||
for epoch in range(num_epochs): | ||
model.train() | ||
train_loss = 0.0 | ||
|
||
for inputs, targets in tqdm(train_data_loader, desc=f'Epoch {epoch + 1}/{num_epochs}', leave=False): | ||
optimizer.zero_grad() | ||
outputs = model(inputs) | ||
targets = targets.squeeze(1) | ||
loss = criterion(outputs, targets) | ||
loss.backward() | ||
optimizer.step() | ||
train_loss += loss.item() | ||
|
||
average_train_loss = train_loss / len(train_data_loader) | ||
|
||
if logger: | ||
logger.info(f'Epoch {epoch + 1}/{num_epochs}, Train Loss: {average_train_loss:.4f}') | ||
|
||
# Validation | ||
model.eval() | ||
val_loss = 0.0 | ||
|
||
with torch.no_grad(): | ||
for inputs, targets in tqdm(test_data_loader, desc=f'Validation', leave=False): | ||
outputs = model(inputs) | ||
targets = targets.squeeze(1) | ||
loss = criterion(outputs, targets) | ||
val_loss += loss.item() | ||
|
||
average_val_loss = val_loss / len(test_data_loader) | ||
|
||
if logger: | ||
logger.info(f'Epoch {epoch + 1}/{num_epochs}, Validation Loss: {average_val_loss:.4f}') | ||
|
||
# Save model checkpoint if validation loss improves | ||
if average_val_loss < best_loss: | ||
best_loss = average_val_loss | ||
checkpoint_path = f'{checkpoint_dir}/unet_model_epoch_{epoch + 1}.pth' | ||
torch.save(model.state_dict(), checkpoint_path) | ||
if logger: | ||
logger.info(f'Saved checkpoint to {checkpoint_path}') | ||
|
||
print('Finished Training') | ||
|
||
def generate_model_summary(model, input_size): | ||
torchsummary.summary(model, input_size=input_size) |