diff --git a/.gitconfig b/.gitconfig deleted file mode 100644 index f125a93..0000000 --- a/.gitconfig +++ /dev/null @@ -1,5 +0,0 @@ -[http] - sslVerify = false - postBuffer = 1048576000 -[url "https://gitclone.com/"] - insteadOf = https:// diff --git a/data_preprocess.py b/data_preprocess.py index 38af09d..f7f74ad 100644 --- a/data_preprocess.py +++ b/data_preprocess.py @@ -10,8 +10,8 @@ if __name__ == "__main__": print("loading YOLO11 model...") - model = YOLO("yolo11m-cls.pt") - # model = torch.hub.load("yolov5", "custom", "yolov5/yolov5m.pt", source="local") + # model = YOLO("yolo11m-cls.pt") + model = torch.hub.load("yolov5", "custom", "yolov5/yolov5m.pt", source="local") num_photos = 0 num_skipped_photos = 0 diff --git a/data_split.py b/data_split.py index be1d07e..aff7ae1 100644 --- a/data_split.py +++ b/data_split.py @@ -65,6 +65,19 @@ def resize_image(image_path, output_path, size): if os.path.isfile(os.path.join(category_path, f)) ] + # 如果图片数量不足 args.filter,则进行补充 + if len(images) < args.filter: + print(f"Category '{category}' has less than {args.filter} images. Augmenting...") + while len(images) < args.filter: + # 随机选择一张图片进行复制 + random_image = random.choice(images) + new_image_name = f"copy_{len(images)}_{random_image}" + shutil.copy( + os.path.join(category_path, random_image), + os.path.join(category_path, new_image_name), + ) + images.append(new_image_name) + # 检查图片数量是否至少为args.filter张 if len(images) < args.filter: print(f"Skipping category '{category}' with less than {args.filter} images.") diff --git a/yolo11m.pt b/yolo11m.pt new file mode 100644 index 0000000..0559400 Binary files /dev/null and b/yolo11m.pt differ