From 836044dd000296b2bb74d0aa4b4e462e459b01d7 Mon Sep 17 00:00:00 2001 From: MachineLearning2020 Date: Mon, 25 Nov 2019 11:56:28 +0100 Subject: [PATCH] Uploaded dataset template --- caltech_dataset.py | 59 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 caltech_dataset.py diff --git a/caltech_dataset.py b/caltech_dataset.py new file mode 100644 index 00000000..ebaf223f --- /dev/null +++ b/caltech_dataset.py @@ -0,0 +1,59 @@ +from torchvision.datasets import VisionDataset + +from PIL import Image + +import os +import os.path +import sys + + +def pil_loader(path): + # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) + with open(path, 'rb') as f: + img = Image.open(f) + return img.convert('RGB') + + +class Caltech(VisionDataset): + def __init__(self, root, split='train', transform=None, target_transform=None): + super(Caltech, self).__init__(root, transform=transform, target_transform=target_transform) + + self.split = split # This defines the split you are going to use + # (split files are called 'train.txt' and 'test.txt') + + ''' + - Here you should implement the logic for reading the splits files and accessing elements + - If the RAM size allows it, it is faster to store all data in memory + - PyTorch Dataset classes use indexes to read elements + - You should provide a way for the __getitem__ method to access the image-label pair + through the index + - Labels should start from 0, so for Caltech you will have lables 0...100 (excluding the background class) + ''' + + def __getitem__(self, index): + ''' + __getitem__ should access an element through its index + Args: + index (int): Index + + Returns: + tuple: (sample, target) where target is class_index of the target class. + ''' + + image, label = ... # Provide a way to access image and label via index + # Image should be a PIL Image + # label can be int + + # Applies preprocessing when accessing the image + if self.transform is not None: + image = self.transform(image) + + return image, label + + def __len__(self): + ''' + The __len__ method returns the length of the dataset + It is mandatory, as this is used by several other components + ''' + length = ... # Provide a way to get the length (number of elements) of the dataset + return length