Skip to content

Commit

Permalink
save the model in a sensible place (#227)
Browse files Browse the repository at this point in the history
  • Loading branch information
peterdudfield authored Nov 26, 2024
1 parent 810750b commit be90200
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
14 changes: 7 additions & 7 deletions quartz_solar_forecast/forecasts/v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from xgboost.sklearn import XGBRegressor

from . import constants
import quartz_solar_forecast

logger = logging.getLogger(__name__)

Expand All @@ -36,6 +37,7 @@ class TryolabsSolarPowerPredictor:
Predicts solar power output for the given parameters.
"""
DATE_COLUMN = "date"
download_dir = os.path.dirname(quartz_solar_forecast.__file__) + "/models"

def _download_model(self, filename: str, repo_id: str, file_path: str) -> str:
"""
Expand All @@ -56,12 +58,11 @@ def _download_model(self, filename: str, repo_id: str, file_path: str) -> str:
The path to the locally saved model file.
"""
# Use the project directory instead of the user's home directory
download_dir = "/home/runner/work/Open-Source-Quartz-Solar-Forecast/Open-Source-Quartz-Solar-Forecast"
os.makedirs(download_dir, exist_ok=True)
os.makedirs(self.download_dir, exist_ok=True)

downloaded_file = hf_hub_download(repo_id=repo_id, filename=file_path, cache_dir=download_dir)
downloaded_file = hf_hub_download(repo_id=repo_id, filename=file_path, cache_dir=self.download_dir)

target_path = os.path.join(download_dir, filename)
target_path = os.path.join(self.download_dir, filename)

# copy file from downloaded_file to target_path
shutil.copyfile(downloaded_file, target_path)
Expand Down Expand Up @@ -111,14 +112,13 @@ def load_model(
The loaded XGBoost model ready for making predictions.
"""
# Use the project directory
download_dir = "/home/runner/work/Open-Source-Quartz-Solar-Forecast/Open-Source-Quartz-Solar-Forecast"
zipfile_model = os.path.join(download_dir, model_file + ".zip")
zipfile_model = os.path.join(self.download_dir, model_file + ".zip")

if not os.path.isfile(zipfile_model):
logger.info("Downloading model...")
zipfile_model = self._download_model(model_file + ".zip", repo_id, file_path)

model_path = os.path.join(download_dir, model_file)
model_path = os.path.join(self.download_dir, model_file)
if not os.path.isfile(model_path):
logger.info("Preparing model...")
self._decompress_zipfile(zipfile_model)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_forecast_no_ts.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def test_run_forecast_no_ts():
current_hr = pd.Timestamp.now().round(freq='h')

# run gradient boosting model with no ts
predications_df = run_forecast(site=site, model="gb")
predications_df = run_forecast(site=site, model="gb", ts=current_ts)
# check current ts agrees with dataset
assert predications_df.index.min() == current_ts

Expand All @@ -20,7 +20,7 @@ def test_run_forecast_no_ts():
print(f"Max: {predications_df['power_kw'].max()}")

# run xgb model with no ts
predications_df = run_forecast(site=site, model="xgb")
predications_df = run_forecast(site=site, model="xgb", ts=current_ts)
# check current ts agrees with dataset
assert predications_df.index.min() == current_hr

Expand Down

0 comments on commit be90200

Please sign in to comment.