-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 9857eb7
Showing
13 changed files
with
752 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
PROJECT_NAME=nlp_completer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
__pycache__/ | ||
.ipynb_checkpoints/ | ||
|
||
# MacOS | ||
.DS_Store | ||
._* | ||
|
||
# Windows | ||
Thumbs.db | ||
ehthumbs.db | ||
$RECYCLE.BIN/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Infer masked token with BERT | ||
|
||
## Introduction | ||
|
||
BERT, or Bidirectional Encoder Representations from Transformers, is a new method of pre-training language representations which obtains state-of-the-art results on a wide array of Natural Language Processing (NLP) tasks. | ||
|
||
In order to train a deep bidirectional representation in BERT, the authors masked some input tokens randomly, and then predict those masked tokens. This procedure is known as a **masked LM** (MLM). | ||
|
||
In this repository, I have implemented **a basic script to infer the masked token** in a sentence. | ||
|
||
Read the BERT paper: https://arxiv.org/pdf/1810.04805.pdf | ||
|
||
## Code | ||
|
||
With this repository, there are some commands in Bash to run the Python code in a Docker container. Anyway, you can find all the Python scripts at [./code](./code) | ||
|
||
You can see the code execution step by step in the notebooks (see folder [./notebooks](./notebooks)). Additionally, the notebooks are available at Kaggle: | ||
- https://www.kaggle.com/dimasmunoz/infer-masked-token-with-bert | ||
|
||
## Commands | ||
|
||
> Note: These commands have been tested in MacOS and Git Bash (Windows). | ||
You can start/stop the docker container with these two commands: | ||
```sh | ||
sh manager.sh docker:run | ||
sh manager.sh docker:down | ||
``` | ||
|
||
Once the docker container is running, execute the BERT script as follows: | ||
```sh | ||
sh manager.sh bert | ||
``` | ||
|
||
Then, it will display a prompt where you can write your sentences with a masked token. For example: | ||
``` | ||
$ sh manager.sh bert | ||
Downloading: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 996k/996k [00:02<00:00, 485kB/s] | ||
Downloading: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 625/625 [00:00<00:00, 176kB/s] | ||
Downloading: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 714M/714M [01:07<00:00, 10.7MB/s] | ||
Input text: This is a [MASK] model | ||
---------------------------------------- | ||
mathematical 0.04100513085722923 | ||
model 0.03972144424915314 | ||
single 0.01666860282421112 | ||
similar 0.014901496469974518 | ||
common 0.01419056300073862 | ||
======================================== | ||
... | ||
``` | ||
|
||
Have fun! ᕙ (° ~ ° ~) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import numpy as np | ||
import torch | ||
from transformers import BertTokenizer, BertForMaskedLM | ||
|
||
MODEL_NAME = 'bert-base-multilingual-cased' | ||
|
||
|
||
def get_topk_predictions(model, tokenizer, text, topk=5): | ||
encoded_input = tokenizer(text, return_tensors='pt') | ||
logits = model(encoded_input['input_ids'], | ||
encoded_input['token_type_ids'], | ||
encoded_input['attention_mask'], | ||
masked_lm_labels=None)[0] | ||
|
||
logits = logits.squeeze(0) | ||
probs = torch.softmax(logits, dim=-1) | ||
|
||
mask_cnt = 0 | ||
token_ids = encoded_input['input_ids'][0] | ||
|
||
top_preds = [] | ||
|
||
for idx, _ in enumerate(token_ids): | ||
if token_ids[idx] == tokenizer.mask_token_id: | ||
mask_cnt += 1 | ||
|
||
topk_prob, topk_indices = torch.topk(probs[idx, :], topk) | ||
topk_indices = topk_indices.cpu().numpy() | ||
topk_tokens = tokenizer.convert_ids_to_tokens(topk_indices) | ||
for prob, tok_str, tok_id in zip(topk_prob, topk_tokens, topk_indices): | ||
top_preds.append({'token_str': tok_str, | ||
'token_id': tok_id, | ||
'probability': float(prob)}) | ||
|
||
return top_preds | ||
|
||
def display_topk_predictions(model, tokenizer, text): | ||
top_preds = get_topk_predictions(model, tokenizer, text) | ||
for item in top_preds: | ||
print('{} {}'.format(item['token_str'], item['probability'])) | ||
|
||
if __name__ == '__main__': | ||
tokenizer = BertTokenizer.from_pretrained(MODEL_NAME) | ||
model = BertForMaskedLM.from_pretrained(MODEL_NAME) | ||
|
||
while True: | ||
print('') | ||
text = input('Input text: ').strip() | ||
print('-' * 40) | ||
display_topk_predictions(model, tokenizer, text) | ||
print('=' * 40) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
version: '3.6' | ||
|
||
services: | ||
python: | ||
container_name: python_${PROJECT_USER} | ||
build: | ||
context: ./misc/dockerfiles/python | ||
dockerfile: Dockerfile | ||
env_file: | ||
- .env | ||
volumes: | ||
- ./code:/App/code | ||
networks: | ||
- backend | ||
|
||
networks: | ||
backend: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
#!/bin/bash | ||
|
||
MANAGER_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" | ||
|
||
cd $MANAGER_DIR | ||
case $1 in | ||
docker:run) | ||
${MANAGER_DIR}/misc/bin/run.sh | ||
;; | ||
docker:down) | ||
${MANAGER_DIR}/misc/bin/down.sh | ||
;; | ||
bert) | ||
${MANAGER_DIR}/misc/bin/python-container.sh bert | ||
;; | ||
*) | ||
echo "Error: The command does not exist!!" | ||
exit 1 | ||
;; | ||
esac |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
#!/bin/bash | ||
|
||
SCRIPT_PATH="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" | ||
source ${SCRIPT_PATH}/environment-vars.sh | ||
|
||
cd $SCRIPT_PATH/../../ | ||
docker-compose -f docker-compose.yml -p $PROJECT_USER down | ||
cd - |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
#!/usr/bin/env bash | ||
|
||
SCRIPT_PATH="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" | ||
|
||
# Including .env file | ||
set -o allexport | ||
source $SCRIPT_PATH/../../.env | ||
set +o allexport | ||
|
||
# Fix $USER variable if needed (Git Bash Windows) | ||
if [[ "$OSTYPE" == "msys" ]]; then | ||
export USER="$(whoami)" | ||
fi | ||
|
||
# Add project variable | ||
export PROJECT_USER="${PROJECT_NAME}_${USER}" | ||
|
||
# Function to execute Docker commands | ||
function idocker() { | ||
# $1: container name | ||
# $2: command, e.g. "sh" or "python" | ||
# $3: parameter of the previous command | ||
|
||
params=() | ||
i=0 | ||
for c in "$@"; do | ||
if [[ $i -lt 4 ]]; then | ||
params+=("${c}") | ||
else | ||
params+=("\"${c}\"") | ||
fi | ||
|
||
let i=i+1 | ||
done | ||
|
||
if [[ "$OSTYPE" == "msys" ]]; then | ||
# Add winpty if Windows Git Bash | ||
command="winpty docker exec -it ${params[@]}" | ||
else | ||
command="docker exec -it ${params[@]}" | ||
fi | ||
|
||
eval $command | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
#!/bin/bash | ||
|
||
SCRIPT_PATH="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" | ||
source ${SCRIPT_PATH}/environment-vars.sh | ||
|
||
if [ $1 == 'bert' ] | ||
then | ||
idocker python_${PROJECT_USER} python code/bert_infer.py | ||
else | ||
idocker python_${PROJECT_USER} bash | ||
fi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
#!/bin/bash | ||
|
||
SCRIPT_PATH="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" | ||
source ${SCRIPT_PATH}/environment-vars.sh | ||
|
||
if [ -z ${no_down} ] && [ "$no_down" = true ] | ||
then | ||
$SCRIPT_PATH/down.sh | ||
fi | ||
|
||
cd $SCRIPT_PATH/../../ | ||
docker-compose -f docker-compose.yml -p $PROJECT_USER up -d --build | ||
cd - |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
FROM python:3.7.5-buster | ||
|
||
# Workdir | ||
ENV WORKDIR /App | ||
WORKDIR $WORKDIR | ||
|
||
# Install/update dependencies | ||
RUN apt-get update && \ | ||
apt-get install -y apt-utils autoconf build-essential curl git libssl-dev unzip vim zip gnupg wget && \ | ||
apt-get -y install build-essential | ||
RUN python -m pip install --upgrade pip | ||
|
||
# Install Python libraries | ||
COPY requirements.txt requirements.txt | ||
RUN pip install -r requirements.txt | ||
|
||
# Do not stop the container | ||
ENTRYPOINT ["tail", "-f", "/dev/null"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
numpy==1.18.4 | ||
torchvision==0.5.0 | ||
transformers==3.0.2 |
Oops, something went wrong.