Skip to content

Commit

Permalink
auto-format more code snippets
Browse files Browse the repository at this point in the history
  • Loading branch information
Max Shkutnyk committed Nov 21, 2024
1 parent d4d4b94 commit 138775e
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 44 deletions.
4 changes: 2 additions & 2 deletions fern/pages/cookbooks/convfinqa-finetuning-wandb.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ from cohere.finetuning import (
)

# fill in your Cohere API key here
os.environ['COHERE_API_KEY'] = "<COHERE_API_KEY>"
os.environ["COHERE_API_KEY"] = "<COHERE_API_KEY>"

# instantiate the Cohere client
co = cohere.Client(os.environ['COHERE_API_KEY'])
co = cohere.Client(os.environ["COHERE_API_KEY"])
```

## Dataset
Expand Down
2 changes: 1 addition & 1 deletion fern/pages/cookbooks/finetune-on-sagemaker.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ from tqdm import tqdm
total = 0
correct = 0
for line in tqdm(
open('./sample_finetune_scienceQA_eval.jsonl').readlines()
open("./sample_finetune_scienceQA_eval.jsonl").readlines()
):
total += 1
question_answer_json = json.loads(line)
Expand Down
82 changes: 41 additions & 41 deletions fern/pages/cookbooks/rag-cohere-mongodb.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,11 @@ def combine_attributes(row):
combined = f"{row['company']} {row['sector']} "

# Add reports information
for report in row['reports']:
for report in row["reports"]:
combined += f"{report['year']} {report['title']} {report['author']} {report['content']} "

# Add recent news information
for news in row['recent_news']:
for news in row["recent_news"]:
combined += f"{news['headline']} {news['summary']} "

return combined.strip()
Expand All @@ -196,15 +196,15 @@ def combine_attributes(row):

```python
# Add the new column 'combined_attributes'
dataset_df['combined_attributes'] = dataset_df.apply(
dataset_df["combined_attributes"] = dataset_df.apply(
combine_attributes, axis=1
)
```


```python
# Display the first few rows of the updated dataframe
dataset_df[['company', 'ticker', 'combined_attributes']].head()
dataset_df[["company", "ticker", "combined_attributes"]].head()
```

<div>
Expand Down Expand Up @@ -270,7 +270,7 @@ def get_embedding(
texts=[text],
model=model,
input_type=input_type, # Used for embeddings of search queries run against a vector DB to find relevant documents
embedding_types=['float'],
embedding_types=["float"],
)

return response.embeddings.float[0]
Expand All @@ -279,7 +279,7 @@ def get_embedding(
# Apply the embedding function with a progress bar
tqdm.pandas(desc="Generating embeddings")
dataset_df["embedding"] = dataset_df[
'combined_attributes'
"combined_attributes"
].progress_apply(get_embedding)

print(f"We just computed {len(dataset_df['embedding'])} embeddings.")
Expand Down Expand Up @@ -421,8 +421,8 @@ def get_mongo_client(mongo_uri):
)

# Validate the connection
ping_result = client.admin.command('ping')
if ping_result.get('ok') == 1.0:
ping_result = client.admin.command("ping")
if ping_result.get("ok") == 1.0:
# Connection successful
print("Connection to MongoDB successful")
return client
Expand Down Expand Up @@ -478,7 +478,7 @@ MongoDB's Document model and its compatibility with Python dictionaries offer se
![](../../assets/images/rag-cohere-mongodb-4.png)

```python
documents = dataset_df.to_dict('records')
documents = dataset_df.to_dict("records")
collection.insert_many(documents)

print("Data ingestion into MongoDB completed")
Expand Down Expand Up @@ -592,13 +592,13 @@ def rerank_documents(query: str, documents, top_n: int = 3):
original_doc = documents[result.index]
top_documents_after_rerank.append(
{
'company': original_doc['company'],
'combined_attributes': original_doc[
'combined_attributes'
"company": original_doc["company"],
"combined_attributes": original_doc[
"combined_attributes"
],
'reports': original_doc['reports'],
'vector_search_score': original_doc['score'],
'relevance_score': result.relevance_score,
"reports": original_doc["reports"],
"vector_search_score": original_doc["score"],
"relevance_score": result.relevance_score,
}
)

Expand Down Expand Up @@ -724,9 +724,9 @@ pd.DataFrame(reranked_documents).head()
def format_documents_for_chat(documents):
return [
{
"company": doc['company'],
"company": doc["company"],
# "reports": doc['reports'],
"combined_attributes": doc['combined_attributes'],
"combined_attributes": doc["combined_attributes"],
}
for doc in documents
]
Expand Down Expand Up @@ -825,7 +825,7 @@ class CohereChat:
# Use the connection string from history_params
self.client = pymongo.MongoClient(
self.history_params.get(
'connection_string', 'mongodb://localhost:27017/'
"connection_string", "mongodb://localhost:27017/"
)
)

Expand All @@ -838,34 +838,34 @@ class CohereChat:
# Use the history_collection from history_params, or default to "chat_history"
self.history_collection = self.db[
self.history_params.get(
'history_collection', 'chat_history'
"history_collection", "chat_history"
)
]

# Use the session_id from history_params, or default to "default_session"
self.session_id = self.history_params.get(
'session_id', 'default_session'
"session_id", "default_session"
)

def add_to_history(self, message: str, prefix: str = ""):
self.history_collection.insert_one(
{
'session_id': self.session_id,
'message': message,
'prefix': prefix,
"session_id": self.session_id,
"message": message,
"prefix": prefix,
}
)

def get_chat_history(self) -> List[Dict[str, str]]:
history = self.history_collection.find(
{'session_id': self.session_id}
).sort('_id', 1)
{"session_id": self.session_id}
).sort("_id", 1)
return [
{
"role": (
"user" if item['prefix'] == "USER" else "chatbot"
"user" if item["prefix"] == "USER" else "chatbot"
),
"message": item['message'],
"message": item["message"],
}
for item in history
]
Expand All @@ -875,11 +875,11 @@ class CohereChat:
) -> List[Dict]:
rerank_docs = [
{
'company': doc['company'],
'combined_attributes': doc['combined_attributes'],
"company": doc["company"],
"combined_attributes": doc["combined_attributes"],
}
for doc in documents
if doc['combined_attributes'].strip()
if doc["combined_attributes"].strip()
]

if not rerank_docs:
Expand All @@ -897,11 +897,11 @@ class CohereChat:

top_documents_after_rerank = [
{
'company': rerank_docs[result.index]['company'],
'combined_attributes': rerank_docs[result.index][
'combined_attributes'
"company": rerank_docs[result.index]["company"],
"combined_attributes": rerank_docs[result.index][
"combined_attributes"
],
'relevance_score': result.relevance_score,
"relevance_score": result.relevance_score,
}
for result in response.results
]
Expand All @@ -925,8 +925,8 @@ class CohereChat:
) -> List[Dict]:
return [
{
"company": doc['company'],
"combined_attributes": doc['combined_attributes'],
"company": doc["company"],
"combined_attributes": doc["combined_attributes"],
}
for doc in documents
]
Expand Down Expand Up @@ -972,8 +972,8 @@ class CohereChat:

def show_history(self):
history = self.history_collection.find(
{'session_id': self.session_id}
).sort('_id', 1)
{"session_id": self.session_id}
).sort("_id", 1)
for item in history:
print(f"{item['prefix']}: {item['message']}")
print("-------------------------")
Expand All @@ -988,9 +988,9 @@ chat = CohereChat(
database=DB_NAME,
main_collection=COLLECTION_NAME,
history_params={
'connection_string': MONGO_URI,
'history_collection': "chat_history",
'session_id': 2,
"connection_string": MONGO_URI,
"history_collection": "chat_history",
"session_id": 2,
},
)

Expand Down

0 comments on commit 138775e

Please sign in to comment.