-
Notifications
You must be signed in to change notification settings - Fork 2.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Not able to export sentence-transformers model to PyTorch. #820
Comments
I have not yet worked with jit trace. The forward function only takes a dict as feature, i.e. you input_features. The |
@nreimers Thank you for looking into it. torch.jit.trace takes tuple, or list of tensors. It throws errors regardless
/Users/sivers/xformers/nlpenv/bin/python3 /Users/sivers/xformers/githubissue.py |
According https://pytorch.org/docs/stable/generated/torch.jit.trace.html it also takes dictionaries. As mentioned, forward accepts only dictionaries. So you must pass a dict. |
@sivers2021 You will want to pass a dictionary inside a tuple. This satisfies both interfaces.
See this line which I what I believe @nreimers is referring to
|
We are tracing SentenceTransformers model using this save_as_pt method. But we recently found that, our traced model won't accept a doc with a token length exceeding 512. We are facing error like: Can I get any direction what do I need to do to make sure the traced model also supports truncation? Thanks |
For anyone else who is trying this, I managed to get basic tracing to work like so: import numpy as np
import torch
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("all-MiniLM-L6-v2")
# Our sentences to encode
sentences = [
"This framework generates embeddings for each input sentence",
"Sentences are passed as a list of string.",
"The quick brown fox jumps over the lazy dog."
]
# Sentences are encoded by calling model.encode()
embeddings = model.encode(sentences)
tokens = model.tokenize(sentences)
tokens = {k : v.to('mps') for (k,v) in tokens.items()}
# strict=False is necessary to avoid some warnings about returning a dict from forward.
traced_encode = torch.jit.trace_module(model, {'forward': tokens}, strict=False)
traced_embeddings = traced_encode(tokens)['sentence_embedding'].cpu().detach().numpy()
for sentence, embedding, traced_embedding in zip(sentences, embeddings, traced_embeddings):
print("Sentence:", sentence)
print("Max diff between embedding and traced_embedding: ", np.max(embedding - traced_embedding))
print("") which outputs
Note that I'm on a Mac, so if you're not you might have to change the For the curious, here's what the trace looks like. Note that the traced forward does not do truncation, nor tokenization, which occur in |
I'm also getting the same warning when calling encode. FutureWarning: I've searched for a force_download option but that is nowhere to be found. This is the code causing the warning embeddings1 = model.encode(sentences[0], convert_to_tensor=True, show_progress_bar=False) embeddings2 = model.encode(sentences[1], convert_to_tensor=True, show_progress_bar=False) |
Hello! This is caused by a recent
The tl:dr is that this warning will go away with the next With other words, your code looks good & it should work correctly. Apologies for the confusion.
|
Hi,
I would like to export sentence-transformers model to PyTorch. However, I am not able to jit trace the stsb-distilbert-base model.
Any help is much appreciated.
Thanks,
-s
sentence-transformers (Version: 0.4.1.2)
torch (version 1.8.0)
python 3.6.7
Traceback (most recent call last):
File "/Users/sivers/xformers/githubissue.py", line 16, in
traced_model = torch.jit.trace(model, example_inputs=(input_ids, input_type_ids, input_mask, input_features))
File "/Users/sivers/xformers/nlpenv/lib/python3.6/site-packages/torch/jit/_trace.py", line 742, in trace
_module_class,
File "/Users/sivers/xformers/nlpenv/lib/python3.6/site-packages/torch/jit/_trace.py", line 940, in trace_module
_force_outplace,
File "/Users/sivers/xformers/nlpenv/lib/python3.6/site-packages/torch/nn/modules/module.py", line 887, in _call_impl
result = self._slow_forward(*input, **kwargs)
File "/Users/sivers/xformers/nlpenv/lib/python3.6/site-packages/torch/nn/modules/module.py", line 860, in _slow_forward
result = self.forward(*input, **kwargs)
TypeError: forward() takes 2 positional arguments but 5 were given
The text was updated successfully, but these errors were encountered: