Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support add tokens to tokenizer. #498

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions trlx/models/modeling_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,18 @@ def __init__(
for parameter in self.parameters():
parameter.requires_grad_(False)

def set_input_embeddings(self, value):
self.embed_tokens = value

def get_input_embeddings(self):
return self.embed_tokens

def get_output_embeddings(self):
return self.lm_head

def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings


class GPTModelBranch(ModelBranch):
def forward( # noqa: max-complexity
Expand Down
2 changes: 2 additions & 0 deletions trlx/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
logit_mask=None,
stop_sequences=None,
train_mode=False,
additional_tokens=None,
):
self.store: BaseRolloutStore = None
self.config = config
Expand All @@ -49,6 +50,7 @@ def __init__(
self.train_mode = train_mode
self.logit_mask = logit_mask
self.stop_sequences = stop_sequences
self.additional_tokens = additional_tokens

def push_to_store(self, data):
self.store.push(data)
Expand Down
4 changes: 4 additions & 0 deletions trlx/trainer/accelerate_base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ def __init__(self, config, **kwargs): # noqa: C901
self.scheduler = self.setup_scheduler()

self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer.tokenizer_path)
self.tokenizer.add_tokens(self.additional_tokens)
# resize the model by-default
self.model.base_model.resize_token_embeddings(len(self.tokenizer))

self.tokenizer.padding_side = config.tokenizer.padding_side
self.tokenizer.truncation_side = config.tokenizer.truncation_side
self.tokenizer.sep_token = "<sep>"
Expand Down
6 changes: 6 additions & 0 deletions trlx/trainer/accelerate_ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,15 @@ def __init__(self, config: TRLConfig, **kwargs):

# Set up a reference model when hydra heads are not used
if not hasattr(self.model, "frozen_head") and not self.model.peft_type:
# Full Reference Copy
self.ref_model = self.get_arch(self.config)
self.ref_model.base_model.resize_token_embeddings(len(self.tokenizer))
self.ref_model.to(self.accelerator.device)
self.ref_model.eval()
elif hasattr(self.model, "frozen_head") and self.model.frozen_head is not None:
# Hydra Reference: Use the frozen base layers and head as the reference model, resize hydra heads
self.model.frozen_head.resize_token_embeddings(len(self.tokenizer))
# else PEFT Reference

# Set up the KL controller
# This helps prevent large divergences in the controller (policy)
Expand Down
7 changes: 6 additions & 1 deletion trlx/trlx.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import warnings
from typing import Callable, Dict, Iterable, List, Optional, Tuple
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union

from trlx.data.configs import TRLConfig
from trlx.data.default_configs import (
Expand All @@ -23,6 +23,7 @@ def train( # noqa: C901
metric_fn: Optional[Callable[[List[str], List[str], List[str]], Dict[str, List[float]]]] = None,
config: Optional[TRLConfig] = None,
stop_sequences: Optional[List[str]] = [],
additional_tokens: Optional[Union[str, List[str]]] = None,
):
"""
Dispatches online, offline reinforcement training or supervised finetuning
Expand Down Expand Up @@ -54,6 +55,9 @@ def train( # noqa: C901
stop_sequences (Optional[List[str]]):
String sequences to trim generations (both for generating of experience and evaluation) up to its
encounter in them. Generations will not contain them and also will also be right-stripped
additional_tokens (Optional[Union[str, List[str]]]):
A list of additional tokens. The given tokens are added only if they don’t already exist
in the vocabulary, each token then gets a new attributed id
"""
if config is None:
warnings.warn(
Expand Down Expand Up @@ -81,6 +85,7 @@ def train( # noqa: C901
reward_fn=reward_fn,
metric_fn=metric_fn,
stop_sequences=stop_sequences,
additional_tokens=additional_tokens,
**config.train.trainer_kwargs,
)

Expand Down