From 9d5a9a19aeed62bf8759a4fe0f443e154a919557 Mon Sep 17 00:00:00 2001 From: David Humphrey Date: Sun, 3 Sep 2023 13:39:20 -0400 Subject: [PATCH] Handle min_duration, confidence, and buffer in detector, with defaults --- action.py | 69 +++++++++++++++++++++------------------- base_detector.py | 17 +++++++--- megadetector_detector.py | 18 ++++++++--- yolo_fish_detector.py | 17 +++++++--- 4 files changed, 76 insertions(+), 45 deletions(-) diff --git a/action.py b/action.py index 31df9d4..86135a0 100755 --- a/action.py +++ b/action.py @@ -24,25 +24,39 @@ # We use converted ONNX models for YOLO-Fish (https://github.com/tamim662/YOLO-Fish) # and Megadetector (https://github.com/microsoft/CameraTraps) -def load_detector(environment, logger): +def load_detector(environment, min_duration, buffer, confidence, logger): """ Load the appropriate detector based on the environment provided. Args: environment (str): The type of detector to load. Must be either "aquatic" or "terrestrial". + min_duration (float): The minimum duration of a generated clip + buffer (float): An optional number of seconds to add before/after a clip + confidence (float): The confidence level to use logger (logging.Logger): The logger to use for logging messages. Returns: detector (object): An instance of the appropriate detector. Raises: - TypeError: If the environment provided is not "aquatic" or "terrestrial". + TypeError: If the any args are not of the correct type """ + + # Make sure any user-provided flags for detection are valid before we use them + if buffer and buffer < 0.0: + raise TypeError("Error: minimum buffer cannot be negative") + + if min_duration and min_duration <= 0.0: + raise TypeError("Error: minimum duration must be greater than 0.0") + + if confidence and (confidence <= 0.0 or confidence > 1.0): + raise TypeError("Error: confidence must be greater than 0.0 and less than 1.0") + detector = None if environment == "terrestrial": - detector = MegadetectorDetector(logger) + detector = MegadetectorDetector(logger, min_duration, buffer, confidence) elif environment == "aquatic": - detector = YoloFishDetector(logger) + detector = YoloFishDetector(logger, min_duration, buffer, confidence) else: raise TypeError("environment must be one of aquatic or terrestrial") @@ -79,9 +93,9 @@ def process_frames( Returns: None """ - confidence_threshold = args.confidence - buffer_seconds = args.buffer - min_detection_duration = args.min_duration + confidence_threshold = detector.confidence + buffer_seconds = detector.buffer + min_detection_duration = detector.min_duration show_detections = args.show_detections # Number of frames per minute of video time @@ -121,7 +135,7 @@ def process_frames( # Before ending this detection period, check if there is a anything # else detected in what comes after it, and extend if necessary - boxes = detector.detect(frame, confidence_threshold) + boxes = detector.detect(frame) if len(boxes) > 0: detection_highest_confidence = max( detection_highest_confidence, max(box[4] for box in boxes) @@ -153,7 +167,7 @@ def process_frames( # vs. every frame for speed (e.g., every 15 of 30fps). We also check # the last frame, so we don't miss anything at the edge. if frame_count % frames_to_skip == 0 or frame_count == total_frames - 1: - boxes = detector.detect(frame, confidence_threshold) + boxes = detector.detect(frame) # If there are one ore more detections if len(boxes) > 0: @@ -209,30 +223,24 @@ def main(args): logger = logging.getLogger(__name__) logging.basicConfig(level=args.log_level, format="%(message)s") - video_paths = get_video_paths(args.filename) - logger.debug(f"Got input files: {video_paths}") - confidence_threshold = args.confidence - buffer_seconds = args.buffer - min_detection_duration = args.min_duration delete_clips = args.delete_clips output_dir = args.output_dir - environment = args.environment + video_paths = get_video_paths(args.filename) + logger.debug(f"Got input files: {video_paths}") # Validate argument parameters from user before using them if len(video_paths) < 1: logger.error("Error: you must specify one or more video filenames to process") sys.exit(1) - if buffer_seconds < 0.0: - logger.error("Error: minimum buffer cannot be negative") - sys.exit(1) - - if min_detection_duration <= 0.0: - logger.error("Error: minimum duration must be greater than 0.0") - sys.exit(1) - - if confidence_threshold <= 0.0 or confidence_threshold > 1.0: - logger.error("Error: confidence must be greater than 0.0 and less than 1.0") + # Load YOLO-Fish or Megadetector, based on `-e` value + detector = None + try: + detector = load_detector( + args.environment, args.min_duration, args.buffer, args.confidence, logger + ) + except Exception as e: + logger.error(f"There was an error: {e}") sys.exit(1) cap = None @@ -250,9 +258,6 @@ def main(args): # Create a queue manager for clips to be processed by ffmpeg clips = ClipManager(logger, output_dir) - # Load YOLO-Fish or Megadetector, based on `-e` value - detector = load_detector(environment, logger) - # Keep track of total time to process all files, recording start time total_time_start = time.time() @@ -281,7 +286,7 @@ def main(args): f"\nStarting file {i} of {len(video_paths)}: {video_path} - {format_time(duration)} - {total_frames} frames at {fps} fps, skipping every {frames_to_skip} frames" ) logger.info( - f"Using confidence threshold {confidence_threshold}, minimum clip duration of {min_detection_duration} seconds, and {buffer_seconds} seconds of buffer." + f"Using confidence threshold {detector.confidence}, minimum clip duration of {detector.min_duration} seconds, and {detector.buffer} seconds of buffer." ) # If we're not using a common clips dir, reset the counter for future clips @@ -367,7 +372,7 @@ def main(args): parser.add_argument( "-b", "--buffer", - default=0.0, + default=None, type=float, dest="buffer", help="Number of seconds to add before and after detection (e.g., 1.0), cannot be negative", @@ -375,16 +380,16 @@ def main(args): parser.add_argument( "-c", "--confidence", - default=0.25, type=float, + default=None, dest="confidence", help="Confidence threshold for detection (e.g., 0.45), must be greater than 0.0 and less than 1.0", ) parser.add_argument( "-m", "--minimum-duration", - default=1.0, type=float, + default=None, dest="min_duration", help="Minimum duration for clips in seconds (e.g., 2.0), must be greater than 0.0", ) diff --git a/base_detector.py b/base_detector.py index 6a5aa31..3ed28ec 100644 --- a/base_detector.py +++ b/base_detector.py @@ -27,6 +27,9 @@ def __init__( model_path, input_image_height, input_image_width, + min_duration, + buffer, + confidence, class_name, providers=["CPUExecutionProvider"], ): @@ -38,6 +41,9 @@ def __init__( model_path: Path to the model file. input_image_height: Height of the input image. input_image_width: Width of the input image. + min_duration (float): The minimum duration of a generated clip + buffer (float): An optional number of seconds to add before/after a clip + confidence (float): The confidence level to use class_name: Name of the class to be detected. providers: List of providers for ONNX runtime. Default is CPUExecutionProvider. """ @@ -45,6 +51,9 @@ def __init__( self.model_path = model_path self.input_image_height = input_image_height self.input_image_width = input_image_width + self.min_duration = min_duration + self.buffer = buffer + self.confidence = confidence self.class_name = class_name self.providers = providers self.session = None @@ -65,13 +74,12 @@ def load(self): self.model_path, providers=self.providers, sess_options=sess_options ) - def detect(self, image_src, confidence_threshold): + def detect(self, image_src): """ Detect objects in the provided image. Args: image_src: Source image for object detection. - confidence_threshold: Confidence threshold for detection. Returns: List of bounding boxes for detected objects. @@ -95,7 +103,7 @@ def detect(self, image_src, confidence_threshold): self.logger.debug(f"Detection took {end_time - start_time}s") # Process detection outputs into bounding boxes - boxes = self.post_processing(outputs, confidence_threshold) + boxes = self.post_processing(outputs) # Return boxes[0] if it exists, otherwise return an empty list return boxes[0] if boxes else [] @@ -156,14 +164,13 @@ def draw_detections(self, img, boxes, title): cv2.imshow(title, img) cv2.waitKey(1) - def post_processing(self, outputs, confidence_threshold): + def post_processing(self, outputs): """ Post-process the detection outputs. See YoloFishDetector and MegadetectorDetector for implementations. Args: outputs: Outputs from the detection model. - confidence_threshold: Confidence threshold for detection. Raises: NotImplementedError: This method must be implemented by a subclass. diff --git a/megadetector_detector.py b/megadetector_detector.py index f6f05e5..aa3fa4b 100644 --- a/megadetector_detector.py +++ b/megadetector_detector.py @@ -137,27 +137,38 @@ class MegadetectorDetector(BaseDetector): post-processing of the model's outputs. """ - def __init__(self, logger): + def __init__(self, logger, min_duration, buffer, confidence): """ Initialize the MegadetectorDetector class. Args: logger (Logger): Logger object for logging. + min_duration (float): The minimum duration of a generated clip (defaults to 15.0) + buffer (float): An optional number of seconds to add before/after a clip (defaults to 5.0) + confidence (float): The confidence level to use (defaults to 0.40) + """ logger.info( "Initializing Megadetector Model and Optimizing (this will take a minute...)" ) + # Use some defaults if any of these aren't already set + min_duration = 15.0 if min_duration is None else min_duration + buffer = 5.0 if buffer is None else buffer + confidence = 0.40 if confidence is None else confidence providers = get_available_providers() super().__init__( logger, megadetector_model_path, megadetector_image_width, megadetector_image_height, + min_duration, + buffer, + confidence, "Animal", providers, ) - def post_processing(self, outputs, confidence_threshold): + def post_processing(self, outputs): """ Perform post-processing on the model's outputs. This includes non-max suppression, filtering based on confidence threshold, and conversion of @@ -165,7 +176,6 @@ def post_processing(self, outputs, confidence_threshold): Args: outputs (numpy array): Outputs from the model. - confidence_threshold (float): Confidence threshold for filtering. Returns: list: Post-processed predictions. @@ -173,7 +183,7 @@ def post_processing(self, outputs, confidence_threshold): preds = [] for p in outputs[0]: # TODO: what to use for IOU_THRESHOLD? Defaults to 0.45 in run-onnx.py - p = non_max_suppression(p, confidence_threshold, 0.45) + p = non_max_suppression(p, self.confidence, 0.45) # Filter out predictions for animals (class=0) p = [pred for pred in p if pred[5] == 0] # If there are predictions, convert absolute coordinates to relative diff --git a/yolo_fish_detector.py b/yolo_fish_detector.py index ab09b80..ef91075 100644 --- a/yolo_fish_detector.py +++ b/yolo_fish_detector.py @@ -75,30 +75,39 @@ class YoloFishDetector(BaseDetector): post-processing of the model's outputs. """ - def __init__(self, logger): + def __init__(self, logger, min_duration, buffer, confidence): """ Initialize the YoloFishDetector class. Args: logger (Logger): Logger object for logging. + min_duration (float): The minimum duration of a generated clip (defaults to 3.0) + buffer (float): An optional number of seconds to add before/after a clip (defaults to 1.0) + confidence (float): The confidence level to use (defaults to 0.45) """ logger.info("Initializing YOLO-Fish Model") + # Use some defaults if any of these aren't already set + min_duration = 3.0 if min_duration is None else min_duration + buffer = 1.0 if buffer is None else buffer + confidence = 0.45 if confidence is None else confidence super().__init__( logger, yolo_fish_model_path, yolo_fish_image_width, yolo_fish_image_height, + min_duration, + buffer, + confidence, "Fish", ) - def post_processing(self, outputs, confidence_threshold): + def post_processing(self, outputs): """ Perform post-processing on the model outputs. This includes non-max suppression, and filtering based on confidence threshold. Args: outputs (list): List of model outputs. - confidence_threshold (float): Confidence threshold for filtering detections. Returns: list: List of bounding boxes for detected objects. @@ -131,7 +140,7 @@ def post_processing(self, outputs, confidence_threshold): nms_threshold = 0.6 bboxes_batch = [] for i in range(box_array.shape[0]): - argwhere = box_array[i, :, 4] > confidence_threshold + argwhere = box_array[i, :, 4] > self.confidence filtered_boxes = box_array[i, argwhere, :] bboxes = []