From 0d1f28f8afaabdd02f3aa1fe07ff02f06ef34828 Mon Sep 17 00:00:00 2001 From: David Rodriguez Date: Tue, 30 Apr 2024 16:21:10 -0400 Subject: [PATCH] Adding some schema validation and implementing tests for them (#495) --- simple/schema.py | 35 +++++++++++++++++++++++-------- tests/test_schema.py | 49 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 8 deletions(-) create mode 100644 tests/test_schema.py diff --git a/simple/schema.py b/simple/schema.py index eedc2c59c..3e6585286 100644 --- a/simple/schema.py +++ b/simple/schema.py @@ -3,21 +3,22 @@ """ import enum + import sqlalchemy as sa -from sqlalchemy.orm import validates +from astrodbkit2.astrodb import Base +from astrodbkit2.views import view +from astropy.io.votable.ucd import check_ucd from sqlalchemy import ( Boolean, Column, + DateTime, + Enum, Float, ForeignKey, - String, - Enum, - DateTime, ForeignKeyConstraint, + String, ) -from astrodbkit2.astrodb import Base -from astrodbkit2.views import view -from astropy.io.votable.ucd import check_ucd +from sqlalchemy.orm import validates # ------------------------------------------------------------------------------------------------------------------- # Reference tables @@ -44,6 +45,12 @@ class Publications(Base): doi = Column(String(100)) description = Column(String(1000)) + @validates("reference") + def validate_reference(self, key, value): + if value is None or len(value) > 30: + raise ValueError(f"Provided reference is invalid; too long or None: {value}") + return value + class Telescopes(Base): __tablename__ = 'Telescopes' @@ -102,7 +109,7 @@ def validate_ucd(self, key, value): raise ValueError(f"UCD {value} not in controlled vocabulary") return value - @validates("effective_wavelength_angstroms") + @validates("effective_wavelength") def validate_wavelength(self, key, value): if value is None or value < 0: raise ValueError(f"Invalid effective wavelength received: {value}") @@ -169,6 +176,18 @@ class Sources(Base): other_references = Column(String(100)) comments = Column(String(1000)) + @validates("ra") + def validate_ra(self, key, value): + if value > 360 or value < 0: + raise ValueError("RA not in allowed range (0..360)") + return value + + @validates("dec") + def validate_dec(self, key, value): + if value > 90 or value < -90: + raise ValueError("Dec not in allowed range (-90..90)") + return value + class Names(Base): __tablename__ = 'Names' diff --git a/tests/test_schema.py b/tests/test_schema.py new file mode 100644 index 000000000..e3f188572 --- /dev/null +++ b/tests/test_schema.py @@ -0,0 +1,49 @@ +"""Tests for the schema itself and any validating functions""" + +import pytest + +from simple.schema import PhotometryFilters, Publications, Sources + + +def schema_tester(table, values, error_state): + """Helper function to handle the basic testing of the schema classes""" + if error_state is None: + _ = table(**values) + else: + with pytest.raises(error_state): + _ = table(**values) + +@pytest.mark.parametrize("values, error_state", + [ + ({"band": "2MASS.J", "effective_wavelength": 1.2, "ucd": "phot;em.IR.J"}, None), + ({"band": "2MASS.J", "effective_wavelength": 1.2, "ucd": "bad"}, ValueError), + ({"band": "bad", "effective_wavelength": 1.2, "ucd": "phot;em.IR.J"}, ValueError), + ({"band": "2MASS.J", "effective_wavelength": -99, "ucd": "phot;em.IR.J"}, ValueError), + ]) +def test_photometryfilters(values, error_state): + """Validating PhotometryFilters""" + schema_tester(PhotometryFilters, values, error_state) + + +@pytest.mark.parametrize("values, error_state", + [ + ({"source": "FAKE", "ra": 1.2, "dec": 3.4, "reference": "Ref1"}, None), + ({"source": "FAKE", "ra": 999, "dec": 3.4, "reference": "Ref1"}, ValueError), + ({"source": "FAKE", "ra": -999, "dec": 3.4, "reference": "Ref1"}, ValueError), + ({"source": "FAKE", "ra": 1.2, "dec": 999, "reference": "Ref1"}, ValueError), + ({"source": "FAKE", "ra": 1.2, "dec": -999, "reference": "Ref1"}, ValueError), + ]) +def test_sources(values, error_state): + """Validating Sources""" + schema_tester(Sources, values, error_state) + + +@pytest.mark.parametrize("values, error_state", + [ + ({"reference": "Ref1"}, None), + ({"reference": None}, ValueError), + ({"reference": "THIS-REFERENCE-IS-REALLY-REALLY-LONG"}, ValueError), + ]) +def test_publications(values, error_state): + """Validating Publications""" + schema_tester(Publications, values, error_state) \ No newline at end of file