diff --git a/MANIFEST.in b/MANIFEST.in index 009b5940..a7bc07d2 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -11,6 +11,7 @@ include deepforest/data/2019_YELL_2_528000_4978000_image_crop2.xml include deepforest/data/2019_YELL_2_528000_4978000_image_crop2.png include deepforest/data/AWPE* include deepforest/data/example.csv +include deepforest/data/test_tiled.tif include LICENSE include dev_requirements.txt diff --git a/deepforest/data/test_tiled.tif b/deepforest/data/test_tiled.tif new file mode 100644 index 00000000..33ad487c Binary files /dev/null and b/deepforest/data/test_tiled.tif differ diff --git a/deepforest/dataset.py b/deepforest/dataset.py index a7cbf511..4d662e9d 100644 --- a/deepforest/dataset.py +++ b/deepforest/dataset.py @@ -195,6 +195,8 @@ class RasterDataset: raster_path (str): Path to raster file patch_size (int): Size of windows to predict on patch_overlap (float): Overlap between windows as fraction (0-1) + Returns: + A dataset of raster windows """ def __init__(self, raster_path, patch_size, patch_overlap): self.raster_path = raster_path @@ -206,6 +208,16 @@ def __init__(self, raster_path, patch_size, patch_overlap): width = src.shape[0] height = src.shape[1] + # Check is tiled + if not src.is_tiled: + raise ValueError( + "Out-of-memory dataset is selected, but raster is not tiled, " + "leading to entire raster being read into memory and defeating " + "the purpose of an out-of-memory dataset. " + "\nPlease run: " + "\ngdal_translate -of GTiff -co TILED=YES " + "to create a tiled raster" + ) # Generate sliding windows self.windows = slidingwindow.generateForSize( height, diff --git a/deepforest/main.py b/deepforest/main.py index 4c77ca75..9b598c1f 100644 --- a/deepforest/main.py +++ b/deepforest/main.py @@ -492,7 +492,7 @@ def predict_tile(self, iou_threshold: Minimum iou overlap among predictions between windows to be suppressed. Lower values suppress more boxes at edges. - in_memory: If true, the entire dataset is loaded into memory. This is useful for small datasets, but not recommended for large datasets since both the tile and the crops are stored in memory. + in_memory: If true, the entire dataset is loaded into memory, which increases speed. This is useful for small datasets, but not recommended for very large datasets. mosaic: Return a single prediction dataframe (True) or a tuple of image crops and predictions (False) sigma: variance of Gaussian function used in Gaussian Soft NMS thresh: the score thresh used to filter bboxes after soft-nms performed @@ -509,6 +509,10 @@ def predict_tile(self, - color: Deprecated bounding box color for visualizations. - thickness: Deprecated bounding box thickness for visualizations. + Raises: + - ValueError: If `raster_path` is None when `in_memory=False`. + - ValueError: If `workers` is greater than 0 when `in_memory=False`. Multiprocessing is not supported when using out-of-memory datasets, rasterio is not threadsafe. + Returns: - If `return_plot` is True, returns an image with predictions overlaid (deprecated). - If `mosaic` is True, returns a Pandas DataFrame containing the predicted diff --git a/tests/test_dataset.py b/tests/test_dataset.py index eb9e8747..a047db35 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -185,17 +185,16 @@ def test_bounding_box_dataset(): # Check the shape of the RGB tensor assert item.shape == (3, 224, 224) -def test_raster_dataset(raster_path): +def test_raster_dataset(): """Test the RasterDataset class""" from deepforest.dataset import RasterDataset import torch from torch.utils.data import DataLoader # Test initialization and context manager - ds = RasterDataset(raster_path, patch_size=256, patch_overlap=0.1) + ds = RasterDataset(get_data("test_tiled.tif"), patch_size=256, patch_overlap=0.1) # Test basic properties - assert hasattr(ds, 'raster') assert hasattr(ds, 'windows') # Test first window @@ -209,8 +208,4 @@ def test_raster_dataset(raster_path): dataloader = DataLoader(ds, batch_size=2, num_workers=0) batch = next(iter(dataloader)) assert batch.shape[0] == 2 # Batch size - assert batch.shape[1] == 3 # Channels first - - # Test that context manager closed the raster - ds.close() - assert ds.raster.closed + assert batch.shape[1] == 3 # Channels first \ No newline at end of file