diff --git a/data_preprocess.py b/data_preprocess.py index f7f74ad..b24fe40 100644 --- a/data_preprocess.py +++ b/data_preprocess.py @@ -44,7 +44,7 @@ # 过滤非cat目标 cat_results = list(filter(lambda target: target["name"] == "cat", results)) # 跳过图片内检测不到cat或有多个cat的图片 - if len(cat_results) != 1: + if len(cat_results) == 0: num_skipped_photos += 1 continue # 裁剪出cat diff --git a/data_split.py b/data_split.py index aff7ae1..f7e372a 100644 --- a/data_split.py +++ b/data_split.py @@ -3,6 +3,8 @@ from sklearn.model_selection import train_test_split import argparse import json +import random +import shutil parser = argparse.ArgumentParser(description="Cat Recognize Data Preprocessor") parser.add_argument( @@ -66,7 +68,7 @@ def resize_image(image_path, output_path, size): ] # 如果图片数量不足 args.filter,则进行补充 - if len(images) < args.filter: + if len(images) != 0 and len(images) < args.filter: print(f"Category '{category}' has less than {args.filter} images. Augmenting...") while len(images) < args.filter: # 随机选择一张图片进行复制 diff --git a/yolo11n.pt b/yolo11n.pt new file mode 100644 index 0000000..45b273b Binary files /dev/null and b/yolo11n.pt differ