diff --git a/README.md b/README.md index 52e1dbd34..ee578cd27 100644 --- a/README.md +++ b/README.md @@ -2,28 +2,23 @@ ![alt text](resources/docs/flair_logo.svg) -- a very simple framework for **state-of-the-art NLP**. Developed by [Zalando Research](https://research.zalando.com/). +A very simple framework for **state-of-the-art NLP**. Developed by [Zalando Research](https://research.zalando.com/). --- -Flair uses **hyper-powerful word embeddings** to achieve state-of-the-art accuracies - on a range of natural language processing (NLP) tasks. - Flair is: * **A powerful syntactic / semantic tagger.** Flair allows you to apply our state-of-the-art models for named entity recognition (NER), part-of-speech tagging (PoS) and chunking to your text. -* **A word embedding library.** There are many different types of word embeddings out there, with wildly different properties. -Flair packages many of them behind a simple interface, so you can mix and match embeddings for your experiments. +* **A text embedding library.** Flair has simple interfaces that allow you to use and combine different word embeddings. In particular, you can try out our proposed -*[contextual string embeddings](https://drive.google.com/file/d/17yVpFA7MmXaQFTe-HDpZuqw9fJlmzg56/view?usp=sharing)*, +**[contextual string embeddings](https://drive.google.com/file/d/17yVpFA7MmXaQFTe-HDpZuqw9fJlmzg56/view?usp=sharing)** to build your own state-of-the-art NLP methods. * **A Pytorch NLP framework.** Our framework builds directly on [Pytorch](https://pytorch.org/), making it easy to train your own models and experiment with new approaches using Flair embeddings and classes. -Embedding your text for state-of-the-art NLP has never been easier. ## Comparison with State-of-the-Art @@ -73,13 +68,13 @@ a pre-trained model and use it to predict tags for the sentence: ```python from flair.data import Sentence -from flair.tagging_model import SequenceTaggerLSTM +from flair.tagging_model import SequenceTagger # make a sentence sentence = Sentence('I love Berlin .') # load the NER tagger -tagger = SequenceTaggerLSTM.load('ner') +tagger = SequenceTagger.load('ner') # run NER over sentence tagger.predict(sentence) @@ -88,15 +83,15 @@ tagger.predict(sentence) Done! The `Sentence` now has entity annotations. Print the sentence to see what the tagger found. ```python -print('Analysing %s' % sentence) +print(sentence) print('The following NER tags are found:') -print(sentence.to_tag_string()) +print(sentence.to_tagged_string()) ``` This should print: ```console -Analysing Sentence: "I love Berlin ." - 4 Tokens +Sentence: "I love Berlin ." - 4 Tokens The following NER tags are found: diff --git a/flair/data.py b/flair/data.py index 0b0f955fc..625530cfc 100644 --- a/flair/data.py +++ b/flair/data.py @@ -126,7 +126,7 @@ def __init__(self, text: str = None, use_tokenizer: bool = False, labels: List[s self.labels: List[str] = labels - self.embeddings: Dict = {} + self._embeddings: Dict = {} # optionally, directly instantiate with sentence tokens if text is not None: @@ -164,15 +164,24 @@ def add_token(self, token: Token): token.idx = len(self.tokens) def set_embedding(self, name: str, vector): - self.embeddings[name] = vector + self._embeddings[name] = vector - def clear_embeddings(self): - self.embeddings: Dict = {} + def clear_embeddings(self, also_clear_word_embeddings: bool = True): + + self._embeddings: Dict = {} + + if also_clear_word_embeddings: + for token in self: + token.clear_embeddings() + + def cpu_embeddings(self): + for name, vector in self._embeddings.items(): + self._embeddings[name] = vector.cpu() def get_embedding(self) -> torch.autograd.Variable: embeddings = [] - for embed in sorted(self.embeddings.keys()): - embedding = self.embeddings[embed] + for embed in sorted(self._embeddings.keys()): + embedding = self._embeddings[embed] embeddings.append(embedding) return torch.cat(embeddings, dim=0) @@ -181,24 +190,41 @@ def get_embedding(self) -> torch.autograd.Variable: def embedding(self): return self.get_embedding() - def to_tag_string(self, tag_type: str = 'tag') -> str: + def to_tagged_string(self) -> str: + list = [] for token in self.tokens: list.append(token.text) - if token.get_tag(tag_type) == '' or token.get_tag(tag_type) == 'O': continue - list.append('<' + token.get_tag(tag_type) + '>') - return ' '.join(list) - def to_ner_string(self) -> str: - list = [] - for token in self.tokens: - if token.get_tag('ner') == 'O' or token.get_tag('ner') == '': - list.append(token.text) - else: - list.append(token.text) - list.append('<' + token.get_tag('ner') + '>') + tags = [] + for tag_type in token.tags.keys(): + + if token.get_tag(tag_type) == '' or token.get_tag(tag_type) == 'O': continue + tags.append(token.get_tag(tag_type)) + all_tags = '<' + '/'.join(tags) + '>' + if all_tags != '<>': + list.append(all_tags) return ' '.join(list) + # def to_tag_string(self, tag_type: str = 'tag') -> str: + # + # list = [] + # for token in self.tokens: + # list.append(token.text) + # if token.get_tag(tag_type) == '' or token.get_tag(tag_type) == 'O': continue + # list.append('<' + token.get_tag(tag_type) + '>') + # return ' '.join(list) + # + # def to_ner_string(self) -> str: + # list = [] + # for token in self.tokens: + # if token.get_tag('ner') == 'O' or token.get_tag('ner') == '': + # list.append(token.text) + # else: + # list.append(token.text) + # list.append('<' + token.get_tag('ner') + '>') + # return ' '.join(list) + def convert_tag_scheme(self, tag_type: str = 'ner', target_scheme: str = 'iob'): tags: List[str] = [] @@ -217,7 +243,7 @@ def convert_tag_scheme(self, tag_type: str = 'ner', target_scheme: str = 'iob'): self.tokens[index].add_tag(tag_type, tag) def __repr__(self): - return ' '.join([x.text for x in self.tokens]) + return 'Sentence: "' + ' '.join([t.text for t in self.tokens]) + '" - %d Tokens' % len(self) def __copy__(self): s = Sentence() diff --git a/flair/embeddings.py b/flair/embeddings.py index e23bb400b..67628a924 100644 --- a/flair/embeddings.py +++ b/flair/embeddings.py @@ -1,5 +1,6 @@ import pickle import re +import os from abc import ABC, abstractmethod from typing import List, Dict, Tuple @@ -15,11 +16,16 @@ class TextEmbeddings(torch.nn.Module): """Abstract base class for all embeddings. Ever new type of embedding must implement these methods.""" + @property @abstractmethod def embedding_length(self) -> int: """Returns the length of the embedding vector.""" pass + @property + def embedding_type(self) -> str: + return 'word-level' + def embed(self, sentences: List[Sentence]) -> List[Sentence]: """Add embeddings to all words in a list of sentences. If embeddings are already added, updates only if embeddings are non-static.""" @@ -29,13 +35,16 @@ def embed(self, sentences: List[Sentence]) -> List[Sentence]: sentences = [sentences] everything_embedded: bool = True - for sentence in sentences: - for token in sentence.tokens: - if self.name not in token._embeddings.keys(): everything_embedded = False - # print(everything_embedded) + if self.embedding_type == 'word-level': + for sentence in sentences: + for token in sentence.tokens: + if self.name not in token._embeddings.keys(): everything_embedded = False + else: + for sentence in sentences: + if self.name not in sentence._embeddings.keys(): everything_embedded = False + if not everything_embedded or not self.static_embeddings: - # print('retrieving embeddings %s' + self.name) self._add_embeddings_internal(sentences) return sentences @@ -63,17 +72,24 @@ def __init__(self, embeddings: List[TextEmbeddings], detach: bool = True): self.name = 'Stack' self.static_embeddings = True - self.embedding_length: int = 0 + self.__embedding_type: int = embeddings[0].embedding_type + + self.__embedding_length: int = 0 for embedding in embeddings: - self.embedding_length += embedding.embedding_length + self.__embedding_length += embedding.embedding_length def embed(self, sentences: List[Sentence], static_embeddings: bool = True): for embedding in self.embeddings: embedding.embed(sentences) + @property + def embedding_type(self): + return self.__embedding_type + + @property def embedding_length(self) -> int: - return self.embedding_length + return self.__embedding_length def _add_embeddings_internal(self, sentences: List[Sentence]): @@ -94,34 +110,43 @@ def __init__(self, embeddings): # GLOVE embeddings if embeddings.lower() == 'glove' or embeddings.lower() == 'en-glove': - cached_path('%sglove.gensim.vectors.npy' % base_path) - embeddings = cached_path('%sglove.gensim' % base_path) + cached_path(os.path.join(base_path, 'glove.gensim.vectors.npy'), cache_dir='embeddings') + embeddings = cached_path(os.path.join(base_path, 'glove.gensim'), cache_dir='embeddings') + + # KOMNIOS embeddings + if embeddings.lower() == 'extvec' or embeddings.lower() == 'en-extvec': + cached_path(os.path.join(base_path, 'extvec.gensim.vectors.npy'), cache_dir='embeddings') + embeddings = cached_path(os.path.join(base_path, 'extvec.gensim'), cache_dir='embeddings') # NUMBERBATCH embeddings if embeddings.lower() == 'numberbatch' or embeddings.lower() == 'en-numberbatch': - cached_path('%snumberbatch-en.vectors.npy' % base_path) - embeddings = cached_path('%snumberbatch-en' % base_path) + cached_path(os.path.join(base_path, 'numberbatch-en.vectors.npy'), cache_dir='embeddings') + embeddings = cached_path(os.path.join(base_path, 'numberbatch-en'), cache_dir='embeddings') - # KOMNIOS embeddings - if embeddings.lower() == 'extvec' or embeddings.lower() == 'en-extvec': - cached_path('%sextvec.gensim.vectors.npy' % base_path) - embeddings = cached_path('%sextvec.gensim' % base_path) + # FT-CRAWL embeddings + if embeddings.lower() == 'crawl' or embeddings.lower() == 'en-crawl': + cached_path(os.path.join(base_path, 'ft-crawl.gensim.vectors.npy'), cache_dir='embeddings') + embeddings = cached_path(os.path.join(base_path, 'ft-crawl.gensim'), cache_dir='embeddings') # FT-CRAWL embeddings - if embeddings.lower() == 'ft-crawl' or embeddings.lower() == 'en-crawl' or embeddings.lower() == 'crawl': - cached_path('%sft-crawl.gensim.vectors.npy' % base_path) - embeddings = cached_path('%sft-crawl.gensim' % base_path) + if embeddings.lower() == 'news' or embeddings.lower() == 'en-news': + cached_path(os.path.join(base_path, 'ft-news.gensim.vectors.npy'), cache_dir='embeddings') + embeddings = cached_path(os.path.join(base_path, 'ft-news.gensim'), cache_dir='embeddings') # GERMAN FASTTEXT embeddings - if embeddings.lower() == 'ft-german' or embeddings.lower() == 'de-fasttext': - cached_path('%sft-wiki-de.gensim.vectors.npy' % base_path) - embeddings = cached_path('%sft-wiki-de.gensim' % base_path) + if embeddings.lower() == 'de-fasttext': + cached_path(os.path.join(base_path, 'ft-wiki-de.gensim.vectors.npy'), cache_dir='embeddings') + embeddings = cached_path(os.path.join(base_path, 'ft-wiki-de.gensim'), cache_dir='embeddings') + + # NUMBERBATCH embeddings + if embeddings.lower() == 'de-numberbatch': + cached_path(os.path.join(base_path, 'de-numberbatch.vectors.npy'), cache_dir='embeddings') + embeddings = cached_path(os.path.join(base_path, 'de-numberbatch'), cache_dir='embeddings') # SWEDISCH FASTTEXT embeddings if embeddings.lower() == 'sv-fasttext': - cached_path('%scc.sv.300.vectors.npy' % base_path) - embeddings = cached_path('%scc.sv.300' % base_path) - + cached_path(os.path.join(base_path, 'cc.sv.300.vectors.npy'), cache_dir='embeddings') + embeddings = cached_path(os.path.join(base_path, 'cc.sv.300'), cache_dir='embeddings') self.name = embeddings self.static_embeddings = True @@ -130,11 +155,12 @@ def __init__(self, embeddings): self.known_words = set(self.precomputed_word_embeddings.index2word) - self.embedding_length: int = self.precomputed_word_embeddings.vector_size + self.__embedding_length: int = self.precomputed_word_embeddings.vector_size super().__init__() + @property def embedding_length(self) -> int: - return self.embedding_length + return self.__embedding_length def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: @@ -147,6 +173,8 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: word_embedding = self.precomputed_word_embeddings[token.text] elif token.text.lower() in self.known_words: word_embedding = self.precomputed_word_embeddings[token.text.lower()] + elif re.sub('\d', '#', token.text.lower()) in self.known_words: + word_embedding = self.precomputed_word_embeddings[re.sub('\d', '#', token.text.lower())] elif re.sub('\d', '0', token.text.lower()) in self.known_words: word_embedding = self.precomputed_word_embeddings[re.sub('\d', '0', token.text.lower())] else: @@ -171,7 +199,7 @@ def __init__(self, path_to_char_dict: str = None): # get list of common characters if none provided if path_to_char_dict is None: base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/models/common_characters' - char_dict = cached_path(base_path) + char_dict = cached_path(base_path, cache_dir='datasets') # load dictionary self.char_dictionary: Dictionary = Dictionary() @@ -189,10 +217,11 @@ def __init__(self, path_to_char_dict: str = None): self.char_rnn = torch.nn.LSTM(self.char_embedding_dim, self.hidden_size_char, num_layers=1, bidirectional=True) - self.embedding_length = self.char_embedding_dim * 2 + self.__embedding_length = self.char_embedding_dim * 2 + @property def embedding_length(self) -> int: - return self.embedding_length + return self.__embedding_length def _add_embeddings_internal(self, sentences: List[Sentence]): @@ -272,32 +301,32 @@ def __init__(self, model, detach: bool = True): # news-english-forward if model.lower() == 'news-forward': base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings/lm-news-english-forward.pt' - model = cached_path(base_path) + model = cached_path(base_path, cache_dir='embeddings') # news-english-backward if model.lower() == 'news-backward': base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings/lm-news-english-backward.pt' - model = cached_path(base_path) + model = cached_path(base_path, cache_dir='embeddings') # mix-english-forward if model.lower() == 'mix-forward': base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings/lm-mix-english-forward.pt' - model = cached_path(base_path) + model = cached_path(base_path, cache_dir='embeddings') # mix-english-backward if model.lower() == 'mix-backward': base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings/lm-mix-english-backward.pt' - model = cached_path(base_path) + model = cached_path(base_path, cache_dir='embeddings') # mix-english-forward if model.lower() == 'german-forward': base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings/lm-mix-german-forward.pt' - model = cached_path(base_path) + model = cached_path(base_path, cache_dir='embeddings') # mix-english-backward if model.lower() == 'german-backward': base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings/lm-mix-german-backward.pt' - model = cached_path(base_path) + model = cached_path(base_path, cache_dir='embeddings') self.name = model self.static_embeddings = detach @@ -321,11 +350,11 @@ def __init__(self, model, detach: bool = True): dummy_sentence: Sentence = Sentence() dummy_sentence.add_token(Token('hello')) embedded_dummy = self.embed([dummy_sentence]) - self.embedding_length: int = len(embedded_dummy[0].get_token(1).get_embedding()) - + self.__embedding_length: int = len(embedded_dummy[0].get_token(1).get_embedding()) + @property def embedding_length(self) -> int: - return self.embedding_length + return self.__embedding_length def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: @@ -378,7 +407,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: offset_backward -= len(token.text) token.set_embedding(self.name, torch.autograd.Variable(embedding)) - self.embedding_length = len(embedding) + self.__embedding_length = len(embedding) return sentences @@ -392,7 +421,7 @@ def __init__(self, embedding_stack: StackedEmbeddings, corpus: TaggedCorpus, det self.name = 'Stack' self.static_embeddings = True - self.embedding_length: int = embedding_stack.embedding_length + self.__embedding_length: int = embedding_stack.embedding_length print(self.embedding_length) sentences = corpus.get_all_sentences() @@ -461,42 +490,49 @@ def embed(self, sentences: List[Sentence], static_embeddings: bool = True): word_embedding = torch.autograd.Variable(torch.FloatTensor(word_embedding)) token.set_embedding(self.name, word_embedding) + @property def embedding_length(self) -> int: - return self.embedding_length + return self.__embedding_length def _add_embeddings_internal(self, sentences: List[Sentence]): return sentences -class TextMeanEmbedder(): - - def __init__(self, word_embeddings: List[TextEmbeddings], detach: bool = True): +class TextMeanEmbedder(TextEmbeddings): + def __init__(self, word_embeddings: List[TextEmbeddings], reproject_words: bool = True): """The constructor takes a list of embeddings to be combined.""" super().__init__() self.embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=word_embeddings) - self.detach = detach - self.name = 'word_mean' - self.static_embeddings = True + self.name: str = 'word_mean' + self.reproject_words: bool = reproject_words + self.static_embeddings: bool = not reproject_words + + self.__embedding_length: int = 0 + self.__embedding_length = self.embeddings.embedding_length - self.embedding_length: int = 0 - self.embedding_length = self.embeddings.embedding_length + self.word_reprojection_map = torch.nn.Linear(self.__embedding_length, self.__embedding_length) + @property + def embedding_type(self): + return 'sentence-level' + + @property def embedding_length(self) -> int: - return self.embedding_length + return self.__embedding_length - def embed(self, paragraphs: List[Sentence]) -> List[Sentence]: + def embed(self, paragraphs: List[Sentence]): """Add embeddings to all words in a list of sentences. If embeddings are already added, updates only if embeddings are non-static.""" + everything_embedded: bool = True + # if only one sentence is passed, convert to list of sentence if type(paragraphs) is Sentence: paragraphs = [paragraphs] - everything_embedded: bool = True - for paragraph in paragraphs: - if self.name not in paragraph.embeddings.keys(): everything_embedded = False + if self.name not in paragraph._embeddings.keys(): everything_embedded = False if not everything_embedded or not self.static_embeddings: @@ -512,6 +548,153 @@ def embed(self, paragraphs: List[Sentence]) -> List[Sentence]: if torch.cuda.is_available(): word_embeddings = word_embeddings.cuda() - paragraph.set_embedding(self.name, torch.mean(word_embeddings, 0)) + if self.reproject_words: + word_embeddings = self.word_reprojection_map(word_embeddings) + + mean_embedding = torch.mean(word_embeddings, 0) + + # mean_embedding /= len(paragraph.tokens) + paragraph.set_embedding(self.name, mean_embedding) + + +class TextLSTMEmbedder(TextEmbeddings): + def __init__(self, word_embeddings: List[TextEmbeddings], hidden_states=128, num_layers=1, + reproject_words: bool = True): + """The constructor takes a list of embeddings to be combined.""" + super().__init__() + + # self.embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=word_embeddings) + self.embeddings: List[TextEmbeddings] = word_embeddings + + self.reproject_words = reproject_words + + self.length_of_all_word_embeddings = 0 + for word_embedding in self.embeddings: + self.length_of_all_word_embeddings += word_embedding.embedding_length + + self.name = 'text_lstm' + self.static_embeddings = False + + # self.__embedding_length: int = hidden_states + self.__embedding_length: int = hidden_states * 2 + + # bidirectional LSTM on top of embedding layer + self.word_reprojection_map = torch.nn.Linear(self.length_of_all_word_embeddings, + self.length_of_all_word_embeddings) + self.rnn = torch.nn.LSTM(self.length_of_all_word_embeddings, hidden_states, num_layers=num_layers, + bidirectional=True) + self.dropout = torch.nn.Dropout(0.5) + + @property + def embedding_type(self): + return 'sentence-level' + + @property + def embedding_length(self) -> int: + return self.__embedding_length + + def embed(self, sentences: List[Sentence]): + """Add embeddings to all words in a list of sentences. If embeddings are already added, updates only if embeddings + are non-static.""" + + self.rnn.zero_grad() + + sentences.sort(key=lambda x: len(x), reverse=True) + + for word_embedding in self.embeddings: + word_embedding.embed(sentences) + + # first, sort sentences by number of tokens + longest_token_sequence_in_batch: int = len(sentences[0]) + + all_sentence_tensors = [] + lengths: List[int] = [] + + # go through each sentence in batch + for i, sentence in enumerate(sentences): + + lengths.append(len(sentence.tokens)) + + word_embeddings = [] + + for token, token_idx in zip(sentence.tokens, range(len(sentence.tokens))): + token: Token = token + word_embeddings.append(token.get_embedding().unsqueeze(0)) + + # PADDING: pad shorter sentences out + for add in range(longest_token_sequence_in_batch - len(sentence.tokens)): + word_embeddings.append( + torch.autograd.Variable( + torch.FloatTensor(np.zeros(self.length_of_all_word_embeddings, dtype='float')).unsqueeze(0))) + + word_embeddings_tensor = torch.cat(word_embeddings, 0) + + sentence_states = word_embeddings_tensor + + # ADD TO SENTENCE LIST: add the representation + all_sentence_tensors.append(sentence_states.unsqueeze(1)) + + # -------------------------------------------------------------------- + # GET REPRESENTATION FOR ENTIRE BATCH + # -------------------------------------------------------------------- + sentence_tensor = torch.cat(all_sentence_tensors, 1) + if torch.cuda.is_available(): + sentence_tensor = sentence_tensor.cuda() + + # -------------------------------------------------------------------- + # FF PART + # -------------------------------------------------------------------- + if self.reproject_words: + sentence_tensor = self.word_reprojection_map(sentence_tensor) + + sentence_tensor = self.dropout(sentence_tensor) + + packed = torch.nn.utils.rnn.pack_padded_sequence(sentence_tensor, lengths) + + lstm_out, hidden = self.rnn(packed) + outputs, output_lengths = torch.nn.utils.rnn.pad_packed_sequence(lstm_out) + + outputs = self.dropout(outputs) + + for i, sentence in enumerate(sentences): + embedding = outputs[output_lengths[i].item() - 1, i] + sentence.set_embedding(self.name, embedding) + + +class TextLMEmbedder(TextEmbeddings): + def __init__(self, charlm_embeddings: List[CharLMEmbeddings], detach: bool = True): + super().__init__() + + self.embeddings = charlm_embeddings + + self.static_embeddings = detach + self.detach = detach + + dummy: Sentence = Sentence('jo') + self.embed([dummy]) + self._embedding_length: int = len(dummy.embedding) + + @property + def embedding_length(self) -> int: + return self._embedding_length + + @property + def embedding_type(self): + return 'sentence-level' + + def embed(self, sentences: List[Sentence]): + + for embedding in self.embeddings: + embedding.embed(sentences) + + # iterate over sentences + for sentence in sentences: + + # if its a forward LM, take last state + if embedding.is_forward_lm: + sentence.set_embedding(embedding.name, sentence[len(sentence)]._embeddings[embedding.name]) + else: + sentence.set_embedding(embedding.name, sentence[1]._embeddings[embedding.name]) + + return sentences - return paragraphs \ No newline at end of file diff --git a/flair/file_utils.py b/flair/file_utils.py index 37344dda5..4cf687a35 100644 --- a/flair/file_utils.py +++ b/flair/file_utils.py @@ -18,7 +18,6 @@ logger = logging.getLogger(__name__) # pylint: disable=invalid-name CACHE_ROOT = os.path.expanduser(os.path.join('~', '.flair')) -DATASET_CACHE = os.path.join(CACHE_ROOT, "datasets") def url_to_filename(url: str, etag: str = None) -> str: """ @@ -54,21 +53,20 @@ def filename_to_url(filename: str) -> Tuple[str, str]: url_bytes = base64.b64decode(filename_bytes) return url_bytes.decode('utf-8'), etag -def cached_path(url_or_filename: str, cache_dir: str = None) -> str: +def cached_path(url_or_filename: str, cache_dir: str) -> str: """ Given something that might be a URL (or might be a local path), determine which. If it's a URL, download the file and cache it, and return the path to the cached file. If it's already a local path, make sure the file exists and then return the path. """ - if cache_dir is None: - cache_dir = DATASET_CACHE + dataset_cache = os.path.join(CACHE_ROOT, cache_dir) parsed = urlparse(url_or_filename) if parsed.scheme in ('http', 'https'): # URL, so get it from the cache (downloading if necessary) - return get_from_cache(url_or_filename, cache_dir) + return get_from_cache(url_or_filename, dataset_cache) elif parsed.scheme == '' and os.path.exists(url_or_filename): # File, and it exists. return url_or_filename @@ -86,21 +84,22 @@ def get_from_cache(url: str, cache_dir: str = None) -> str: Given a URL, look for the corresponding dataset in the local cache. If it's not there, download it. Then return the path to the cached file. """ - if cache_dir is None: - cache_dir = DATASET_CACHE os.makedirs(cache_dir, exist_ok=True) + filename = re.sub(r'.+/', '', url) + # get cache path to put the file + cache_path = os.path.join(cache_dir, filename) + if os.path.exists(cache_path): + return cache_path + # make HEAD request to check ETag response = requests.head(url) if response.status_code != 200: raise IOError("HEAD request failed for url {}".format(url)) # add ETag to filename if it exists - etag = response.headers.get("ETag") - filename = re.sub(r'.+/', '', url) - # get cache path to put the file - cache_path = os.path.join(cache_dir, filename) + # etag = response.headers.get("ETag") if not os.path.exists(cache_path): # Download to temporary file, then copy to cache dir once finished. diff --git a/flair/language_model.py b/flair/language_model.py index 0b986f101..06d93d78d 100644 --- a/flair/language_model.py +++ b/flair/language_model.py @@ -16,6 +16,8 @@ def __init__(self, rnn_type, ntoken, ninp, nhid, nout, nlayers, dropout=0.5): self.dictionary = Dictionary() self.is_forward_lm: bool = True + self.dropout = dropout + self.drop = nn.Dropout(dropout) self.encoder = nn.Embedding(ntoken, ninp) @@ -30,6 +32,7 @@ def __init__(self, rnn_type, ntoken, ninp, nhid, nout, nlayers, dropout=0.5): self.rnn_type = rnn_type self.nhid = nhid + self.ninp = ninp self.nlayers = nlayers self.hidden = None @@ -57,7 +60,7 @@ def forward(self, input, hidden, ordered_sequence_lengths=None): output, hidden = self.rnn(emb, hidden) if self.proj is not None: - output = self.proj(output) + output = self.proj(output) output = self.drop(output) @@ -108,8 +111,23 @@ def initialize(self, matrix): def load_language_model(cls, model_file): state = torch.load(model_file) model = RNNModel(state['rnn_type'], state['ntoken'], state['ninp'], state['nhid'], state['nout'], - state['nlayers'], state['dropout']) + state['nlayers'], state['dropout']) model.load_state_dict(state['state_dict']) model.is_forward_lm = state['is_forward_lm'] model.dictionary = state['char_dictionary_forward'] - return model \ No newline at end of file + return model + + def save(self, file): + model_state = { + 'state_dict': self.state_dict(), + 'is_forward_lm': self.is_forward_lm, + 'char_dictionary_forward': self.dictionary, + 'rnn_type': self.rnn_type, + 'ntoken': len(self.dictionary), + 'ninp': self.ninp, + 'nhid': self.nhid, + 'nout': self.proj, + 'nlayers': self.nlayers, + 'dropout': self.dropout + } + torch.save(model_state, file, pickle_protocol=4) diff --git a/flair/tagging_model.py b/flair/tagging_model.py index e329e519a..73361a8dc 100644 --- a/flair/tagging_model.py +++ b/flair/tagging_model.py @@ -1,6 +1,9 @@ +import warnings + import torch.autograd as autograd import torch.nn as nn import torch +import os import numpy as np from flair.file_utils import cached_path @@ -30,22 +33,23 @@ def log_sum_exp(vec): torch.log(torch.sum(torch.exp(vec - max_score_broadcast))) -class SequenceTaggerLSTM(nn.Module): - +class SequenceTagger(nn.Module): def __init__(self, hidden_size: int, embeddings, tag_dictionary: Dictionary, + tag_type: str, use_crf: bool = True, use_rnn: bool = True, rnn_layers: int = 1 ): - super(SequenceTaggerLSTM, self).__init__() + super(SequenceTagger, self).__init__() - self.use_RNN = use_rnn + self.use_rnn = use_rnn self.hidden_size = hidden_size self.use_crf: bool = use_crf + self.rnn_layers: int = rnn_layers self.trained_epochs: int = 0 @@ -53,6 +57,7 @@ def __init__(self, # set the dictionaries self.tag_dictionary: Dictionary = tag_dictionary + self.tag_type: str = tag_type self.tagset_size: int = len(tag_dictionary) # initialize the network architecture @@ -86,7 +91,7 @@ def __init__(self, self.relu = nn.ReLU() # final linear map to tag space - if self.use_RNN: + if self.use_rnn: self.linear = nn.Linear(hidden_size * 2, len(tag_dictionary)) else: self.linear = nn.Linear(self.embeddings.embedding_length, len(tag_dictionary)) @@ -98,30 +103,44 @@ def __init__(self, self.transitions.data[self.tag_dictionary.get_idx_for_item(START_TAG), :] = -10000 self.transitions.data[:, self.tag_dictionary.get_idx_for_item(STOP_TAG)] = -10000 - @staticmethod - def load(model: str): - model_file = None - - if model.lower() == 'ner': - base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/models/ner-conll03.pt' - model_file = cached_path(base_path) - - if model.lower() == 'chunk': - base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/models/chunk-conll2000.pt' - model_file = cached_path(base_path) - - if model.lower() == 'pos': - base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/models/pos-ontonotes-small.pt' - model_file = cached_path(base_path) - - if model_file is not None: - tagger: SequenceTaggerLSTM = torch.load(model_file, map_location={'cuda:0': 'cpu'}) - tagger.eval() - if torch.cuda.is_available(): - tagger = tagger.cuda() - return tagger + def save(self, model_file: str): + model_state = { + 'state_dict': self.state_dict(), + 'embeddings': self.embeddings, + 'hidden_size': self.hidden_size, + 'tag_dictionary': self.tag_dictionary, + 'tag_type': self.tag_type, + 'use_crf': self.use_crf, + 'use_rnn': self.use_rnn, + 'rnn_layers': self.rnn_layers, + } + torch.save(model_state, model_file, pickle_protocol=4) + + @classmethod + def load_from_file(cls, model_file): + + # ACHTUNG: suppressing torch serialization warnings. This needs to be taken out once we sort out recursive + # serialization of torch objects + warnings.filterwarnings("ignore") + state = torch.load(model_file, map_location={'cuda:0': 'cpu'}) + warnings.filterwarnings("default") + + model = SequenceTagger( + hidden_size=state['hidden_size'], + embeddings=state['embeddings'], + tag_dictionary=state['tag_dictionary'], + tag_type=state['tag_type'], + use_crf=state['use_crf'], + use_rnn=state['use_rnn'], + rnn_layers=state['rnn_layers']) + + model.load_state_dict(state['state_dict']) + model.eval() + if torch.cuda.is_available(): + model = model.cuda() + return model - def forward(self, sentences: List[Sentence], tag_type: str) -> Tuple[List, List]: + def forward(self, sentences: List[Sentence]) -> Tuple[List, List]: self.zero_grad() @@ -152,7 +171,7 @@ def forward(self, sentences: List[Sentence], tag_type: str) -> Tuple[List, List] token: Token = token # get the tag - tag_idx.append(self.tag_dictionary.get_idx_for_item(token.get_tag(tag_type))) + tag_idx.append(self.tag_dictionary.get_idx_for_item(token.get_tag(self.tag_type))) word_embeddings.append(token.get_embedding().unsqueeze(0)) @@ -190,7 +209,7 @@ def forward(self, sentences: List[Sentence], tag_type: str) -> Tuple[List, List] if self.relearn_embeddings: tagger_states = self.embedding2nn(tagger_states) - if self.use_RNN: + if self.use_rnn: packed = torch.nn.utils.rnn.pack_padded_sequence(tagger_states, lengths) rnn_output, hidden = self.rnn(packed) @@ -267,7 +286,7 @@ def neg_log_likelihood(self, sentences: List[Sentence], tag_type: str): # features is a 2D tensor, len(sentence) * self.tagset_size # for sentence in sentences: # print(sentence) - feats, tags = self.forward(sentences, tag_type) + feats, tags = self.forward(sentences) if self.use_crf: @@ -322,8 +341,8 @@ def _forward_alg(self, feats): # Z(x) return alpha - def predict_scores(self, sentence: Sentence, tag_type: str): - feats, tags = self.forward([sentence], tag_type) + def predict_scores(self, sentence: Sentence): + feats, tags = self.forward([sentence]) feats = feats[0] tags = tags[0] # viterbi to get tag_seq @@ -335,20 +354,76 @@ def predict_scores(self, sentence: Sentence, tag_type: str): return score, tag_seq - def predict(self, sentence: Sentence, tag_type: str = 'tag') -> Sentence: + def predict(self, sentence: Sentence) -> Sentence: - score, tag_seq = self.predict_scores(sentence, tag_type) - # sentences_out = copy.deepcopy(sentence) + score, tag_seq = self.predict_scores(sentence) predicted_id = tag_seq for (token, pred_id) in zip(sentence.tokens, predicted_id): token: Token = token # get the predicted tag predicted_tag = self.tag_dictionary.get_item_for_index(pred_id) - token.add_tag(tag_type, predicted_tag) + token.add_tag(self.tag_type, predicted_tag) return sentence + @staticmethod + def load(model: str): + model_file = None + aws_resource_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/models' + + if model.lower() == 'ner': + base_path = '/'.join([aws_resource_path, + 'NER-conll03--h256-l1-b32-%2Bglove%2Bnews-forward%2Bnews-backward--anneal', + 'en-ner-conll03-v0.1.pt']) + model_file = cached_path(base_path, cache_dir='models') + + if model.lower() == 'ner-ontonotes': + base_path = '/'.join([aws_resource_path, + 'NER-ontoner--h256-l1-b32-%2Bft-crawl%2Bnews-forward%2Bnews-backward--anneal', + 'en-ner-ontonotes-v0.1.pt']) + model_file = cached_path(base_path, cache_dir='models') + + if model.lower() == 'chunk': + base_path = '/'.join([aws_resource_path, + 'NP-conll2000--h256-l1-b32-%2Bnews-forward%2Bnews-backward--anneal', + 'en-chunk-conll2000-v0.1.pt']) + model_file = cached_path(base_path, cache_dir='models') + + if model.lower() == 'pos': + base_path = '/'.join([aws_resource_path, + 'POS-ontonotes--h256-l1-b32-%2Bmix-forward%2Bmix-backward--anneal', + 'en-pos-ontonotes-v0.1.pt']) + model_file = cached_path(base_path, cache_dir='models') + + if model.lower() == 'frame': + base_path = '/'.join([aws_resource_path, + 'FRAME-conll12--h256-l1-b8-%2Bnews%2Bnews-forward%2Bnews-backward--anneal', + 'en-frame-ontonotes-v0.1.pt']) + model_file = cached_path(base_path, cache_dir='models') + + if model.lower() == 'de-pos': + base_path = '/'.join([aws_resource_path, + 'UPOS-udgerman--h256-l1-b8-%2Bgerman-forward%2Bgerman-backward--anneal', + 'de-pos-ud-v0.1.pt']) + model_file = cached_path(base_path, cache_dir='models') + + if model.lower() == 'de-ner': + base_path = '/'.join([aws_resource_path, + 'NER-conll03ger--h256-l1-b32-%2Bde-fasttext%2Bgerman-forward%2Bgerman-backward--anneal', + 'de-ner-conll03-v0.1.pt']) + model_file = cached_path(base_path, cache_dir='models') + + if model.lower() == 'de-ner-germeval': + base_path = '/'.join([aws_resource_path, + 'NER-germeval--h256-l1-b32-%2Bde-fasttext%2Bgerman-forward%2Bgerman-backward--anneal', + 'de-ner-germeval-v0.1.pt']) + model_file = cached_path(base_path, cache_dir='models') + + if model_file is not None: + tagger: SequenceTagger = SequenceTagger.load_from_file(model_file) + return tagger + class LockedDropout(nn.Module): def __init__(self, dropout_rate=0.5): super(LockedDropout, self).__init__() @@ -362,37 +437,3 @@ def forward(self, x): mask = torch.autograd.Variable(m, requires_grad=False) / (1 - self.dropout_rate) mask = mask.expand_as(x) return mask * x - - -class Fokus(nn.Module): - def __init__(self, dropout_rate=0.5): - super(Fokus, self).__init__() - self.dropout_rate = dropout_rate - - def forward(self, x): - if not self.training or not self.dropout_rate: - return x - - states = len(x.data[0, 0, :]) - # print(states) - - import random - mu, sigma = 0.5, 0.2 # mean and standard deviation - s = sorted(np.random.normal(mu, sigma, states)) - # print(s) - mask = [0 if i < random.random() else 1 for i in s] - - if torch.cuda.is_available(): - mask = torch.autograd.Variable(torch.cuda.FloatTensor(mask), requires_grad=False) - else: - mask = torch.autograd.Variable(torch.FloatTensor(mask), requires_grad=False) - # print(mask) - # print(mask * x) - - # m = x.data.new(1, x.size(1), x.size(2)).bernoulli_(1 - self.dropout_rate) - # # print(x.data.new(1, x.size(1), x.size(2))) - # # print(m) - # mask = torch.autograd.Variable(m, requires_grad=False) / (1 - self.dropout_rate) - # mask = mask.expand_as(x) - # asd - return mask * x diff --git a/flair/trainer.py b/flair/trainer.py index b5986c1b5..2392b5c23 100644 --- a/flair/trainer.py +++ b/flair/trainer.py @@ -1,5 +1,5 @@ from .data import Sentence, Token, TaggedCorpus, Dictionary -from .tagging_model import SequenceTaggerLSTM +from .tagging_model import SequenceTagger from typing import List, Dict, Tuple @@ -9,12 +9,9 @@ class TagTrain: - - def __init__(self, model: SequenceTaggerLSTM, corpus: TaggedCorpus, tag_type: str, - test_mode: bool = False) -> None: - self.model: SequenceTaggerLSTM = model + def __init__(self, model: SequenceTagger, corpus: TaggedCorpus, test_mode: bool = False) -> None: + self.model: SequenceTagger = model self.corpus: TaggedCorpus = corpus - self.tag_type: str = tag_type self.test_mode: bool = test_mode def train(self, @@ -27,8 +24,10 @@ def train(self, train_with_dev: bool = False, anneal_mode: bool = False): + checkpoint: bool = False + evaluate_with_fscore: bool = True - if self.tag_type not in ['ner', 'np']: evaluate_with_fscore = False + if self.model.tag_type not in ['ner', 'np', 'srl']: evaluate_with_fscore = False self.base_path = base_path os.makedirs(self.base_path, exist_ok=True) @@ -48,8 +47,10 @@ def train(self, try: # record overall best dev scores and best loss - best_dev_score = 0 - best_loss: float = 10000 + best_score = 0 + if train_with_dev: best_score = 10000 + # best_dev_score = 0 + # best_loss: float = 10000 # this variable is used for annealing schemes epochs_without_improvement: int = 0 @@ -76,7 +77,7 @@ def train(self, optimizer.zero_grad() # Step 4. Compute the loss, gradients, and update the parameters by calling optimizer.step() - loss = self.model.neg_log_likelihood(batch, self.tag_type) + loss = self.model.neg_log_likelihood(batch, self.model.tag_type) current_loss += loss.item() @@ -111,19 +112,9 @@ def train(self, evaluate_with_fscore=evaluate_with_fscore, embeddings_in_memory=embeddings_in_memory) - summary = '%d' % epoch + '\t({:%H:%M:%S})'.format(datetime.datetime.now()) \ - + '\t%f\t%f\tDEV %d\t' % (current_loss, learning_rate, dev_fp) + dev_result - summary = summary.replace('\n', '') - summary += '\tTEST \t%d\t' % test_fp + test_result - # IMPORTANT: Switch back to train mode self.model.train() - print(summary) - with open(loss_txt, "a") as loss_file: - loss_file.write('%s\n' % summary) - loss_file.close() - # checkpoint model self.model.trained_epochs = epoch @@ -131,42 +122,61 @@ def train(self, is_best_model_so_far: bool = False # if dev data is used for model selection, use dev F1 score to determine best model - if not train_with_dev and dev_score > best_dev_score: - best_dev_score = dev_score + if not train_with_dev and dev_score > best_score: + best_score = dev_score is_best_model_so_far = True - print('new best dev F1: %f' % best_dev_score) # if dev data is used for training, use training loss to determine best model - if train_with_dev and current_loss < best_loss: - best_loss = current_loss - epochs_without_improvement = 0 + if train_with_dev and current_loss < best_score: + best_score = current_loss is_best_model_so_far = True - print('after %d - new best loss: %f' % (epochs_without_improvement, best_loss)) if is_best_model_so_far: + print('after %d - new best score: %f' % (epochs_without_improvement, best_score)) + epochs_without_improvement = 0 - if save_model or anneal_mode: - with open(base_path + "/model.pt", 'wb') as model_save_file: - torch.save(self.model, model_save_file, pickle_protocol=4) - model_save_file.close() - print(model_save_file.closed) - print('.. model saved ... ') + # save model + if save_model or (anneal_mode and checkpoint): + self.model.save(base_path + "/model.pt") + print('.. model saved ... ') else: epochs_without_improvement += 1 - if epochs_without_improvement == 5 and anneal_mode: - epochs_without_improvement = 0 + # anneal after 3 epochs of no improvement if anneal mode + if epochs_without_improvement == 3 and anneal_mode: + best_score = current_loss learning_rate /= 2 - self.model = torch.load(base_path + '/model.pt') + if checkpoint: + self.model = SequenceTagger.load_from_file(base_path + '/model.pt') + optimizer = torch.optim.SGD(self.model.parameters(), lr=learning_rate) + # print info + summary = '%d' % epoch + '\t({:%H:%M:%S})'.format(datetime.datetime.now()) \ + + '\t%f\t%d\t%f\tDEV %d\t' % (current_loss, epochs_without_improvement, learning_rate, dev_fp) + dev_result + summary = summary.replace('\n', '') + summary += '\tTEST \t%d\t' % test_fp + test_result + + print(summary) + with open(loss_txt, "a") as loss_file: + loss_file.write('%s\n' % summary) + loss_file.close() + + self.model.save(base_path + "/final-model.pt") + except KeyboardInterrupt: print('-' * 89) print('Exiting from training early') + print('saving model') + with open(base_path + "/final-model.pt", 'wb') as model_save_file: + torch.save(self.model, model_save_file, pickle_protocol=4) + model_save_file.close() + print('done') + def evaluate(self, evaluation: List[Sentence], evaluate_with_fscore: bool = True, embeddings_in_memory: bool = True): @@ -190,7 +200,7 @@ def evaluate(self, evaluation: List[Sentence], evaluate_with_fscore: bool = True sentence: Sentence = sentence # Step 3. Run our forward pass. - score, tag_seq = self.model.predict_scores(sentence, self.tag_type) + score, tag_seq = self.model.predict_scores(sentence) # Step 5. Compute predictions predicted_id = tag_seq @@ -202,7 +212,7 @@ def evaluate(self, evaluation: List[Sentence], evaluate_with_fscore: bool = True token.add_tag('predicted', predicted_tag) # get the gold tag - gold_tag = token.get_tag(self.tag_type) + gold_tag = token.get_tag(self.model.tag_type) # append both to file for evaluation eval_line = token.text + ' ' + gold_tag + ' ' + predicted_tag + "\n" diff --git a/predict.py b/predict.py index 6ca4723f8..bed14bd06 100644 --- a/predict.py +++ b/predict.py @@ -1,11 +1,11 @@ from flair.data import Sentence -from flair.tagging_model import SequenceTaggerLSTM +from flair.tagging_model import SequenceTagger -tagger: SequenceTaggerLSTM = SequenceTaggerLSTM.load('ner') +tagger: SequenceTagger = SequenceTagger.load('ner') sentence: Sentence = Sentence('George Washington went to Washington .') -tagger.predict(sentence, tag_type='ner') +tagger.predict(sentence) print('Analysing %s' % sentence) print('\nThe following NER tags are found: \n') -print(sentence.to_ner_string()) \ No newline at end of file +print(sentence.to_tagged_string()) \ No newline at end of file diff --git a/resources/docs/EXPERIMENTS.md b/resources/docs/EXPERIMENTS.md index f653e1e43..7ab40c828 100644 --- a/resources/docs/EXPERIMENTS.md +++ b/resources/docs/EXPERIMENTS.md @@ -25,7 +25,7 @@ This allows the `NLPTaskDataFetcher` class to read the data into our data struct the dataset, as follows: ```python -corpus: TaggedCorpus = task_data_fetcher.fetch_data(NLPTask.CONLL_03) +corpus: TaggedCorpus = NLPTaskDataFetcher.fetch_data(NLPTask.CONLL_03) ``` This gives you a `TaggedCorpus` object that contains the data. @@ -67,17 +67,21 @@ embedding_types: List[TextEmbeddings] = [ embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=embedding_types) # initialize sequence tagger -from flair.tagging_model import SequenceTaggerLSTM - -tagger: SequenceTaggerLSTM = SequenceTaggerLSTM(hidden_size=256, embeddings=embeddings, tag_dictionary=tag_dictionary, - use_crf=True) +from flair.tagging_model import SequenceTagger + +tagger: SequenceTagger = SequenceTagger(hidden_size=256, + embeddings=embeddings, + tag_dictionary=tag_dictionary, + tag_type=tag_type, + use_crf=True) + if torch.cuda.is_available(): tagger = tagger.cuda() # initialize trainer from flair.trainer import TagTrain -trainer: TagTrain = TagTrain(tagger, corpus, tag_type=tag_type, test_mode=False) +trainer: TagTrain = TagTrain(tagger, corpus, test_mode=False) trainer.train('resources/taggers/example-ner', mini_batch_size=32, max_epochs=150, save_model=True, train_with_dev=True, anneal_mode=True) @@ -133,17 +137,22 @@ embedding_types: List[TextEmbeddings] = [ embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=embedding_types) # initialize sequence tagger -from flair.tagging_model import SequenceTaggerLSTM +from flair.tagging_model import SequenceTagger -tagger: SequenceTaggerLSTM = SequenceTaggerLSTM(hidden_size=256, embeddings=embeddings, tag_dictionary=tag_dictionary, - use_crf=True) + +tagger: SequenceTagger = SequenceTagger(hidden_size=256, + embeddings=embeddings, + tag_dictionary=tag_dictionary, + tag_type=tag_type, + use_crf=True) + if torch.cuda.is_available(): tagger = tagger.cuda() # initialize trainer from flair.trainer import TagTrain -trainer: TagTrain = TagTrain(tagger, corpus, tag_type=tag_type, test_mode=False) +trainer: TagTrain = TagTrain(tagger, corpus, test_mode=False) trainer.train('resources/taggers/example-ner', mini_batch_size=32, max_epochs=150, save_model=True, train_with_dev=True, anneal_mode=True) @@ -196,17 +205,21 @@ embedding_types: List[TextEmbeddings] = [ embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=embedding_types) # initialize sequence tagger -from flair.tagging_model import SequenceTaggerLSTM - -tagger: SequenceTaggerLSTM = SequenceTaggerLSTM(hidden_size=256, embeddings=embeddings, tag_dictionary=tag_dictionary, - use_crf=True) +from flair.tagging_model import SequenceTagger + +tagger: SequenceTagger = SequenceTagger(hidden_size=256, + embeddings=embeddings, + tag_dictionary=tag_dictionary, + tag_type=tag_type, + use_crf=True) + if torch.cuda.is_available(): tagger = tagger.cuda() # initialize trainer from flair.trainer import TagTrain -trainer: TagTrain = TagTrain(tagger, corpus, tag_type=tag_type, test_mode=False) +trainer: TagTrain = TagTrain(tagger, corpus, test_mode=False) trainer.train('resources/taggers/example-ner', mini_batch_size=32, max_epochs=150, save_model=True, train_with_dev=True, anneal_mode=True) @@ -258,17 +271,21 @@ embedding_types: List[TextEmbeddings] = [ embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=embedding_types) # initialize sequence tagger -from flair.tagging_model import SequenceTaggerLSTM - -tagger: SequenceTaggerLSTM = SequenceTaggerLSTM(hidden_size=256, embeddings=embeddings, tag_dictionary=tag_dictionary, - use_crf=True) +from flair.tagging_model import SequenceTagger + +tagger: SequenceTagger = SequenceTagger(hidden_size=256, + embeddings=embeddings, + tag_dictionary=tag_dictionary, + tag_type=tag_type, + use_crf=True) + if torch.cuda.is_available(): tagger = tagger.cuda() # initialize trainer from flair.trainer import TagTrain -trainer: TagTrain = TagTrain(tagger, corpus, tag_type=tag_type, test_mode=False) +trainer: TagTrain = TagTrain(tagger, corpus, test_mode=False) trainer.train('resources/taggers/example-ner', mini_batch_size=32, max_epochs=150, save_model=True, train_with_dev=True, anneal_mode=True) @@ -323,17 +340,21 @@ embedding_types: List[TextEmbeddings] = [ embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=embedding_types) # initialize sequence tagger -from flair.tagging_model import SequenceTaggerLSTM - -tagger: SequenceTaggerLSTM = SequenceTaggerLSTM(hidden_size=256, embeddings=embeddings, tag_dictionary=tag_dictionary, - use_crf=True) +from flair.tagging_model import SequenceTagger + +tagger: SequenceTagger = SequenceTagger(hidden_size=256, + embeddings=embeddings, + tag_dictionary=tag_dictionary, + tag_type=tag_type, + use_crf=True) + if torch.cuda.is_available(): tagger = tagger.cuda() # initialize trainer from flair.trainer import TagTrain -trainer: TagTrain = TagTrain(tagger, corpus, tag_type=tag_type, test_mode=False) +trainer: TagTrain = TagTrain(tagger, corpus, test_mode=False) trainer.train('resources/taggers/example-pos', mini_batch_size=32, max_epochs=150, save_model=True, train_with_dev=True, anneal_mode=True) @@ -386,17 +407,21 @@ embedding_types: List[TextEmbeddings] = [ embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=embedding_types) # initialize sequence tagger -from flair.tagging_model import SequenceTaggerLSTM - -tagger: SequenceTaggerLSTM = SequenceTaggerLSTM(hidden_size=256, embeddings=embeddings, tag_dictionary=tag_dictionary, - use_crf=True) +from flair.tagging_model import SequenceTagger + +tagger: SequenceTagger = SequenceTagger(hidden_size=256, + embeddings=embeddings, + tag_dictionary=tag_dictionary, + tag_type=tag_type, + use_crf=True) + if torch.cuda.is_available(): tagger = tagger.cuda() # initialize trainer from flair.trainer import TagTrain -trainer: TagTrain = TagTrain(tagger, corpus, tag_type=tag_type, test_mode=False) +trainer: TagTrain = TagTrain(tagger, corpus, test_mode=False) trainer.train('resources/taggers/example-pos', mini_batch_size=32, max_epochs=150, save_model=True, train_with_dev=True, anneal_mode=True) diff --git a/resources/docs/TUTORIAL.md b/resources/docs/TUTORIAL.md deleted file mode 100644 index ab3ebf63c..000000000 --- a/resources/docs/TUTORIAL.md +++ /dev/null @@ -1,419 +0,0 @@ -# Tutorial - -Let's look into some core functionality to understand the library better. A [Jupyter notebook](/tutorial.ipynb) version of this tutorial -is available, too. - -## NLP base types - -There are two types of objects that are central to this library, namely the `Sentence` and `Token` objects. A `Sentence` -holds a textual sentence and is essentially a list of `Token`. - -Let's start by making a `Sentence` object for an example sentence. - -```python -# The sentence objects holds a sentence that we may want to embed -from flair.data import Sentence - -# Make a sentence object by passing a whitespace tokenized string -sentence = Sentence('The grass is green .') - -# Print the object to see what's in there -print(sentence) -``` - -This should print: - -```console -Sentence: "The grass is green ." - 5 Tokens -``` - -The print-out tells us that the sentence consists of 5 tokens. -You can access the tokens of a sentence via their token id: - -```python -print(sentence[4]) -``` - -which should print - -```console -Token: 4 green -``` - -This print-out includes the token id (4) and the lexical value of the token ("green"). You can also iterate over all -tokens in a sentence. - -```python -for token in sentence: - print(token) -``` - -This should print: - -```console -Token: 1 The -Token: 2 grass -Token: 3 is -Token: 4 green -Token: 5 . -``` - -A Token has fields for linguistic annotation, such as lemmas, part-of-speech tags or named entity tags. You can -add a tag by specifying the tag type and the tag value. In this example, we're adding an NER tag of type 'color' to -the word 'green'. This means that we've tagged this word as an entity of type color. - -```python -# add a tag to a word in the sentence -sentence[4].add_tag('ner', 'color') - -# print the sentence with all tags of this type -print(sentence.to_ner_string()) -``` - -This should print: - -```console -The grass is green . -``` - - -## Tagging with Pre-Trained Models - -Now, lets use a pre-trained model for named entity recognition (NER). -This model was trained over the English CoNLL-03 task and can recognize 4 different entity -types. - -```python -from flair.tagging_model import SequenceTaggerLSTM - -tagger = SequenceTaggerLSTM.load('ner') -``` -All you need to do is use the `predict()` method of the tagger on a sentence. This will add predicted tags to the tokens -in the sentence. Lets use a sentence with two named -entities: - -```python -sentence = Sentence('George Washington went to Washington .') - -# predict NER tags -tagger.predict(sentence) - -# print sentence with predicted tags -print(sentence.to_tag_string()) -``` - -This should print: -```console -George Washington went to Washington . -``` - -You chose which pre-trained model you load by passing the appropriate -string you pass to the `load()` method of the `SequenceTaggerLSTM` class. Currently, the following pre-trained models -are provided (more coming): - -| ID | Task + Training Dataset | Accuracy | -| ------------- | ------------- | ------------- | -| 'ner' | Conll-03 Named Entity Recognition (English) | **93.17** (F1) | -| 'chunk' | Conll-2000 Syntactic Chunking (English) | **96.74** (F1) | -| 'pos' | Ontonotes Part-of-Speech Tagging (English) | **98.06** (Accuracy) | - -So, if you want to use a `SequenceTaggerLSTM` that performs PoS tagging, instantiate the tagger as follows: - -```python -tagger = SequenceTaggerLSTM.load('pos') -``` - - -## Embeddings - -We provide a set of classes with which you can embed the words in sentences in various ways. Note that all embedding -classes inherit from the `TextEmbeddings` class and implement the `embed()` method which you need to call -to embed your text. This means that for most users of Flair, the complexity of different embeddings remains hidden -behind this interface. Simply instantiate the embedding class you require and call `embed()` to embed your text. - -All embeddings produced with our methods are pytorch vectors, so they can be immediately used for training and -fine-tuning. - -### Classic Word Embeddings - -Classic word embeddings are static and word-level, meaning that each distinc word gets exactly one pre-computed -embedding. Most embeddings fall under this class, including the popular GloVe or Komnios embeddings. - -Simply instantiate the WordEmbeddings class and pass a string identifier of the embedding you wish to load. So, if -you want to use GloVe embeddings, pass the string 'glove' to the constructor: - -```python -# all embeddings inherit from the TextEmbeddings class. Init a simple glove embedding. -from flair.embeddings import WordEmbeddings -glove_embedding = WordEmbeddings('glove') -``` -Now, create an example sentence and call the embedding's `embed()` method. You always pass a list of sentences to -this method since some embedding types make use of batching to increase speed. So if you only have one sentence, -pass a list containing only one sentence: - -```python -# embed a sentence using glove. -from flair.data import Sentence -sentence = Sentence('The grass is green .') -glove_embedding.embed(sentences=[sentence]) - -# now check out the embedded tokens. -for token in sentence: - print(token) - print(token.embedding) -``` - -This prints out the tokens and their embeddings. GloVe embeddings are pytorch vectors of dimensionality 100. - -You choose which pre-trained embeddings you load by passing the appropriate -string you pass to the constructor of the `WordEmbeddings` class. Currently, the following static embeddings -are provided (more coming): - -| ID | Embedding | -| ------------- | ------------- | -| 'glove' | GloVe embeddings | -| 'extvec' | Komnios embeddings | -| 'ft-crawl' | FastText embeddings | -| 'ft-german' | German FastText embeddings | - -So, if you want to load German FastText embeddings, instantiate the method as follows: - -```python -german_embedding = WordEmbeddings('ft-german') -``` - -### Contextual String Embeddings - - -Contextual string embeddings are [powerful embeddings](https://drive.google.com/file/d/17yVpFA7MmXaQFTe-HDpZuqw9fJlmzg56/view?usp=sharing) - that capture latent syntactic-semantic information that goes beyond -standard word embeddings. Key differences are: (1) they are trained without any explicit notion of words and -thus fundamentally model words as sequences of characters. And (2) they are **contextualized** by their -surrounding text, meaning that the *same word will have different embeddings depending on its -contextual use*. - -With Flair, you can use these embeddings simply by instantiating the appropriate embedding class, same as before: - -```python - -# the CharLMEmbedding also inherits from the TextEmbeddings class -from flair.embeddings import CharLMEmbeddings -charlm_embedding_forward = CharLMEmbeddings('news-forward') - -# embed a sentence using CharLM. -from flair.data import Sentence -sentence = Sentence('The grass is green .') -charlm_embedding_forward.embed(sentences=[sentence]) -``` - -You choose which embeddings you load by passing the appropriate -string you pass to the constructor of the `CharLMEmbeddings` class. Currently, the following contextual string - embeddings -are provided (more coming): - -| ID | Language | Embedding | -| ------------- | ------------- | ------------- | -| 'news-forward' | English | Forward LM embeddings over 1 billion word corpus | -| 'news-backward' | English | Backward LM embeddings over 1 billion word corpus | -| 'mix-forward' | English | Forward LM embeddings over mixed corpus (Web, Wikipedia, Subtitles) | -| 'mix-backward' | English | Backward LM embeddings over mixed corpus (Web, Wikipedia, Subtitles) | -| 'german-forward' | German | Forward LM embeddings over mixed corpus (Web, Wikipedia, Subtitles) | -| 'german-backward' | German | Backward LM embeddings over mixed corpus (Web, Wikipedia, Subtitles) | - -So, if you want to load embeddings from the English news backward LM model, instantiate the method as follows: - -```python -charlm_embedding_backward = CharLMEmbeddings('news-backward') -``` - - -### Character Embeddings - -Some embeddings - such as character-features - are not pre-trained but rather trained on the downstream task. Normally -this requires you to implement a [hierarchical embedding architecture](http://neuroner.com/NeuroNERengine_with_caption_no_figure.png). - -With Flair, you need not worry about such things. Just choose the appropriate -embedding class and character features will then automatically train during downstream task training. - -```python -# the CharLMEmbedding also inherits from the TextEmbeddings class -from flair.embeddings import CharacterEmbeddings -embedder = CharacterEmbeddings() - -# embed a sentence using CharLM. -from flair.data import Sentence -sentence = Sentence('The grass is green .') -embedder.embed(sentences=[sentence]) -``` - -## Stacked Embeddings - -Stacked embeddings are one of the most important concepts of this library. You can use them to combine different embeddings -together, for instance if you want to use both traditional embeddings together with contextual sting embeddings. -Stacked embeddings allow you to mix and match. We find that a combination of embeddings often gives best results. - -All you need to do is use the `StackedEmbeddings` class and instantiate it by passing a list of embeddings that you wish -to combine. For instance, lets combine classic GloVe embeddings with embeddings from a forward and backward -character language model. - -First, instantiate the three embeddings you wish to combine: - -```python -# the CharLMEmbedding also inherits from the TextEmbeddings class -from flair.embeddings import WordEmbeddings, CharLMEmbeddings - -# init GloVe embedding -glove_embedding = WordEmbeddings('glove') - -# init CharLM embedding -charlm_embedding_forward = CharLMEmbeddings('news-forward') -charlm_embedding_backward = CharLMEmbeddings('news-backward') -``` - -Now instantiate the `StackedEmbeddings` class and pass it a list containing these three embeddings. - -```python -# now create the StackedEmbedding object that combines all embeddings -from flair.embeddings import StackedEmbeddings -stacked_embeddings = StackedEmbeddings(embeddings=[glove_embedding, charlm_embedding_forward, charlm_embedding_backward]) -``` - -That's it! Now just use this embedding like all the other embeddings, i.e. call the `embed()` method over your sentences. - -```python -# just embed a sentence using the StackedEmbedding as you would with any single embedding. -from flair.data import Sentence -sentence = Sentence('The grass is green .') -stacked_embeddings.embed(sentences=[sentence]) - -# now check out the embedded tokens. -for token in sentence: - print(token) - print(token.embedding) -``` - -Words are now embedding using a concatenation of three different embeddings. This means that the resulting embedding -vector is still a single Pytorch vector. - -## Reading an Evaluation Dataset - -Flair provides helper -methods to read common NLP datasets, such as the CoNLL-03 and CoNLL-2000 evaluation datasets, and the -CoNLL-U format. These might be interesting to you if you want to train your own sequence labelers. - -All helper methods for reading data are bundled in the `NLPTaskDataFetcher` class. One option for you is to follow -the instructions for putting the training data in the appropriate folder structure, and use the prepared functions. -For instance, if you want to use the CoNLL-03 data, get it from the task Web site -and place train, test and dev data in `/resources/tasks/conll_03/` as follows: - -``` -/resources/tasks/conll_03/eng.testa -/resources/tasks/conll_03/eng.testb -/resources/tasks/conll_03/eng.train -``` - -This allows the `NLPTaskDataFetcher` class to read the data into our data structures. Use the `NLPTask` enum to select -the dataset, as follows: - -```python -corpus: TaggedCorpus = NLPTaskDataFetcher.fetch_data(NLPTask.CONLL_03) -``` - -This gives you a `TaggedCorpus` object that contains the data. - -However, this only works if the relative folder structure perfectly matches the presets. If not - or you are using -a different dataset, you can still use the inbuilt functions to read different CoNLL formats: - -```python -# use your own data path -data_folder = 'path/to/your/data' - -# get training, test and dev data -sentences_train: List[Sentence] = NLPTaskDataFetcher.read_conll_sequence_labeling_data(data_folder + '/eng.train') -sentences_dev: List[Sentence] = NLPTaskDataFetcher.read_conll_sequence_labeling_data(data_folder + '/eng.testa') -sentences_test: List[Sentence] = NLPTaskDataFetcher.read_conll_sequence_labeling_data(data_folder + '/eng.testb') - -# return corpus -return TaggedCorpus(sentences_train, sentences_dev, sentences_test) -``` - -The `TaggedCorpus` contains a bunch of useful helper functions. For instance, you can downsample the data by calling -`downsample()` and passing a ratio. So, if you normally get a corpus like this: - -```python -original_corpus: TaggedCorpus = NLPTaskDataFetcher.fetch_data(NLPTask.CONLL_03) -``` - -then you can downsample the corpus, simply like this: - -```python -downsampled_corpus: TaggedCorpus = NLPTaskDataFetcher.fetch_data(NLPTask.CONLL_03).downsample(0.1) -``` - -If you print both corpora, you see that the second one has been downsampled to 10% of the data. - -```python -print("--- 1 Original ---") -print(original_corpus) - -print("--- 2 Downsampled ---") -print(downsampled_corpus) -``` - -This should print: - -```console ---- 1 Original --- -TaggedCorpus: 14987 train + 3466 dev + 3684 test sentences - ---- 2 Downsampled --- -TaggedCorpus: 1499 train + 347 dev + 369 test sentences -``` - - -## Training a Model - -Here is example code for a small NER model trained over CoNLL-03 data, using simple GloVe embeddings. -In this example, we downsample the data to 10% of the original data. - -```python -from flair.data import NLPTaskDataFetcher, TaggedCorpus, NLPTask -from flair.embeddings import WordEmbeddings -import torch - -# 1. get the corpus -corpus: TaggedCorpus = NLPTaskDataFetcher.fetch_data(NLPTask.CONLL_03).downsample(0.1) # remove the last bit to not downsample -print(corpus) - -# 2. what tag do we want to predict? -tag_type = 'ner' - -# 3. make the tag dictionary from the corpus -tag_dictionary = corpus.make_tag_dictionary(tag_type=tag_type) -print(tag_dictionary.idx2item) - -# initialize embeddings. In this case, simple GloVe embeddings -embeddings = WordEmbeddings('glove') - -# initialize sequence tagger -from flair.tagging_model import SequenceTaggerLSTM - -tagger = SequenceTaggerLSTM(hidden_size=256, embeddings=embeddings, tag_dictionary=tag_dictionary, - use_crf=True) - -# put model on cuda if GPU is available (i.e. much faster training) -if torch.cuda.is_available(): - tagger = tagger.cuda() - -# initialize trainer -from flair.trainer import TagTrain -trainer = TagTrain(tagger, corpus, tag_type=tag_type, test_mode=False) - -# run training for 5 epochs -trainer.train('resources/taggers/example-ner', mini_batch_size=32, max_epochs=5, save_model=True, - train_with_dev=True, anneal_mode=True) -``` - -Alternatively, try using a stacked embedding with charLM and glove, over the full data, for 150 epochs. -This will give you the state-of-the-art accuracy we report in the paper. To see the full code to reproduce experiments, -check [here](/resources/docs/EXPERIMENTS.md). \ No newline at end of file diff --git a/resources/docs/TUTORIAL_BASICS.md b/resources/docs/TUTORIAL_BASICS.md index a32cae5bf..c68862411 100644 --- a/resources/docs/TUTORIAL_BASICS.md +++ b/resources/docs/TUTORIAL_BASICS.md @@ -92,7 +92,7 @@ the word 'green'. This means that we've tagged this word as an entity of type co sentence[4].add_tag('ner', 'color') # print the sentence with all tags of this type -print(sentence.to_ner_string()) +print(sentence.to_tagged_string()) ``` This should print: diff --git a/resources/docs/TUTORIAL_TAGGING.md b/resources/docs/TUTORIAL_TAGGING.md index eb3288bf2..54719f2aa 100644 --- a/resources/docs/TUTORIAL_TAGGING.md +++ b/resources/docs/TUTORIAL_TAGGING.md @@ -9,9 +9,9 @@ This model was trained over the English CoNLL-03 task and can recognize 4 differ types. ```python -from flair.tagging_model import SequenceTaggerLSTM +from flair.tagging_model import SequenceTagger -tagger = SequenceTaggerLSTM.load('ner') +tagger = SequenceTagger.load('ner') ``` All you need to do is use the `predict()` method of the tagger on a sentence. This will add predicted tags to the tokens in the sentence. Lets use a sentence with two named @@ -24,29 +24,99 @@ sentence = Sentence('George Washington went to Washington .') tagger.predict(sentence) # print sentence with predicted tags -print(sentence.to_tag_string()) +print(sentence.to_tagged_string()) ``` This should print: ```console -George Washington went to Washington . +George Washington went to Washington . ``` You chose which pre-trained model you load by passing the appropriate -string you pass to the `load()` method of the `SequenceTaggerLSTM` class. Currently, the following pre-trained models +string you pass to the `load()` method of the `SequenceTagger` class. Currently, the following pre-trained models are provided (more coming): -| ID | Task + Training Dataset | Accuracy | -| ------------- | ------------- | ------------- | -| 'ner' | Conll-03 Named Entity Recognition (English) | **93.17** (F1) | -| 'chunk' | Conll-2000 Syntactic Chunking (English) | **96.74** (F1) | -| 'pos' | Ontonotes Part-of-Speech Tagging (English) | **98.06** (Accuracy) | +| ID | Task | Language| Training Dataset | Accuracy | +| ------------- | ------------- | ------------- |------------- |------------- | +| 'ner' | 4-class Named Entity Recognition | English | Conll-03 | **93.18** (F1) | +| 'ner-ontonotes' | 12-class Named Entity Recognition | English | Ontonotes | **89.62** (F1) | +| 'chunk' | Syntactic Chunking | English | Conll-2000 | **96.68** (F1) | +| 'pos' | Part-of-Speech Tagging | English | Ontonotes | **98.06** (Accuracy) | +| 'frame' | Semantic Frame Detection (***Experimental***)| English | Propbank 3.0 | **98.00** (Accuracy) | +| | | German | | | +| 'de-ner' | 4-class Named Entity Recognition | German | Conll-03 | **88.29** (F1) | +| 'de-ner-germeval' | 4+4-class Named Entity Recognition | German | Germeval | **84.53** (F1) | +| 'de-pos' | Part-of-Speech Tagging | German | Universal Dependency Treebank | **94.67** (Accuracy) | -So, if you want to use a `SequenceTaggerLSTM` that performs PoS tagging, instantiate the tagger as follows: + +So, if you want to use a `SequenceTagger` that performs PoS tagging, instantiate the tagger as follows: ```python -tagger = SequenceTaggerLSTM.load('pos') +tagger = SequenceTagger.load('pos') +``` + +## Tagging a German sentence + +As indicated in the list above, we also provide pre-trained models for languages other than English. Currently, we +support German and other languages are forthcoming. To tag a German sentence, just load the appropriate model: + +```python + +# load model +tagger = SequenceTagger.load('de-ner') + +# make German sentence +sentence = Sentence('George Washington ging nach Washington .') + +# predict NER tags +tagger.predict(sentence) + +# print sentence with predicted tags +print(sentence.to_tagged_string()) ``` +This should print: +```console +George Washington ging nach Washington . +``` + +## Experimental: Semantic Frame Detection + +For English, we now provide a pre-trained model that detects semantic frames in text, trained using Propbank 3.0 frames. +This provides a sort of word sense disambiguation for frame evoking words, and we are curious what researchers might +do with this. + +Here's an example: + +```python +# load model +tagger = SequenceTagger.load('frame') + +# make German sentence +sentence_1 = Sentence('George returned to Berlin to return his hat .') +sentence_2 = Sentence('He had a look at different hats .') + +# predict NER tags +tagger.predict(sentence_1) +tagger.predict(sentence_2) + +# print sentence with predicted tags +print(sentence_1.to_tagged_string()) +print(sentence_2.to_tagged_string()) +``` +This should print: + +```console +George returned to Berlin to return his hat . + +He had a look at different hats . +``` + +As we can see, the frame detector makes a distinction in sentence 1 between two different meanings of the word 'return'. +'return.01' means returning to a location, while 'return.02' means giving something back. + +Similarly, in sentence 2 the frame detector finds a light verb construction in which 'have' is the light verb and +'look' is a frame evoking word. + ## Tagging a List of Sentences diff --git a/resources/docs/TUTORIAL_TRAINING_A_MODEL.md b/resources/docs/TUTORIAL_TRAINING_A_MODEL.md index ed4a2ee82..1136c2c8b 100644 --- a/resources/docs/TUTORIAL_TRAINING_A_MODEL.md +++ b/resources/docs/TUTORIAL_TRAINING_A_MODEL.md @@ -87,11 +87,12 @@ In this example, we downsample the data to 10% of the original data. ```python from flair.data import NLPTaskDataFetcher, TaggedCorpus, NLPTask -from flair.embeddings import WordEmbeddings +from flair.embeddings import TextEmbeddings, WordEmbeddings, StackedEmbeddings, CharLMEmbeddings, CharacterEmbeddings +from typing import List import torch # 1. get the corpus -corpus: TaggedCorpus = NLPTaskDataFetcher.fetch_data(NLPTask.CONLL_03).downsample(0.1) # remove the last bit to not downsample +corpus: TaggedCorpus = NLPTaskDataFetcher.fetch_data(NLPTask.CONLL_03).downsample(0.1) print(corpus) # 2. what tag do we want to predict? @@ -101,26 +102,41 @@ tag_type = 'ner' tag_dictionary = corpus.make_tag_dictionary(tag_type=tag_type) print(tag_dictionary.idx2item) -# initialize embeddings. In this case, simple GloVe embeddings -embeddings = WordEmbeddings('glove') +# initialize embeddings +embedding_types: List[TextEmbeddings] = [ + + WordEmbeddings('glove') + + # comment in this line to use character embeddings + # , CharacterEmbeddings() + + # comment in these lines to use contextual string embeddings + # , + # CharLMEmbeddings('news-forward') + # , + # CharLMEmbeddings('news-backward') +] + +embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=embedding_types) # initialize sequence tagger -from flair.tagging_model import SequenceTaggerLSTM +from flair.tagging_model import SequenceTagger -tagger = SequenceTaggerLSTM(hidden_size=256, embeddings=embeddings, tag_dictionary=tag_dictionary, - use_crf=True) - -# put model on cuda if GPU is available (i.e. much faster training) +tagger: SequenceTagger = SequenceTagger(hidden_size=256, + embeddings=embeddings, + tag_dictionary=tag_dictionary, + tag_type=tag_type, + use_crf=True) if torch.cuda.is_available(): tagger = tagger.cuda() # initialize trainer from flair.trainer import TagTrain -trainer = TagTrain(tagger, corpus, tag_type=tag_type, test_mode=False) -# run training for 5 epochs -trainer.train('resources/taggers/example-ner', mini_batch_size=32, max_epochs=5, save_model=True, - train_with_dev=True, anneal_mode=True) +trainer: TagTrain = TagTrain(tagger, corpus, test_mode=True) + +trainer.train('resources/taggers/example-ner', mini_batch_size=32, max_epochs=150, save_model=False, + train_with_dev=False, anneal_mode=False) ``` Alternatively, try using a stacked embedding with charLM and glove, over the full data, for 150 epochs. diff --git a/resources/docs/TUTORIAL_WORD_EMBEDDING.md b/resources/docs/TUTORIAL_WORD_EMBEDDING.md index 8823a41dc..0c9adfe2a 100644 --- a/resources/docs/TUTORIAL_WORD_EMBEDDING.md +++ b/resources/docs/TUTORIAL_WORD_EMBEDDING.md @@ -49,14 +49,16 @@ You choose which pre-trained embeddings you load by passing the appropriate string you pass to the constructor of the `WordEmbeddings` class. Currently, the following static embeddings are provided (more coming): -| ID | Embedding | -| ------------- | ------------- | -| 'en-glove' (or 'glove') | GloVe embeddings | -| 'en-numberbatch' (or 'numberbatch') | [Numberbatch](https://github.com/commonsense/conceptnet-numberbatch) embeddings | -| 'en-extvec' (or 'extvec') | Komnios embeddings | -| 'en-crawl' (or 'crawl') | FastText embeddings over Web crawls | -| 'de-fasttext' | German FastText embeddings | -| 'sv-fasttext' | Swedish FastText embeddings | +| ID | Language | Embedding | +| ------------- | ------------- | ------------- | +| 'en-glove' (or 'glove') | English | GloVe embeddings | +| 'en-numberbatch' (or 'numberbatch') | English |[Numberbatch](https://github.com/commonsense/conceptnet-numberbatch) embeddings | +| 'en-extvec' (or 'extvec') | English |Komnios embeddings | +| 'en-crawl' (or 'crawl') | English |FastText embeddings over Web crawls | +| 'en-news' (or 'news') |English | FastText embeddings over news and wikipedia data | +| 'de-fasttext' | German |German FastText embeddings | +| 'de-numberbatch' |German | German Numberbatch embeddings | +| 'sv-fasttext' |Swedish | Swedish FastText embeddings | So, if you want to load German FastText embeddings, instantiate the method as follows: @@ -85,7 +87,7 @@ charlm_embedding_forward = CharLMEmbeddings('news-forward') # embed a sentence using CharLM. from flair.data import Sentence sentence = Sentence('The grass is green .') -charlm_embedding_forward.embed(sentences=[sentence]) +charlm_embedding_forward.embed(sentence) ``` You choose which embeddings you load by passing the appropriate @@ -125,7 +127,7 @@ embedder = CharacterEmbeddings() # embed a sentence using CharLM. from flair.data import Sentence sentence = Sentence('The grass is green .') -embedder.embed(sentences=[sentence]) +embedder.embed(sentence) ``` # Stacked Embeddings @@ -166,7 +168,7 @@ That's it! Now just use this embedding like all the other embeddings, i.e. call # just embed a sentence using the StackedEmbedding as you would with any single embedding. from flair.data import Sentence sentence = Sentence('The grass is green .') -stacked_embeddings.embed(sentences=[sentence]) +stacked_embeddings.embed(sentence) # now check out the embedded tokens. for token in sentence: diff --git a/train.py b/train.py index a7a06660b..8913c41be 100644 --- a/train.py +++ b/train.py @@ -4,7 +4,7 @@ import torch # 1. get the corpus -corpus: TaggedCorpus = NLPTaskDataFetcher.fetch_data(NLPTask.CONLL_03).downsample(0.1) +corpus: TaggedCorpus = NLPTaskDataFetcher.fetch_data(NLPTask.CONLL_03).downsample(0.01) print(corpus) # 2. what tag do we want to predict? @@ -20,29 +20,32 @@ WordEmbeddings('glove') # comment in this line to use character embeddings - # , CharacterEmbeddings() + , CharacterEmbeddings() # comment in these lines to use contextual string embeddings - # , - # CharLMEmbeddings('news-forward') - # , - # CharLMEmbeddings('news-backward') + , + CharLMEmbeddings('news-forward') + , + CharLMEmbeddings('news-backward') ] embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=embedding_types) # initialize sequence tagger -from flair.tagging_model import SequenceTaggerLSTM +from flair.tagging_model import SequenceTagger -tagger: SequenceTaggerLSTM = SequenceTaggerLSTM(hidden_size=256, embeddings=embeddings, tag_dictionary=tag_dictionary, - use_crf=True) +tagger: SequenceTagger = SequenceTagger(hidden_size=256, + embeddings=embeddings, + tag_dictionary=tag_dictionary, + tag_type=tag_type, + use_crf=True) if torch.cuda.is_available(): tagger = tagger.cuda() # initialize trainer from flair.trainer import TagTrain -trainer: TagTrain = TagTrain(tagger, corpus, tag_type=tag_type, test_mode=True) +trainer: TagTrain = TagTrain(tagger, corpus, test_mode=True) -trainer.train('resources/taggers/example-pos', mini_batch_size=32, max_epochs=150, save_model=False, +trainer.train('resources/taggers/example-ner', mini_batch_size=32, max_epochs=150, save_model=False, train_with_dev=False, anneal_mode=False) diff --git a/tutorial.ipynb b/tutorial.ipynb deleted file mode 100644 index 2d8f5493b..000000000 --- a/tutorial.ipynb +++ /dev/null @@ -1,619 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "collapsed": true - }, - "source": [ - "# Tutorial\n", - "\n", - "This tutorial takes you through the Flair library. \n", - "\n", - "## NLP base types\n", - "\n", - "The Sentence object is the central object to our library. It holds a Sentence, consisting of Tokens. To this object, various layers of linguistic annotation may be added. This is also the central object for embedding your text.\n", - "\n", - "Let's illustrate this with an example sentence." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Sentence: \"The grass is green .\" - 5 Tokens\n" - ] - } - ], - "source": [ - "# The sentence objects holds a sentence that we may want to embed\n", - "from flair.data import Sentence\n", - "\n", - "# Make a sentence object by passing a whitespace tokenized string\n", - "sentence = Sentence('The grass is green .')\n", - "\n", - "# Print the object to see what's in there\n", - "print(sentence)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Each word in a sentence is a Token object. You can directly access a token using the token_id. Each token has attributes, such as an id and a text." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Token: 4 green\n" - ] - } - ], - "source": [ - "print(sentence[4])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "You can also iterate over all tokens in a sentence." - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Token: 1 The\nToken: 2 grass\nToken: 3 is\nToken: 4 green\nToken: 5 .\n" - ] - } - ], - "source": [ - "for token in sentence:\n", - " print(token)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Tokens can also have tags, such as a named entity tag. In this example, we're adding an NER tag of type 'color' to \n", - "the word 'green' in the example sentence.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The grass is green .\n" - ] - } - ], - "source": [ - "# add a tag to a word in the sentence\n", - "sentence[4].add_tag('ner', 'color')\n", - "\n", - "# print the sentence with all tags of this type\n", - "print(sentence.to_ner_string())" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Tagging with Pre-Trained Models\n", - "\n", - "Now, lets use a pre-trained model for named entity recognition (NER). \n", - "This model was trained over the English CoNLL-03 task and can recognize 4 different entity\n", - "types.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": {}, - "outputs": [], - "source": [ - "from flair.tagging_model import SequenceTaggerLSTM\n", - "\n", - "tagger = SequenceTaggerLSTM.load('ner')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "All you need to do is use the `predict()` method of the tagger on a sentence. This will add predicted tags to the tokens\n", - "in the sentence. Lets use a sentence with two named\n", - "entities: " - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "George Washington went to Washington .\n" - ] - } - ], - "source": [ - "sentence = Sentence('George Washington went to Washington .')\n", - "\n", - "# predict NER tags\n", - "tagger.predict(sentence)\n", - "\n", - "# print sentence with predicted tags\n", - "print(sentence.to_tag_string())" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "You chose which pre-trained model you load by passing the appropriate \n", - "string you pass to the `load()` method of the `SequenceTaggerLSTM` class. Currently, the following pre-trained models\n", - "are provided (more coming): \n", - " \n", - "\n", - " 'ner' : Conll-03 Named Entity Recognition (English) \n", - "\n", - " 'chunk' : Conll-2000 Syntactic Chunking (English) \n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Embeddings\n", - "\n", - "We provide a set of classes with which you can embed the words in sentences in various ways. Note that all embedding \n", - "classes inherit from the `TextEmbeddings` class and implement the `embed()` method which you need to call \n", - "to embed your text. This means that for most users of Flair, the complexity of different embeddings remains hidden \n", - "behind this interface. Simply instantiate the embedding class you require and call `embed()` to embed your text.\n", - "\n", - "All embeddings produced with our methods are pytorch vectors, so they can be immediately used for training and \n", - "fine-tuning.\n", - "\n", - "### Classic Word Embeddings\n", - "\n", - "Classic word embeddings are static and word-level, meaning that each distinc word gets exactly one pre-computed \n", - "embedding. Most embeddings fall under this class, including the popular GloVe or Komnios embeddings. \n", - "\n", - "Simply instantiate the WordEmbeddings class and pass a string identifier of the embedding you wish to load. So, if \n", - "you want to use GloVe embeddings, pass the string 'glove' to the constructor: " - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [], - "source": [ - "# all embeddings inherit from the TextEmbeddings class. Init a simple glove embedding.\n", - "from flair.embeddings import WordEmbeddings\n", - "glove_embedding = WordEmbeddings('glove')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now, create an example sentence and call the embedding's `embed()` method. You always pass a list of sentences to \n", - "this method since some embedding types make use of batching to increase speed. So if you only have one sentence, \n", - "pass a list containing only one sentence:\n" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Token: 1 The\ntorch.Size([100])\nToken: 2 grass\ntorch.Size([100])\nToken: 3 is\ntorch.Size([100])\nToken: 4 green\ntorch.Size([100])\nToken: 5 .\ntorch.Size([100])\n" - ] - } - ], - "source": [ - "# embed a sentence using glove.\n", - "from flair.data import Sentence\n", - "sentence = Sentence('The grass is green .')\n", - "glove_embedding.embed(sentences=[sentence])\n", - "\n", - "# now check out the embedded tokens.\n", - "for token in sentence:\n", - " print(token)\n", - " print(token.embedding.size())" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "This prints out the tokens and their embeddings. GloVe embeddings are pytorch vectors of dimensionality 100.\n", - "\n", - "You choose which pre-trained embeddings you load by passing the appropriate \n", - "string you pass to the constructor of the `WordEmbeddings` class. Currently, the following static embeddings\n", - "are provided (more coming): \n", - " \n", - "'glove' : GloVe embeddings \n", - "\n", - "'extvec' : Komnios embeddings \n", - "\n", - "'ft-crawl' : FastText embeddings \n", - "\n", - "'ft-german' : German FastText embeddings \n", - "\n", - "So, if you want to load German FastText embeddings, instantiate the method as follows:\n" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [], - "source": [ - "german_embedding = WordEmbeddings('ft-german')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Contextual String Embeddings\n", - "\n", - "\n", - "Contextual string embeddings are [powerful embeddings](https://drive.google.com/file/d/17yVpFA7MmXaQFTe-HDpZuqw9fJlmzg56/view?usp=sharing)\n", - " that capture latent syntactic-semantic information that goes beyond\n", - "standard word embeddings. Key differences are: (1) they are trained without any explicit notion of words and\n", - "thus fundamentally model words as sequences of characters. And (2) they are **contextualized** by their\n", - "surrounding text, meaning that the *same word will have different embeddings depending on its\n", - "contextual use*.\n", - "\n", - "With Flair, you can use these embeddings simply by instantiating the appropriate embedding class, same as before:" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "FORWARD language mode loaded\non cuda:\nFalse\n" - ] - }, - { - "data": { - "text/plain": [ - "[]" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# the CharLMEmbedding also inherits from the TextEmbeddings class\n", - "from flair.embeddings import CharLMEmbeddings\n", - "charlm_embedding_forward = CharLMEmbeddings('news-forward')\n", - "\n", - "# embed a sentence using CharLM.\n", - "from flair.data import Sentence\n", - "sentence = Sentence('The grass is green .')\n", - "charlm_embedding_forward.embed(sentences=[sentence])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "You choose which embeddings you load by passing the appropriate \n", - "string you pass to the constructor of the `CharLMEmbeddings` class. Currently, the following contextual string\n", - " embeddings\n", - "are provided (more coming): \n", - " \n", - "| ID | Language | Embedding | \n", - "| ------------- | ------------- | ------------- |\n", - "| 'news-forward' | English | Forward LM embeddings over 1 billion word corpus |\n", - "| 'news-backward' | English | Backward LM embeddings over 1 billion word corpus |\n", - "| 'mix-forward' | English | Forward LM embeddings over mixed corpus (Web, Wikipedia, Subtitles) |\n", - "| 'mix-backward' | English | Backward LM embeddings over mixed corpus (Web, Wikipedia, Subtitles) |\n", - "| 'german-forward' | German | Forward LM embeddings over mixed corpus (Web, Wikipedia, Subtitles) |\n", - "| 'german-backward' | German | Backward LM embeddings over mixed corpus (Web, Wikipedia, Subtitles) |\n", - "\n", - "So, if you want to load embeddings from the English news backward LM model, instantiate the method as follows:\n" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "BACKWARD language mode loaded\non cuda:\nFalse\n" - ] - } - ], - "source": [ - "charlm_embedding_backward = CharLMEmbeddings('news-backward')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Character Embeddings\n", - "\n", - "Some embeddings - such as character-features - are not pre-trained but rather trained on the downstream task. Normally\n", - "this requires you to implement a [hierarchical embedding architecture](http://neuroner.com/NeuroNERengine_with_caption_no_figure.png). \n", - "\n", - "With Flair, you need not worry about such things. Just choose the appropriate\n", - "embedding class and character features will then automatically train during downstream task training. \n" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[]" - ] - }, - "execution_count": 1, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# the CharLMEmbedding also inherits from the TextEmbeddings class\n", - "from flair.embeddings import CharacterEmbeddings\n", - "embedder = CharacterEmbeddings()\n", - "\n", - "# embed a sentence using CharLM.\n", - "from flair.data import Sentence\n", - "sentence = Sentence('The grass is green .')\n", - "embedder.embed(sentences=[sentence])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "\n", - "## Stacked Embeddings\n", - "\n", - "Stacked embeddings are one of the most important concepts of this library. You can use them to combine different embeddings\n", - "together, for instance if you want to use both traditional embeddings together with contextual sting embeddings. \n", - "Stacked embeddings allow you to mix and match. We find that a combination of embeddings often gives best results. \n", - "\n", - "All you need to do is use the `StackedEmbeddings` class and instantiate it by passing a list of embeddings that you wish \n", - "to combine. For instance, lets combine classic GloVe embeddings with embeddings from a forward and backward \n", - "character language model.\n", - "\n", - "First, instantiate the three embeddings you wish to combine: " - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "FORWARD language mode loaded\non cuda:\nFalse\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "BACKWARD language mode loaded\non cuda:\nFalse\n" - ] - } - ], - "source": [ - "# the CharLMEmbedding also inherits from the TextEmbeddings class\n", - "from flair.embeddings import WordEmbeddings, CharLMEmbeddings\n", - "\n", - "# init GloVe embedding\n", - "glove_embedding = WordEmbeddings('glove')\n", - "\n", - "# init CharLM embedding\n", - "charlm_embedding_forward = CharLMEmbeddings('news-forward')\n", - "charlm_embedding_backward = CharLMEmbeddings('news-backward')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now instantiate the `StackedEmbeddings` class and pass it a list containing these three embeddings.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "# now create the StackedEmbedding object that combines all embeddings\n", - "from flair.embeddings import StackedEmbeddings\n", - "stacked_embeddings = StackedEmbeddings(embeddings=[glove_embedding, charlm_embedding_forward, charlm_embedding_backward])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "That's it! Now just use this embedding like all the other embeddings, i.e. call the `embed()` method over your sentences.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Token: 1 The\ntensor([-3.8194e-02, -2.4487e-01, 7.2812e-01, ..., -2.5692e-05,\n -5.9604e-03, -2.5547e-03])\nToken: 2 grass\ntensor([-8.1353e-01, 9.4042e-01, -2.4048e-01, ..., -6.7730e-05,\n -3.0360e-03, -1.3282e-02])\nToken: 3 is\ntensor([-5.4264e-01, 4.1476e-01, 1.0322e+00, ..., -6.5714e-03,\n -3.5937e-03, -1.4478e-03])\nToken: 4 green\ntensor([-6.7907e-01, 3.4908e-01, -2.3984e-01, ..., -2.2562e-05,\n -1.0895e-04, -4.3916e-03])\nToken: 5 .\ntensor([-3.3979e-01, 2.0941e-01, 4.6348e-01, ..., 4.1382e-05,\n -4.4364e-04, -2.5425e-02])\n" - ] - } - ], - "source": [ - "# just embed a sentence using the StackedEmbedding as you would with any single embedding.\n", - "from flair.data import Sentence\n", - "sentence = Sentence('The grass is green .')\n", - "stacked_embeddings.embed(sentences=[sentence])\n", - "\n", - "# now check out the embedded tokens.\n", - "for token in sentence:\n", - " print(token)\n", - " print(token.embedding)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Words are now embedding using a concatenation of three different embeddings. This means that the resulting embedding\n", - "vector is still a single Pytorch vector. \n", - "\n", - "\n", - "## Training a Model\n", - "\n", - "Here is example code for a small NER model trained over CoNLL-03 data, using simple GloVe embeddings.\n", - "In this example, we downsample the data to 10% of the original data. \n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3.0 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.4" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} \ No newline at end of file