Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Try to migrate the existing database, by adding missing columns #191

Merged
merged 1 commit into from
Sep 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions assume/common/forecasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,15 @@ def set_forecast(self, data: pd.DataFrame | pd.Series | None, prefix=""):
return
elif isinstance(data, pd.DataFrame):
if prefix:
# set prefix for columns to set
columns = [prefix + column for column in data.columns]
data.columns = columns
if len(data.index) == 1:
# if we have a single value which should be set for the whole series
for column in data.columns:
self.forecasts[column] = data[column].item()
else:
# if some columns already exist, just add the new columns
new_columns = set(data.columns) - set(self.forecasts.columns)
self.forecasts = pd.concat(
[self.forecasts, data[list(new_columns)]], axis=1
Expand Down
46 changes: 39 additions & 7 deletions assume/common/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pandas as pd
from dateutil import rrule as rr
from mango import Role
from pandas.api.types import is_numeric_dtype
from sqlalchemy import inspect, text
from sqlalchemy.exc import ProgrammingError

Expand Down Expand Up @@ -84,10 +85,10 @@ def delete_db_scenario(self, simulation_id):

# Loop throuph all database tables
# Get list of table names in database
table_names = inspect(self.db.bind).get_table_names()
table_names = inspect(self.db).get_table_names()
# Iterate through each table
for table_name in table_names:
with self.db() as db:
with self.db.begin() as db:
# Read table into Pandas DataFrame
query = text(
f"delete from {table_name} where simulation = '{simulation_id}'"
Expand All @@ -101,7 +102,7 @@ def del_similar_runs(self):
query = text("select distinct simulation from market_meta")

try:
with self.db() as db:
with self.db.begin() as db:
simulations = db.execute(query).fetchall()
except Exception:
simulations = []
Expand Down Expand Up @@ -212,9 +213,39 @@ async def store_dfs(self):
df.to_csv(data_path, mode="a", header=not data_path.exists())

if self.db is not None:
df.to_sql(table, self.db.bind, if_exists="append")
try:
with self.db.begin() as db:
df.to_sql(table, db, if_exists="append")
except ProgrammingError:
self.check_columns(table, df)
# now try again
with self.db.begin() as db:
df.to_sql(table, db, if_exists="append")

self.write_dfs[table] = []

def check_columns(self, table: str, df: pd.DataFrame):
"""
If a simulation before has been started which does not include an additional field
we try to add the field.
For now, this only works for float and text.
An alternative which finds the correct types would be to use
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What other types there might be? anyway we cannot write dicts or enything else.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we do not detect bool or datetime values. But for most use cases it will be enough

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so maybe we can for now add a raise if the type is not detected? so we see if the problem ever occurs. Or it won't work anyway?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the field is bool (0 or 1) or int, we will add it as float numeric.
Datetime might be added as unix timestamp numeric or raise an error.

This will probably be enough to visualize the values, even if it is not the proper datatype.
And in other cases it will raise anyway. So I don't think we need additional error raising

"""
with self.db.begin() as db:
# Read table into Pandas DataFrame
query = f"select * from {table} where 1=0"
db_columns = pd.read_sql(query, db).columns
for column in df.columns:
if column not in db_columns:
try:
# TODO this only works for float and text
column_type = "float" if is_numeric_dtype(df[column]) else "text"
query = f"ALTER TABLE {table} ADD COLUMN {column} {column_type}"
maurerle marked this conversation as resolved.
Show resolved Hide resolved
with self.db.begin() as db:
db.execute(text(query))
except Exception:
logger.exception("Error converting column")

def check_for_tensors(self, data):
"""
Checks if the data contains tensors and converts them to floats.
Expand Down Expand Up @@ -321,7 +352,7 @@ async def on_stop(self):

try:
for query in queries:
df = pd.read_sql(query, self.db.bind)
df = pd.read_sql(query, self.db)
dfs.append(df)
df = pd.concat(dfs)
df.reset_index()
Expand All @@ -335,7 +366,8 @@ async def on_stop(self):
index=None,
)
if self.db is not None and not df.empty:
df.to_sql("kpis", self.db.bind, if_exists="append", index=None)
with self.db.begin() as db:
df.to_sql("kpis", self.db, if_exists="append", index=None)
except ProgrammingError as e:
self.db.rollback()
logger.error(f"No scenario run Yet {e}")
Expand All @@ -357,7 +389,7 @@ def get_sum_reward(self):
)

try:
with self.db() as db:
with self.db.begin() as db:
avg_reward = db.execute(query).fetchall()[0]
except Exception:
avg_reward = 0
Expand Down
8 changes: 4 additions & 4 deletions assume/world.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,13 @@ def __init__(
self.export_csv_path = export_csv_path
# intialize db connection at beginning of simulation
if database_uri:
self.db = scoped_session(sessionmaker(create_engine(database_uri)))
self.db = create_engine(database_uri)
connected = False
while not connected:
try:
self.db.connection()
connected = True
self.logger.info("connected to db")
with self.db.connect():
connected = True
self.logger.info("connected to db")
except OperationalError as e:
self.logger.error(
f"could not connect to {database_uri}, trying again"
Expand Down
8 changes: 4 additions & 4 deletions tests/test_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


def test_output_market_orders():
engine = scoped_session(sessionmaker(create_engine(DB_URI)))
engine = create_engine(DB_URI)
start = datetime(2020, 1, 1)
end = datetime(2020, 1, 2)
output_writer = WriteOutput("test_sim", start, end, engine)
Expand Down Expand Up @@ -73,7 +73,7 @@ def test_output_market_orders():


def test_output_market_results():
engine = scoped_session(sessionmaker(create_engine(DB_URI)))
engine = create_engine(DB_URI)
start = datetime(2020, 1, 1)
end = datetime(2020, 1, 2)
output_writer = WriteOutput("test_sim", start, end, engine)
Expand Down Expand Up @@ -106,7 +106,7 @@ def test_output_market_results():


def test_output_market_dispatch():
engine = scoped_session(sessionmaker(create_engine(DB_URI)))
engine = create_engine(DB_URI)
start = datetime(2020, 1, 1)
end = datetime(2020, 1, 2)
output_writer = WriteOutput("test_sim", start, end, engine)
Expand All @@ -118,7 +118,7 @@ def test_output_market_dispatch():


def test_output_unit_dispatch():
engine = scoped_session(sessionmaker(create_engine(DB_URI)))
engine = create_engine(DB_URI)
start = datetime(2020, 1, 1)
end = datetime(2020, 1, 2)
output_writer = WriteOutput("test_sim", start, end, engine)
Expand Down