Skip to content

Commit 42a82ba

Browse files
authored
Merge pull request #11 from mfl28/bidirectional-rnn
Add Bidirectional Recurrent Neural Network tutorial
2 parents e3b2c21 + a696366 commit 42a82ba

File tree

6 files changed

+200
-1
lines changed

6 files changed

+200
-1
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ add_subdirectory("tutorials/basics/pytorch_basics")
2121
add_subdirectory("tutorials/intermediate/convolutional_neural_network")
2222
add_subdirectory("tutorials/intermediate/deep_residual_network")
2323
add_subdirectory("tutorials/intermediate/recurrent_neural_network")
24+
add_subdirectory("tutorials/intermediate/bidirectional_recurrent_neural_network")
2425

2526
# The following code block is suggested to be used on Windows.
2627
# According to https://github.com/pytorch/pytorch/issues/25457,

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ $ ./scripts.sh build
3232
* [Convolutional Neural Network](https://github.com/prabhuomkar/pytorch-cpp/tree/master/tutorials/intermediate/convolutional_neural_network/src/main.cpp)
3333
* [Deep Residual Network](https://github.com/prabhuomkar/pytorch-cpp/tree/master/tutorials/intermediate/deep_residual_network/src/main.cpp)
3434
* [Recurrent Neural Network](https://github.com/prabhuomkar/pytorch-cpp/tree/master/tutorials/intermediate/recurrent_neural_network/src/main.cpp)
35-
* [Bidirectional Recurrent Neural Network]()
35+
* [Bidirectional Recurrent Neural Network](https://github.com/prabhuomkar/pytorch-cpp/tree/master/tutorials/intermediate/bidirectional_recurrent_neural_network/src/main.cpp)
3636
* [Language Model (RNN-LM)]()
3737

3838
#### 3. Advanced
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
2+
3+
project(bidirectional-recurrent-neural-network VERSION 1.0.0 LANGUAGES CXX)
4+
5+
# Files
6+
set(SOURCES src/main.cpp
7+
src/bi_rnn.cpp
8+
)
9+
10+
set(HEADERS include/bi_rnn.h
11+
)
12+
13+
set(EXECUTABLE_NAME bidirectional-recurrent-neural-network)
14+
15+
16+
add_executable(${EXECUTABLE_NAME} ${SOURCES} ${HEADERS})
17+
target_include_directories(${EXECUTABLE_NAME} PRIVATE include)
18+
19+
target_link_libraries(${EXECUTABLE_NAME} "${TORCH_LIBRARIES}")
20+
21+
set_target_properties(${EXECUTABLE_NAME} PROPERTIES
22+
CXX_STANDARD 11
23+
CXX_STANDARD_REQUIRED YES
24+
)
25+
26+
# The following code block is suggested to be used on Windows.
27+
# According to https://github.com/pytorch/pytorch/issues/25457,
28+
# the DLLs need to be copied to avoid memory errors.
29+
# See https://pytorch.org/cppdocs/installing.html.
30+
if (MSVC)
31+
file(GLOB TORCH_DLLS "${TORCH_INSTALL_PREFIX}/lib/*.dll")
32+
add_custom_command(TARGET ${EXECUTABLE_NAME}
33+
POST_BUILD
34+
COMMAND ${CMAKE_COMMAND} -E copy_if_different
35+
${TORCH_DLLS}
36+
$<TARGET_FILE_DIR:${EXECUTABLE_NAME}>)
37+
endif (MSVC)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// Copyright 2019 Markus Fleischhacker
2+
#pragma once
3+
4+
#include <torch/torch.h>
5+
6+
class BiRNNImpl : public torch::nn::Module {
7+
public:
8+
BiRNNImpl(int64_t input_size, int64_t hidden_size, int64_t num_layers, int64_t num_classes);
9+
torch::Tensor forward(torch::Tensor x);
10+
11+
private:
12+
torch::nn::LSTM lstm;
13+
torch::nn::Linear fc;
14+
};
15+
16+
TORCH_MODULE(BiRNN);
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// Copyright 2019 Markus Fleischhacker
2+
#include "bi_rnn.h"
3+
#include <torch/torch.h>
4+
5+
BiRNNImpl::BiRNNImpl(int64_t input_size, int64_t hidden_size, int64_t num_layers, int64_t num_classes)
6+
: lstm(torch::nn::LSTMOptions(input_size, hidden_size).layers(num_layers).batch_first(true).bidirectional(true)),
7+
fc(hidden_size * 2, num_classes) {
8+
register_module("lstm", lstm);
9+
register_module("fc", fc);
10+
}
11+
12+
torch::Tensor BiRNNImpl::forward(torch::Tensor x) {
13+
auto out = lstm->forward(x)
14+
.output
15+
.slice(1, -1)
16+
.squeeze(1);
17+
out = fc->forward(out);
18+
return torch::log_softmax(out, 1);
19+
}
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
// Copyright 2019 Markus Fleischhacker
2+
#include <torch/torch.h>
3+
#include <iostream>
4+
#include <iomanip>
5+
#include "bi_rnn.h"
6+
7+
int main() {
8+
std::cout << "Bidirectional Recurrent Neural Network\n\n";
9+
10+
// Device
11+
torch::Device device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU);
12+
13+
// Hyper parameters
14+
const int64_t sequence_length = 28;
15+
const int64_t input_size = 28;
16+
const int64_t hidden_size = 128;
17+
const int64_t num_layers = 2;
18+
const int64_t num_classes = 10;
19+
const int64_t batch_size = 100;
20+
const int64_t num_epochs = 2;
21+
const double learning_rate = 0.003;
22+
23+
const std::string MNIST_data_path = "../../../../tutorials/intermediate/"
24+
"bidirectional_recurrent_neural_network/data/";
25+
26+
// MNIST dataset
27+
auto train_dataset = torch::data::datasets::MNIST(MNIST_data_path)
28+
.map(torch::data::transforms::Normalize<>(0.1307, 0.3081))
29+
.map(torch::data::transforms::Stack<>());
30+
31+
// Number of samples in the training set
32+
auto num_train_samples = train_dataset.size().value();
33+
34+
auto test_dataset = torch::data::datasets::MNIST(MNIST_data_path, torch::data::datasets::MNIST::Mode::kTest)
35+
.map(torch::data::transforms::Normalize<>(0.1307, 0.3081))
36+
.map(torch::data::transforms::Stack<>());
37+
38+
// Number of samples in the testset
39+
auto num_test_samples = test_dataset.size().value();
40+
41+
// Data loader
42+
auto train_loader = torch::data::make_data_loader<torch::data::samplers::RandomSampler>(
43+
std::move(train_dataset), batch_size);
44+
auto test_loader = torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(
45+
std::move(test_dataset), batch_size);
46+
47+
// Model
48+
BiRNN model(input_size, hidden_size, num_layers, num_classes);
49+
model->to(device);
50+
51+
// Optimizer
52+
auto optimizer = torch::optim::Adam(model->parameters(), torch::optim::AdamOptions(learning_rate));
53+
54+
// Set floating point output precision
55+
std::cout << std::fixed << std::setprecision(4);
56+
57+
std::cout << "Training...\n";
58+
59+
// Train the model
60+
for (size_t epoch = 0; epoch != num_epochs; ++epoch) {
61+
// Initialize running metrics
62+
float running_loss = 0.0;
63+
size_t num_correct = 0;
64+
65+
for (auto& batch : *train_loader) {
66+
// Transfer images and target labels to device
67+
auto data = batch.data.view({-1, sequence_length, input_size}).to(device);
68+
auto target = batch.target.to(device);
69+
70+
// Forward pass
71+
auto output = model->forward(data);
72+
73+
// Calculate loss
74+
auto loss = torch::nll_loss(output, target);
75+
// Update running loss
76+
running_loss += loss.item().toFloat() * data.size(0);
77+
78+
// Calculate prediction
79+
auto prediction = output.argmax(1);
80+
81+
// Update number of correctly classified samples
82+
num_correct += prediction.eq(target).sum().item().toLong();
83+
84+
// Backward pass and optimize
85+
optimizer.zero_grad();
86+
loss.backward();
87+
optimizer.step();
88+
}
89+
90+
auto sample_mean_loss = running_loss / num_train_samples;
91+
auto accuracy = static_cast<float>(num_correct) / num_train_samples;
92+
93+
std::cout << "Epoch [" << (epoch + 1) << "/" << num_epochs << "], Trainset - Loss: "
94+
<< sample_mean_loss << ", Accuracy: " << accuracy << '\n';
95+
}
96+
97+
std::cout << "Training finished!\n\n";
98+
std::cout << "Testing...\n";
99+
100+
// Test the model
101+
model->eval();
102+
torch::NoGradGuard no_grad;
103+
104+
float running_loss = 0.0;
105+
size_t num_correct = 0;
106+
107+
for (const auto& batch : *test_loader) {
108+
auto data = batch.data.view({-1, sequence_length, input_size}).to(device);
109+
auto target = batch.target.to(device);
110+
111+
auto output = model->forward(data);
112+
113+
auto loss = torch::nll_loss(output, target);
114+
running_loss += loss.item().toFloat() * data.size(0);
115+
116+
auto prediction = output.argmax(1);
117+
num_correct += prediction.eq(target).sum().item().toLong();
118+
}
119+
120+
std::cout << "Testing finished!\n";
121+
122+
auto test_accuracy = static_cast<float>(num_correct) / num_test_samples;
123+
auto test_sample_mean_loss = running_loss / num_test_samples;
124+
125+
std::cout << "Testset - Loss: " << test_sample_mean_loss << ", Accuracy: " << test_accuracy << '\n';
126+
}

0 commit comments

Comments
 (0)