Skip to content

Commit

Permalink
make stats and histogram optional #466 (#467)
Browse files Browse the repository at this point in the history
* make stats and histogram optional #466

* update docstring and changelog

* update comment

* fix(docs): use PR for changelog

* Update tests/core/test_add_raster.py

* Update tests/core/test_add_raster.py

---------

Co-authored-by: Pete Gadomski <[email protected]>
  • Loading branch information
thomas-maschler and gadomski authored Oct 2, 2023
1 parent 53b99ad commit 149f3c8
Show file tree
Hide file tree
Showing 3 changed files with 204 additions and 11 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

- Make computation of statistics and histogram optional for `core.add_raster.add_raster_to_item` ([#467](https://github.com/stac-utils/stactools/pull/467))

## [0.5.2] - 2023-09-20

### Fixed
Expand Down
37 changes: 26 additions & 11 deletions src/stactools/core/add_raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,15 @@
BINS = 256


def add_raster_to_item(item: Item) -> Item:
def add_raster_to_item(
item: Item, statistics: bool = True, histogram: bool = True
) -> Item:
"""Adds the raster extension to an item.
Args:
item (Item): The PySTAC Item to extend.
statistics (bool): Compute band statistics (min/max). Defaults to True
histogram (bool): Compute band histogram. Defaults to True
Returns:
Item:
Expand All @@ -34,27 +38,38 @@ def add_raster_to_item(item: Item) -> Item:
if asset.roles and "data" in asset.roles:
raster = RasterExtension.ext(asset)
href = make_absolute_href(asset.href, item.get_self_href())
bands = _read_bands(href)
bands = _read_bands(href, statistics, histogram)
if bands:
raster.apply(bands)
return item


def _read_bands(href: str) -> List[RasterBand]:
def _read_bands(href: str, statistics: bool, histogram: bool) -> List[RasterBand]:
bands = []
with rasterio.open(href) as dataset:
for i, index in enumerate(dataset.indexes):
data = dataset.read(index, masked=True)
band = RasterBand.create()
band.nodata = dataset.nodatavals[i]
band.spatial_resolution = dataset.transform[0]
band.data_type = DataType(dataset.dtypes[i])
minimum = float(numpy.min(data))
maximum = float(numpy.max(data))
band.statistics = Statistics.create(minimum=minimum, maximum=maximum)
hist_data, _ = numpy.histogram(data, range=(minimum, maximum), bins=BINS)
band.histogram = Histogram.create(
BINS, minimum, maximum, hist_data.tolist()
)

if statistics or histogram:
data = dataset.read(index, masked=True)
minimum = float(numpy.nanmin(data))
maximum = float(numpy.nanmax(data))
if statistics:
band.statistics = Statistics.create(minimum=minimum, maximum=maximum)
if histogram:
# the entire array is masked, or all values are NAN.
# won't be able to compute histogram and will return empty array.
if numpy.isnan(minimum):
band.histogram = Histogram.create(0, minimum, maximum, [])
else:
hist_data, _ = numpy.histogram(
data, range=(minimum, maximum), bins=BINS
)
band.histogram = Histogram.create(
BINS, minimum, maximum, hist_data.tolist()
)
bands.append(band)
return bands
176 changes: 176 additions & 0 deletions tests/core/test_add_raster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import tempfile
from typing import Callable, List, Optional

import numpy as np
import pystac
import pytest
import rasterio
from rasterio.crs import CRS
from rasterio.transform import Affine
from stactools.core import create
from stactools.core.add_raster import add_raster_to_item


def random_data(count: int) -> np.ndarray:
return np.random.rand(count, 10, 10) * 10


def nan_data(count: int) -> np.ndarray:
data = np.empty((count, 10, 10))
data[:] = np.nan
return data


def data_with_nan(count: int) -> np.ndarray:
data = np.random.rand(count, 10, 10) * 10
data[0][1][1] = np.nan
return data


def zero_data(count: int) -> np.ndarray:
return np.zeros((count, 10, 10))


def test_add_raster(tmp_asset_path) -> None:
item = create.item(tmp_asset_path)
add_raster_to_item(item)

asset: pystac.Asset = item.assets["data"]

_assert_asset(
asset,
expected_count=1,
expected_nodata=None,
expected_spatial_resolution=60.0,
expected_dtype=np.dtype("uint8"),
expected_min=[74.0],
expected_max=[255.0],
)


@pytest.mark.parametrize(
"count,nodata,dtype,datafunc,hist_count",
[
(1, 0, np.dtype("int8"), random_data, 256),
(1, None, np.dtype("float64"), random_data, 256),
(1, np.nan, np.dtype("float64"), random_data, 256),
(2, 0, np.dtype("int8"), random_data, 256),
(2, None, np.dtype("float64"), random_data, 256),
(2, np.nan, np.dtype("float64"), random_data, 256),
(1, 0, np.dtype("uint8"), zero_data, 0),
(1, None, np.dtype("uint8"), zero_data, 256),
(1, None, np.dtype("float64"), nan_data, 0),
(1, np.nan, np.dtype("float64"), nan_data, 0),
(1, None, np.dtype("float64"), data_with_nan, 256),
(1, np.nan, np.dtype("float64"), data_with_nan, 256),
],
)
def test_add_raster_with_nodata(
count: int, nodata: float, dtype: np.dtype, datafunc: Callable, hist_count: int
) -> None:
with tempfile.NamedTemporaryFile(suffix=".tif") as tmpfile:
with rasterio.open(
tmpfile.name,
mode="w",
driver="GTiff",
count=count,
nodata=nodata,
dtype=dtype,
transform=Affine(0.1, 0.0, 1.0, 0.0, -0.1, 1.0),
width=10,
height=10,
crs=CRS.from_epsg(4326),
) as dst:
data = datafunc(count)
data.astype(dtype)
dst.write(data)

with rasterio.open(tmpfile.name) as src:
data = src.read(masked=True)
minimum = []
maximum = []
for i, _ in enumerate(src.indexes):
minimum.append(float(np.nanmin(data[i])))
maximum.append(float(np.nanmax(data[i])))

item = create.item(tmpfile.name)

add_raster_to_item(item)

asset: pystac.Asset = item.assets["data"]
_assert_asset(
asset,
expected_count=count,
expected_nodata=nodata,
expected_spatial_resolution=0.1,
expected_dtype=dtype,
expected_min=minimum,
expected_max=maximum,
expected_hist_count=hist_count,
)


def test_add_raster_without_stats(tmp_asset_path) -> None:
item = create.item(tmp_asset_path)
add_raster_to_item(item, statistics=False)

asset: pystac.Asset = item.assets["data"]
bands = asset.extra_fields.get("raster:bands")

assert bands[0].get("statistics") is None
assert bands[0].get("histogram")


def test_add_raster_without_histogram(tmp_asset_path) -> None:
item = create.item(tmp_asset_path)
add_raster_to_item(item, histogram=False)

asset: pystac.Asset = item.assets["data"]
bands = asset.extra_fields.get("raster:bands")

assert bands[0].get("statistics")
assert bands[0].get("histogram") is None


def _assert_asset(
asset: pystac.Asset,
expected_count: int,
expected_nodata: Optional[float],
expected_dtype: np.dtype,
expected_spatial_resolution: float,
expected_min: List[float],
expected_max: List[float],
expected_hist_count=256,
) -> None:
bands = asset.extra_fields.get("raster:bands")
assert bands
assert len(bands) == expected_count

for i, band in enumerate(bands):
nodata = band.get("nodata")
dtype = band["data_type"].value
spatial_resolution = band["spatial_resolution"]
statistics = band["statistics"]
histogram = band["histogram"]
assert nodata == expected_nodata or (
np.isnan(nodata) and np.isnan(expected_nodata)
)
assert dtype == expected_dtype.name
assert spatial_resolution == expected_spatial_resolution
assert statistics == {
"minimum": expected_min[i],
"maximum": expected_max[i],
} or (
np.isnan(statistics["maximum"])
and np.isnan(expected_max[i])
and np.isnan(statistics["minimum"])
and np.isnan(expected_min[i])
)
assert histogram["count"] == expected_hist_count
assert histogram["max"] == band["statistics"]["maximum"] or (
np.isnan(histogram["max"]) and np.isnan(statistics["maximum"])
)
assert histogram["min"] == band["statistics"]["minimum"] or (
np.isnan(histogram["min"]) and np.isnan(statistics["minimum"])
)
assert len(histogram["buckets"]) == histogram["count"]

0 comments on commit 149f3c8

Please sign in to comment.