Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
peterdudfield committed Nov 13, 2024
1 parent 3a5fed6 commit 5d706c7
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 5 deletions.
4 changes: 2 additions & 2 deletions src/india_api/internal/inputs/indiadb/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def get_predicted_solar_power_production_for_location(
forecast_horizon=forecast_horizon,
forecast_horizon_minutes=forecast_horizon_minutes,
smooth_flag=smooth_flag,
model_name=model_name,
ml_model_name=model_name,
)

def get_predicted_wind_power_production_for_location(
Expand Down Expand Up @@ -224,7 +224,7 @@ def get_predicted_wind_power_production_for_location(
forecast_horizon=forecast_horizon,
forecast_horizon_minutes=forecast_horizon_minutes,
smooth_flag=smooth_flag,
model_name=model_name,
ml_model_name=model_name,
)

def get_actual_solar_power_production_for_location(
Expand Down
22 changes: 22 additions & 0 deletions src/india_api/internal/inputs/indiadb/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
from pvsite_datamodel.sqlmodels import Base, ForecastSQL, ForecastValueSQL, GenerationSQL, SiteSQL
from pvsite_datamodel.read.user import get_user_by_email
from pvsite_datamodel.read.model import get_or_create_model
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
from testcontainers.postgres import PostgresContainer
Expand Down Expand Up @@ -123,6 +124,23 @@ def generations(db_session, sites):
@pytest.fixture()
def forecast_values(db_session, sites):
"""Create some fake forecast values"""

make_fake_forecast_values(db_session, sites, "pvnet_india")

@pytest.fixture()
def forecast_values_wind(db_session, sites):
"""Create some fake forecast values"""

make_fake_forecast_values(db_session, sites, "windnet_india")

@pytest.fixture()
def forecast_values_site(db_session, sites):
"""Create some fake forecast values"""

make_fake_forecast_values(db_session, sites, "pvnet_ad_sites")


def make_fake_forecast_values(db_session, sites, model_name):
forecast_values = []
forecast_version: str = "0.0.0"

Expand All @@ -134,6 +152,9 @@ def forecast_values(db_session, sites):
# To make things trickier we make a second forecast at the same for one of the timestamps.
timestamps = timestamps + timestamps[-1:]

# get model
ml_model = get_or_create_model(db_session, model_name)

for site in sites:
for timestamp in timestamps:
forecast: ForecastSQL = ForecastSQL(
Expand All @@ -154,6 +175,7 @@ def forecast_values(db_session, sites):
end_utc=timestamp + timedelta(minutes=horizon + duration),
horizon_minutes=horizon,
)
forecast_value.ml_model = ml_model

forecast_values.append(forecast_value)

Expand Down
6 changes: 3 additions & 3 deletions src/india_api/internal/inputs/indiadb/test_indiadb.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def client(engine, db_session):

class TestIndiaDBClient:
def test_get_predicted_wind_power_production_for_location(
self, client, forecast_values
self, client, forecast_values_wind
) -> None:
locID = "testID"
result = client.get_predicted_wind_power_production_for_location(locID)
Expand All @@ -33,7 +33,7 @@ def test_get_predicted_wind_power_production_for_location(
assert isinstance(record, PredictedPower)

def test_get_predicted_wind_power_production_for_location_raise_error(
self, client, forecast_values
self, client, forecast_values_wind
) -> None:

with pytest.raises(Exception):
Expand Down Expand Up @@ -83,7 +83,7 @@ def test_get_sites_no_sites(self, client, sites) -> None:
sites_from_api = client.get_sites(email="[email protected]")
assert len(sites_from_api) == 0

def test_get_site_forecast(self, client, sites, forecast_values) -> None:
def test_get_site_forecast(self, client, sites, forecast_values_site) -> None:
out = client.get_site_forecast(site_uuid=str(sites[0].site_uuid), email="[email protected]")
assert len(out) > 0

Expand Down

0 comments on commit 5d706c7

Please sign in to comment.