Skip to content

Commit

Permalink
Add an out of memory dataset (#831)
Browse files Browse the repository at this point in the history
* need to independently measure memory, close dataset each iteration
  and added test object
  • Loading branch information
bw4sz authored Nov 16, 2024
1 parent c86f96d commit def50b9
Show file tree
Hide file tree
Showing 8 changed files with 151 additions and 14 deletions.
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ include deepforest/data/2019_YELL_2_528000_4978000_image_crop2.xml
include deepforest/data/2019_YELL_2_528000_4978000_image_crop2.png
include deepforest/data/AWPE*
include deepforest/data/example.csv
include deepforest/data/test_tiled.tif

include LICENSE
include dev_requirements.txt
Expand Down
2 changes: 1 addition & 1 deletion deepforest/data/deepforest_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# Cpu workers for data loaders
# Dataloaders
workers: 1
workers: 0
devices: auto
accelerator: auto
batch_size: 1
Expand Down
Binary file added deepforest/data/test_tiled.tif
Binary file not shown.
67 changes: 67 additions & 0 deletions deepforest/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from deepforest import preprocess
from rasterio.windows import Window
from torchvision import transforms
import slidingwindow
import warnings


def get_transform(augment):
Expand Down Expand Up @@ -153,6 +155,7 @@ def __init__(self,
tile: an in memory numpy array.
patch_size (int): The size for the crops used to cut the input raster into smaller pieces. This is given in pixels, not any geographic unit.
patch_overlap (float): The horizontal and vertical overlap among patches
preload_images (bool): If true, the entire dataset is loaded into memory. This is useful for small datasets, but not recommended for large datasets since both the tile and the crops are stored in memory.
Returns:
ds: a pytorch dataset
Expand Down Expand Up @@ -187,6 +190,70 @@ def __getitem__(self, idx):
return crop


class RasterDataset:
"""Dataset for predicting on raster windows.
Args:
raster_path (str): Path to raster file
patch_size (int): Size of windows to predict on
patch_overlap (float): Overlap between windows as fraction (0-1)
Returns:
A dataset of raster windows
"""

def __init__(self, raster_path, patch_size, patch_overlap):
self.raster_path = raster_path
self.patch_size = patch_size
self.patch_overlap = patch_overlap

# Get raster shape without keeping file open
with rio.open(raster_path) as src:
width = src.shape[0]
height = src.shape[1]

# Check is tiled
if not src.is_tiled:
raise ValueError(
"Out-of-memory dataset is selected, but raster is not tiled, "
"leading to entire raster being read into memory and defeating "
"the purpose of an out-of-memory dataset. "
"\nPlease run: "
"\ngdal_translate -of GTiff -co TILED=YES <input> <output> "
"to create a tiled raster")
# Generate sliding windows
self.windows = slidingwindow.generateForSize(
height,
width,
dimOrder=slidingwindow.DimOrder.ChannelHeightWidth,
maxWindowSize=patch_size,
overlapPercent=patch_overlap)
self.n_windows = len(self.windows)

def __len__(self):
return self.n_windows

def __getitem__(self, idx):
"""Get a window of the raster.
Args:
idx (int): Index of window to get
Returns:
crop (torch.Tensor): A tensor of shape (3, height, width)
"""
window = self.windows[idx]

# Open, read window, and close for each operation
with rio.open(self.raster_path) as src:
window_data = src.read(window=Window(window.x, window.y, window.w, window.h))

# Convert to torch tensor and rearrange dimensions
window_data = torch.from_numpy(window_data).float() # Convert to torch tensor
window_data = window_data / 255.0 # Normalize

return window_data # Already in (C, H, W) format from rasterio


def bounding_box_transform(augment=False):
data_transforms = []
data_transforms.append(transforms.ToTensor())
Expand Down
47 changes: 37 additions & 10 deletions deepforest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,7 @@ def predict_tile(self,
patch_size=400,
patch_overlap=0.05,
iou_threshold=0.15,
in_memory=True,
return_plot=False,
mosaic=True,
sigma=0.5,
Expand All @@ -490,6 +491,7 @@ def predict_tile(self,
iou_threshold: Minimum iou overlap among predictions between
windows to be suppressed.
Lower values suppress more boxes at edges.
in_memory: If true, the entire dataset is loaded into memory, which increases speed. This is useful for small datasets, but not recommended for very large datasets.
mosaic: Return a single prediction dataframe (True) or a tuple of image crops and predictions (False)
sigma: variance of Gaussian function used in Gaussian Soft NMS
thresh: the score thresh used to filter bboxes after soft-nms performed
Expand All @@ -506,6 +508,10 @@ def predict_tile(self,
- color: Deprecated bounding box color for visualizations.
- thickness: Deprecated bounding box thickness for visualizations.
Raises:
- ValueError: If `raster_path` is None when `in_memory=False`.
- ValueError: If `workers` is greater than 0 when `in_memory=False`. Multiprocessing is not supported when using out-of-memory datasets, rasterio is not threadsafe.
Returns:
- If `return_plot` is True, returns an image with predictions overlaid (deprecated).
- If `mosaic` is True, returns a Pandas DataFrame containing the predicted
Expand All @@ -530,15 +536,30 @@ def predict_tile(self,
"Both tile and tile_path are None. Either supply a path to a tile on disk, or read one into memory!"
)

if raster_path is None:
self.image = image
if in_memory:
if raster_path is None:
image = image
else:
image = rio.open(raster_path).read()
image = np.moveaxis(image, 0, 2)

ds = dataset.TileDataset(tile=image,
patch_overlap=patch_overlap,
patch_size=patch_size)
else:
self.image = rio.open(raster_path).read()
self.image = np.moveaxis(self.image, 0, 2)
if raster_path is None:
raise ValueError("raster_path is required if in_memory is False")

# Check for workers config when using out of memory dataset
if self.config["workers"] > 0:
raise ValueError(
"workers must be 0 when using out-of-memory dataset (in_memory=False). Set config['workers']=0 and recreate trainer self.create_trainer()."
)

ds = dataset.RasterDataset(raster_path=raster_path,
patch_overlap=patch_overlap,
patch_size=patch_size)

ds = dataset.TileDataset(tile=self.image,
patch_overlap=patch_overlap,
patch_size=patch_size)
batched_results = self.trainer.predict(self, self.predict_dataloader(ds))

# Flatten list from batched prediction
Expand All @@ -565,7 +586,7 @@ def predict_tile(self,
if raster_path:
tile = rio.open(raster_path).read()
else:
tile = self.image
tile = image
drawn_plot = tile[:, :, ::-1]
drawn_plot = visualize.plot_predictions(tile,
results,
Expand All @@ -576,10 +597,16 @@ def predict_tile(self,
for df in results:
df["label"] = df.label.apply(lambda x: self.numeric_to_label_dict[x])

# TODO this is the 2nd time the crops are generated? Could be more efficient.
# TODO this is the 2nd time the crops are generated? Could be more efficient, but memory intensive
self.crops = []
if raster_path is None:
image = image
else:
image = rio.open(raster_path).read()
image = np.moveaxis(image, 0, 2)

for window in ds.windows:
crop = self.image[window.indices()]
crop = image[window.indices()]
self.crops.append(crop)

return list(zip(results, self.crops))
Expand Down
2 changes: 1 addition & 1 deletion deepforest_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# Cpu workers for data loaders
# Dataloaders
workers: 1
workers: 0
devices: auto
accelerator: auto
batch_size: 1
Expand Down
28 changes: 28 additions & 0 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ def multi_class():

return csv_file

@pytest.fixture()
def raster_path():
return get_data(path='OSBS_029.tif')

@pytest.mark.parametrize("csv_file,label_dict", [(single_class(), {"Tree": 0}), (multi_class(), {"Alive": 0, "Dead": 1})])
def test_tree_dataset(csv_file, label_dict):
Expand Down Expand Up @@ -181,3 +184,28 @@ def test_bounding_box_dataset():

# Check the shape of the RGB tensor
assert item.shape == (3, 224, 224)

def test_raster_dataset():
"""Test the RasterDataset class"""
from deepforest.dataset import RasterDataset
import torch
from torch.utils.data import DataLoader

# Test initialization and context manager
ds = RasterDataset(get_data("test_tiled.tif"), patch_size=256, patch_overlap=0.1)

# Test basic properties
assert hasattr(ds, 'windows')

# Test first window
first_crop = ds[0]
assert isinstance(first_crop, torch.Tensor)
assert first_crop.dtype == torch.float32
assert first_crop.shape[0] == 3 # RGB channels first
assert 0 <= first_crop.min() <= first_crop.max() <= 1.0 # Check normalization

# Test with DataLoader
dataloader = DataLoader(ds, batch_size=2, num_workers=0)
batch = next(iter(dataloader))
assert batch.shape[0] == 2 # Batch size
assert batch.shape[1] == 3 # Channels first
18 changes: 16 additions & 2 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,12 +269,20 @@ def test_predict_tile_empty(raster_path):
predictions = m.predict_tile(raster_path=raster_path, patch_size=300, patch_overlap=0)
assert predictions is None

def test_predict_tile(m, raster_path):
@pytest.mark.parametrize("in_memory", [True, False])
def test_predict_tile(m, raster_path, in_memory):
m.create_model()
m.config["train"]["fast_dev_run"] = False
m.create_trainer()

if in_memory:
raster_path = raster_path
else:
raster_path = get_data("test_tiled.tif")

prediction = m.predict_tile(raster_path=raster_path,
patch_size=300,
in_memory=in_memory,
patch_overlap=0.1)

assert isinstance(prediction, pd.DataFrame)
Expand All @@ -283,6 +291,12 @@ def test_predict_tile(m, raster_path):
}
assert not prediction.empty

# test equivalence for in_memory=True and False
def test_predict_tile_equivalence(m):
raster_path = get_data("test_tiled.tif")
in_memory_prediction = m.predict_tile(raster_path=raster_path, patch_size=300, patch_overlap=0, in_memory=True)
not_in_memory_prediction = m.predict_tile(raster_path=raster_path, patch_size=300, patch_overlap=0, in_memory=False)
assert in_memory_prediction.equals(not_in_memory_prediction)

@pytest.mark.parametrize("patch_overlap", [0.1, 0])
def test_predict_tile_from_array(m, patch_overlap, raster_path):
Expand Down Expand Up @@ -634,4 +648,4 @@ def test_predict_tile_with_crop_model(m, config):
assert set(result.columns) == {
"xmin", "ymin", "xmax", "ymax", "label", "score", "cropmodel_label", "geometry",
"cropmodel_score", "image_path"
}
}

0 comments on commit def50b9

Please sign in to comment.