diff --git a/ocf_datapipes/utils/consts.py b/ocf_datapipes/utils/consts.py index 2acb71447..e2887860f 100644 --- a/ocf_datapipes/utils/consts.py +++ b/ocf_datapipes/utils/consts.py @@ -5,7 +5,7 @@ import numpy as np import xarray as xr -from pydantic import BaseModel +from pydantic import BaseModel, validator PV_TIME_AXIS = 1 PV_SYSTEM_AXIS = 2 @@ -144,10 +144,51 @@ class Location(BaseModel): """Represent a spatial location.""" + coordinate_system: Optional[str] = "osgb" # ["osgb", "lat_lon"] x: float y: float id: Optional[int] + @validator("coordinate_system", pre=True, always=True) + def validate_coordinate_system(cls, v): + """Validate 'coordinate_system'""" + allowed_coordinate_systen = ["osgb", "lat_lon"] + if v not in allowed_coordinate_systen: + raise ValueError(f"coordinate_system = {v} is not in {allowed_coordinate_systen}") + return v + + @validator("x") + def validate_x(cls, v, values): + """Validate 'x'""" + min_x: float + max_x: float + if "coordinate_system" not in values: + raise ValueError("coordinate_system is incorrect") + co = values["coordinate_system"] + if co == "osgb": + min_x, max_x = -103976.3, 652897.98 + if co == "lat_lon": + min_x, max_x = -180, 180 + if v < min_x or v > max_x: + raise ValueError(f"x = {v} must be within {[min_x, max_x]} for {co} coordinate system") + return v + + @validator("y") + def validate_y(cls, v, values): + """Validate 'y'""" + min_y: float + max_y: float + if "coordinate_system" not in values: + raise ValueError("coordinate_system is incorrect") + co = values["coordinate_system"] + if co == "osgb": + min_y, max_y = -16703.87, 1199851.44 + if co == "lat_lon": + min_y, max_y = -90, 90 + if v < min_y or v > max_y: + raise ValueError(f"y = {v} must be within {[min_y, max_y]} for {co} coordinate system") + return v + class BatchKey(Enum): """The names of the different elements of each batch. diff --git a/tests/utils/test_constants.py b/tests/utils/test_constants.py new file mode 100644 index 000000000..189b3dc80 --- /dev/null +++ b/tests/utils/test_constants.py @@ -0,0 +1,67 @@ +from ocf_datapipes.utils.consts import Location +import pytest + + +def test_make_valid_location_object_with_default_coordinate_system(): + x, y = -1000.5, 50000 + location = Location(x=x, y=y) + assert location.x == x, "location.x value not set correctly" + assert location.y == y, "location.x value not set correctly" + assert ( + location.coordinate_system == "osgb" + ), "location.coordinate_system value not set correctly" + + +def test_make_valid_location_object_with_osgb_coordinate_system(): + x, y, coordinate_system = 1.2, 22.9, "osgb" + location = Location(x=x, y=y, coordinate_system=coordinate_system) + assert location.x == x, "location.x value not set correctly" + assert location.y == y, "location.x value not set correctly" + assert ( + location.coordinate_system == coordinate_system + ), "location.coordinate_system value not set correctly" + + +def test_make_valid_location_object_with_lat_lon_coordinate_system(): + x, y, coordinate_system = 1.2, 1.2, "lat_lon" + location = Location(x=x, y=y, coordinate_system=coordinate_system) + assert location.x == x, "location.x value not set correctly" + assert location.y == y, "location.x value not set correctly" + assert ( + location.coordinate_system == coordinate_system + ), "location.coordinate_system value not set correctly" + + +def test_make_invalid_location_object_with_invalid_osgb_x(): + x, y, coordinate_system = 10000000, 1.2, "osgb" + with pytest.raises(ValueError) as err: + location = Location(x=x, y=y, coordinate_system=coordinate_system) + assert err.typename == "ValidationError" + + +def test_make_invalid_location_object_with_invalid_osgb_y(): + x, y, coordinate_system = 2.5, 10000000, "osgb" + with pytest.raises(ValueError) as err: + location = Location(x=x, y=y, coordinate_system=coordinate_system) + assert err.typename == "ValidationError" + + +def test_make_invalid_location_object_with_invalid_lat_lon_x(): + x, y, coordinate_system = 200, 1.2, "lat_lon" + with pytest.raises(ValueError) as err: + location = Location(x=x, y=y, coordinate_system=coordinate_system) + assert err.typename == "ValidationError" + + +def test_make_invalid_location_object_with_invalid_lat_lon_y(): + x, y, coordinate_system = 2.5, -200, "lat_lon" + with pytest.raises(ValueError) as err: + location = Location(x=x, y=y, coordinate_system=coordinate_system) + assert err.typename == "ValidationError" + + +def test_make_invalid_location_object_with_invalid_coordinate_system(): + x, y, coordinate_system = 2.5, 1000, "abcd" + with pytest.raises(ValueError) as err: + location = Location(x=x, y=y, coordinate_system=coordinate_system) + assert err.typename == "ValidationError"