108 lines
4.1 KiB
Python
108 lines
4.1 KiB
Python
|
# -*- coding: utf-8 -*-
|
|||
|
"""
|
|||
|
-------------------------------------------------
|
|||
|
Project Name: unet
|
|||
|
File Name: test.py
|
|||
|
Author: chenming
|
|||
|
Create Date: 2022/2/7
|
|||
|
Description:
|
|||
|
-------------------------------------------------
|
|||
|
"""
|
|||
|
|
|||
|
import os
|
|||
|
from tqdm import tqdm
|
|||
|
from utils.utils_metrics import compute_mIoU, show_results
|
|||
|
import glob
|
|||
|
import numpy as np
|
|||
|
import torch
|
|||
|
import os
|
|||
|
import cv2
|
|||
|
from model.unet_model import UNet
|
|||
|
from torchmetrics import Dice
|
|||
|
|
|||
|
def cal_miou(test_dir=r"C:\Users\USER\Desktop\segment\test\Testimages_new",
|
|||
|
pred_dir=r"C:\Users\USER\Desktop\segment\unet_42-master\results",
|
|||
|
gt_dir=r"C:\Users\USER\Desktop\segment\test\Testlabels_new"):
|
|||
|
# ---------------------------------------------------------------------------#
|
|||
|
# miou_mode 用於指定該文件運行時計算的內容
|
|||
|
# miou_mode 為 0 代表整個 miou 計算流程,包括獲得預測結果、計算 miou。
|
|||
|
# miou_mode 為 1 代表僅僅獲得預測結果。
|
|||
|
# miou_mode 為 2 代表僅僅計算 miou。
|
|||
|
# ---------------------------------------------------------------------------#
|
|||
|
miou_mode = 0
|
|||
|
# ------------------------------#
|
|||
|
# 分類個數 + 1,如 2 + 1
|
|||
|
# ------------------------------#
|
|||
|
num_classes = 2
|
|||
|
# --------------------------------------------#
|
|||
|
# 區分的種類,和 json_to_dataset 裡面的一樣
|
|||
|
# --------------------------------------------#
|
|||
|
name_classes = ["background", "hippocampus"]
|
|||
|
# name_classes = ["_background_","cat","dog"]
|
|||
|
# -------------------------------------------------------#
|
|||
|
# 指向 VOC 數據集所在的文件夾
|
|||
|
# 預設指向根目錄下的 VOC 數據集
|
|||
|
# -------------------------------------------------------#
|
|||
|
# 計算結果與 gt 的結果進行比對
|
|||
|
|
|||
|
# 載入模型
|
|||
|
|
|||
|
if miou_mode == 0 or miou_mode == 1:
|
|||
|
if not os.path.exists(pred_dir):
|
|||
|
os.makedirs(pred_dir)
|
|||
|
|
|||
|
print("載入模型中。")
|
|||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|||
|
# 載入網路,圖片為單通道,分類為 1。
|
|||
|
net = UNet(n_channels=1, n_classes=1)
|
|||
|
# 將網路拷貝至裝置中
|
|||
|
net.to(device=device)
|
|||
|
# 載入模型參數
|
|||
|
net.load_state_dict(torch.load('best_model_30_1.pth', map_location=device)) # todo
|
|||
|
# 測試模式
|
|||
|
net.eval()
|
|||
|
print("模型載入完成。")
|
|||
|
|
|||
|
img_names = os.listdir(test_dir)
|
|||
|
image_ids = [image_name.split(".")[0] for image_name in img_names]
|
|||
|
print("獲取預測結果中。")
|
|||
|
for image_id in tqdm(image_ids):
|
|||
|
image_path = os.path.join(test_dir, image_id + ".png")
|
|||
|
img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
|
|||
|
|
|||
|
origin_shape = img.shape
|
|||
|
img = cv2.resize(img, (512, 512))
|
|||
|
|
|||
|
# 轉換為批次大小為 1,通道數為 1,大小為 512x512 的陣列
|
|||
|
img = img.reshape(1, 1, img.shape[0], img.shape[1])
|
|||
|
|
|||
|
img_tensor = torch.from_numpy(img)
|
|||
|
|
|||
|
img_tensor = img_tensor.to(device=device, dtype=torch.float32)
|
|||
|
|
|||
|
pred = net(img_tensor)
|
|||
|
|
|||
|
pred = np.array(pred.data.cpu()[0])[0]
|
|||
|
pred[pred >= 0.5] = 255
|
|||
|
pred[pred < 0.5] = 0
|
|||
|
pred = cv2.resize(pred, (origin_shape[1], origin_shape[0]), interpolation=cv2.INTER_NEAREST)
|
|||
|
cv2.imwrite(os.path.join(pred_dir, image_id + ".png"), pred)
|
|||
|
|
|||
|
print("預測結果獲取完成。")
|
|||
|
|
|||
|
|
|||
|
if miou_mode == 0 or miou_mode == 2:
|
|||
|
print("計算 mIoU 中。")
|
|||
|
print(gt_dir)
|
|||
|
print(pred_dir)
|
|||
|
print(num_classes)
|
|||
|
print(name_classes)
|
|||
|
hist, IoUs, PA_Recall, Precision = compute_mIoU(gt_dir, pred_dir, image_ids, num_classes,
|
|||
|
name_classes) # 執行計算 mIoU 的函數
|
|||
|
print("mIoU 計算完成。")
|
|||
|
miou_out_path = "results/"
|
|||
|
show_results(miou_out_path, hist, IoUs, PA_Recall, Precision, name_classes)
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
cal_miou()
|