Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Testing Dask Arrays #119

Merged
merged 3 commits into from
May 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion notebooks/Tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1029,7 +1029,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
"version": "3.8.10"
}
},
"nbformat": 4,
Expand Down
1 change: 1 addition & 0 deletions requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ pandera==0.14.5
shapely==2.0.1
geocube>=0.3.3
pandas==1.5.3
odc-geo @ git+https://github.com/opendatacube/odc-geo.git

###################################
# Potential future packages
Expand Down
30 changes: 22 additions & 8 deletions src/gval/comparison/tabulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
import xarray as xr
import pandera as pa
from pandera.typing import DataFrame
import dask

from gval.utils.schemas import Xrspatial_crosstab_df, Crosstab_df
from gval.homogenize.spatial_alignment import _check_dask_array


@pa.check_types
Expand All @@ -45,6 +47,8 @@ def _convert_crosstab_to_contigency_table(
DataFrame[Crosstab_df]
Crosstab DataFrame using candidate and benchmark conventions.
"""
if isinstance(crosstab_df, dask.dataframe.core.DataFrame):
crosstab_df = crosstab_df.compute()

# renames zone, renames column index, melts dataframe, then resets the index.
crosstab_df = (
Expand All @@ -63,7 +67,7 @@ def _convert_crosstab_to_contigency_table(
@pa.check_types
def _compute_agreement_values(
crosstab_df: DataFrame[Crosstab_df],
comparison_function: Callable[[float, float], float],
comparison_function: Callable[..., float],
) -> DataFrame[Crosstab_df]:
"""
Computes agreement values from Crosstab DataFrame.
Expand All @@ -81,12 +85,12 @@ def _compute_agreement_values(
Crosstab DataFrame with agreement values.
"""

# copy crosstab_df
crosstab_df = crosstab_df.copy()

def apply_pairing_function(row):
return comparison_function(row["candidate_values"], row["benchmark_values"])

# copy crosstab_df
crosstab_df = crosstab_df.copy()

agreement_values = crosstab_df.apply(apply_pairing_function, axis=1)

crosstab_df.insert(3, "agreement_values", agreement_values)
Expand Down Expand Up @@ -162,10 +166,14 @@ def _crosstab_2d_DataArrays(
allow_candidate_values: Optional[Iterable[Number]] = None,
allow_benchmark_values: Optional[Iterable[Number]] = None,
exclude_value: Optional[Number] = None,
comparison_function: Optional[Callable[[float, float], float]] = None,
comparison_function: Optional[Callable[..., float]] = None,
) -> DataFrame[Crosstab_df]:
"""Please see `_crosstab_docstring` function decorator for docstring"""

if _check_dask_array(candidate_map):
candidate_map = candidate_map.drop("spatial_ref")
benchmark_map = benchmark_map.drop("spatial_ref")

crosstab_df = crosstab(
zones=candidate_map,
values=benchmark_map,
Expand Down Expand Up @@ -194,7 +202,7 @@ def _crosstab_3d_DataArrays(
allow_candidate_values: Optional[Iterable[Number]] = None,
allow_benchmark_values: Optional[Iterable[Number]] = None,
exclude_value: Optional[Number] = None,
comparison_function: Optional[Callable[[float, float], float]] = None,
comparison_function: Optional[Callable[..., float]] = None,
) -> DataFrame[Crosstab_df]:
"""Please see `_crosstab_docstring` function decorator for docstring"""

Expand Down Expand Up @@ -270,7 +278,7 @@ def _crosstab_DataArrays(
allow_candidate_values: Optional[Iterable[Number]] = None,
allow_benchmark_values: Optional[Iterable[Number]] = None,
exclude_value: Optional[Number] = None,
comparison_function: Optional[Callable[[float, float], float]] = None,
comparison_function: Optional[Callable[..., float]] = None,
) -> DataFrame[Crosstab_df]:
"""Please see `_crosstab_docstring` function decorator for docstring"""

Expand Down Expand Up @@ -304,10 +312,16 @@ def _crosstab_Datasets(
allow_candidate_values: Optional[Iterable[Number]] = None,
allow_benchmark_values: Optional[Iterable[Number]] = None,
exclude_value: Optional[Number] = None,
comparison_function: Optional[Callable[[float, float], float]] = None,
comparison_function: Optional[Callable[..., float]] = None,
) -> DataFrame[Crosstab_df]:
"""Please see `_crosstab_docstring` function decorator for docstring"""

if _check_dask_array(candidate_map):
# TODO: Currently there is an issue open on xarray spatial regarding dask dataset useage in crosstab
# https://github.com/makepath/xarray-spatial/issues/777
candidate_map = candidate_map.compute()
benchmark_map = benchmark_map.compute()

# gets variable names
candidate_variable_names = list(candidate_map.data_vars)
benchmark_variable_names = list(benchmark_map.data_vars)
Expand Down
23 changes: 23 additions & 0 deletions src/gval/homogenize/numeric_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,26 @@ def _align_numeric_data_type(
return _align_numeric_dtype(candidate_map, benchmark_map)
else:
return _align_datasets_dtype(candidate_map, benchmark_map)


def _check_dask_array(original_map: Union[xr.DataArray, xr.Dataset]) -> bool:
"""
Check whether map to be reprojected has dask data or not

Parameters
----------
original_map: Union[xr.DataArray, xr.Dataset]
Map to be reprojected

Returns
-------
bool
Whether the data is a dask array
"""

chunks = (
original_map["band_1"].chunks
if isinstance(original_map, xr.Dataset)
else original_map.chunks
)
return chunks is not None
63 changes: 54 additions & 9 deletions src/gval/homogenize/spatial_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,17 @@
# __all__ = ['*']
__author__ = "Fernando Aristizabal"

from typing import (
Optional,
Tuple,
Union,
)
from typing import Optional, Tuple, Union

import xarray as xr
from shapely.geometry import box
from rasterio.enums import Resampling

from gval.utils.exceptions import RasterMisalignment, RastersDontIntersect
from gval.homogenize.numeric_alignment import _check_dask_array
from odc.geo.xr import ODCExtensionDa

ODCExtensionDa


def _matching_crs(
Expand Down Expand Up @@ -139,6 +139,51 @@ def _rasters_intersect(
return rasters_intersect_bool


def _reproject_map(
original_map: Union[xr.DataArray, xr.Dataset],
target_map: Union[xr.DataArray, xr.Dataset],
resampling: str,
) -> Union[xr.DataArray, xr.Dataset]:
"""

Parameters
----------
original_map: Union[xr.DataArray, xr.Dataset]
Map to be reprojected
target_map: Union[xr.DataArray, xr.Dataset]
Map to use for extent, resolution, and spatial reference
resampling: str
Method to resample changing resolutions

Returns
-------
Union[xr.DataArray, xr.Dataset]
Reprojected map
"""

is_dst, is_dask = isinstance(original_map, xr.Dataset), _check_dask_array(
original_map
)

if not is_dask:
return original_map.rio.reproject_match(target_map, resampling)
else:
nodata = target_map["band_1"].rio.nodata if is_dst else target_map.rio.nodata
reproj = original_map.odc.reproject(
target_map.odc.geobox, tight=True, dst_nodata=nodata
)

# Coordinates need to be aligned
reproj_coords = reproj.rename({"longitude": "x", "latitude": "y"})
del reproj
# Coordinates are virtually the same but 1e-8 or so is rounded differently
final_reproj = reproj_coords.assign_coords(
{"x": target_map.coords["x"], "y": target_map.coords["y"]}
)
del reproj_coords
return final_reproj


def _align_rasters(
candidate_map: Union[xr.DataArray, xr.Dataset],
benchmark_map: Union[xr.DataArray, xr.Dataset],
Expand Down Expand Up @@ -205,16 +250,16 @@ def ensure_nodata_value_is_set(dataset_or_dataarray):

# align benchmark and candidate to target
elif isinstance(target_map, (xr.DataArray, xr.Dataset)):
candidate_map = candidate_map.rio.reproject_match(target_map, resampling)
benchmark_map = benchmark_map.rio.reproject_match(target_map, resampling)
candidate_map = _reproject_map(candidate_map, target_map, resampling)
benchmark_map = _reproject_map(benchmark_map, target_map, resampling)

# match candidate to benchmark
elif target_map == "benchmark":
candidate_map = candidate_map.rio.reproject_match(benchmark_map, resampling)
candidate_map = _reproject_map(candidate_map, benchmark_map, resampling)

# match benchmark to candidate
elif target_map == "candidate":
benchmark_map = benchmark_map.rio.reproject_match(candidate_map, resampling)
benchmark_map = _reproject_map(benchmark_map, candidate_map, resampling)

else:
raise ValueError(
Expand Down
1 change: 1 addition & 0 deletions src/gval/utils/loading_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def load_raster_as_xarray(
default_name=default_name,
band_as_variable=band_as_variable,
masked=masked,
chunks=chunks,
mask_and_scale=mask_and_scale,
**open_kwargs,
)
26 changes: 20 additions & 6 deletions tests/cases_accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@
band=1, drop=True
),
_load_xarray("candidate_map_0_accessor.tif", mask_and_scale=True),
_load_xarray("candidate_map_0_accessor.tif", mask_and_scale=True, chunks="auto"),
]
benchmark_maps = [
_load_xarray("benchmark_map_0_accessor.tif", mask_and_scale=True),
_load_xarray("benchmark_map_0_accessor.tif", mask_and_scale=True).sel(
band=1, drop=True
),
_load_gpkg("polygons_two_class_categorical.gpkg"),
_load_xarray("benchmark_map_0_accessor.tif", mask_and_scale=True, chunks="auto"),
]
candidate_datasets = [
_load_xarray(
Expand All @@ -31,17 +33,29 @@
_load_xarray(
"candidate_map_0_accessor.tif", mask_and_scale=True, band_as_variable=True
),
_load_xarray(
"candidate_map_0_accessor.tif",
mask_and_scale=True,
band_as_variable=True,
chunks="auto",
),
]
benchmark_datasets = [
_load_xarray(
"benchmark_map_0_accessor.tif", mask_and_scale=True, band_as_variable=True
),
_load_gpkg("polygons_two_class_categorical.gpkg"),
_load_xarray(
"benchmark_map_0_accessor.tif",
mask_and_scale=True,
band_as_variable=True,
chunks="auto",
),
]

positive_cat = np.array([2, 2, 2])
negative_cat = np.array([[0, 1], [0, 1], [0, 1]])
rasterize_attrs = [None, None, ["category"]]
positive_cat = np.array([2, 2, 2, 2])
negative_cat = np.array([[0, 1], [0, 1], [0, 1], [0, 1]])
rasterize_attrs = [None, None, ["category"], None]


@parametrize(
Expand Down Expand Up @@ -142,9 +156,9 @@ def case_data_array_accessor_crosstab_table_fail(
zip(
candidate_datasets,
benchmark_datasets,
positive_cat[0:2],
negative_cat[0:2],
[rasterize_attrs[0], rasterize_attrs[2]],
positive_cat[0:3],
negative_cat[0:3],
[rasterize_attrs[0], rasterize_attrs[2], rasterize_attrs[0]],
)
),
)
Expand Down
20 changes: 13 additions & 7 deletions tests/cases_homogenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,26 +104,32 @@ def case_rasters_intersect_exception(


@parametrize(
"candidate_map_fn, benchmark_map_fn, resampling, target_map",
"candidate_map_fn, benchmark_map_fn, resampling, target_map, dataset, chunks",
list(
zip(
candidate_map_fns[[0, 1, 1, 1, 1]],
benchmark_map_fns[[0, 1, 1, 1, 1]],
[{}, {}, {}, {}, {"resampling": Resampling.bilinear}],
candidate_map_fns[[0, 1, 1, 1, 1, 1, 1]],
benchmark_map_fns[[0, 1, 1, 1, 1, 1, 1]],
[{}, {}, {}, {}, {"resampling": Resampling.bilinear}, {}, {}],
[
"candidate",
"benchmark",
_load_xarray("target_map_0.tif"),
"candidate",
"candidate",
"candidate",
"candidate",
],
[False, False, False, False, False, False, True],
[None, None, None, None, None, "auto", "auto"],
)
),
)
def case_align_rasters(candidate_map_fn, benchmark_map_fn, target_map, resampling):
def case_align_rasters(
candidate_map_fn, benchmark_map_fn, target_map, resampling, dataset, chunks
):
return (
_load_xarray(candidate_map_fn),
_load_xarray(benchmark_map_fn),
_load_xarray(candidate_map_fn, chunks=chunks, band_as_variable=dataset),
_load_xarray(benchmark_map_fn, chunks=chunks, band_as_variable=dataset),
target_map,
resampling,
)
Expand Down