diff --git a/lca_algebraic/axis_dict.py b/lca_algebraic/axis_dict.py new file mode 100644 index 0000000..63188d1 --- /dev/null +++ b/lca_algebraic/axis_dict.py @@ -0,0 +1,67 @@ +from sympy import symbols +from sympy.core.containers import Dict as SympyDict + +NO_AXIS = symbols("_other_") + + +class AxisDict(SympyDict): + """This class acts like a dict with arithmetic operations. It is useful to process 'axes' LCA computations""" + + def _apply_op(self, other, fop, null_val): + # None is the key for non flagged values + if not isinstance(other, AxisDict): + dic = dict() + dic[NO_AXIS] = other + other = AxisDict(dic) + + all_keys = set(other._dict.keys()) | set(self._dict.keys()) + return AxisDict({key: fop(self._dict.get(key, null_val), other._dict.get(key, null_val)) for key in all_keys}) + + def __repr__(self): + """Custom representation that returns string as key instead of symbols""" + return "{%s}" % ",".join("'%s': %s" % (k.__repr__(), v.__repr__()) for k, v in self._dict.items()) + + def __str__(self): + return self.__repr__() + + def _apply_self(self, fop): + return AxisDict({key: fop(val) for key, val in self._dict.items()}) + + def __add__(self, other): + return self._apply_op(other, lambda a, b: a + b, 0) + + def __radd__(self, other): + return self._apply_op(other, lambda a, b: b + a, 0) + + def __mul__(self, other): + return self._apply_self(lambda a: a * other) + + def __rmul__(self, other): + return self._apply_self(lambda a: other * a) + + def __truediv__(self, other): + return self._apply_self(lambda a: a / other) + + def __rtruediv__(self, other): + return NotImplemented + + def _defer(self, funcname, args, kwargs): + return AxisDict( + { + key: val if not hasattr(val, funcname) else getattr(val, funcname)(*args, **kwargs) + for key, val in self._dict.items() + } + ) + + def str_keys(self): + # REturn a list to ensure the order is kept + return list(str(key) for key in self._dict.keys()) + + @property + def free_symbols(self): + """Only return free symbol for values (not keys)""" + res = set() + for key, val in self._dict.items(): + if hasattr(val, "free_symbols"): + res |= val.free_symbols + return res diff --git a/lca_algebraic/base_utils.py b/lca_algebraic/base_utils.py index de3b477..40e3060 100644 --- a/lca_algebraic/base_utils.py +++ b/lca_algebraic/base_utils.py @@ -142,66 +142,6 @@ def _snake2camel(val): return "".join(word.title() for word in val.split("_")) -class SymDict: - """This class acts like a dict with arithmetic operations. It is useful to process 'axes' LCA computations""" - - def __init__(self, values): - self.dict = values - - def _apply_op(self, other, fop, null_val): - # None is the key for non flagged values - if not isinstance(other, SymDict): - dic = dict() - dic[None] = other - other = SymDict(dic) - - all_keys = set(other.dict.keys()) | set(self.dict.keys()) - return SymDict({key: fop(self.dict.get(key, null_val), other.dict.get(key, null_val)) for key in all_keys}) - - def _apply_self(self, fop): - return SymDict({key: fop(val) for key, val in self.dict.items()}) - - def __add__(self, other): - return self._apply_op(other, lambda a, b: a + b, 0) - - def __radd__(self, other): - return self._apply_op(other, lambda a, b: b + a, 0) - - def __mul__(self, other): - return self._apply_self(lambda a: a * other) - - def __rmul__(self, other): - return self._apply_self(lambda a: other * a) - - def __truediv__(self, other): - return self._apply_self(lambda a: a / other) - - def __rtruediv__(self, other): - return NotImplemented - - def __repr__(self): - return "{" + "; ".join("%s: %s" % (key, str(val)) for key, val in self.dict.items()) + "}" - - def _defer(self, funcname, args, kwargs): - return SymDict( - { - key: val if not hasattr(val, funcname) else getattr(val, funcname)(*args, **kwargs) - for key, val in self.dict.items() - } - ) - - def xreplace(self, *args, **kwargs): - return self._defer("xreplace", args, kwargs) - - @property - def free_symbols(self): - res = set() - for key, val in self.dict.items(): - if hasattr(val, "free_symbols"): - res |= val.free_symbols - return list(res) - - class TabbedDataframe: """This class holds a dictionnary of dataframes and can display and saved them awith 'tabs'/'sheets'""" diff --git a/lca_algebraic/lca.py b/lca_algebraic/lca.py index 838486c..885cab9 100644 --- a/lca_algebraic/lca.py +++ b/lca_algebraic/lca.py @@ -10,7 +10,8 @@ from peewee import DoesNotExist from sympy import Symbol, lambdify, parse_expr -from . import SymDict, TabbedDataframe +from . import TabbedDataframe +from .axis_dict import AxisDict from .base_utils import _actName, _getDb, _method_unit from .cache import ExprCache, LCIACache from .helpers import ( @@ -20,8 +21,6 @@ Basic, DbContext, Dict, - Expr, - Union, _actDesc, _getAmountOrFormula, _isForeground, @@ -202,34 +201,31 @@ def _filter_param_values(params, expanded_param_names): return {key: val for key, val in params.items() if key in expanded_param_names} -def _free_symbols(expr: Union[SymDict, Expr]): - if isinstance(expr, SymDict): - # SymDict => sum of vars of members - return set.union(*[_free_symbols(ex) for ex in expr.dict.values()]) - elif isinstance(expr, Expr): +def _free_symbols(expr: Basic): + if isinstance(expr, Basic): return set([str(symb) for symb in expr.free_symbols]) else: # Static value return set() -def _lambdify(expr: Union[SymDict, Expr], expanded_params): +def _lambdify(expr: Basic, expanded_params): """Lambdify, handling manually the case of SymDict (for impacts by axis)""" - if isinstance(expr, SymDict): - lambd_dict = dict() - for key, val in expr.dict.items(): - lambd_dict[key] = _lambdify(val, expanded_params) - - # Dynamic function calling lambda function with same params - def dict_func(*args, **kwargs): - return SymDict({key: func(*args, **kwargs) for key, func in lambd_dict.items()}) + if isinstance(expr, Basic): + lambd = lambdify(expanded_params, expr, "numpy") + + def func(*arg, **kwargs): + res = lambd(*arg, **kwargs) + if isinstance(res, dict): + # Transform key symbols into Str + return {str(k): v for k, v in res.items()} + else: + return res - return dict_func + return func - elif isinstance(expr, Expr): - return lambdify(expanded_params, expr, "numpy") else: - # Not an expression : return statis func + # Not an expression : return static func def static_func(*args, **kargs): return expr @@ -296,12 +292,12 @@ def __init__(self, exprOrDict, expanded_params=None, params=None, sobols=None): @property def has_axis(self): - return isinstance(self.expr, SymDict) + return isinstance(self.expr, AxisDict) @property def axis_keys(self): if self.has_axis: - return list(self.expr.dict.keys()) + return self.expr.str_keys() else: return None @@ -322,11 +318,7 @@ def compute(self, **params) -> ValueContext: return ValueContext(value=value, context=completed_params) def serialize(self): - if isinstance(self.expr, SymDict): - expr = {key: str(sym) for key, sym in self.expr.dict.items()} - else: - expr = str(self.expr) - + expr = str(self.expr) return dict(params=self.params, expr=expr, sobols=self.sobols) def __repr__(self): @@ -410,8 +402,8 @@ def process(args): value = value_context.value # Expand axis values as a list, to fit into the result numpy array - if lambd.has_axis: - value = list(float(val) for val in value.dict.values()) + if isinstance(value, dict): + value = list(float(val) for val in value.values()) return (imethod, value) @@ -669,6 +661,13 @@ def _replace_fixed_params(expr, fixed_params, fixed_mode=FixedParamMode.DEFAULT) return expr.xreplace(sub) +def _safe_axis(axis_name: str): + if axis_name.isalnum(): + return axis_name + else: + return re.sub("[^0-9a-zA-Z]+", "*", axis_name) + + def _tag_expr(expr, act, axis): """Tag expression for one axe. Check the child expression is not already tagged with different values""" axis_tag = act.get(axis, None) @@ -676,10 +675,12 @@ def _tag_expr(expr, act, axis): if axis_tag is None: return expr - if isinstance(expr, SymDict): + axis_tag = _safe_axis(axis_tag) + + if isinstance(expr, AxisDict): res = 0 - for key, val in expr.dict.items(): - if key is not None and key != axis_tag: + for key, val in expr._dict.items(): + if key is not None and str(key) != axis_tag: raise ValueError( "Inconsistent axis for one change of '%s' : attempt to tag as '%s'. " "Already tagged as '%s'. Value of the exchange : %s" % (act["name"], axis_tag, key, str(val)) @@ -688,7 +689,7 @@ def _tag_expr(expr, act, axis): else: res = expr - return SymDict({axis_tag: res}) + return AxisDict({axis_tag: res}) @with_db_context(arg="act") diff --git a/lca_algebraic/stats.py b/lca_algebraic/stats.py index b47cce2..9179578 100644 --- a/lca_algebraic/stats.py +++ b/lca_algebraic/stats.py @@ -13,7 +13,7 @@ from matplotlib.lines import Line2D from SALib.analyze import sobol as analyse_sobol from SALib.sample import sobol, sobol_sequence -from sympy import Abs, Add, AtomicExpr, Eq, Float, Mul, Number, Piecewise, Sum +from sympy import Abs, Add, AtomicExpr, Eq, Expr, Float, Mul, Number, Piecewise, Sum from sympy.core.operations import AssocOp from .base_utils import _display_tabs, _method_unit, displayWithExportButton, r_squared @@ -21,7 +21,6 @@ Activity, DbContext, Dict, - Expr, LambdaWithParamNames, Symbol, _expanded_names_to_names, diff --git a/test/test_axis_dict.py b/test/test_axis_dict.py new file mode 100644 index 0000000..7ec5727 --- /dev/null +++ b/test/test_axis_dict.py @@ -0,0 +1,40 @@ +from sympy import symbols, lambdify, simplify + +from lca_algebraic.axis_dict import AxisDict, NO_AXIS + + +def test_sum(): + a, b = symbols("a b") + + a1 = AxisDict({a: 1}) + a2 = AxisDict({a: 2}) + b2 = AxisDict({b: 2}) + + assert a1 + b2 == AxisDict({a: 1, b: 2}) + assert a1 + a2 == AxisDict({a: 3}) + assert a1 + 1 == AxisDict({a: 1, NO_AXIS: 1}) + + +def test_mul(): + a = symbols("a") + + a1 = AxisDict({a: 2}) + assert a1 * 2 == AxisDict({a: 4}) + assert simplify(a1 / 2) == AxisDict({a: 1}) + + +def test_free_symbols(): + dic = AxisDict({"a": "b"}) + assert dic.free_symbols == set([symbols("b")]) + + +def test_lambdify(): + a, b = symbols("a b") + + a1 = AxisDict({a: b * 2}) + + lambd = lambdify([b], a1) + + res = lambd(2) + + assert res == {a: 4} diff --git a/test/unit_test.py b/test/unit_test.py index ee97b73..0ed9a7c 100644 --- a/test/unit_test.py +++ b/test/unit_test.py @@ -455,8 +455,7 @@ def test_axis(data): res = {key: val for key, val in zip(res.index.values, res[res.columns[0]].values)} - expected = dict(a=2.0, b=4.0) - expected["*other*"] = 6.0 + expected = dict(a=2.0, b=4.0, _other_=6.0) expected["*sum*"] = 12.0 assert res == expected