@@ -85,18 +85,7 @@ def augment_label(self, label, transformation, **kwargs):
8585 return label
8686
8787 def augment_image (self , image , transformation , ** kwargs ):
88- flipped_output = tf .cond (
89- transformation ["flip_horizontal" ],
90- lambda : tf .image .flip_left_right (image ),
91- lambda : image ,
92- )
93- flipped_output = tf .cond (
94- transformation ["flip_vertical" ],
95- lambda : tf .image .flip_up_down (flipped_output ),
96- lambda : flipped_output ,
97- )
98- flipped_output .set_shape (image .shape )
99- return flipped_output
88+ return RandomFlip ._flip_image (image , transformation )
10089
10190 def get_random_transformation (self , ** kwargs ):
10291 flip_horizontal = False
@@ -110,6 +99,20 @@ def get_random_transformation(self, **kwargs):
11099 "flip_vertical" : tf .cast (flip_vertical , dtype = tf .bool ),
111100 }
112101
102+ def _flip_image (image , transformation ):
103+ flipped_output = tf .cond (
104+ transformation ["flip_horizontal" ],
105+ lambda : tf .image .flip_left_right (image ),
106+ lambda : image ,
107+ )
108+ flipped_output = tf .cond (
109+ transformation ["flip_vertical" ],
110+ lambda : tf .image .flip_up_down (flipped_output ),
111+ lambda : flipped_output ,
112+ )
113+ flipped_output .set_shape (image .shape )
114+ return flipped_output
115+
113116 def _flip_bounding_boxes_horizontal (bounding_boxes ):
114117 x1 , x2 , x3 , x4 , rest = tf .split (
115118 bounding_boxes , [1 , 1 , 1 , 1 , bounding_boxes .shape [- 1 ] - 4 ], axis = - 1
@@ -186,6 +189,11 @@ def augment_bounding_boxes(
186189 )
187190 return bounding_boxes
188191
192+ def augment_segmentation_mask (
193+ self , segmentation_mask , transformation = None , ** kwargs
194+ ):
195+ return RandomFlip ._flip_image (segmentation_mask , transformation )
196+
189197 def compute_output_shape (self , input_shape ):
190198 return input_shape
191199
0 commit comments