Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Important remark. Potential inconsistency between source code and paper: DLow's diversity loss. #38

Open
PFery4 opened this issue Nov 22, 2023 · 0 comments

Comments

@PFery4
Copy link

PFery4 commented Nov 22, 2023

Good day to all people involved in developing/studying the AgentFormer model,

I would like to make a small remark with regards to DLow's diversity sampling loss. There may be a discrepancy between the definition of the loss as stipulated in the paper, and the loss as written in the original source code.

Indeed, the paper defines the diversity component of the loss as:
image
Here we can see that each prediction made across the K number of modes is being compared to one another, by summing the distance of each point predicted over the prediction horizon.

However, in the source code, the implementation collapses the x and y components of the prediction into one single dimension. This in turn means that the distance being computed with the F.pdist() function does not compute the sum of distances across each timestep, but instead the L2 distance of two points in a high dimension space of shape [T_pred * 2].

Here's a minimal code snippet that highlights the difference between the loss as defined in the source code, and the loss as explained in the paper:

    import torch.nn.functional as F

    # scaling factor
    d_scale = 10

    # example predictions
    pred_1 = torch.Tensor([[1, 1],
                           [2, 1],
                           [3, 1],
                           [3, 2],
                           [3, 3],
                           [3, 4]])
    pred_2 = torch.Tensor([[1, 1],
                           [1, 1.5],
                           [1, 2],
                           [2, 2],
                           [3, 3],
                           [4, 3]])
    pred_3 = torch.zeros_like(pred_1)

    # predictions are of shape [N agents, K samples, P prediction length, 2]
    preds = torch.stack([pred_1, pred_2, pred_3]).unsqueeze(0)

    # diversity_loss reshaped the predictions tensor to collapse the x and y components of predictions
    reshaped_preds = preds.view(*preds.shape[:2], -1)       # [N agents, K samples, P prediction length * 2]

    code_loss = 0
    for motion in reshaped_preds:
        # motion: [K, P * 2]
        their_dist = F.pdist(motion, 2) ** 2
        code_loss += (-their_dist / d_scale).exp().mean()
    print(f"{code_loss=}")

    paper_loss = 0
    paper_dists = []
    for motion in preds:
        for k1, sample_1 in enumerate(motion):
            for k2, sample_2 in enumerate(motion[k1+1:, ...]):
                # sample_1, sample_2 --> [P, 2]

                # difference between any two non-identical predictions
                diff = sample_1 - sample_2

                # sum of euclidean distance between points of diff over each timestep
                se = diff.pow(2).sum(-1).sqrt().sum()

                paper_dists.append(se)

    paper_dists = torch.tensor(paper_dists)

    paper_loss = (-paper_dists / 10).exp().mean()

    print(f"{paper_loss=}")

Note that the loss value as defined in the paper is different than that of the one implemented in the source code.

I would however like to also mention that both versions of the loss do encourage diversity among predictions. However the way in which they 'push' predictions away from each other is different.

It might be nice (for whomever would be interested in studying this further) to implement an efficient version of the computation performed to obtain paper_loss as I showed in the code snippet above, and check if this ends up altering the behaviour of DLow.

I do not suspect a major change in the way DLow operates. However I leave this remark here for whomever might want to study the DLow module in more detail.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant