Skip to content

Commit

Permalink
Merge pull request #1 from ashutoshsingh0223/ashutoshsingh0223-patch-1
Browse files Browse the repository at this point in the history
Enforcing dtype in `K.zeros_like` for mask
  • Loading branch information
ashutoshsingh0223 authored Sep 12, 2019
2 parents 5ffab17 + 0d4129f commit 239c8bb
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion keras_contrib/layers/crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ def recursion(self, input_energy, mask=None, go_backwards=False,
constants = [chain_energy]

if mask is not None:
mask2 = K.cast(K.concatenate([mask, K.zeros_like(mask[:, :1])], axis=1),
mask2 = K.cast(K.concatenate([mask, K.cast(K.zeros_like(mask[:, :1]), mask.dtype)], axis=1),
K.floatx())
constants.append(mask2)

Expand Down

0 comments on commit 239c8bb

Please sign in to comment.