Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix cohere api #250

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
wandb:
run_name: "cohere/command-r-plus-v1:0" # use run_name defined above
run_name: "cohere/command-r-08-2024" # use run_name defined above

# if you don't use api, please set "api" as "false"
# if you use api, please select from "openai", "anthropic", "google", "cohere", "vllm", "mistral", "bedrock"
api: "amazon_bedrock"
api: "cohere"
batch_size: 1 # vllmは256, apiは32を推奨
# inference_interval

generator:
preamble: "" # Put a custom preamble here if needed

model:
use_wandb_artifacts: false
pretrained_model_name_or_path: "cohere.command-r-plus-v1:0" #if you use openai api, put the name of model
pretrained_model_name_or_path: "command-r-08-2024" # put the name of Cohere model here
size_category: "api"
size: null
release_date: "4/4/2024"
release_date: "8/30/2024"
18 changes: 18 additions & 0 deletions configs/config-cohere-command-r-plus-08-2024.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
wandb:
run_name: "cohere/command-r-plus-08-2024" # use run_name defined above

# if you don't use api, please set "api" as "false"
# if you use api, please select from "openai", "anthropic", "google", "cohere", "vllm", "mistral", "bedrock"
api: "cohere"
batch_size: 1 # vllmは256, apiは32を推奨
# inference_interval

generator:
preamble: "" # Put a custom preamble here if needed

model:
use_wandb_artifacts: false
pretrained_model_name_or_path: "command-r-plus-08-2024" # put the name of Cohere model here
size_category: "api"
size: null
release_date: "8/30/2024"
15 changes: 7 additions & 8 deletions scripts/llm_inference_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,11 @@
)
# from langchain_aws import ChatBedrock
from langchain_anthropic import ChatAnthropic
from langchain_cohere import ChatCohere
from botocore.exceptions import ClientError
import boto3
from botocore.config import Config

# from langchain_cohere import Cohere


import json
import boto3
Expand Down Expand Up @@ -188,12 +187,12 @@ def get_llm_inference_engine():
**cfg.generator,
)

# elif api_type == "cohere":
# llm = Cohere(
# model=cfg.model.pretrained_model_name_or_path,
# cohere_api_key=os.environ["COHERE_API_KEY"],
# **cfg.generator,
# )
elif api_type == "cohere":
llm = ChatCohere(
model=cfg.model.pretrained_model_name_or_path,
cohere_api_key=os.environ["COHERE_API_KEY"],
**cfg.generator,
)

else:
raise ValueError(f"Unsupported API type: {api_type}")
Expand Down