Skip to content

Commit

Permalink
arithmetic fixes to account for np.ndarray being a leaf array
Browse files Browse the repository at this point in the history
  • Loading branch information
kaushikcfd authored and inducer committed Sep 1, 2022
1 parent 6e30532 commit ce8ab7c
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions arraycontext/container/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,16 +492,17 @@ def {fname}(arg1):
bcast_actx_ary_types = ()

gen(f"""
if {bool(outer_bcast_type_names)}: # optimized away
if isinstance(arg2,
{tup_str(outer_bcast_type_names
+ bcast_actx_ary_types)}):
return cls({bcast_same_cls_init_args})
if {numpy_pred("arg2")}:
result = np.empty_like(arg2, dtype=object)
for i in np.ndindex(arg2.shape):
result[i] = {op_str.format("arg1", "arg2[i]")}
return result
if {bool(outer_bcast_type_names)}: # optimized away
if isinstance(arg2,
{tup_str(outer_bcast_type_names
+ bcast_actx_ary_types)}):
return cls({bcast_same_cls_init_args})
return NotImplemented
""")
gen(f"cls.__{dunder_name}__ = {fname}")
Expand Down Expand Up @@ -538,16 +539,16 @@ def {fname}(arg1):
def {fname}(arg2, arg1):
# assert other.__cls__ is not cls
if {bool(outer_bcast_type_names)}: # optimized away
if isinstance(arg1,
{tup_str(outer_bcast_type_names
+ bcast_actx_ary_types)}):
return cls({bcast_init_args})
if {numpy_pred("arg1")}:
result = np.empty_like(arg1, dtype=object)
for i in np.ndindex(arg1.shape):
result[i] = {op_str.format("arg1[i]", "arg2")}
return result
if {bool(outer_bcast_type_names)}: # optimized away
if isinstance(arg1,
{tup_str(outer_bcast_type_names
+ bcast_actx_ary_types)}):
return cls({bcast_init_args})
return NotImplemented
cls.__r{dunder_name}__ = {fname}""")
Expand Down

0 comments on commit ce8ab7c

Please sign in to comment.