Skip to content
This repository has been archived by the owner on Feb 12, 2024. It is now read-only.

Commit

Permalink
Feature/add sciphi llm (#105)
Browse files Browse the repository at this point in the history
* add sciphi llm provider

* modify rag

* update

* up

* up

* up

* up

* Add sci-phi llm

* add sciphi llm

---------

Co-authored-by: Owen Colegrove <[email protected]>
  • Loading branch information
emrgnt-cmplxty and emergentagi123 authored Oct 27, 2023
1 parent a4f9e67 commit cccb59b
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 16 deletions.
1 change: 1 addition & 0 deletions sciphi/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class LLMProviderName(Enum):
VLLM = "vllm"
LLAMACPP = "llamacpp"
LITE_LLM = "lite-llm"
SCIPHI = "sciphi"


class RAGProviderName(Enum):
Expand Down
19 changes: 7 additions & 12 deletions sciphi/interface/rag/sciphi_wiki.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from sciphi.interface.base import RAGInterface, RAGProviderConfig
from sciphi.interface.rag_interface_manager import rag_config, rag_provider

""


@dataclass
@rag_config
Expand Down Expand Up @@ -44,16 +42,13 @@ def get_contexts(self, prompts: list[str]) -> list[str]:
for raw_context in raw_contexts
]

def _format_wiki_context(self, context: str) -> str:
def _format_wiki_context(self, context: list) -> str:
"""Format the context for a prompt."""
truncated_context = context[0 : self.config.max_context]
wiki_context = dedent(truncated_context)
joined_context = [f"{ele['title']}\n{ele['text']}" for ele in context]
return "\n".join(
[
f"{SciPhiWikiRAGInterface.FORMAT_INDENT}{line}"
for line in wiki_context.split("\n")
]
)
f"{SciPhiWikiRAGInterface.FORMAT_INDENT}{dedent(entry)}"
for entry in joined_context
)[: self.config.max_context]


def wiki_search_api(
Expand All @@ -69,12 +64,12 @@ def wiki_search_api(
# Make the GET request with basic authentication and the query parameter
response = requests.get(
rag_api_base,
params={"queries": queries, "k": top_k},
params={"queries": queries, "top_k": top_k},
headers={"Authorization": f"Bearer {rag_api_key}"},
)

if response.status_code == 200:
return response.json()["match"] # Return the JSON response
return response.json() # Return the JSON response
if "detail" in response.json():
raise ValueError(
f'Unexpected response from API - {response.json()["detail"]}'
Expand Down
2 changes: 1 addition & 1 deletion sciphi/llm/embedding_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,6 @@ def sentencize(
"offset": abs_offsets,
}
document_sentences.append(row)
except:
except Exception:
continue
return pd.DataFrame(document_sentences)
121 changes: 121 additions & 0 deletions sciphi/llm/models/sciphi_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""A module for managing local vLLM models."""

import logging
from dataclasses import dataclass
from typing import Optional

from sciphi.core import LLMProviderName, RAGProviderName
from sciphi.interface.rag_interface_manager import RAGInterfaceManager
from sciphi.llm.config_manager import model_config
from sciphi.llm.models.vllm_llm import vLLM, vLLMConfig

logging.basicConfig(level=logging.INFO)


class SciPhiFormatter:
"""Formatter for SciPhi."""

INSTRUCTION_PREFIX = "### Instruction:\n"
INSTRUCTION_SUFFIX = "### Response:\n"
INIT_PARAGRAPH_TOKEN = "<paragraph>"
END_PARAGRAPH_TOKEN = "</paragraph>"

RETRIEVAL_TOKEN = "[Retrieval]"
NO_RETRIEVAL_TOKEN = "[No Retrieval]"
EVIDENCE_TOKEN = "[Continue to Use Evidence]"
RELEVANT_TOKEN = "[Relevant]"
PARTIALLY_SUPPORTED_TOKEN = "[Partially supported]"
SUFFIX_CRUFT = "[Utility:5]</s>"

@staticmethod
def format_prompt(input: str) -> str:
"""Format the prompt for the model."""
return f"{SciPhiFormatter.INSTRUCTION_PREFIX}\n{input}\n\n{SciPhiFormatter.INSTRUCTION_SUFFIX}"

@staticmethod
def remove_cruft(result: str) -> str:
return (
result.replace(SciPhiFormatter.RETRIEVAL_TOKEN, " ")
.replace(SciPhiFormatter.NO_RETRIEVAL_TOKEN, "")
.replace(SciPhiFormatter.EVIDENCE_TOKEN, " ")
.replace(SciPhiFormatter.SUFFIX_CRUFT, "")
.replace(SciPhiFormatter.RELEVANT_TOKEN, "")
.replace(SciPhiFormatter.PARTIALLY_SUPPORTED_TOKEN, "")
)


@model_config
@dataclass
class SciPhiConfig(vLLMConfig):
"""Configuration for local vLLM models."""

# Base
provider_name: LLMProviderName = LLMProviderName.SCIPHI
model_name: str = "selfrag/selfrag_llama2_7b"
temperature: float = 0.1
top_p: float = 1.0
top_k: int = 100
max_tokens_to_sample: int = 256
server_base: Optional[str] = None

# RAG Parameters
rag_provider_name: RAGProviderName = RAGProviderName.SCIPHI_WIKI
rag_provider_base: Optional[str] = None
rag_provider_token: Optional[str] = None
rag_top_k: int = 100


class SciPhiLLM(vLLM):
"""Configuration for local vLLM models."""

def __init__(
self,
config: SciPhiConfig,
) -> None:
super().__init__(config)
from vllm import SamplingParams

self.config: SciPhiConfig = config
self.sampling_params = SamplingParams(
temperature=config.temperature,
top_p=config.top_p,
top_k=config.top_k,
max_tokens=config.max_tokens_to_sample,
skip_special_tokens=False, # RAG Fine Tune includes special tokens
stop=SciPhiFormatter.INIT_PARAGRAPH_TOKEN, # Stops on Retrieval
)

self.rag_provider = RAGInterfaceManager.get_interface_from_args(
provider_name=config.rag_provider_name,
base=config.rag_provider_base or "http://localhost:8000",
token=config.rag_provider_token or "",
top_k=config.rag_top_k,
)

def get_chat_completion(self, messages: list[dict[str, str]]) -> str:
"""Get a completion from the SciPhi API based on the provided messages."""
raise NotImplementedError(
"Chat completion not yet implemented for SciPhi."
)

def get_instruct_completion(self, prompt: str) -> str:
"""Get an instruction completion from local SciPhi API."""
import openai

openai.api_base = self.config.server_base or ""
return openai.Completion.create(
model=self.config.model_name,
temperature=self.config.temperature,
top_p=self.config.top_p,
top_k=self.config.top_k,
max_tokens=self.config.max_tokens_to_sample,
prompt=prompt,
skip_special_tokens=False,
stop=SciPhiFormatter.INIT_PARAGRAPH_TOKEN,
)

def get_batch_instruct_completion(self, prompts: list[str]) -> list[str]:
"""Get batch instruction completion from local vLLM."""
raise NotImplementedError(
"Batch instruction completion not yet implemented for SciPhi."
)
2 changes: 1 addition & 1 deletion sciphi/llm/models/vllm_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(
self.config: vLLMConfig = config

def get_chat_completion(self, messages: list[dict[str, str]]) -> str:
"""Get a completion from the OpenAI API based on the provided messages."""
"""Get a completion from the vLLM API based on the provided messages."""
raise NotImplementedError(
"Chat completion not yet implemented for vLLM."
)
Expand Down
3 changes: 1 addition & 2 deletions sciphi/scripts/make_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,12 @@ def initialize_memmap(
shape_file.write(f"{initial_estimate},{embedding_dim}")

# Create initial memmap
memmap_array = np.memmap(
np.memmap(
file_name,
dtype="float32",
mode="w+",
shape=(initial_estimate, embedding_dim),
)
del memmap_array


def reconstitute_sentences_into_chunks(
Expand Down

0 comments on commit cccb59b

Please sign in to comment.