Skip to content

Commit

Permalink
Merge pull request #163 from zalandoresearch/GH-162-memory-optimizations
Browse files Browse the repository at this point in the history
GH-162: integration tests
  • Loading branch information
tabergma authored Oct 19, 2018
2 parents b8dfc15 + 4ce911f commit 375e6cf
Show file tree
Hide file tree
Showing 2 changed files with 266 additions and 19 deletions.
33 changes: 14 additions & 19 deletions flair/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,13 +462,16 @@ def embedding_length(self) -> int:

def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:

cache_path = '{}-tmp-cache.sqllite'.format(self.name) if self.cache_directory is None else os.path.join(
self.cache_directory, '{}-tmp-cache.sqllite'.format(os.path.basename(self.name)))

# by default, use_cache is false (for older pre-trained models TODO: remove in version 0.4)
if 'cache' not in self.__dict__ or 'cache_directory' not in self.__dict__ or not os.path.exists(cache_path):
# this whole block is for compatibility with older serialized models TODO: remove in version 0.4
if 'cache' not in self.__dict__ or 'cache_directory' not in self.__dict__:
self.use_cache = False
self.cache_directory = None
else:
cache_path = '{}-tmp-cache.sqllite'.format(self.name) if not self.cache_directory else os.path.join(
self.cache_directory, '{}-tmp-cache.sqllite'.format(os.path.basename(self.name)))
if not os.path.exists(cache_path):
self.use_cache = False
self.cache_directory = None

# if cache is used, try setting embeddings from cache first
if self.use_cache:
Expand Down Expand Up @@ -553,15 +556,14 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:

class DocumentMeanEmbeddings(DocumentEmbeddings):

def __init__(self, word_embeddings: List[TokenEmbeddings]):
def __init__(self, token_embeddings: List[TokenEmbeddings]):
"""The constructor takes a list of embeddings to be combined."""
super().__init__()

self.embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=word_embeddings)
self.embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=token_embeddings)
self.name: str = 'document_mean'

self.__embedding_length: int = 0
self.__embedding_length = self.embeddings.embedding_length
self.__embedding_length: int = self.embeddings.embedding_length

if torch.cuda.is_available():
self.cuda()
Expand Down Expand Up @@ -631,18 +633,12 @@ def __init__(self,
"""
super().__init__()

self.embeddings: List[TokenEmbeddings] = token_embeddings

# IMPORTANT: add embeddings as torch modules
for i, embedding in enumerate(self.embeddings):
self.add_module('token_embedding_{}'.format(i), embedding)
self.embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=token_embeddings)

self.reproject_words = reproject_words
self.bidirectional = bidirectional

self.length_of_all_token_embeddings = 0
for token_embedding in self.embeddings:
self.length_of_all_token_embeddings += token_embedding.embedding_length
self.length_of_all_token_embeddings: int = self.embeddings.embedding_length

self.name = 'document_lstm'
self.static_embeddings = False
Expand Down Expand Up @@ -691,8 +687,7 @@ def embed(self, sentences: Union[List[Sentence], Sentence]):

sentences.sort(key=lambda x: len(x), reverse=True)

for token_embedding in self.embeddings:
token_embedding.embed(sentences)
self.embeddings.embed(sentences)

# first, sort sentences by number of tokens
longest_token_sequence_in_batch: int = len(sentences[0])
Expand Down
252 changes: 252 additions & 0 deletions tests/test_model_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
import os
import shutil

from flair.data import Sentence
from flair.data_fetcher import NLPTaskDataFetcher, NLPTask
from flair.embeddings import WordEmbeddings, CharLMEmbeddings, DocumentLSTMEmbeddings, TokenEmbeddings
from flair.models import SequenceTagger, TextClassifier
from flair.trainers import SequenceTaggerTrainer, TextClassifierTrainer


def test_train_load_use_tagger():

corpus = NLPTaskDataFetcher.fetch_data(NLPTask.FASHION)
tag_dictionary = corpus.make_tag_dictionary('ner')

embeddings = WordEmbeddings('glove')

tagger: SequenceTagger = SequenceTagger(hidden_size=256,
embeddings=embeddings,
tag_dictionary=tag_dictionary,
tag_type='ner',
use_crf=False)

# initialize trainer
trainer: SequenceTaggerTrainer = SequenceTaggerTrainer(tagger, corpus, test_mode=True)

trainer.train('./results', learning_rate=0.1, mini_batch_size=2, max_epochs=3)

loaded_model: SequenceTagger = SequenceTagger.load_from_file('./results/final-model.pt')

sentence = Sentence('I love Berlin')
sentence_empty = Sentence(' ')

loaded_model.predict(sentence)
loaded_model.predict([sentence, sentence_empty])
loaded_model.predict([sentence_empty])

# clean up results directory
shutil.rmtree('./results')


def test_train_charlm_load_use_tagger():

corpus = NLPTaskDataFetcher.fetch_data(NLPTask.FASHION)
tag_dictionary = corpus.make_tag_dictionary('ner')

embeddings = CharLMEmbeddings('news-forward-fast')

tagger: SequenceTagger = SequenceTagger(hidden_size=256,
embeddings=embeddings,
tag_dictionary=tag_dictionary,
tag_type='ner',
use_crf=False)

# initialize trainer
trainer: SequenceTaggerTrainer = SequenceTaggerTrainer(tagger, corpus, test_mode=True)

trainer.train('./results', learning_rate=0.1, mini_batch_size=2, max_epochs=3)

loaded_model: SequenceTagger = SequenceTagger.load_from_file('./results/final-model.pt')

sentence = Sentence('I love Berlin')
sentence_empty = Sentence(' ')

loaded_model.predict(sentence)
loaded_model.predict([sentence, sentence_empty])
loaded_model.predict([sentence_empty])

# clean up results directory
shutil.rmtree('./results')


def test_train_charlm_changed_chache_load_use_tagger():

corpus = NLPTaskDataFetcher.fetch_data(NLPTask.FASHION)
tag_dictionary = corpus.make_tag_dictionary('ner')

# make a temporary cache directory that we remove afterwards
os.makedirs('./results/cache/', exist_ok=True)
embeddings = CharLMEmbeddings('news-forward-fast', cache_directory='./results/cache/')

tagger: SequenceTagger = SequenceTagger(hidden_size=256,
embeddings=embeddings,
tag_dictionary=tag_dictionary,
tag_type='ner',
use_crf=False)

# initialize trainer
trainer: SequenceTaggerTrainer = SequenceTaggerTrainer(tagger, corpus, test_mode=True)

trainer.train('./results', learning_rate=0.1, mini_batch_size=2, max_epochs=3)

# remove the cache directory
shutil.rmtree('./results/cache')

loaded_model: SequenceTagger = SequenceTagger.load_from_file('./results/final-model.pt')

sentence = Sentence('I love Berlin')
sentence_empty = Sentence(' ')

loaded_model.predict(sentence)
loaded_model.predict([sentence, sentence_empty])
loaded_model.predict([sentence_empty])

# clean up results directory
shutil.rmtree('./results')


def test_train_charlm_nochache_load_use_tagger():

corpus = NLPTaskDataFetcher.fetch_data(NLPTask.FASHION)
tag_dictionary = corpus.make_tag_dictionary('ner')

embeddings = CharLMEmbeddings('news-forward-fast', use_cache=False)

tagger: SequenceTagger = SequenceTagger(hidden_size=256,
embeddings=embeddings,
tag_dictionary=tag_dictionary,
tag_type='ner',
use_crf=False)

# initialize trainer
trainer: SequenceTaggerTrainer = SequenceTaggerTrainer(tagger, corpus, test_mode=True)

trainer.train('./results', learning_rate=0.1, mini_batch_size=2, max_epochs=3)

loaded_model: SequenceTagger = SequenceTagger.load_from_file('./results/final-model.pt')

sentence = Sentence('I love Berlin')
sentence_empty = Sentence(' ')

loaded_model.predict(sentence)
loaded_model.predict([sentence, sentence_empty])
loaded_model.predict([sentence_empty])

# clean up results directory
shutil.rmtree('./results')


def test_load_use_serialized_tagger():

loaded_model: SequenceTagger = SequenceTagger.load('ner')

sentence = Sentence('I love Berlin')
sentence_empty = Sentence(' ')

loaded_model.predict(sentence)
loaded_model.predict([sentence, sentence_empty])
loaded_model.predict([sentence_empty])


def test_train_load_use_classifier():
corpus = NLPTaskDataFetcher.fetch_data(NLPTask.IMDB)
label_dict = corpus.make_label_dictionary()

glove_embedding: WordEmbeddings = WordEmbeddings('en-glove')
document_embeddings: DocumentLSTMEmbeddings = DocumentLSTMEmbeddings([glove_embedding], 128, 1, False, 64, False,
False)

model = TextClassifier(document_embeddings, label_dict, False)

trainer = TextClassifierTrainer(model, corpus, label_dict, False)
trainer.train('./results', max_epochs=2)

sentence = Sentence("Berlin is a really nice city.")

for s in model.predict(sentence):
for l in s.labels:
assert (l.value is not None)
assert (0.0 <= l.score <= 1.0)
assert (type(l.score) is float)

loaded_model = TextClassifier.load_from_file('./results/final-model.pt')

sentence = Sentence('I love Berlin')
sentence_empty = Sentence(' ')

loaded_model.predict(sentence)
loaded_model.predict([sentence, sentence_empty])
loaded_model.predict([sentence_empty])

# clean up results directory
shutil.rmtree('./results')


def test_train_charlm_load_use_classifier():
corpus = NLPTaskDataFetcher.fetch_data(NLPTask.IMDB)
label_dict = corpus.make_label_dictionary()

glove_embedding: TokenEmbeddings = CharLMEmbeddings('news-forward-fast')
document_embeddings: DocumentLSTMEmbeddings = DocumentLSTMEmbeddings([glove_embedding], 128, 1, False, 64, False,
False)

model = TextClassifier(document_embeddings, label_dict, False)

trainer = TextClassifierTrainer(model, corpus, label_dict, False)
trainer.train('./results', max_epochs=2)

sentence = Sentence("Berlin is a really nice city.")

for s in model.predict(sentence):
for l in s.labels:
assert (l.value is not None)
assert (0.0 <= l.score <= 1.0)
assert (type(l.score) is float)

loaded_model = TextClassifier.load_from_file('./results/final-model.pt')

sentence = Sentence('I love Berlin')
sentence_empty = Sentence(' ')

loaded_model.predict(sentence)
loaded_model.predict([sentence, sentence_empty])
loaded_model.predict([sentence_empty])

# clean up results directory
shutil.rmtree('./results')


def test_train_charlm__nocache_load_use_classifier():
corpus = NLPTaskDataFetcher.fetch_data(NLPTask.IMDB)
label_dict = corpus.make_label_dictionary()

glove_embedding: TokenEmbeddings = CharLMEmbeddings('news-forward-fast', use_cache=False)
document_embeddings: DocumentLSTMEmbeddings = DocumentLSTMEmbeddings([glove_embedding], 128, 1, False, 64,
False,
False)

model = TextClassifier(document_embeddings, label_dict, False)

trainer = TextClassifierTrainer(model, corpus, label_dict, False)
trainer.train('./results', max_epochs=2)

sentence = Sentence("Berlin is a really nice city.")

for s in model.predict(sentence):
for l in s.labels:
assert (l.value is not None)
assert (0.0 <= l.score <= 1.0)
assert (type(l.score) is float)

loaded_model = TextClassifier.load_from_file('./results/final-model.pt')

sentence = Sentence('I love Berlin')
sentence_empty = Sentence(' ')

loaded_model.predict(sentence)
loaded_model.predict([sentence, sentence_empty])
loaded_model.predict([sentence_empty])

# clean up results directory
shutil.rmtree('./results')

0 comments on commit 375e6cf

Please sign in to comment.