From 4ce911feee13c2adadb53b6a6b4535fe6e676b6e Mon Sep 17 00:00:00 2001 From: aakbik Date: Fri, 19 Oct 2018 12:09:41 +0200 Subject: [PATCH] GH-162: harmonize DocumentMeanEmbeddings and DocumentLSTMEmbeddings --- flair/embeddings.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/flair/embeddings.py b/flair/embeddings.py index dda428cd0..5316c2975 100644 --- a/flair/embeddings.py +++ b/flair/embeddings.py @@ -556,15 +556,14 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: class DocumentMeanEmbeddings(DocumentEmbeddings): - def __init__(self, word_embeddings: List[TokenEmbeddings]): + def __init__(self, token_embeddings: List[TokenEmbeddings]): """The constructor takes a list of embeddings to be combined.""" super().__init__() - self.embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=word_embeddings) + self.embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=token_embeddings) self.name: str = 'document_mean' - self.__embedding_length: int = 0 - self.__embedding_length = self.embeddings.embedding_length + self.__embedding_length: int = self.embeddings.embedding_length if torch.cuda.is_available(): self.cuda() @@ -634,18 +633,12 @@ def __init__(self, """ super().__init__() - self.embeddings: List[TokenEmbeddings] = token_embeddings - - # IMPORTANT: add embeddings as torch modules - for i, embedding in enumerate(self.embeddings): - self.add_module('token_embedding_{}'.format(i), embedding) + self.embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=token_embeddings) self.reproject_words = reproject_words self.bidirectional = bidirectional - self.length_of_all_token_embeddings = 0 - for token_embedding in self.embeddings: - self.length_of_all_token_embeddings += token_embedding.embedding_length + self.length_of_all_token_embeddings: int = self.embeddings.embedding_length self.name = 'document_lstm' self.static_embeddings = False @@ -694,8 +687,7 @@ def embed(self, sentences: Union[List[Sentence], Sentence]): sentences.sort(key=lambda x: len(x), reverse=True) - for token_embedding in self.embeddings: - token_embedding.embed(sentences) + self.embeddings.embed(sentences) # first, sort sentences by number of tokens longest_token_sequence_in_batch: int = len(sentences[0])