diff --git a/syntaxdot-cli/src/subcommands/distill.rs b/syntaxdot-cli/src/subcommands/distill.rs index 8ca4841..6ac519a 100644 --- a/syntaxdot-cli/src/subcommands/distill.rs +++ b/syntaxdot-cli/src/subcommands/distill.rs @@ -191,6 +191,8 @@ impl DistillApp { student_logits: &BiaffineScoreLogits, token_mask: &TokenMask, ) -> Result { + let (_, seq_len) = token_mask.size2()?; + // Compute teacher probabilities. let teacher_head_probs = teacher_logits .head_score_logits @@ -199,6 +201,7 @@ impl DistillApp { let probs_mask = token_mask .with_root()? .unsqueeze(1) + .f_expand([-1, seq_len, -1], true)? .logical_and(&token_mask.unsqueeze(-1)); let teacher_head_probs = teacher_head_probs.masked_select(&probs_mask);