image_segmentation/split_data.py
2025-01-20 17:09:20 +08:00

74 lines
2.5 KiB
Python

import os
import shutil
import random
# 設定資料夾路徑
image_dir = r"E:\code\image_segmentation\pic"
label_dir = r"E:\code\image_segmentation\labelme_txt_dir"
output_dir = r"E:\code\image_segmentation\YOLODataset"
# 設定輸出子資料夾
image_output_dir = os.path.join(output_dir, "images")
label_output_dir = os.path.join(output_dir, "labels")
subsets = ["train", "test", "val"]
# 清空輸出資料夾
for subset in subsets:
image_subset_dir = os.path.join(image_output_dir, subset)
label_subset_dir = os.path.join(label_output_dir, subset)
# 若資料夾存在,則清空
if os.path.exists(image_subset_dir):
shutil.rmtree(image_subset_dir)
if os.path.exists(label_subset_dir):
shutil.rmtree(label_subset_dir)
# 重新建立空資料夾
os.makedirs(image_subset_dir, exist_ok=True)
os.makedirs(label_subset_dir, exist_ok=True)
# 設定比例
train_ratio = 0.8
test_ratio = 0.1
val_ratio = 0.1
# 獲取所有影像檔案名稱
image_files = [f for f in os.listdir(image_dir) if f.endswith((".jpg", ".png", ".bmp"))]
# 確保每個影像都有對應的標籤檔
image_files = [f for f in image_files if os.path.exists(os.path.join(label_dir, f.replace(".bmp", ".txt").replace(".jpg", ".txt").replace(".png", ".txt")))]
# 隨機打亂影像檔案列表
random.shuffle(image_files)
# 計算分配數量
total_files = len(image_files)
train_count = int(total_files * train_ratio)
test_count = int(total_files * test_ratio)
val_count = total_files - train_count - test_count # 剩下的分到 val
# 分配影像到子集
train_files = image_files[:train_count]
test_files = image_files[train_count:train_count + test_count]
val_files = image_files[train_count + test_count:]
# 定義一個函式來處理檔案複製
def copy_files(files, subset):
for file in files:
# 複製影像檔案
image_src = os.path.join(image_dir, file)
image_dest = os.path.join(image_output_dir, subset, file)
shutil.copy(image_src, image_dest)
# 複製對應的標籤檔案
label_src = os.path.join(label_dir, file.replace(".bmp", ".txt").replace(".jpg", ".txt").replace(".png", ".txt"))
label_dest = os.path.join(label_output_dir, subset, file.replace(".bmp", ".txt").replace(".jpg", ".txt").replace(".png", ".txt"))
shutil.copy(label_src, label_dest)
# 執行檔案複製
copy_files(train_files, "train")
copy_files(test_files, "test")
copy_files(val_files, "val")
print("Dataset split completed! Existing files in the output directory were cleared.")