Implementation of the NBeats model (paper) in Julia (Flux). To use the package please do the following, as the package is not yet in the general registry:
using Pkg
Pkg.add("https://github.com/MartinuzziFrancesco/NeuralBasisExpansions.jl")
The package is still undergoing heavy testing, expect unexpected behavior.
Full sin example with helper functions is given in the example
folder, under readme.jl
.
# Model parameters
forecast_length = 5
backcast_length = 2*forecast_length
batch_size = 32
hidden_units = 128
theta_dims = (4, 8)
blocks_per_stack = 3
# Generate and batch the data
data = generate_sine_data(1000, backcast_length, forecast_length)
train_data, test_data = data[1:800], data[801:end]
train_batches = batch_data(train_data, batch_size)
test_batches = batch_data(test_data, batch_size)
# Create the NBeatsNet model
model = NBeatsNet(
stacks=[generic_basis, trend_basis],
blocks_stacks=blocks_per_stack,
forecast_length=forecast_length,
backcast_length=backcast_length,
thetas_dim=theta_dims,
hidden_units=hidden_units
)
# Loss function and optimizer
loss_fn(x, y) = Flux.mse(model(x)[2], y)
optimizer = Flux.ADAM(0.001)
# Training loop
epochs = 50
for epoch in 1:epochs
Flux.train!(loss_fn, Flux.params(model), train_batches, optimizer)
train_loss = mean([loss_fn(getindex(batch, 1), getindex(batch, 2)) for batch in train_batches])
test_loss = mean([loss_fn(getindex(batch, 1), getindex(batch, 2)) for batch in test_batches])
println("Epoch $epoch: Train Loss = $train_loss, Test Loss = $test_loss")
end
# Forecast using the model (example)
x_test, y_true = test_batches[1]
y_pred = model(x_test)[2]
mse, mae, r_squared = evaluate_predictions(y_true, y_pred)
println("Mean Squared Error: $mse")
println("Mean Absolute Error: $mae")
println("R-squared: $r_squared")
Quick example with random data to test the model
forecast_length = 5
backcast_length = 10
blocks_stacks = 3
thetas_dim = (4, 8)
hidden_units = 256
nbeats_net = NBeatsNet(
stacks = [trend_basis, seasonality_basis],
blocks_stacks = blocks_stacks,
forecast_length = forecast_length,
backcast_length = backcast_length,
thetas_dim = thetas_dim,
share_weights = false,
hidden_units = hidden_units
)
# Create a batch of input data
batch_size = 3 # Number of instances in the batch
input_data = randn(Float32, backcast_length, batch_size)
backcast_output, forecast_output = nbeats_net(input_data)