Skip to content

Commit

Permalink
add 'right' option for 'truncation_strategy'
Browse files Browse the repository at this point in the history
  • Loading branch information
zsxm1998 committed Dec 24, 2024
1 parent 00c2eaa commit 7a5be2a
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 10 deletions.
2 changes: 1 addition & 1 deletion docs/source/Instruction/命令行参数.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
- 🔥template: 对话模板类型,默认使用model对应的template类型。`swift pt`会将对话模版转为生成模板使用
- 🔥system: 自定义system字段,默认为None,使用template的默认system
- 🔥max_length: 单样本的tokens最大长度,默认为None,不做限制
- truncation_strategy: 如果超长如何处理,支持`delete``left`,代表删除和左侧裁剪,默认为'delete'
- truncation_strategy: 如果超长如何处理,支持`delete`, `left``right`,代表删除、左侧裁剪和右侧裁剪,默认为'delete'
- 🔥max_pixels: 多模态模型图片前处理的最大像素数(H\*W),默认不缩放。
- tools_prompt: 智能体训练时的工具列表转为system的格式,请参考[智能体训练](./智能体的支持.md),默认为'react_en'
- loss_scale: 如何针对训练添加token的loss权重。默认为`'default'`,代表所有response(含history)以1计算交叉熵损失。具体可以查看[插件化](../Customization/插件化.md)[智能体训练](./智能体的支持.md)
Expand Down
2 changes: 1 addition & 1 deletion docs/source_en/Instruction/Command-line-parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ The introduction to command line parameters will cover base arguments, atomic ar
- 🔥template: Type of dialogue template, which defaults to the template type corresponding to the model. `swift pt` will convert the dialogue template into a generation template for use.
- 🔥system: Custom system field, default is None, uses the default system of the template.
- 🔥max_length: Maximum length of tokens for a single sample, default is None (no limit).
- truncation_strategy: How to handle overly long tokens, supports `delete` and `left`, representing deletion and left trimming, default is 'delete'.
- truncation_strategy: How to handle overly long tokens, supports `delete`, `left`, `right`, representing deletion, left trimming, and right trimming, default is 'delete'.
- 🔥max_pixels: Maximum pixel count for pre-processing images in multimodal models (H*W), default is no scaling.
- tools_prompt: The list of tools for agent training converted to system format, refer to [Agent Training](./Agent-support.md), default is 'react_en'.
- loss_scale: How to add token loss weight during training. Default is `'default'`, meaning all responses (including history) are treated as 1 for cross-entropy loss. For specifics, see [Pluginization](../Customization/Pluginization.md) and [Agent Training](./Agent-support.md).
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/argument/base_args/template_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class TemplateArguments:
system: Optional[str] = None # Override the default_system in the template.
max_length: Optional[int] = None

truncation_strategy: Literal['delete', 'left'] = 'delete'
truncation_strategy: Literal['delete', 'left', 'right'] = 'delete'
max_pixels: Optional[int] = None
tools_prompt: str = 'react_en' # Override the default_tools_prompt in the template.
# train
Expand Down
19 changes: 13 additions & 6 deletions swift/llm/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
*,
use_chat_template: bool = True,
template_backend: Literal['swift', 'jinja'] = 'swift',
truncation_strategy: Literal['raise', 'left'] = 'raise',
truncation_strategy: Literal['raise', 'left', 'right'] = 'raise',
max_pixels: Optional[int] = None,
tools_prompt: Optional[str] = None,
# only for train
Expand Down Expand Up @@ -630,11 +630,18 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
if self.truncation_strategy == 'raise' and len(input_ids) > self.max_length:
raise MaxLengthError(f'Current length of row({len(input_ids)}) is larger'
f' than the max_length({self.max_length}).')
input_ids = input_ids[-self.max_length:]
if labels is not None:
labels = labels[-self.max_length:]
if loss_scale is not None:
loss_scale = loss_scale[-self.max_length:]
elif self.truncation_strategy == 'right':
input_ids = input_ids[:self.max_length]
if labels is not None:
labels = labels[:self.max_length]
if loss_scale is not None:
loss_scale = loss_scale[:self.max_length]
else:
input_ids = input_ids[-self.max_length:]
if labels is not None:
labels = labels[-self.max_length:]
if loss_scale is not None:
loss_scale = loss_scale[-self.max_length:]

encoded['input_ids'] = input_ids
encoded['labels'] = labels
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/template/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def get_template(
*,
use_chat_template: bool = True,
template_backend: Literal['swift', 'jinja'] = 'swift',
truncation_strategy: Literal['raise', 'left'] = 'raise',
truncation_strategy: Literal['raise', 'left', 'right'] = 'raise',
max_pixels: Optional[int] = None, # h * w
tools_prompt: str = 'react_en',
# train
Expand Down

0 comments on commit 7a5be2a

Please sign in to comment.