From 0d4129f9abc636685e934ac4b34236e3a9324371 Mon Sep 17 00:00:00 2001 From: Ashutosh Singh <40604544+ashutoshsingh0223@users.noreply.github.com> Date: Thu, 12 Sep 2019 22:00:21 +0530 Subject: [PATCH] Enforcing dtype in `K.zeros_like` for mask Enforcing dtype in `K.zeros_like` for mask in `recursion` function. `K.zeros_like` failed to infer dtype of `mask` when a Masking layer is used. --- keras_contrib/layers/crf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_contrib/layers/crf.py b/keras_contrib/layers/crf.py index 88a64ac69..e48db76a1 100644 --- a/keras_contrib/layers/crf.py +++ b/keras_contrib/layers/crf.py @@ -513,7 +513,7 @@ def recursion(self, input_energy, mask=None, go_backwards=False, constants = [chain_energy] if mask is not None: - mask2 = K.cast(K.concatenate([mask, K.zeros_like(mask[:, :1])], axis=1), + mask2 = K.cast(K.concatenate([mask, K.cast(K.zeros_like(mask[:, :1]), mask.dtype)], axis=1), K.floatx()) constants.append(mask2)