Skip to content

Commit

Permalink
Add support for ESM models (#447)
Browse files Browse the repository at this point in the history
* Add support for ESM models

* Add ESM tokenizer conversion methods

* Add special test cases for ESM tokenizer

* add special tokens in conversion script

* Do not save decoder

* Add special tokens tokenizer test

* Join tokens with space if decoder is null

* Treat all tokens as added tokens

* Use `WhitespaceSplit` pretokenizer

* `<eos>` and `<bos>` are not special tokens

* Update more supported ESM models

* Add `--tokenizer_id` to conversion script

* Add supported models comments
  • Loading branch information
xenova authored Dec 13, 2023
1 parent 0d2f05d commit 80d22da
Show file tree
Hide file tree
Showing 8 changed files with 176 additions and 14 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te
1. **[Donut](https://huggingface.co/docs/transformers/model_doc/donut)** (from NAVER), released together with the paper [OCR-free Document Understanding Transformer](https://arxiv.org/abs/2111.15664) by Geewook Kim, Teakgyu Hong, Moonbin Yim, Jeongyeon Nam, Jinyoung Park, Jinyeong Yim, Wonseok Hwang, Sangdoo Yun, Dongyoon Han, Seunghyun Park.
1. **[DPT](https://huggingface.co/docs/transformers/master/model_doc/dpt)** (from Intel Labs) released with the paper [Vision Transformers for Dense Prediction](https://arxiv.org/abs/2103.13413) by René Ranftl, Alexey Bochkovskiy, Vladlen Koltun.
1. **[ELECTRA](https://huggingface.co/docs/transformers/model_doc/electra)** (from Google Research/Stanford University) released with the paper [ELECTRA: Pre-training text encoders as discriminators rather than generators](https://arxiv.org/abs/2003.10555) by Kevin Clark, Minh-Thang Luong, Quoc V. Le, Christopher D. Manning.
1. **[ESM](https://huggingface.co/docs/transformers/model_doc/esm)** (from Meta AI) are transformer protein language models. **ESM-1b** was released with the paper [Biological structure and function emerge from scaling unsupervised learning to 250 million protein sequences](https://www.pnas.org/content/118/15/e2016239118) by Alexander Rives, Joshua Meier, Tom Sercu, Siddharth Goyal, Zeming Lin, Jason Liu, Demi Guo, Myle Ott, C. Lawrence Zitnick, Jerry Ma, and Rob Fergus. **ESM-1v** was released with the paper [Language models enable zero-shot prediction of the effects of mutations on protein function](https://doi.org/10.1101/2021.07.09.450648) by Joshua Meier, Roshan Rao, Robert Verkuil, Jason Liu, Tom Sercu and Alexander Rives. **ESM-2 and ESMFold** were released with the paper [Language models of protein sequences at the scale of evolution enable accurate structure prediction](https://doi.org/10.1101/2022.07.20.500902) by Zeming Lin, Halil Akin, Roshan Rao, Brian Hie, Zhongkai Zhu, Wenting Lu, Allan dos Santos Costa, Maryam Fazel-Zarandi, Tom Sercu, Sal Candido, Alexander Rives.
1. **[Falcon](https://huggingface.co/docs/transformers/model_doc/falcon)** (from Technology Innovation Institute) by Almazrouei, Ebtesam and Alobeidli, Hamza and Alshamsi, Abdulaziz and Cappelli, Alessandro and Cojocaru, Ruxandra and Debbah, Merouane and Goffinet, Etienne and Heslow, Daniel and Launay, Julien and Malartic, Quentin and Noune, Badreddine and Pannier, Baptiste and Penedo, Guilherme.
1. **[FLAN-T5](https://huggingface.co/docs/transformers/model_doc/flan-t5)** (from Google AI) released in the repository [google-research/t5x](https://github.com/google-research/t5x/blob/main/docs/models.md#flan-t5-checkpoints) by Hyung Won Chung, Le Hou, Shayne Longpre, Barret Zoph, Yi Tay, William Fedus, Eric Li, Xuezhi Wang, Mostafa Dehghani, Siddhartha Brahma, Albert Webson, Shixiang Shane Gu, Zhuyun Dai, Mirac Suzgun, Xinyun Chen, Aakanksha Chowdhery, Sharan Narang, Gaurav Mishra, Adams Yu, Vincent Zhao, Yanping Huang, Andrew Dai, Hongkun Yu, Slav Petrov, Ed H. Chi, Jeff Dean, Jacob Devlin, Adam Roberts, Denny Zhou, Quoc V. Le, and Jason Wei
1. **[GLPN](https://huggingface.co/docs/transformers/model_doc/glpn)** (from KAIST) released with the paper [Global-Local Path Networks for Monocular Depth Estimation with Vertical CutDepth](https://arxiv.org/abs/2201.07436) by Doyeon Kim, Woonghyun Ga, Pyungwhan Ahn, Donggyu Joo, Sehwan Chun, Junmo Kim.
Expand Down
1 change: 1 addition & 0 deletions docs/snippets/6_supported-models.snippet
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
1. **[Donut](https://huggingface.co/docs/transformers/model_doc/donut)** (from NAVER), released together with the paper [OCR-free Document Understanding Transformer](https://arxiv.org/abs/2111.15664) by Geewook Kim, Teakgyu Hong, Moonbin Yim, Jeongyeon Nam, Jinyoung Park, Jinyeong Yim, Wonseok Hwang, Sangdoo Yun, Dongyoon Han, Seunghyun Park.
1. **[DPT](https://huggingface.co/docs/transformers/master/model_doc/dpt)** (from Intel Labs) released with the paper [Vision Transformers for Dense Prediction](https://arxiv.org/abs/2103.13413) by René Ranftl, Alexey Bochkovskiy, Vladlen Koltun.
1. **[ELECTRA](https://huggingface.co/docs/transformers/model_doc/electra)** (from Google Research/Stanford University) released with the paper [ELECTRA: Pre-training text encoders as discriminators rather than generators](https://arxiv.org/abs/2003.10555) by Kevin Clark, Minh-Thang Luong, Quoc V. Le, Christopher D. Manning.
1. **[ESM](https://huggingface.co/docs/transformers/model_doc/esm)** (from Meta AI) are transformer protein language models. **ESM-1b** was released with the paper [Biological structure and function emerge from scaling unsupervised learning to 250 million protein sequences](https://www.pnas.org/content/118/15/e2016239118) by Alexander Rives, Joshua Meier, Tom Sercu, Siddharth Goyal, Zeming Lin, Jason Liu, Demi Guo, Myle Ott, C. Lawrence Zitnick, Jerry Ma, and Rob Fergus. **ESM-1v** was released with the paper [Language models enable zero-shot prediction of the effects of mutations on protein function](https://doi.org/10.1101/2021.07.09.450648) by Joshua Meier, Roshan Rao, Robert Verkuil, Jason Liu, Tom Sercu and Alexander Rives. **ESM-2 and ESMFold** were released with the paper [Language models of protein sequences at the scale of evolution enable accurate structure prediction](https://doi.org/10.1101/2022.07.20.500902) by Zeming Lin, Halil Akin, Roshan Rao, Brian Hie, Zhongkai Zhu, Wenting Lu, Allan dos Santos Costa, Maryam Fazel-Zarandi, Tom Sercu, Sal Candido, Alexander Rives.
1. **[Falcon](https://huggingface.co/docs/transformers/model_doc/falcon)** (from Technology Innovation Institute) by Almazrouei, Ebtesam and Alobeidli, Hamza and Alshamsi, Abdulaziz and Cappelli, Alessandro and Cojocaru, Ruxandra and Debbah, Merouane and Goffinet, Etienne and Heslow, Daniel and Launay, Julien and Malartic, Quentin and Noune, Badreddine and Pannier, Baptiste and Penedo, Guilherme.
1. **[FLAN-T5](https://huggingface.co/docs/transformers/model_doc/flan-t5)** (from Google AI) released in the repository [google-research/t5x](https://github.com/google-research/t5x/blob/main/docs/models.md#flan-t5-checkpoints) by Hyung Won Chung, Le Hou, Shayne Longpre, Barret Zoph, Yi Tay, William Fedus, Eric Li, Xuezhi Wang, Mostafa Dehghani, Siddhartha Brahma, Albert Webson, Shixiang Shane Gu, Zhuyun Dai, Mirac Suzgun, Xinyun Chen, Aakanksha Chowdhery, Sharan Narang, Gaurav Mishra, Adams Yu, Vincent Zhao, Yanping Huang, Andrew Dai, Hongkun Yu, Slav Petrov, Ed H. Chi, Jeff Dean, Jacob Devlin, Adam Roberts, Denny Zhou, Quoc V. Le, and Jason Wei
1. **[GLPN](https://huggingface.co/docs/transformers/model_doc/glpn)** (from KAIST) released with the paper [Global-Local Path Networks for Monocular Depth Estimation with Vertical CutDepth](https://arxiv.org/abs/2201.07436) by Doyeon Kim, Woonghyun Ga, Pyungwhan Ahn, Donggyu Joo, Sehwan Chun, Junmo Kim.
Expand Down
14 changes: 13 additions & 1 deletion scripts/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,12 @@ class ConversionArguments:
"help": "Model identifier"
}
)
tokenizer_id: str = field(
default=None,
metadata={
"help": "Tokenizer identifier (if different to `model_id`)"
}
)
quantize: bool = field(
default=False,
metadata={
Expand Down Expand Up @@ -262,6 +268,7 @@ def main():
conv_args, = parser.parse_args_into_dataclasses()

model_id = conv_args.model_id
tokenizer_id = conv_args.tokenizer_id or model_id

output_model_folder = os.path.join(conv_args.output_parent_dir, model_id)

Expand All @@ -274,7 +281,7 @@ def main():
tokenizer = None
try:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)

except KeyError:
pass # No Tokenizer
Expand All @@ -300,6 +307,11 @@ def main():
with open(os.path.join(output_model_folder, 'tokenizer.json'), 'w', encoding='utf-8') as fp:
json.dump(tokenizer_json, fp, indent=4)

elif config.model_type == 'esm':
from .extra.esm import generate_fast_tokenizer
fast_tokenizer = generate_fast_tokenizer(tokenizer)
fast_tokenizer.save(os.path.join(output_model_folder, 'tokenizer.json'))

elif config.model_type == 'whisper':
if conv_args.output_attentions:
from .extra.whisper import get_main_export_kwargs
Expand Down
49 changes: 49 additions & 0 deletions scripts/extra/esm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from transformers.convert_slow_tokenizer import Converter
from tokenizers import Tokenizer, pre_tokenizers, processors
from tokenizers.models import WordPiece

class EsmConverter(Converter):
def converted(self) -> Tokenizer:
vocab = self.original_tokenizer.vocab
tokenizer = Tokenizer(WordPiece(vocab, continuing_subword_prefix='', max_input_chars_per_word=int(1e10), unk_token=str(self.original_tokenizer.unk_token)))

tokenizer.pre_tokenizer = pre_tokenizers.WhitespaceSplit()

cls = str(self.original_tokenizer.cls_token)
cls_token_id = self.original_tokenizer.cls_token_id
sep = str(self.original_tokenizer.eos_token) # No sep token in ESM vocabulary
sep_token_id = self.original_tokenizer.eos_token_id

if sep_token_id is None:
tokenizer.post_processor = processors.TemplateProcessing(
single=f"{cls}:0 $A:0",
special_tokens=[
(cls, cls_token_id),
],
)
else:
tokenizer.post_processor = processors.TemplateProcessing(
single=f"{cls}:0 $A:0 {sep}:0",
special_tokens=[
(cls, cls_token_id),
(sep, sep_token_id),
],
)

# For some reason, all tokens are added: none of them are special, but they all need special splitting.
# See https://github.com/huggingface/transformers/blob/df5c5c62ae253055336f5bb0828ca8e3e15ab6bd/src/transformers/models/esm/tokenization_esm.py#L79-L80
special_tokens = []
other_tokens = []
for token, token_id in vocab.items():
if token[0] == '<' and token[-1] == '>' and token_id <= 3:
special_tokens.append(token)
else:
other_tokens.append(token)

tokenizer.add_special_tokens(special_tokens)
tokenizer.add_tokens(other_tokens)
return tokenizer

def generate_fast_tokenizer(tokenizer):
tokenizer.vocab = tokenizer._token_to_id
return EsmConverter(tokenizer).converted()
24 changes: 24 additions & 0 deletions scripts/supported_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,30 @@
'google/electra-base-discriminator',
],
},
'esm': {
# Masked language modelling
'fill-mask': [
# with and without --task feature-extraction
'InstaDeepAI/nucleotide-transformer-500m-human-ref',
'InstaDeepAI/nucleotide-transformer-500m-1000g',

# NOTE: requires --opset 12
'facebook/esm2_t6_8M_UR50D',
'facebook/esm2_t12_35M_UR50D',
'facebook/esm2_t30_150M_UR50D',
'facebook/esm2_t33_650M_UR50D',
],

# Token classification
'token-classification': [
'AmelieSchreiber/esm2_t6_8M_UR50D_rna_binding_site_predictor',
],

# Zero-shot classification
'zero-shot-classification': [
'AmelieSchreiber/esm2_t6_8M_UR50D_sequence_classifier_v1',
],
},
'falcon': {
# Text generation
'text-generation': [
Expand Down
61 changes: 61 additions & 0 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -1892,6 +1892,63 @@ export class DistilBertForMaskedLM extends DistilBertPreTrainedModel {
//////////////////////////////////////////////////


//////////////////////////////////////////////////
// ESM models
export class EsmPreTrainedModel extends PreTrainedModel { }

/**
* The bare ESM Model transformer outputting raw hidden-states without any specific head on top.
*/
export class EsmModel extends EsmPreTrainedModel { }

/**
* ESM Model with a `language modeling` head on top.
*/
export class EsmForMaskedLM extends EsmPreTrainedModel {
/**
* Calls the model on new inputs.
*
* @param {Object} model_inputs The inputs to the model.
* @returns {Promise<MaskedLMOutput>} An object containing the model's output logits for masked language modeling.
*/
async _call(model_inputs) {
return new MaskedLMOutput(await super._call(model_inputs));
}
}

/**
* ESM Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled output)
*/
export class EsmForSequenceClassification extends EsmPreTrainedModel {
/**
* Calls the model on new inputs.
*
* @param {Object} model_inputs The inputs to the model.
* @returns {Promise<SequenceClassifierOutput>} An object containing the model's output logits for sequence classification.
*/
async _call(model_inputs) {
return new SequenceClassifierOutput(await super._call(model_inputs));
}
}

/**
* ESM Model with a token classification head on top (a linear layer on top of the hidden-states output)
* e.g. for Named-Entity-Recognition (NER) tasks.
*/
export class EsmForTokenClassification extends EsmPreTrainedModel {
/**
* Calls the model on new inputs.
*
* @param {Object} model_inputs The inputs to the model.
* @returns {Promise<TokenClassifierOutput>} An object containing the model's output logits for token classification.
*/
async _call(model_inputs) {
return new TokenClassifierOutput(await super._call(model_inputs));
}
}
//////////////////////////////////////////////////


//////////////////////////////////////////////////
// MobileBert models
export class MobileBertPreTrainedModel extends PreTrainedModel { }
Expand Down Expand Up @@ -4539,6 +4596,7 @@ export class PretrainedMixin {
const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([
['bert', ['BertModel', BertModel]],
['electra', ['ElectraModel', ElectraModel]],
['esm', ['EsmModel', EsmModel]],
['convbert', ['ConvBertModel', ConvBertModel]],
['camembert', ['CamembertModel', CamembertModel]],
['deberta', ['DebertaModel', DebertaModel]],
Expand Down Expand Up @@ -4622,6 +4680,7 @@ const MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES = new Map([
const MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = new Map([
['bert', ['BertForSequenceClassification', BertForSequenceClassification]],
['electra', ['ElectraForSequenceClassification', ElectraForSequenceClassification]],
['esm', ['EsmForSequenceClassification', EsmForSequenceClassification]],
['convbert', ['ConvBertForSequenceClassification', ConvBertForSequenceClassification]],
['camembert', ['CamembertForSequenceClassification', CamembertForSequenceClassification]],
['deberta', ['DebertaForSequenceClassification', DebertaForSequenceClassification]],
Expand All @@ -4641,6 +4700,7 @@ const MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = new Map([
const MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = new Map([
['bert', ['BertForTokenClassification', BertForTokenClassification]],
['electra', ['ElectraForTokenClassification', ElectraForTokenClassification]],
['esm', ['EsmForTokenClassification', EsmForTokenClassification]],
['convbert', ['ConvBertForTokenClassification', ConvBertForTokenClassification]],
['camembert', ['CamembertForTokenClassification', CamembertForTokenClassification]],
['deberta', ['DebertaForTokenClassification', DebertaForTokenClassification]],
Expand Down Expand Up @@ -4685,6 +4745,7 @@ const MODEL_WITH_LM_HEAD_MAPPING_NAMES = new Map([
const MODEL_FOR_MASKED_LM_MAPPING_NAMES = new Map([
['bert', ['BertForMaskedLM', BertForMaskedLM]],
['electra', ['ElectraForMaskedLM', ElectraForMaskedLM]],
['esm', ['EsmForMaskedLM', EsmForMaskedLM]],
['convbert', ['ConvBertForMaskedLM', ConvBertForMaskedLM]],
['camembert', ['CamembertForMaskedLM', CamembertForMaskedLM]],
['deberta', ['DebertaForMaskedLM', DebertaForMaskedLM]],
Expand Down
30 changes: 18 additions & 12 deletions src/tokenizers.js
Original file line number Diff line number Diff line change
Expand Up @@ -1617,6 +1617,7 @@ class Decoder extends Callable {
* @throws {Error} If an unknown decoder type is provided.
*/
static fromConfig(config) {
if (config === null) return null;
switch (config.type) {
case 'WordPiece':
return new WordPieceDecoder(config);
Expand Down Expand Up @@ -2216,13 +2217,6 @@ export class PreTrainedTokenizer extends Callable {
// TODO: maybe, allow this to be null; in which case, we use model as decoder too?
this.decoder = Decoder.fromConfig(tokenizerJSON.decoder);


// Another slight hack to add `end_of_word_suffix` (if present) to the decoder
// This is needed for cases where BPE model and ByteLevel decoder are used
// For more information, see https://github.com/xenova/transformers.js/issues/74
// TODO: save this to the decoder when exporting?
this.decoder.end_of_word_suffix = this.model.end_of_word_suffix;

// Add added_tokens to model
this.special_tokens = [];
this.all_special_ids = [];
Expand All @@ -2246,8 +2240,17 @@ export class PreTrainedTokenizer extends Callable {
this.special_tokens.push(...(tokenizerConfig.additional_special_tokens ?? []));
this.special_tokens = [...new Set(this.special_tokens)]; // Remove duplicates

// Slight hack, but it prevents code duplication:
this.decoder.added_tokens = this.added_tokens;
if (this.decoder) {
// Slight hack, but it prevents code duplication:
this.decoder.added_tokens = this.added_tokens;

// Another slight hack to add `end_of_word_suffix` (if present) to the decoder
// This is needed for cases where BPE model and ByteLevel decoder are used
// For more information, see https://github.com/xenova/transformers.js/issues/74
// TODO: save this to the decoder when exporting?
this.decoder.end_of_word_suffix = this.model.end_of_word_suffix;
}


this.added_tokens_regex = this.added_tokens.length > 0 ? new RegExp(
'(' + this.added_tokens.map(escapeRegExp).join('|') + ')'
Expand Down Expand Up @@ -2634,13 +2637,14 @@ export class PreTrainedTokenizer extends Callable {
tokens = tokens.filter(x => !this.special_tokens.includes(x));
}

// If `this.decoder` is null, we just join tokens with a space:
// https://github.com/huggingface/tokenizers/blob/8edec536a737cb04494b454805be16c020abb14f/tokenizers/src/tokenizer/mod.rs#L835
/** @type {string} */
let decoded = this.decoder(tokens);

let decoded = this.decoder ? this.decoder(tokens) : tokens.join(' ');

// Slight hack, but prevents having to pass `skip_special_tokens` to
// each call to `decode`, which would lead to code duplication.
if (this.decoder.end_of_word_suffix) {
if (this.decoder && this.decoder.end_of_word_suffix) {
decoded = decoded.replaceAll(this.decoder.end_of_word_suffix, ' ');
if (skip_special_tokens) {
decoded = decoded.trim();
Expand Down Expand Up @@ -2811,6 +2815,7 @@ export class FalconTokenizer extends PreTrainedTokenizer { }

export class GPTNeoXTokenizer extends PreTrainedTokenizer { }

export class EsmTokenizer extends PreTrainedTokenizer { }

/**
* Helper function to build translation inputs for an `NllbTokenizer` or `M2M100Tokenizer`.
Expand Down Expand Up @@ -3896,6 +3901,7 @@ export class AutoTokenizer {
MPNetTokenizer,
FalconTokenizer,
GPTNeoXTokenizer,
EsmTokenizer,
Wav2Vec2CTCTokenizer,
BlenderbotTokenizer,
BlenderbotSmallTokenizer,
Expand Down
10 changes: 9 additions & 1 deletion tests/generate_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,15 @@
"The Heavenly Llama is said to drink water from the ocean and urinates as it rains.[6] According to " \
"Aymara eschatology, llamas will return to the water springs and lagoons where they come from at the " \
"end of time.[6]"
]
],
"InstaDeepAI/nucleotide-transformer-500m-human-ref": [
# Actual protein sequences
"ATTCCGATTCCGATTCCG",
"ATTTCTCTCTCTCTCTGAGATCGATCGATCGAT",

# Special tokens
"<unk><pad><mask><cls><eos><bos>",
],
},
}

Expand Down

0 comments on commit 80d22da

Please sign in to comment.