Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question]: How to fine-tune a pre-trained flair model with a dataset containing new entities (NER task)? #3318

Closed
mariasierro opened this issue Sep 18, 2023 · 2 comments
Labels
question Further information is requested

Comments

@mariasierro
Copy link

Question

Hello!

I am working on a NER model in French but I am having an issue and I cannot find the solution anywhere :S

I want to fine-tune the pre-trained "flair/ner-french" model that, as provided in Huggingface (https://huggingface.co/flair/ner-french) recognizes the labels ORG, LOC, PER, MISC.

However, the dataset that I want to use for fine-tuning contains those labels plus some others: CODE, DATETIME, DEM, and QUANTITY.

The problem is that I do not know how to make the pre-trained model recognize these new labels.

I am working in Google Colab using Python. For now I just tried loading the model:
tagger = SequenceTagger.load("flair/ner-french")

Then I tried adding new tags to the tagger:
tagger.label_dictionary.add_item('B-DATETIME')
tagger.label_dictionary.add_item('I-DATETIME')
...

Then I tried training it:
from flair.trainers import ModelTrainer
trainer = ModelTrainer(tagger, corpus)
trainer.train(path,
learning_rate=0.1,
mini_batch_size=32,
max_epochs=15,
write_weights=True)

And then I get this error:
transitions_to_stop = transitions[
53 np.repeat(self.stop_tag, features.shape[0]),
54 [target[length - 1] for target, length in zip(targets, lengths)],

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

On the other hand, I found that someone asked a similar question (#1540) and someone provided some code to solve the issue:
tagger = SequenceTagger.load('ner')
state = tagger._get_state_dict()
tag_type = 'ner'
tag_dictionary = corpus.make_tag_dictionary(tag_type=tag_type)
state['tag_dictionary'] = tag_dictionary
START_TAG: str = ""
STOP_TAG: str = ""
state['state_dict']['transitions'] = torch.nn.Parameter(torch.randn(len(tag_dictionary), len(tag_dictionary)))
state['state_dict']['transitions'].detach()[tag_dictionary.get_idx_for_item(START_TAG), :] = -10000
state['state_dict']['transitions'].detach()[:, tag_dictionary.get_idx_for_item(STOP_TAG)] = -10000
num_directions = 2 if tagger.bidirectional else 1
linear_layer = torch.nn.Linear(tagger.hidden_size * num_directions, len(tag_dictionary))
state['state_dict']['linear.weight'] = linear_layer.weight
state['state_dict']['linear.bias'] = linear_layer.bias
model = SequenceTagger._init_model_with_state_dict(state)
trainer: ModelTrainer = ModelTrainer(model, corpus)

trainer.train('finetuned_model',
learning_rate=0.001,
mini_batch_size=64,
max_epochs=10)

The issue is that I already tried this code and it gets to training on the new dataset without errors but the accuracy is 0.
The model is not learning anything at all.

If someone could please give me a hint on what to do to add these new labels for fine-tuning the model, it would be much appreciated :) Thanks!

@mariasierro mariasierro added the question Further information is requested label Sep 18, 2023
@helpmefindaname
Copy link
Collaborator

Hi @msierrofer
you can just create a new tagger with a new tag-dictionary using the old embeddings:
new_tagger = SequenceTagger(old_tagger.embeddings, new_tagdict, label_type)

@helpmefindaname helpmefindaname added the Awaiting Response Waiting for new input from the author label Sep 18, 2023
@mariasierro
Copy link
Author

It works! Thank you very much :)

@github-actions github-actions bot removed the Awaiting Response Waiting for new input from the author label Sep 18, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants