-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tutorial to export a SPLADE model to ONNX #47
Comments
Hi @ntnq4 Not that I am aware of. I am not super familiar with ONNX - did you manage to make it work? |
I didn't manage to make it work unfortunately... I tried this tutorial but it didn't work for my SPLADE model. I also found this recent paper that mentionned this conversion. |
Hi @ntnq4 ,
import torch
from transformers import AutoModelForMaskedLM,AutoTokenizer # type: ignore
class TransformerRep(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-ensembledistil", torchscript=True) # type: ignore
self.model.eval() # type: ignore
self.fp16 = True
def encode(self, input_ids, token_type_ids, attention_mask):
# Tokens is a dict with keys input_ids and attention_mask
return self.model(input_ids, token_type_ids, attention_mask)[0]
class SpladeModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = TransformerRep()
self.agg = "max"
self.model.eval()
def forward(self, input_ids,token_type_ids, attention_mask):
with torch.cuda.amp.autocast(): # type: ignore
with torch.no_grad():
lm_logits = self.model.encode(input_ids,token_type_ids, attention_mask)[0]
vec, _ = torch.max(torch.log(1 + torch.relu(lm_logits)) * attention_mask.unsqueeze(-1), dim=1)
indices = vec.nonzero().squeeze()
weights = vec.squeeze()[indices]
return indices[:,1], weights[:,1]
# Convert the model to TorchScript
model = SpladeModel()
tokenizer = AutoTokenizer.from_pretrained("naver/splade-cocondenser-ensembledistil")
sample = "the capital of france is paris"
inputs = tokenizer(sample, return_tensors="pt")
traced_model = torch.jit.trace(model, (inputs["input_ids"], inputs["token_type_ids"], inputs["attention_mask"]))
import torch
dyn_axis = {
'input_ids': {0: 'batch_size', 1: 'sequence'},
'attention_mask': {0: 'batch_size', 1: 'sequence'},
'token_type_ids': {0: 'batch_size', 1: 'sequence'},
'indices': {0: 'batch_size', 1: 'sequence'},
'weights': {0: 'batch_size', 1: 'sequence'}
}
model = torch.jit.load(model_file)
onnx_model = torch.onnx.export(
model,
dummy_input, # type: ignore
f=model_onnx_file,
input_names=['input_ids','token_type_ids', 'attention_mask'],
output_names=['indices', 'weights'],
dynamic_axes=dyn_axis,
do_constant_folding=True,
opset_version=15,
verbose=False,
)
model_names= [
"naver/splade_v2_max",
"naver/splade_v2_distil",
"naver/splade-cocondenser-ensembledistil",
"naver/efficient-splade-VI-BT-large-query",
"naver/efficient-splade-VI-BT-large-doc",
] requirements:
Hope this helps! :) |
Hi @risan-raja, Thank you for your help : ) |
if an ONNX conversion was added to HuggingFace in a folder called |
Hello,
I trained a SPLADE model on my own recently. To reduce the inference time, I tried to export my model to ONNX with
torch.onnx.export()
but I encountered a few errors.Is there a tutorial somewhere for this conversion?
The text was updated successfully, but these errors were encountered: