Skip to content

Commit 73c0938

Browse files
committed
feat: implementing ship detection model
1 parent 6e50c55 commit 73c0938

File tree

7 files changed

+301
-3
lines changed

7 files changed

+301
-3
lines changed

.gitattributes

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
1-
assets/model_weights.pth filter=lfs diff=lfs merge=lfs -text
1+
assets/aircraft_model_weights.pth filter=lfs diff=lfs merge=lfs -text
2+
assets/ship_model_weights.pth filter=lfs diff=lfs merge=lfs -text
23
assets/images/2_planes.tiff filter=lfs diff=lfs merge=lfs -text
4+
assets/*.pth filter=lfs diff=lfs merge=lfs -text
5+
assets/images/*.tiff filter=lfs diff=lfs merge=lfs -text

assets/aircraft_model_weights.pth

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:e878fef8cb125159a7b2054ff1c02fce274659700261aec2e2a7e0c1b0c37e22
3+
size 351017331

assets/ship_model_weights.pth

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:e878fef8cb125159a7b2054ff1c02fce274659700261aec2e2a7e0c1b0c37e22
3+
size 351017331

src/aws/osml/models/aircraft/app.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023-2024 Amazon.com, Inc. or its affiliates.
1+
# Copyright 2023-2025 Amazon.com, Inc. or its affiliates.
22

33
import json
44
import os
@@ -50,7 +50,7 @@ def build_predictor() -> DefaultPredictor:
5050
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
5151
# Path to the model weights
5252
cfg.MODEL.WEIGHTS = os.getenv(
53-
os.path.join("MODEL_WEIGHTS"), os.path.join("/home/osml-models/assets/", "model_weights.pth")
53+
os.path.join("MODEL_WEIGHTS"), os.path.join("/home/osml-models/assets/", "aircraft_model_weights.pth")
5454
)
5555

5656
# Build the detectron2 default predictor
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Copyright 2023-2025 Amazon.com, Inc. or its affiliates.

src/aws/osml/models/ship/app.py

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
# Copyright 2023-2025 Amazon.com, Inc. or its affiliates.
2+
3+
import json
4+
import os
5+
import uuid
6+
import warnings
7+
from typing import Dict, Optional, Union
8+
9+
import numpy as np
10+
import torch
11+
from detectron2.config import get_cfg
12+
from detectron2.engine import DefaultPredictor
13+
from detectron2.structures.instances import Instances
14+
from flask import Request, Response, request
15+
from osgeo import gdal
16+
17+
from aws.osml.models import build_flask_app, build_logger, setup_server
18+
from aws.osml.models.ship.config import build_config
19+
20+
ENABLE_SEGMENTATION = os.environ.get("ENABLE_SEGMENTATION", "False").lower() == "true"
21+
ENABLE_FAULT_DETECTION = os.environ.get("ENABLE_FAULT_DETECTION", "False").lower() == "true"
22+
23+
# Enable exceptions for GDAL
24+
gdal.UseExceptions()
25+
26+
# Create logger instance
27+
logger = build_logger()
28+
29+
# Create our default flask app
30+
app = build_flask_app(logger)
31+
32+
33+
def build_predictor() -> DefaultPredictor:
34+
"""
35+
Create a single detection predictor to detect ships
36+
:return: DefaultPredictor
37+
"""
38+
# Load the prebuilt plane model w/ Detectron2
39+
cfg = get_cfg()
40+
# If we can't find a gpu
41+
if not torch.cuda.is_available():
42+
cfg.MODEL.DEVICE = "cpu"
43+
app.logger.warning("GPU not found, running in CPU mode!")
44+
# Set to only expect one class (ships)
45+
cfg = build_config()
46+
47+
# Build the detectron2 default predictor
48+
return DefaultPredictor(cfg)
49+
50+
51+
def instances_to_feature_collection(
52+
instances: Instances, image_id: Optional[str] = str(uuid.uuid4())
53+
) -> Dict[str, Union[str, list]]:
54+
"""
55+
Convert the gRPC response from the GetDetection call into a GeoJSON output.
56+
Each detection is a feature in the collection, including image coordinates,
57+
score, and type identifier as feature properties.
58+
59+
:param instances: Detectron2 result instances
60+
:param image_id: Identifier for the processed image (optional)
61+
:return: FeatureCollection object containing detections
62+
"""
63+
geojson_feature_collection_dict = {"type": "FeatureCollection", "features": []}
64+
if instances:
65+
# Get the bounding boxes for this image
66+
bboxes = instances.pred_boxes.tensor.cpu().numpy().tolist()
67+
68+
# Get the scores for this image, this model does not support segmentation
69+
scores = instances.scores.cpu().numpy().tolist()
70+
71+
for i in range(0, len(bboxes)):
72+
feature = {
73+
"type": "Feature",
74+
"geometry": {"type": "Point", "coordinates": [0.0, 0.0]},
75+
"id": str(uuid.uuid4()),
76+
"properties": {
77+
"bounds_imcoords": bboxes[i],
78+
"detection_score": float(scores[i]),
79+
"feature_types": {"ship": float(scores[i])},
80+
"image_id": image_id,
81+
},
82+
}
83+
app.logger.debug(feature)
84+
geojson_feature_collection_dict["features"].append(feature)
85+
else:
86+
app.logger.debug("No features found!")
87+
88+
return geojson_feature_collection_dict
89+
90+
91+
def request_to_instances(req: Request) -> Union[Instances, None]:
92+
"""
93+
Use GDAL to open the image. The binary payload from the HTTP request is used to
94+
create an in-memory VFS for GDAL which is then opened to decode the image into
95+
a dataset which will give us access to a NumPy array for the pixels. Then
96+
use that image to create detectron2 detection instances.
97+
98+
:param req: Request: the flask request object passed into the SM endpoint
99+
:return: Either a set of detectron2 detection instances or nothing
100+
"""
101+
# Set up default variables
102+
temp_ds_name = "/vsimem/" + str(uuid.uuid4())
103+
gdal_dataset = None
104+
instances = None
105+
try:
106+
# Load the binary memory buffer sent to the model
107+
gdal.FileFromMemBuffer(temp_ds_name, req.get_data())
108+
gdal_dataset = gdal.Open(temp_ds_name)
109+
110+
# Read GDAL dataset and convert to a numpy array
111+
image_array = gdal_dataset.ReadAsArray()
112+
113+
# Check if all pixels are zero and raise an exception if so
114+
if ENABLE_FAULT_DETECTION:
115+
app.logger.debug(f"Image array min: {image_array.min()}, max: {image_array.max()}")
116+
if np.all(np.isclose(image_array, 0)):
117+
err = "All pixels in the image tile are set to 0."
118+
app.logger.error(err)
119+
raise Exception(err)
120+
121+
# Handling of different image shapes
122+
if image_array.ndim == 2: # For grayscale images without a channel dimension
123+
# Reshape to add a channel dimension and replicate across 3 channels for RGB
124+
image_array = np.stack([image_array] * 3, axis=0)
125+
elif image_array.shape[0] == 1: # For grayscale images with a channel dimension
126+
# Replicate the single channel across 3 channels for RGB
127+
image_array = np.repeat(image_array, 3, axis=0)
128+
elif image_array.shape[0] == 4: # For images with an alpha channel
129+
# Remove the alpha channel
130+
image_array = image_array[:3, :, :]
131+
132+
# Conversion to uint8 (ensure this is done after ensuring 3 channels)
133+
image_array = (image_array * 255).astype(np.uint8)
134+
135+
# Transpose the array from (channels, height, width) to (height, width, channels)
136+
image = np.transpose(image_array, (1, 2, 0))
137+
app.logger.debug(f"Running D2 on image array: {image}")
138+
139+
# PyTorch can often give warnings about upcoming changes
140+
with warnings.catch_warnings():
141+
warnings.simplefilter("ignore")
142+
instances = ship_detector(image)["instances"]
143+
except Exception as err:
144+
app.logger.error(f"Unable to load tile from request: {err}")
145+
raise err
146+
finally:
147+
try:
148+
if gdal_dataset is not None:
149+
if temp_ds_name is not None:
150+
gdal.Unlink(temp_ds_name)
151+
del gdal_dataset
152+
except Exception as err:
153+
app.logger.warning(f"Unable to cleanup gdal dataset: {err}")
154+
155+
return instances
156+
157+
158+
# Build our ship predictor
159+
ship_detector = build_predictor()
160+
161+
162+
@app.route("/ping", methods=["GET"])
163+
def healthcheck() -> Response:
164+
"""
165+
This is a health check that will always pass since this is a stub model.
166+
167+
:return: Successful status code (200) indicates all is well
168+
"""
169+
app.logger.debug("Responding to health check")
170+
return Response(response="\n", status=200)
171+
172+
173+
@app.route("/invocations", methods=["POST"])
174+
def predict() -> Response:
175+
"""
176+
This is the model invocation endpoint for the model container's REST
177+
API. The binary payload, in this case an image, is taken from the request
178+
parsed to ensure it is a valid image. This is a stub implementation that
179+
will always return a fixed set of detections for a valid input image.
180+
181+
:return: Response: Contains the GeoJSON results or an error status
182+
"""
183+
app.logger.debug("Invoking model endpoint using the Detectron2 Ship Model!")
184+
try:
185+
# Load the image into memory and get detection instances
186+
app.logger.debug("Loading image request.")
187+
instances = request_to_instances(request)
188+
189+
# Generate a geojson feature collection that we can return
190+
geojson_detects = instances_to_feature_collection(instances)
191+
app.logger.debug(f"Sending geojson to requester: {json.dumps(geojson_detects)}")
192+
193+
# Send back the detections
194+
return Response(response=json.dumps(geojson_detects), status=200)
195+
except Exception as err:
196+
app.logger.debug(err)
197+
return Response(response="Unable to process request!", status=500)
198+
199+
200+
# pragma: no cover
201+
if __name__ == "__main__":
202+
setup_server(app)

src/aws/osml/models/ship/config.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Copyright 2025 Amazon.com, Inc. or its affiliates.
2+
3+
"""Detectron2 configuration module for ship detection on high-resolution imagery.
4+
5+
This configuration uses the R_101_DC5_3x backbone for improved receptive field and spatial detail,
6+
with performance optimizations for AWS p3.2xlarge or similar environments.
7+
"""
8+
9+
import os
10+
11+
from detectron2 import model_zoo
12+
from detectron2.config import get_cfg
13+
14+
15+
def build_config():
16+
"""Set up Detectron2 config optimized for 2048×2048 tile inputs using R_101_DC5 backbone.
17+
18+
Returns:
19+
Configured Detectron2 config object
20+
"""
21+
# -----------------------------
22+
# Config: Faster R-CNN R101-DC5 (better for small objects than FPN R50)
23+
# -----------------------------
24+
cfg = get_cfg()
25+
cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_101_DC5_3x.yaml"))
26+
27+
# Path to the model weights
28+
cfg.MODEL.WEIGHTS = os.getenv(
29+
os.path.join("MODEL_WEIGHTS"), os.path.join("/home/osml-models/assets/", "ship_model_weights.pth")
30+
)
31+
32+
# Datasets
33+
cfg.DATASETS.TRAIN = ("ship_train",)
34+
cfg.DATASETS.TEST = ("ship_test",)
35+
36+
# One class
37+
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
38+
39+
# Anchors: add tiny + elongated ratios for hulls
40+
cfg.MODEL.ANCHOR_GENERATOR.SIZES = [[72, 96, 160, 256, 384, 512, 704]]
41+
cfg.MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS = [[0.18, 0.38, 0.71, 1.5, 2.56, 4.3, 6.9]]
42+
43+
# RPN proposal budget (don’t prune tiny ships too early)
44+
cfg.MODEL.RPN.PRE_NMS_TOPK_TRAIN = 6000
45+
cfg.MODEL.RPN.PRE_NMS_TOPK_TEST = 4000
46+
cfg.MODEL.RPN.POST_NMS_TOPK_TRAIN = 2000
47+
cfg.MODEL.RPN.POST_NMS_TOPK_TEST = 1000
48+
cfg.MODEL.RPN.NMS_THRESH = 0.5
49+
50+
# ROI settings
51+
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 1024
52+
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.45
53+
cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION = 7
54+
cfg.MODEL.ROI_ALIGN_USE_PRECISE_ROI_POOLER = True
55+
56+
# Input format & sizes (allow light jitter; we’re not hard-locking to 2024 anymore)
57+
cfg.INPUT.FORMAT = "BGR"
58+
cfg.INPUT.MIN_SIZE_TRAIN = (1536, 1792, 2024, 2240)
59+
cfg.INPUT.MAX_SIZE_TRAIN = 2560
60+
cfg.INPUT.MIN_SIZE_TEST = 2024
61+
cfg.INPUT.MAX_SIZE_TEST = 2560
62+
63+
# Sampler: repeat rare positives if dataset is sparse
64+
cfg.DATALOADER.SAMPLER_TRAIN = "RepeatFactorTrainingSampler"
65+
cfg.DATALOADER.REPEAT_THRESHOLD = 0.001
66+
67+
# Mixed precision + grad clip
68+
cfg.SOLVER.IMS_PER_BATCH = 2 # tune to VRAM
69+
cfg.SOLVER.BASE_LR = 0.00025
70+
cfg.SOLVER.WARMUP_ITERS = 1000
71+
cfg.SOLVER.WARMUP_FACTOR = 0.001
72+
cfg.SOLVER.WARMUP_METHOD = "linear"
73+
cfg.SOLVER.MAX_ITER = 500000
74+
cfg.SOLVER.STEPS = [30000, 45000]
75+
cfg.SOLVER.AMP.ENABLED = True
76+
cfg.SOLVER.CLIP_GRADIENTS.ENABLED = True
77+
cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE = 1.0
78+
cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE = "value"
79+
80+
# Evaluation & checkpoints
81+
cfg.TEST.AUG.ENABLED = True
82+
cfg.TEST.DETECTIONS_PER_IMAGE = 500
83+
cfg.TEST.EVAL_PERIOD = 1000
84+
cfg.SOLVER.CHECKPOINT_PERIOD = 1000
85+
86+
return cfg

0 commit comments

Comments
 (0)