Skip to content
This repository was archived by the owner on May 7, 2025. It is now read-only.

Err` value: IrError(OutputNodeNotFound("/linear/MatMul_output_0")) on linear model #191

@maxwellflitton

Description

@maxwellflitton

Describe the bug
I've trained a simple linear model in pytorch. I then export it to ONNX. Calling from the ONNX library it works fine. However, when trying to call from wonnx I get the error Err value: IrError(OutputNodeNotFound("/linear/MatMul_output_0")). Looking at the model in neuron everything seems to make sense and in my settings I define the output at the name of 5 as this is the output, I don't know why wonnx is erroring here when onnx works fine.

To Reproduce
Train a simple linear model with the following code:

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

squarefoot = np.array([1000, 1200, 1500, 1800, 2000, 2200, 2500, 2800, 3000, 3200], dtype=np.float32)
num_floors = np.array([1, 1, 1.5, 1.5, 2, 2, 2.5, 2.5, 3, 3], dtype=np.float32)
house_price = np.array([200000, 230000, 280000, 320000, 350000, 380000, 420000, 470000, 500000, 520000], dtype=np.float32)

squarefoot_mean = squarefoot.mean()
squarefoot_std = squarefoot.std()
num_floors_mean = num_floors.mean()
num_floors_std = num_floors.std()
house_price_mean = house_price.mean()
house_price_std = house_price.std()

# Normalize the data (optional, but recommended for better convergence)
squarefoot = (squarefoot - squarefoot.mean()) / squarefoot.std()
num_floors = (num_floors - num_floors.mean()) / num_floors.std()
house_price = (house_price - house_price.mean()) / house_price.std()

# Convert numpy arrays to PyTorch tensors
squarefoot_tensor = torch.from_numpy(squarefoot)
num_floors_tensor = torch.from_numpy(num_floors)
house_price_tensor = torch.from_numpy(house_price)


# Define the linear regression model
class LinearRegressionModel(nn.Module):
    def __init__(self):
        super(LinearRegressionModel, self).__init__()
        self.linear = nn.Linear(2, 1)  # 2 input features, 1 output

    def forward(self, x):
        return self.linear(x)

# Initialize the model
model = LinearRegressionModel()

# Define the loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
#
# # Training loop
num_epochs = 1000
for epoch in range(num_epochs):
    # Forward pass
    y_pred = model(X)

    # Compute the loss
    loss = criterion(y_pred.squeeze(), house_price_tensor)

    # Backward pass and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Print the progress
    if (epoch + 1) % 100 == 0:
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

test_squarefoot = torch.tensor([2800, 3200], dtype=torch.float32)
test_num_floors = torch.tensor([2.5, 3], dtype=torch.float32)
test_inputs = torch.stack([test_squarefoot, test_num_floors], dim=1)
test_inputs = torch.tensor([2800, 3], dtype=torch.float32)

# Test the model
with torch.no_grad():
    predicted_prices = model(test_inputs)
    predicted_prices = predicted_prices.squeeze().numpy()
    print("Predicted Prices:", predicted_prices)

I then perform an onnx export with the following code:

# export to ONNX and save file
torch.onnx.export(model, test_inputs, "./linear_test.onnx")

I then load the model in rust with the following code:

use std::collections::HashMap;
use ndarray::{ArrayD, CowArray};
use std::sync::Arc;
use wonnx::Session;
use wonnx::utils::{InputTensor, OutputTensor, tensor};
use wonnx::SessionConfig;

use std::fs::File;
use std::io::{Read, Result};

pub async fn load_model() {
    let mut file = File::open("./linear_test.onnx").unwrap();

    let mut buffer = Vec::new();

    file.read_to_end(&mut buffer).unwrap();
    let config = SessionConfig::new().with_outputs(Some(vec!["5".to_string()]));
    let session = Session::from_bytes_with_config(&buffer, &config).await.unwrap();
    let mut inputs = HashMap::new();
    inputs.insert("onnx::MatMul_0".to_string(), InputTensor::F32(vec![1000.0, 2.0].into()));
    let outputs = session.run(&inputs).await.unwrap();
    println!("file: {:?}", outputs);
}

and I get the error with the following line:

let session = Session::from_bytes_with_config(&buffer, &config).await.unwrap();

Expected behavior
Merely to run a simple inference

Screenshots
When inspecting the onnx file all the weights seem to match up or am I missing something here?

Screenshot 2023-10-03 at 15 14 16 Screenshot 2023-10-03 at 15 14 29 Screenshot 2023-10-03 at 15 14 40 Screenshot 2023-10-03 at 15 15 06

Desktop

  • OS: MacOs (Ventura V13.4)
  • Chip: Apple M2 Max
  • RAM: 96GB
  • Hard Drive: 3.62TB available of 4TB
  • model: 16-inch 2023

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions