-
Notifications
You must be signed in to change notification settings - Fork 76
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integrate AutoModelForSequenceClassification through PytorchModel #339
Integrate AutoModelForSequenceClassification through PytorchModel #339
Conversation
raise ValueError( | ||
"When using HuggingFace pretrained models, please use Tokenizers output for `x`" | ||
) | ||
pred = self.model(**x, **model_predict_kwargs).logits |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we also enable softmax here (post accessing the logits)? so that we convert the pred to softmax, if softmax=True? (we can add it as a class attribute above).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I implemented this way, please see if you agree: 9a67c4c
pyproject.toml
Outdated
@@ -36,7 +36,9 @@ dependencies = [ | |||
"scipy>=1.7.3", | |||
"tqdm>=4.62.3", | |||
"matplotlib>=3.3.4", | |||
"typing_extensions; python_version <= '3.8'" | |||
"typing_extensions; python_version <= '3.8'", | |||
"transformers<=4.30.2; python_version == '3.7'", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if this should be a part of the base dependencies. I think it is better fitted to add it under 'torch' (see line 78/80 and below)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd say those should go into
[project.optional_dependecies]
transformers = [...]
tests/conftest.py
Outdated
|
||
|
||
@pytest.fixture(scope="session", autouse=True) | ||
def mock_hf_text(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
contrary to the name, this is not a mock 🤷
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Haha, good catch! Just renamed it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tests/conftest.py
Outdated
|
||
CIFAR_IMAGE_SIZE = 32 | ||
MNIST_IMAGE_SIZE = 28 | ||
BATCH_SIZE = 124 | ||
MINI_BATCH_SIZE = 8 | ||
RANDOM_SEED = 42 | ||
|
||
set_seed(42) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can't believe we have forgotten to set PRNG seed 🤦
Thanks for noticing!
Mb, to ensure each test runs with the same PRNG state, we could do
@pytest.fixture(scope='function', autouse=True)
def reset_prngs():
# module names might be a bit wrong ;)
torch.seed()
np.seed()
tf.keras.set_seed()
random.seed()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
set_seed
from huggingface ensure all of these (and some others as well), but using autouse
is a clever idea :) just did it!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tests/conftest.py
Outdated
return model | ||
|
||
|
||
@pytest.fixture(scope="session", autouse=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove autouse
,
autouse=True
will force the model to be loaded into memory every time any tests is executed, even if the test does not use it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, just did it
elif isinstance(self.model, nn.Module): | ||
pred_model = self.get_softmax_arg_model() | ||
pred = pred_model(torch.Tensor(x).to(self.device), **model_predict_kwargs) | ||
return pred |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's try not to return None
, either tensor, or raise exception
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done: 179da1e
pyproject.toml
Outdated
@@ -36,7 +36,9 @@ dependencies = [ | |||
"scipy>=1.7.3", | |||
"tqdm>=4.62.3", | |||
"matplotlib>=3.3.4", | |||
"typing_extensions; python_version <= '3.8'" | |||
"typing_extensions; python_version <= '3.8'", | |||
"transformers<=4.30.2; python_version == '3.7'", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd say those should go into
[project.optional_dependecies]
transformers = [...]
raise ValueError( | ||
"When using HuggingFace pretrained models, please use Tokenizers output for `x`" | ||
) | ||
pred = self.model(**x, **model_predict_kwargs).logits |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just return self.model(**x, **model_predict_kwargs).logits
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did slightly different (in 179da1e) to handle and raise the softmax param properly. Could you see if you agree? Thanks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that looks great :D @abarbosa94
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you're right, it looks a bit different now. Can we also remove pred = None
at the top?
pyproject.toml
Outdated
@@ -52,6 +52,7 @@ dynamic = ["version"] | |||
# | |||
[project.optional-dependencies] | |||
tests = [ | |||
"cachetools>=5.3.3", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need cachetools
for tests?
If it is used by library it must be in [project.dependecies]
otherwise users can face issues after installation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right, removing it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done: 179da1e
Codecov ReportAttention: Patch coverage is
❗ Your organization needs to install the Codecov GitHub app to enable full functionality. Additional details and impacted files@@ Coverage Diff @@
## main #339 +/- ##
=======================================
Coverage 91.19% 91.20%
=======================================
Files 66 66
Lines 3906 3921 +15
=======================================
+ Hits 3562 3576 +14
- Misses 344 345 +1 ☔ View full report in Codecov by Sentry. |
…gration-v0-huggingface
improve type hint and raise error when predict is None
aba53a7
to
179da1e
Compare
based on the tesing, it looks like we need to add transformers also to the |
b0b6cda
into
understandable-machine-intelligence-lab:main
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please make sure Quantus is usable without transformers
installation
], | ||
) | ||
def test_huggingface_classifier_predict(hf_model, data, softmax, model_kwargs, expected): | ||
model = PyTorchModel(model=hf_model, softmax=softmax, model_predict_kwargs=model_kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I though softmax
must be a bool, or?
return model | ||
|
||
|
||
@pytest.fixture(scope="session", autouse=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
autouse=False
is the default
import torch | ||
from torch import nn | ||
from functools import lru_cache | ||
from transformers import PreTrainedModel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will cause ModuleNotFoundError
when user tries to import Quantus
without transformers
installed.
@@ -104,8 +81,39 @@ zennit = [ | |||
"quantus[torch]", | |||
"zennit>=0.5.1" | |||
] | |||
transformers = [ | |||
"quantus[torch, tensorflow]", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
quantus[torch]
should be enough
@@ -85,7 +60,9 @@ torch = [ | |||
"torchvision<=0.12.0; python_version == '3.7'", | |||
"torchvision>=0.15.1; sys_platform != 'linux' and python_version > '3.7'", | |||
"torchvision>=0.14.0, <0.15.1; sys_platform == 'linux' and python_version > '3.7' and python_version <= '3.10'", | |||
"torchvision>=0.15.1; sys_platform == 'linux' and python_version >= '3.11'" | |||
"torchvision>=0.15.1; sys_platform == 'linux' and python_version >= '3.11'", | |||
"transformers<=4.30.2; python_version == '3.7'", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please remove transformers
from torch = [...]
section
Description
This should be taken as an initial step toward: #238, #103, and #217
Implemented changes
PyTorchModel
prediction method to accept the HuggingFace modelisort
Minimum acceptance criteria