Skip to content

Commit

Permalink
All test working with new AxisDict
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaeljolivet committed Feb 27, 2024
1 parent 25c0799 commit 1689b11
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 98 deletions.
67 changes: 67 additions & 0 deletions lca_algebraic/axis_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from sympy import symbols
from sympy.core.containers import Dict as SympyDict

NO_AXIS = symbols("_other_")


class AxisDict(SympyDict):
"""This class acts like a dict with arithmetic operations. It is useful to process 'axes' LCA computations"""

def _apply_op(self, other, fop, null_val):
# None is the key for non flagged values
if not isinstance(other, AxisDict):
dic = dict()
dic[NO_AXIS] = other
other = AxisDict(dic)

all_keys = set(other._dict.keys()) | set(self._dict.keys())
return AxisDict({key: fop(self._dict.get(key, null_val), other._dict.get(key, null_val)) for key in all_keys})

def __repr__(self):
"""Custom representation that returns string as key instead of symbols"""
return "{%s}" % ",".join("'%s': %s" % (k.__repr__(), v.__repr__()) for k, v in self._dict.items())

def __str__(self):
return self.__repr__()

def _apply_self(self, fop):
return AxisDict({key: fop(val) for key, val in self._dict.items()})

def __add__(self, other):
return self._apply_op(other, lambda a, b: a + b, 0)

def __radd__(self, other):
return self._apply_op(other, lambda a, b: b + a, 0)

def __mul__(self, other):
return self._apply_self(lambda a: a * other)

def __rmul__(self, other):
return self._apply_self(lambda a: other * a)

def __truediv__(self, other):
return self._apply_self(lambda a: a / other)

def __rtruediv__(self, other):
return NotImplemented

def _defer(self, funcname, args, kwargs):
return AxisDict(
{
key: val if not hasattr(val, funcname) else getattr(val, funcname)(*args, **kwargs)
for key, val in self._dict.items()
}
)

def str_keys(self):
# REturn a list to ensure the order is kept
return list(str(key) for key in self._dict.keys())

@property
def free_symbols(self):
"""Only return free symbol for values (not keys)"""
res = set()
for key, val in self._dict.items():
if hasattr(val, "free_symbols"):
res |= val.free_symbols
return res
60 changes: 0 additions & 60 deletions lca_algebraic/base_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,66 +142,6 @@ def _snake2camel(val):
return "".join(word.title() for word in val.split("_"))


class SymDict:
"""This class acts like a dict with arithmetic operations. It is useful to process 'axes' LCA computations"""

def __init__(self, values):
self.dict = values

def _apply_op(self, other, fop, null_val):
# None is the key for non flagged values
if not isinstance(other, SymDict):
dic = dict()
dic[None] = other
other = SymDict(dic)

all_keys = set(other.dict.keys()) | set(self.dict.keys())
return SymDict({key: fop(self.dict.get(key, null_val), other.dict.get(key, null_val)) for key in all_keys})

def _apply_self(self, fop):
return SymDict({key: fop(val) for key, val in self.dict.items()})

def __add__(self, other):
return self._apply_op(other, lambda a, b: a + b, 0)

def __radd__(self, other):
return self._apply_op(other, lambda a, b: b + a, 0)

def __mul__(self, other):
return self._apply_self(lambda a: a * other)

def __rmul__(self, other):
return self._apply_self(lambda a: other * a)

def __truediv__(self, other):
return self._apply_self(lambda a: a / other)

def __rtruediv__(self, other):
return NotImplemented

def __repr__(self):
return "{" + "; ".join("%s: %s" % (key, str(val)) for key, val in self.dict.items()) + "}"

def _defer(self, funcname, args, kwargs):
return SymDict(
{
key: val if not hasattr(val, funcname) else getattr(val, funcname)(*args, **kwargs)
for key, val in self.dict.items()
}
)

def xreplace(self, *args, **kwargs):
return self._defer("xreplace", args, kwargs)

@property
def free_symbols(self):
res = set()
for key, val in self.dict.items():
if hasattr(val, "free_symbols"):
res |= val.free_symbols
return list(res)


class TabbedDataframe:
"""This class holds a dictionnary of dataframes and can display and saved them awith 'tabs'/'sheets'"""

Expand Down
69 changes: 35 additions & 34 deletions lca_algebraic/lca.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from peewee import DoesNotExist
from sympy import Symbol, lambdify, parse_expr

from . import SymDict, TabbedDataframe
from . import TabbedDataframe
from .axis_dict import AxisDict
from .base_utils import _actName, _getDb, _method_unit
from .cache import ExprCache, LCIACache
from .helpers import (
Expand All @@ -20,8 +21,6 @@
Basic,
DbContext,
Dict,
Expr,
Union,
_actDesc,
_getAmountOrFormula,
_isForeground,
Expand Down Expand Up @@ -202,34 +201,31 @@ def _filter_param_values(params, expanded_param_names):
return {key: val for key, val in params.items() if key in expanded_param_names}


def _free_symbols(expr: Union[SymDict, Expr]):
if isinstance(expr, SymDict):
# SymDict => sum of vars of members
return set.union(*[_free_symbols(ex) for ex in expr.dict.values()])
elif isinstance(expr, Expr):
def _free_symbols(expr: Basic):
if isinstance(expr, Basic):
return set([str(symb) for symb in expr.free_symbols])
else:
# Static value
return set()


def _lambdify(expr: Union[SymDict, Expr], expanded_params):
def _lambdify(expr: Basic, expanded_params):
"""Lambdify, handling manually the case of SymDict (for impacts by axis)"""
if isinstance(expr, SymDict):
lambd_dict = dict()
for key, val in expr.dict.items():
lambd_dict[key] = _lambdify(val, expanded_params)

# Dynamic function calling lambda function with same params
def dict_func(*args, **kwargs):
return SymDict({key: func(*args, **kwargs) for key, func in lambd_dict.items()})
if isinstance(expr, Basic):
lambd = lambdify(expanded_params, expr, "numpy")

def func(*arg, **kwargs):
res = lambd(*arg, **kwargs)
if isinstance(res, dict):
# Transform key symbols into Str
return {str(k): v for k, v in res.items()}
else:
return res

return dict_func
return func

elif isinstance(expr, Expr):
return lambdify(expanded_params, expr, "numpy")
else:
# Not an expression : return statis func
# Not an expression : return static func
def static_func(*args, **kargs):
return expr

Expand Down Expand Up @@ -296,12 +292,12 @@ def __init__(self, exprOrDict, expanded_params=None, params=None, sobols=None):

@property
def has_axis(self):
return isinstance(self.expr, SymDict)
return isinstance(self.expr, AxisDict)

@property
def axis_keys(self):
if self.has_axis:
return list(self.expr.dict.keys())
return self.expr.str_keys()
else:
return None

Expand All @@ -322,11 +318,7 @@ def compute(self, **params) -> ValueContext:
return ValueContext(value=value, context=completed_params)

def serialize(self):
if isinstance(self.expr, SymDict):
expr = {key: str(sym) for key, sym in self.expr.dict.items()}
else:
expr = str(self.expr)

expr = str(self.expr)
return dict(params=self.params, expr=expr, sobols=self.sobols)

def __repr__(self):
Expand Down Expand Up @@ -410,8 +402,8 @@ def process(args):
value = value_context.value

# Expand axis values as a list, to fit into the result numpy array
if lambd.has_axis:
value = list(float(val) for val in value.dict.values())
if isinstance(value, dict):
value = list(float(val) for val in value.values())

return (imethod, value)

Expand Down Expand Up @@ -669,17 +661,26 @@ def _replace_fixed_params(expr, fixed_params, fixed_mode=FixedParamMode.DEFAULT)
return expr.xreplace(sub)


def _safe_axis(axis_name: str):
if axis_name.isalnum():
return axis_name
else:
return re.sub("[^0-9a-zA-Z]+", "*", axis_name)


def _tag_expr(expr, act, axis):
"""Tag expression for one axe. Check the child expression is not already tagged with different values"""
axis_tag = act.get(axis, None)

if axis_tag is None:
return expr

if isinstance(expr, SymDict):
axis_tag = _safe_axis(axis_tag)

if isinstance(expr, AxisDict):
res = 0
for key, val in expr.dict.items():
if key is not None and key != axis_tag:
for key, val in expr._dict.items():
if key is not None and str(key) != axis_tag:
raise ValueError(
"Inconsistent axis for one change of '%s' : attempt to tag as '%s'. "
"Already tagged as '%s'. Value of the exchange : %s" % (act["name"], axis_tag, key, str(val))
Expand All @@ -688,7 +689,7 @@ def _tag_expr(expr, act, axis):
else:
res = expr

return SymDict({axis_tag: res})
return AxisDict({axis_tag: res})


@with_db_context(arg="act")
Expand Down
3 changes: 1 addition & 2 deletions lca_algebraic/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,14 @@
from matplotlib.lines import Line2D
from SALib.analyze import sobol as analyse_sobol
from SALib.sample import sobol, sobol_sequence
from sympy import Abs, Add, AtomicExpr, Eq, Float, Mul, Number, Piecewise, Sum
from sympy import Abs, Add, AtomicExpr, Eq, Expr, Float, Mul, Number, Piecewise, Sum
from sympy.core.operations import AssocOp

from .base_utils import _display_tabs, _method_unit, displayWithExportButton, r_squared
from .lca import (
Activity,
DbContext,
Dict,
Expr,
LambdaWithParamNames,
Symbol,
_expanded_names_to_names,
Expand Down
40 changes: 40 additions & 0 deletions test/test_axis_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from sympy import symbols, lambdify, simplify

from lca_algebraic.axis_dict import AxisDict, NO_AXIS


def test_sum():
a, b = symbols("a b")

a1 = AxisDict({a: 1})
a2 = AxisDict({a: 2})
b2 = AxisDict({b: 2})

assert a1 + b2 == AxisDict({a: 1, b: 2})
assert a1 + a2 == AxisDict({a: 3})
assert a1 + 1 == AxisDict({a: 1, NO_AXIS: 1})


def test_mul():
a = symbols("a")

a1 = AxisDict({a: 2})
assert a1 * 2 == AxisDict({a: 4})
assert simplify(a1 / 2) == AxisDict({a: 1})


def test_free_symbols():
dic = AxisDict({"a": "b"})
assert dic.free_symbols == set([symbols("b")])


def test_lambdify():
a, b = symbols("a b")

a1 = AxisDict({a: b * 2})

lambd = lambdify([b], a1)

res = lambd(2)

assert res == {a: 4}
3 changes: 1 addition & 2 deletions test/unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,8 +455,7 @@ def test_axis(data):

res = {key: val for key, val in zip(res.index.values, res[res.columns[0]].values)}

expected = dict(a=2.0, b=4.0)
expected["*other*"] = 6.0
expected = dict(a=2.0, b=4.0, _other_=6.0)
expected["*sum*"] = 12.0

assert res == expected
Expand Down

0 comments on commit 1689b11

Please sign in to comment.