image_segmentation/sample code/Image Segmentation/unet_42-master/test.py

108 lines
4.1 KiB
Python
Raw Normal View History

2025-01-20 16:21:14 +08:00
# -*- coding: utf-8 -*-
"""
-------------------------------------------------
Project Name: unet
File Name: test.py
Author: chenming
Create Date: 2022/2/7
Description
-------------------------------------------------
"""
import os
from tqdm import tqdm
from utils.utils_metrics import compute_mIoU, show_results
import glob
import numpy as np
import torch
import os
import cv2
from model.unet_model import UNet
from torchmetrics import Dice
def cal_miou(test_dir=r"C:\Users\USER\Desktop\segment\test\Testimages_new",
pred_dir=r"C:\Users\USER\Desktop\segment\unet_42-master\results",
gt_dir=r"C:\Users\USER\Desktop\segment\test\Testlabels_new"):
# ---------------------------------------------------------------------------#
# miou_mode 用於指定該文件運行時計算的內容
# miou_mode 為 0 代表整個 miou 計算流程,包括獲得預測結果、計算 miou。
# miou_mode 為 1 代表僅僅獲得預測結果。
# miou_mode 為 2 代表僅僅計算 miou。
# ---------------------------------------------------------------------------#
miou_mode = 0
# ------------------------------#
# 分類個數 + 1如 2 + 1
# ------------------------------#
num_classes = 2
# --------------------------------------------#
# 區分的種類,和 json_to_dataset 裡面的一樣
# --------------------------------------------#
name_classes = ["background", "hippocampus"]
# name_classes = ["_background_","cat","dog"]
# -------------------------------------------------------#
# 指向 VOC 數據集所在的文件夾
# 預設指向根目錄下的 VOC 數據集
# -------------------------------------------------------#
# 計算結果與 gt 的結果進行比對
# 載入模型
if miou_mode == 0 or miou_mode == 1:
if not os.path.exists(pred_dir):
os.makedirs(pred_dir)
print("載入模型中。")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 載入網路,圖片為單通道,分類為 1。
net = UNet(n_channels=1, n_classes=1)
# 將網路拷貝至裝置中
net.to(device=device)
# 載入模型參數
net.load_state_dict(torch.load('best_model_30_1.pth', map_location=device)) # todo
# 測試模式
net.eval()
print("模型載入完成。")
img_names = os.listdir(test_dir)
image_ids = [image_name.split(".")[0] for image_name in img_names]
print("獲取預測結果中。")
for image_id in tqdm(image_ids):
image_path = os.path.join(test_dir, image_id + ".png")
img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
origin_shape = img.shape
img = cv2.resize(img, (512, 512))
# 轉換為批次大小為 1通道數為 1大小為 512x512 的陣列
img = img.reshape(1, 1, img.shape[0], img.shape[1])
img_tensor = torch.from_numpy(img)
img_tensor = img_tensor.to(device=device, dtype=torch.float32)
pred = net(img_tensor)
pred = np.array(pred.data.cpu()[0])[0]
pred[pred >= 0.5] = 255
pred[pred < 0.5] = 0
pred = cv2.resize(pred, (origin_shape[1], origin_shape[0]), interpolation=cv2.INTER_NEAREST)
cv2.imwrite(os.path.join(pred_dir, image_id + ".png"), pred)
print("預測結果獲取完成。")
if miou_mode == 0 or miou_mode == 2:
print("計算 mIoU 中。")
print(gt_dir)
print(pred_dir)
print(num_classes)
print(name_classes)
hist, IoUs, PA_Recall, Precision = compute_mIoU(gt_dir, pred_dir, image_ids, num_classes,
name_classes) # 執行計算 mIoU 的函數
print("mIoU 計算完成。")
miou_out_path = "results/"
show_results(miou_out_path, hist, IoUs, PA_Recall, Precision, name_classes)
if __name__ == '__main__':
cal_miou()