Skip to content

Commit b380c08

Browse files
authored
Add support for segmentation mask in BaseImageAugmentationLayer (#748)
* Add segmap to base image augmentation layer * Address Scott review comments * More review comments
1 parent 349eadc commit b380c08

File tree

2 files changed

+81
-11
lines changed

2 files changed

+81
-11
lines changed

keras_cv/layers/preprocessing/base_image_augmentation_layer.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
BOUNDING_BOXES = "bounding_boxes"
2929
KEYPOINTS = "keypoints"
3030
RAGGED_BOUNDING_BOXES = "ragged_bounding_boxes"
31+
SEGMENTATION_MASK = "segmentation_mask"
3132
IS_DICT = "is_dict"
3233
USE_TARGETS = "use_targets"
3334

@@ -143,7 +144,7 @@ def augment_image(self, image, transformation, **kwargs):
143144
`layer.call()`.
144145
transformation: The transformation object produced by
145146
`get_random_transformation`. Used to coordinate the randomness
146-
between image, label and bounding box.
147+
between image, label, bounding box, keypoints, and segmentation mask.
147148
148149
Returns:
149150
output 3D tensor, which will be forward to `layer.call()`.
@@ -157,7 +158,7 @@ def augment_label(self, label, transformation, **kwargs):
157158
label: 1D label to the layer. Forwarded from `layer.call()`.
158159
transformation: The transformation object produced by
159160
`get_random_transformation`. Used to coordinate the randomness
160-
between image, label and bounding box.
161+
between image, label, bounding box, keypoints, and segmentation mask.
161162
162163
Returns:
163164
output 1D tensor, which will be forward to `layer.call()`.
@@ -171,7 +172,7 @@ def augment_target(self, target, transformation, **kwargs):
171172
target: 1D label to the layer. Forwarded from `layer.call()`.
172173
transformation: The transformation object produced by
173174
`get_random_transformation`. Used to coordinate the randomness
174-
between image, label and bounding box.
175+
between image, label, bounding box, keypoints, and segmentation mask.
175176
176177
Returns:
177178
output 1D tensor, which will be forward to `layer.call()`.
@@ -188,7 +189,7 @@ def augment_bounding_boxes(self, bounding_boxes, transformation, **kwargs):
188189
`call()`.
189190
transformation: The transformation object produced by
190191
`get_random_transformation`. Used to coordinate the randomness
191-
between image, label and bounding box.
192+
between image, label, bounding box, keypoints, and segmentation mask.
192193
193194
Returns:
194195
output 2D tensor, which will be forward to `layer.call()`.
@@ -203,15 +204,36 @@ def augment_keypoints(self, keypoints, transformation, **kwargs):
203204
`layer.call()`.
204205
transformation: The transformation object produced by
205206
`get_random_transformation`. Used to coordinate the randomness
206-
between image, label and bounding box.
207+
between image, label, bounding box, keypoints, and segmentation mask.
207208
208209
Returns:
209210
output 2D tensor, which will be forward to `layer.call()`.
210211
"""
211212
raise NotImplementedError()
212213

214+
def augment_segmentation_mask(self, segmentation_mask, transformation, **kwargs):
215+
"""Augment a single image's segmentation mask during training.
216+
217+
Args:
218+
segmentation_mask: 3D segmentation mask input tensor to the layer.
219+
This should generally have the shape [H, W, 1], or in some cases [H, W, C] for multilabeled data.
220+
Forwarded from `layer.call()`.
221+
transformation: The transformation object produced by
222+
`get_random_transformation`. Used to coordinate the randomness
223+
between image, label, bounding box, keypoints, and segmentation mask.
224+
225+
Returns:
226+
output 3D tensor containing the augmented segmentation mask, which will be forward to `layer.call()`.
227+
"""
228+
raise NotImplementedError()
229+
213230
def get_random_transformation(
214-
self, image=None, label=None, bounding_boxes=None, keypoints=None
231+
self,
232+
image=None,
233+
label=None,
234+
bounding_boxes=None,
235+
keypoints=None,
236+
segmentation_mask=None,
215237
):
216238
"""Produce random transformation config for one single input.
217239
@@ -222,6 +244,7 @@ def get_random_transformation(
222244
image: 3D image tensor from inputs.
223245
label: optional 1D label tensor from inputs.
224246
bounding_box: optional 2D bounding boxes tensor from inputs.
247+
segmentation_mask: optional 3D segmentation mask tensor from inputs.
225248
226249
Returns:
227250
Any type of object, which will be forwarded to `augment_image`,
@@ -253,8 +276,13 @@ def _augment(self, inputs):
253276
label = inputs.get(LABELS, None)
254277
bounding_boxes = inputs.get(BOUNDING_BOXES, None)
255278
keypoints = inputs.get(KEYPOINTS, None)
279+
segmentation_mask = inputs.get(SEGMENTATION_MASK, None)
256280
transformation = self.get_random_transformation(
257-
image=image, label=label, bounding_boxes=bounding_boxes, keypoints=keypoints
281+
image=image,
282+
label=label,
283+
bounding_boxes=bounding_boxes,
284+
keypoints=keypoints,
285+
segmentation_mask=segmentation_mask,
258286
)
259287
image = self.augment_image(
260288
image,
@@ -288,6 +316,12 @@ def _augment(self, inputs):
288316
image=image,
289317
)
290318
result[KEYPOINTS] = keypoints
319+
if segmentation_mask is not None:
320+
segmentation_mask = self.augment_segmentation_mask(
321+
segmentation_mask,
322+
transformation=transformation,
323+
)
324+
result[SEGMENTATION_MASK] = segmentation_mask
291325

292326
# preserve any additional inputs unmodified by this layer.
293327
for key in inputs.keys() - result.keys():

keras_cv/layers/preprocessing/base_image_augmentation_layer_test.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ def augment_bounding_boxes(self, bounding_boxes, transformation, **kwargs):
4444
def augment_keypoints(self, keypoints, transformation, **kwargs):
4545
return keypoints + transformation
4646

47+
def augment_segmentation_mask(self, segmentation_mask, transformation, **kwargs):
48+
return segmentation_mask + transformation
49+
4750

4851
class VectorizeDisabledLayer(BaseImageAugmentationLayer):
4952
def __init__(self, **kwargs):
@@ -148,15 +151,22 @@ def test_augment_image_and_localization_data(self):
148151
images = np.random.random(size=(8, 8, 3)).astype("float32")
149152
bounding_boxes = np.random.random(size=(3, 5)).astype("float32")
150153
keypoints = np.random.random(size=(3, 5, 2)).astype("float32")
154+
segmentation_mask = np.random.random(size=(8, 8, 1)).astype("float32")
151155

152156
output = add_layer(
153-
{"images": images, "bounding_boxes": bounding_boxes, "keypoints": keypoints}
157+
{
158+
"images": images,
159+
"bounding_boxes": bounding_boxes,
160+
"keypoints": keypoints,
161+
"segmentation_mask": segmentation_mask,
162+
}
154163
)
155164

156165
expected_output = {
157166
"images": images + 2.0,
158167
"bounding_boxes": bounding_boxes + 2.0,
159168
"keypoints": keypoints + 2.0,
169+
"segmentation_mask": segmentation_mask + 2.0,
160170
}
161171
self.assertAllClose(output, expected_output)
162172

@@ -165,53 +175,78 @@ def test_augment_batch_image_and_localization_data(self):
165175
images = np.random.random(size=(2, 8, 8, 3)).astype("float32")
166176
bounding_boxes = np.random.random(size=(2, 3, 5)).astype("float32")
167177
keypoints = np.random.random(size=(2, 3, 5, 2)).astype("float32")
178+
segmentation_mask = np.random.random(size=(2, 8, 8, 1)).astype("float32")
168179

169180
output = add_layer(
170-
{"images": images, "bounding_boxes": bounding_boxes, "keypoints": keypoints}
181+
{
182+
"images": images,
183+
"bounding_boxes": bounding_boxes,
184+
"keypoints": keypoints,
185+
"segmentation_mask": segmentation_mask,
186+
}
171187
)
172188

173189
bounding_boxes_diff = output["bounding_boxes"] - bounding_boxes
174190
keypoints_diff = output["keypoints"] - keypoints
191+
segmentation_mask_diff = output["segmentation_mask"] - segmentation_mask
175192
self.assertNotAllClose(bounding_boxes_diff[0], bounding_boxes_diff[1])
176193
self.assertNotAllClose(keypoints_diff[0], keypoints_diff[1])
194+
self.assertNotAllClose(segmentation_mask_diff[0], segmentation_mask_diff[1])
177195

178196
@tf.function
179197
def in_tf_function(inputs):
180198
return add_layer(inputs)
181199

182200
output = in_tf_function(
183-
{"images": images, "bounding_boxes": bounding_boxes, "keypoints": keypoints}
201+
{
202+
"images": images,
203+
"bounding_boxes": bounding_boxes,
204+
"keypoints": keypoints,
205+
"segmentation_mask": segmentation_mask,
206+
}
184207
)
185208

186209
bounding_boxes_diff = output["bounding_boxes"] - bounding_boxes
187210
keypoints_diff = output["keypoints"] - keypoints
211+
segmentation_mask_diff = output["segmentation_mask"] - segmentation_mask
188212
self.assertNotAllClose(bounding_boxes_diff[0], bounding_boxes_diff[1])
189213
self.assertNotAllClose(keypoints_diff[0], keypoints_diff[1])
214+
self.assertNotAllClose(segmentation_mask_diff[0], segmentation_mask_diff[1])
190215

191216
def test_augment_all_data_in_tf_function(self):
192217
add_layer = RandomAddLayer()
193218
images = np.random.random(size=(2, 8, 8, 3)).astype("float32")
194219
bounding_boxes = np.random.random(size=(2, 3, 5)).astype("float32")
195220
keypoints = np.random.random(size=(2, 3, 5, 2)).astype("float32")
221+
segmentation_mask = np.random.random(size=(2, 8, 8, 1)).astype("float32")
196222

197223
@tf.function
198224
def in_tf_function(inputs):
199225
return add_layer(inputs)
200226

201227
output = in_tf_function(
202-
{"images": images, "bounding_boxes": bounding_boxes, "keypoints": keypoints}
228+
{
229+
"images": images,
230+
"bounding_boxes": bounding_boxes,
231+
"keypoints": keypoints,
232+
"segmentation_mask": segmentation_mask,
233+
}
203234
)
204235

205236
bounding_boxes_diff = output["bounding_boxes"] - bounding_boxes
206237
keypoints_diff = output["keypoints"] - keypoints
238+
segmentation_mask_diff = output["segmentation_mask"] - segmentation_mask
207239
self.assertNotAllClose(bounding_boxes_diff[0], bounding_boxes_diff[1])
208240
self.assertNotAllClose(keypoints_diff[0], keypoints_diff[1])
241+
self.assertNotAllClose(segmentation_mask_diff[0], segmentation_mask_diff[1])
209242

210243
def test_raise_error_missing_class_id(self):
211244
add_layer = RandomAddLayer()
212245
images = np.random.random(size=(2, 8, 8, 3)).astype("float32")
213246
bounding_boxes = np.random.random(size=(2, 3, 4)).astype("float32")
214247
keypoints = np.random.random(size=(2, 3, 5, 2)).astype("float32")
248+
segmentation_mask = np.random.random(size=(2, 8, 8, 1)).astype("float32")
249+
215250
with self.assertRaisesRegex(
216251
ValueError,
217252
"Bounding boxes are missing class_id. If you would like to pad the "
@@ -223,5 +258,6 @@ def test_raise_error_missing_class_id(self):
223258
"images": images,
224259
"bounding_boxes": bounding_boxes,
225260
"keypoints": keypoints,
261+
"segmentation_mask": segmentation_mask,
226262
}
227263
)

0 commit comments

Comments
 (0)