Skip to content

Commit

Permalink
Update SimplifyReshape Pattern Callback
Browse files Browse the repository at this point in the history
  • Loading branch information
chandrasekaranpradeep committed Apr 18, 2024
1 parent 1955a63 commit ffc8e5a
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions python/tvm/relay/op/contrib/buda/buda_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3010,9 +3010,9 @@ def callback(self, pre, post, node_map):
reshape_1 = node_map[self.reshape_1][0]
final_shape = list(reshape_1.attrs.newshape)

if input_shape[0] * input_shape[1] == final_shape[-1]:
reshape_0 = tvm.relay.reshape(node_map[self.reshape_0][0], newshape=[1, 1, input_shape[-2] * input_shape[-3], input_shape[-1]])
final_transpose = tvm.relay.transpose(reshape_0, axes=[0,1,3,2])
if input_shape == final_shape and len(input_shape) >= 3:
final_transpose_axes = np.arange(int(len(input_shape) - 2)).tolist() + np.flip(np.arange(int(len(input_shape) - 2), len(input_shape))).tolist()
final_transpose = tvm.relay.transpose(node_map[self.act][0], axes=final_transpose_axes)
return final_transpose

return post
Expand Down

0 comments on commit ffc8e5a

Please sign in to comment.