AP/split_data.py
2025-02-07 17:40:26 +08:00

82 lines
2.9 KiB
Python

import os
import shutil
import random
import argparse
# 設定出口說明
parser = argparse.ArgumentParser(description='Split dataset into train, test, and val subsets')
parser.add_argument('--image-dir', type=str, default='./img0204/P/', help='Directory containing image files')
parser.add_argument('--label-dir', type=str, default='./img0204/P/', help='Directory containing label files')
parser.add_argument('--output-dir', type=str, default='./YOLODataset/', help='Output directory for the dataset')
args = parser.parse_args()
# 接收參數
image_dir = args.image_dir
label_dir = args.label_dir
output_dir = args.output_dir
# 設定輸出子資料夾
image_output_dir = os.path.join(output_dir, "images")
label_output_dir = os.path.join(output_dir, "labels")
subsets = ["train", "test", "val"]
# 清空輸出資料夾
for subset in subsets:
image_subset_dir = os.path.join(image_output_dir, subset)
label_subset_dir = os.path.join(label_output_dir, subset)
# 若資料夾存在,則清空
if os.path.exists(image_subset_dir):
shutil.rmtree(image_subset_dir)
if os.path.exists(label_subset_dir):
shutil.rmtree(label_subset_dir)
# 重新建立空資料夾
os.makedirs(image_subset_dir, exist_ok=True)
os.makedirs(label_subset_dir, exist_ok=True)
# 設定比例
train_ratio = 0.8
test_ratio = 0.1
val_ratio = 0.1
# 獲取所有影像檔案名稱
image_files = [f for f in os.listdir(image_dir) if f.endswith((".jpg", ".png", ".bmp"))]
# 確保每個影像都有對應的標籤檔
image_files = [f for f in image_files if os.path.exists(os.path.join(label_dir, f.replace(".bmp", ".txt").replace(".jpg", ".txt").replace(".png", ".txt")))]
# 隨機打亂影像檔案列表
random.shuffle(image_files)
# 計算分配數量
total_files = len(image_files)
train_count = int(total_files * train_ratio)
test_count = int(total_files * test_ratio)
val_count = total_files - train_count - test_count # 剩下的分到 val
# 分配影像到子集
train_files = image_files[:train_count]
test_files = image_files[train_count:train_count + test_count]
val_files = image_files[train_count + test_count:]
# 定義一個函式來處理檔案複製
def copy_files(files, subset):
for file in files:
# 複製影像檔案
image_src = os.path.join(image_dir, file)
image_dest = os.path.join(image_output_dir, subset, file)
shutil.copy(image_src, image_dest)
# 複製對應的標籤檔案
label_src = os.path.join(label_dir, file.replace(".bmp", ".txt").replace(".jpg", ".txt").replace(".png", ".txt"))
label_dest = os.path.join(label_output_dir, subset, file.replace(".bmp", ".txt").replace(".jpg", ".txt").replace(".png", ".txt"))
shutil.copy(label_src, label_dest)
# 執行檔案複製
copy_files(train_files, "train")
copy_files(test_files, "test")
copy_files(val_files, "val")
print("Dataset split completed! Existing files in the output directory were cleared.")