We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
script/main.py中 class DataCollatorForLMDataset(object):
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: input_ids, labels = tuple([instance[key].unsqueeze(0) for instance in instances] for key in ("input_ids", "labels")) input_ids = torch.cat(input_ids, dim=0) labels = torch.cat(labels, dim=0) eos_indices = input_ids.argmin(dim=1) - 1 max_position = eos_indices.max() if max_position < 0: return dict( input_ids=input_ids, labels=labels ) return dict( input_ids=input_ids[:, :max_position+1], labels=labels[:, :max_position+1] )
这里,为什么 "eos_indices = input_ids.argmin(dim=1) - 1",
但是在sort_and_group.py中, eos_indice = (input_id == EOS_ID).int().argmax().item()
The text was updated successfully, but these errors were encountered:
这两行作用相同,可以统一为eos_indice = (input_id == EOS_ID).int().argmax().item()
eos_indice = (input_id == EOS_ID).int().argmax().item()
Sorry, something went wrong.
谢谢提醒,我这两天更新代码
No branches or pull requests
script/main.py中
class DataCollatorForLMDataset(object):
这里,为什么 "eos_indices = input_ids.argmin(dim=1) - 1",
但是在sort_and_group.py中,
eos_indice = (input_id == EOS_ID).int().argmax().item()
The text was updated successfully, but these errors were encountered: