Skip to content

Commit

Permalink
Enforcing dtype in K.zeros_like for mask
Browse files Browse the repository at this point in the history
Enforcing dtype in `K.zeros_like` for mask in `recursion` function. `K.zeros_like` failed to infer dtype of `mask` when a Masking layer is used.
  • Loading branch information
ashutoshsingh0223 authored Sep 12, 2019
1 parent 5ffab17 commit 0d4129f
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion keras_contrib/layers/crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ def recursion(self, input_energy, mask=None, go_backwards=False,
constants = [chain_energy]

if mask is not None:
mask2 = K.cast(K.concatenate([mask, K.zeros_like(mask[:, :1])], axis=1),
mask2 = K.cast(K.concatenate([mask, K.cast(K.zeros_like(mask[:, :1]), mask.dtype)], axis=1),
K.floatx())
constants.append(mask2)

Expand Down

0 comments on commit 0d4129f

Please sign in to comment.