Skip to content

Commit

Permalink
iter
Browse files Browse the repository at this point in the history
  • Loading branch information
glemaitre committed Dec 10, 2023
1 parent 2719553 commit be0373e
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
16 changes: 10 additions & 6 deletions rag_based_llm/prompt/_agent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import re
from textwrap import wrap

import tiktoken

Expand Down Expand Up @@ -40,14 +40,18 @@ def __call__(self, query, **prompt_kwargs):
).intersection(api["source"] for api in api_lexical_context)
context = "\n".join(
f"source: {api['source']} \n content: {api['text']}\n"
for api in api_semantic_context if api['source'] in api_common_sources
for api in api_semantic_context
if api["source"] in api_common_sources
)
prompt = (
"[INST] Answer to the query related to scikit-learn using the following "
"pair of content and source. Be succinct. Add a link to the source(s) used."
"pair of content and source. Be succinct. \n"
f"query: {query}\n"
f"context: {context}[/INST]."
f"context: {context} [/INST]."
)
response = self.llm(trim(prompt, max_tokens=max_tokens), **prompt_kwargs)
return response["choices"][0]["text"]

return (
"\n".join(wrap(response["choices"][0]["text"].strip(), width=80))
+ "\n\nSource(s):\n"
+ "\n".join(api_common_sources)
)
3 changes: 1 addition & 2 deletions scripts/exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@
response = agent(query, max_tokens=4096, temperature=0.1)

# %%
from textwrap import wrap
print("\n".join(wrap(response, width=80)))
print(response)

# %%

0 comments on commit be0373e

Please sign in to comment.