Skip to content

Commit 305c083

Browse files
authored
Add segmask support for random_flip (#775)
* Add segmask support for random_flip * Set seed for test case * Undo RRC changes * Make shared _flip_image method
1 parent 02dda74 commit 305c083

File tree

2 files changed

+42
-12
lines changed

2 files changed

+42
-12
lines changed

keras_cv/layers/preprocessing/random_flip.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -85,18 +85,7 @@ def augment_label(self, label, transformation, **kwargs):
8585
return label
8686

8787
def augment_image(self, image, transformation, **kwargs):
88-
flipped_output = tf.cond(
89-
transformation["flip_horizontal"],
90-
lambda: tf.image.flip_left_right(image),
91-
lambda: image,
92-
)
93-
flipped_output = tf.cond(
94-
transformation["flip_vertical"],
95-
lambda: tf.image.flip_up_down(flipped_output),
96-
lambda: flipped_output,
97-
)
98-
flipped_output.set_shape(image.shape)
99-
return flipped_output
88+
return RandomFlip._flip_image(image, transformation)
10089

10190
def get_random_transformation(self, **kwargs):
10291
flip_horizontal = False
@@ -110,6 +99,20 @@ def get_random_transformation(self, **kwargs):
11099
"flip_vertical": tf.cast(flip_vertical, dtype=tf.bool),
111100
}
112101

102+
def _flip_image(image, transformation):
103+
flipped_output = tf.cond(
104+
transformation["flip_horizontal"],
105+
lambda: tf.image.flip_left_right(image),
106+
lambda: image,
107+
)
108+
flipped_output = tf.cond(
109+
transformation["flip_vertical"],
110+
lambda: tf.image.flip_up_down(flipped_output),
111+
lambda: flipped_output,
112+
)
113+
flipped_output.set_shape(image.shape)
114+
return flipped_output
115+
113116
def _flip_bounding_boxes_horizontal(bounding_boxes):
114117
x1, x2, x3, x4, rest = tf.split(
115118
bounding_boxes, [1, 1, 1, 1, bounding_boxes.shape[-1] - 4], axis=-1
@@ -186,6 +189,11 @@ def augment_bounding_boxes(
186189
)
187190
return bounding_boxes
188191

192+
def augment_segmentation_mask(
193+
self, segmentation_mask, transformation=None, **kwargs
194+
):
195+
return RandomFlip._flip_image(segmentation_mask, transformation)
196+
189197
def compute_output_shape(self, input_shape):
190198
return input_shape
191199

keras_cv/layers/preprocessing/random_flip_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,25 @@ def test_augment_bbox_batched_input(self):
134134
)
135135
expected_output = np.reshape(expected_output, (2, 2, 5))
136136
self.assertAllClose(expected_output, output["bounding_boxes"])
137+
138+
def test_augment_segmentation_mask(self):
139+
np.random.seed(1337)
140+
image = np.random.random((1, 20, 20, 3)).astype(np.float32)
141+
mask = np.random.randint(2, size=(1, 20, 20, 1)).astype(np.float32)
142+
143+
input = {"images": image, "segmentation_masks": mask}
144+
145+
# Flip both vertically and horizontally
146+
mock_random = [0.6, 0.6]
147+
layer = RandomFlip()
148+
149+
with unittest.mock.patch.object(
150+
layer._random_generator,
151+
"random_uniform",
152+
side_effect=mock_random,
153+
):
154+
output = layer(input, training=True)
155+
156+
expected_mask = np.flip(np.flip(mask, axis=1), axis=2)
157+
158+
self.assertAllClose(expected_mask, output["segmentation_masks"])

0 commit comments

Comments
 (0)