From fc929aaa9c07e95d391643ec6b36bb8a9189a359 Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Wed, 13 Nov 2024 11:11:15 +0000 Subject: [PATCH 1/3] add hard coded model names --- .../internal/inputs/indiadb/client.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/src/india_api/internal/inputs/indiadb/client.py b/src/india_api/internal/inputs/indiadb/client.py index a2ae167..0cff60a 100644 --- a/src/india_api/internal/inputs/indiadb/client.py +++ b/src/india_api/internal/inputs/indiadb/client.py @@ -56,6 +56,7 @@ def get_predicted_power_production_for_location( self, location: str, asset_type: SiteAssetType, + ml_model_name: str, forecast_horizon: ForecastHorizon = ForecastHorizon.latest, forecast_horizon_minutes: Optional[int] = None, smooth_flag: bool = True, @@ -65,6 +66,7 @@ def get_predicted_power_production_for_location( Args: location: the location to get the predicted power production for asset_type: The type of asset to get the forecast for + ml_model_name: The name of the model to get the forecast from forecast_horizon: The time horizon to get the data for. Can be latest or day ahead forecast_horizon_minutes: The number of minutes to get the forecast for. forecast_horizon must be 'horizon' smooth_flag: Flag to smooth the forecast @@ -109,6 +111,7 @@ def get_predicted_power_production_for_location( day_ahead_hours=day_ahead_hours, day_ahead_timezone_delta_hours=day_ahead_timezone_delta_hours, forecast_horizon_minutes=forecast_horizon_minutes, + model_name=ml_model_name ) forecast_values: [ForecastValueSQL] = values[site.site_uuid] @@ -183,12 +186,17 @@ def get_predicted_solar_power_production_for_location( smooth_flag: Flag to smooth the forecast """ + # set this to be hard coded for now + model_name = 'pvnet_india' + + return self.get_predicted_power_production_for_location( location=location, asset_type=SiteAssetType.pv, forecast_horizon=forecast_horizon, forecast_horizon_minutes=forecast_horizon_minutes, smooth_flag=smooth_flag, + model_name=model_name ) def get_predicted_wind_power_production_for_location( @@ -208,12 +216,16 @@ def get_predicted_wind_power_production_for_location( smooth_flag: Flag to smooth the forecast """ + # set this to be hard coded for now + model_name = 'windnet_india' + return self.get_predicted_power_production_for_location( location=location, asset_type=SiteAssetType.wind, forecast_horizon=forecast_horizon, forecast_horizon_minutes=forecast_horizon_minutes, - smooth_flag=smooth_flag + smooth_flag=smooth_flag, + model_name=model_name, ) def get_actual_solar_power_production_for_location( @@ -266,6 +278,9 @@ def get_site_forecast(self, site_uuid: str, email:str) -> list[internal.Predicte # TODO feels like there is some duplicated code here which could be refactored + # hard coded model name + ml_model_name = 'pvnet_ad_sites' + # Get the window start, _ = get_window() @@ -279,6 +294,7 @@ def get_site_forecast(self, site_uuid: str, email:str) -> list[internal.Predicte session, site_uuids=[site_uuid], start_utc=start, + model_name=ml_model_name ) forecast_values: [ForecastValueSQL] = values[site_uuid] From 3a5fed605ef99689925bf2bccbc1d44fb50f3046 Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Wed, 13 Nov 2024 11:12:06 +0000 Subject: [PATCH 2/3] lint --- .../internal/inputs/indiadb/client.py | 24 +++++++++---------- .../internal/inputs/indiadb/conftest.py | 10 ++++---- 2 files changed, 16 insertions(+), 18 deletions(-) diff --git a/src/india_api/internal/inputs/indiadb/client.py b/src/india_api/internal/inputs/indiadb/client.py index 0cff60a..36a77d3 100644 --- a/src/india_api/internal/inputs/indiadb/client.py +++ b/src/india_api/internal/inputs/indiadb/client.py @@ -111,7 +111,7 @@ def get_predicted_power_production_for_location( day_ahead_hours=day_ahead_hours, day_ahead_timezone_delta_hours=day_ahead_timezone_delta_hours, forecast_horizon_minutes=forecast_horizon_minutes, - model_name=ml_model_name + model_name=ml_model_name, ) forecast_values: [ForecastValueSQL] = values[site.site_uuid] @@ -187,8 +187,7 @@ def get_predicted_solar_power_production_for_location( """ # set this to be hard coded for now - model_name = 'pvnet_india' - + model_name = "pvnet_india" return self.get_predicted_power_production_for_location( location=location, @@ -196,7 +195,7 @@ def get_predicted_solar_power_production_for_location( forecast_horizon=forecast_horizon, forecast_horizon_minutes=forecast_horizon_minutes, smooth_flag=smooth_flag, - model_name=model_name + model_name=model_name, ) def get_predicted_wind_power_production_for_location( @@ -217,7 +216,7 @@ def get_predicted_wind_power_production_for_location( """ # set this to be hard coded for now - model_name = 'windnet_india' + model_name = "windnet_india" return self.get_predicted_power_production_for_location( location=location, @@ -273,13 +272,13 @@ def get_sites(self, email: str) -> list[internal.Site]: return sites - def get_site_forecast(self, site_uuid: str, email:str) -> list[internal.PredictedPower]: + def get_site_forecast(self, site_uuid: str, email: str) -> list[internal.PredictedPower]: """Get a forecast for a site, this is for a solar site""" # TODO feels like there is some duplicated code here which could be refactored # hard coded model name - ml_model_name = 'pvnet_ad_sites' + ml_model_name = "pvnet_ad_sites" # Get the window start, _ = get_window() @@ -291,10 +290,7 @@ def get_site_forecast(self, site_uuid: str, email:str) -> list[internal.Predicte site_uuid = UUID(site_uuid) values = get_latest_forecast_values_by_site( - session, - site_uuids=[site_uuid], - start_utc=start, - model_name=ml_model_name + session, site_uuids=[site_uuid], start_utc=start, model_name=ml_model_name ) forecast_values: [ForecastValueSQL] = values[site_uuid] @@ -312,7 +308,7 @@ def get_site_forecast(self, site_uuid: str, email:str) -> list[internal.Predicte return values - def get_site_generation(self, site_uuid: str, email:str) -> list[internal.ActualPower]: + def get_site_generation(self, site_uuid: str, email: str) -> list[internal.ActualPower]: """Get the generation for a site, this is for a solar site""" # TODO feels like there is some duplicated code here which could be refactored @@ -344,7 +340,9 @@ def get_site_generation(self, site_uuid: str, email:str) -> list[internal.Actual return values - def post_site_generation(self, site_uuid: str, generation: list[internal.ActualPower], email:str): + def post_site_generation( + self, site_uuid: str, generation: list[internal.ActualPower], email: str + ): """Post generation for a site""" with self._get_session() as session: diff --git a/src/india_api/internal/inputs/indiadb/conftest.py b/src/india_api/internal/inputs/indiadb/conftest.py index 715b451..cf294d8 100644 --- a/src/india_api/internal/inputs/indiadb/conftest.py +++ b/src/india_api/internal/inputs/indiadb/conftest.py @@ -65,8 +65,8 @@ def sites(db_session): ml_id=1, asset_type="pv", country="india", - region='testID', - client_site_name='ruvnl_pv_testID1' + region="testID", + client_site_name="ruvnl_pv_testID1", ) db_session.add(site) sites.append(site) @@ -80,8 +80,8 @@ def sites(db_session): ml_id=2, asset_type="wind", country="india", - region='testID', - client_site_name = 'ruvnl_wind_testID' + region="testID", + client_site_name="ruvnl_wind_testID", ) db_session.add(site) sites.append(site) @@ -89,7 +89,7 @@ def sites(db_session): db_session.commit() # create user - user = get_user_by_email(session=db_session, email='test@test.com') + user = get_user_by_email(session=db_session, email="test@test.com") user.site_group.sites = sites db_session.commit() From 5d706c7a3808b1c21868bed53d551ffbac7d5970 Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Wed, 13 Nov 2024 11:36:06 +0000 Subject: [PATCH 3/3] fix tests --- .../internal/inputs/indiadb/client.py | 4 ++-- .../internal/inputs/indiadb/conftest.py | 22 +++++++++++++++++++ .../internal/inputs/indiadb/test_indiadb.py | 6 ++--- 3 files changed, 27 insertions(+), 5 deletions(-) diff --git a/src/india_api/internal/inputs/indiadb/client.py b/src/india_api/internal/inputs/indiadb/client.py index 36a77d3..bea2bf2 100644 --- a/src/india_api/internal/inputs/indiadb/client.py +++ b/src/india_api/internal/inputs/indiadb/client.py @@ -195,7 +195,7 @@ def get_predicted_solar_power_production_for_location( forecast_horizon=forecast_horizon, forecast_horizon_minutes=forecast_horizon_minutes, smooth_flag=smooth_flag, - model_name=model_name, + ml_model_name=model_name, ) def get_predicted_wind_power_production_for_location( @@ -224,7 +224,7 @@ def get_predicted_wind_power_production_for_location( forecast_horizon=forecast_horizon, forecast_horizon_minutes=forecast_horizon_minutes, smooth_flag=smooth_flag, - model_name=model_name, + ml_model_name=model_name, ) def get_actual_solar_power_production_for_location( diff --git a/src/india_api/internal/inputs/indiadb/conftest.py b/src/india_api/internal/inputs/indiadb/conftest.py index cf294d8..aa22369 100644 --- a/src/india_api/internal/inputs/indiadb/conftest.py +++ b/src/india_api/internal/inputs/indiadb/conftest.py @@ -6,6 +6,7 @@ import pytest from pvsite_datamodel.sqlmodels import Base, ForecastSQL, ForecastValueSQL, GenerationSQL, SiteSQL from pvsite_datamodel.read.user import get_user_by_email +from pvsite_datamodel.read.model import get_or_create_model from sqlalchemy import create_engine from sqlalchemy.orm import Session from testcontainers.postgres import PostgresContainer @@ -123,6 +124,23 @@ def generations(db_session, sites): @pytest.fixture() def forecast_values(db_session, sites): """Create some fake forecast values""" + + make_fake_forecast_values(db_session, sites, "pvnet_india") + +@pytest.fixture() +def forecast_values_wind(db_session, sites): + """Create some fake forecast values""" + + make_fake_forecast_values(db_session, sites, "windnet_india") + +@pytest.fixture() +def forecast_values_site(db_session, sites): + """Create some fake forecast values""" + + make_fake_forecast_values(db_session, sites, "pvnet_ad_sites") + + +def make_fake_forecast_values(db_session, sites, model_name): forecast_values = [] forecast_version: str = "0.0.0" @@ -134,6 +152,9 @@ def forecast_values(db_session, sites): # To make things trickier we make a second forecast at the same for one of the timestamps. timestamps = timestamps + timestamps[-1:] + # get model + ml_model = get_or_create_model(db_session, model_name) + for site in sites: for timestamp in timestamps: forecast: ForecastSQL = ForecastSQL( @@ -154,6 +175,7 @@ def forecast_values(db_session, sites): end_utc=timestamp + timedelta(minutes=horizon + duration), horizon_minutes=horizon, ) + forecast_value.ml_model = ml_model forecast_values.append(forecast_value) diff --git a/src/india_api/internal/inputs/indiadb/test_indiadb.py b/src/india_api/internal/inputs/indiadb/test_indiadb.py index c9b0791..7c8e28b 100644 --- a/src/india_api/internal/inputs/indiadb/test_indiadb.py +++ b/src/india_api/internal/inputs/indiadb/test_indiadb.py @@ -23,7 +23,7 @@ def client(engine, db_session): class TestIndiaDBClient: def test_get_predicted_wind_power_production_for_location( - self, client, forecast_values + self, client, forecast_values_wind ) -> None: locID = "testID" result = client.get_predicted_wind_power_production_for_location(locID) @@ -33,7 +33,7 @@ def test_get_predicted_wind_power_production_for_location( assert isinstance(record, PredictedPower) def test_get_predicted_wind_power_production_for_location_raise_error( - self, client, forecast_values + self, client, forecast_values_wind ) -> None: with pytest.raises(Exception): @@ -83,7 +83,7 @@ def test_get_sites_no_sites(self, client, sites) -> None: sites_from_api = client.get_sites(email="test2@test.com") assert len(sites_from_api) == 0 - def test_get_site_forecast(self, client, sites, forecast_values) -> None: + def test_get_site_forecast(self, client, sites, forecast_values_site) -> None: out = client.get_site_forecast(site_uuid=str(sites[0].site_uuid), email="test@test.com") assert len(out) > 0