diff --git a/dataset.py b/dataset.py index eaa0cd329d..76f3340c80 100644 --- a/dataset.py +++ b/dataset.py @@ -17,10 +17,10 @@ def get_loaders(): # Define data transformations (resize, normalize, etc.) with data augmentation transform = transforms.Compose([ transforms.Resize((256, 256)), # Resize images to a consistent size - transforms.RandomHorizontalFlip(), # Randomly flip the image horizontally - transforms.RandomRotation(10), # Randomly rotate the image up to 10 degrees transforms.ToTensor(), # Convert images to tensors - transforms.Normalize((0.5,), (0.5,)) # Normalize pixel values (adjust mean and std as needed) + transforms.Normalize((0.1232,), (0.2308,)), # Normalize pixel values + transforms.RandomHorizontalFlip(p=0.5), # Randomly flip images horizontally + transforms.RandomRotation(degrees=15), # Randomly rotate images ]) # Create the ImageFolder dataset for training @@ -36,16 +36,11 @@ def get_loaders(): train_data, validation_data = train_test_split(train_data, train_size=train_size, test_size=1 - train_size, shuffle=True, random_state=42) # Create data loaders with reduced batch size and multi-processing - batch_size = 64 # Adjust as needed + batch_size = 32 # Adjust as needed num_workers = 4 # Use multiple workers for data loading - train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) - validation_loader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, num_workers=num_workers, pin_memory=True) - test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, num_workers=num_workers, pin_memory=True) + train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)#, num_workers=num_workers, pin_memory=True) + validation_loader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size)#, num_workers=num_workers, pin_memory=True) + test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size)#, num_workers=num_workers, pin_memory=True) return train_loader, validation_loader, test_loader - -if __name__ == "__main__": - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - train_loader, validation_loader, test_loader = get_loaders() - print("done da sdasd")