Skip to content

Commit

Permalink
Merge pull request #32 from mpelchat04/bbox
Browse files Browse the repository at this point in the history
Fix support for STAC and replace bands_requested and bbox type for list instead of string.
  • Loading branch information
mpelchat04 authored Nov 13, 2024
2 parents 2bbe63d + 3d362b8 commit 9e02167
Show file tree
Hide file tree
Showing 10 changed files with 325 additions and 31 deletions.
33 changes: 19 additions & 14 deletions geo_inference/geo_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
import threading
import numpy as np
import xarray as xr
import rioxarray
import ttach as tta
from typing import Dict
from dask import config
import dask.array as da
from pathlib import Path
from omegaconf import ListConfig
from rasterio.windows import from_bounds
from rasterio.transform import from_origin
from typing import Union, Sequence, List
from dask.diagnostics import ProgressBar
from multiprocessing.pool import ThreadPool
Expand Down Expand Up @@ -197,7 +199,7 @@ async def async_run_inference(self,
raise TypeError(
f"Invalid raster type.\nGot {inference_input} of type {type(inference_input)}"
)
if not isinstance(bands_requested, (Sequence, ListConfig)):
if not isinstance(bands_requested, (List, ListConfig)):
raise ValueError(
f"Requested bands should be a list."
f"\nGot {bands_requested} of type {type(bands_requested)}"
Expand Down Expand Up @@ -238,9 +240,9 @@ async def async_run_inference(self,
raster_stac_item = True
except Exception:
raster_stac_item = False
self.json = None
if not raster_stac_item:
inference_input_path = Path(inference_input)
self.json = None
if os.path.splitext(inference_input_path)[1].lower() == ".zarr":
aoi_dask_array = da.from_zarr(inference_input, chunks=(1, stride_patch_size, stride_patch_size))
meta_data_json = re.sub(r'\.zarr$', '', inference_input)
Expand All @@ -249,32 +251,33 @@ async def async_run_inference(self,
with rasterio.open(inference_input, "r") as src:
self.raster_meta = src.meta
self.raster = src
import rioxarray

aoi_dask_array = rioxarray.open_rasterio(inference_input, chunks=(1, stride_patch_size, stride_patch_size))
try:
if bands_requested:
raster_bands_request = [int(b) for b in bands_requested.split(",")]
if (
len(raster_bands_request) != 0
and len(raster_bands_request) != aoi_dask_array.shape[0]
len(bands_requested) != 0
and len(bands_requested) != aoi_dask_array.shape[0]
):
if self.json is None:
aoi_dask_array = xr.concat(
[aoi_dask_array[i - 1, :, :] for i in raster_bands_request],
[aoi_dask_array[i - 1, :, :] for i in bands_requested],
dim="band"
)
else:
aoi_dask_array = da.stack(
[aoi_dask_array[i - 1, :, :] for i in raster_bands_request],
[aoi_dask_array[i - 1, :, :] for i in bands_requested],
axis =0,
)
except Exception as e:
raise e
else:
assets = asset_by_common_name(inference_input)
bands_requested = {
band: assets[band] for band in bands_requested.split(",")
}
try:
bands_requested = {band: assets[band.lower()] for band in bands_requested}
except KeyError:
raise KeyError(f"Common names of the STAC assets ({assets.keys()}) do not match provided bands_requested keys ({bands_requested}).")

rio_gdal_options = {
"GDAL_DISABLE_READDIR_ON_OPEN": "EMPTY_DIR",
"CPL_VSIL_CURL_ALLOWED_EXTENSIONS": ".tif",
Expand All @@ -290,8 +293,10 @@ async def async_run_inference(self,
del all_bands_requested

if bbox is not None:
bbox = tuple(map(float, bbox.split(", ")))
roi_window = from_bounds(
if not isinstance(bbox, (List, ListConfig)):
raise TypeError("bbox should be a list.")
bbox = tuple(map(float, bbox))
self.roi_window = from_bounds(
left=bbox[0],
bottom=bbox[1],
right=bbox[2],
Expand Down Expand Up @@ -361,7 +366,7 @@ async def async_run_inference(self,

with ProgressBar() as pbar:
pbar.register()
import rioxarray
# import rioxarray
logger.info("Inference is running:")
aoi_dask_array = xr.DataArray(aoi_dask_array[: self.original_shape[1], : self.original_shape[2]], dims=("y", "x"), attrs= self.json if self.json is not None else xarray_profile_info(self.raster_meta))
aoi_dask_array.rio.to_raster(mask_path, tiled=True, lock=threading.Lock())
Expand Down
4 changes: 2 additions & 2 deletions geo_inference/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def cmd_interface(argv=None):
"-br",
"--bands_requested",
nargs=1,
help="bands_requested in this format'R,G,B'",
help="bands_requested in this format['Red','Green','Blue'] or [1,2,3]",
)

parser.add_argument(
Expand Down Expand Up @@ -490,7 +490,7 @@ def cmd_interface(argv=None):
config = read_yaml(args.args[0])
image = config["arguments"]["image"]
model = config["arguments"]["model"]
bbox = None if config["arguments"]["bbox"].lower() == "none" else config["arguments"]["bbox"]
bbox = None if config["arguments"]["bbox"] == "None" else config["arguments"]["bbox"]
work_dir = config["arguments"]["work_dir"]
bands_requested = config["arguments"]["bands_requested"]
workers = config["arguments"]["workers"]
Expand Down
Binary file removed tests/data/inference/test_model/test_model.pt
Binary file not shown.
7 changes: 3 additions & 4 deletions tests/data/sample.yaml
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
arguments:
image: "./data/areial.tiff" # Path to Geotiff
bbox: None # "minx, miny, maxx, maxy"
image: "./data/0.tif" # Path to Geotiff
bbox: None # [minx, miny, maxx, maxy]
model: "rgb-4class-segformer" # Name of Extraction Model: str
work_dir: None # Working Directory: str
vec: False # Vector Conversion: bool
yolo: False # YOLO Conversion: bool
coco: False # COCO Conversion: bool
device: "gpu" # cpu or gpu: str
gpu_id: 0 # GPU ID: int
bands_requested: '1,2,3' # requested Bands
bands_requested: [1,2,3] # requested Bands
workers: 0
mgpu: False
classes : 5
n_workers: 20
prediction_thr : 0.3
transformers: False
transformer_flip : False
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
250 changes: 250 additions & 0 deletions tests/data/stac/SpaceNet_AOI_2_Las_Vegas.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
{
"type": "Feature",
"stac_version": "1.0.0",
"stac_extensions": [
"https://stac-extensions.github.io/projection/v1.0.0/schema.json",
"https://stac-extensions.github.io/eo/v1.0.0/schema.json"
],
"id": "SpaceNet_AOI_2_Las_Vegas",
"collection": "test",
"geometry": {
"type": "Polygon",
"coordinates": [
[
[
-53.1659631578073,
47.42365019337817
],
[
-53.167569143274676,
47.462743080832595
],
[
-53.1691785188727,
47.50183565900671
],
[
-53.17079129363732,
47.540927927817044
],
[
-53.17240747663812,
47.58001988718016
],
[
-53.17402707697844,
47.61911153701266
],
[
-53.17565010379554,
47.6582028772313
],
[
-53.17727656626078,
47.697293907752794
],
[
-53.17890647357976,
47.736384628494
],
[
-53.18053983499246,
47.77547503937181
],
[
-53.18217665977342,
47.81456514030315
],
[
-53.18274315341038,
47.82807460710128
],
[
-53.22072950009213,
47.827345731985154
],
[
-53.278774162009874,
47.826207664203224
],
[
-53.336815195757985,
47.82504027130873
],
[
-53.39276323165151,
47.823887148354494
],
[
-53.39269853547722,
47.82248011926179
],
[
-53.39090339059741,
47.783395812804095
],
[
-53.38911204395473,
47.74431118789802
],
[
-53.38732448537806,
47.70522624465101
],
[
-53.38554070473415,
47.66614098317041
],
[
-53.38376069192746,
47.627055403563595
],
[
-53.38198443690002,
47.58796950593788
],
[
-53.3802119296312,
47.54888329040058
],
[
-53.37844316013759,
47.50979675705901
],
[
-53.37667811847282,
47.470709906020375
],
[
-53.37491679472736,
47.43162273739196
],
[
-53.37437224107312,
47.41952128561759
],
[
-53.33461086840071,
47.42033831586188
],
[
-53.277016005703906,
47.42149729651587
],
[
-53.1659631578073,
47.42365019337817
]
]
]
},
"bbox": [
-53.39276323165151,
47.41952128561759,
-53.1659631578073,
47.82807460710128
],
"properties": {
"created": "2024-08-23T02:41:12Z",
"updated": "2024-08-23T02:41:14.149265Z",
"datetime": "2023-10-30T14:25:29Z",
"proj:epsg": 32622,
"collection": "test",
"proj:shape": [
103456,
36184
],
"proj:geometry": {
"type": "Polygon",
"coordinates": [
[
[
336640.77805401257,
5254518.836411794
],
[
336640.77805401257,
5299498.415712714
],
[
320909.0554141993,
5299498.415712714
],
[
320909.0554141993,
5254518.836411794
],
[
336640.77805401257,
5254518.836411794
]
]
]
},
"eo:cloud_cover": 0,
"proj:transform": [
320909.0554141993,
0.434770137072,
0,
5299498.415712714,
0,
-0.434770137072
]
},
"assets": {
"B": {
"href": "./tests/data/stac/SpaceNet_AOI_2_Las_Vegas-056155973080_01_P001-WV03-B.tif",
"type": "image/tiff; application=geotiff; profile=cloud-optimized",
"roles": [
"data"
],
"title": "Blue band",
"eo:bands": [
{
"name": "b",
"common_name": "blue",
"description": "sensor:GE01, min_wavelength_nm: 450, max_wavelength_nm: 510, orthorectified, pansharpened, downsampled to 8 bit",
"center_wavelength": 0.48,
"full_width_half_max": 0.06
}
],
"description": "COG - Blue Single spectral band / COG - Bande spectrale unique Bleu"
},
"G": {
"href": "./tests/data/stac/SpaceNet_AOI_2_Las_Vegas-056155973080_01_P001-WV03-G.tif",
"type": "image/tiff; application=geotiff; profile=cloud-optimized",
"roles": [
"data"
],
"title": "Green band",
"eo:bands": [
{
"name": "g",
"common_name": "green",
"description": "sensor:GE01, min_wavelength_nm: 520, max_wavelength_nm: 580, orthorectified, pansharpened, downsampled to 8 bit",
"center_wavelength": 0.55,
"full_width_half_max": 0.05999999999999994
}
],
"description": "COG - Green Single spectral band / COG - Bande spectrale unique Vert"
},
"R": {
"href": "./tests/data/stac/SpaceNet_AOI_2_Las_Vegas-056155973080_01_P001-WV03-R.tif",
"type": "image/tiff; application=geotiff; profile=cloud-optimized",
"roles": [
"data"
],
"title": "Red band",
"eo:bands": [
{
"name": "r",
"common_name": "red",
"description": "sensor:GE01, min_wavelength_nm: 655, max_wavelength_nm: 690, orthorectified, pansharpened, downsampled to 8 bit",
"center_wavelength": 0.6725,
"full_width_half_max": 0.03499999999999992
}
],
"description": "COG - Red Single spectral band / COG - Bande spectrale unique Rouge"
}
}
}
Loading

0 comments on commit 9e02167

Please sign in to comment.