-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathjax_nutritional_content_prediction_explicit_diff.py
103 lines (82 loc) · 3.37 KB
/
jax_nutritional_content_prediction_explicit_diff.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# Import necessary libraries
import jax
import jax.numpy as jnp
from jax import jit, jacrev
from jax.experimental import stax
from jax.experimental.stax import Dense, Relu
import numpy as np
# Simulate data
# For demonstration purposes, let's simulate some example data
np.random.seed(42)
X_train = np.random.rand(100, 8) # 100 samples, 8 features
y_train = np.random.rand(100, 4) # 100 samples, 4 targets (calories, proteins, carbs, fats)
X_test = np.random.rand(20, 8) # 20 samples, 8 features
y_test = np.random.rand(20, 4) # 20 samples, 4 targets
# Define the neural network architecture
def create_model(input_shape):
return stax.serial(
Dense(128), Relu,
Dense(64), Relu,
Dense(4) # Output layer for 4 targets
)
# Create the model
init_fun, apply_fun = create_model(X_train.shape[1])
# Initialize model parameters
key = jax.random.PRNGKey(42)
_, init_params = init_fun(key, (-1, X_train.shape[1]))
# Define the loss function
def loss_fn(params, inputs, targets):
predictions = apply_fun(params, inputs)
return jnp.mean((predictions - targets) ** 2)
# Define the update function using explicit differentiation
@jit
def update(params, inputs, targets, learning_rate):
grads = jacrev(loss_fn)(params, inputs, targets)
return jax.tree_util.tree_map(lambda p, g: p - learning_rate * g, params, grads)
# Training settings
num_epochs = 100
batch_size = 32
learning_rate = 0.001
# Train the model
for epoch in range(num_epochs):
# Shuffle the training data
perm = np.random.permutation(len(X_train))
X_train = X_train[perm]
y_train = y_train[perm]
# Mini-batch training
epoch_loss = 0.0
num_batches = 0
for i in range(0, len(X_train), batch_size):
X_batch = X_train[i:i+batch_size]
y_batch = y_train[i:i+batch_size]
# Update parameters
init_params = update(init_params, X_batch, y_batch, learning_rate)
batch_loss = loss_fn(init_params, X_batch, y_batch)
epoch_loss += batch_loss
num_batches += 1
# Calculate average loss for the epoch
epoch_loss /= num_batches
print(f"Epoch {epoch + 1}, Loss: {epoch_loss:.4f}")
# Evaluate the model
test_predictions = apply_fun(init_params, X_test)
test_loss = jnp.mean((test_predictions - y_test) ** 2)
print(f"Test Loss: {test_loss:.4f}")
# Define the inference function
@jit
def predict(params, inputs):
return apply_fun(params, inputs)
# Example usage
new_input = np.array([[5, 3, 2, 1, 4, 2, 0, 1]]) # Replace with actual input features
predicted_nutrition = predict(init_params, new_input)
print(f"Predicted Nutrition: {predicted_nutrition}")
"""
Possible Errors and Solutions:
1. ValueError: If there are NaN values in the dataset, you might get a ValueError.
Solution: Ensure that your dataset does not contain NaN values by using `np.nan_to_num` or similar preprocessing steps.
2. Dimension Mismatch: If the dimensions of weights or features do not align, an error will occur.
Solution: Check the shapes of your arrays to ensure they are correct, especially after splitting the data.
3. Convergence Issues: If the learning rate is too high, the model may not converge and result in a high loss.
Solution: Reduce the learning rate and observe the change in loss over epochs.
4. Memory Issues: For large datasets, you might encounter memory issues.
Solution: Use batch processing or reduce the dataset size.
"""