-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathjax_image_inpainting.py
169 lines (140 loc) · 6.61 KB
/
jax_image_inpainting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import jax
import jax.numpy as jnp
import numpy as np
from flax import linen as nn
from flax.training import train_state
import optax
from tensorflow_datasets import load_dataset
import matplotlib.pyplot as plt
# Define the generator network
class Generator(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(features=64, kernel_size=(4, 4), strides=(2, 2), padding="SAME")(x)
x = nn.leaky_relu(x, negative_slope=0.2)
x = nn.Conv(features=128, kernel_size=(4, 4), strides=(2, 2), padding="SAME")(x)
x = nn.leaky_relu(x, negative_slope=0.2)
x = nn.Conv(features=256, kernel_size=(4, 4), strides=(2, 2), padding="SAME")(x)
x = nn.leaky_relu(x, negative_slope=0.2)
x = nn.ConvTranspose(features=128, kernel_size=(4, 4), strides=(2, 2), padding="SAME")(x)
x = nn.leaky_relu(x, negative_slope=0.2)
x = nn.ConvTranspose(features=64, kernel_size=(4, 4), strides=(2, 2), padding="SAME")(x)
x = nn.leaky_relu(x, negative_slope=0.2)
x = nn.ConvTranspose(features=3, kernel_size=(4, 4), strides=(2, 2), padding="SAME")(x)
x = nn.tanh(x)
return x
# Define the discriminator network
class Discriminator(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(features=64, kernel_size=(4, 4), strides=(2, 2), padding="SAME")(x)
x = nn.leaky_relu(x, negative_slope=0.2)
x = nn.Conv(features=128, kernel_size=(4, 4), strides=(2, 2), padding="SAME")(x)
x = nn.leaky_relu(x, negative_slope=0.2)
x = nn.Conv(features=256, kernel_size=(4, 4), strides=(2, 2), padding="SAME")(x)
x = nn.leaky_relu(x, negative_slope=0.2)
x = nn.Conv(features=1, kernel_size=(4, 4), strides=(1, 1), padding="VALID")(x)
return x
# Define the loss functions
def generator_loss(params_generator, params_discriminator, images, masks):
masked_images = images * masks
generated_images = Generator().apply(params_generator, masked_images)
discriminator_output = Discriminator().apply(params_discriminator, generated_images)
loss = jnp.mean(jnp.square(discriminator_output - 1))
return loss
def discriminator_loss(params_generator, params_discriminator, images, masks):
masked_images = images * masks
generated_images = Generator().apply(params_generator, masked_images)
real_output = Discriminator().apply(params_discriminator, images)
fake_output = Discriminator().apply(params_discriminator, generated_images)
real_loss = jnp.mean(jnp.square(real_output - 1))
fake_loss = jnp.mean(jnp.square(fake_output))
total_loss = (real_loss + fake_loss) / 2
return total_loss
# Define the train step
@jax.jit
def train_step(state_generator, state_discriminator, images, masks):
# Update generator
loss_generator, grads_generator = jax.value_and_grad(generator_loss)(
state_generator.params, state_discriminator.params, images, masks)
state_generator = state_generator.apply_gradients(grads=grads_generator)
# Update discriminator
loss_discriminator, grads_discriminator = jax.value_and_grad(discriminator_loss)(
state_generator.params, state_discriminator.params, images, masks)
state_discriminator = state_discriminator.apply_gradients(grads=grads_discriminator)
return state_generator, state_discriminator, loss_generator, loss_discriminator
# Load and preprocess the dataset
def load_dataset(dataset_name, batch_size):
ds = load_dataset(dataset_name, split="train")
ds = ds.map(lambda x: (x["image"] / 127.5 - 1))
ds = ds.cache().shuffle(1000).batch(batch_size).prefetch(1)
return ds
# Create random masks
def create_random_masks(images, mask_size):
batch_size = images.shape[0]
mask_height, mask_width = mask_size
masks = jnp.ones_like(images)
for i in range(batch_size):
y = np.random.randint(0, images.shape[1] - mask_height)
x = np.random.randint(0, images.shape[2] - mask_width)
masks = masks.at[i, y:y+mask_height, x:x+mask_width, :].set(0)
return masks
# Set hyperparameters
num_epochs = 100
batch_size = 64
learning_rate = 2e-4
mask_size = (32, 32)
# Load the dataset
dataset = load_dataset("celeb_a", batch_size)
# Initialize the generator and discriminator
generator = Generator()
discriminator = Discriminator()
generator_params = generator.init(jax.random.PRNGKey(0), jnp.zeros((1, 128, 128, 3)))
discriminator_params = discriminator.init(jax.random.PRNGKey(1), jnp.zeros((1, 128, 128, 3)))
# Create train state
tx = optax.adam(learning_rate)
state_generator = train_state.TrainState.create(apply_fn=generator.apply, params=generator_params, tx=tx)
state_discriminator = train_state.TrainState.create(apply_fn=discriminator.apply, params=discriminator_params, tx=tx)
# Training loop
for epoch in range(num_epochs):
for batch in dataset:
masks = create_random_masks(batch, mask_size)
state_generator, state_discriminator, loss_generator, loss_discriminator = train_step(
state_generator, state_discriminator, batch, masks)
print(f"Epoch [{epoch+1}/{num_epochs}], "
f"Generator Loss: {loss_generator:.4f}, "
f"Discriminator Loss: {loss_discriminator:.4f}")
# Inpainting example
image = next(iter(dataset))
mask = create_random_masks(image, mask_size)
masked_image = image * mask
generated_image = Generator().apply(state_generator.params, masked_image)
plt.subplot(1, 3, 1)
plt.imshow(masked_image[0])
plt.title("Masked Image")
plt.axis("off")
plt.subplot(1, 3, 2)
plt.imshow(generated_image[0])
plt.title("Generated Image")
plt.axis("off")
plt.subplot(1, 3, 3)
plt.imshow(image[0])
plt.title("Original Image")
plt.axis("off")
plt.tight_layout()
plt.show()
# Possible Errors and Solutions:
# 1. Import Errors:
# Error: "ModuleNotFoundError: No module named 'jax'"
# Solution: Ensure JAX and other required libraries are properly installed. Use `pip install jax jaxlib flax optax tensorflow-datasets`.
# 2. Data Loading Errors:
# Error: "ValueError: Failed to find data set"
# Solution: Make sure the dataset name is correct and the dataset is available. Use the command `tfds.list_builders()` to check available datasets.
# 3. Shape Mismatch Errors:
# Error: "ValueError: shapes (X,Y) and (Y,Z) not aligned"
# Solution: Verify the shapes of inputs and weights in matrix multiplication. Adjust dimensions if necessary.
# 4. Gradient Issues:
# Error: "ValueError: gradients must be arrays"
# Solution: Ensure that the loss function returns a scalar value for proper gradient computation.
# 5. Performance Issues:
# Solution: Use smaller batch sizes or fewer epochs if the training process is too slow. Consider using GPU for faster computation.