|
| 1 | +// Copyright 2019 Markus Fleischhacker |
| 2 | +#include <torch/torch.h> |
| 3 | +#include <iostream> |
| 4 | +#include <iomanip> |
| 5 | +#include "bi_rnn.h" |
| 6 | + |
| 7 | +int main() { |
| 8 | + std::cout << "Bidirectional Recurrent Neural Network\n\n"; |
| 9 | + |
| 10 | + // Device |
| 11 | + torch::Device device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU); |
| 12 | + |
| 13 | + // Hyper parameters |
| 14 | + const int64_t sequence_length = 28; |
| 15 | + const int64_t input_size = 28; |
| 16 | + const int64_t hidden_size = 128; |
| 17 | + const int64_t num_layers = 2; |
| 18 | + const int64_t num_classes = 10; |
| 19 | + const int64_t batch_size = 100; |
| 20 | + const int64_t num_epochs = 2; |
| 21 | + const double learning_rate = 0.003; |
| 22 | + |
| 23 | + const std::string MNIST_data_path = "../../../../tutorials/intermediate/" |
| 24 | + "bidirectional_recurrent_neural_network/data/"; |
| 25 | + |
| 26 | + // MNIST dataset |
| 27 | + auto train_dataset = torch::data::datasets::MNIST(MNIST_data_path) |
| 28 | + .map(torch::data::transforms::Normalize<>(0.1307, 0.3081)) |
| 29 | + .map(torch::data::transforms::Stack<>()); |
| 30 | + |
| 31 | + // Number of samples in the training set |
| 32 | + auto num_train_samples = train_dataset.size().value(); |
| 33 | + |
| 34 | + auto test_dataset = torch::data::datasets::MNIST(MNIST_data_path, torch::data::datasets::MNIST::Mode::kTest) |
| 35 | + .map(torch::data::transforms::Normalize<>(0.1307, 0.3081)) |
| 36 | + .map(torch::data::transforms::Stack<>()); |
| 37 | + |
| 38 | + // Number of samples in the testset |
| 39 | + auto num_test_samples = test_dataset.size().value(); |
| 40 | + |
| 41 | + // Data loader |
| 42 | + auto train_loader = torch::data::make_data_loader<torch::data::samplers::RandomSampler>( |
| 43 | + std::move(train_dataset), batch_size); |
| 44 | + auto test_loader = torch::data::make_data_loader<torch::data::samplers::SequentialSampler>( |
| 45 | + std::move(test_dataset), batch_size); |
| 46 | + |
| 47 | + // Model |
| 48 | + BiRNN model(input_size, hidden_size, num_layers, num_classes); |
| 49 | + model->to(device); |
| 50 | + |
| 51 | + // Optimizer |
| 52 | + auto optimizer = torch::optim::Adam(model->parameters(), torch::optim::AdamOptions(learning_rate)); |
| 53 | + |
| 54 | + // Set floating point output precision |
| 55 | + std::cout << std::fixed << std::setprecision(4); |
| 56 | + |
| 57 | + std::cout << "Training...\n"; |
| 58 | + |
| 59 | + // Train the model |
| 60 | + for (size_t epoch = 0; epoch != num_epochs; ++epoch) { |
| 61 | + // Initialize running metrics |
| 62 | + float running_loss = 0.0; |
| 63 | + size_t num_correct = 0; |
| 64 | + |
| 65 | + for (auto& batch : *train_loader) { |
| 66 | + // Transfer images and target labels to device |
| 67 | + auto data = batch.data.view({-1, sequence_length, input_size}).to(device); |
| 68 | + auto target = batch.target.to(device); |
| 69 | + |
| 70 | + // Forward pass |
| 71 | + auto output = model->forward(data); |
| 72 | + |
| 73 | + // Calculate loss |
| 74 | + auto loss = torch::nll_loss(output, target); |
| 75 | + // Update running loss |
| 76 | + running_loss += loss.item().toFloat() * data.size(0); |
| 77 | + |
| 78 | + // Calculate prediction |
| 79 | + auto prediction = output.argmax(1); |
| 80 | + |
| 81 | + // Update number of correctly classified samples |
| 82 | + num_correct += prediction.eq(target).sum().item().toLong(); |
| 83 | + |
| 84 | + // Backward pass and optimize |
| 85 | + optimizer.zero_grad(); |
| 86 | + loss.backward(); |
| 87 | + optimizer.step(); |
| 88 | + } |
| 89 | + |
| 90 | + auto sample_mean_loss = running_loss / num_train_samples; |
| 91 | + auto accuracy = static_cast<float>(num_correct) / num_train_samples; |
| 92 | + |
| 93 | + std::cout << "Epoch [" << (epoch + 1) << "/" << num_epochs << "], Trainset - Loss: " |
| 94 | + << sample_mean_loss << ", Accuracy: " << accuracy << '\n'; |
| 95 | + } |
| 96 | + |
| 97 | + std::cout << "Training finished!\n\n"; |
| 98 | + std::cout << "Testing...\n"; |
| 99 | + |
| 100 | + // Test the model |
| 101 | + model->eval(); |
| 102 | + torch::NoGradGuard no_grad; |
| 103 | + |
| 104 | + float running_loss = 0.0; |
| 105 | + size_t num_correct = 0; |
| 106 | + |
| 107 | + for (const auto& batch : *test_loader) { |
| 108 | + auto data = batch.data.view({-1, sequence_length, input_size}).to(device); |
| 109 | + auto target = batch.target.to(device); |
| 110 | + |
| 111 | + auto output = model->forward(data); |
| 112 | + |
| 113 | + auto loss = torch::nll_loss(output, target); |
| 114 | + running_loss += loss.item().toFloat() * data.size(0); |
| 115 | + |
| 116 | + auto prediction = output.argmax(1); |
| 117 | + num_correct += prediction.eq(target).sum().item().toLong(); |
| 118 | + } |
| 119 | + |
| 120 | + std::cout << "Testing finished!\n"; |
| 121 | + |
| 122 | + auto test_accuracy = static_cast<float>(num_correct) / num_test_samples; |
| 123 | + auto test_sample_mean_loss = running_loss / num_test_samples; |
| 124 | + |
| 125 | + std::cout << "Testset - Loss: " << test_sample_mean_loss << ", Accuracy: " << test_accuracy << '\n'; |
| 126 | +} |
0 commit comments