-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathjax_srcnn_image_super_resolution.py
122 lines (96 loc) · 4.34 KB
/
jax_srcnn_image_super_resolution.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# jax_srcnn_image_super_resolution.py
import jax
import jax.numpy as jnp
import numpy as np
from PIL import Image
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
import jax.nn
# Define the SRCNN model
def jax_srcnn(inputs, num_filters):
# Patch extraction and representation
conv1 = jax.lax.conv(inputs, kernel_size=(9, 9), feature_map_shape=(num_filters,), padding='SAME')
conv1 = jax.nn.relu(conv1)
# Non-linear mapping
conv2 = jax.lax.conv(conv1, kernel_size=(1, 1), feature_map_shape=(num_filters // 2,), padding='SAME')
conv2 = jax.nn.relu(conv2)
# Reconstruction
outputs = jax.lax.conv(conv2, kernel_size=(5, 5), feature_map_shape=(1,), padding='SAME')
return outputs
# Define the loss function
def jax_loss_fn(params, inputs, targets):
predictions = jax_srcnn(inputs, num_filters)
loss = jnp.mean(jnp.square(predictions - targets))
return loss
# Define the evaluation metrics
def jax_evaluate_model(params, inputs, targets):
predictions = jax_srcnn(inputs, num_filters)
psnr = peak_signal_noise_ratio(targets, predictions)
ssim = structural_similarity(targets, predictions, multichannel=True)
return psnr, ssim
# Train the SRCNN model
def jax_train_model(params, optimizer, train_data, num_epochs, batch_size):
for epoch in range(num_epochs):
epoch_loss = 0
for i in range(0, len(train_data[0]), batch_size):
batch_inputs = train_data[0][i:i+batch_size]
batch_targets = train_data[1][i:i+batch_size]
loss_value, grads = jax.value_and_grad(jax_loss_fn)(params, batch_inputs, batch_targets)
params = optimizer.update(grads, params)
epoch_loss += loss_value
epoch_loss /= (len(train_data[0]) // batch_size)
print(f"Epoch {epoch+1}, Loss: {epoch_loss:.4f}")
return params
# Perform image super-resolution
def jax_image_super_resolution(params, low_res_image, scale_factor):
# Preprocess the low-resolution image
low_res_image = jnp.array(low_res_image) / 255.0
low_res_image = jnp.expand_dims(low_res_image, axis=0)
# Perform super-resolution
high_res_image = jax_srcnn(low_res_image, num_filters)
high_res_image = jnp.squeeze(high_res_image, axis=0)
high_res_image = (high_res_image * 255.0).astype(jnp.uint8)
return high_res_image
# Example usage
# Load and preprocess the training data
train_images = []
train_labels = []
# Simulated low-resolution and high-resolution image pairs
for _ in range(100):
low_res_image = np.random.rand(32, 32, 3) * 255
high_res_image = np.random.rand(128, 128, 3) * 255
train_images.append(low_res_image)
train_labels.append(high_res_image)
train_images = jnp.array(train_images) / 255.0
train_labels = jnp.array(train_labels) / 255.0
train_data = (train_images, train_labels)
# Set the hyperparameters
num_filters = 64
num_epochs = 10
batch_size = 8
# Initialize model parameters and optimizer
params = jax.random.normal(jax.random.PRNGKey(0), (9, 9, 3, num_filters))
optimizer = jax.optim.Adam(learning_rate=0.001)
# Train the SRCNN model
params = jax_train_model(params, optimizer, train_data, num_epochs, batch_size)
# Load a low-resolution image
low_res_image = Image.open("low_res_image.png")
low_res_image = np.array(low_res_image)
# Perform image super-resolution
scale_factor = 4
high_res_image = jax_image_super_resolution(params, low_res_image, scale_factor)
# Save the super-resolved image
output_image = Image.fromarray(high_res_image)
output_image.save("high_res_image.png")
# Evaluate the super-resolution quality
psnr, ssim = jax_evaluate_model(params, jnp.expand_dims(low_res_image / 255.0, axis=0), jnp.expand_dims(high_res_image / 255.0, axis=0))
print(f"PSNR: {psnr:.2f} dB")
print(f"SSIM: {ssim:.4f}")
# Possible Errors and Solutions:
# ValueError: operands could not be broadcast together with shapes (x, y) (a, b)
# Solution: Ensure that the shapes of the predictions and targets match exactly when calculating the loss.
# ImportError: No module named 'PIL'
# Solution: Ensure the PIL library is installed using `pip install pillow`.
# RuntimeError: Invalid argument: Non-scalable parameters
# Solution: Ensure all operations in the model are scalable and support JAX's JIT compilation.
# KeyError: 'low_res_image.png'
# Solution: Ensure the file path and name are correct and the image file exists.