Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Packing in DPOTrainer #2469

Open
zhc7 opened this issue Dec 13, 2024 · 2 comments
Open

Packing in DPOTrainer #2469

zhc7 opened this issue Dec 13, 2024 · 2 comments
Labels
🏋 DPO Related to DPO ✨ enhancement New feature or request

Comments

@zhc7
Copy link
Contributor

zhc7 commented Dec 13, 2024

Feature request

packing can be supported in dpo trainer.

Motivation

sequences with high variation in length padded together will waste a lot of resources. issue #1274 mentioned this, but i don't think the conclusion is correct. packing won't conflict with pairwise datasets. you just need to unpack the sequence after forwarding. you can easily identify the start and the end of each sequence by the position of 0 in position_ids.

Your contribution

actually i already have a version of dpo trainer that can deal with packing:

    def _forward_one(self, model: nn.Module, batch: dict, name: str):
        batch = dict(batch)
        targets = batch.pop("targets")
        loss_masks = batch.pop("loss_masks").to(torch.bool)
        out = model(**batch)
        logits = out["logits"]
        targets = targets.clone()
        targets[~loss_masks] = 0  # dummy token; we'll ignore the losses on these tokens later
        per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=targets.unsqueeze(2)).squeeze(2)
        per_token_logps[~loss_masks] = 0
        # unpack logp
        assert batch["position_ids"].shape[0] == 1
        starts = (batch["position_ids"] == 0).nonzero()[:, 1]
        seqs = []
        for i in range(len(starts) - 1):
            seqs.append(per_token_logps[0, starts[i]: starts[i + 1]].sum())
        seqs.append(per_token_logps[0, starts[-1]:].sum())
        all_logps = torch.stack(seqs)
        return {
            f"{name}_logps": all_logps,
            f"mean_{name}_logits": logits[loss_masks].mean(),
        }

    def concatenated_forward(self, model: nn.Module, batch: dict):
        ret = {}
        for name in ("chosen", "rejected"):
            ret = {**ret, **self._forward_one(model, batch[name], name)}
        return ret

this does not directly fit into the complex logic of dpo trainer, but the idea is that it is possible to do packing.

@qgallouedec
Copy link
Member

qgallouedec commented Dec 13, 2024

What you're describing sounds closer to padding-free than packing. We have a (currently draft) PR for this: #2437.
Can you confirm that's it is what you're describing?


At this point I'm not even sure that packing for DPO makes sense. How to ensure that you've as many chosen than rejected? How to ensure they match? How to handle partial sequences?

@qgallouedec qgallouedec added ✨ enhancement New feature or request 🏋 DPO Related to DPO labels Dec 13, 2024
@zhc7
Copy link
Contributor Author

zhc7 commented Dec 13, 2024

Hi, thank you for your response. I looked into the link you provided. I think we are talking about the same thing. I used the word "packing" from https://huggingface.co/blog/packing-with-FA2. The "packing" here actually means concatenating a fixed batch size of samples into one sequence, and use position_ids to mark the boundaries, rather than packing to a fixed length. So there won't be the problems you mentioned. I've also briefly read https://huggingface.co/blog/mayank-mishra/padding-free-transformer this blog, I think the ideas are the same. But I'm not sure how the latter is implemented. Maybe they are the same thing just with different names:)

I breifly went through the pr, I see it is trying to add position_ids in the whole process, so I guess we are talking about the same thing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🏋 DPO Related to DPO ✨ enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants