-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathprepare_data.py
More file actions
95 lines (77 loc) · 2.82 KB
/
prepare_data.py
File metadata and controls
95 lines (77 loc) · 2.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os
import shutil
import zipfile
from pathlib import Path
from tqdm import tqdm
# Target classes and their original COCO IDs (must stay aligned)
COCO_MAP = {
47: "apple", 46: "banana", 49: "orange", 50: "broccoli", 51: "carrot",
48: "sandwich", 52: "hot dog", 53: "pizza", 54: "donut", 55: "cake"
}
def main():
root = Path(__file__).resolve().parent
dl_dir = root / "downloads"
output_root = root / "dataset" / "fruit_veg_v1"
# 1) Validate source archives
img_zip = dl_dir / "val2017.zip"
lbl_zip = dl_dir / "coco2017labels.zip"
if not img_zip.exists() or not lbl_zip.exists():
print("Error: place val2017.zip and coco2017labels.zip under the downloads directory.")
return
# 2) Reset output structure
if output_root.exists():
shutil.rmtree(output_root)
img_out = output_root / "images" / "val2017"
lbl_out = output_root / "labels" / "val2017"
img_out.mkdir(parents=True)
lbl_out.mkdir(parents=True)
# 3) Temporary extraction directory
tmp_dir = root / "tmp_extract"
if tmp_dir.exists():
shutil.rmtree(tmp_dir)
tmp_dir.mkdir()
# 4) Extract archives
print("Extracting images and labels...")
with zipfile.ZipFile(img_zip, 'r') as z:
z.extractall(tmp_dir)
with zipfile.ZipFile(lbl_zip, 'r') as z:
z.extractall(tmp_dir)
# Locate extracted paths
src_lbl_dir = tmp_dir / "coco" / "labels" / "val2017"
src_img_dir = tmp_dir / "val2017"
# 5) Filter the 10 target classes
print("Filtering 10 fruit/vegetable classes...")
target_ids = list(COCO_MAP.keys())
id_remap = {old_id: i for i, old_id in enumerate(target_ids)}
new_names = [COCO_MAP[oid] for oid in target_ids]
kept_count = 0
for lbl_file in tqdm(list(src_lbl_dir.glob("*.txt"))):
with open(lbl_file, 'r') as f:
lines = f.readlines()
new_lines = []
for line in lines:
parts = line.split()
cls_id = int(parts[0])
if cls_id in target_ids:
parts[0] = str(id_remap[cls_id])
new_lines.append(" ".join(parts))
if new_lines:
img_name = lbl_file.stem + ".jpg"
if (src_img_dir / img_name).exists():
shutil.copy2(src_img_dir / img_name, img_out / img_name)
with open(lbl_out / lbl_file.name, 'w') as f:
f.write("\n".join(new_lines))
kept_count += 1
# 6) Write data.yaml with relative paths
yaml_content = f"""path: .
train: images/val2017
val: images/val2017
nc: {len(new_names)}
names: {new_names}
"""
(output_root / "data.yaml").write_text(yaml_content)
# 7) Clean up temporary files
shutil.rmtree(tmp_dir)
print(f"Done. Kept {kept_count} images. Dataset ready at {output_root}.")
if __name__ == "__main__":
main()