diff --git a/rust/moshi-core/src/seanet.rs b/rust/moshi-core/src/seanet.rs index e947002..6dae77e 100644 --- a/rust/moshi-core/src/seanet.rs +++ b/rust/moshi-core/src/seanet.rs @@ -117,6 +117,7 @@ impl Module for SeaNetResnetBlock { impl StreamingModule for SeaNetResnetBlock { fn reset_state(&mut self) { + // TODO(laurent): self.skip_op should probably be resetted here. for block in self.block.iter_mut() { block.reset_state() } @@ -133,7 +134,10 @@ impl StreamingModule for SeaNetResnetBlock { } match self.shortcut.as_ref() { None => self.skip_op.step(&ys, xs), - Some(shortcut) => self.skip_op.step(&ys, &xs.apply(shortcut)?), + Some(shortcut) => { + // TODO(laurent): shouldn't this use shortcut.step(xs) instead? + self.skip_op.step(&ys, &xs.apply(shortcut)?) + } } } }