diff --git a/src/models/transformer.h b/src/models/transformer.h index 2d9ced33c..404b89254 100755 --- a/src/models/transformer.h +++ b/src/models/transformer.h @@ -281,7 +281,7 @@ class Transformer : public EncoderOrDecoderBase { // memoization propagation (short-term) if (cache // if caching && cache_.count(prefix + "_keys") > 0 // and the keys expression has been seen - && cache_[prefix + "_keys"]->shape().elements() == keys->shape().elements()) { // and the underlying element size did not change + && cache_[prefix + "_keys"]->shape() == keys->shape()) { // and the underlying shape did not change kh = cache_[prefix + "_keys"]; // then return cached tensor } else { @@ -296,7 +296,7 @@ class Transformer : public EncoderOrDecoderBase { Expr vh; if (cache && cache_.count(prefix + "_values") > 0 - && cache_[prefix + "_values"]->shape().elements() == values->shape().elements()) { + && cache_[prefix + "_values"]->shape() == values->shape()) { vh = cache_[prefix + "_values"]; } else { auto Wv = graph_->param(prefix + "_Wv", {dimModel, dimModel}, inits::glorotUniform());