Skip to content

Commit

Permalink
Fix implicit expansion when computing head logit mask
Browse files Browse the repository at this point in the history
The implicit expansion emitted a Torch warning in each call.
  • Loading branch information
danieldk committed Oct 16, 2023
1 parent 5dcd024 commit 4247f35
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions syntaxdot-cli/src/subcommands/distill.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,8 @@ impl DistillApp {
student_logits: &BiaffineScoreLogits,
token_mask: &TokenMask,
) -> Result<Tensor> {
let (_, seq_len) = token_mask.size2()?;

// Compute teacher probabilities.
let teacher_head_probs = teacher_logits
.head_score_logits
Expand All @@ -199,6 +201,7 @@ impl DistillApp {
let probs_mask = token_mask
.with_root()?
.unsqueeze(1)
.f_expand([-1, seq_len, -1], true)?
.logical_and(&token_mask.unsqueeze(-1));
let teacher_head_probs = teacher_head_probs.masked_select(&probs_mask);

Expand Down

0 comments on commit 4247f35

Please sign in to comment.