Skip to content

Commit

Permalink
update some type hints
Browse files Browse the repository at this point in the history
improve some functions to utilize new speed up
  • Loading branch information
nick-harder committed Nov 22, 2024
1 parent 028b551 commit 13afb66
Show file tree
Hide file tree
Showing 16 changed files with 266 additions and 403 deletions.
69 changes: 15 additions & 54 deletions assume/common/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from typing import TypedDict

import numpy as np
import pandas as pd

from assume.common.fast_pandas import FastSeries, TensorFastSeries
from assume.common.forecasts import Forecaster
Expand All @@ -22,22 +21,12 @@ class BaseUnit:
"""
A base class for a unit. This class is used as a foundation for all units.
Attributes:
id (str): The ID of the unit.
unit_operator (str): The operator of the unit.
technology (str): The technology of the unit.
bidding_strategies (dict[str, BaseStrategy]): The bidding strategies of the unit.
index (pandas.DatetimeIndex): The index of the unit.
node (str, optional): The node of the unit. Defaults to "".
forecaster (Forecaster, optional): The forecast of the unit. Defaults to None.
**kwargs: Additional keyword arguments.
Args:
id (str): The ID of the unit.
unit_operator (str): The operator of the unit.
technology (str): The technology of the unit.
bidding_strategies (dict[str, BaseStrategy]): The bidding strategies of the unit.
index (pandas.DatetimeIndex): The index of the unit.
index (FastIndex): The index of the unit.
node (str, optional): The node of the unit. Defaults to "".
forecaster (Forecaster, optional): The forecast of the unit. Defaults to None.
location (tuple[float, float], optional): The location of the unit. Defaults to (0.0, 0.0).
Expand Down Expand Up @@ -129,10 +118,10 @@ def calculate_bids(

def calculate_marginal_cost(self, start: datetime, power: float) -> float:
"""
Calculates the marginal cost for the given power.
Calculates the marginal cost for the given power.`
Args:
start (pandas.Timestamp): The start time of the dispatch.
start (datetime.datetime): The start time of the dispatch.
power (float): The power output of the unit.
Returns:
Expand Down Expand Up @@ -281,7 +270,7 @@ def calculate_cashflow(self, product_type: str, orderbook: Orderbook):
cashflow = float(
order.get("accepted_price", 0) * order.get("accepted_volume", 0)
)
elapsed_intervals = (end - start) / pd.Timedelta(self.index.freq)
elapsed_intervals = (end - start) / self.index.freq
self.outputs[f"{product_type}_cashflow"].loc[start:end_excl] += (
cashflow * elapsed_intervals
)
Expand Down Expand Up @@ -331,12 +320,12 @@ def calculate_min_max_power(
Calculates the min and max power for the given time period.
Args:
start (pandas.Timestamp): The start time of the dispatch.
end (pandas.Timestamp): The end time of the dispatch.
start (datetime.datetime): The start time of the dispatch.
end (datetime.datetime): The end time of the dispatch.
product_type (str): The product type of the unit.
Returns:
tuple[pandas.Series, pandas.Series]: The min and max power for the given time period.
tuple[np.array, np.array]: The min and max power for the given time period.
"""

def calculate_ramp(
Expand Down Expand Up @@ -385,26 +374,12 @@ def calculate_ramp(
)
return power

def get_clean_spread(self, prices: pd.DataFrame) -> float:
"""
Returns the clean spread for the given prices.
Args:
prices (pandas.DataFrame): The prices.
Returns:
float: The clean spread for the given prices.
"""
emission_cost = self.emission_factor * prices["co"].mean()
fuel_cost = prices[self.technology.replace("_combined", "")].mean()
return (fuel_cost + emission_cost) / self.efficiency

def get_operation_time(self, start: datetime) -> int:
"""
Returns the time the unit is operating (positive) or shut down (negative).
Args:
start (datetime): The start time.
start (datetime.datetime): The start time.
Returns:
int: The operation time as a positive integer if operating, or negative if shut down.
Expand Down Expand Up @@ -553,27 +528,27 @@ def calculate_min_max_charge(
Calculates the min and max charging power for the given time period.
Args:
start (pandas.Timestamp): The start time of the dispatch.
end (pandas.Timestamp): The end time of the dispatch.
start (datetime.datetime): The start time of the dispatch.
end (datetime.datetime): The end time of the dispatch.
product_type (str, optional): The product type of the unit. Defaults to "energy".
Returns:
tuple[pandas.Series, pandas.Series]: The min and max charging power for the given time period.
tuple[np.array, np.array]: The min and max charging power for the given time period.
"""

def calculate_min_max_discharge(
self, start: datetime, end: datetime, product_type="energy"
) -> tuple[FastSeries, FastSeries]:
) -> tuple[np.array, np.array]:
"""
Calculates the min and max discharging power for the given time period.
Args:
start (pandas.Timestamp): The start time of the dispatch.
end (pandas.Timestamp): The end time of the dispatch.
start (datetime.datetime): The start time of the dispatch.
end (datetime.datetime): The end time of the dispatch.
product_type (str, optional): The product type of the unit. Defaults to "energy".
Returns:
tuple[pandas.Series, pandas.Series]: The min and max discharging power for the given time period.
tuple[np.array, np.array]: The min and max discharging power for the given time period.
"""

def get_soc_before(self, dt: datetime) -> float:
Expand All @@ -593,20 +568,6 @@ def get_soc_before(self, dt: datetime) -> float:
else:
return self.outputs["soc"].at[dt - self.index.freq]

def get_clean_spread(self, prices: pd.DataFrame) -> float:
"""
Returns the clean spread for the given prices.
Args:
prices (pandas.DataFrame): The prices.
Returns:
float: The clean spread for the given prices.
"""
emission_cost = self.emission_factor * prices["co"].mean()
fuel_cost = prices[self.technology.replace("_combined", "")].mean()
return (fuel_cost + emission_cost) / self.efficiency_charge

def calculate_ramp_discharge(
self,
previous_power: float,
Expand Down
10 changes: 7 additions & 3 deletions assume/common/fast_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def __contains__(self, date: datetime) -> bool:
Check if a datetime is within the index range and aligned with the frequency.
Parameters:
date (datetime): The datetime to check.
date (datetime.datetime): The datetime to check.
Returns:
bool: True if the datetime is in the index range and aligned; False otherwise.
Expand Down Expand Up @@ -224,7 +224,7 @@ def _get_idx_from_date(self, date: datetime) -> int:
Convert a datetime to its corresponding index in the range.
Parameters:
date (datetime): The datetime to convert.
date (datetime.datetime): The datetime to convert.
Returns:
int: The index of the datetime in the index range.
Expand Down Expand Up @@ -306,6 +306,10 @@ def __init__(
value (float | np.ndarray, optional): Initial value(s) for the data. Defaults to 0.0.
name (str, optional): Name of the series. Defaults to an empty string.
"""
# check that the index is a FastIndex
if not isinstance(index, FastIndex):
raise TypeError("In FastSeries, index must be a FastIndex object.")

Check warning on line 311 in assume/common/fast_pandas.py

View check run for this annotation

Codecov / codecov/patch

assume/common/fast_pandas.py#L311

Added line #L311 was not covered by tests

self._index = index
self._name = name
self.loc = self # Allow adjusting loc as well
Expand Down Expand Up @@ -835,7 +839,7 @@ def as_pd_series(
return pd.Series(data_slice, index=index, name=name if name else self.name)

@staticmethod
def from_series(series: pd.Series):
def from_pandas_series(series: pd.Series):
"""
Create a FastSeries from a pandas Series.
Expand Down
33 changes: 12 additions & 21 deletions assume/common/forecasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,21 +84,6 @@ def get_price(self, fuel_type: str) -> FastSeries:

return self[f"fuel_price_{fuel_type}"]

def _to_fast_series(self, value, name: str) -> FastSeries:
"""
Converts a value to a FastSeries based on self.index.
Args:
value (float | list | pd.Series): The value to convert.
name (str): Name of the series.
Returns:
FastSeries: The converted FastSeries.
"""
if isinstance(value, pd.Series):
value = value.values # Use the values as an array for consistency
return FastSeries(index=self.index, value=value, name=name)


class CsvForecaster(Forecaster):
"""
Expand Down Expand Up @@ -427,7 +412,7 @@ def convert_forecasts_to_fast_series(self):
for column_name in self.forecasts.columns:
# Convert each column in self.forecasts to FastSeries
forecast_series = self.forecasts[column_name]
fast_forecasts[column_name] = FastSeries.from_series(forecast_series)
fast_forecasts[column_name] = FastSeries.from_pandas_series(forecast_series)

# Replace the DataFrame with the dictionary of FastSeries
self.forecasts = fast_forecasts
Expand Down Expand Up @@ -547,11 +532,17 @@ def __init__(
self.index = FastIndex(start=index[0], end=index[-1], freq=pd.infer_freq(index))

# Convert attributes to FastSeries if they are not already Series
self.fuel_price = self._to_fast_series(fuel_price, "fuel_price")
self.availability = self._to_fast_series(availability, "availability")
self.co2_price = self._to_fast_series(co2_price, "co2_price")
self.demand = self._to_fast_series(demand, "demand")
self.price_forecast = self._to_fast_series(price_forecast, "price_forecast")
self.fuel_price = FastSeries(
index=self.index, value=fuel_price, name="fuel_price"
)
self.availability = FastSeries(
index=self.index, value=availability, name="availability"
)
self.co2_price = FastSeries(index=self.index, value=co2_price, name="co2_price")
self.demand = FastSeries(index=self.index, value=demand, name="demand")
self.price_forecast = FastSeries(
index=self.index, value=price_forecast, name="price_forecast"
)

def __getitem__(self, column: str) -> FastSeries:
"""
Expand Down
3 changes: 2 additions & 1 deletion assume/common/units_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def get_actual_dispatch(
Args:
product_type (str): The product type for which this is done
last (datetime): the last date until which the dispatch was already sent
last (datetime.datetime): the last date until which the dispatch was already sent
Returns:
tuple[list[tuple[datetime, float, str, str]], list[dict]]: market_dispatch and unit_dispatch dataframes
Expand Down Expand Up @@ -336,6 +336,7 @@ def get_actual_dispatch(
dispatch["time"] = unit.index.get_date_list(start, end)
dispatch["unit"] = unit_id
unit_dispatch.append(dispatch)

return market_dispatch, unit_dispatch

def write_actual_dispatch(self, product_type: str) -> None:
Expand Down
43 changes: 18 additions & 25 deletions assume/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,31 +333,6 @@ def aggregate_step_amount(orderbook: Orderbook, begin=None, end=None, groupby=No
return [j for sub in list(aggregation.values()) for j in sub]


def get_test_demand_orders(power: np.ndarray):
"""
Get test demand orders.
Args:
power (numpy.ndarray): Power array.
Returns:
pandas.DataFrame: DataFrame of demand orders.
Examples:
>>> power = np.array([100, 200, 150])
>>> get_test_demand_orders(power)
"""

order_book = {}
for t in range(len(power)):
order_book[t] = dict(
type="demand", hour=t, block_id=t, name="DEM", price=3, volume=-power[t]
)
demand_order = pd.DataFrame.from_dict(order_book, orient="index")
demand_order = demand_order.set_index(["block_id", "hour", "name"])
return demand_order


def separate_orders(orderbook: Orderbook):
"""
Separate orders with several hours into single hour orders.
Expand Down Expand Up @@ -674,3 +649,21 @@ def suppress_output():
os.close(saved_stdout_fd)
os.close(saved_stderr_fd)
os.close(devnull)


# Function to parse the duration string
def parse_duration(duration_str):
if duration_str.endswith("d"):
days = int(duration_str[:-1])
return timedelta(days=days)

Check warning on line 658 in assume/common/utils.py

View check run for this annotation

Codecov / codecov/patch

assume/common/utils.py#L657-L658

Added lines #L657 - L658 were not covered by tests
elif duration_str.endswith("h"):
hours = int(duration_str[:-1])
return timedelta(hours=hours)
elif duration_str.endswith("m"):
minutes = int(duration_str[:-1])
return timedelta(minutes=minutes)
elif duration_str.endswith("s"):
seconds = int(duration_str[:-1])
return timedelta(seconds=seconds)

Check warning on line 667 in assume/common/utils.py

View check run for this annotation

Codecov / codecov/patch

assume/common/utils.py#L662-L667

Added lines #L662 - L667 were not covered by tests
else:
raise ValueError(f"Unsupported duration format: {duration_str}")

Check warning on line 669 in assume/common/utils.py

View check run for this annotation

Codecov / codecov/patch

assume/common/utils.py#L669

Added line #L669 was not covered by tests
6 changes: 3 additions & 3 deletions assume/strategies/advanced_orders.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
#
# SPDX-License-Identifier: AGPL-3.0-or-later

import pandas as pd

from assume.common.base import BaseStrategy, SupportsMinMax
from assume.common.market_objects import MarketConfig, Orderbook, Product
from assume.common.utils import parse_duration
from assume.strategies.flexable import (
calculate_EOM_price_if_off,
calculate_EOM_price_if_on,
Expand All @@ -29,7 +29,7 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# check if kwargs contains eom_foresight argument
self.foresight = pd.Timedelta(kwargs.get("eom_foresight", "12h"))
self.foresight = parse_duration(kwargs.get("eom_foresight", "12h"))

def calculate_bids(
self,
Expand Down Expand Up @@ -219,7 +219,7 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# check if kwargs contains eom_foresight argument
self.foresight = pd.Timedelta(kwargs.get("eom_foresight", "12h"))
self.foresight = parse_duration(kwargs.get("eom_foresight", "12h"))

def calculate_bids(
self,
Expand Down
Loading

0 comments on commit 13afb66

Please sign in to comment.