screwdriver/test/move_image.py

110 lines
4.4 KiB
Python
Raw Permalink Normal View History

2025-02-06 16:10:58 +08:00
import cv2
from ultralytics import YOLO
import numpy as np
from PyQt5.QtWidgets import QApplication, QFileDialog, QListView, QTreeView, QAbstractItemView
import os
import shutil
from collections import defaultdict
class yolo_class():
def __init__(self, yolo_path):
self.yolo_path = yolo_path
self.model = YOLO(yolo_path) # load a custom model
def YoloDetect(self, imgs):
results = self.model(imgs) # predict on an image
return results
def get_class_name(index):
# 更新類型名稱對應表,包含 41 個類別
class_names = {
0: "HexTamperproof_1", 1: "HexTamperproof_2", 2: "HexTamperproof_3", 3: "HexTamperproof_4",
4: "Hex_1", 5: "Hex_2", 6: "Hex_3", 7: "Opposite_1", 8: "Pentalope_1", 9: "Pentalope_2",
10: "Phillips_1", 11: "Phillips_2", 12: "Phillips_3", 13: "Phillips_4", 14: "Phillips_5",
15: "Slotted_1", 16: "Slotted_2", 17: "Slotted_3", 18: "Slotted_4", 19: "Slotted_5",
20: "Slotted_6", 21: "Spanner_1", 22: "Spanner_2", 23: "Square_1", 24: "Square_2",
25: "Standoff_1", 26: "TORXTamperproof_1", 27: "TORXTamperproof_2", 28: "TORXTamperproof_3",
29: "TORXTamperproof_4", 30: "TORXTamperproof_5", 31: "TORXTamperproof_6", 32: "TORXTamperproof_7",
33: "TORX_1", 34: "TORX_2", 35: "TriWing_1", 36: "TriWing_2", 37: "TriWing_3",
38: "Triangle_1", 39: "Triangle_2", 40: "Vacancy_1"
}
return class_names.get(index, f"Unknown_{index}")
def process_and_move_images(yolo, input_dir, error_base_dir):
moved_count = defaultdict(int)
expected_class = os.path.basename(input_dir)
for filename in os.listdir(input_dir):
if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp')):
image_path = os.path.join(input_dir, filename)
image = cv2.imread(image_path)
results = yolo.YoloDetect(image)
result_Max_Index = results[0].probs.top1 # 獲取分類編號
confidence = results[0].probs.top1conf.item() # 獲取 top1 的置信度
# 獲取類型名稱
predicted_class = get_class_name(result_Max_Index)
print(f"原始圖片: {image_path}")
print(f"推論結果: {predicted_class}, 信心值: {confidence}")
# 如果預測類別與期望類別不同且信心值高於 0.4,則移動圖片
if predicted_class != expected_class and confidence > 0.4:
# 只創建預測類別資料夾
error_class_dir = os.path.join(error_base_dir, predicted_class)
if not os.path.exists(error_class_dir):
os.makedirs(error_class_dir)
new_image_path = os.path.join(error_class_dir, filename)
shutil.move(image_path, new_image_path)
print(f"移動圖片至: {new_image_path}")
moved_count[predicted_class] += 1
else:
print("圖片未移動")
print("---")
total_moved = sum(moved_count.values())
print(f"總共移動了 {total_moved} 張圖片")
print("各類型移動數量:")
for class_name, count in moved_count.items():
print(f"{class_name}: {count}")
return moved_count
def process_multiple_folders(yolo, input_dirs):
for input_dir in input_dirs:
print(f"處理資料夾: {input_dir}")
error_base_dir = os.path.join(os.path.dirname(input_dir), "Error")
moved_count = process_and_move_images(yolo, input_dir, error_base_dir)
print(f"資料夾 {input_dir} 處理完成\n")
if __name__ == '__main__':
app = QApplication([])
yolo = yolo_class(r"D:\ScrewdriverFile\train_0811_Milwaukee_41\weights\best.pt")
# 選擇多個輸入資料夾
dialog = QFileDialog()
dialog.setFileMode(QFileDialog.DirectoryOnly)
dialog.setOption(QFileDialog.DontUseNativeDialog, True)
file_view = dialog.findChild(QListView, 'listView')
if file_view:
file_view.setSelectionMode(QAbstractItemView.ExtendedSelection)
tree_view = dialog.findChild(QTreeView, 'treeView')
if tree_view:
tree_view.setSelectionMode(QAbstractItemView.ExtendedSelection)
if dialog.exec_():
input_dirs = dialog.selectedFiles()
if input_dirs:
process_multiple_folders(yolo, input_dirs)
print("所有選定的資料夾處理完成")
else:
print("未選擇資料夾")
else:
print("取消選擇資料夾")