Skip to content

Commit

Permalink
feat: 修改分割数据集逻辑,随机复制数据到 filter 数量
Browse files Browse the repository at this point in the history
  • Loading branch information
Dusker233 committed Jan 14, 2025
1 parent 43878cc commit d6e00d7
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 7 deletions.
5 changes: 0 additions & 5 deletions .gitconfig

This file was deleted.

4 changes: 2 additions & 2 deletions data_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

if __name__ == "__main__":
print("loading YOLO11 model...")
model = YOLO("yolo11m-cls.pt")
# model = torch.hub.load("yolov5", "custom", "yolov5/yolov5m.pt", source="local")
# model = YOLO("yolo11m-cls.pt")
model = torch.hub.load("yolov5", "custom", "yolov5/yolov5m.pt", source="local")

num_photos = 0
num_skipped_photos = 0
Expand Down
13 changes: 13 additions & 0 deletions data_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,19 @@ def resize_image(image_path, output_path, size):
if os.path.isfile(os.path.join(category_path, f))
]

# 如果图片数量不足 args.filter,则进行补充
if len(images) < args.filter:
print(f"Category '{category}' has less than {args.filter} images. Augmenting...")
while len(images) < args.filter:
# 随机选择一张图片进行复制
random_image = random.choice(images)
new_image_name = f"copy_{len(images)}_{random_image}"
shutil.copy(
os.path.join(category_path, random_image),
os.path.join(category_path, new_image_name),
)
images.append(new_image_name)

# 检查图片数量是否至少为args.filter张
if len(images) < args.filter:
print(f"Skipping category '{category}' with less than {args.filter} images.")
Expand Down
Binary file added yolo11m.pt
Binary file not shown.

0 comments on commit d6e00d7

Please sign in to comment.