Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize sequence labeler and allow re-use embeddings for labeling #798

Merged
merged 6 commits into from
Mar 14, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions neuralmonkey/config/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@ def create(self) -> Any:
if exc.name == module_name: # type: ignore
raise Exception(
"Cannot import module {}.".format(module_name))
else:
raise
raise

try:
clazz = getattr(module, class_name)
Expand Down
13 changes: 10 additions & 3 deletions neuralmonkey/decoders/autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,16 +288,21 @@ def train_xents(self) -> tf.Tensor:
tf.one_hot(labels, len(self.vocabulary)),
logits, label_smoothing=self.label_smoothing))

# Return losses of shape (batch, time). Losses on invalid positions
# are zero.
return tf.contrib.seq2seq.sequence_loss(
tf.transpose(self.train_logits, perm=[1, 0, 2]),
train_targets,
tf.transpose(self.train_mask),
average_across_batch=False,
average_across_timesteps=False,
softmax_loss_function=softmax_function)

@tensor
def train_loss(self) -> tf.Tensor:
return tf.reduce_mean(self.train_xents)
# Cross entropy mean over all words in the batch
# (could also be done as a mean over sentences)
return tf.reduce_sum(self.train_xents) / tf.reduce_sum(self.train_mask)

@property
def cost(self) -> tf.Tensor:
Expand Down Expand Up @@ -344,11 +349,13 @@ def runtime_xents(self) -> tf.Tensor:
logits=batch_major_logits[:, :min_time],
targets=train_targets[:, :min_time],
weights=tf.transpose(self.train_mask)[:, :min_time],
average_across_batch=False)
average_across_batch=False,
average_across_timesteps=False)

@tensor
def runtime_loss(self) -> tf.Tensor:
return tf.reduce_mean(self.runtime_xents)
return (tf.reduce_sum(self.runtime_xents)
/ tf.reduce_sum(tf.to_float(self.runtime_mask)))

@tensor
def runtime_logprobs(self) -> tf.Tensor:
Expand Down
182 changes: 123 additions & 59 deletions neuralmonkey/decoders/sequence_labeler.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,34 @@
from typing import Dict, Union
from typing import List, Dict, Callable

import tensorflow as tf
from typeguard import check_argument_types

from neuralmonkey.dataset import Dataset
from neuralmonkey.decorators import tensor
from neuralmonkey.encoders.recurrent import RecurrentEncoder
from neuralmonkey.encoders.facebook_conv import SentenceEncoder
from neuralmonkey.model.stateful import TemporalStateful
from neuralmonkey.model.feedable import FeedDict
from neuralmonkey.model.parameterized import InitializerSpecs
from neuralmonkey.model.model_part import ModelPart
from neuralmonkey.tf_utils import get_variable
from neuralmonkey.model.sequence import EmbeddedSequence
from neuralmonkey.nn.utils import dropout
from neuralmonkey.vocabulary import Vocabulary, pad_batch, sentence_mask


class SequenceLabeler(ModelPart):
"""Classifier assing a label to each encoder's state."""

# pylint: disable=too-many-arguments
# pylint: disable=too-many-arguments,too-many-locals
def __init__(self,
name: str,
encoder: Union[RecurrentEncoder, SentenceEncoder],
encoders: List[TemporalStateful],
vocabulary: Vocabulary,
data_id: str,
max_output_len: int = None,
hidden_dim: int = None,
activation: Callable = tf.nn.relu,
dropout_keep_prob: float = 1.0,
add_start_symbol: bool = False,
add_end_symbol: bool = False,
reuse: ModelPart = None,
save_checkpoint: str = None,
load_checkpoint: str = None,
Expand All @@ -32,11 +37,16 @@ def __init__(self,
ModelPart.__init__(self, name, reuse, save_checkpoint, load_checkpoint,
initializers)

self.encoder = encoder
self.encoders = encoders
self.vocabulary = vocabulary
self.data_id = data_id
self.max_output_len = max_output_len
self.hidden_dim = hidden_dim
self.activation = activation
self.dropout_keep_prob = dropout_keep_prob
# pylint: enable=too-many-arguments
self.add_start_symbol = add_start_symbol
self.add_end_symbol = add_end_symbol
# pylint: enable=too-many-arguments,too-many-locals

@property
def input_types(self) -> Dict[str, tf.DType]:
Expand All @@ -46,70 +56,58 @@ def input_types(self) -> Dict[str, tf.DType]:
def input_shapes(self) -> Dict[str, tf.TensorShape]:
return {self.data_id: tf.TensorShape([None, None])}

@tensor
def input_mask(self) -> tf.Tensor:
mask_main = self.encoders[0].temporal_mask

asserts = [
tf.assert_equal(
mask_main, enc.temporal_mask,
message=("Encoders '{}' and '{}' does not have equal temporal "
"masks.".format(str(self.encoders[0]), str(enc))))
for enc in self.encoders[1:]]

with tf.control_dependencies(asserts):
return mask_main

@tensor
def target_tokens(self) -> tf.Tensor:
return self.dataset[self.data_id]

@tensor
def train_targets(self) -> tf.Tensor:
return self.vocabulary.strings_to_indices(self.target_tokens)
return self.vocabulary.strings_to_indices(
self.dataset[self.data_id])

@tensor
def train_mask(self) -> tf.Tensor:
return sentence_mask(self.train_targets)

@property
def rnn_size(self) -> int:
return int(self.encoder.temporal_states.get_shape()[-1])

@tensor
def decoding_w(self) -> tf.Variable:
return get_variable(
name="state_to_word_W",
shape=[self.rnn_size, len(self.vocabulary)])

@tensor
def decoding_b(self) -> tf.Variable:
return get_variable(
name="state_to_word_b",
shape=[len(self.vocabulary)],
initializer=tf.zeros_initializer())
def concatenated_inputs(self) -> tf.Tensor:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kdyz budu mit vice enkoderu, co pracuji nad ruzne dlouhymi sekvencemi, tak to na tom concatu spadne, ne?
Nemela by se takova situace resit spise pres FactoredSequence/FactoredEncoder?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to je pravda

# Validate shapes first
with tf.control_dependencies([self.input_mask]):
return tf.concat(
[inp.temporal_states for inp in self.encoders], axis=2)

@tensor
def decoding_residual_w(self) -> tf.Variable:
input_dim = self.encoder.input_sequence.dimension
return get_variable(
name="emb_to_word_W",
shape=[input_dim, len(self.vocabulary)])
def states(self) -> tf.Tensor:
states = dropout(
self.concatenated_inputs, self.dropout_keep_prob, self.train_mode)

jindrahelcl marked this conversation as resolved.
Show resolved Hide resolved
if self.hidden_dim is not None:
hidden = tf.layers.dense(
states, self.hidden_dim, self.activation,
name="hidden_layer")
# pylint: disable=redefined-variable-type
states = dropout(hidden, self.dropout_keep_prob, self.train_mode)
# pylint: enable=redefined-variable-type
return states

@tensor
def logits(self) -> tf.Tensor:
# To multiply 3-D matrix (encoder hidden states) by a 2-D matrix
# (weights), we use 1-by-1 convolution (similar trick can be found in
# attention computation)

# TODO dropout needs to be revisited

encoder_states = tf.expand_dims(self.encoder.temporal_states, 2)
weights_4d = tf.expand_dims(tf.expand_dims(self.decoding_w, 0), 0)

multiplication = tf.nn.conv2d(
encoder_states, weights_4d, [1, 1, 1, 1], "SAME")
multiplication_3d = tf.squeeze(multiplication, axis=[2])

biases_3d = tf.expand_dims(tf.expand_dims(self.decoding_b, 0), 0)

embedded_inputs = tf.expand_dims(
self.encoder.input_sequence.temporal_states, 2)
dweights_4d = tf.expand_dims(
tf.expand_dims(self.decoding_residual_w, 0), 0)

dmultiplication = tf.nn.conv2d(
embedded_inputs, dweights_4d, [1, 1, 1, 1], "SAME")
dmultiplication_3d = tf.squeeze(dmultiplication, axis=[2])

logits = multiplication_3d + dmultiplication_3d + biases_3d
return logits
return tf.layers.dense(
jindrahelcl marked this conversation as resolved.
Show resolved Hide resolved
self.states, len(self.vocabulary), name="logits")

@tensor
def logprobs(self) -> tf.Tensor:
Expand All @@ -120,14 +118,19 @@ def decoded(self) -> tf.Tensor:
return tf.argmax(self.logits, 2)

@tensor
def cost(self) -> tf.Tensor:
def train_xents(self) -> tf.Tensor:
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=self.train_targets, logits=self.logits)

# loss is now of shape [batch, time]. Need to mask it now by
# element-wise multiplication with weights placeholder
weighted_loss = loss * self.train_mask
return tf.reduce_sum(weighted_loss)
return loss * self.train_mask
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Jo, jeste se mi tady nezda format train_xents.
Tady to vraci shape(batch, time), ale v autoregressivnich dekoderech to vraci shape(batch).

Bylo by teda fajn se dohodnout, co se bude vracet (bud tady vracet prumer-per-sequence; nebo v autoregressive by melo stacit vypnout average_across_timesteps v self.train_xents)

In the long run by samozrejme mel byt jeden spolecny predek "dekoder" s abstraktnima metodama, jako loss, xents apod., ale to bych klidne nechal do jineho PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

V mariánovi to jde nastavovat ("ce-mean" vs "ce-mean-words"), s tím, že v nových modelech se používá loss zprůměrovanej ze všech slov v batchi, ale default má průměr po větách.

U labeleru dává víc smysl mít ten loss pro každej label zvlášť, kdežto u dekodéru asi spíš po větách, ale souhlasim, že se to má sjednotit. Je tim pádem asi lepší vracet (batch, time) kterej si pak zprůměruješ jak chceš.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Taky jsem za druhou moznost, to ale znamena tedka jeste vypnout switch v AutoregressiveDecoder. Nevim, jak to ovlivni zabehle trenovani (predpokladam, ze minimalne; ale nemeril jsem to). Pokud udelas porovnani, tak to klidne muzes zamergovat timhle zpusobem.

Na druhou stranu uz mam vyzkousene ze udelat prumer pres vety a pak pres batch v seq. labeleru funguje, takze bych v tomto PR radej udelal tohle (hodil issue na poradne doreseni).

Kazdopadne by bylo fajn to uz ted mit v masteru sjednocene, nez to nekam zapadne. Klicove je, ze to vyrazne snizi uroven odrbavani v jinych komponentach, kde pak musis mit divne workaroundy tipu kontroly shapu, supported decoder apod.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

autoregressive dekodér má nějakej switch? Vidim jen, že to dělá reduce_mean na {train,runtime}_xents (místo aby dělal mean jen přes validní pozice, ale jinak je to stejný)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mam na mysli toto:
https://github.com/ufal/neuralmonkey/blob/master/neuralmonkey/decoders/autoregressive.py#L291

seq2seq.sequence_loss ma prepinac average_across_timesteps, ktery je tedka True. Kdyby se vypnul, tak by to odpovidalo formatu train_xents v labeleru

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

už to vidim...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hotovo. musel se ještě změnit perplexity runner, kterej počítá s průměrama přes čas.


@tensor
def cost(self) -> tf.Tensor:
# Cross entropy mean over all words in the batch
# (could also be done as a mean over sentences)
return tf.reduce_sum(self.train_xents) / tf.reduce_sum(self.train_mask)

@property
def train_loss(self) -> tf.Tensor:
Expand All @@ -142,6 +145,67 @@ def feed_dict(self, dataset: Dataset, train: bool = False) -> FeedDict:

sentences = dataset.maybe_get_series(self.data_id)
if sentences is not None:
fd[self.target_tokens] = pad_batch(list(sentences))
fd[self.target_tokens] = pad_batch(
list(sentences), self.max_output_len, self.add_start_symbol,
self.add_end_symbol)

return fd


class EmbeddingsLabeler(SequenceLabeler):
"""SequenceLabeler that uses an embedding matrix for output projection."""

# pylint: disable=too-many-arguments,too-many-locals
def __init__(self,
name: str,
encoders: List[TemporalStateful],
embedded_sequence: EmbeddedSequence,
data_id: str,
max_output_len: int = None,
hidden_dim: int = None,
activation: Callable = tf.nn.relu,
train_embeddings: bool = True,
dropout_keep_prob: float = 1.0,
add_start_symbol: bool = False,
add_end_symbol: bool = False,
reuse: ModelPart = None,
save_checkpoint: str = None,
load_checkpoint: str = None,
initializers: InitializerSpecs = None) -> None:

check_argument_types()
SequenceLabeler.__init__(
self, name, encoders, embedded_sequence.vocabulary, data_id,
max_output_len, hidden_dim=hidden_dim, activation=activation,
dropout_keep_prob=dropout_keep_prob,
add_start_symbol=add_start_symbol, add_end_symbol=add_end_symbol,
reuse=reuse, save_checkpoint=save_checkpoint,
load_checkpoint=load_checkpoint, initializers=initializers)

self.embedded_sequence = embedded_sequence
self.train_embeddings = train_embeddings
# pylint: enable=too-many-arguments,too-many-locals

@tensor
def logits(self) -> tf.Tensor:
embeddings = self.embedded_sequence.embedding_matrix
if not self.train_embeddings:
embeddings = tf.stop_gradient(embeddings)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fakt muzes tohle udelat? Vzhledem k tomu, ze se jedna o fakticky posledni vrstvu, nestane se to, ze ti pri backpropu neprotece zadna informace do zbytku site (a tedy se nic nenaucis)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

neproteče tam přes tuhle lokální proměnnou embeddings, ale proteče tam skrz states.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Jasne, uz to vidim


states = self.states
# pylint: disable=no-member
states_dim = self.states.get_shape()[-1].value
# pylint: enable=no-member
embedding_dim = self.embedded_sequence.embedding_sizes[0]
# pylint: disable=redefined-variable-type
if states_dim != embedding_dim:
states = tf.layers.dense(
jindrahelcl marked this conversation as resolved.
Show resolved Hide resolved
states, embedding_dim, name="project_for_embeddings")
states = dropout(states, self.dropout_keep_prob, self.train_mode)
# pylint: enable=redefined-variable-type

reshaped_states = tf.reshape(states, [-1, embedding_dim])
reshaped_logits = tf.matmul(
reshaped_states, embeddings, transpose_b=True, name="logits")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Biasy necheme?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

neděláme je ani jinde při tie_embeddings, viz autoregressive.py:225

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Jo, nevsiml jsem si, ze EmbeddingsLabeler vzdycky vaze embeddingy.

return tf.reshape(
reshaped_logits, [self.batch_size, -1, len(self.vocabulary)])
6 changes: 2 additions & 4 deletions neuralmonkey/learning_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,15 +260,13 @@ def _check_series_collisions(runners: List[BaseRunner],
if series in runners_outputs:
raise Exception(("Output series '{}' is multiple times among the "
"runners' outputs.").format(series))
else:
runners_outputs.add(series)
runners_outputs.add(series)
if postprocess is not None:
for series, _ in postprocess:
if series in runners_outputs:
raise Exception(("Postprocess output series '{}' "
"already exists.").format(series))
else:
runners_outputs.add(series)
runners_outputs.add(series)


def run_on_dataset(tf_manager: TensorFlowManager,
Expand Down
2 changes: 1 addition & 1 deletion neuralmonkey/readers/string_vector_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def process_line(line: str, lineno: int, path: str) -> np.ndarray:

return np.array(numbers, dtype=dtype)

def reader(files: List[str])-> Iterable[List[np.ndarray]]:
def reader(files: List[str]) -> Iterable[List[np.ndarray]]:
for path in files:
current_line = 0

Expand Down
2 changes: 1 addition & 1 deletion neuralmonkey/runners/label_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(self,
def fetches(self) -> Dict[str, tf.Tensor]:
return {
"label_logprobs": self.decoder.logprobs,
"input_mask": self.decoder.encoder.input_sequence.temporal_mask,
"input_mask": self.decoder.input_mask,
"loss": self.decoder.cost}

@property
Expand Down
6 changes: 5 additions & 1 deletion neuralmonkey/runners/perplexity_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@ def __init__(self,

@tensor
def fetches(self) -> Dict[str, tf.Tensor]:
return {"xents": self.decoder.train_xents}
# decoder.train_xents are (batch, time)
# we average xents over the time dimension
return {"xents": tf.reduce_sum(
self.decoder.train_xents, axis=1) / tf.reduce_sum(
tf.transpose(self.decoder.train_mask), axis=1)}

@property
def loss_names(self) -> List[str]:
Expand Down
2 changes: 1 addition & 1 deletion tests/labeler.ini
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ vocabulary=<source_vocabulary>
[decoder]
class=decoders.sequence_labeler.SequenceLabeler
name="tagger"
encoder=<encoder>
encoders=[<encoder>]
data_id="tags"
dropout_keep_prob=0.5
vocabulary=<tags_vocabulary>
Expand Down