Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue/model names #98

Merged
merged 3 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions src/india_api/internal/inputs/indiadb/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def get_predicted_power_production_for_location(
self,
location: str,
asset_type: SiteAssetType,
ml_model_name: str,
forecast_horizon: ForecastHorizon = ForecastHorizon.latest,
forecast_horizon_minutes: Optional[int] = None,
smooth_flag: bool = True,
Expand All @@ -65,6 +66,7 @@ def get_predicted_power_production_for_location(
Args:
location: the location to get the predicted power production for
asset_type: The type of asset to get the forecast for
ml_model_name: The name of the model to get the forecast from
forecast_horizon: The time horizon to get the data for. Can be latest or day ahead
forecast_horizon_minutes: The number of minutes to get the forecast for. forecast_horizon must be 'horizon'
smooth_flag: Flag to smooth the forecast
Expand Down Expand Up @@ -109,6 +111,7 @@ def get_predicted_power_production_for_location(
day_ahead_hours=day_ahead_hours,
day_ahead_timezone_delta_hours=day_ahead_timezone_delta_hours,
forecast_horizon_minutes=forecast_horizon_minutes,
model_name=ml_model_name,
)
forecast_values: [ForecastValueSQL] = values[site.site_uuid]

Expand Down Expand Up @@ -183,12 +186,16 @@ def get_predicted_solar_power_production_for_location(
smooth_flag: Flag to smooth the forecast
"""

# set this to be hard coded for now
model_name = "pvnet_india"

return self.get_predicted_power_production_for_location(
location=location,
asset_type=SiteAssetType.pv,
forecast_horizon=forecast_horizon,
forecast_horizon_minutes=forecast_horizon_minutes,
smooth_flag=smooth_flag,
ml_model_name=model_name,
)

def get_predicted_wind_power_production_for_location(
Expand All @@ -208,12 +215,16 @@ def get_predicted_wind_power_production_for_location(
smooth_flag: Flag to smooth the forecast
"""

# set this to be hard coded for now
model_name = "windnet_india"

return self.get_predicted_power_production_for_location(
location=location,
asset_type=SiteAssetType.wind,
forecast_horizon=forecast_horizon,
forecast_horizon_minutes=forecast_horizon_minutes,
smooth_flag=smooth_flag
smooth_flag=smooth_flag,
ml_model_name=model_name,
)

def get_actual_solar_power_production_for_location(
Expand Down Expand Up @@ -261,11 +272,14 @@ def get_sites(self, email: str) -> list[internal.Site]:

return sites

def get_site_forecast(self, site_uuid: str, email:str) -> list[internal.PredictedPower]:
def get_site_forecast(self, site_uuid: str, email: str) -> list[internal.PredictedPower]:
"""Get a forecast for a site, this is for a solar site"""

# TODO feels like there is some duplicated code here which could be refactored

# hard coded model name
ml_model_name = "pvnet_ad_sites"

# Get the window
start, _ = get_window()

Expand All @@ -276,9 +290,7 @@ def get_site_forecast(self, site_uuid: str, email:str) -> list[internal.Predicte
site_uuid = UUID(site_uuid)

values = get_latest_forecast_values_by_site(
session,
site_uuids=[site_uuid],
start_utc=start,
session, site_uuids=[site_uuid], start_utc=start, model_name=ml_model_name
)
forecast_values: [ForecastValueSQL] = values[site_uuid]

Expand All @@ -296,7 +308,7 @@ def get_site_forecast(self, site_uuid: str, email:str) -> list[internal.Predicte

return values

def get_site_generation(self, site_uuid: str, email:str) -> list[internal.ActualPower]:
def get_site_generation(self, site_uuid: str, email: str) -> list[internal.ActualPower]:
"""Get the generation for a site, this is for a solar site"""

# TODO feels like there is some duplicated code here which could be refactored
Expand Down Expand Up @@ -328,7 +340,9 @@ def get_site_generation(self, site_uuid: str, email:str) -> list[internal.Actual

return values

def post_site_generation(self, site_uuid: str, generation: list[internal.ActualPower], email:str):
def post_site_generation(
self, site_uuid: str, generation: list[internal.ActualPower], email: str
):
"""Post generation for a site"""

with self._get_session() as session:
Expand Down
32 changes: 27 additions & 5 deletions src/india_api/internal/inputs/indiadb/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
from pvsite_datamodel.sqlmodels import Base, ForecastSQL, ForecastValueSQL, GenerationSQL, SiteSQL
from pvsite_datamodel.read.user import get_user_by_email
from pvsite_datamodel.read.model import get_or_create_model
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
from testcontainers.postgres import PostgresContainer
Expand Down Expand Up @@ -65,8 +66,8 @@ def sites(db_session):
ml_id=1,
asset_type="pv",
country="india",
region='testID',
client_site_name='ruvnl_pv_testID1'
region="testID",
client_site_name="ruvnl_pv_testID1",
)
db_session.add(site)
sites.append(site)
Expand All @@ -80,16 +81,16 @@ def sites(db_session):
ml_id=2,
asset_type="wind",
country="india",
region='testID',
client_site_name = 'ruvnl_wind_testID'
region="testID",
client_site_name="ruvnl_wind_testID",
)
db_session.add(site)
sites.append(site)

db_session.commit()

# create user
user = get_user_by_email(session=db_session, email='[email protected]')
user = get_user_by_email(session=db_session, email="[email protected]")
user.site_group.sites = sites

db_session.commit()
Expand Down Expand Up @@ -123,6 +124,23 @@ def generations(db_session, sites):
@pytest.fixture()
def forecast_values(db_session, sites):
"""Create some fake forecast values"""

make_fake_forecast_values(db_session, sites, "pvnet_india")

@pytest.fixture()
def forecast_values_wind(db_session, sites):
"""Create some fake forecast values"""

make_fake_forecast_values(db_session, sites, "windnet_india")

@pytest.fixture()
def forecast_values_site(db_session, sites):
"""Create some fake forecast values"""

make_fake_forecast_values(db_session, sites, "pvnet_ad_sites")


def make_fake_forecast_values(db_session, sites, model_name):
forecast_values = []
forecast_version: str = "0.0.0"

Expand All @@ -134,6 +152,9 @@ def forecast_values(db_session, sites):
# To make things trickier we make a second forecast at the same for one of the timestamps.
timestamps = timestamps + timestamps[-1:]

# get model
ml_model = get_or_create_model(db_session, model_name)

for site in sites:
for timestamp in timestamps:
forecast: ForecastSQL = ForecastSQL(
Expand All @@ -154,6 +175,7 @@ def forecast_values(db_session, sites):
end_utc=timestamp + timedelta(minutes=horizon + duration),
horizon_minutes=horizon,
)
forecast_value.ml_model = ml_model

forecast_values.append(forecast_value)

Expand Down
6 changes: 3 additions & 3 deletions src/india_api/internal/inputs/indiadb/test_indiadb.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def client(engine, db_session):

class TestIndiaDBClient:
def test_get_predicted_wind_power_production_for_location(
self, client, forecast_values
self, client, forecast_values_wind
) -> None:
locID = "testID"
result = client.get_predicted_wind_power_production_for_location(locID)
Expand All @@ -33,7 +33,7 @@ def test_get_predicted_wind_power_production_for_location(
assert isinstance(record, PredictedPower)

def test_get_predicted_wind_power_production_for_location_raise_error(
self, client, forecast_values
self, client, forecast_values_wind
) -> None:

with pytest.raises(Exception):
Expand Down Expand Up @@ -83,7 +83,7 @@ def test_get_sites_no_sites(self, client, sites) -> None:
sites_from_api = client.get_sites(email="[email protected]")
assert len(sites_from_api) == 0

def test_get_site_forecast(self, client, sites, forecast_values) -> None:
def test_get_site_forecast(self, client, sites, forecast_values_site) -> None:
out = client.get_site_forecast(site_uuid=str(sites[0].site_uuid), email="[email protected]")
assert len(out) > 0

Expand Down
Loading