From fb2e71a9aa071f64a760ff141ad1de2b680c3fbd Mon Sep 17 00:00:00 2001 From: Florian Maurer Date: Thu, 21 Nov 2024 13:32:01 +0100 Subject: [PATCH] fix type hints in Forecaster --- assume/common/forecasts.py | 42 ++++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/assume/common/forecasts.py b/assume/common/forecasts.py index bfe9e577..d8a1d324 100644 --- a/assume/common/forecasts.py +++ b/assume/common/forecasts.py @@ -30,10 +30,10 @@ class Forecaster: """ - def __init__(self, index: pd.Series): + def __init__(self, index: FastIndex): self.index = index - def __getitem__(self, column: str) -> pd.Series: + def __getitem__(self, column: str) -> FastSeries: """ Returns the forecast for a given column. @@ -41,13 +41,13 @@ def __getitem__(self, column: str) -> pd.Series: column (str): The column of the forecast. Returns: - pd.Series: The forecast. + FastSeries: The forecast. This method returns the forecast for a given column as a pandas Series based on the provided index. """ return FastSeries(value=0.0, index=self.index) - def get_availability(self, unit: str) -> pd.Series: + def get_availability(self, unit: str) -> FastSeries: """ Returns the availability of a given unit as a pandas Series based on the provided index. @@ -55,7 +55,7 @@ def get_availability(self, unit: str) -> pd.Series: unit (str): The unit. Returns: - pd.Series: The availability of the unit. + FastSeries: The availability of the unit. Example: >>> forecaster = Forecaster(index=pd.Series([1, 2, 3])) @@ -65,7 +65,7 @@ def get_availability(self, unit: str) -> pd.Series: return self[f"availability_{unit}"] - def get_price(self, fuel_type: str) -> pd.Series: + def get_price(self, fuel_type: str) -> FastSeries: """ Returns the price for a given fuel type as a pandas Series or zeros if the type does not exist. @@ -74,7 +74,7 @@ def get_price(self, fuel_type: str) -> pd.Series: fuel_type (str): The fuel type. Returns: - pd.Series: The price of the fuel. + FastSeries: The price of the fuel. Example: >>> forecaster = Forecaster(index=pd.Series([1, 2, 3])) @@ -125,9 +125,9 @@ class CsvForecaster(Forecaster): def __init__( self, index: pd.Series, - powerplants_units: dict[str, pd.Series] = {}, - demand_units: dict[str, pd.Series] = {}, - market_configs: dict[str, pd.Series] = {}, + powerplants_units: pd.DataFrame, + demand_units: pd.DataFrame, + market_configs: dict = {}, *args, **kwargs, ): @@ -138,7 +138,7 @@ def __init__( self.market_configs = market_configs self.forecasts = pd.DataFrame(index=index) - def __getitem__(self, column: str) -> pd.Series: + def __getitem__(self, column: str) -> FastSeries: """ Returns the forecast for a given column. @@ -148,7 +148,7 @@ def __getitem__(self, column: str) -> pd.Series: column (str): The column of the forecast. Returns: - pd.Series: The forecast for the given column. + FastSeries: The forecast for the given column. """ @@ -441,12 +441,12 @@ class RandomForecaster(CsvForecaster): Attributes: index (pandas.Series): The index of the forecasts. - powerplants_units (dict[str, pandas.Series]): The power plants. + powerplants_units (pandas.DataFrame): The power plants. sigma (float): The standard deviation of the noise. Args: index (pandas.Series): The index of the forecasts. - powerplants_units (dict[str, pandas.Series]): The power plants. + powerplants_units (pandas.DataFrame): The power plants. sigma (float): The standard deviation of the noise. Example: @@ -459,17 +459,19 @@ class RandomForecaster(CsvForecaster): def __init__( self, index: pd.Series, - powerplants_units: dict[str, pd.Series] = {}, + powerplants_units: pd.DataFrame, + demand_units: pd.DataFrame, + market_configs: dict = {}, sigma: float = 0.02, *args, **kwargs, ): - super().__init__(index, powerplants_units, *args, **kwargs) + super().__init__(index, powerplants_units, demand_units, market_configs, *args, **kwargs) self.index = FastIndex(start=index[0], end=index[-1], freq=pd.infer_freq(index)) self.sigma = sigma - def __getitem__(self, column: str) -> pd.Series: + def __getitem__(self, column: str) -> FastSeries: """ Retrieves forecasted values modified by random noise. @@ -481,7 +483,7 @@ def __getitem__(self, column: str) -> pd.Series: column (str): The column of the forecast. Returns: - pd.Series: The forecast modified by random noise. + FastSeries: The forecast modified by random noise. """ @@ -549,7 +551,7 @@ def __init__( self.demand = self._to_fast_series(demand, "demand") self.price_forecast = self._to_fast_series(price_forecast, "price_forecast") - def __getitem__(self, column: str) -> pd.Series: + def __getitem__(self, column: str) -> FastSeries: """ Retrieves forecasted values. @@ -562,7 +564,7 @@ def __getitem__(self, column: str) -> pd.Series: column (str): The column for which forecasted values are requested. Returns: - pd.Series: The forecasted values for the specified column. + FastSeries: The forecasted values for the specified column. """