Skip to content

Commit

Permalink
Simplify query logic for AllRegions.get() (#28)
Browse files Browse the repository at this point in the history
* Simplify query logic for AllRegions.get()

* Refactor AllRegions.get()

* Fix get and search methods

* Add tests

* Clean up code and add real data tests
  • Loading branch information
dc-almeida authored Aug 29, 2024
1 parent aacb952 commit 18af2a9
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 47 deletions.
4 changes: 2 additions & 2 deletions pysquirrel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""pysquirrel"""

from . import core
from .core import AllRegions

# create database
nuts = core.AllRegions()
nuts = AllRegions()
64 changes: 19 additions & 45 deletions pysquirrel/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,10 @@ class SRRegion(Region):
class AllRegions:
"""Database that contains list of all territorial region."""

search_index: dict = {}
data: list[NUTSRegion | SRRegion] = []

def __init__(self) -> None:
self._load()
self._set_index()

def _load(self) -> None:
"""
Expand All @@ -137,63 +135,39 @@ def _load(self) -> None:
}
self.data.append(cls(**region))

def _set_index(self) -> None:
"""
Builds search index to be used to retrieve regions.
"""
for region in self.data:
for field in fields(region):
key = self.search_index.setdefault(field.name, {})
value = getattr(region, field.name)
if field.name == "code":
key[value] = region
else:
if value in key:
key[value].append(region)
else:
key[value] = [region]

def _search(
self,
param: str,
value: str | int,
) -> set[NUTSRegion | SRRegion]:
"""
Searches database index for one value of a parameter
Searches database for one value of a region field
and returns a set of all matching result(s).
:param param: field to be searched
:param value: value(s) to be searched in the field
"""
results = set(
flatten(
[
self.search_index[param][key]
for key in self.search_index[param]
if key == value
]
)
)

return results
return set(i for i in self.data if getattr(i, param) == value)

def get(self, **params) -> list[NUTSRegion | SRRegion, None]:
def get(
self, *, country_code: str | list[str] = None, level: int | list[int] = None
) -> list[NUTSRegion | SRRegion, None]:
"""
Searches NUTS 2024 classification database. Supports multiple fields/values
search.
Searches NUTS 2024 classification database by country code(s) and/or
NUTS level.
Returns all regions for the listed countries and levels.
:param **params: key-value pair, with key being a `Region` field to search,
and the value to search
:param country_code: country code(s) to search
:param level: NUTS level(s) to search
"""
results = []
if not params:
if not (country_code or level):
raise ValueError("no keyword argument(s) passed.")
else:
for param, value in params.items():
if isinstance(value, (int, str)):
if param in [field.name for field in fields(Region)]:
results.append(self._search(param, value))
else:
raise TypeError("only one value per keyword argument allowed.")
return list(set.intersection(*results))
for param, values in {"country_code": country_code, "level": level}.items():
if isinstance(values, (int, str)):
values = [values]
if values:
results.append(
set.union(*(self._search(param, value) for value in values))
)
return list(set.intersection(*results))
76 changes: 76 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import pytest

from pysquirrel.core import Level, NUTSRegion, AllRegions
from pydantic import ValidationError

MOCK_DATA = [
NUTSRegion(country_code="AT", code="AT1", label="Ostösterreich", level=1),
NUTSRegion(country_code="AT", code="AT12", label="Niederösterreich", level=2),
NUTSRegion(country_code="AT", code="AT127", label="Wiener Umland/Südteil", level=3),
NUTSRegion(country_code="PT", code="PT1", label="Continente", level=1),
NUTSRegion(country_code="PT", code="PT1C", label="Alentejo", level=2),
NUTSRegion(country_code="PT", code="PT1C1", label="Alentejo Litoral", level=3),
]


def test_region_creation():
region = MOCK_DATA[2]
assert region.country_code == "AT"
assert region.code == "AT127"
assert region.label == "Wiener Umland/Südteil"
assert region.level == Level.LEVEL_3


def test_invalid_country_code():
with pytest.raises(ValidationError):
NUTSRegion(
country_code="at", code="AT127", label="Wiener Umland/Südteil", level=3
)


def test_invalid_region_code():
with pytest.raises(ValidationError):
NUTSRegion(
country_code="AT", code="A127", label="Wiener Umland/Südteil", level=3
)


def mock_load(self):
"""Mocked _load method that loads a sample dataset."""
self.data = MOCK_DATA


def test_all_regions(monkeypatch):
# Create an instance of AllRegions
all_regions = AllRegions()

# Test full data import
assert len(all_regions.get(level=1)) == 160
assert len(all_regions.get(level=2)) == 361
assert len(all_regions.get(level=3)) == 1521

# Test data fields
lux = [
nuts
for nuts in all_regions.get(country_code="LU", level=1)
if "Z" not in nuts.code # exclude Extra-Regio NUTS 1
]
assert len(lux) == 1
assert lux[0].label == "Luxembourg" and lux[0].code == "LU0"

# Replace the _load method with the mock method
monkeypatch.setattr(AllRegions, "_load", mock_load)

# Call _load to apply the mock
all_regions._load()

# Check if the data is loaded correctly
assert len(all_regions.data) == 6

# Test query logic with mock data
assert set(all_regions.get(country_code="AT")) == set(MOCK_DATA[:3])
assert set(all_regions.get(country_code="PT")) == set(MOCK_DATA[3:])
assert set(all_regions.get(level=2)) == set([MOCK_DATA[1], MOCK_DATA[4]])
assert set(all_regions.get(country_code="AT", level=1)) == set([MOCK_DATA[0]])
assert set(all_regions.get(country_code="PT", level=3)) == set([MOCK_DATA[-1]])
assert set(all_regions.get(country_code=["AT", "PT"])) == set(MOCK_DATA)

0 comments on commit 18af2a9

Please sign in to comment.