128 lines
5.2 KiB
Python
128 lines
5.2 KiB
Python
|
import cv2
|
||
|
from ultralytics import YOLO
|
||
|
import numpy as np
|
||
|
from PyQt5.QtWidgets import QApplication, QFileDialog, QListView, QTreeView, QAbstractItemView
|
||
|
import os
|
||
|
import shutil
|
||
|
from collections import defaultdict
|
||
|
from datetime import datetime
|
||
|
|
||
|
|
||
|
class yolo_class():
|
||
|
def __init__(self, yolo_path):
|
||
|
self.yolo_path = yolo_path
|
||
|
self.model = YOLO(yolo_path) # load a custom model
|
||
|
|
||
|
def YoloDetect(self, imgs):
|
||
|
results = self.model(imgs) # predict on an image
|
||
|
return results
|
||
|
|
||
|
|
||
|
def get_class_name(index):
|
||
|
# 更新類型名稱對應表,包含 41 個類別
|
||
|
class_names = {
|
||
|
0: "HexTamperproof_1", 1: "HexTamperproof_2", 2: "HexTamperproof_3", 3: "HexTamperproof_4",
|
||
|
4: "Hex_1", 5: "Hex_2", 6: "Hex_3", 7: "Opposite_1", 8: "Pentalope_1", 9: "Pentalope_2",
|
||
|
10: "Phillips_1", 11: "Phillips_2", 12: "Phillips_3", 13: "Phillips_4", 14: "Phillips_5",
|
||
|
15: "Slotted_1", 16: "Slotted_2", 17: "Slotted_3", 18: "Slotted_4", 19: "Slotted_5",
|
||
|
20: "Slotted_6", 21: "Spanner_1", 22: "Spanner_2", 23: "Square_1", 24: "Square_2",
|
||
|
25: "Standoff_1", 26: "TORXTamperproof_1", 27: "TORXTamperproof_2", 28: "TORXTamperproof_3",
|
||
|
29: "TORXTamperproof_4", 30: "TORXTamperproof_5", 31: "TORXTamperproof_6", 32: "TORXTamperproof_7",
|
||
|
33: "TORX_1", 34: "TORX_2", 35: "TriWing_1", 36: "TriWing_2", 37: "TriWing_3",
|
||
|
38: "Triangle_1", 39: "Triangle_2", 40: "Vacancy_1"
|
||
|
}
|
||
|
return class_names.get(index, f"Unknown_{index}")
|
||
|
|
||
|
|
||
|
def process_and_move_images(yolo, input_dir, error_base_dir):
|
||
|
moved_count = defaultdict(int)
|
||
|
expected_class = os.path.basename(input_dir)
|
||
|
|
||
|
for filename in os.listdir(input_dir):
|
||
|
if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp')):
|
||
|
image_path = os.path.join(input_dir, filename)
|
||
|
image = cv2.imread(image_path)
|
||
|
results = yolo.YoloDetect(image)
|
||
|
result_Max_Index = results[0].probs.top1 # 獲取分類編號
|
||
|
confidence = results[0].probs.top1conf.item() # 獲取 top1 的置信度
|
||
|
|
||
|
# 獲取類型名稱
|
||
|
predicted_class = get_class_name(result_Max_Index)
|
||
|
print(f"原始圖片: {image_path}")
|
||
|
print(f"推論結果: {predicted_class}, 信心值: {confidence}")
|
||
|
|
||
|
# 如果預測類別與期望類別不同且信心值高於 0.2,則移動圖片
|
||
|
if predicted_class != expected_class and confidence > 0.4:
|
||
|
# 創建以原資料夾名稱命名的子資料夾
|
||
|
error_class_dir = os.path.join(error_base_dir, predicted_class, expected_class)
|
||
|
if not os.path.exists(error_class_dir):
|
||
|
os.makedirs(error_class_dir)
|
||
|
new_image_path = os.path.join(error_class_dir, filename)
|
||
|
shutil.move(image_path, new_image_path)
|
||
|
print(f"移動圖片至: {new_image_path}")
|
||
|
moved_count[predicted_class] += 1
|
||
|
else:
|
||
|
print("圖片未移動")
|
||
|
print("---")
|
||
|
|
||
|
total_moved = sum(moved_count.values())
|
||
|
print(f"總共移動了 {total_moved} 張圖片")
|
||
|
print("各類型移動數量:")
|
||
|
for class_name, count in moved_count.items():
|
||
|
print(f"{class_name}: {count} 張")
|
||
|
|
||
|
return moved_count
|
||
|
|
||
|
|
||
|
def write_log(input_dir, moved_count):
|
||
|
current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
|
log_filename = f"移動記錄_{current_time}.txt"
|
||
|
log_path = os.path.join(input_dir, log_filename)
|
||
|
|
||
|
with open(log_path, 'w', encoding='utf-8') as log_file:
|
||
|
log_file.write(f"執行時間: {current_time}\n")
|
||
|
log_file.write(f"原始資料夾: {input_dir}\n\n")
|
||
|
log_file.write("移動數量記錄:\n")
|
||
|
total_moved = sum(moved_count.values())
|
||
|
log_file.write(f"總共移動: {total_moved} 張\n")
|
||
|
for class_name, count in moved_count.items():
|
||
|
log_file.write(f"{class_name}: {count} 張\n")
|
||
|
|
||
|
print(f"記錄檔案已保存至: {log_path}")
|
||
|
|
||
|
|
||
|
def process_multiple_folders(yolo, input_dirs):
|
||
|
for input_dir in input_dirs:
|
||
|
print(f"處理資料夾: {input_dir}")
|
||
|
error_base_dir = os.path.join(os.path.dirname(input_dir), "Error")
|
||
|
moved_count = process_and_move_images(yolo, input_dir, error_base_dir)
|
||
|
write_log(input_dir, moved_count)
|
||
|
print(f"資料夾 {input_dir} 處理完成\n")
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
app = QApplication([])
|
||
|
yolo = yolo_class(r"D:\ScrewdriverFile\train_0811_Milwaukee_41\weights\best.pt")
|
||
|
|
||
|
# 選擇多個輸入資料夾
|
||
|
dialog = QFileDialog()
|
||
|
dialog.setFileMode(QFileDialog.DirectoryOnly)
|
||
|
dialog.setOption(QFileDialog.DontUseNativeDialog, True)
|
||
|
|
||
|
file_view = dialog.findChild(QListView, 'listView')
|
||
|
if file_view:
|
||
|
file_view.setSelectionMode(QAbstractItemView.ExtendedSelection)
|
||
|
|
||
|
tree_view = dialog.findChild(QTreeView, 'treeView')
|
||
|
if tree_view:
|
||
|
tree_view.setSelectionMode(QAbstractItemView.ExtendedSelection)
|
||
|
|
||
|
if dialog.exec_():
|
||
|
input_dirs = dialog.selectedFiles()
|
||
|
if input_dirs:
|
||
|
process_multiple_folders(yolo, input_dirs)
|
||
|
print("所有選定的資料夾處理完成")
|
||
|
else:
|
||
|
print("未選擇資料夾")
|
||
|
else:
|
||
|
print("取消選擇資料夾")
|