74 lines
2.5 KiB
Python
74 lines
2.5 KiB
Python
import os
|
|
import shutil
|
|
import random
|
|
|
|
# 設定資料夾路徑
|
|
image_dir = r"E:\code\image_segmentation\pic"
|
|
label_dir = r"E:\code\image_segmentation\labelme_txt_dir"
|
|
output_dir = r"E:\code\image_segmentation\YOLODataset"
|
|
|
|
# 設定輸出子資料夾
|
|
image_output_dir = os.path.join(output_dir, "images")
|
|
label_output_dir = os.path.join(output_dir, "labels")
|
|
subsets = ["train", "test", "val"]
|
|
|
|
# 清空輸出資料夾
|
|
for subset in subsets:
|
|
image_subset_dir = os.path.join(image_output_dir, subset)
|
|
label_subset_dir = os.path.join(label_output_dir, subset)
|
|
|
|
# 若資料夾存在,則清空
|
|
if os.path.exists(image_subset_dir):
|
|
shutil.rmtree(image_subset_dir)
|
|
if os.path.exists(label_subset_dir):
|
|
shutil.rmtree(label_subset_dir)
|
|
|
|
# 重新建立空資料夾
|
|
os.makedirs(image_subset_dir, exist_ok=True)
|
|
os.makedirs(label_subset_dir, exist_ok=True)
|
|
|
|
# 設定比例
|
|
train_ratio = 0.8
|
|
test_ratio = 0.1
|
|
val_ratio = 0.1
|
|
|
|
# 獲取所有影像檔案名稱
|
|
image_files = [f for f in os.listdir(image_dir) if f.endswith((".jpg", ".png", ".bmp"))]
|
|
|
|
# 確保每個影像都有對應的標籤檔
|
|
image_files = [f for f in image_files if os.path.exists(os.path.join(label_dir, f.replace(".bmp", ".txt").replace(".jpg", ".txt").replace(".png", ".txt")))]
|
|
|
|
# 隨機打亂影像檔案列表
|
|
random.shuffle(image_files)
|
|
|
|
# 計算分配數量
|
|
total_files = len(image_files)
|
|
train_count = int(total_files * train_ratio)
|
|
test_count = int(total_files * test_ratio)
|
|
val_count = total_files - train_count - test_count # 剩下的分到 val
|
|
|
|
# 分配影像到子集
|
|
train_files = image_files[:train_count]
|
|
test_files = image_files[train_count:train_count + test_count]
|
|
val_files = image_files[train_count + test_count:]
|
|
|
|
# 定義一個函式來處理檔案複製
|
|
def copy_files(files, subset):
|
|
for file in files:
|
|
# 複製影像檔案
|
|
image_src = os.path.join(image_dir, file)
|
|
image_dest = os.path.join(image_output_dir, subset, file)
|
|
shutil.copy(image_src, image_dest)
|
|
|
|
# 複製對應的標籤檔案
|
|
label_src = os.path.join(label_dir, file.replace(".bmp", ".txt").replace(".jpg", ".txt").replace(".png", ".txt"))
|
|
label_dest = os.path.join(label_output_dir, subset, file.replace(".bmp", ".txt").replace(".jpg", ".txt").replace(".png", ".txt"))
|
|
shutil.copy(label_src, label_dest)
|
|
|
|
# 執行檔案複製
|
|
copy_files(train_files, "train")
|
|
copy_files(test_files, "test")
|
|
copy_files(val_files, "val")
|
|
|
|
print("Dataset split completed! Existing files in the output directory were cleared.")
|