Skip to content

Commit

Permalink
Merge pull request #13 from humphrem/detector-defaults
Browse files Browse the repository at this point in the history
Handle min_duration, confidence, and buffer in detector, with defaults
  • Loading branch information
humphrem authored Sep 3, 2023
2 parents 0acac82 + 9d5a9a1 commit 68e288f
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 45 deletions.
69 changes: 37 additions & 32 deletions action.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,39 @@

# We use converted ONNX models for YOLO-Fish (https://github.com/tamim662/YOLO-Fish)
# and Megadetector (https://github.com/microsoft/CameraTraps)
def load_detector(environment, logger):
def load_detector(environment, min_duration, buffer, confidence, logger):
"""
Load the appropriate detector based on the environment provided.
Args:
environment (str): The type of detector to load. Must be either "aquatic" or "terrestrial".
min_duration (float): The minimum duration of a generated clip
buffer (float): An optional number of seconds to add before/after a clip
confidence (float): The confidence level to use
logger (logging.Logger): The logger to use for logging messages.
Returns:
detector (object): An instance of the appropriate detector.
Raises:
TypeError: If the environment provided is not "aquatic" or "terrestrial".
TypeError: If the any args are not of the correct type
"""

# Make sure any user-provided flags for detection are valid before we use them
if buffer and buffer < 0.0:
raise TypeError("Error: minimum buffer cannot be negative")

if min_duration and min_duration <= 0.0:
raise TypeError("Error: minimum duration must be greater than 0.0")

if confidence and (confidence <= 0.0 or confidence > 1.0):
raise TypeError("Error: confidence must be greater than 0.0 and less than 1.0")

detector = None
if environment == "terrestrial":
detector = MegadetectorDetector(logger)
detector = MegadetectorDetector(logger, min_duration, buffer, confidence)
elif environment == "aquatic":
detector = YoloFishDetector(logger)
detector = YoloFishDetector(logger, min_duration, buffer, confidence)
else:
raise TypeError("environment must be one of aquatic or terrestrial")

Expand Down Expand Up @@ -79,9 +93,9 @@ def process_frames(
Returns:
None
"""
confidence_threshold = args.confidence
buffer_seconds = args.buffer
min_detection_duration = args.min_duration
confidence_threshold = detector.confidence
buffer_seconds = detector.buffer
min_detection_duration = detector.min_duration
show_detections = args.show_detections

# Number of frames per minute of video time
Expand Down Expand Up @@ -121,7 +135,7 @@ def process_frames(

# Before ending this detection period, check if there is a anything
# else detected in what comes after it, and extend if necessary
boxes = detector.detect(frame, confidence_threshold)
boxes = detector.detect(frame)
if len(boxes) > 0:
detection_highest_confidence = max(
detection_highest_confidence, max(box[4] for box in boxes)
Expand Down Expand Up @@ -153,7 +167,7 @@ def process_frames(
# vs. every frame for speed (e.g., every 15 of 30fps). We also check
# the last frame, so we don't miss anything at the edge.
if frame_count % frames_to_skip == 0 or frame_count == total_frames - 1:
boxes = detector.detect(frame, confidence_threshold)
boxes = detector.detect(frame)

# If there are one ore more detections
if len(boxes) > 0:
Expand Down Expand Up @@ -209,30 +223,24 @@ def main(args):
logger = logging.getLogger(__name__)
logging.basicConfig(level=args.log_level, format="%(message)s")

video_paths = get_video_paths(args.filename)
logger.debug(f"Got input files: {video_paths}")
confidence_threshold = args.confidence
buffer_seconds = args.buffer
min_detection_duration = args.min_duration
delete_clips = args.delete_clips
output_dir = args.output_dir
environment = args.environment
video_paths = get_video_paths(args.filename)
logger.debug(f"Got input files: {video_paths}")

# Validate argument parameters from user before using them
if len(video_paths) < 1:
logger.error("Error: you must specify one or more video filenames to process")
sys.exit(1)

if buffer_seconds < 0.0:
logger.error("Error: minimum buffer cannot be negative")
sys.exit(1)

if min_detection_duration <= 0.0:
logger.error("Error: minimum duration must be greater than 0.0")
sys.exit(1)

if confidence_threshold <= 0.0 or confidence_threshold > 1.0:
logger.error("Error: confidence must be greater than 0.0 and less than 1.0")
# Load YOLO-Fish or Megadetector, based on `-e` value
detector = None
try:
detector = load_detector(
args.environment, args.min_duration, args.buffer, args.confidence, logger
)
except Exception as e:
logger.error(f"There was an error: {e}")
sys.exit(1)

cap = None
Expand All @@ -250,9 +258,6 @@ def main(args):
# Create a queue manager for clips to be processed by ffmpeg
clips = ClipManager(logger, output_dir)

# Load YOLO-Fish or Megadetector, based on `-e` value
detector = load_detector(environment, logger)

# Keep track of total time to process all files, recording start time
total_time_start = time.time()

Expand Down Expand Up @@ -281,7 +286,7 @@ def main(args):
f"\nStarting file {i} of {len(video_paths)}: {video_path} - {format_time(duration)} - {total_frames} frames at {fps} fps, skipping every {frames_to_skip} frames"
)
logger.info(
f"Using confidence threshold {confidence_threshold}, minimum clip duration of {min_detection_duration} seconds, and {buffer_seconds} seconds of buffer."
f"Using confidence threshold {detector.confidence}, minimum clip duration of {detector.min_duration} seconds, and {detector.buffer} seconds of buffer."
)

# If we're not using a common clips dir, reset the counter for future clips
Expand Down Expand Up @@ -367,24 +372,24 @@ def main(args):
parser.add_argument(
"-b",
"--buffer",
default=0.0,
default=None,
type=float,
dest="buffer",
help="Number of seconds to add before and after detection (e.g., 1.0), cannot be negative",
)
parser.add_argument(
"-c",
"--confidence",
default=0.25,
type=float,
default=None,
dest="confidence",
help="Confidence threshold for detection (e.g., 0.45), must be greater than 0.0 and less than 1.0",
)
parser.add_argument(
"-m",
"--minimum-duration",
default=1.0,
type=float,
default=None,
dest="min_duration",
help="Minimum duration for clips in seconds (e.g., 2.0), must be greater than 0.0",
)
Expand Down
17 changes: 12 additions & 5 deletions base_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def __init__(
model_path,
input_image_height,
input_image_width,
min_duration,
buffer,
confidence,
class_name,
providers=["CPUExecutionProvider"],
):
Expand All @@ -38,13 +41,19 @@ def __init__(
model_path: Path to the model file.
input_image_height: Height of the input image.
input_image_width: Width of the input image.
min_duration (float): The minimum duration of a generated clip
buffer (float): An optional number of seconds to add before/after a clip
confidence (float): The confidence level to use
class_name: Name of the class to be detected.
providers: List of providers for ONNX runtime. Default is CPUExecutionProvider.
"""
self.logger = logger
self.model_path = model_path
self.input_image_height = input_image_height
self.input_image_width = input_image_width
self.min_duration = min_duration
self.buffer = buffer
self.confidence = confidence
self.class_name = class_name
self.providers = providers
self.session = None
Expand All @@ -65,13 +74,12 @@ def load(self):
self.model_path, providers=self.providers, sess_options=sess_options
)

def detect(self, image_src, confidence_threshold):
def detect(self, image_src):
"""
Detect objects in the provided image.
Args:
image_src: Source image for object detection.
confidence_threshold: Confidence threshold for detection.
Returns:
List of bounding boxes for detected objects.
Expand All @@ -95,7 +103,7 @@ def detect(self, image_src, confidence_threshold):
self.logger.debug(f"Detection took {end_time - start_time}s")

# Process detection outputs into bounding boxes
boxes = self.post_processing(outputs, confidence_threshold)
boxes = self.post_processing(outputs)

# Return boxes[0] if it exists, otherwise return an empty list
return boxes[0] if boxes else []
Expand Down Expand Up @@ -156,14 +164,13 @@ def draw_detections(self, img, boxes, title):
cv2.imshow(title, img)
cv2.waitKey(1)

def post_processing(self, outputs, confidence_threshold):
def post_processing(self, outputs):
"""
Post-process the detection outputs. See YoloFishDetector and
MegadetectorDetector for implementations.
Args:
outputs: Outputs from the detection model.
confidence_threshold: Confidence threshold for detection.
Raises:
NotImplementedError: This method must be implemented by a subclass.
Expand Down
18 changes: 14 additions & 4 deletions megadetector_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,43 +137,53 @@ class MegadetectorDetector(BaseDetector):
post-processing of the model's outputs.
"""

def __init__(self, logger):
def __init__(self, logger, min_duration, buffer, confidence):
"""
Initialize the MegadetectorDetector class.
Args:
logger (Logger): Logger object for logging.
min_duration (float): The minimum duration of a generated clip (defaults to 15.0)
buffer (float): An optional number of seconds to add before/after a clip (defaults to 5.0)
confidence (float): The confidence level to use (defaults to 0.40)
"""
logger.info(
"Initializing Megadetector Model and Optimizing (this will take a minute...)"
)
# Use some defaults if any of these aren't already set
min_duration = 15.0 if min_duration is None else min_duration
buffer = 5.0 if buffer is None else buffer
confidence = 0.40 if confidence is None else confidence
providers = get_available_providers()
super().__init__(
logger,
megadetector_model_path,
megadetector_image_width,
megadetector_image_height,
min_duration,
buffer,
confidence,
"Animal",
providers,
)

def post_processing(self, outputs, confidence_threshold):
def post_processing(self, outputs):
"""
Perform post-processing on the model's outputs. This includes non-max
suppression, filtering based on confidence threshold, and conversion of
absolute coordinates to relative.
Args:
outputs (numpy array): Outputs from the model.
confidence_threshold (float): Confidence threshold for filtering.
Returns:
list: Post-processed predictions.
"""
preds = []
for p in outputs[0]:
# TODO: what to use for IOU_THRESHOLD? Defaults to 0.45 in run-onnx.py
p = non_max_suppression(p, confidence_threshold, 0.45)
p = non_max_suppression(p, self.confidence, 0.45)
# Filter out predictions for animals (class=0)
p = [pred for pred in p if pred[5] == 0]
# If there are predictions, convert absolute coordinates to relative
Expand Down
17 changes: 13 additions & 4 deletions yolo_fish_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,30 +75,39 @@ class YoloFishDetector(BaseDetector):
post-processing of the model's outputs.
"""

def __init__(self, logger):
def __init__(self, logger, min_duration, buffer, confidence):
"""
Initialize the YoloFishDetector class.
Args:
logger (Logger): Logger object for logging.
min_duration (float): The minimum duration of a generated clip (defaults to 3.0)
buffer (float): An optional number of seconds to add before/after a clip (defaults to 1.0)
confidence (float): The confidence level to use (defaults to 0.45)
"""
logger.info("Initializing YOLO-Fish Model")
# Use some defaults if any of these aren't already set
min_duration = 3.0 if min_duration is None else min_duration
buffer = 1.0 if buffer is None else buffer
confidence = 0.45 if confidence is None else confidence
super().__init__(
logger,
yolo_fish_model_path,
yolo_fish_image_width,
yolo_fish_image_height,
min_duration,
buffer,
confidence,
"Fish",
)

def post_processing(self, outputs, confidence_threshold):
def post_processing(self, outputs):
"""
Perform post-processing on the model outputs. This includes non-max
suppression, and filtering based on confidence threshold.
Args:
outputs (list): List of model outputs.
confidence_threshold (float): Confidence threshold for filtering detections.
Returns:
list: List of bounding boxes for detected objects.
Expand Down Expand Up @@ -131,7 +140,7 @@ def post_processing(self, outputs, confidence_threshold):
nms_threshold = 0.6
bboxes_batch = []
for i in range(box_array.shape[0]):
argwhere = box_array[i, :, 4] > confidence_threshold
argwhere = box_array[i, :, 4] > self.confidence
filtered_boxes = box_array[i, argwhere, :]

bboxes = []
Expand Down

0 comments on commit 68e288f

Please sign in to comment.