image_segmentation/sample code/Image Segmentation/unet_42-master/test.py
2025-01-20 16:21:14 +08:00

108 lines
4.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
-------------------------------------------------
Project Name: unet
File Name: test.py
Author: chenming
Create Date: 2022/2/7
Description
-------------------------------------------------
"""
import os
from tqdm import tqdm
from utils.utils_metrics import compute_mIoU, show_results
import glob
import numpy as np
import torch
import os
import cv2
from model.unet_model import UNet
from torchmetrics import Dice
def cal_miou(test_dir=r"C:\Users\USER\Desktop\segment\test\Testimages_new",
pred_dir=r"C:\Users\USER\Desktop\segment\unet_42-master\results",
gt_dir=r"C:\Users\USER\Desktop\segment\test\Testlabels_new"):
# ---------------------------------------------------------------------------#
# miou_mode 用於指定該文件運行時計算的內容
# miou_mode 為 0 代表整個 miou 計算流程,包括獲得預測結果、計算 miou。
# miou_mode 為 1 代表僅僅獲得預測結果。
# miou_mode 為 2 代表僅僅計算 miou。
# ---------------------------------------------------------------------------#
miou_mode = 0
# ------------------------------#
# 分類個數 + 1如 2 + 1
# ------------------------------#
num_classes = 2
# --------------------------------------------#
# 區分的種類,和 json_to_dataset 裡面的一樣
# --------------------------------------------#
name_classes = ["background", "hippocampus"]
# name_classes = ["_background_","cat","dog"]
# -------------------------------------------------------#
# 指向 VOC 數據集所在的文件夾
# 預設指向根目錄下的 VOC 數據集
# -------------------------------------------------------#
# 計算結果與 gt 的結果進行比對
# 載入模型
if miou_mode == 0 or miou_mode == 1:
if not os.path.exists(pred_dir):
os.makedirs(pred_dir)
print("載入模型中。")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 載入網路,圖片為單通道,分類為 1。
net = UNet(n_channels=1, n_classes=1)
# 將網路拷貝至裝置中
net.to(device=device)
# 載入模型參數
net.load_state_dict(torch.load('best_model_30_1.pth', map_location=device)) # todo
# 測試模式
net.eval()
print("模型載入完成。")
img_names = os.listdir(test_dir)
image_ids = [image_name.split(".")[0] for image_name in img_names]
print("獲取預測結果中。")
for image_id in tqdm(image_ids):
image_path = os.path.join(test_dir, image_id + ".png")
img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
origin_shape = img.shape
img = cv2.resize(img, (512, 512))
# 轉換為批次大小為 1通道數為 1大小為 512x512 的陣列
img = img.reshape(1, 1, img.shape[0], img.shape[1])
img_tensor = torch.from_numpy(img)
img_tensor = img_tensor.to(device=device, dtype=torch.float32)
pred = net(img_tensor)
pred = np.array(pred.data.cpu()[0])[0]
pred[pred >= 0.5] = 255
pred[pred < 0.5] = 0
pred = cv2.resize(pred, (origin_shape[1], origin_shape[0]), interpolation=cv2.INTER_NEAREST)
cv2.imwrite(os.path.join(pred_dir, image_id + ".png"), pred)
print("預測結果獲取完成。")
if miou_mode == 0 or miou_mode == 2:
print("計算 mIoU 中。")
print(gt_dir)
print(pred_dir)
print(num_classes)
print(name_classes)
hist, IoUs, PA_Recall, Precision = compute_mIoU(gt_dir, pred_dir, image_ids, num_classes,
name_classes) # 執行計算 mIoU 的函數
print("mIoU 計算完成。")
miou_out_path = "results/"
show_results(miou_out_path, hist, IoUs, PA_Recall, Precision, name_classes)
if __name__ == '__main__':
cal_miou()