From be90200e3a448caff71e73cf9b3c81c5dbc8c026 Mon Sep 17 00:00:00 2001 From: Peter Dudfield <34686298+peterdudfield@users.noreply.github.com> Date: Tue, 26 Nov 2024 11:37:33 +0000 Subject: [PATCH] save the model in a sensible place (#227) --- quartz_solar_forecast/forecasts/v2.py | 14 +++++++------- tests/test_forecast_no_ts.py | 4 ++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/quartz_solar_forecast/forecasts/v2.py b/quartz_solar_forecast/forecasts/v2.py index 2d5c3660..0cbe18cc 100644 --- a/quartz_solar_forecast/forecasts/v2.py +++ b/quartz_solar_forecast/forecasts/v2.py @@ -11,6 +11,7 @@ from xgboost.sklearn import XGBRegressor from . import constants +import quartz_solar_forecast logger = logging.getLogger(__name__) @@ -36,6 +37,7 @@ class TryolabsSolarPowerPredictor: Predicts solar power output for the given parameters. """ DATE_COLUMN = "date" + download_dir = os.path.dirname(quartz_solar_forecast.__file__) + "/models" def _download_model(self, filename: str, repo_id: str, file_path: str) -> str: """ @@ -56,12 +58,11 @@ def _download_model(self, filename: str, repo_id: str, file_path: str) -> str: The path to the locally saved model file. """ # Use the project directory instead of the user's home directory - download_dir = "/home/runner/work/Open-Source-Quartz-Solar-Forecast/Open-Source-Quartz-Solar-Forecast" - os.makedirs(download_dir, exist_ok=True) + os.makedirs(self.download_dir, exist_ok=True) - downloaded_file = hf_hub_download(repo_id=repo_id, filename=file_path, cache_dir=download_dir) + downloaded_file = hf_hub_download(repo_id=repo_id, filename=file_path, cache_dir=self.download_dir) - target_path = os.path.join(download_dir, filename) + target_path = os.path.join(self.download_dir, filename) # copy file from downloaded_file to target_path shutil.copyfile(downloaded_file, target_path) @@ -111,14 +112,13 @@ def load_model( The loaded XGBoost model ready for making predictions. """ # Use the project directory - download_dir = "/home/runner/work/Open-Source-Quartz-Solar-Forecast/Open-Source-Quartz-Solar-Forecast" - zipfile_model = os.path.join(download_dir, model_file + ".zip") + zipfile_model = os.path.join(self.download_dir, model_file + ".zip") if not os.path.isfile(zipfile_model): logger.info("Downloading model...") zipfile_model = self._download_model(model_file + ".zip", repo_id, file_path) - model_path = os.path.join(download_dir, model_file) + model_path = os.path.join(self.download_dir, model_file) if not os.path.isfile(model_path): logger.info("Preparing model...") self._decompress_zipfile(zipfile_model) diff --git a/tests/test_forecast_no_ts.py b/tests/test_forecast_no_ts.py index 668f8582..0450379d 100644 --- a/tests/test_forecast_no_ts.py +++ b/tests/test_forecast_no_ts.py @@ -11,7 +11,7 @@ def test_run_forecast_no_ts(): current_hr = pd.Timestamp.now().round(freq='h') # run gradient boosting model with no ts - predications_df = run_forecast(site=site, model="gb") + predications_df = run_forecast(site=site, model="gb", ts=current_ts) # check current ts agrees with dataset assert predications_df.index.min() == current_ts @@ -20,7 +20,7 @@ def test_run_forecast_no_ts(): print(f"Max: {predications_df['power_kw'].max()}") # run xgb model with no ts - predications_df = run_forecast(site=site, model="xgb") + predications_df = run_forecast(site=site, model="xgb", ts=current_ts) # check current ts agrees with dataset assert predications_df.index.min() == current_hr