Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

save the model in a sensible place #227

Merged
merged 1 commit into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions quartz_solar_forecast/forecasts/v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from xgboost.sklearn import XGBRegressor

from . import constants
import quartz_solar_forecast

logger = logging.getLogger(__name__)

Expand All @@ -36,6 +37,7 @@ class TryolabsSolarPowerPredictor:
Predicts solar power output for the given parameters.
"""
DATE_COLUMN = "date"
download_dir = os.path.dirname(quartz_solar_forecast.__file__) + "/models"

def _download_model(self, filename: str, repo_id: str, file_path: str) -> str:
"""
Expand All @@ -56,12 +58,11 @@ def _download_model(self, filename: str, repo_id: str, file_path: str) -> str:
The path to the locally saved model file.
"""
# Use the project directory instead of the user's home directory
download_dir = "/home/runner/work/Open-Source-Quartz-Solar-Forecast/Open-Source-Quartz-Solar-Forecast"
os.makedirs(download_dir, exist_ok=True)
os.makedirs(self.download_dir, exist_ok=True)

downloaded_file = hf_hub_download(repo_id=repo_id, filename=file_path, cache_dir=download_dir)
downloaded_file = hf_hub_download(repo_id=repo_id, filename=file_path, cache_dir=self.download_dir)

target_path = os.path.join(download_dir, filename)
target_path = os.path.join(self.download_dir, filename)

# copy file from downloaded_file to target_path
shutil.copyfile(downloaded_file, target_path)
Expand Down Expand Up @@ -111,14 +112,13 @@ def load_model(
The loaded XGBoost model ready for making predictions.
"""
# Use the project directory
download_dir = "/home/runner/work/Open-Source-Quartz-Solar-Forecast/Open-Source-Quartz-Solar-Forecast"
zipfile_model = os.path.join(download_dir, model_file + ".zip")
zipfile_model = os.path.join(self.download_dir, model_file + ".zip")

if not os.path.isfile(zipfile_model):
logger.info("Downloading model...")
zipfile_model = self._download_model(model_file + ".zip", repo_id, file_path)

model_path = os.path.join(download_dir, model_file)
model_path = os.path.join(self.download_dir, model_file)
if not os.path.isfile(model_path):
logger.info("Preparing model...")
self._decompress_zipfile(zipfile_model)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_forecast_no_ts.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def test_run_forecast_no_ts():
current_hr = pd.Timestamp.now().round(freq='h')

# run gradient boosting model with no ts
predications_df = run_forecast(site=site, model="gb")
predications_df = run_forecast(site=site, model="gb", ts=current_ts)
# check current ts agrees with dataset
assert predications_df.index.min() == current_ts

Expand All @@ -20,7 +20,7 @@ def test_run_forecast_no_ts():
print(f"Max: {predications_df['power_kw'].max()}")

# run xgb model with no ts
predications_df = run_forecast(site=site, model="xgb")
predications_df = run_forecast(site=site, model="xgb", ts=current_ts)
# check current ts agrees with dataset
assert predications_df.index.min() == current_hr

Expand Down
Loading