From 6be474d4ce678a9a600817df875ed2ae36af73bf Mon Sep 17 00:00:00 2001 From: Dusker233 <1187305740@qq.com> Date: Mon, 20 Jan 2025 17:32:41 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=B0=86=20app.py=20=E5=92=8C=20prepro?= =?UTF-8?q?cess=20=E4=B8=AD=E8=AF=86=E7=8C=AB=E5=88=87=E5=89=B2=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E6=9B=B4=E6=8D=A2=E4=B8=BA=20yolo11?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app.py | 29 ++++++++++++++++++----------- data_preprocess.py | 25 +++++++++++++++++-------- process.sh | 7 ++++++- 3 files changed, 41 insertions(+), 20 deletions(-) diff --git a/app.py b/app.py index 0683926..7abcb4c 100644 --- a/app.py +++ b/app.py @@ -12,6 +12,7 @@ import time from base64 import b64encode from hashlib import sha256 +from ultralytics import YOLO load_dotenv("./env", override=True) @@ -32,7 +33,8 @@ "export" ), "*** export directory not found! you should export the training checkpoint to ONNX model." -crop_model = torch.hub.load("yolov5", "custom", "yolov5/yolov5m.onnx", source="local") +# crop_model = torch.hub.load("yolov5", "custom", "yolov5/yolov5m.onnx", source="local") +crop_model = YOLO("yolo11m.pt") with open("export/cat.json", "r") as fp: cat_ids = json.load(fp) @@ -82,9 +84,14 @@ def recognize_cat_photo(): src_img = Image.open(photo).convert("RGB") # 使用 YOLOv5 进行目标检测,结果为[{xmin, ymin, xmax, ymax, confidence, class, name}]格式 - results = crop_model(src_img).pandas().xyxy[0].to_dict("records") + results = crop_model(src_img) # 过滤非cat目标 - cat_results = list(filter(lambda target: target["name"] == "cat", results)) + cat_results = [] + for result in results: + for box in result.boxes: + # print(result.names[box.cls.tolist()[0]], box.xyxy.tolist()) + if result.names[box.cls.tolist()[0]] == "cat": + cat_results.append(box.xyxy.tolist()) if len(cat_results) >= 1: cat_idx = ( @@ -97,10 +104,10 @@ def recognize_cat_photo(): # 裁剪出(指定的)cat cat_result = cat_results[cat_idx] crop_box = ( - cat_result["xmin"], - cat_result["ymin"], - cat_result["xmax"], - cat_result["ymax"], + cat_result[0][0], + cat_result[0][1], + cat_result[0][2], + cat_result[0][3], ) # 裁剪后直接resize到正方形 src_img = src_img.crop(crop_box).resize((IMG_SIZE, IMG_SIZE)) @@ -143,10 +150,10 @@ def recognize_cat_photo(): { "catBoxes": [ { - "xmin": item["xmin"], - "ymin": item["ymin"], - "xmax": item["xmax"], - "ymax": item["ymax"], + "xmin": item[0][0], + "ymin": item[0][1], + "xmax": item[0][2], + "ymax": item[0][3], } for item in cat_results ][:CAT_BOX_MAX_RET_NUM], diff --git a/data_preprocess.py b/data_preprocess.py index b24fe40..f64e8e6 100644 --- a/data_preprocess.py +++ b/data_preprocess.py @@ -7,11 +7,13 @@ SRC = "data/photos" DEST = "data/crop_photos" +# SRC = "test/photos" +# DEST = "test/crop_photos" if __name__ == "__main__": print("loading YOLO11 model...") - # model = YOLO("yolo11m-cls.pt") - model = torch.hub.load("yolov5", "custom", "yolov5/yolov5m.pt", source="local") + model = YOLO("yolo11m.pt") + # model = torch.hub.load("yolov5", "custom", "yolov5/yolov5m.pt", source="local") num_photos = 0 num_skipped_photos = 0 @@ -36,13 +38,20 @@ dest_file_path = os.path.join(dest_path, file_name) # 使用 YOLOv5 进行目标检测,结果为[{xmin, ymin, xmax, ymax, confidence, class, name}]格式 try: - results = model(src_file_path).pandas().xyxy[0].to_dict("records") + # results = model(src_file_path).pandas().xyxy[0].to_dict("records") + results = model(src_file_path) + # print(results[0].boxes.xyxy.tolist(), results[0].names[results[0].boxes.cls.tolist()[0]]) except OSError as err: # 发现有的图片有问题,会导致 PIL 抛出 OSError: image file is truncated num_skipped_photos += 1 continue # 过滤非cat目标 - cat_results = list(filter(lambda target: target["name"] == "cat", results)) + cat_results = [] + for result in results: + for box in result.boxes: + print(result.names[box.cls.tolist()[0]], box.xyxy.tolist()) + if result.names[box.cls.tolist()[0]] == "cat": + cat_results.append(box.xyxy.tolist()) # 跳过图片内检测不到cat或有多个cat的图片 if len(cat_results) == 0: num_skipped_photos += 1 @@ -50,10 +59,10 @@ # 裁剪出cat cat_result = cat_results[0] crop_box = ( - cat_result["xmin"], - cat_result["ymin"], - cat_result["xmax"], - cat_result["ymax"], + cat_result[0][0], + cat_result[0][1], + cat_result[0][2], + cat_result[0][3], ) Image.open(src_file_path).convert("RGB").crop(crop_box).save( dest_file_path, format="JPEG" diff --git a/process.sh b/process.sh index d9359e0..7023ad7 100644 --- a/process.sh +++ b/process.sh @@ -2,4 +2,9 @@ python3 data_preprocess.py python3 data_split.py python3 data_split.py --name fallback --source data/photos --size 512 python3 train.py -python3 train.py --data data/dataset-fallback --name fallback --size 512 \ No newline at end of file +python3 train.py --data data/dataset-fallback --name fallback --size 512 +rm -rf data/crop_photos +rm -rf data/dataset-fallback +rm -rf data/dataset-cat +# 遍历 data/photos 下所有文件夹,删除以 copy 开头的文件 +find data/photos -type f -name "copy*" -exec rm -f {} \; \ No newline at end of file