forked from shakes76/PatternAnalysis-2023
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
161 lines (128 loc) · 5.22 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import matplotlib.pyplot as plt
import torch
import numpy as np
import dataset
import module
import logging
"""
Training the Perceiver Transformer for Alzheimer's classification using PyTorch.
"""
def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load training data with train valid split of 0.8 to 0.2
train_data, valid_data, test_data = dataset.get_loaders()
print(f"Train Data Size: {len(train_data.dataset)}\nValid Data Size: {len(valid_data.dataset)}")
print(f"1st Train Data: {next(iter(train_data))[0].shape}\n1st Valid Data: {next(iter(valid_data))[0].shape}")
sample, _ = next(iter(train_data))
# Print the shape of the input image
print("Input image shape:", sample.shape)
logging.info("Loaded data")
# Create model and training components
model, optimizer, criterion, scheduler = module.create_model(
input_shape=(256, 256),
latent_dim=8, # Increase latent space dimension for more representational capacity
embed_dim=16,
attention_mlp_dim=16,
transformer_mlp_dim=16,
transformer_heads=4, # Use more attention heads for enhanced feature capturing
dropout=0.1,
transformer_layers=4,
n_blocks=4,
n_classes=2,
lr=0.005,
)
model = model.to(device)
EPOCHS = 30
# Tracking minimum loss
min_valid_loss = np.inf
# Tracking accuracy and loss during training
history = {'train_loss': [], 'train_acc': [], 'valid_loss': [], 'valid_acc': []}
early_stopping_patience = 5
epochs_without_improvement = 0
for epoch in range(EPOCHS):
# Free up GPU memory before each epoch
torch.cuda.empty_cache()
train_loss, train_acc = train(model, train_data, criterion, optimizer, device)
valid_loss, valid_acc = validate(model, valid_data, criterion, device)
# Append metric history
history['train_acc'].append(train_acc)
history['train_loss'].append(train_loss)
history['valid_acc'].append(valid_acc)
history['valid_loss'].append(valid_loss)
logging.info(f"Epoch: {epoch + 1}\nTrain loss: {train_loss}\nTrain Accuracy: {train_acc}\nValid Loss: {valid_loss}\nValid Accuracy: {valid_acc}\n")
scheduler.step() # Step the learning rate scheduler after each epoch
if valid_loss < min_valid_loss:
min_valid_loss = valid_loss
best_model_state = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'loss': valid_loss
}
epochs_without_improvement = 0
else:
epochs_without_improvement += 1
if epochs_without_improvement >= early_stopping_patience:
logging.info(f"No improvement in validation loss for {early_stopping_patience} epochs.")
#break # Exit the training loop
# Save the best model state
torch.save(best_model_state, 'saved/best_model.pth')
# Plot training history
plot_training_history(history)
def train(model, data_loader, criterion, optimizer, device):
model.train()
total_loss = 0.0
correct_predictions = 0
total_samples = len(data_loader.dataset)
for batch, labels in data_loader:
batch, labels = batch.to(device), labels.to(device)
optimizer.zero_grad()
prediction = model(batch)
loss = criterion(prediction, labels)
loss.backward()
optimizer.step()
total_loss += loss.item() * len(batch)
correct_predictions += torch.sum(torch.argmax(prediction, dim=1) == labels)
loss = total_loss / total_samples
accuracy = correct_predictions / total_samples
return loss, accuracy
def validate(model, data_loader, criterion, device):
model.eval()
total_loss = 0.0
correct_predictions = 0
total_samples = len(data_loader.dataset)
with torch.no_grad():
for batch, labels in data_loader:
batch, labels = batch.to(device), labels.to(device)
prediction = model(batch)
loss = criterion(prediction, labels)
total_loss += loss.item() * len(batch)
correct_predictions += torch.sum(torch.argmax(prediction, dim=1) == labels)
loss = total_loss / total_samples
accuracy = correct_predictions / total_samples
return loss, accuracy
def plot_training_history(history):
plt.figure(figsize=(12, 8), dpi=80)
plt.plot(history['train_acc'])
plt.plot(history['valid_acc'])
plt.xlim([0, len(history['train_acc'])])
plt.xticks(range(len(history['train_acc'])))
plt.ylim([0, 1])
plt.title('Accuracy')
plt.legend(['Training', 'Validation'])
plt.show()
plt.savefig('plots/accuracy.png')
plt.figure(figsize=(12, 8), dpi=80)
plt.plot(history['train_loss'])
plt.plot(history['valid_loss'])
plt.xlim([0, len(history['train_loss'])])
plt.xticks(range(len(history['train_loss'])))
plt.ylim([0, 1])
plt.title('Loss')
plt.legend(['Training', 'Validation'])
plt.show()
plt.savefig('plots/loss.png')
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
main()