Skip to content

Commit

Permalink
Merge pull request #201 from rjmcoder/main
Browse files Browse the repository at this point in the history
PR for issue #56 Add validation on Location object
  • Loading branch information
peterdudfield authored Jun 5, 2023
2 parents 63b0c33 + cd4c52f commit b664c93
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 1 deletion.
43 changes: 42 additions & 1 deletion ocf_datapipes/utils/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import numpy as np
import xarray as xr
from pydantic import BaseModel
from pydantic import BaseModel, validator

PV_TIME_AXIS = 1
PV_SYSTEM_AXIS = 2
Expand Down Expand Up @@ -144,10 +144,51 @@
class Location(BaseModel):
"""Represent a spatial location."""

coordinate_system: Optional[str] = "osgb" # ["osgb", "lat_lon"]
x: float
y: float
id: Optional[int]

@validator("coordinate_system", pre=True, always=True)
def validate_coordinate_system(cls, v):
"""Validate 'coordinate_system'"""
allowed_coordinate_systen = ["osgb", "lat_lon"]
if v not in allowed_coordinate_systen:
raise ValueError(f"coordinate_system = {v} is not in {allowed_coordinate_systen}")
return v

@validator("x")
def validate_x(cls, v, values):
"""Validate 'x'"""
min_x: float
max_x: float
if "coordinate_system" not in values:
raise ValueError("coordinate_system is incorrect")
co = values["coordinate_system"]
if co == "osgb":
min_x, max_x = -103976.3, 652897.98
if co == "lat_lon":
min_x, max_x = -180, 180
if v < min_x or v > max_x:
raise ValueError(f"x = {v} must be within {[min_x, max_x]} for {co} coordinate system")
return v

@validator("y")
def validate_y(cls, v, values):
"""Validate 'y'"""
min_y: float
max_y: float
if "coordinate_system" not in values:
raise ValueError("coordinate_system is incorrect")
co = values["coordinate_system"]
if co == "osgb":
min_y, max_y = -16703.87, 1199851.44
if co == "lat_lon":
min_y, max_y = -90, 90
if v < min_y or v > max_y:
raise ValueError(f"y = {v} must be within {[min_y, max_y]} for {co} coordinate system")
return v


class BatchKey(Enum):
"""The names of the different elements of each batch.
Expand Down
67 changes: 67 additions & 0 deletions tests/utils/test_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from ocf_datapipes.utils.consts import Location
import pytest


def test_make_valid_location_object_with_default_coordinate_system():
x, y = -1000.5, 50000
location = Location(x=x, y=y)
assert location.x == x, "location.x value not set correctly"
assert location.y == y, "location.x value not set correctly"
assert (
location.coordinate_system == "osgb"
), "location.coordinate_system value not set correctly"


def test_make_valid_location_object_with_osgb_coordinate_system():
x, y, coordinate_system = 1.2, 22.9, "osgb"
location = Location(x=x, y=y, coordinate_system=coordinate_system)
assert location.x == x, "location.x value not set correctly"
assert location.y == y, "location.x value not set correctly"
assert (
location.coordinate_system == coordinate_system
), "location.coordinate_system value not set correctly"


def test_make_valid_location_object_with_lat_lon_coordinate_system():
x, y, coordinate_system = 1.2, 1.2, "lat_lon"
location = Location(x=x, y=y, coordinate_system=coordinate_system)
assert location.x == x, "location.x value not set correctly"
assert location.y == y, "location.x value not set correctly"
assert (
location.coordinate_system == coordinate_system
), "location.coordinate_system value not set correctly"


def test_make_invalid_location_object_with_invalid_osgb_x():
x, y, coordinate_system = 10000000, 1.2, "osgb"
with pytest.raises(ValueError) as err:
location = Location(x=x, y=y, coordinate_system=coordinate_system)
assert err.typename == "ValidationError"


def test_make_invalid_location_object_with_invalid_osgb_y():
x, y, coordinate_system = 2.5, 10000000, "osgb"
with pytest.raises(ValueError) as err:
location = Location(x=x, y=y, coordinate_system=coordinate_system)
assert err.typename == "ValidationError"


def test_make_invalid_location_object_with_invalid_lat_lon_x():
x, y, coordinate_system = 200, 1.2, "lat_lon"
with pytest.raises(ValueError) as err:
location = Location(x=x, y=y, coordinate_system=coordinate_system)
assert err.typename == "ValidationError"


def test_make_invalid_location_object_with_invalid_lat_lon_y():
x, y, coordinate_system = 2.5, -200, "lat_lon"
with pytest.raises(ValueError) as err:
location = Location(x=x, y=y, coordinate_system=coordinate_system)
assert err.typename == "ValidationError"


def test_make_invalid_location_object_with_invalid_coordinate_system():
x, y, coordinate_system = 2.5, 1000, "abcd"
with pytest.raises(ValueError) as err:
location = Location(x=x, y=y, coordinate_system=coordinate_system)
assert err.typename == "ValidationError"

0 comments on commit b664c93

Please sign in to comment.