Skip to content

Commit

Permalink
added test object
Browse files Browse the repository at this point in the history
  • Loading branch information
bw4sz committed Nov 5, 2024
1 parent 4dd9914 commit 0d13175
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 9 deletions.
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ include deepforest/data/2019_YELL_2_528000_4978000_image_crop2.xml
include deepforest/data/2019_YELL_2_528000_4978000_image_crop2.png
include deepforest/data/AWPE*
include deepforest/data/example.csv
include deepforest/data/test_tiled.tif

include LICENSE
include dev_requirements.txt
Expand Down
Binary file added deepforest/data/test_tiled.tif
Binary file not shown.
12 changes: 12 additions & 0 deletions deepforest/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ class RasterDataset:
raster_path (str): Path to raster file
patch_size (int): Size of windows to predict on
patch_overlap (float): Overlap between windows as fraction (0-1)
Returns:
A dataset of raster windows
"""
def __init__(self, raster_path, patch_size, patch_overlap):
self.raster_path = raster_path
Expand All @@ -206,6 +208,16 @@ def __init__(self, raster_path, patch_size, patch_overlap):
width = src.shape[0]
height = src.shape[1]

# Check is tiled
if not src.is_tiled:
raise ValueError(
"Out-of-memory dataset is selected, but raster is not tiled, "
"leading to entire raster being read into memory and defeating "
"the purpose of an out-of-memory dataset. "
"\nPlease run: "
"\ngdal_translate -of GTiff -co TILED=YES <input> <output> "
"to create a tiled raster"
)
# Generate sliding windows
self.windows = slidingwindow.generateForSize(
height,
Expand Down
6 changes: 5 additions & 1 deletion deepforest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ def predict_tile(self,
iou_threshold: Minimum iou overlap among predictions between
windows to be suppressed.
Lower values suppress more boxes at edges.
in_memory: If true, the entire dataset is loaded into memory. This is useful for small datasets, but not recommended for large datasets since both the tile and the crops are stored in memory.
in_memory: If true, the entire dataset is loaded into memory, which increases speed. This is useful for small datasets, but not recommended for very large datasets.
mosaic: Return a single prediction dataframe (True) or a tuple of image crops and predictions (False)
sigma: variance of Gaussian function used in Gaussian Soft NMS
thresh: the score thresh used to filter bboxes after soft-nms performed
Expand All @@ -509,6 +509,10 @@ def predict_tile(self,
- color: Deprecated bounding box color for visualizations.
- thickness: Deprecated bounding box thickness for visualizations.
Raises:
- ValueError: If `raster_path` is None when `in_memory=False`.
- ValueError: If `workers` is greater than 0 when `in_memory=False`. Multiprocessing is not supported when using out-of-memory datasets, rasterio is not threadsafe.
Returns:
- If `return_plot` is True, returns an image with predictions overlaid (deprecated).
- If `mosaic` is True, returns a Pandas DataFrame containing the predicted
Expand Down
11 changes: 3 additions & 8 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,17 +185,16 @@ def test_bounding_box_dataset():
# Check the shape of the RGB tensor
assert item.shape == (3, 224, 224)

def test_raster_dataset(raster_path):
def test_raster_dataset():
"""Test the RasterDataset class"""
from deepforest.dataset import RasterDataset
import torch
from torch.utils.data import DataLoader

# Test initialization and context manager
ds = RasterDataset(raster_path, patch_size=256, patch_overlap=0.1)
ds = RasterDataset(get_data("test_tiled.tif"), patch_size=256, patch_overlap=0.1)

# Test basic properties
assert hasattr(ds, 'raster')
assert hasattr(ds, 'windows')

# Test first window
Expand All @@ -209,8 +208,4 @@ def test_raster_dataset(raster_path):
dataloader = DataLoader(ds, batch_size=2, num_workers=0)
batch = next(iter(dataloader))
assert batch.shape[0] == 2 # Batch size
assert batch.shape[1] == 3 # Channels first

# Test that context manager closed the raster
ds.close()
assert ds.raster.closed
assert batch.shape[1] == 3 # Channels first

0 comments on commit 0d13175

Please sign in to comment.