Skip to content

Commit

Permalink
FunMC: Make maybe_broadcast_structure more general.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 560376560
  • Loading branch information
SiegeLordEx authored and tensorflower-gardener committed Aug 26, 2023
1 parent d1affbb commit 0ff4e3a
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
17 changes: 9 additions & 8 deletions spinoffs/fun_mc/fun_mc/fun_mc_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,9 +732,9 @@ def maybe_broadcast_structure(from_structure: Any,
to_structure: Any) -> Any:
"""Maybe broadcasts `from_structure` to `to_structure`.
If `from_structure` is a singleton, it is tiled to match the structure of
`to_structure`. Note that the elements in `from_structure` are not copied if
this tiling occurs.
This assumes that `from_structure` is a shallow version of `to_structure`.
Subtrees of `to_structure` are set to the leaf values of `from_structure` that
those subtrees correspond to.
Args:
from_structure: A structure.
Expand All @@ -743,11 +743,12 @@ def maybe_broadcast_structure(from_structure: Any,
Returns:
new_from_structure: Same structure as `to_structure`.
"""
flat_from = util.flatten_tree(from_structure)
flat_to = util.flatten_tree(to_structure)
if len(flat_from) == 1:
flat_from *= len(flat_to)
return util.unflatten_tree(to_structure, flat_from)
def _broadcast_leaf(from_val, to_subtree):
return util.map_tree(lambda _: from_val, to_subtree)

return util.map_tree_up_to(
from_structure, _broadcast_leaf, from_structure, to_structure
)


def reparameterize_potential_fn(
Expand Down
3 changes: 3 additions & 0 deletions spinoffs/fun_mc/fun_mc/fun_mc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,9 @@ def testBroadcastStructure(self):
struct = fun_mc.maybe_broadcast_structure([3, 4], [1, 2])
self.assertEqual([3, 4], struct)

struct = fun_mc.maybe_broadcast_structure([1, 2], [[0, 0], [0, 0, 0]])
self.assertEqual([[1, 1], [2, 2, 2]], struct)

def testCallPotentialFn(self):

def potential(x):
Expand Down

0 comments on commit 0ff4e3a

Please sign in to comment.