-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathjax_nutritional_content_prediction.py
137 lines (107 loc) · 5.08 KB
/
jax_nutritional_content_prediction.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import jax.numpy as jnp
from jax import grad, jit, vmap, random
from jax.experimental import stax
from jax.experimental.stax import Dense, Relu
import numpy as np
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
from sklearn.model_selection import train_test_split
# Simulate data for nutritional information and meal preferences
# We'll create a dataset with features like ingredients, category,
# and nutritional values like calories, proteins, carbs, fats, and dietary restrictions
# Define the number of samples
num_samples = 10000
# Generate random ingredients (simplified for demonstration purposes)
ingredients = np.random.randint(10, size=(num_samples, 5))
# Generate random categories (simplified for demonstration purposes)
categories = np.random.randint(5, size=num_samples)
# Generate random nutritional values
calories = np.random.uniform(100, 1000, size=num_samples)
proteins = np.random.uniform(5, 50, size=num_samples)
carbs = np.random.uniform(10, 100, size=num_samples)
fats = np.random.uniform(5, 50, size=num_samples)
# Generate random dietary restrictions (simplified for demonstration purposes)
dietary_restrictions = np.random.randint(3, size=num_samples)
# Combine the features into a dataset
data = np.column_stack((ingredients, categories, calories, proteins, carbs, fats, dietary_restrictions))
# Split the dataset into features (X) and targets (y)
X = data[:, :-4] # Exclude the nutritional values and dietary restrictions
y = data[:, -4:] # Nutritional values (calories, proteins, carbs, fats)
# Encode categorical features
categorical_features = [5, -1] # indices of the category and dietary restriction columns
label_encoder = LabelEncoder()
X[:, 5] = label_encoder.fit_transform(X[:, 5])
X[:, -1] = label_encoder.fit_transform(X[:, -1])
one_hot_encoder = OneHotEncoder(sparse=False)
categorical_data = one_hot_encoder.fit_transform(X[:, categorical_features])
# Concatenate the numerical and one-hot encoded categorical features
X = np.concatenate((X[:, :5], X[:, 6:-1], categorical_data), axis=1)
# Split the dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Define the neural network architecture
def create_model(input_shape):
return stax.serial(
Dense(128), Relu,
Dense(64), Relu,
Dense(4) # Output layer for 4 targets
)
# Create the model
init_fun, apply_fun = create_model(X_train.shape[1])
# Initialize model parameters
key = random.PRNGKey(42)
_, init_params = init_fun(key, (-1, X_train.shape[1]))
# Define the loss function
def loss_fn(params, inputs, targets):
predictions = apply_fun(params, inputs)
return jnp.mean((predictions - targets) ** 2)
# Define the update function
@jit
def update(params, inputs, targets, learning_rate):
grads = grad(loss_fn)(params, inputs, targets)
return jax.tree_util.tree_map(lambda p, g: p - learning_rate * g, params, grads)
# Custom transformation for batch normalization
def custom_batch_norm(x, mean, var, beta, gamma, eps=1e-5):
return (x - mean) / jnp.sqrt(var + eps) * gamma + beta
# Training settings
num_epochs = 100
batch_size = 32
learning_rate = 0.001
# Train the model
for epoch in range(num_epochs):
epoch_loss = 0.0
num_batches = 0
# Create batches
batches = [(X_train[i:i+batch_size], y_train[i:i+batch_size])
for i in range(0, X_train.shape[0], batch_size)]
for inputs, targets in batches:
# Apply custom batch normalization
mean = jnp.mean(inputs, axis=0)
var = jnp.var(inputs, axis=0)
beta = jnp.zeros_like(mean)
gamma = jnp.ones_like(var)
inputs = custom_batch_norm(inputs, mean, var, beta, gamma)
# Update parameters
init_params = update(init_params, inputs, targets, learning_rate)
batch_loss = loss_fn(init_params, inputs, targets)
epoch_loss += batch_loss
num_batches += 1
epoch_loss /= num_batches
print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss:.4f}")
# Evaluate the model on the test set
test_predictions = apply_fun(init_params, X_test)
test_loss = jnp.mean((test_predictions - y_test) ** 2)
print(f"Test Loss: {test_loss:.4f}")
# Define the inference function
@jit
def predict(params, inputs):
return apply_fun(params, inputs)
# Example usage: Predict nutritional values for a new meal
new_input = np.array([[5, 3, 2, 1, 4, 2, 0, 1]]) # Replace with actual input features
predicted_nutrition = predict(init_params, new_input)
print(f"Predicted Nutrition: {predicted_nutrition}")
# Possible Errors and Solutions:
# 1. ValueError: If the input data contains NaN or infinite values, it may cause errors in training.
# Solution: Ensure the input data is cleaned and contains no NaN or infinite values.
# 2. Dimension Mismatch: If the dimensions of input features or targets do not match the expected dimensions, an error will occur.
# Solution: Verify the dimensions of your input data and targets before training.
# 3. Convergence Issues: If the learning rate is too high, the model may not converge.
# Solution: Decrease the learning rate and try again.