Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement initial ReLU/sign function via polynomial approximation #658

Open
j2kun opened this issue Apr 29, 2024 · 15 comments
Open

Implement initial ReLU/sign function via polynomial approximation #658

j2kun opened this issue Apr 29, 2024 · 15 comments
Assignees
Labels
dialect: polynomial Issues concerning the polynomial dialect

Comments

@j2kun
Copy link
Collaborator

j2kun commented Apr 29, 2024

Given that ReLU(x) = x (0.5 + 0.5 sgn(x)), this reduces to approximating the sign function, and this paper appears to have the state of the art: https://eprint.iacr.org/2020/834

Also note

  • max(u, v) = ((u+v) + (u-v)sign(u-v)) / 2
  • min(u, v) = -max(-u, -v) = ((u+v) - (v - u)sign(v - u)) / 2
@j2kun
Copy link
Collaborator Author

j2kun commented Apr 29, 2024

Also cf. https://openreview.net/pdf?id=Hq16Jk2bVlp and https://eprint.iacr.org/2021/1688 which use these approximations.

@j2kun
Copy link
Collaborator Author

j2kun commented Apr 29, 2024

The paper linked above 2020/834 does not release source code, but there is a reference implementation in LattiGo: https://github.com/tuneinsight/lattigo/blob/4cce9a48c1daaa2dd122921822f5ad70cd444156/he/hefloat/minimax_composite_polynomial.go#L124

@j2kun
Copy link
Collaborator Author

j2kun commented Apr 30, 2024

The paper https://eprint.iacr.org/2019/1234 is a precursor to https://eprint.iacr.org/2020/834, but also seems to explain more of the motivation behind the composite polynomials.

@j2kun
Copy link
Collaborator Author

j2kun commented May 1, 2024

An example of generating a well-fitting polynomial using lolremez: samhocevar/lolremez#28 (comment)

Another tool: https://github.com/pychebfun/pychebfun

@j2kun
Copy link
Collaborator Author

j2kun commented May 1, 2024

Outline sketch:

Various improvements based on more recent research that would be worth splitting into separate tickets.

@j2kun j2kun changed the title Implement ReLU/sign function via polynomial approximation Implement initial ReLU/sign function via polynomial approximation May 1, 2024
@j2kun
Copy link
Collaborator Author

j2kun commented May 2, 2024

Thanks to Seonhong Min for sending me https://eprint.iacr.org/2018/462, in which it shows that BFV achieves polynomial approximations via a fixed-point approximation, not a floating point one. I think there is also some complexity there in that evaluating a fixed point polynomial approximation also requires a rounding step, but not always, see sec 2.5

image

@Maokami
Copy link
Collaborator

Maokami commented May 2, 2024

I also had an interest in this issue a while back, so I know a paper worth sharing: https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=10155408
It's a follow-up paper by the authors of https://eprint.iacr.org/2020/834, and there's an implementation: https://github.com/snu-ccl/approxCNN/tree/main.
In the repo, you'll find that they've hardcoded the coefficients of polynomial approximations of sign function from alpha=4 to 14!

@Maokami
Copy link
Collaborator

Maokami commented May 2, 2024

The paper linked above 2020/834 does not release source code, but there is a reference implementation in LattiGo: https://github.com/tuneinsight/lattigo/blob/4cce9a48c1daaa2dd122921822f5ad70cd444156/he/hefloat/minimax_composite_polynomial.go#L124

Oh, I didn't know there was an implementation in Lattigo! In that case, these hardcoded coefficients might not be necessary.

@j2kun
Copy link
Collaborator Author

j2kun commented May 2, 2024

After starting an implementation in #665 (with a fixed approximation polynomial) and discussing in the HEIR meeting today, I have a few kinks to work out:

  1. I want to separate the task of choosing a polynomial approximation from the optimizations around evaluating it. This implies:
    1. I need a floating-point representation of a polynomial in the IR, but PolynomialAttr currently only supports integer coefficients
    2. I need a new op, say poly_ext.eval whose operands are the polynomial to apply and its input
    3. The approximation itself is for sign, but these operations are actually applied to max(0, x) = (x + x * sign(x)) / 2, which means we should support some internal polynomial arithmetic to construct these from the checked-in approximation. We meant to do this to support constant folding in the polynomial dialect, but never got around to it.
  2. (1) has a twist in that many more advanced polynomial approximations are not represented literally, but implicitly as a composition of smaller degree polynomials. This implies I will need a polynomial.compose op, or else an attribute that supports composite sub-polynomials, and cannot limit (1.ii) above to a single static polynomial. I think I will start with a single static polynomial but try to avoid making it difficult to upgrade to a composite polynomial.
  3. The approximate polynomial itself has a few quirks, because its coefficients further need to be encoded in a scheme-specific fashion. For CKKS this is relatively straightforward, but introduces additional error. For BGV/BFV this seems much harder, in part because the encodings are fixed-point and hence require rounding during polynomial evaluation, but rounding itself is hard (see above). There is also a question about which basis the polynomial is expressed in, cf. https://discord.com/channels/901152454077452399/1235349479482196049 for more on this
  4. The above points expose a problem with "lowering a ReLU": at the tosa level we don't yet know what scheme will be chosen, so the choice of polynomial approximation can't be scheme-aware or encoding-aware. I think the right solution here will be to include some extra metadata on the polynomial to express what function is being approximated, so that we can re-approximate it at lower levels if necessary.

@j2kun
Copy link
Collaborator Author

j2kun commented May 8, 2024

These folks do something slightly different, which is more holistic in re NN training: https://github.com/EfficientFHE/SmartPAF

They pick small degree polynomial approximations and then do a variety of network fine-tuning to adapt the network to the replaced operations. This seems out of scope of the compiler, since it would require training data to be included.

@j2kun
Copy link
Collaborator Author

j2kun commented May 31, 2024

I added an upstream RFC for the polynomial approximation pass https://discourse.llvm.org/t/rfc-a-polynomial-approximation-pass/79301

@j2kun j2kun self-assigned this Jun 30, 2024
@j2kun j2kun added the dialect: polynomial Issues concerning the polynomial dialect label Jun 30, 2024
@asraa
Copy link
Collaborator

asraa commented Sep 17, 2024

@JianmingTONG has also worked on https://github.com/EfficientFHE/SmartPAF system for approximating activation functions

@Pro7ech
Copy link

Pro7ech commented Nov 8, 2024

The paper linked above 2020/834 does not release source code, but there is a reference implementation in LattiGo: https://github.com/tuneinsight/lattigo/blob/4cce9a48c1daaa2dd122921822f5ad70cd444156/he/hefloat/minimax_composite_polynomial.go#L124

Oh, I didn't know there was an implementation in Lattigo! In that case, these hardcoded coefficients might not be necessary.

The one in Tune Insight's Lattigo is unstable and highly likely to fail when approximating over multiple intervals. The issue comes from the method to find the roots of the error function which can miss roots. I fixed if this summer, because I needed it for the iDASH challenge, it is in my personal fork of Lattigo.

@j2kun
Copy link
Collaborator Author

j2kun commented Nov 8, 2024

The paper linked above 2020/834 does not release source code, but there is a reference implementation in LattiGo: https://github.com/tuneinsight/lattigo/blob/4cce9a48c1daaa2dd122921822f5ad70cd444156/he/hefloat/minimax_composite_polynomial.go#L124

Oh, I didn't know there was an implementation in Lattigo! In that case, these hardcoded coefficients might not be necessary.

The one in Tune Insight's Lattigo is unstable and highly likely to fail when approximating over multiple intervals. The issue comes from the method to find the roots of the error function which can miss roots. I fixed if this summer, because I needed it for the iDASH challenge, it is in my personal fork of Lattigo.

@Pro7ech Since I wrote that comment I've been doing a lot of studying on this topic, including writing a basic Remez solver that had issues with finding roots. I believe that using the so-called barycentric form of the remez algorithm would avoid these issues (though I have not yet followed through with this plan to confirm it). Did you come to the same conclusion? Or did you resolve it via a different method?

@Pro7ech
Copy link

Pro7ech commented Nov 8, 2024

The paper linked above 2020/834 does not release source code, but there is a reference implementation in LattiGo: https://github.com/tuneinsight/lattigo/blob/4cce9a48c1daaa2dd122921822f5ad70cd444156/he/hefloat/minimax_composite_polynomial.go#L124

Oh, I didn't know there was an implementation in Lattigo! In that case, these hardcoded coefficients might not be necessary.

The one in Tune Insight's Lattigo is unstable and highly likely to fail when approximating over multiple intervals. The issue comes from the method to find the roots of the error function which can miss roots. I fixed if this summer, because I needed it for the iDASH challenge, it is in my personal fork of Lattigo.

@Pro7ech Since I wrote that comment I've been doing a lot of studying on this topic, including writing a basic Remez solver that had issues with finding roots. I believe that using the so-called barycentric form of the remez algorithm would avoid these issues (though I have not yet followed through with this plan to confirm it). Did you come to the same conclusion? Or did you resolve it via a different method?

I haven't looked into the barycentric Remez, but I've managed to make it work and stable (although quite slow for large degrees) even when the number if discret intervals is in the dozen. I directly look for the extrema of the error function, this is easier and more stable than looking for the roots of the derivative of the error function. Because of the alternating condition, the minimum and maximum number of extrema can be bounded, and knowing how many points we are looking for helps a lot. The code is commented if you want to take a look.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
dialect: polynomial Issues concerning the polynomial dialect
Projects
None yet
Development

No branches or pull requests

4 participants