# -*- coding: utf-8 -*- """ ------------------------------------------------- Project Name: unet File Name: test.py Author: chenming Create Date: 2022/2/7 Description: ------------------------------------------------- """ import os from tqdm import tqdm from utils.utils_metrics import compute_mIoU, show_results import glob import numpy as np import torch import os import cv2 from model.unet_model import UNet from torchmetrics import Dice def cal_miou(test_dir=r"C:\Users\USER\Desktop\segment\test\Testimages_new", pred_dir=r"C:\Users\USER\Desktop\segment\unet_42-master\results", gt_dir=r"C:\Users\USER\Desktop\segment\test\Testlabels_new"): # ---------------------------------------------------------------------------# # miou_mode 用於指定該文件運行時計算的內容 # miou_mode 為 0 代表整個 miou 計算流程,包括獲得預測結果、計算 miou。 # miou_mode 為 1 代表僅僅獲得預測結果。 # miou_mode 為 2 代表僅僅計算 miou。 # ---------------------------------------------------------------------------# miou_mode = 0 # ------------------------------# # 分類個數 + 1,如 2 + 1 # ------------------------------# num_classes = 2 # --------------------------------------------# # 區分的種類,和 json_to_dataset 裡面的一樣 # --------------------------------------------# name_classes = ["background", "hippocampus"] # name_classes = ["_background_","cat","dog"] # -------------------------------------------------------# # 指向 VOC 數據集所在的文件夾 # 預設指向根目錄下的 VOC 數據集 # -------------------------------------------------------# # 計算結果與 gt 的結果進行比對 # 載入模型 if miou_mode == 0 or miou_mode == 1: if not os.path.exists(pred_dir): os.makedirs(pred_dir) print("載入模型中。") device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 載入網路,圖片為單通道,分類為 1。 net = UNet(n_channels=1, n_classes=1) # 將網路拷貝至裝置中 net.to(device=device) # 載入模型參數 net.load_state_dict(torch.load('best_model_30_1.pth', map_location=device)) # todo # 測試模式 net.eval() print("模型載入完成。") img_names = os.listdir(test_dir) image_ids = [image_name.split(".")[0] for image_name in img_names] print("獲取預測結果中。") for image_id in tqdm(image_ids): image_path = os.path.join(test_dir, image_id + ".png") img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE) origin_shape = img.shape img = cv2.resize(img, (512, 512)) # 轉換為批次大小為 1,通道數為 1,大小為 512x512 的陣列 img = img.reshape(1, 1, img.shape[0], img.shape[1]) img_tensor = torch.from_numpy(img) img_tensor = img_tensor.to(device=device, dtype=torch.float32) pred = net(img_tensor) pred = np.array(pred.data.cpu()[0])[0] pred[pred >= 0.5] = 255 pred[pred < 0.5] = 0 pred = cv2.resize(pred, (origin_shape[1], origin_shape[0]), interpolation=cv2.INTER_NEAREST) cv2.imwrite(os.path.join(pred_dir, image_id + ".png"), pred) print("預測結果獲取完成。") if miou_mode == 0 or miou_mode == 2: print("計算 mIoU 中。") print(gt_dir) print(pred_dir) print(num_classes) print(name_classes) hist, IoUs, PA_Recall, Precision = compute_mIoU(gt_dir, pred_dir, image_ids, num_classes, name_classes) # 執行計算 mIoU 的函數 print("mIoU 計算完成。") miou_out_path = "results/" show_results(miou_out_path, hist, IoUs, PA_Recall, Precision, name_classes) if __name__ == '__main__': cal_miou()