Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve fast pandas and flexABLE startegies for better performance #502

Closed
wants to merge 13 commits into from
106 changes: 43 additions & 63 deletions assume/common/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@
self.outputs = defaultdict(lambda: FastSeries(value=0.0, index=self.index))
# series does not like to convert from tensor to float otherwise

self.avg_op_time = 0

# some data is stored as series to allow to store it in the outputs
# check if any bidding strategy is using the RL strategy
if any(
Expand All @@ -70,6 +72,7 @@
index=self.index,
)
self.outputs["reward"] = FastSeries(value=0.0, index=self.index)
self.outputs["regret"] = FastSeries(value=0.0, index=self.index)

# RL data stored as lists to simplify storing to the buffer
self.outputs["rl_observations"] = []
Expand Down Expand Up @@ -139,25 +142,51 @@
Iterates through the orderbook, adding the accepted volumes to the corresponding time slots
in the dispatch plan. It then calculates the cashflow and the reward for the bidding strategies.

Additionally, updates the average operation and downtime dynamically.

Args:
marketconfig (MarketConfig): The market configuration.
orderbook (Orderbook): The orderbook.

"""

product_type = marketconfig.product_type

# Initialize counters for operation and downtime updates
total_op_time = self.avg_op_time * len(self.outputs[product_type])
total_periods = len(self.outputs[product_type])

for order in orderbook:
start = order["start_time"]
end = order["end_time"]
# end includes the end of the last product, to get the last products' start time we deduct the frequency once
end_excl = end - self.index.freq

# Determine the added volume
if isinstance(order["accepted_volume"], dict):
added_volume = list(order["accepted_volume"].values())
else:
added_volume = order["accepted_volume"]

# Update outputs and track changes
current_slice = self.outputs[product_type].loc[start:end_excl]
self.outputs[product_type].loc[start:end_excl] += added_volume
self.calculate_cashflow(product_type, orderbook)

# Detect changes in operation/downtime
for idx, volume in enumerate(
self.outputs[product_type].loc[start:end_excl]
):
was_operating = current_slice[idx] > 0
now_operating = volume > 0

if was_operating and not now_operating: # Transition to downtime
total_op_time -= 1

Check warning on line 181 in assume/common/base.py

View check run for this annotation

Codecov / codecov/patch

assume/common/base.py#L181

Added line #L181 was not covered by tests
elif not was_operating and now_operating: # Transition to operating
total_op_time += 1

Check warning on line 183 in assume/common/base.py

View check run for this annotation

Codecov / codecov/patch

assume/common/base.py#L183

Added line #L183 was not covered by tests

# Recalculate averages
self.avg_op_time = total_op_time / total_periods

# Calculate cashflow and reward
self.calculate_cashflow(product_type, orderbook)
self.bidding_strategies[marketconfig.market_id].calculate_reward(
unit=self,
marketconfig=marketconfig,
Expand Down Expand Up @@ -411,56 +440,6 @@
# Return positive time if operating, negative if shut down
return -run if is_off else run

def get_average_operation_times(self, start: datetime) -> tuple[float, float]:
"""
Calculates the average uninterrupted operation and down time.

Args:
start (datetime.datetime): The current time.

Returns:
tuple[float, float]: Tuple of the average operation time avg_op_time and average down time avg_down_time.

Note:
down_time in general is indicated with negative values
"""
op_series = []

before = start - self.index.freq
arr = self.outputs["energy"].loc[self.index[0] : before][::-1] > 0

if len(arr) < 1:
# before start of index
return max(self.min_operating_time, 1), min(-self.min_down_time, -1)

op_series = []
status = arr[0]
run = 0
for val in arr:
if val == status:
run += 1
else:
op_series.append(-((-1) ** status) * run)
run = 1
status = val
op_series.append(-((-1) ** status) * run)

op_times = [operation for operation in op_series if operation > 0]
if op_times == []:
avg_op_time = self.min_operating_time
else:
avg_op_time = sum(op_times) / len(op_times)

down_times = [operation for operation in op_series if operation < 0]
if down_times == []:
avg_down_time = self.min_down_time
else:
avg_down_time = sum(down_times) / len(down_times)

return max(1, avg_op_time, self.min_operating_time), min(
-1, avg_down_time, -self.min_down_time
)

def get_starting_costs(self, op_time: int) -> float:
"""
Returns the start-up cost for the given operation time.
Expand All @@ -475,19 +454,20 @@
float: The start-up costs depending on the down time.
"""
if op_time > 0:
# unit is running
# The unit is running, no start-up cost is needed
return 0

if self.downtime_hot_start is not None and self.hot_start_cost is not None:
if -op_time <= self.downtime_hot_start:
return self.hot_start_cost
if self.downtime_warm_start is not None and self.warm_start_cost is not None:
if -op_time <= self.downtime_warm_start:
return self.warm_start_cost
if self.cold_start_cost is not None:
return self.cold_start_cost
downtime = abs(op_time)

return 0
# Check and return the appropriate start-up cost
if downtime <= self.downtime_hot_start:
return self.hot_start_cost

Check warning on line 464 in assume/common/base.py

View check run for this annotation

Codecov / codecov/patch

assume/common/base.py#L464

Added line #L464 was not covered by tests

if downtime <= self.downtime_warm_start:
return self.warm_start_cost

Check warning on line 467 in assume/common/base.py

View check run for this annotation

Codecov / codecov/patch

assume/common/base.py#L467

Added line #L467 was not covered by tests

# If it exceeds warm start threshold, return cold start cost
return self.cold_start_cost


class SupportsMinMaxCharge(BaseUnit):
Expand Down
126 changes: 93 additions & 33 deletions assume/common/fast_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# SPDX-License-Identifier: AGPL-3.0-or-later

from datetime import datetime, timedelta
from functools import lru_cache

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -56,7 +55,11 @@
total_seconds = (self._end - self._start).total_seconds()
self._count = int(np.floor(total_seconds / self._freq_seconds)) + 1

self._tolerance_seconds = 1
# Precompute the mapping
self._date_to_index = {
self._start + i * self._freq: i for i in range(self._count)
}

self._date_list = None # Lazy-loaded

@property
Expand All @@ -79,11 +82,6 @@
"""Get the frequency of the index in total seconds."""
return self._freq_seconds

@property
def tolerance_seconds(self) -> int:
"""Get the tolerance in seconds for date alignment."""
return self._tolerance_seconds

def __getitem__(self, item: int | slice):
"""
Retrieve datetime(s) based on the specified index or slice.
Expand Down Expand Up @@ -131,9 +129,7 @@
if not sliced_dates:
return []

return FastIndex(
start=sliced_dates[0], end=sliced_dates[-1], freq=self._freq
)
return sliced_dates

else:
raise TypeError("Index must be an integer or a slice")
Expand Down Expand Up @@ -184,7 +180,6 @@
"""Return an informal string representation of the FastIndex."""
return self.__repr__()

@lru_cache(maxsize=100)
def get_date_list(
self, start: datetime | None = None, end: datetime | None = None
) -> list[datetime]:
Expand Down Expand Up @@ -218,7 +213,6 @@
# Convert to pandas DatetimeIndex
return pd.DatetimeIndex(pd.to_datetime(datetimes), name="FastIndex")

@lru_cache(maxsize=1000)
def _get_idx_from_date(self, date: datetime) -> int:
"""
Convert a datetime to its corresponding index in the range.
Expand All @@ -233,21 +227,11 @@
KeyError: If the input `date` is None.
ValueError: If the `date` is not aligned with the frequency within tolerance.
"""
if date is None:
raise KeyError("Date cannot be None. Please provide a valid datetime.")

delta_seconds = (date - self.start).total_seconds()
remainder = delta_seconds % self.freq_seconds

if remainder > self.tolerance_seconds and remainder < (
self.freq_seconds - self.tolerance_seconds
):
if date not in self._date_to_index:
raise ValueError(
f"Date {date} is not aligned with frequency {self.freq_seconds} seconds. "
f"Allowed tolerance: {self.tolerance_seconds} seconds."
f"Date {date} is not aligned with the frequency or out of range."
)

return round(delta_seconds / self.freq_seconds)
return self._date_to_index[date]

@staticmethod
def _convert_to_datetime(value: datetime | str) -> datetime:
Expand Down Expand Up @@ -312,8 +296,6 @@

self._index = index
self._name = name
self.loc = self # Allow adjusting loc as well
self.at = self

count = len(self.index) # Use index length directly
self._data = (
Expand Down Expand Up @@ -384,6 +366,16 @@
"""
return self.data.dtype

@property
def loc(self):
"""
Label-based indexing property.

Returns:
FastSeriesLocIndexer: Indexer for label-based access.
"""
return FastSeriesLocIndexer(self)

@property
def iloc(self):
"""
Expand All @@ -394,6 +386,16 @@
"""
return FastSeriesILocIndexer(self)

@property
def at(self):
"""
Label-based single-item access property.

Returns:
FastSeriesAtIndexer: Indexer for label-based single-element access.
"""
return FastSeriesAtIndexer(self)

@property
def iat(self):
"""
Expand Down Expand Up @@ -444,12 +446,6 @@
[(d - self.index.start).total_seconds() for d in dates]
)
indices = (delta_seconds / self.index.freq_seconds).round().astype(int)
remainders = delta_seconds % self.index.freq_seconds

if not np.all(remainders <= self.index.tolerance_seconds):
raise ValueError(
"One or more dates are not aligned with the index frequency."
)
return self.data[indices]

elif isinstance(item, str):
Expand Down Expand Up @@ -976,6 +972,39 @@
return result


class FastSeriesLocIndexer:
def __init__(self, series: FastSeries):
self._series = series

def __getitem__(
self, item: datetime | slice | list | pd.Index | pd.Series | np.ndarray | str
):
"""
Retrieve item(s) using label-based indexing.

Parameters:
item (datetime | slice | list | pd.Index | pd.Series | np.ndarray | str): The label(s) to retrieve.

Returns:
float | np.ndarray: The retrieved value(s).
"""
return self._series.__getitem__(item)

def __setitem__(
self,
item: datetime | slice | list | pd.Index | pd.Series | np.ndarray | str,
value: float | np.ndarray,
):
"""
Assign value(s) using label-based indexing.

Parameters:
item (datetime | slice | list | pd.Index | pd.Series | np.ndarray | str): The label(s) to set.
value (float | np.ndarray): The value(s) to assign.
"""
self._series.__setitem__(item, value)


class FastSeriesILocIndexer:
def __init__(self, series: FastSeries):
self._series = series
Expand Down Expand Up @@ -1070,6 +1099,37 @@
)


class FastSeriesAtIndexer:
def __init__(self, series: FastSeries):
self._series = series

def __getitem__(self, item: datetime | str):
"""
Retrieve a single item using label-based indexing.

Parameters:
item (datetime | str): The label.

Returns:
float: The retrieved value.
"""
if isinstance(item, str):
item = pd.to_datetime(item).to_pydatetime()

Check warning on line 1117 in assume/common/fast_pandas.py

View check run for this annotation

Codecov / codecov/patch

assume/common/fast_pandas.py#L1117

Added line #L1117 was not covered by tests
return self._series[item]

def __setitem__(self, item: datetime | str, value: float):
"""
Assign a value using label-based indexing.

Parameters:
item (datetime | str): The label.
value (float): The value to assign.
"""
if isinstance(item, str):
item = pd.to_datetime(item).to_pydatetime()

Check warning on line 1129 in assume/common/fast_pandas.py

View check run for this annotation

Codecov / codecov/patch

assume/common/fast_pandas.py#L1129

Added line #L1129 was not covered by tests
self._series[item] = value


class FastSeriesIatIndexer:
def __init__(self, series: FastSeries):
self._series = series
Expand Down
9 changes: 7 additions & 2 deletions assume/common/forecasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,13 @@

if column not in self.forecasts.keys():
if "availability" in column:
return FastSeries(value=1.0, index=self.index)
return FastSeries(value=0.0, index=self.index)
self.forecasts[column] = FastSeries(

Check warning on line 144 in assume/common/forecasts.py

View check run for this annotation

Codecov / codecov/patch

assume/common/forecasts.py#L144

Added line #L144 was not covered by tests
value=1.0, index=self.index, name=column
)
else:
self.forecasts[column] = FastSeries(

Check warning on line 148 in assume/common/forecasts.py

View check run for this annotation

Codecov / codecov/patch

assume/common/forecasts.py#L148

Added line #L148 was not covered by tests
value=0.0, index=self.index, name=column
)

return self.forecasts[column]

Expand Down
Loading