diff --git a/pysquirrel/__init__.py b/pysquirrel/__init__.py index e87b28c..9115f59 100644 --- a/pysquirrel/__init__.py +++ b/pysquirrel/__init__.py @@ -1,6 +1,6 @@ """pysquirrel""" -from . import core +from .core import AllRegions # create database -nuts = core.AllRegions() +nuts = AllRegions() diff --git a/pysquirrel/core.py b/pysquirrel/core.py index e62cc3e..7ec2c1d 100644 --- a/pysquirrel/core.py +++ b/pysquirrel/core.py @@ -105,12 +105,10 @@ class SRRegion(Region): class AllRegions: """Database that contains list of all territorial region.""" - search_index: dict = {} data: list[NUTSRegion | SRRegion] = [] def __init__(self) -> None: self._load() - self._set_index() def _load(self) -> None: """ @@ -137,63 +135,39 @@ def _load(self) -> None: } self.data.append(cls(**region)) - def _set_index(self) -> None: - """ - Builds search index to be used to retrieve regions. - """ - for region in self.data: - for field in fields(region): - key = self.search_index.setdefault(field.name, {}) - value = getattr(region, field.name) - if field.name == "code": - key[value] = region - else: - if value in key: - key[value].append(region) - else: - key[value] = [region] - def _search( self, param: str, value: str | int, ) -> set[NUTSRegion | SRRegion]: """ - Searches database index for one value of a parameter + Searches database for one value of a region field and returns a set of all matching result(s). :param param: field to be searched :param value: value(s) to be searched in the field - """ - results = set( - flatten( - [ - self.search_index[param][key] - for key in self.search_index[param] - if key == value - ] - ) - ) - - return results + return set(i for i in self.data if getattr(i, param) == value) - def get(self, **params) -> list[NUTSRegion | SRRegion, None]: + def get( + self, *, country_code: str | list[str] = None, level: int | list[int] = None + ) -> list[NUTSRegion | SRRegion, None]: """ - Searches NUTS 2024 classification database. Supports multiple fields/values - search. + Searches NUTS 2024 classification database by country code(s) and/or + NUTS level. + Returns all regions for the listed countries and levels. - :param **params: key-value pair, with key being a `Region` field to search, - and the value to search + :param country_code: country code(s) to search + :param level: NUTS level(s) to search """ results = [] - if not params: + if not (country_code or level): raise ValueError("no keyword argument(s) passed.") - else: - for param, value in params.items(): - if isinstance(value, (int, str)): - if param in [field.name for field in fields(Region)]: - results.append(self._search(param, value)) - else: - raise TypeError("only one value per keyword argument allowed.") - return list(set.intersection(*results)) + for param, values in {"country_code": country_code, "level": level}.items(): + if isinstance(values, (int, str)): + values = [values] + if values: + results.append( + set.union(*(self._search(param, value) for value in values)) + ) + return list(set.intersection(*results)) diff --git a/tests/test_core.py b/tests/test_core.py new file mode 100644 index 0000000..aaee436 --- /dev/null +++ b/tests/test_core.py @@ -0,0 +1,76 @@ +import pytest + +from pysquirrel.core import Level, NUTSRegion, AllRegions +from pydantic import ValidationError + +MOCK_DATA = [ + NUTSRegion(country_code="AT", code="AT1", label="Ostösterreich", level=1), + NUTSRegion(country_code="AT", code="AT12", label="Niederösterreich", level=2), + NUTSRegion(country_code="AT", code="AT127", label="Wiener Umland/Südteil", level=3), + NUTSRegion(country_code="PT", code="PT1", label="Continente", level=1), + NUTSRegion(country_code="PT", code="PT1C", label="Alentejo", level=2), + NUTSRegion(country_code="PT", code="PT1C1", label="Alentejo Litoral", level=3), +] + + +def test_region_creation(): + region = MOCK_DATA[2] + assert region.country_code == "AT" + assert region.code == "AT127" + assert region.label == "Wiener Umland/Südteil" + assert region.level == Level.LEVEL_3 + + +def test_invalid_country_code(): + with pytest.raises(ValidationError): + NUTSRegion( + country_code="at", code="AT127", label="Wiener Umland/Südteil", level=3 + ) + + +def test_invalid_region_code(): + with pytest.raises(ValidationError): + NUTSRegion( + country_code="AT", code="A127", label="Wiener Umland/Südteil", level=3 + ) + + +def mock_load(self): + """Mocked _load method that loads a sample dataset.""" + self.data = MOCK_DATA + + +def test_all_regions(monkeypatch): + # Create an instance of AllRegions + all_regions = AllRegions() + + # Test full data import + assert len(all_regions.get(level=1)) == 160 + assert len(all_regions.get(level=2)) == 361 + assert len(all_regions.get(level=3)) == 1521 + + # Test data fields + lux = [ + nuts + for nuts in all_regions.get(country_code="LU", level=1) + if "Z" not in nuts.code # exclude Extra-Regio NUTS 1 + ] + assert len(lux) == 1 + assert lux[0].label == "Luxembourg" and lux[0].code == "LU0" + + # Replace the _load method with the mock method + monkeypatch.setattr(AllRegions, "_load", mock_load) + + # Call _load to apply the mock + all_regions._load() + + # Check if the data is loaded correctly + assert len(all_regions.data) == 6 + + # Test query logic with mock data + assert set(all_regions.get(country_code="AT")) == set(MOCK_DATA[:3]) + assert set(all_regions.get(country_code="PT")) == set(MOCK_DATA[3:]) + assert set(all_regions.get(level=2)) == set([MOCK_DATA[1], MOCK_DATA[4]]) + assert set(all_regions.get(country_code="AT", level=1)) == set([MOCK_DATA[0]]) + assert set(all_regions.get(country_code="PT", level=3)) == set([MOCK_DATA[-1]]) + assert set(all_regions.get(country_code=["AT", "PT"])) == set(MOCK_DATA)