2828BOUNDING_BOXES = "bounding_boxes"
2929KEYPOINTS = "keypoints"
3030RAGGED_BOUNDING_BOXES = "ragged_bounding_boxes"
31+ SEGMENTATION_MASK = "segmentation_mask"
3132IS_DICT = "is_dict"
3233USE_TARGETS = "use_targets"
3334
@@ -143,7 +144,7 @@ def augment_image(self, image, transformation, **kwargs):
143144 `layer.call()`.
144145 transformation: The transformation object produced by
145146 `get_random_transformation`. Used to coordinate the randomness
146- between image, label and bounding box.
147+ between image, label, bounding box, keypoints, and segmentation mask .
147148
148149 Returns:
149150 output 3D tensor, which will be forward to `layer.call()`.
@@ -157,7 +158,7 @@ def augment_label(self, label, transformation, **kwargs):
157158 label: 1D label to the layer. Forwarded from `layer.call()`.
158159 transformation: The transformation object produced by
159160 `get_random_transformation`. Used to coordinate the randomness
160- between image, label and bounding box.
161+ between image, label, bounding box, keypoints, and segmentation mask .
161162
162163 Returns:
163164 output 1D tensor, which will be forward to `layer.call()`.
@@ -171,7 +172,7 @@ def augment_target(self, target, transformation, **kwargs):
171172 target: 1D label to the layer. Forwarded from `layer.call()`.
172173 transformation: The transformation object produced by
173174 `get_random_transformation`. Used to coordinate the randomness
174- between image, label and bounding box.
175+ between image, label, bounding box, keypoints, and segmentation mask .
175176
176177 Returns:
177178 output 1D tensor, which will be forward to `layer.call()`.
@@ -188,7 +189,7 @@ def augment_bounding_boxes(self, bounding_boxes, transformation, **kwargs):
188189 `call()`.
189190 transformation: The transformation object produced by
190191 `get_random_transformation`. Used to coordinate the randomness
191- between image, label and bounding box.
192+ between image, label, bounding box, keypoints, and segmentation mask .
192193
193194 Returns:
194195 output 2D tensor, which will be forward to `layer.call()`.
@@ -203,15 +204,36 @@ def augment_keypoints(self, keypoints, transformation, **kwargs):
203204 `layer.call()`.
204205 transformation: The transformation object produced by
205206 `get_random_transformation`. Used to coordinate the randomness
206- between image, label and bounding box.
207+ between image, label, bounding box, keypoints, and segmentation mask .
207208
208209 Returns:
209210 output 2D tensor, which will be forward to `layer.call()`.
210211 """
211212 raise NotImplementedError ()
212213
214+ def augment_segmentation_mask (self , segmentation_mask , transformation , ** kwargs ):
215+ """Augment a single image's segmentation mask during training.
216+
217+ Args:
218+ segmentation_mask: 3D segmentation mask input tensor to the layer.
219+ This should generally have the shape [H, W, 1], or in some cases [H, W, C] for multilabeled data.
220+ Forwarded from `layer.call()`.
221+ transformation: The transformation object produced by
222+ `get_random_transformation`. Used to coordinate the randomness
223+ between image, label, bounding box, keypoints, and segmentation mask.
224+
225+ Returns:
226+ output 3D tensor containing the augmented segmentation mask, which will be forward to `layer.call()`.
227+ """
228+ raise NotImplementedError ()
229+
213230 def get_random_transformation (
214- self , image = None , label = None , bounding_boxes = None , keypoints = None
231+ self ,
232+ image = None ,
233+ label = None ,
234+ bounding_boxes = None ,
235+ keypoints = None ,
236+ segmentation_mask = None ,
215237 ):
216238 """Produce random transformation config for one single input.
217239
@@ -222,6 +244,7 @@ def get_random_transformation(
222244 image: 3D image tensor from inputs.
223245 label: optional 1D label tensor from inputs.
224246 bounding_box: optional 2D bounding boxes tensor from inputs.
247+ segmentation_mask: optional 3D segmentation mask tensor from inputs.
225248
226249 Returns:
227250 Any type of object, which will be forwarded to `augment_image`,
@@ -253,8 +276,13 @@ def _augment(self, inputs):
253276 label = inputs .get (LABELS , None )
254277 bounding_boxes = inputs .get (BOUNDING_BOXES , None )
255278 keypoints = inputs .get (KEYPOINTS , None )
279+ segmentation_mask = inputs .get (SEGMENTATION_MASK , None )
256280 transformation = self .get_random_transformation (
257- image = image , label = label , bounding_boxes = bounding_boxes , keypoints = keypoints
281+ image = image ,
282+ label = label ,
283+ bounding_boxes = bounding_boxes ,
284+ keypoints = keypoints ,
285+ segmentation_mask = segmentation_mask ,
258286 )
259287 image = self .augment_image (
260288 image ,
@@ -288,6 +316,12 @@ def _augment(self, inputs):
288316 image = image ,
289317 )
290318 result [KEYPOINTS ] = keypoints
319+ if segmentation_mask is not None :
320+ segmentation_mask = self .augment_segmentation_mask (
321+ segmentation_mask ,
322+ transformation = transformation ,
323+ )
324+ result [SEGMENTATION_MASK ] = segmentation_mask
291325
292326 # preserve any additional inputs unmodified by this layer.
293327 for key in inputs .keys () - result .keys ():
0 commit comments