-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Working streamlit.chat with Google Vertex AI
- Loading branch information
Showing
16 changed files
with
338 additions
and
30 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import streamlit as st | ||
|
||
import vertexai | ||
from vertexai.generative_models import GenerativeModel, Part | ||
import vertexai.preview.generative_models as generative_models | ||
from google.oauth2.service_account import Credentials | ||
|
||
from datetime import datetime | ||
|
||
|
||
st.title("Stub of DS chat with Google Vertex AI") | ||
|
||
######################################################################## | ||
# init chat model | ||
|
||
chat = None | ||
|
||
if "vertex_ai_model" not in st.session_state or chat is None: | ||
st.session_state["vertex_ai_model"] = "gemini-1.5-flash-001" | ||
|
||
credentials = Credentials.from_service_account_info(st.secrets["gcs_connections"]) | ||
vertexai.init( | ||
project="pivotal-cable-428219-c5", | ||
location="us-central1", | ||
credentials=credentials, | ||
) | ||
model = GenerativeModel( | ||
st.session_state[ | ||
"vertex_ai_model" | ||
], # by default it will be "gemini-1.5-flash-001", | ||
system_instruction=[ | ||
"""You are an expert python engineer with data scientist background.""" | ||
], | ||
) | ||
chat = model.start_chat() | ||
|
||
generation_config = { | ||
"max_output_tokens": 8192, | ||
"temperature": 1, | ||
"top_p": 0.95, | ||
} | ||
|
||
safety_settings = { | ||
generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, | ||
generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, | ||
generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, | ||
generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, | ||
} | ||
|
||
##################################### | ||
|
||
|
||
if "messages" not in st.session_state: | ||
st.session_state.messages = [] | ||
|
||
for message in st.session_state.messages: | ||
with st.chat_message(message["role"]): | ||
st.markdown(message["content"]) | ||
|
||
# streaming is not working with streamlit, exceptions inside `vertexai.generative_models import GenerativeModel` | ||
USE_STREAMING = False | ||
|
||
if prompt := st.chat_input("What is up?"): | ||
st.session_state.messages.append({"role": "user", "content": prompt}) | ||
with st.chat_message("user"): | ||
st.markdown(prompt) | ||
|
||
with st.chat_message("assistant"): | ||
if USE_STREAMING: | ||
api_response_stream = chat.send_message( | ||
[prompt], | ||
generation_config=generation_config, | ||
safety_settings=safety_settings, | ||
stream=True, | ||
) | ||
|
||
def stream_data(): | ||
for api_response in api_response_stream: | ||
chunk = api_response.candidates[0].content.parts[0]._raw_part.text | ||
print(f"{datetime.now()}: {chunk}") | ||
yield chunk | ||
|
||
response = st.write_stream(stream_data) | ||
else: | ||
with st.spinner( | ||
"Wait for the whole response (streaming not working with Streamlit)..." | ||
): | ||
api_response = chat.send_message( | ||
[prompt], | ||
generation_config=generation_config, | ||
safety_settings=safety_settings, | ||
stream=False, | ||
) | ||
response = api_response.candidates[0].content.parts[0]._raw_part.text | ||
st.write(response) | ||
|
||
print(("response:", response)) | ||
st.session_state.messages.append({"role": "assistant", "content": response}) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
14 changes: 14 additions & 0 deletions
14
apps/streamlit_ds_chat/experiments_standalone/conversational.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from transformers import pipeline | ||
|
||
generator = pipeline(model="HuggingFaceH4/zephyr-7b-beta") | ||
# Zephyr-beta is a conversational model, so let's pass it a chat instead of a single string | ||
result = generator( | ||
[{"role": "user", "content": "What is the capital of France? Answer in one word."}], | ||
do_sample=False, | ||
max_new_tokens=2, | ||
) | ||
|
||
# [{'generated_text': [{'role': 'user', 'content': 'What is the capital of France? Answer in one word.'}, | ||
# {'role': 'assistant', 'content': 'Paris'}]}] | ||
|
||
print(result) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from pprint import pprint | ||
|
||
# Use a pipeline as a high-level helper | ||
from transformers import pipeline | ||
|
||
oracle = pipeline("text2text-generation", model="describeai/gemini") | ||
# - `"question-answering"`: will return a [`QuestionAnsweringPipeline`]. | ||
# - `"text2text-generation"`: will return a [`Text2TextGenerationPipeline`]. | ||
|
||
# QuestionAnsweringPipeline | ||
|
||
# result = oracle(question="Write a short snippet of python code, which use pandas to read csv file into dataframe.", context="I am an expert Python engineer.") | ||
result = oracle( | ||
inputs=[ | ||
"Write a short snippet of python code, which use pandas to read csv file into dataframe.", | ||
"I am an expert Python engineer.", | ||
], | ||
max_lenght=1000, | ||
max_new_tokens=1000, | ||
) | ||
|
||
pprint(result) | ||
|
||
# from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | ||
# tokenizer = AutoTokenizer.from_pretrained("describeai/gemini") | ||
# model = AutoModelForSeq2SeqLM.from_pretrained("describeai/gemini") |
131 changes: 131 additions & 0 deletions
131
apps/streamlit_ds_chat/experiments_standalone/google_vertex_ai.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
""" | ||
pip install --upgrade google-cloud-aiplatform | ||
gcloud auth application-default login | ||
Command 'gcloud' not found, but can be installed with: | ||
sudo snap install google-cloud-cli # version 483.0.0, or | ||
sudo snap install google-cloud-sdk # version 483.0.0 | ||
#################333 | ||
Credentials saved to file: [/home/s-nechuiviter/.config/gcloud/application_default_credentials.json] | ||
These credentials will be used by any library that requests Application Default Credentials (ADC). | ||
WARNING: | ||
Cannot find a quota project to add to ADC. You might receive a "quota exceeded" or "API not enabled" error. Run $ gcloud auth application-default set-quota-project to add a quota project. | ||
""" | ||
|
||
import base64 | ||
import vertexai | ||
from vertexai.generative_models import GenerativeModel, Part | ||
import vertexai.preview.generative_models as generative_models | ||
|
||
# import streamlit as st | ||
|
||
from google.oauth2.service_account import Credentials | ||
|
||
creds = { | ||
"type": "service_account", | ||
"project_id": "pivotal-cable-428219-c5", | ||
"private_key_id": "be8874b7dea5c6d02590fb0d0df13798777e5e08", | ||
"private_key": "-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDGqmtXg3FIQ1Ee\na7fjOy/sFQYDBCE32QKMCwGT9HTFVeXb2fUJmqlXQAIvQPoprOE3CQn6UOW1+mMU\n0bQKF3C4NMd//zQAr2TT/LIc/I968Ex6y+kkXAubsgnpd8mUQJEsFXDAasJzAhLI\n8NdeAiYkrXEBzv1glsnTvkE8WqijC22tCDe1n8F1N3VUZhucmGPrWLu++Pxwpox5\n5HxcTYImo2qs/MXMekor0FWOKIQqNZLyVYs5/K3XBC0UGePo0V4LEbraZIlBaPdb\nWdry1jxOpG3cCt9HEQj3g88PZZXNV1go54vKzoEmwy862UN2oIV0o6spNzWV3MJF\n5B0svOOvAgMBAAECggEAEX3t0kMY1VOqr8gCOo1P4d8pVq36vW/BC/2SwlPdeDE2\nT444R1kzxyKJTrR6PPMjfUoGyTXC6Vabeg4gu7FmqQFp/g4apmN1uD3hR4YWnVd3\nvxQhz06/bZJWfGpMowiwVOD/uswWRN643V0UnrVdEGZTszoQuybBAdYYPlfTzfnK\n6OdvvMEPBnzCpXwe0cC8bLFTpZ4Q6N0Ss7fMDZ4wOLhA+fNk24xedQhRSLHm8REJ\nxNQ6ha2cp7xK7PXYjjpPsMGVLxdY7Qn2hOA0/ZFPJCoG8ToSxaU5t27dOBsbvw6N\nrEo7Dlx93lzizPNc9XhaKTLxfv2B4a9UAWiNbd6JoQKBgQDj8r5A7G0YT3KSTtlj\ncWRGceN/ig34/Lo7086IMHODWFZyahnZtKu1pJAfsteFT3Xb3Hp1fYxl7rqt++tS\nUY/tHIWu/zQwcYE3Nx03Kn5lafFTcH7bHTtsfvVYcUP0wKASYnbv1p9hxytj3Hqi\nqe2gBadULE69jolXZAum2VOTcQKBgQDfHSmuq3e7V6ydAMiN2a5CLLuJ1TiRx0VJ\nLNgTMYjJApm8BP7symUf62/I+FKfolUdHZrfleT14En4TYTC82wQie4jwPXfj4Wi\nlkB9m/cIj516XrmMqC+F0M1Qb4ii1OvTy0Pfwu34TyjFgUNCT1Mg/be0w20Fni/G\n1Sl6ybgZHwKBgFRaM8VauFRSshcqTo/aGj1nT8SWle0ZuOEC1F7ZbyWfvv2//aju\njsw9BYh1agPPD9I4mKh5uUbPPQ29N6vSuuwHrgDAN9PlbOe94XXUp8lnlwJFkuwK\nuT7BDJGZ+IfN8G5dOZ4vUfOg/JGLuWYQc/rPnMgtTUYgRPqt7xHjQmZBAoGBAJWM\nXpwNoruYEMNL+yHZfswsX6gLm1dbUj2yKUL0ONNDQvicAKOHJjE3Bj6W9Aq8LIDP\njze+qTGFnQ8qJorlztFnIpAkjqnC8bgBLkkDeZnraYrUY1q8gN4ZDwWTPOqn/UqB\nPIWHiyqdJJ79/a88rGO4rKIlO7ZASZXk22DKRPPRAoGAUy8iMXIWOOKnIwRhKtFb\nkraZRAvPS7b9409wdRnaEwa7h2FMsIXNvuvob1B9r1N/sWZr3lb8Q7gvJAqB4Zf2\nl2YfVsbqPkNby1bDdQ/atRjK7NlKTN3bXvktYb4pcGmmAZ/RuQuFl2rx1rptJN6D\nha5obgOGZyiUbJoHXDalvBg=\n-----END PRIVATE KEY-----\n", | ||
"client_email": "[email protected]", | ||
"client_id": "112274147955673066299", | ||
"auth_uri": "https://accounts.google.com/o/oauth2/auth", | ||
"token_uri": "https://oauth2.googleapis.com/token", | ||
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", | ||
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/ds-chat%40pivotal-cable-428219-c5.iam.gserviceaccount.com", | ||
"universe_domain": "googleapis.com", | ||
} | ||
|
||
credentials = Credentials.from_service_account_info( | ||
creds | ||
) # st.secrets["gcs_connections"]) | ||
|
||
|
||
def multiturn_generate_content(): | ||
vertexai.init( | ||
project="pivotal-cable-428219-c5", | ||
location="us-central1", | ||
credentials=credentials, | ||
) | ||
model = GenerativeModel( | ||
"gemini-1.5-flash-001", | ||
system_instruction=[ | ||
"""You are an expert python engineer with data scientist background.""" | ||
], | ||
) | ||
chat = model.start_chat() | ||
print( | ||
chat.send_message( | ||
[text1_1], | ||
generation_config=generation_config, | ||
safety_settings=safety_settings, | ||
) | ||
) | ||
|
||
|
||
text1_1 = """Write a short snippet of python code, which use pandas to read csv file into dataframe. Return only code, nothing else. So I could copy the response, so I could copy it into python module.""" | ||
|
||
generation_config = { | ||
"max_output_tokens": 8192, | ||
"temperature": 1, | ||
"top_p": 0.95, | ||
} | ||
|
||
safety_settings = { | ||
generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, | ||
generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, | ||
generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, | ||
generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, | ||
} | ||
|
||
multiturn_generate_content() | ||
|
||
""" | ||
Example results: | ||
candidates { | ||
content { | ||
role: "model" | ||
parts { | ||
text: "```python\nimport pandas as pd\ndf = pd.read_csv(\'your_file.csv\')\n```" | ||
} | ||
} | ||
finish_reason: STOP | ||
safety_ratings { | ||
category: HARM_CATEGORY_HATE_SPEECH | ||
probability: NEGLIGIBLE | ||
probability_score: 0.06632687151432037 | ||
severity: HARM_SEVERITY_NEGLIGIBLE | ||
severity_score: 0.10017222911119461 | ||
} | ||
safety_ratings { | ||
category: HARM_CATEGORY_DANGEROUS_CONTENT | ||
probability: NEGLIGIBLE | ||
probability_score: 0.20134170353412628 | ||
severity: HARM_SEVERITY_NEGLIGIBLE | ||
severity_score: 0.07599521428346634 | ||
} | ||
safety_ratings { | ||
category: HARM_CATEGORY_HARASSMENT | ||
probability: NEGLIGIBLE | ||
probability_score: 0.193451926112175 | ||
severity: HARM_SEVERITY_NEGLIGIBLE | ||
severity_score: 0.09334687888622284 | ||
} | ||
safety_ratings { | ||
category: HARM_CATEGORY_SEXUALLY_EXPLICIT | ||
probability: NEGLIGIBLE | ||
probability_score: 0.06816437095403671 | ||
severity: HARM_SEVERITY_NEGLIGIBLE | ||
severity_score: 0.05155818909406662 | ||
} | ||
} | ||
usage_metadata { | ||
prompt_token_count: 52 | ||
candidates_token_count: 24 | ||
total_token_count: 76 | ||
} | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
12 changes: 12 additions & 0 deletions
12
apps/streamlit_ds_chat/experiments_standalone/vertex_code_bison.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import vertexai | ||
from vertexai.language_models import CodeGenerationModel | ||
|
||
vertexai.init(project="pivotal-cable-428219-c5", location="us-central1") | ||
parameters = {"candidate_count": 1, "max_output_tokens": 1024, "temperature": 0.9} | ||
model = CodeGenerationModel.from_pretrained("code-bison@002") | ||
response = model.predict( | ||
prefix="""You are an expert python engineer with data scientist background. | ||
Write a short snippet of python code, which use pandas to read csv file into dataframe.""", | ||
**parameters, | ||
) | ||
print(f"Response from Model: {response.text}") |
Empty file.
File renamed without changes.
File renamed without changes.
Oops, something went wrong.