Skip to content

Commit

Permalink
Merge branch 'pchandrasekaran/simplyreshape' into 'main'
Browse files Browse the repository at this point in the history
Update SimplifyReshape Pattern Callback

See merge request tenstorrent/tvm!59
  • Loading branch information
chandrasekaranpradeep committed Apr 22, 2024
2 parents 1955a63 + ffc8e5a commit 5a6455b
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions python/tvm/relay/op/contrib/buda/buda_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3010,9 +3010,9 @@ def callback(self, pre, post, node_map):
reshape_1 = node_map[self.reshape_1][0]
final_shape = list(reshape_1.attrs.newshape)

if input_shape[0] * input_shape[1] == final_shape[-1]:
reshape_0 = tvm.relay.reshape(node_map[self.reshape_0][0], newshape=[1, 1, input_shape[-2] * input_shape[-3], input_shape[-1]])
final_transpose = tvm.relay.transpose(reshape_0, axes=[0,1,3,2])
if input_shape == final_shape and len(input_shape) >= 3:
final_transpose_axes = np.arange(int(len(input_shape) - 2)).tolist() + np.flip(np.arange(int(len(input_shape) - 2), len(input_shape))).tolist()
final_transpose = tvm.relay.transpose(node_map[self.act][0], axes=final_transpose_axes)
return final_transpose

return post
Expand Down

0 comments on commit 5a6455b

Please sign in to comment.