Skip to content

Commit

Permalink
feat: 将 app.py 和 preprocess 中识猫切割模型更换为 yolo11
Browse files Browse the repository at this point in the history
  • Loading branch information
Dusker233 committed Jan 20, 2025
1 parent 1c5a78d commit 6be474d
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 20 deletions.
29 changes: 18 additions & 11 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import time
from base64 import b64encode
from hashlib import sha256
from ultralytics import YOLO

load_dotenv("./env", override=True)

Expand All @@ -32,7 +33,8 @@
"export"
), "*** export directory not found! you should export the training checkpoint to ONNX model."

crop_model = torch.hub.load("yolov5", "custom", "yolov5/yolov5m.onnx", source="local")
# crop_model = torch.hub.load("yolov5", "custom", "yolov5/yolov5m.onnx", source="local")
crop_model = YOLO("yolo11m.pt")

with open("export/cat.json", "r") as fp:
cat_ids = json.load(fp)
Expand Down Expand Up @@ -82,9 +84,14 @@ def recognize_cat_photo():

src_img = Image.open(photo).convert("RGB")
# 使用 YOLOv5 进行目标检测,结果为[{xmin, ymin, xmax, ymax, confidence, class, name}]格式
results = crop_model(src_img).pandas().xyxy[0].to_dict("records")
results = crop_model(src_img)
# 过滤非cat目标
cat_results = list(filter(lambda target: target["name"] == "cat", results))
cat_results = []
for result in results:
for box in result.boxes:
# print(result.names[box.cls.tolist()[0]], box.xyxy.tolist())
if result.names[box.cls.tolist()[0]] == "cat":
cat_results.append(box.xyxy.tolist())

if len(cat_results) >= 1:
cat_idx = (
Expand All @@ -97,10 +104,10 @@ def recognize_cat_photo():
# 裁剪出(指定的)cat
cat_result = cat_results[cat_idx]
crop_box = (
cat_result["xmin"],
cat_result["ymin"],
cat_result["xmax"],
cat_result["ymax"],
cat_result[0][0],
cat_result[0][1],
cat_result[0][2],
cat_result[0][3],
)
# 裁剪后直接resize到正方形
src_img = src_img.crop(crop_box).resize((IMG_SIZE, IMG_SIZE))
Expand Down Expand Up @@ -143,10 +150,10 @@ def recognize_cat_photo():
{
"catBoxes": [
{
"xmin": item["xmin"],
"ymin": item["ymin"],
"xmax": item["xmax"],
"ymax": item["ymax"],
"xmin": item[0][0],
"ymin": item[0][1],
"xmax": item[0][2],
"ymax": item[0][3],
}
for item in cat_results
][:CAT_BOX_MAX_RET_NUM],
Expand Down
25 changes: 17 additions & 8 deletions data_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@

SRC = "data/photos"
DEST = "data/crop_photos"
# SRC = "test/photos"
# DEST = "test/crop_photos"

if __name__ == "__main__":
print("loading YOLO11 model...")
# model = YOLO("yolo11m-cls.pt")
model = torch.hub.load("yolov5", "custom", "yolov5/yolov5m.pt", source="local")
model = YOLO("yolo11m.pt")
# model = torch.hub.load("yolov5", "custom", "yolov5/yolov5m.pt", source="local")

num_photos = 0
num_skipped_photos = 0
Expand All @@ -36,24 +38,31 @@
dest_file_path = os.path.join(dest_path, file_name)
# 使用 YOLOv5 进行目标检测,结果为[{xmin, ymin, xmax, ymax, confidence, class, name}]格式
try:
results = model(src_file_path).pandas().xyxy[0].to_dict("records")
# results = model(src_file_path).pandas().xyxy[0].to_dict("records")
results = model(src_file_path)
# print(results[0].boxes.xyxy.tolist(), results[0].names[results[0].boxes.cls.tolist()[0]])
except OSError as err:
# 发现有的图片有问题,会导致 PIL 抛出 OSError: image file is truncated
num_skipped_photos += 1
continue
# 过滤非cat目标
cat_results = list(filter(lambda target: target["name"] == "cat", results))
cat_results = []
for result in results:
for box in result.boxes:
print(result.names[box.cls.tolist()[0]], box.xyxy.tolist())
if result.names[box.cls.tolist()[0]] == "cat":
cat_results.append(box.xyxy.tolist())
# 跳过图片内检测不到cat或有多个cat的图片
if len(cat_results) == 0:
num_skipped_photos += 1
continue
# 裁剪出cat
cat_result = cat_results[0]
crop_box = (
cat_result["xmin"],
cat_result["ymin"],
cat_result["xmax"],
cat_result["ymax"],
cat_result[0][0],
cat_result[0][1],
cat_result[0][2],
cat_result[0][3],
)
Image.open(src_file_path).convert("RGB").crop(crop_box).save(
dest_file_path, format="JPEG"
Expand Down
7 changes: 6 additions & 1 deletion process.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,9 @@ python3 data_preprocess.py
python3 data_split.py
python3 data_split.py --name fallback --source data/photos --size 512
python3 train.py
python3 train.py --data data/dataset-fallback --name fallback --size 512
python3 train.py --data data/dataset-fallback --name fallback --size 512
rm -rf data/crop_photos
rm -rf data/dataset-fallback
rm -rf data/dataset-cat
# 遍历 data/photos 下所有文件夹,删除以 copy 开头的文件
find data/photos -type f -name "copy*" -exec rm -f {} \;

0 comments on commit 6be474d

Please sign in to comment.