Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adapted megatronlm server implementation for interacting with lm eval… #8

Open
wants to merge 30 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
3527e6e
Adapted megatronlm server implementation for interacting with lm eval…
KlaudiaTH Aug 3, 2023
62c0def
Merge branch 'main' into megatron_lmeval_server
KlaudiaTH Aug 3, 2023
630a38f
Removed some comments from text_generation_server.py
KlaudiaTH Aug 3, 2023
0d7f8fd
Minor correction
KlaudiaTH Aug 3, 2023
7a536b8
integrated first methods of hf tokenizer
Aug 4, 2023
db0c036
added tokenizer
Aug 6, 2023
685be55
bugfix
Aug 7, 2023
12948f7
retrieve eod id from tokenizer
Aug 7, 2023
44d1bbb
bugfix 2
Aug 7, 2023
e6e6b75
bugfix 3
Aug 7, 2023
31ebade
bugfix 4
Aug 7, 2023
5773be8
bugfix 4
Aug 7, 2023
1cfe037
_HFTokenizer typo
Aug 7, 2023
2dd938d
added functions
Aug 8, 2023
abe9d7a
integrated pretrained hf tokenizer
Aug 11, 2023
c14fefe
Add metadata query
janEbert Aug 11, 2023
fa6c3fb
bugfix PretrainedHFTokenizer
Aug 11, 2023
0530610
bugfix
Aug 11, 2023
0c8461d
Merge remote-tracking branch 'origin/add-gptx-tokenizers' into megatr…
KlaudiaTH Aug 12, 2023
f42ded1
MegatronLM server API adaption. Example sh files.
KlaudiaTH Aug 12, 2023
341f53a
Adaptations for greedy until generation; minor fixes
KlaudiaTH Aug 16, 2023
5bba337
API and SP tokenizer adaptions for handling continuations
KlaudiaTH Aug 18, 2023
10d7fe8
Server: Don't return padding tokens
KlaudiaTH Aug 24, 2023
ac09fe4
Corrected is_max_logprobs slicing
KlaudiaTH Aug 29, 2023
526ec2a
Added option for padding to seq_len during tokenization and generation
KlaudiaTH Aug 29, 2023
22aa758
Minor fix
KlaudiaTH Aug 29, 2023
d704b30
Corrected monolingual bpe sp 32k example
KlaudiaTH Sep 5, 2023
ab63e91
Server: Add argument for specifying HTTP port
KlaudiaTH Oct 16, 2023
c0cb866
Merge branch 'main' into megatron_lmeval_server
KlaudiaTH Nov 3, 2023
5b214f9
training.py: Import vision modules only when needed
KlaudiaTH Nov 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@ build
*~
slurm*
logs
.vscode/
apex/
48 changes: 48 additions & 0 deletions examples/run_text_generation_server_2_6B.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#!/bin/bash

set -x -e

export CUDA_DEVICE_MAX_CONNECTIONS=1
export TORCHELASTIC_ERROR_FILE=/tmp/torch-elastic-error.json
export NCCL_ASYNC_ERROR_HANDLING=1
export NCCL_IB_TIMEOUT=50
export UCX_RC_TIMEOUT=4s
export NCCL_IB_RETRY_CNT=10
export NCCL_SOCKET_IFNAME=ib0
export NCCL_DEBUG=INFO
export MASTER_ADDR=127.0.0.1
export MASTER_PORT=60234

export CMD=" \
tools/run_text_generation_server.py \
--load /p/scratch/opengptx-elm/ali5/opengpt/megatron-lm/2023-07-27_17-52-51/output_dir/2_6B_monolingual_eng-bpe_sp_32768_10_nope.sbatch/checkpoints \
--tokenizer-model /p/scratch/opengptx-elm/data/datasources_opgptx/data_quality_experiments_datasets/ablations_studies/monolingual_en/70B_10/tokenizer_training/bpe/sp/32768_10/bpe_tokenizer.model \
--tokenizer-type OpenGPTX-SPTokenizer \
--pipeline-model-parallel-size 1 \
--tensor-model-parallel-size 2 \
--num-layers 32 \
--hidden-size 2560 \
--num-attention-heads 32 \
--max-position-embeddings 2048 \
--bf16 \
--micro-batch-size 1 \
--seq-length 2048 \
--out-seq-length 2048 \
--temperature 0.8 \
--top_p 0.5 \
--seed 42 \
--position-embedding-type none \
--no-position-embedding \
--use-flash-attn \
--reset-attention-mask \
--reset-position-ids"


export LAUNCHER="python -u -m torch.distributed.run \
--nproc_per_node 2 \
--nnodes 1 \
--node_rank 0 \
--master_addr localhost \
--master_port 6000"

bash -c "$LAUNCHER $CMD"
110 changes: 93 additions & 17 deletions examples/run_text_generation_server_345M.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,105 @@ DISTRIBUTED_ARGS="--nproc_per_node 1 \
--master_addr localhost \
--master_port 6000"

CHECKPOINT=<Path to checkpoint (e.g /345m)>
VOCAB_FILE=<Path to vocab.json (e.g. /gpt2-vocab.json)>
MERGE_FILE=<Path to merges.txt (e.g. /gpt2-merges.txt)>

export CUDA_DEVICE_MAX_CONNECTIONS=1
export TORCHELASTIC_ERROR_FILE=/tmp/torch-elastic-error.json
export NCCL_ASYNC_ERROR_HANDLING=1
export NCCL_IB_TIMEOUT=50
export UCX_RC_TIMEOUT=4s
export NCCL_IB_RETRY_CNT=10
export NCCL_SOCKET_IFNAME=ib0
export NCCL_DEBUG=INFO
export MASTER_ADDR=127.0.0.1
export MASTER_PORT=60234

pip install flask-restful
# pip install flask-restful

torchrun $DISTRIBUTED_ARGS tools/run_text_generation_server.py \
--tensor-model-parallel-size 1 \
python tools/run_text_generation_server.py \
--load /p/project/opengptx-elm/thellmann1/workdir/checkpoint_conversion_meglm_test/meglm/2023-07-27_18-00-00/output_dir/340M_meglm_8105626.sbatch/checkpoints \
--tokenizer-model /p/project/opengptx-elm/thellmann1/workdir/checkpoint_conversion_meglm_test/meglm/2023-07-27_18-00-00/output_dir/340M_meglm_8105626.sbatch/converted_checkpoints/iter_0015000/tokenizer.model \
--tokenizer-type OpenGPTX-SPTokenizer \
--pipeline-model-parallel-size 1 \
--tensor-model-parallel-size 1 \
--num-layers 24 \
--hidden-size 1024 \
--load ${CHECKPOINT} \
--num-attention-heads 16 \
--max-position-embeddings 1024 \
--tokenizer-type GPT2BPETokenizer \
--fp16 \
--max-position-embeddings 2048 \
--bf16 \
--micro-batch-size 1 \
--seq-length 1024 \
--out-seq-length 1024 \
--temperature 1.0 \
--vocab-file $VOCAB_FILE \
--merge-file $MERGE_FILE \
--top_p 0.9 \
--seed 42
--seq-length 2048 \
--out-seq-length 2048 \
--temperature 0.8 \
--top_p 0.5 \
--seed 42 \
--no-position-embedding \
--position-embedding-type rotary \
--use-flash-attn \
--reset-attention-mask \
--reset-position-ids



# --load /p/project/opengptx-elm/thellmann1/workdir/checkpoint_conversion_meglm_test/meglm/2023-07-17_16-38-00/output_dir/340M_monolingual_en_sp_bpe_32768_10.sbatch/checkpoints \
# --tokenizer-model /p/project/opengptx-elm/thellmann1/workdir/checkpoint_conversion_meglm_test/meglm/2023-07-17_16-38-00/output_dir/340M_monolingual_en_sp_bpe_32768_10.sbatch/converted_checkpoints/iter_0001525/tokenizer.model \
# --tokenizer-type SentencePieceTokenizer \
# --pipeline-model-parallel-size 1 \
# --tensor-model-parallel-size 1 \
# --num-layers 24 \
# --hidden-size 1024 \
# --num-attention-heads 16 \
# --max-position-embeddings 2048 \
# --position-embedding-type alibi \
# --no-position-embedding \
# --bf16 \
# --micro-batch-size 1 \
# --seq-length 2048 \
# --out-seq-length 2048 \
# --temperature 0.8 \
# --top_p 0.5 \
# --seed 42

# --load /p/project/opengptx-elm/thellmann1/workdir/checkpoint_conversion_meglm_test/meglm/2023-07-13_13-00-00/output_dir/2_6B_multilingual-unigram_sp_32768_10.sbatch/checkpoints \
# --vocab-file='/p/project/opengptx-elm/thellmann1/workdir/checkpoint_conversion_meglm_test/meglm/2023-07-13_13-00-00/output_dir/2_6B_multilingual-unigram_sp_32768_10.sbatch/converted_checkpoints/iter_0009537/vocab.json' \
# --merge-file='/p/project/opengptx-elm/thellmann1/workdir/checkpoint_conversion_meglm_test/meglm/2023-07-13_13-00-00/output_dir/2_6B_multilingual-unigram_sp_32768_10.sbatch/converted_checkpoints/iter_0009537/merges.txt' \
# --tokenizer-type GPT2BPETokenizer \
# --pipeline-model-parallel-size 2 \
# --tensor-model-parallel-size 2 \
# --num-layers 12 \
# --hidden-size 768 \
# --num-attention-heads 12 \
# --max-position-embeddings 2048 \
# --position-embedding-type alibi \
# --bf16 \
# --micro-batch-size 32 \
# --seq-length 2048 \
# --out-seq-length 2048 \
# --temperature 0.8 \
# --top_p 0.5 \
# --seed 42



# --load /p/project/opengptx-elm/thellmann1/workdir/checkpoint_conversion_meglm_test/meglm/2023-07-27_18-00-00/output_dir/340M_meglm_8105626.sbatch/checkpoints/iter_0015000 \
# --tokenizer-model /p/project/opengptx-elm/thellmann1/workdir/checkpoint_conversion_meglm_test/meglm/2023-07-27_18-00-00/output_dir/340M_meglm_8105626.sbatch/converted_checkpoints/iter_0015000/tokenizer.mode
# --tokenizer-type OpenGPTX-SPTokenizer \
# --pipeline-model-parallel-size 1 \
# --tensor-model-parallel-size 1 \
# --num-layers 24 \
# --hidden-size 1024 \
# --num-attention-heads 16 \
# --max-position-embeddings 2048 \
# --bf16 \
# --micro-batch-size 1 \
# --seq-length 2048 \
# --out-seq-length 2048 \
# --temperature 0.8 \
# --top_p 0.5 \
# --seed 42 \
# --distributed-backend nccl
# --position-embedding-type rotary \
# --use-flash-attn \
# --reset-attention-mask \
# --reset-position-ids \
# --no-position-embedding \

49 changes: 49 additions & 0 deletions examples/run_text_generation_server_iter_0001525.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#!/bin/bash

set -x -e

export CUDA_DEVICE_MAX_CONNECTIONS=1
export TORCHELASTIC_ERROR_FILE=/tmp/torch-elastic-error.json
export NCCL_ASYNC_ERROR_HANDLING=1
export NCCL_IB_TIMEOUT=50
export UCX_RC_TIMEOUT=4s
export NCCL_IB_RETRY_CNT=10
export NCCL_SOCKET_IFNAME=ib0
export NCCL_DEBUG=INFO
export MASTER_ADDR=127.0.0.1
export MASTER_PORT=60234

SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]:-${(%):-%x}}" )" &> /dev/null && pwd )


export CMD=" \
$SCRIPT_DIR/../tools/run_text_generation_server.py \
--load /p/project/opengptx-elm/thellmann1/workdir/checkpoint_conversion_meglm_test/meglm/2023-07-17_16-38-00/output_dir/340M_monolingual_en_sp_bpe_32768_10.sbatch/checkpoints \
--tokenizer-model /p/scratch/opengptx-elm/data/datasources_opgptx/data_quality_experiments_datasets/ablations_studies/monolingual_en/70B_10/tokenizer_training/bpe/sp/32768_10/bpe_tokenizer.model \
--tokenizer-type SentencePieceTokenizer \
--pipeline-model-parallel-size 1 \
--tensor-model-parallel-size 1 \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
--max-position-embeddings 2048 \
--bf16 \
--micro-batch-size 5 \
--seq-length 2048 \
--out-seq-length 2048 \
--temperature 0.8 \
--top_p 0.5 \
--seed 42 \
--position-embedding-type alibi \
--no-position-embedding \
"


export LAUNCHER="python -u -m torch.distributed.run \
--nproc_per_node 1 \
--nnodes 1 \
--node_rank 0 \
--master_addr localhost \
--master_port 6000"

bash -c "$LAUNCHER $CMD"
37 changes: 28 additions & 9 deletions megatron/text_generation/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
tokenize_prompts,
detokenize_generations)


def generate_and_post_process(model,
prompts=None,
tokens_to_generate=0,
Expand All @@ -29,12 +30,14 @@ def generate_and_post_process(model,
stop_on_double_eol=False,
stop_on_eol=False,
prevent_newline_after_colon=False,
random_seed=-1):
random_seed=-1,
return_is_max_logprobs=False,
):
"""Run inference and post-process outputs, i.e., detokenize,
move to cpu and convert to list."""

# Main inference.
tokens, lengths, output_log_probs = generate(
tokens, lengths, output_log_probs, is_max_logprobs = generate(
model,
prompts=prompts,
tokens_to_generate=tokens_to_generate,
Expand All @@ -49,7 +52,9 @@ def generate_and_post_process(model,
stop_on_double_eol=stop_on_double_eol,
stop_on_eol=stop_on_eol,
prevent_newline_after_colon=prevent_newline_after_colon,
random_seed=random_seed)
random_seed=random_seed,
return_is_max_logprobs=True,
)

# Only post-process on first stage.
if mpu.is_pipeline_first_stage():
Expand All @@ -61,11 +66,15 @@ def generate_and_post_process(model,
for i, (prob, seg) in enumerate(zip(output_log_probs, prompts_plus_generations_segments)):
output_log_probs[i] = prob[:len(seg)-1]

return prompts_plus_generations, prompts_plus_generations_segments, \
output_log_probs, tokens
# is_max_logprobs an das results wenn return_is_max_logprobs true
result = prompts_plus_generations, prompts_plus_generations_segments, output_log_probs, tokens
if return_is_max_logprobs:
result = result + (is_max_logprobs,)
return result

return None


def generate(model,
prompts=None,
tokens_to_generate=0,
Expand All @@ -80,7 +89,9 @@ def generate(model,
stop_on_double_eol=False,
stop_on_eol=False,
prevent_newline_after_colon=False,
random_seed=-1):
random_seed=-1,
return_is_max_logprobs=False,
):
"""Given prompts and input parameters, run inference and return:
tokens: prompts plus the generated tokens.
lengths: length of the prompt + generations. Note that we can
Expand Down Expand Up @@ -124,12 +135,17 @@ def generate(model,
context_tokens_tensor, context_length_tensor = tokenize_prompts(
prompts=prompts, tokens_to_generate=tokens_to_generate, add_BOS=add_BOS)

# wenn ich nichts generiere d.h. für loglikelihood Anfragen (da interessieren mich nur die logprobs)
if tokens_to_generate == 0:
return score_and_return_on_first_stage(
model, context_tokens_tensor, context_length_tensor)
model, context_tokens_tensor, context_length_tensor, return_is_max_logprobs=return_is_max_logprobs)

if return_is_max_logprobs:
raise NotImplementedError("return_is_max_logprobs only implemented for tokens_to_generate == 0")

# Main inference function.
# Note that the outputs are available on the first stage.
# Hier kommt neben den logprobs auch was generiert wurde zurück
return generate_tokens_probs_and_return_on_first_stage(
model, context_tokens_tensor, context_length_tensor,
return_output_log_probs=return_output_log_probs,
Expand All @@ -141,7 +157,9 @@ def generate(model,
use_eod_token_for_early_termination=use_eod_token_for_early_termination,
stop_on_double_eol=stop_on_double_eol,
stop_on_eol=stop_on_eol,
prevent_newline_after_colon=prevent_newline_after_colon)
prevent_newline_after_colon=prevent_newline_after_colon,
)


def beam_search_and_post_process(model,
prompts=None,
Expand Down Expand Up @@ -170,10 +188,11 @@ def beam_search_and_post_process(model,
lengths = tokens.size(1)*torch.ones(beam_size, dtype=torch.int64, device=torch.cuda.current_device())
tokens, prompts_plus_generations, prompts_plus_generations_segments = detokenize_generations(tokens, lengths, True)
scores = scores.cpu().numpy().tolist()
return prompts_plus_generations, prompts_plus_generations_segments, scores
return prompts_plus_generations, prompts_plus_generations_segments, scores, tokens

return None


def beam_search(model, prompts=None, tokens_to_generate=0, beam_size=0, add_BOS=False, stop_token=50256, num_return_gen=1, length_penalty=1, prevent_newline_after_colon=False):
# Make sure input params are avaialble to all ranks.
values = [tokens_to_generate,
Expand Down
Loading