From 4247f357022452455770581468ca853bfa1c0dc0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Mon, 16 Oct 2023 13:05:56 +0200 Subject: [PATCH] Fix implicit expansion when computing head logit mask The implicit expansion emitted a Torch warning in each call. --- syntaxdot-cli/src/subcommands/distill.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/syntaxdot-cli/src/subcommands/distill.rs b/syntaxdot-cli/src/subcommands/distill.rs index 8ca4841..6ac519a 100644 --- a/syntaxdot-cli/src/subcommands/distill.rs +++ b/syntaxdot-cli/src/subcommands/distill.rs @@ -191,6 +191,8 @@ impl DistillApp { student_logits: &BiaffineScoreLogits, token_mask: &TokenMask, ) -> Result { + let (_, seq_len) = token_mask.size2()?; + // Compute teacher probabilities. let teacher_head_probs = teacher_logits .head_score_logits @@ -199,6 +201,7 @@ impl DistillApp { let probs_mask = token_mask .with_root()? .unsqueeze(1) + .f_expand([-1, seq_len, -1], true)? .logical_and(&token_mask.unsqueeze(-1)); let teacher_head_probs = teacher_head_probs.masked_select(&probs_mask);