-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make the code compatible with other keras backends #2137
Conversation
5721909
to
2120355
Compare
f2484d3
to
4448ba4
Compare
502ac27
to
2f150e6
Compare
45fe562
to
b3c5b27
Compare
This makes the code:
*I don't think we want to run with pytorch, but having the possibility is nice. If we start using it for whatever reason I can debug that. |
Greetings from your nice fit 🤖 !
Check the report carefully, and please buy me a ☕ , or better, a GPU 😉! |
6b94412
to
b859c07
Compare
59464e9
to
ee35538
Compare
Greetings from your nice fit 🤖 !
Check the report carefully, and please buy me a ☕ , or better, a GPU 😉! |
ee35538
to
bd05e0f
Compare
3ca453d
to
15e9f34
Compare
Greetings from your nice fit 🤖 !
Check the report carefully, and please buy me a ☕ , or better, a GPU 😉! |
Ok, this is ready for review. Some comparisons in the first post. I've heard you are completely free this weekend @RoyStegeman in case you want to have another look :P |
15e9f34
to
23e7620
Compare
23e7620
to
da1d6bd
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, looks good. Main point is that it should be documented
f97a928
to
af832de
Compare
make the code compatible with pytorch make it more keras only, working also with sum rules update keras limits test with newer tf fix for normalized distributions
test change workflow test in 3.12 in conda
remove other instances of kops Update n3fit/src/n3fit/backends/keras_backend/MetaModel.py Co-authored-by: Roy Stegeman <[email protected]>
300c865
to
59ce25a
Compare
This deals with #2134
It exchanges tf imports to keras (with small changes when the functions signatures change). It was necessary to add a few explicit output shapes here and there so that Keras knows how to compile the graph (and then we can also use the
jit_compile
flag, at least with the tensorflow backend).Now, for some comparisons:
Turns out that the code is mostly general enough to work with pytorch ootb. Very few changes are needed* (the second commit, and not even, because some of the changes are
tf.
that I missed).However, I've run a few quick benchmarks and turns out that it is about 4 times slower than the tensorflow backend. This is not surprising since we have been using tensorflow for a long time so we might be doing thing that just happen to be better for tensorflow. Also, while jit_compile is set to auto, I think the default for pytorch is False.