82 lines
2.9 KiB
Python
82 lines
2.9 KiB
Python
import os
|
|
import shutil
|
|
import random
|
|
import argparse
|
|
|
|
# 設定出口說明
|
|
parser = argparse.ArgumentParser(description='Split dataset into train, test, and val subsets')
|
|
parser.add_argument('--image-dir', type=str, default='./img0204/P/', help='Directory containing image files')
|
|
parser.add_argument('--label-dir', type=str, default='./img0204/P/', help='Directory containing label files')
|
|
parser.add_argument('--output-dir', type=str, default='./YOLODataset/', help='Output directory for the dataset')
|
|
|
|
args = parser.parse_args()
|
|
|
|
# 接收參數
|
|
image_dir = args.image_dir
|
|
label_dir = args.label_dir
|
|
output_dir = args.output_dir
|
|
|
|
# 設定輸出子資料夾
|
|
image_output_dir = os.path.join(output_dir, "images")
|
|
label_output_dir = os.path.join(output_dir, "labels")
|
|
subsets = ["train", "test", "val"]
|
|
|
|
# 清空輸出資料夾
|
|
for subset in subsets:
|
|
image_subset_dir = os.path.join(image_output_dir, subset)
|
|
label_subset_dir = os.path.join(label_output_dir, subset)
|
|
|
|
# 若資料夾存在,則清空
|
|
if os.path.exists(image_subset_dir):
|
|
shutil.rmtree(image_subset_dir)
|
|
if os.path.exists(label_subset_dir):
|
|
shutil.rmtree(label_subset_dir)
|
|
|
|
# 重新建立空資料夾
|
|
os.makedirs(image_subset_dir, exist_ok=True)
|
|
os.makedirs(label_subset_dir, exist_ok=True)
|
|
|
|
# 設定比例
|
|
train_ratio = 0.8
|
|
test_ratio = 0.1
|
|
val_ratio = 0.1
|
|
|
|
# 獲取所有影像檔案名稱
|
|
image_files = [f for f in os.listdir(image_dir) if f.endswith((".jpg", ".png", ".bmp"))]
|
|
|
|
# 確保每個影像都有對應的標籤檔
|
|
image_files = [f for f in image_files if os.path.exists(os.path.join(label_dir, f.replace(".bmp", ".txt").replace(".jpg", ".txt").replace(".png", ".txt")))]
|
|
|
|
# 隨機打亂影像檔案列表
|
|
random.shuffle(image_files)
|
|
|
|
# 計算分配數量
|
|
total_files = len(image_files)
|
|
train_count = int(total_files * train_ratio)
|
|
test_count = int(total_files * test_ratio)
|
|
val_count = total_files - train_count - test_count # 剩下的分到 val
|
|
|
|
# 分配影像到子集
|
|
train_files = image_files[:train_count]
|
|
test_files = image_files[train_count:train_count + test_count]
|
|
val_files = image_files[train_count + test_count:]
|
|
|
|
# 定義一個函式來處理檔案複製
|
|
def copy_files(files, subset):
|
|
for file in files:
|
|
# 複製影像檔案
|
|
image_src = os.path.join(image_dir, file)
|
|
image_dest = os.path.join(image_output_dir, subset, file)
|
|
shutil.copy(image_src, image_dest)
|
|
|
|
# 複製對應的標籤檔案
|
|
label_src = os.path.join(label_dir, file.replace(".bmp", ".txt").replace(".jpg", ".txt").replace(".png", ".txt"))
|
|
label_dest = os.path.join(label_output_dir, subset, file.replace(".bmp", ".txt").replace(".jpg", ".txt").replace(".png", ".txt"))
|
|
shutil.copy(label_src, label_dest)
|
|
|
|
# 執行檔案複製
|
|
copy_files(train_files, "train")
|
|
copy_files(test_files, "test")
|
|
copy_files(val_files, "val")
|
|
|
|
print("Dataset split completed! Existing files in the output directory were cleared.") |