-
Notifications
You must be signed in to change notification settings - Fork 72
Error while inferencing model trained on own dataset!!! #15
Comments
Hi, I think this problem is probably because I should probably update the training script to use You could either:
model = T5ForConditionalGeneration.from_pretrained("iarfmoose/t5-base-question-generator")
state_dict = torch.load("t5-question-generator.pt")
model.load_state_dict(state_dict) |
Thanks for your response @AMontgomerie . Can you please update the required changes to smoothen the flow? |
Thanku for raising the issue @sabhi27 import torch when i am loading my trained weight with using below codes i am getting this error RuntimeError: Error(s) in loading state_dict for T5ForConditionalGeneration: can you help me to get out of this @AMontgomerie ..ty |
Oh, looks like the save function I wrote also saves the optimizer state and some other variables. That's why it's complaining about unexpected keys, although I'm not sure why it's also complaining about missing keys... Can you try this instead? import torch
from transformers import T5ForConditionalGeneration
model = T5ForConditionalGeneration.from_pretrained("iarfmoose/t5-base-question-generator")
state_dict = torch.load("/content/drive/MyDrive/question_BIOasq/new_type/t5-question-generator.pt")
model.load_state_dict(state_dict["model_state_dict"]) # <-- try changing this line |
OK I've replace the old save function. The new one uses the Huggingface-style saving instead. Now when you train the model, it should create a directory called Then you can load your saved model like: from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
tokenizer = AutoTokenizer.from_pretrained("./t5-base-question-generator")
model = AutoModelForSeq2SeqLM.from_pretrained("./t5-base-question-generator") |
Hi @AMontgomerie
I have successfully trained model on my own dataset and one file "t5-question-generator.pt" got saved as model file in question-generator folder.
While inferencing, when I am doing qg = QuestionGeneration(), I am getting below error.
OSError: Couldn't reach server at '/content/question_generator/t5-question-generator.pt' to download configuration file or configuration file is not a valid JSON file. Please check network or file content here: /content/question_generator/t5-question-generator.pt.
Can you help me get this resolved???
The text was updated successfully, but these errors were encountered: