Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 16, 2024
1 parent 0a09ec0 commit d00e207
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 24 deletions.
2 changes: 1 addition & 1 deletion comps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
PIIResponseDoc,
Audio2text,
DocSumDoc,
DocSumLLMParams
DocSumLLMParams,
)

# Constants
Expand Down
2 changes: 2 additions & 0 deletions comps/llms/summarization/tgi/langchain/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ export LLM_MODEL_ID=${your_hf_llm_model}
export MAX_INPUT_TOKENS=2048
export MAX_TOTAL_TOKENS=4096
```

Please make sure MAX_TOTAL_TOKENS should be larger than (MAX_INPUT_TOKENS + max_new_tokens + 50), 50 is reserved prompt length.

### 2.2 Build Docker Image
Expand Down Expand Up @@ -124,6 +125,7 @@ curl http://${your_ip}:9000/v1/chat/docsum \
```

#### 3.2.2 Long context summarization with "summary_type"

"summary_type" is set to be "stuff" by default, which will let LLM generate summary based on complete input text. In this case please carefully set `MAX_INPUT_TOKENS` and `MAX_TOTAL_TOKENS` according to your model and device memory, otherwise it may exceed LLM context limit and raise error when meet long context.

When deal with long context, you can set "summary_type" to one of "truncate", "map_reduce" and "refine" for better performance.
Expand Down
22 changes: 13 additions & 9 deletions comps/llms/summarization/tgi/langchain/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from langchain_community.llms import HuggingFaceEndpoint
from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter
from transformers import AutoTokenizer

from comps import CustomLogger, DocSumLLMParams, GeneratedDoc, ServiceType, opea_microservices, register_microservice
from comps.cores.mega.utils import get_access_token

Expand Down Expand Up @@ -100,18 +101,21 @@ async def llm_generate(input: DocSumLLMParams):
elif input.summary_type in ["truncate", "map_reduce", "refine"]:
if input.summary_type == "refine":
if MAX_TOTAL_TOKENS <= 2 * input.max_tokens + 128:
raise RuntimeError('In Refine mode, Please set MAX_TOTAL_TOKENS larger than (max_tokens * 2 + 128)')
max_input_tokens = min(MAX_TOTAL_TOKENS - 2 * input.max_tokens - 128, MAX_INPUT_TOKENS) # 128 is reserved token length for prompt
raise RuntimeError("In Refine mode, Please set MAX_TOTAL_TOKENS larger than (max_tokens * 2 + 128)")
max_input_tokens = min(
MAX_TOTAL_TOKENS - 2 * input.max_tokens - 128, MAX_INPUT_TOKENS
) # 128 is reserved token length for prompt
else:
if MAX_TOTAL_TOKENS <= input.max_tokens + 50:
raise RuntimeError('Please set MAX_TOTAL_TOKENS larger than max_tokens + 50)')
max_input_tokens = min(MAX_TOTAL_TOKENS - input.max_tokens - 50, MAX_INPUT_TOKENS) # 50 is reserved token length for prompt
chunk_size = min(input.chunk_size, max_input_tokens) if input.chunk_size > 0 else max_input_tokens
raise RuntimeError("Please set MAX_TOTAL_TOKENS larger than max_tokens + 50)")
max_input_tokens = min(
MAX_TOTAL_TOKENS - input.max_tokens - 50, MAX_INPUT_TOKENS
) # 50 is reserved token length for prompt
chunk_size = min(input.chunk_size, max_input_tokens) if input.chunk_size > 0 else max_input_tokens
chunk_overlap = input.chunk_overlap if input.chunk_overlap > 0 else int(0.1 * chunk_size)
text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
tokenizer=tokenizer,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap)
tokenizer=tokenizer, chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
if logflag:
logger.info(f"set chunk size to: {chunk_size}")
logger.info(f"set chunk overlap to: {chunk_overlap}")
Expand Down Expand Up @@ -197,7 +201,7 @@ async def stream_generator():
if logflag:
logger.info("\n\noutput_text:")
logger.info(output_text)

return GeneratedDoc(text=output_text, prompt=input.query)


Expand Down
2 changes: 1 addition & 1 deletion comps/llms/summarization/tgi/langchain/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
docarray[full]
fastapi
httpx==0.27.2
huggingface_hub
langchain #==0.1.12
langchain-huggingface
Expand All @@ -13,4 +14,3 @@ prometheus-fastapi-instrumentator
shortuuid
transformers
uvicorn
httpx==0.27.2
2 changes: 1 addition & 1 deletion comps/llms/summarization/vllm/langchain/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ curl http://${your_ip}:9000/v1/chat/docsum \
```

#### 3.2.2 Long context summarization with "summary_type"

"summary_type" is set to be "stuff" by default, which will let LLM generate summary based on complete input text. In this case please carefully set `MAX_INPUT_TOKENS` and `MAX_TOTAL_TOKENS` according to your model and device memory, otherwise it may exceed LLM context limit and raise error when meet long context.

When deal with long context, you can set "summary_type" to one of "truncate", "map_reduce" and "refine" for better performance.
Expand Down Expand Up @@ -164,4 +165,3 @@ curl http://${your_ip}:9000/v1/chat/docsum \
-d '{"query":"Text Embeddings Inference (TEI) is a toolkit for deploying and serving open source text embeddings and sequence classification models. TEI enables high-performance extraction for the most popular models, including FlagEmbedding, Ember, GTE and E5.", "max_tokens":32, "language":"en", "summary_type": "refine", "chunk_size": 2000}' \
-H 'Content-Type: application/json'
```

26 changes: 15 additions & 11 deletions comps/llms/summarization/vllm/langchain/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from langchain_community.llms import VLLMOpenAI
from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter
from transformers import AutoTokenizer

from comps import CustomLogger, DocSumLLMParams, GeneratedDoc, ServiceType, opea_microservices, register_microservice
from comps.cores.mega.utils import get_access_token

Expand Down Expand Up @@ -75,7 +76,7 @@
async def llm_generate(input: DocSumLLMParams):
if logflag:
logger.info(input)

if input.language in ["en", "auto"]:
templ = templ_en
templ_refine = templ_refine_en
Expand All @@ -101,18 +102,21 @@ async def llm_generate(input: DocSumLLMParams):
elif input.summary_type in ["truncate", "map_reduce", "refine"]:
if input.summary_type == "refine":
if MAX_TOTAL_TOKENS <= 2 * input.max_tokens + 128:
raise RuntimeError('In Refine mode, Please set MAX_TOTAL_TOKENS larger than (max_tokens * 2 + 128)')
max_input_tokens = min(MAX_TOTAL_TOKENS - 2 * input.max_tokens - 128, MAX_INPUT_TOKENS) # 128 is reserved token length for prompt
raise RuntimeError("In Refine mode, Please set MAX_TOTAL_TOKENS larger than (max_tokens * 2 + 128)")
max_input_tokens = min(
MAX_TOTAL_TOKENS - 2 * input.max_tokens - 128, MAX_INPUT_TOKENS
) # 128 is reserved token length for prompt
else:
if MAX_TOTAL_TOKENS <= input.max_tokens + 50:
raise RuntimeError('Please set MAX_TOTAL_TOKENS larger than max_tokens + 50)')
max_input_tokens = min(MAX_TOTAL_TOKENS - input.max_tokens - 50, MAX_INPUT_TOKENS) # 50 is reserved token length for prompt
chunk_size = min(input.chunk_size, max_input_tokens) if input.chunk_size > 0 else max_input_tokens
raise RuntimeError("Please set MAX_TOTAL_TOKENS larger than max_tokens + 50)")
max_input_tokens = min(
MAX_TOTAL_TOKENS - input.max_tokens - 50, MAX_INPUT_TOKENS
) # 50 is reserved token length for prompt
chunk_size = min(input.chunk_size, max_input_tokens) if input.chunk_size > 0 else max_input_tokens
chunk_overlap = input.chunk_overlap if input.chunk_overlap > 0 else int(0.1 * chunk_size)
text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
tokenizer=tokenizer,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap)
tokenizer=tokenizer, chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
if logflag:
logger.info(f"set chunk size to: {chunk_size}")
logger.info(f"set chunk overlap to: {chunk_overlap}")
Expand Down Expand Up @@ -171,7 +175,7 @@ async def llm_generate(input: DocSumLLMParams):
)
else:
raise NotImplementedError('Please specify the summary_type in "stuff", "truncate", "map_reduce", "refine"')

if input.streaming:

async def stream_generator():
Expand Down Expand Up @@ -199,7 +203,7 @@ async def stream_generator():
if logflag:
logger.info("\n\noutput_text:")
logger.info(output_text)

return GeneratedDoc(text=output_text, prompt=input.query)


Expand Down
2 changes: 1 addition & 1 deletion tests/llms/test_llms_summarization_tgi_langchain.sh
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ function validate_services() {
function validate_microservices() {
sum_port=5076
URL="http://${ip_address}:$sum_port/v1/chat/docsum"

validate_services \
"$URL" \
'text' \
Expand Down

0 comments on commit d00e207

Please sign in to comment.