This repository has been archived by the owner on Sep 25, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 13
/
data_loader.py
111 lines (93 loc) · 3.54 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from config import torch
from config import device
from config import seed_everything
import pathlib
import plotext as tplt
from torchvision import transforms
import matplotlib.pyplot as plt
import fnmatch
import os
import cv2
from albumentations import HorizontalFlip, VerticalFlip, RandomRotate90, Compose
imagenet_stats = [[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]]
data_dir = pathlib.Path("./data/colorEnhanced/")
TRAIN_DIR = data_dir / "train"
VALID_DIR = data_dir / "val"
img_transforms = {
"train": transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.RandomRotation(45),
transforms.ToTensor(),
transforms.Normalize(*imagenet_stats),
]
),
"valid": transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(*imagenet_stats),
]
),
}
def augment_and_save(path, target_number=1000):
"""augment dataset if total number per class is less than 1000 and save to data dir."""
subfolders = [f.path for f in os.scandir(path) if f.is_dir()]
for subfolder in subfolders:
images = fnmatch.filter(os.listdir(subfolder), "*.png")
augmentations_per_image = max(target_number // len(images), 1)
augmentations = Compose(
[
HorizontalFlip(),
VerticalFlip(),
RandomRotate90(),
]
)
for image in images:
image_path = os.path.join(subfolder, image)
img = cv2.imread(image_path)
for i in range(augmentations_per_image):
augmented = augmentations(image=img)
new_filename = os.path.splitext(image)[0] + f"_{i}.png"
cv2.imwrite(
os.path.join(subfolder, new_filename),
augmented["image"],
)
def _denormalize(images, imagenet_stats):
"""De-normalize dataset using imagenet std and mean to show images."""
mean = torch.tensor(imagenet_stats[0]).reshape(1, 3, 1, 1)
std = torch.tensor(imagenet_stats[1]).reshape(1, 3, 1, 1)
return images * std + mean
def show_data(dataloader, imagenet_stats=imagenet_stats, num_data=2):
"""Show `num_data` of images and labels from dataloader."""
batch = next(iter(dataloader))
imgs, labels = batch[0][:num_data].to(device), batch[1][:num_data].tolist()
if plt.get_backend() == "agg":
print(f"Labels for {num_data} images: {labels}")
else:
_, axes = plt.subplots(1, num_data, figsize=(10, 6))
for n in range(num_data):
axes[n].set_title(labels[n])
imgs[n] = _denormalize(imgs[n], imagenet_stats)
axes[n].imshow(torch.clamp(imgs[n].cpu(), 0, 1).permute(1, 2, 0))
plt.show()
def data_distribution(dataset, path: str) -> dict:
"""
Returns a dictionary with the distribution of each class in the dataset.
"""
class_counts = {
cls: len(fnmatch.filter(os.listdir(f"{path}/{cls}"), "*.png"))
for cls in dataset.class_to_idx.keys()
}
return class_counts
def plot_data_distribution(data_dist: dict, title: str = ""):
import seaborn as sns
classes, counts = list(data_dist.keys()), list(data_dist.values())
if plt.get_backend() == "agg":
tplt.simple_bar(classes, counts, width=100, title=title)
tplt.show()
else:
sns.barplot(x=classes, y=counts).set_title(title)
plt.show()