Skip to content

Commit

Permalink
model --> model_name_or_path (huggingface#1452)
Browse files Browse the repository at this point in the history
* `model` --> `model_name_or_path`

* fix style
  • Loading branch information
lvwerra authored and Andrew Lapp committed May 10, 2024
1 parent a49a076 commit 90eb89d
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 12 deletions.
6 changes: 3 additions & 3 deletions docs/source/clis.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ You can pass any of these arguments either to the CLI or the YAML file.
Follow the basic instructions above and run `trl sft --output_dir <output_dir> <*args>`:

```bash
trl sft --config config.yaml --output_dir your-output-dir
trl sft --model_name_or_path facebook/opt-125m --dataset_name imdb --output_dir opt-sft-imdb
```

The SFT CLI is based on the `examples/scripts/sft.py` script.
Expand All @@ -82,7 +82,7 @@ python examples/datasets/anthropic_hh.py --push_to_hub --hf_entity your-hf-org
Once your dataset being pushed, run the dpo CLI as follows:

```bash
trl dpo --config config.yaml --output_dir your-output-dir
trl dpo --model_name_or_path facebook/opt-125m --dataset_name trl-internal-testing/Anthropic-hh-rlhf-processed --output_dir opt-sft-hh-rlhf
```

The SFT CLI is based on the `examples/scripts/dpo.py` script.
Expand All @@ -92,7 +92,7 @@ The SFT CLI is based on the `examples/scripts/dpo.py` script.
The chat CLI lets you quickly load the model and talk to it. Simply run the following:

```bash
trl chat --model Qwen/Qwen1.5-0.5B-Chat
trl chat --model_name_or_path Qwen/Qwen1.5-0.5B-Chat
```

Note that the chat interface relies on the chat template of the tokenizer to format the inputs for the model. Make sure your tokenizer has a chat template defined.
Expand Down
17 changes: 9 additions & 8 deletions examples/scripts/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def save_chat(chat, args, filename):
folder = args.save_folder

if filename is None:
filename = create_default_filename(args.model)
filename = create_default_filename(args.model_name_or_path)
filename = os.path.join(folder, filename)
os.makedirs(os.path.dirname(filename), exist_ok=True)

Expand Down Expand Up @@ -210,7 +210,9 @@ def parse_settings(user_input, current_args, interface):
return current_args, True


def load_model(args):
def load_model_and_tokenizer(args):
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)

torch_dtype = args.torch_dtype if args.torch_dtype in ["auto", None] else getattr(torch, args.torch_dtype)
quantization_config = get_quantization_config(args)
model_kwargs = dict(
Expand All @@ -221,12 +223,12 @@ def load_model(args):
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
model = AutoModelForCausalLM.from_pretrained(args.model, **model_kwargs)
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, **model_kwargs)

if getattr(model, "hf_device_map", None) is None:
model = model.to(args.device)
print(model.device)
return model

return model, tokenizer


def chat_cli():
Expand All @@ -247,11 +249,10 @@ def chat_cli():
else:
user = args.user

tokenizer = AutoTokenizer.from_pretrained(args.model)
model = load_model(args)
model, tokenizer = load_model_and_tokenizer(args)
generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)

interface = RichInterface(model_name=args.model, user_name=user)
interface = RichInterface(model_name=args.model_name_or_path, user_name=user)
interface.clear()
chat = clear_chat_history(current_args.system_prompt)
while True:
Expand Down
2 changes: 1 addition & 1 deletion trl/commands/cli_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ class DpoScriptArguments:
@dataclass
class ChatArguments:
# general settings
model: str = field(metadata={"help": "Name of the pre-trained model"})
model_name_or_path: str = field(metadata={"help": "Name of the pre-trained model"})
user: str = field(default=None, metadata={"help": "Username to display in chat interface"})
system_prompt: str = field(default=None, metadata={"help": "System prompt"})
save_folder: str = field(default="./chat_history/", metadata={"help": "Folder to save chat history"})
Expand Down

0 comments on commit 90eb89d

Please sign in to comment.