Skip to content

Commit

Permalink
Merge pull request #15 from dlr-eoc/add_four_band_model
Browse files Browse the repository at this point in the history
Add four band model
  • Loading branch information
MWieland authored May 12, 2023
2 parents 3570aa9 + 336ce32 commit 2470575
Show file tree
Hide file tree
Showing 28 changed files with 96 additions and 40 deletions.
Empty file modified .github/ISSUE_TEMPLATE/bug_report.md
100644 → 100755
Empty file.
Empty file modified .github/ISSUE_TEMPLATE/feature_request.md
100644 → 100755
Empty file.
Empty file modified .github/workflows/black.yml
100644 → 100755
Empty file.
Empty file modified .github/workflows/python-publish.yml
100644 → 100755
Empty file.
Empty file modified .github/workflows/pythonapp.yml
100644 → 100755
Empty file.
Empty file modified .gitignore
100644 → 100755
Empty file.
6 changes: 6 additions & 0 deletions CHANGELOG.rst
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
Changelog
=========

[0.1.9] (2023-05-12)

--------------------
Added
*******
- support for 4 band (R-G-B-NIR) images

[0.1.8] (2023-01-30)
--------------------
Expand Down
Empty file modified CITATION.cff
100644 → 100755
Empty file.
Empty file modified CODE_OF_CONDUCT.md
100644 → 100755
Empty file.
Empty file modified CONTRIBUTING.md
100644 → 100755
Empty file.
Empty file modified DLR_Individual_Contributor_License_Agreement_UKIS.pdf
100644 → 100755
Empty file.
Empty file modified LICENSE
100644 → 100755
Empty file.
Empty file modified MANIFEST.in
100644 → 100755
Empty file.
6 changes: 3 additions & 3 deletions README.md
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
[![Code Style](https://img.shields.io/badge/code%20style-black-000000.svg)](https://black.readthedocs.io/en/stable/)
[![DOI](https://zenodo.org/badge/328616234.svg)](https://zenodo.org/badge/latestdoi/328616234)

UKIS Cloud Shadow MASK (ukis-csmask) package masks clouds and cloud shadows in Sentinel-2, Landsat-9, Landsat-8, Landsat-7 and Landsat-5 images. Masking is performed with a pre-trained convolution neural network. It is fast and works directly on Level-1C data (no atmospheric correction required). Images just need to be in Top Of Atmosphere (TOA) reflectance and include at least the "Blue", "Green", "Red", "NIR", "SWIR1" and "SWIR2" spectral bands. Best performance (in terms of accuracy and speed) is achieved when images are resampled to approximately 30 m spatial resolution.
UKIS Cloud Shadow MASK (ukis-csmask) package masks clouds and cloud shadows in Sentinel-2, Landsat-9, Landsat-8, Landsat-7 and Landsat-5 images. Masking is performed with a pre-trained convolution neural network. It is fast and works directly on Level-1C data (no atmospheric correction required). Images just need to be in Top Of Atmosphere (TOA) reflectance and include at least the "Blue", "Green", "Red" and "NIR" spectral bands. Best performance (in terms of accuracy and speed) is achieved when images also include "SWIR1" and "SWIR2" spectral bands and are resampled to approximately 30 m spatial resolution.

This [publication](https://doi.org/10.1016/j.rse.2019.05.022) provides further insight into the underlying algorithm and compares it to the widely used [Fmask](http://www.pythonfmask.org/en/latest/) algorithm across a heterogeneous test dataset.

Expand Down Expand Up @@ -41,8 +41,8 @@ img.warp(
)

# compute cloud and cloud shadow mask
# NOTE: band_order must match the order of bands in the input image. it does not have to be in this explicit order,
# but needs to include these six spectral bands.
# NOTE: band_order must match the order of bands in the input image. it does not have to be in this explicit order.
# make sure to use these six spectral bands to get best performance
csmask = CSmask(
img=img.arr,
band_order=["Blue", "Green", "Red", "NIR", "SWIR1", "SWIR2"],
Expand Down
Empty file modified img/examples.png
100644 → 100755
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file modified img/ukis-logo.png
100644 → 100755
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file modified requirements.txt
100644 → 100755
Empty file.
Empty file modified setup.py
100644 → 100755
Empty file.
49 changes: 47 additions & 2 deletions tests/test_mask.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
(np.empty((256, 256, 6), np.float32), ["Blue", "Green", "Red", "NIR", "SWIR1", "SWIR2"], None),
(np.empty((256, 256, 8), np.float32), ["Green", "Red", "Blue", "NIR", "SWIR1", "SWIR2", "NIR2", "ETC"], None),
(np.empty((256, 256, 6), np.float32), ["Green", "Red", "Blue", "NIR", "SWIR1", "SWIR2"], -666),
(np.empty((256, 256, 4), np.float32), ["Red", "Green", "Blue", "NIR"], 0),
(np.empty((256, 256, 5), np.float32), ["Red", "Green", "Blue", "NIR", "SWIR2"], 0),
],
)
def test_csmask_init(img, band_order, nodata_value):
Expand All @@ -31,6 +33,7 @@ def test_csmask_init(img, band_order, nodata_value):
(np.empty((256, 256, 3), np.float32), ["Blue", "Green", "Red", "NIR", "SWIR1", "SWIR2"], None),
(np.empty((256, 256, 6), np.uint8), ["Blue", "Green", "Red", "NIR", "SWIR1", "SWIR2"], None),
(np.empty((256, 256, 6), np.float32), ["Blue", "Green", "Yellow", "NIR", "SWIR1", "SWIR2"], None),
(np.empty((256, 256, 6), np.float32), None, None),
],
)
def test_csmask_init_raises(img, band_order, nodata_value):
Expand Down Expand Up @@ -60,7 +63,7 @@ def test_csmask_init_warns(img, band_order, nodata_value):
np.load(r"tests/testfiles/landsat5.npz"),
],
)
def test_csmask_csm(data):
def test_csmask_csm_6band(data):
csmask = CSmask(img=data["img"], band_order=["Blue", "Green", "Red", "NIR", "SWIR1", "SWIR2"])
y_pred = csmask.csm
y_true = reclassify(data["msk"], {"reclass_value_from": [0, 1, 2, 3, 4], "reclass_value_to": [2, 0, 0, 0, 1]})
Expand All @@ -80,7 +83,7 @@ def test_csmask_csm(data):
np.load(r"tests/testfiles/landsat5.npz"),
],
)
def test_csmask_valid(data):
def test_csmask_valid_6band(data):
csmask = CSmask(img=data["img"], band_order=["Blue", "Green", "Red", "NIR", "SWIR1", "SWIR2"])
y_pred = csmask.valid
y_true = reclassify(data["msk"], {"reclass_value_from": [0, 1, 2, 3, 4], "reclass_value_to": [0, 1, 1, 1, 0]})
Expand All @@ -90,3 +93,45 @@ def test_csmask_valid(data):
y_pred = y_pred.ravel()
kappa = round(cohen_kappa_score(y_true, y_pred), 2)
assert kappa >= 0.75


@pytest.mark.filterwarnings("ignore::UserWarning")
@pytest.mark.parametrize(
"data",
[
np.load(r"tests/testfiles/sentinel2.npz"),
np.load(r"tests/testfiles/landsat8.npz"),
np.load(r"tests/testfiles/landsat7.npz"),
np.load(r"tests/testfiles/landsat5.npz"),
],
)
def test_csmask_csm_4band(data):
csmask = CSmask(img=data["img"], band_order=["Blue", "Green", "Red", "NIR"])
y_pred = csmask.csm
y_true = reclassify(data["msk"], {"reclass_value_from": [0, 1, 2, 3, 4], "reclass_value_to": [2, 0, 0, 0, 1]})
y_true = y_true.ravel()
y_pred = y_pred.ravel()
kappa = round(cohen_kappa_score(y_true, y_pred), 2)
assert kappa >= 0.50


@pytest.mark.filterwarnings("ignore::UserWarning")
@pytest.mark.parametrize(
"data",
[
np.load(r"tests/testfiles/sentinel2.npz"),
np.load(r"tests/testfiles/landsat8.npz"),
np.load(r"tests/testfiles/landsat7.npz"),
np.load(r"tests/testfiles/landsat5.npz"),
],
)
def test_csmask_valid_4band(data):
csmask = CSmask(img=data["img"], band_order=["Blue", "Green", "Red", "NIR"])
y_pred = csmask.valid
y_true = reclassify(data["msk"], {"reclass_value_from": [0, 1, 2, 3, 4], "reclass_value_to": [0, 1, 1, 1, 0]})
y_true_inverted = ~y_true.astype(bool)
y_true = (~ndimage.binary_dilation(y_true_inverted, iterations=4).astype(bool)).astype(np.uint8)
y_true = y_true.ravel()
y_pred = y_pred.ravel()
kappa = round(cohen_kappa_score(y_true, y_pred), 2)
assert kappa >= 0.50
Empty file modified tests/testfiles/landsat5.npz
100644 → 100755
Empty file.
Empty file modified tests/testfiles/landsat7.npz
100644 → 100755
Empty file.
Empty file modified tests/testfiles/landsat8.npz
100644 → 100755
Empty file.
Empty file modified tests/testfiles/sentinel2.npz
100644 → 100755
Empty file.
2 changes: 1 addition & 1 deletion ukis_csmask/__init__.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.1.8"
__version__ = "0.1.9"
73 changes: 39 additions & 34 deletions ukis_csmask/mask.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -24,32 +24,30 @@ class CSmask:
def __init__(
self,
img,
band_order=None,
band_order,
nodata_value=None,
invalid_buffer=4,
):
"""
:param img: Input satellite image of shape (rows, cols, bands). (ndarray).
Requires images of Sentinel-2, Landsat-8, -7 or -5 in Top of Atmosphere reflectance [0, 1].
Requires image bands to include at least "Blue", "Green", "Red", "NIR", "SWIR1", "SWIR2".
Requires image bands to be in approximately 30 m resolution.
:param band_order: Image band order. (dict).
>>> band_order = {0: "Blue", 1: "Green", 2: "Red", 3: "NIR", 4: "SWIR1", 5: "SWIR2"}
Requires satellite images in Top of Atmosphere reflectance [0, 1].
Requires image bands to include at least "Blue", "Green", "Red", "NIR" (uses 4 band model).
For better performance requires image bands to include "Blue", "Green", "Red", "NIR", "SWIR1", "SWIR2" (runs 6 band model).
For better performance requires image bands to be in approximately 30 m resolution.
:param band_order: Image band order. (list of string).
>>> band_order = ["Blue", "Green", "Red", "NIR", "SWIR1", "SWIR2"]
:param nodata_value: Additional nodata value that will be added to valid mask. (num).
:param invalid_buffer: Number of pixel that should be buffered around invalid areas.
:param invalid_buffer: Number of pixels that should be buffered around invalid areas.
"""
# consistency checks on input image
if band_order is None:
band_order = ["Blue", "Green", "Red", "NIR", "SWIR1", "SWIR2"]

if isinstance(img, np.ndarray) is False:
raise TypeError("img must be of type np.ndarray")

if img.ndim != 3:
raise TypeError("img must be of shape (rows, cols, bands)")

if img.shape[2] < 6:
raise TypeError("img must contain at least 6 spectral bands")
if img.shape[2] < 4:
raise TypeError("img must contain at least 4 spectral bands")

if img.dtype != np.float32:
raise TypeError("img must be in top of atmosphere reflectance with dtype float32")
Expand All @@ -62,28 +60,37 @@ def __init__(
)

# consistency checks on band_order
target_band_order = ["Blue", "Green", "Red", "NIR", "SWIR1", "SWIR2"]
if band_order != target_band_order:
if all(elem in band_order for elem in target_band_order):
# rearrange image bands to match target_band_order
idx = np.array(
[
np.where(band == np.array(band_order, dtype="S"))[0][0]
for band in np.array(target_band_order, dtype="S")
]
)
img = np.stack(np.asarray([img[:, :, i] for i in range(img.shape[2])])[idx], axis=2)
else:
raise TypeError(
"img must contain at least ['Blue', 'Green', 'Red', 'NIR', 'SWIR1', 'SWIR2'] spectral bands"
)
if band_order is None:
raise TypeError("band_order cannot be None")

if all(elem in band_order for elem in ["Blue", "Green", "Red", "NIR", "SWIR1", "SWIR2"]):
target_band_order = ["Blue", "Green", "Red", "NIR", "SWIR1", "SWIR2"]
band_mean = [0.19312, 0.18659, 0.18899, 0.30362, 0.23085, 0.16216]
band_std = [0.16431, 0.16762, 0.18230, 0.17409, 0.16020, 0.14164]
model_file = str(Path(__file__).parent) + "/model_6b.onnx"
elif all(elem in band_order for elem in ["Blue", "Green", "Red", "NIR"]):
target_band_order = ["Blue", "Green", "Red", "NIR"]
band_mean = [0.19312, 0.18659, 0.18899, 0.30362]
band_std = [0.16431, 0.16762, 0.18230, 0.17409]
model_file = str(Path(__file__).parent) + "/model_4b.onnx"
else:
# use image bands as are
img = np.stack(np.asarray([img[:, :, i] for i in range(img.shape[2])]), axis=2)
raise TypeError(
f"band_order must contain at least 'Blue', 'Green', 'Red', 'NIR' "
f"and for better performance also 'SWIR1' and 'SWIR2'"
)

# rearrange image bands to match target_band_order
idx = np.array(
[np.where(band == np.array(band_order, dtype="S"))[0][0] for band in np.array(target_band_order, dtype="S")]
)
img = img[:, :, idx]

self.img = img
self.band_order = band_order
self.band_mean = band_mean
self.band_std = band_std
self.nodata_value = nodata_value
self.model_file = model_file
self.csm = self._csm()
self.valid = self._valid(invalid_buffer)

Expand All @@ -96,13 +103,11 @@ def _csm(self):
x = tile_array(self.img, xsize=256, ysize=256, overlap=0.2)

# standardize feature space
x -= [0.19312, 0.18659, 0.18899, 0.30362, 0.23085, 0.16216]
x /= [0.16431, 0.16762, 0.18230, 0.17409, 0.16020, 0.14164]
x -= self.band_mean
x /= self.band_std

# start onnx inference session and load model
sess = onnxruntime.InferenceSession(
str(Path(__file__).parent) + "/model.onnx", providers=onnxruntime.get_available_providers()
)
sess = onnxruntime.InferenceSession(self.model_file, providers=onnxruntime.get_available_providers())

# predict on array tiles
y_prob = [sess.run(None, {"input_1": tile[np.newaxis, :]}) for n, tile in enumerate(list(x))]
Expand Down
Binary file added ukis_csmask/model_4b.onnx
Binary file not shown.
File renamed without changes.
Empty file modified ukis_csmask/utils.py
100644 → 100755
Empty file.

0 comments on commit 2470575

Please sign in to comment.