108 lines
4.1 KiB
Python
108 lines
4.1 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
-------------------------------------------------
|
||
Project Name: unet
|
||
File Name: test.py
|
||
Author: chenming
|
||
Create Date: 2022/2/7
|
||
Description:
|
||
-------------------------------------------------
|
||
"""
|
||
|
||
import os
|
||
from tqdm import tqdm
|
||
from utils.utils_metrics import compute_mIoU, show_results
|
||
import glob
|
||
import numpy as np
|
||
import torch
|
||
import os
|
||
import cv2
|
||
from model.unet_model import UNet
|
||
from torchmetrics import Dice
|
||
|
||
def cal_miou(test_dir=r"C:\Users\USER\Desktop\segment\test\Testimages_new",
|
||
pred_dir=r"C:\Users\USER\Desktop\segment\unet_42-master\results",
|
||
gt_dir=r"C:\Users\USER\Desktop\segment\test\Testlabels_new"):
|
||
# ---------------------------------------------------------------------------#
|
||
# miou_mode 用於指定該文件運行時計算的內容
|
||
# miou_mode 為 0 代表整個 miou 計算流程,包括獲得預測結果、計算 miou。
|
||
# miou_mode 為 1 代表僅僅獲得預測結果。
|
||
# miou_mode 為 2 代表僅僅計算 miou。
|
||
# ---------------------------------------------------------------------------#
|
||
miou_mode = 0
|
||
# ------------------------------#
|
||
# 分類個數 + 1,如 2 + 1
|
||
# ------------------------------#
|
||
num_classes = 2
|
||
# --------------------------------------------#
|
||
# 區分的種類,和 json_to_dataset 裡面的一樣
|
||
# --------------------------------------------#
|
||
name_classes = ["background", "hippocampus"]
|
||
# name_classes = ["_background_","cat","dog"]
|
||
# -------------------------------------------------------#
|
||
# 指向 VOC 數據集所在的文件夾
|
||
# 預設指向根目錄下的 VOC 數據集
|
||
# -------------------------------------------------------#
|
||
# 計算結果與 gt 的結果進行比對
|
||
|
||
# 載入模型
|
||
|
||
if miou_mode == 0 or miou_mode == 1:
|
||
if not os.path.exists(pred_dir):
|
||
os.makedirs(pred_dir)
|
||
|
||
print("載入模型中。")
|
||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||
# 載入網路,圖片為單通道,分類為 1。
|
||
net = UNet(n_channels=1, n_classes=1)
|
||
# 將網路拷貝至裝置中
|
||
net.to(device=device)
|
||
# 載入模型參數
|
||
net.load_state_dict(torch.load('best_model_30_1.pth', map_location=device)) # todo
|
||
# 測試模式
|
||
net.eval()
|
||
print("模型載入完成。")
|
||
|
||
img_names = os.listdir(test_dir)
|
||
image_ids = [image_name.split(".")[0] for image_name in img_names]
|
||
print("獲取預測結果中。")
|
||
for image_id in tqdm(image_ids):
|
||
image_path = os.path.join(test_dir, image_id + ".png")
|
||
img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
|
||
|
||
origin_shape = img.shape
|
||
img = cv2.resize(img, (512, 512))
|
||
|
||
# 轉換為批次大小為 1,通道數為 1,大小為 512x512 的陣列
|
||
img = img.reshape(1, 1, img.shape[0], img.shape[1])
|
||
|
||
img_tensor = torch.from_numpy(img)
|
||
|
||
img_tensor = img_tensor.to(device=device, dtype=torch.float32)
|
||
|
||
pred = net(img_tensor)
|
||
|
||
pred = np.array(pred.data.cpu()[0])[0]
|
||
pred[pred >= 0.5] = 255
|
||
pred[pred < 0.5] = 0
|
||
pred = cv2.resize(pred, (origin_shape[1], origin_shape[0]), interpolation=cv2.INTER_NEAREST)
|
||
cv2.imwrite(os.path.join(pred_dir, image_id + ".png"), pred)
|
||
|
||
print("預測結果獲取完成。")
|
||
|
||
|
||
if miou_mode == 0 or miou_mode == 2:
|
||
print("計算 mIoU 中。")
|
||
print(gt_dir)
|
||
print(pred_dir)
|
||
print(num_classes)
|
||
print(name_classes)
|
||
hist, IoUs, PA_Recall, Precision = compute_mIoU(gt_dir, pred_dir, image_ids, num_classes,
|
||
name_classes) # 執行計算 mIoU 的函數
|
||
print("mIoU 計算完成。")
|
||
miou_out_path = "results/"
|
||
show_results(miou_out_path, hist, IoUs, PA_Recall, Precision, name_classes)
|
||
|
||
if __name__ == '__main__':
|
||
cal_miou()
|