From 138775e3aabee1bdfabaae7f8ecae0fd0e6e8668 Mon Sep 17 00:00:00 2001 From: Max Shkutnyk Date: Thu, 21 Nov 2024 08:57:13 +0200 Subject: [PATCH] auto-format more code snippets --- .../cookbooks/convfinqa-finetuning-wandb.mdx | 4 +- .../pages/cookbooks/finetune-on-sagemaker.mdx | 2 +- fern/pages/cookbooks/rag-cohere-mongodb.mdx | 82 +++++++++---------- 3 files changed, 44 insertions(+), 44 deletions(-) diff --git a/fern/pages/cookbooks/convfinqa-finetuning-wandb.mdx b/fern/pages/cookbooks/convfinqa-finetuning-wandb.mdx index e22bf438..d4e29412 100644 --- a/fern/pages/cookbooks/convfinqa-finetuning-wandb.mdx +++ b/fern/pages/cookbooks/convfinqa-finetuning-wandb.mdx @@ -55,10 +55,10 @@ from cohere.finetuning import ( ) # fill in your Cohere API key here -os.environ['COHERE_API_KEY'] = "" +os.environ["COHERE_API_KEY"] = "" # instantiate the Cohere client -co = cohere.Client(os.environ['COHERE_API_KEY']) +co = cohere.Client(os.environ["COHERE_API_KEY"]) ``` ## Dataset diff --git a/fern/pages/cookbooks/finetune-on-sagemaker.mdx b/fern/pages/cookbooks/finetune-on-sagemaker.mdx index 1a53a152..d9dda465 100644 --- a/fern/pages/cookbooks/finetune-on-sagemaker.mdx +++ b/fern/pages/cookbooks/finetune-on-sagemaker.mdx @@ -298,7 +298,7 @@ from tqdm import tqdm total = 0 correct = 0 for line in tqdm( - open('./sample_finetune_scienceQA_eval.jsonl').readlines() + open("./sample_finetune_scienceQA_eval.jsonl").readlines() ): total += 1 question_answer_json = json.loads(line) diff --git a/fern/pages/cookbooks/rag-cohere-mongodb.mdx b/fern/pages/cookbooks/rag-cohere-mongodb.mdx index 0bfa14a6..c4ff401a 100644 --- a/fern/pages/cookbooks/rag-cohere-mongodb.mdx +++ b/fern/pages/cookbooks/rag-cohere-mongodb.mdx @@ -183,11 +183,11 @@ def combine_attributes(row): combined = f"{row['company']} {row['sector']} " # Add reports information - for report in row['reports']: + for report in row["reports"]: combined += f"{report['year']} {report['title']} {report['author']} {report['content']} " # Add recent news information - for news in row['recent_news']: + for news in row["recent_news"]: combined += f"{news['headline']} {news['summary']} " return combined.strip() @@ -196,7 +196,7 @@ def combine_attributes(row): ```python # Add the new column 'combined_attributes' -dataset_df['combined_attributes'] = dataset_df.apply( +dataset_df["combined_attributes"] = dataset_df.apply( combine_attributes, axis=1 ) ``` @@ -204,7 +204,7 @@ dataset_df['combined_attributes'] = dataset_df.apply( ```python # Display the first few rows of the updated dataframe -dataset_df[['company', 'ticker', 'combined_attributes']].head() +dataset_df[["company", "ticker", "combined_attributes"]].head() ```
@@ -270,7 +270,7 @@ def get_embedding( texts=[text], model=model, input_type=input_type, # Used for embeddings of search queries run against a vector DB to find relevant documents - embedding_types=['float'], + embedding_types=["float"], ) return response.embeddings.float[0] @@ -279,7 +279,7 @@ def get_embedding( # Apply the embedding function with a progress bar tqdm.pandas(desc="Generating embeddings") dataset_df["embedding"] = dataset_df[ - 'combined_attributes' + "combined_attributes" ].progress_apply(get_embedding) print(f"We just computed {len(dataset_df['embedding'])} embeddings.") @@ -421,8 +421,8 @@ def get_mongo_client(mongo_uri): ) # Validate the connection - ping_result = client.admin.command('ping') - if ping_result.get('ok') == 1.0: + ping_result = client.admin.command("ping") + if ping_result.get("ok") == 1.0: # Connection successful print("Connection to MongoDB successful") return client @@ -478,7 +478,7 @@ MongoDB's Document model and its compatibility with Python dictionaries offer se ![](../../assets/images/rag-cohere-mongodb-4.png) ```python -documents = dataset_df.to_dict('records') +documents = dataset_df.to_dict("records") collection.insert_many(documents) print("Data ingestion into MongoDB completed") @@ -592,13 +592,13 @@ def rerank_documents(query: str, documents, top_n: int = 3): original_doc = documents[result.index] top_documents_after_rerank.append( { - 'company': original_doc['company'], - 'combined_attributes': original_doc[ - 'combined_attributes' + "company": original_doc["company"], + "combined_attributes": original_doc[ + "combined_attributes" ], - 'reports': original_doc['reports'], - 'vector_search_score': original_doc['score'], - 'relevance_score': result.relevance_score, + "reports": original_doc["reports"], + "vector_search_score": original_doc["score"], + "relevance_score": result.relevance_score, } ) @@ -724,9 +724,9 @@ pd.DataFrame(reranked_documents).head() def format_documents_for_chat(documents): return [ { - "company": doc['company'], + "company": doc["company"], # "reports": doc['reports'], - "combined_attributes": doc['combined_attributes'], + "combined_attributes": doc["combined_attributes"], } for doc in documents ] @@ -825,7 +825,7 @@ class CohereChat: # Use the connection string from history_params self.client = pymongo.MongoClient( self.history_params.get( - 'connection_string', 'mongodb://localhost:27017/' + "connection_string", "mongodb://localhost:27017/" ) ) @@ -838,34 +838,34 @@ class CohereChat: # Use the history_collection from history_params, or default to "chat_history" self.history_collection = self.db[ self.history_params.get( - 'history_collection', 'chat_history' + "history_collection", "chat_history" ) ] # Use the session_id from history_params, or default to "default_session" self.session_id = self.history_params.get( - 'session_id', 'default_session' + "session_id", "default_session" ) def add_to_history(self, message: str, prefix: str = ""): self.history_collection.insert_one( { - 'session_id': self.session_id, - 'message': message, - 'prefix': prefix, + "session_id": self.session_id, + "message": message, + "prefix": prefix, } ) def get_chat_history(self) -> List[Dict[str, str]]: history = self.history_collection.find( - {'session_id': self.session_id} - ).sort('_id', 1) + {"session_id": self.session_id} + ).sort("_id", 1) return [ { "role": ( - "user" if item['prefix'] == "USER" else "chatbot" + "user" if item["prefix"] == "USER" else "chatbot" ), - "message": item['message'], + "message": item["message"], } for item in history ] @@ -875,11 +875,11 @@ class CohereChat: ) -> List[Dict]: rerank_docs = [ { - 'company': doc['company'], - 'combined_attributes': doc['combined_attributes'], + "company": doc["company"], + "combined_attributes": doc["combined_attributes"], } for doc in documents - if doc['combined_attributes'].strip() + if doc["combined_attributes"].strip() ] if not rerank_docs: @@ -897,11 +897,11 @@ class CohereChat: top_documents_after_rerank = [ { - 'company': rerank_docs[result.index]['company'], - 'combined_attributes': rerank_docs[result.index][ - 'combined_attributes' + "company": rerank_docs[result.index]["company"], + "combined_attributes": rerank_docs[result.index][ + "combined_attributes" ], - 'relevance_score': result.relevance_score, + "relevance_score": result.relevance_score, } for result in response.results ] @@ -925,8 +925,8 @@ class CohereChat: ) -> List[Dict]: return [ { - "company": doc['company'], - "combined_attributes": doc['combined_attributes'], + "company": doc["company"], + "combined_attributes": doc["combined_attributes"], } for doc in documents ] @@ -972,8 +972,8 @@ class CohereChat: def show_history(self): history = self.history_collection.find( - {'session_id': self.session_id} - ).sort('_id', 1) + {"session_id": self.session_id} + ).sort("_id", 1) for item in history: print(f"{item['prefix']}: {item['message']}") print("-------------------------") @@ -988,9 +988,9 @@ chat = CohereChat( database=DB_NAME, main_collection=COLLECTION_NAME, history_params={ - 'connection_string': MONGO_URI, - 'history_collection': "chat_history", - 'session_id': 2, + "connection_string": MONGO_URI, + "history_collection": "chat_history", + "session_id": 2, }, )