Skip to content

Commit

Permalink
GH-162: harmonize DocumentMeanEmbeddings and DocumentLSTMEmbeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
aakbik committed Oct 19, 2018
1 parent a09837b commit 4ce911f
Showing 1 changed file with 6 additions and 14 deletions.
20 changes: 6 additions & 14 deletions flair/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,15 +556,14 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:

class DocumentMeanEmbeddings(DocumentEmbeddings):

def __init__(self, word_embeddings: List[TokenEmbeddings]):
def __init__(self, token_embeddings: List[TokenEmbeddings]):
"""The constructor takes a list of embeddings to be combined."""
super().__init__()

self.embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=word_embeddings)
self.embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=token_embeddings)
self.name: str = 'document_mean'

self.__embedding_length: int = 0
self.__embedding_length = self.embeddings.embedding_length
self.__embedding_length: int = self.embeddings.embedding_length

if torch.cuda.is_available():
self.cuda()
Expand Down Expand Up @@ -634,18 +633,12 @@ def __init__(self,
"""
super().__init__()

self.embeddings: List[TokenEmbeddings] = token_embeddings

# IMPORTANT: add embeddings as torch modules
for i, embedding in enumerate(self.embeddings):
self.add_module('token_embedding_{}'.format(i), embedding)
self.embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=token_embeddings)

self.reproject_words = reproject_words
self.bidirectional = bidirectional

self.length_of_all_token_embeddings = 0
for token_embedding in self.embeddings:
self.length_of_all_token_embeddings += token_embedding.embedding_length
self.length_of_all_token_embeddings: int = self.embeddings.embedding_length

self.name = 'document_lstm'
self.static_embeddings = False
Expand Down Expand Up @@ -694,8 +687,7 @@ def embed(self, sentences: Union[List[Sentence], Sentence]):

sentences.sort(key=lambda x: len(x), reverse=True)

for token_embedding in self.embeddings:
token_embedding.embed(sentences)
self.embeddings.embed(sentences)

# first, sort sentences by number of tokens
longest_token_sequence_in_batch: int = len(sentences[0])
Expand Down

0 comments on commit 4ce911f

Please sign in to comment.