-
Notifications
You must be signed in to change notification settings - Fork 107
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
cb7946e
commit 836044d
Showing
1 changed file
with
59 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from torchvision.datasets import VisionDataset | ||
|
||
from PIL import Image | ||
|
||
import os | ||
import os.path | ||
import sys | ||
|
||
|
||
def pil_loader(path): | ||
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) | ||
with open(path, 'rb') as f: | ||
img = Image.open(f) | ||
return img.convert('RGB') | ||
|
||
|
||
class Caltech(VisionDataset): | ||
def __init__(self, root, split='train', transform=None, target_transform=None): | ||
super(Caltech, self).__init__(root, transform=transform, target_transform=target_transform) | ||
|
||
self.split = split # This defines the split you are going to use | ||
# (split files are called 'train.txt' and 'test.txt') | ||
|
||
''' | ||
- Here you should implement the logic for reading the splits files and accessing elements | ||
- If the RAM size allows it, it is faster to store all data in memory | ||
- PyTorch Dataset classes use indexes to read elements | ||
- You should provide a way for the __getitem__ method to access the image-label pair | ||
through the index | ||
- Labels should start from 0, so for Caltech you will have lables 0...100 (excluding the background class) | ||
''' | ||
|
||
def __getitem__(self, index): | ||
''' | ||
__getitem__ should access an element through its index | ||
Args: | ||
index (int): Index | ||
Returns: | ||
tuple: (sample, target) where target is class_index of the target class. | ||
''' | ||
|
||
image, label = ... # Provide a way to access image and label via index | ||
# Image should be a PIL Image | ||
# label can be int | ||
|
||
# Applies preprocessing when accessing the image | ||
if self.transform is not None: | ||
image = self.transform(image) | ||
|
||
return image, label | ||
|
||
def __len__(self): | ||
''' | ||
The __len__ method returns the length of the dataset | ||
It is mandatory, as this is used by several other components | ||
''' | ||
length = ... # Provide a way to get the length (number of elements) of the dataset | ||
return length |