Skip to content

Commit

Permalink
Add missing columns to table if not exists (#191)
Browse files Browse the repository at this point in the history
support adding additional float and text columns
simplifies database engine usage as pandas now fully supports sqlalchemy
v2

fixes #188
  • Loading branch information
maurerle authored Sep 15, 2023
1 parent fd574b0 commit dd9c0e9
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 15 deletions.
3 changes: 3 additions & 0 deletions assume/common/forecasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,15 @@ def set_forecast(self, data: pd.DataFrame | pd.Series | None, prefix=""):
return
elif isinstance(data, pd.DataFrame):
if prefix:
# set prefix for columns to set
columns = [prefix + column for column in data.columns]
data.columns = columns
if len(data.index) == 1:
# if we have a single value which should be set for the whole series
for column in data.columns:
self.forecasts[column] = data[column].item()
else:
# if some columns already exist, just add the new columns
new_columns = set(data.columns) - set(self.forecasts.columns)
self.forecasts = pd.concat(
[self.forecasts, data[list(new_columns)]], axis=1
Expand Down
46 changes: 39 additions & 7 deletions assume/common/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pandas as pd
from dateutil import rrule as rr
from mango import Role
from pandas.api.types import is_numeric_dtype
from sqlalchemy import inspect, text
from sqlalchemy.exc import ProgrammingError

Expand Down Expand Up @@ -84,10 +85,10 @@ def delete_db_scenario(self, simulation_id):

# Loop throuph all database tables
# Get list of table names in database
table_names = inspect(self.db.bind).get_table_names()
table_names = inspect(self.db).get_table_names()
# Iterate through each table
for table_name in table_names:
with self.db() as db:
with self.db.begin() as db:
# Read table into Pandas DataFrame
query = text(
f"delete from {table_name} where simulation = '{simulation_id}'"
Expand All @@ -101,7 +102,7 @@ def del_similar_runs(self):
query = text("select distinct simulation from market_meta")

try:
with self.db() as db:
with self.db.begin() as db:
simulations = db.execute(query).fetchall()
except Exception:
simulations = []
Expand Down Expand Up @@ -212,9 +213,39 @@ async def store_dfs(self):
df.to_csv(data_path, mode="a", header=not data_path.exists())

if self.db is not None:
df.to_sql(table, self.db.bind, if_exists="append")
try:
with self.db.begin() as db:
df.to_sql(table, db, if_exists="append")
except ProgrammingError:
self.check_columns(table, df)
# now try again
with self.db.begin() as db:
df.to_sql(table, db, if_exists="append")

self.write_dfs[table] = []

def check_columns(self, table: str, df: pd.DataFrame):
"""
If a simulation before has been started which does not include an additional field
we try to add the field.
For now, this only works for float and text.
An alternative which finds the correct types would be to use
"""
with self.db.begin() as db:
# Read table into Pandas DataFrame
query = f"select * from {table} where 1=0"
db_columns = pd.read_sql(query, db).columns
for column in df.columns:
if column not in db_columns:
try:
# TODO this only works for float and text
column_type = "float" if is_numeric_dtype(df[column]) else "text"
query = f"ALTER TABLE {table} ADD COLUMN {column} {column_type}"
with self.db.begin() as db:
db.execute(text(query))
except Exception:
logger.exception("Error converting column")

def check_for_tensors(self, data):
"""
Checks if the data contains tensors and converts them to floats.
Expand Down Expand Up @@ -321,7 +352,7 @@ async def on_stop(self):

try:
for query in queries:
df = pd.read_sql(query, self.db.bind)
df = pd.read_sql(query, self.db)
dfs.append(df)
df = pd.concat(dfs)
df.reset_index()
Expand All @@ -335,7 +366,8 @@ async def on_stop(self):
index=None,
)
if self.db is not None and not df.empty:
df.to_sql("kpis", self.db.bind, if_exists="append", index=None)
with self.db.begin() as db:
df.to_sql("kpis", self.db, if_exists="append", index=None)
except ProgrammingError as e:
self.db.rollback()
logger.error(f"No scenario run Yet {e}")
Expand All @@ -357,7 +389,7 @@ def get_sum_reward(self):
)

try:
with self.db() as db:
with self.db.begin() as db:
avg_reward = db.execute(query).fetchall()[0]
except Exception:
avg_reward = 0
Expand Down
8 changes: 4 additions & 4 deletions assume/world.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,13 @@ def __init__(
self.export_csv_path = export_csv_path
# intialize db connection at beginning of simulation
if database_uri:
self.db = scoped_session(sessionmaker(create_engine(database_uri)))
self.db = create_engine(database_uri)
connected = False
while not connected:
try:
self.db.connection()
connected = True
self.logger.info("connected to db")
with self.db.connect():
connected = True
self.logger.info("connected to db")
except OperationalError as e:
self.logger.error(
f"could not connect to {database_uri}, trying again"
Expand Down
8 changes: 4 additions & 4 deletions tests/test_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


def test_output_market_orders():
engine = scoped_session(sessionmaker(create_engine(DB_URI)))
engine = create_engine(DB_URI)
start = datetime(2020, 1, 1)
end = datetime(2020, 1, 2)
output_writer = WriteOutput("test_sim", start, end, engine)
Expand Down Expand Up @@ -73,7 +73,7 @@ def test_output_market_orders():


def test_output_market_results():
engine = scoped_session(sessionmaker(create_engine(DB_URI)))
engine = create_engine(DB_URI)
start = datetime(2020, 1, 1)
end = datetime(2020, 1, 2)
output_writer = WriteOutput("test_sim", start, end, engine)
Expand Down Expand Up @@ -106,7 +106,7 @@ def test_output_market_results():


def test_output_market_dispatch():
engine = scoped_session(sessionmaker(create_engine(DB_URI)))
engine = create_engine(DB_URI)
start = datetime(2020, 1, 1)
end = datetime(2020, 1, 2)
output_writer = WriteOutput("test_sim", start, end, engine)
Expand All @@ -118,7 +118,7 @@ def test_output_market_dispatch():


def test_output_unit_dispatch():
engine = scoped_session(sessionmaker(create_engine(DB_URI)))
engine = create_engine(DB_URI)
start = datetime(2020, 1, 1)
end = datetime(2020, 1, 2)
output_writer = WriteOutput("test_sim", start, end, engine)
Expand Down

0 comments on commit dd9c0e9

Please sign in to comment.