Skip to content

Commit

Permalink
added data augmentation and normalisation to the dataset and adjusted…
Browse files Browse the repository at this point in the history
… batch size for more stable gradients
  • Loading branch information
lorenzopolicar committed Oct 17, 2023
1 parent bedc7fd commit 43cec5e
Showing 1 changed file with 7 additions and 12 deletions.
19 changes: 7 additions & 12 deletions dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ def get_loaders():
# Define data transformations (resize, normalize, etc.) with data augmentation
transform = transforms.Compose([
transforms.Resize((256, 256)), # Resize images to a consistent size
transforms.RandomHorizontalFlip(), # Randomly flip the image horizontally
transforms.RandomRotation(10), # Randomly rotate the image up to 10 degrees
transforms.ToTensor(), # Convert images to tensors
transforms.Normalize((0.5,), (0.5,)) # Normalize pixel values (adjust mean and std as needed)
transforms.Normalize((0.1232,), (0.2308,)), # Normalize pixel values
transforms.RandomHorizontalFlip(p=0.5), # Randomly flip images horizontally
transforms.RandomRotation(degrees=15), # Randomly rotate images
])

# Create the ImageFolder dataset for training
Expand All @@ -36,16 +36,11 @@ def get_loaders():
train_data, validation_data = train_test_split(train_data, train_size=train_size, test_size=1 - train_size, shuffle=True, random_state=42)

# Create data loaders with reduced batch size and multi-processing
batch_size = 64 # Adjust as needed
batch_size = 32 # Adjust as needed
num_workers = 4 # Use multiple workers for data loading

train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
validation_loader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, num_workers=num_workers, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, num_workers=num_workers, pin_memory=True)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)#, num_workers=num_workers, pin_memory=True)
validation_loader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size)#, num_workers=num_workers, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size)#, num_workers=num_workers, pin_memory=True)

return train_loader, validation_loader, test_loader

if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader, validation_loader, test_loader = get_loaders()
print("done da sdasd")

0 comments on commit 43cec5e

Please sign in to comment.