Skip to content

Commit

Permalink
fix type hints in Forecaster
Browse files Browse the repository at this point in the history
  • Loading branch information
maurerle committed Nov 21, 2024
1 parent 4e93615 commit fb2e71a
Showing 1 changed file with 22 additions and 20 deletions.
42 changes: 22 additions & 20 deletions assume/common/forecasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,32 +30,32 @@ class Forecaster:
"""

def __init__(self, index: pd.Series):
def __init__(self, index: FastIndex):
self.index = index

def __getitem__(self, column: str) -> pd.Series:
def __getitem__(self, column: str) -> FastSeries:
"""
Returns the forecast for a given column.
Args:
column (str): The column of the forecast.
Returns:
pd.Series: The forecast.
FastSeries: The forecast.
This method returns the forecast for a given column as a pandas Series based on the provided index.
"""
return FastSeries(value=0.0, index=self.index)

def get_availability(self, unit: str) -> pd.Series:
def get_availability(self, unit: str) -> FastSeries:
"""
Returns the availability of a given unit as a pandas Series based on the provided index.
Args:
unit (str): The unit.
Returns:
pd.Series: The availability of the unit.
FastSeries: The availability of the unit.
Example:
>>> forecaster = Forecaster(index=pd.Series([1, 2, 3]))
Expand All @@ -65,7 +65,7 @@ def get_availability(self, unit: str) -> pd.Series:

return self[f"availability_{unit}"]

def get_price(self, fuel_type: str) -> pd.Series:
def get_price(self, fuel_type: str) -> FastSeries:
"""
Returns the price for a given fuel type as a pandas Series or zeros if the type does
not exist.
Expand All @@ -74,7 +74,7 @@ def get_price(self, fuel_type: str) -> pd.Series:
fuel_type (str): The fuel type.
Returns:
pd.Series: The price of the fuel.
FastSeries: The price of the fuel.
Example:
>>> forecaster = Forecaster(index=pd.Series([1, 2, 3]))
Expand Down Expand Up @@ -125,9 +125,9 @@ class CsvForecaster(Forecaster):
def __init__(
self,
index: pd.Series,
powerplants_units: dict[str, pd.Series] = {},
demand_units: dict[str, pd.Series] = {},
market_configs: dict[str, pd.Series] = {},
powerplants_units: pd.DataFrame,
demand_units: pd.DataFrame,
market_configs: dict = {},
*args,
**kwargs,
):
Expand All @@ -138,7 +138,7 @@ def __init__(
self.market_configs = market_configs
self.forecasts = pd.DataFrame(index=index)

def __getitem__(self, column: str) -> pd.Series:
def __getitem__(self, column: str) -> FastSeries:
"""
Returns the forecast for a given column.
Expand All @@ -148,7 +148,7 @@ def __getitem__(self, column: str) -> pd.Series:
column (str): The column of the forecast.
Returns:
pd.Series: The forecast for the given column.
FastSeries: The forecast for the given column.
"""

Expand Down Expand Up @@ -441,12 +441,12 @@ class RandomForecaster(CsvForecaster):
Attributes:
index (pandas.Series): The index of the forecasts.
powerplants_units (dict[str, pandas.Series]): The power plants.
powerplants_units (pandas.DataFrame): The power plants.
sigma (float): The standard deviation of the noise.
Args:
index (pandas.Series): The index of the forecasts.
powerplants_units (dict[str, pandas.Series]): The power plants.
powerplants_units (pandas.DataFrame): The power plants.
sigma (float): The standard deviation of the noise.
Example:
Expand All @@ -459,17 +459,19 @@ class RandomForecaster(CsvForecaster):
def __init__(
self,
index: pd.Series,
powerplants_units: dict[str, pd.Series] = {},
powerplants_units: pd.DataFrame,
demand_units: pd.DataFrame,
market_configs: dict = {},
sigma: float = 0.02,
*args,
**kwargs,
):
super().__init__(index, powerplants_units, *args, **kwargs)
super().__init__(index, powerplants_units, demand_units, market_configs, *args, **kwargs)

self.index = FastIndex(start=index[0], end=index[-1], freq=pd.infer_freq(index))
self.sigma = sigma

def __getitem__(self, column: str) -> pd.Series:
def __getitem__(self, column: str) -> FastSeries:
"""
Retrieves forecasted values modified by random noise.
Expand All @@ -481,7 +483,7 @@ def __getitem__(self, column: str) -> pd.Series:
column (str): The column of the forecast.
Returns:
pd.Series: The forecast modified by random noise.
FastSeries: The forecast modified by random noise.
"""

Expand Down Expand Up @@ -549,7 +551,7 @@ def __init__(
self.demand = self._to_fast_series(demand, "demand")
self.price_forecast = self._to_fast_series(price_forecast, "price_forecast")

def __getitem__(self, column: str) -> pd.Series:
def __getitem__(self, column: str) -> FastSeries:
"""
Retrieves forecasted values.
Expand All @@ -562,7 +564,7 @@ def __getitem__(self, column: str) -> pd.Series:
column (str): The column for which forecasted values are requested.
Returns:
pd.Series: The forecasted values for the specified column.
FastSeries: The forecasted values for the specified column.
"""

Expand Down

0 comments on commit fb2e71a

Please sign in to comment.