forked from huggingface/transformers
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Diff converter v2 (huggingface#30868)
* current working example! * commit regex and result file * update * nit * push the conversion file * oups * roadmap and nits * attempt diffs for 3 files * persimmon * nit * add diff file that is the same as the modeling_llama.py * fix rope nits * updates * updates with converted versions * give some breathing space to the code * delete * update * update * push the actual result * update regex patterns * update regex patterns * fix some issues * fix some issues * fix some issues * updates * updates * updates * updates * updates * revert changes done to llama * updates * update gemma * updates * oups * current state * current state * update * ouiiii * nit * clear diffs * nit * fixup * update * doc 🚀 * 🔥 * for now use gemma * deal with comments * style * handle funtions * deal with assigns * todos * process inheritage * keep decorators? * 🤗 * deal with duplicates * fixup * correctly remove duplicate code * run ruff post script * ruff deals pretty well with imports, let's leave it to him * ah maybe not lol * for now remove all imports from child. * nit * conversion of llama * okay * convert starcoder2 * synch with main * update llama diff * updates * https://docs.astral.sh/ruff/rules/redefined-while-unused/ fixes the imports, bit needs later version of ruff * updates * okay actual state * non zero exit * update! * revert unrelated * remove other diff files * updates * cleanup * update * less diff! * stash * current updates * updates * No need for call * finished fining deps * update * current changes * current state * current state * new status * nit * finally * fixes * nits * order is now expected * use logger info instead of prints * fixup * up * nit * update * nits * update * correct merge * update * update * update * add warning * update caution message * update * better merging strategy * copy class statements :wink * fixups * nits * update * Apply suggestions from code review Co-authored-by: amyeroberts <[email protected]> * nits * smaller header * do cleanup some stuff * even simpler header? * fixup * updates * ruff * update examples * nit * TODO * state * OUUUUUUF * current state * nits * final state * add a readme * fixup * remove diff llama * fix * nit * dummy noy funny * ruff format tests src utils --check * everless diffs * less diffs and fix test * fixes * naming nit? * update converter and add supper example * nits * updated for function signatures * update * update * add converted dummies * autoformat * single target assign fix * fixup * fix some imports * fixes * don't push them * `# noqa: F841` --------- Co-authored-by: amyeroberts <[email protected]>
- Loading branch information
1 parent
372baec
commit 96eb062
Showing
13 changed files
with
1,315 additions
and
62 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# Using the `diff_converter` linter | ||
|
||
`pip install libcst` is a must! | ||
|
||
# `sh examples/diff-conversion/convert_examples.sh` to get the converted outputs | ||
|
||
The diff converter is a new `linter` specific to `transformers`. It allows us to unpack inheritance in python to convert a modular `diff` file like `diff_gemma.py` into a `single model single file`. | ||
|
||
Examples of possible usage are available in the `examples/diff-conversion`, or `diff_gemma` for a full model usage. | ||
|
||
`python utils/diff_model_converter.py --files_to_parse "/Users/arthurzucker/Work/transformers/examples/diff-conversion/diff_my_new_model2.py"` | ||
|
||
## How it works | ||
We use the `libcst` parser to produce an AST representation of the `diff_xxx.py` file. For any imports that are made from `transformers.models.modeling_xxxx` we parse the source code of that module, and build a class dependency mapping, which allows us to unpack the difference dependencies. | ||
|
||
The code from the `diff` file and the class dependency mapping are "merged" to produce the single model single file. | ||
We use ruff to automatically remove the potential duplicate imports. | ||
|
||
## Why we use libcst instead of the native AST? | ||
AST is super powerful, but it does not keep the `docstring`, `comment` or code formatting. Thus we decided to go with `libcst` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
#!/bin/bash | ||
|
||
# Iterate over each file in the current directory | ||
for file in examples/diff-conversion/diff_*; do | ||
# Check if it's a regular file | ||
if [ -f "$file" ]; then | ||
# Call the Python script with the file name as an argument | ||
python utils/diff_model_converter.py --files_to_parse "$file" | ||
fi | ||
done |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from math import log | ||
from typing import List, Optional, Tuple, Union | ||
|
||
import torch | ||
|
||
from transformers import Cache | ||
from transformers.modeling_outputs import CausalLMOutputWithPast | ||
from transformers.models.llama.modeling_llama import LlamaModel | ||
|
||
|
||
def _pre_process_input(input_ids): | ||
print(log(input_ids)) | ||
return input_ids | ||
|
||
|
||
# example where we need some deps and some functions | ||
class DummyModel(LlamaModel): | ||
def forward( | ||
self, | ||
input_ids: torch.LongTensor = None, | ||
attention_mask: Optional[torch.Tensor] = None, | ||
position_ids: Optional[torch.LongTensor] = None, | ||
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, | ||
inputs_embeds: Optional[torch.FloatTensor] = None, | ||
use_cache: Optional[bool] = None, | ||
output_attentions: Optional[bool] = None, | ||
output_hidden_states: Optional[bool] = None, | ||
return_dict: Optional[bool] = None, | ||
cache_position: Optional[torch.LongTensor] = None, | ||
) -> Union[Tuple, CausalLMOutputWithPast]: | ||
input_ids = _pre_process_input(input_ids) | ||
|
||
return super().forward( | ||
None, | ||
attention_mask, | ||
position_ids, | ||
past_key_values, | ||
inputs_embeds, | ||
use_cache, | ||
output_attentions, | ||
output_hidden_states, | ||
return_dict, | ||
cache_position, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from transformers.models.llama.configuration_llama import LlamaConfig | ||
|
||
|
||
# Example where we only want to only add a new config argument and new arg doc | ||
# here there is no `ARG` so we are gonna take parent doc | ||
class MyNewModelConfig(LlamaConfig): | ||
r""" | ||
mlp_bias (`bool`, *optional*, defaults to `False`) | ||
""" | ||
|
||
def __init__(self, mlp_bias=True, new_param=0, **super_kwargs): | ||
self.mlp_bias = mlp_bias | ||
self.new_param = new_param | ||
super().__init__(self, **super_kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from transformers.models.gemma.modeling_gemma import GemmaForSequenceClassification | ||
from transformers.models.llama.configuration_llama import LlamaConfig | ||
|
||
|
||
# Example where we only want to only modify the docstring | ||
class MyNewModel2Config(LlamaConfig): | ||
r""" | ||
This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma | ||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the | ||
defaults will yield a similar configuration to that of the Gemma-7B. | ||
e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b) | ||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the | ||
documentation from [`PretrainedConfig`] for more information. | ||
Args: | ||
vocab_size (`int`, *optional*, defaults to 256000): | ||
Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the | ||
`inputs_ids` passed when calling [`GemmaModel`] | ||
```python | ||
>>> from transformers import GemmaModel, GemmaConfig | ||
>>> # Initializing a Gemma gemma-7b style configuration | ||
>>> configuration = GemmaConfig() | ||
>>> # Initializing a model from the gemma-7b style configuration | ||
>>> model = GemmaModel(configuration) | ||
>>> # Accessing the model configuration | ||
>>> configuration = model.config | ||
```""" | ||
|
||
|
||
# Example where alllllll the dependencies are fetched to just copy the entire class | ||
class MyNewModel2ForSequenceClassification(GemmaForSequenceClassification): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# Example where we only want to overwrite the defaults of an init | ||
|
||
from transformers.models.gemma.configuration_gemma import GemmaConfig | ||
|
||
|
||
class NewModelConfig(GemmaConfig): | ||
def __init__( | ||
self, | ||
vocab_size=256030, | ||
hidden_size=64, | ||
intermediate_size=90, | ||
num_hidden_layers=28, | ||
num_attention_heads=16, | ||
num_key_value_heads=16, | ||
head_dim=256, | ||
hidden_act="gelu_pytorch_tanh", | ||
hidden_activation=None, | ||
max_position_embeddings=1500, | ||
initializer_range=0.02, | ||
rms_norm_eps=1e-6, | ||
use_cache=True, | ||
pad_token_id=0, | ||
eos_token_id=1, | ||
bos_token_id=2, | ||
tie_word_embeddings=True, | ||
rope_theta=10000.0, | ||
attention_bias=False, | ||
attention_dropout=0.0, | ||
): | ||
super().__init__(self) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from typing import List, Optional, Tuple, Union | ||
|
||
import torch | ||
|
||
from transformers import Cache | ||
from transformers.modeling_outputs import CausalLMOutputWithPast | ||
from transformers.models.llama.modeling_llama import LlamaModel | ||
|
||
|
||
# example where we need some deps and some functions | ||
class SuperModel(LlamaModel): | ||
def forward( | ||
self, | ||
input_ids: torch.LongTensor = None, | ||
attention_mask: Optional[torch.Tensor] = None, | ||
position_ids: Optional[torch.LongTensor] = None, | ||
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, | ||
inputs_embeds: Optional[torch.FloatTensor] = None, | ||
use_cache: Optional[bool] = None, | ||
output_attentions: Optional[bool] = None, | ||
output_hidden_states: Optional[bool] = None, | ||
return_dict: Optional[bool] = None, | ||
cache_position: Optional[torch.LongTensor] = None, | ||
) -> Union[Tuple, CausalLMOutputWithPast]: | ||
out = super().forward( | ||
input_ids, | ||
attention_mask, | ||
position_ids, | ||
past_key_values, | ||
inputs_embeds, | ||
use_cache, | ||
output_attentions, | ||
output_hidden_states, | ||
return_dict, | ||
cache_position, | ||
) | ||
out.logits *= 2**4 | ||
return out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.