Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support triu function for tvm.relay.expr.Call Inputs #2

Closed
wants to merge 2 commits into from

Conversation

JushBJJ
Copy link
Contributor

@JushBJJ JushBJJ commented Mar 30, 2024

This PR is to address #3

Having a proper support for triangular upper when the inputs are a list of CallNodes will allow us to be a step closer to successfully implementing Qwen 1.5 (0.5B) (See tenstorrent/tt-buda-demos#20).

Explanation

def triu(self, inputs, input_types):
    x = inputs[0]
    x_shape = _infer_shape(x)

+    if isinstance(inputs[0], tvm.relay.expr.Call):
+        return self.trilu(inputs, input_types, mode="triu")
        
    mask = np.triu(np.ones(x_shape), inputs[1]).astype(np.bool)
    mask = tvm.nd.array(mask)
    mask = tvm.relay.Constant(mask)

    zeros = np.zeros(x_shape).astype(_convert_tvm_to_np_dtype(input_types[0]))
    zeros = tvm.nd.array(zeros)
    zeros = tvm.relay.Constant(zeros)
        
    return _op.where(mask, x, zeros)

When compiling Qwen 1.5 (0.5B) (tenstorrent/tt-buda-demos#37), one of its OP codes is aten::triu with its inputs containing nested functions of OP calls.

>> inputs[0]
CallNode(Op(add), [CallNode(Op(subtract), [CallNode(Op(subtract), [CallNode(Op(add), [Constant(256), Constant(0)], (nullptr), []), CallNode(Op(multiply), [Constant(1), Constant(256)], (nullptr), [])], (nullptr), []), CallNode(Op(multiply), [Constant(1), Constant(32768)], (nullptr), [])], (nullptr), []), Constant(1)], (nullptr), [])

>> type(inputs[0])
<tvm.relay.expr.Call>

self.trilu seems to be able to successfully handle these inputs when mode="upper" to do triangular upper operation.

Issues

1 Op(trilu) instead of Op(triu)

After doing self.trilu(inputs, input_types, mode="triu"), the resulting output would be:

CallNode(Op(trilu), [CallNode(Op(cast), [CallNode(Op(ones_like), [CallNode(Op(add),...

Is this a genuine issue??? Or can it be ignored for now?

2. NaN tensor values for Grayskull e75

When tested on @marty1885's e75, he ran into an error where his tensor values were NaN.
But weirdly @JonathanALevine's e150 was able to successfully compile and run it until running into some errors later.

*This is a draft for now since this is just a workaround and not a proper fix yet.

@JushBJJ JushBJJ changed the title Support triu Function for tvm.relay.expr.Call Inputs Support triu function for tvm.relay.expr.Call Inputs Mar 30, 2024
@JushBJJ
Copy link
Contributor Author

JushBJJ commented Jul 25, 2024

Closed as this is no longer an issue for Qwen specifically

@JushBJJ JushBJJ closed this Jul 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant