Skip to content

Commit

Permalink
fix invalid character
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjm97 committed Mar 31, 2024
1 parent 2dee5b7 commit b8525bc
Showing 1 changed file with 25 additions and 36 deletions.
61 changes: 25 additions & 36 deletions teg/derivs/edge/common.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,13 @@
from typing import Dict, Set, Tuple, Iterable

from teg import (
ITeg,
Const,
Var,
Add,
Mul,
IfElse,
Teg,
ITegBool,
Bool,
And,
Or,
TegVar
)
from teg import ITeg, Const, Var, Add, Mul, IfElse, Teg, ITegBool, Bool, And, Or, TegVar
from teg.lang.extended_utils import extract_vars
from functools import reduce
import operator


def extract_variables_from_affine(expr: ITeg) -> Dict[Tuple[str, int], ITeg]:
"""Extract all of the variables in an affine expression. """
"""Extract all of the variables in an affine expression."""

if isinstance(expr, Const):
return {}
Expand All @@ -35,10 +22,9 @@ def extract_variables_from_affine(expr: ITeg) -> Dict[Tuple[str, int], ITeg]:
raise ValueError(f'The expression of type "{type(expr)}" results in a computation that is not affine.')


def extract_moving_discontinuities(expr: ITeg,
var: Var,
not_ctx: Set[Tuple[str, int]],
banned_variables: Set[Tuple[(str, int)]]) -> Iterable[Tuple[ITeg, ITeg]]:
def extract_moving_discontinuities(
expr: ITeg, var: Var, not_ctx: Set[Tuple[str, int]], banned_variables: Set[Tuple[(str, int)]]
) -> Iterable[Tuple[ITeg, ITeg]]:
"""Yield all subexpressions producing a moving discontinuity.
A moving discontinuity is a branching statement that includes
Expand All @@ -52,32 +38,36 @@ def extract_moving_discontinuities(expr: ITeg,
elif isinstance(expr, Teg):
banned_variables.add((expr.dvar.name, expr.dvar.uid))

yield from (moving_cond for child in expr.children
for moving_cond in extract_moving_discontinuities(child, var, not_ctx, banned_variables))
yield from (
moving_cond
for child in expr.children
for moving_cond in extract_moving_discontinuities(child, var, not_ctx, banned_variables)
)


def moving_discontinuities_in_boolean(expr: ITegBool,
var: Var,
not_ctx: Set[Tuple[str, int]],
banned_variables: Set[Tuple[(str, int)]]) -> Iterable[ITeg]:
"""Yield all moving discontinuities in boolean expression (e.g., in d_t \int_x [x < t] ) """
def moving_discontinuities_in_boolean(
expr: ITegBool, var: Var, not_ctx: Set[Tuple[str, int]], banned_variables: Set[Tuple[(str, int)]]
) -> Iterable[ITeg]:
"""Yield all moving discontinuities in boolean expression (e.g., in d_t int_x [x < t] )"""
if isinstance(expr, Bool):
var_name_var_in_cond = extract_variables_from_affine(expr.left_expr - expr.right_expr)
moving_var_name_uids = var_name_var_in_cond.keys() - not_ctx - {(var.name, var.uid)}

# Check that the variable var is in the condition
# and another free variable (not in not_ctx) is in the condition
if ((var.name, var.uid) in var_name_var_in_cond
and len(moving_var_name_uids) > 0
and len(banned_variables & moving_var_name_uids) == 0):
if (
(var.name, var.uid) in var_name_var_in_cond
and len(moving_var_name_uids) > 0
and len(banned_variables & moving_var_name_uids) == 0
):
yield expr

elif isinstance(expr, (And, Or)):
yield from moving_discontinuities_in_boolean(expr.left_expr, var, not_ctx, banned_variables)
yield from moving_discontinuities_in_boolean(expr.right_expr, var, not_ctx, banned_variables)

else:
raise ValueError('Illegal expression in boolean.')
raise ValueError("Illegal expression in boolean.")


def extend_dependencies(var_list, deps_list):
Expand All @@ -90,9 +80,9 @@ def extend_dependencies(var_list, deps_list):
return var_list | extended_list


def primitive_booleans_in(expr: ITegBool,
not_ctx: Set[Tuple[str, int]],
deps: Dict[TegVar, Set[Var]]) -> Iterable[ITeg]:
def primitive_booleans_in(
expr: ITegBool, not_ctx: Set[Tuple[str, int]], deps: Dict[TegVar, Set[Var]]
) -> Iterable[ITeg]:

if isinstance(expr, Bool):
cond_variables = extract_vars(expr.left_expr - expr.right_expr)
Expand All @@ -101,13 +91,12 @@ def primitive_booleans_in(expr: ITegBool,

# Check that the variable var is in the condition
# and another variable not in not_ctx is in the condition
if (({(v.name, v.uid) for v in extended_cond_variables} & not_ctx)
and len(moving_var_name_uids) > 0):
if ({(v.name, v.uid) for v in extended_cond_variables} & not_ctx) and len(moving_var_name_uids) > 0:
yield expr

elif isinstance(expr, (And, Or)):
yield from primitive_booleans_in(expr.left_expr, not_ctx, deps)
yield from primitive_booleans_in(expr.right_expr, not_ctx, deps)

else:
raise ValueError('Illegal expression in boolean.')
raise ValueError("Illegal expression in boolean.")

0 comments on commit b8525bc

Please sign in to comment.