From 7a23e37e00e3c594baa6d744f2e85171d7326d08 Mon Sep 17 00:00:00 2001 From: Adrian Stepniak Date: Thu, 17 Oct 2024 17:09:05 +0200 Subject: [PATCH 1/6] add some debug logs --- admin_apps/journeys/iteration.py | 9 +++++---- semantic_model_generator/data_processing/cte_utils.py | 11 +++++++++-- .../snowflake_utils/snowflake_connector.py | 3 ++- 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/admin_apps/journeys/iteration.py b/admin_apps/journeys/iteration.py index b1c01f10..bb9a8dd1 100644 --- a/admin_apps/journeys/iteration.py +++ b/admin_apps/journeys/iteration.py @@ -6,6 +6,7 @@ import requests import sqlglot import streamlit as st +from loguru import logger from snowflake.connector import ProgrammingError, SnowflakeConnection from streamlit.delta_generator import DeltaGenerator from streamlit_extras.row import row @@ -93,11 +94,10 @@ def send_message( "messages": messages, "semantic_model": proto_to_yaml(st.session_state.semantic_model), } - host = st.session_state.host_name + api_endpoint = API_ENDPOINT.format(HOST=st.session_state.host_name) + logger.debug(f"Sending request to Analyst API at {api_endpoint}: {request_body}") resp = requests.post( - API_ENDPOINT.format( - HOST=host, - ), + api_endpoint, json=request_body, headers={ "Authorization": f'Snowflake Token="{_conn.rest.token}"', # type: ignore[union-attr] @@ -106,6 +106,7 @@ def send_message( ) if resp.status_code < 400: json_resp: Dict[str, Any] = resp.json() + logger.debug(f"Received response from Analyst API: {json_resp}") return json_resp else: raise Exception(f"Failed request with status {resp.status_code}: {resp.text}") diff --git a/semantic_model_generator/data_processing/cte_utils.py b/semantic_model_generator/data_processing/cte_utils.py index 92857c54..8300a54f 100644 --- a/semantic_model_generator/data_processing/cte_utils.py +++ b/semantic_model_generator/data_processing/cte_utils.py @@ -142,6 +142,7 @@ def _generate_cte_for( cte += ",\n".join(expr_columns) + "\n" cte += f"FROM {fully_qualified_table_name(table.base_table)}" cte += ")" + return cte @@ -266,7 +267,10 @@ def generate_select( non_agg_cte + f"SELECT * FROM {logical_table_name(table_in_column_format)} LIMIT {limit}" ) - sqls_to_return.append(_convert_to_snowflake_sql(non_agg_sql)) + # sqls_to_return.append(_convert_to_snowflake_sql(non_agg_sql)) + sqls_to_return.append( + non_agg_sql + ) # do not convert to snowflake sql for now, as sqlglot make mistakes sometimes, e.g. with TO_DATE() # Generate select query for columns with aggregation exprs. agg_cols = [ @@ -280,7 +284,10 @@ def generate_select( agg_cte + f"SELECT * FROM {logical_table_name(table_in_column_format)} LIMIT {limit}" ) - sqls_to_return.append(_convert_to_snowflake_sql(agg_sql)) + # sqls_to_return.append(_convert_to_snowflake_sql(agg_sql)) + sqls_to_return.append( + agg_sql + ) # do not convert to snowflake sql for now, as sqlglot make mistakes sometimes, e.g. with TO_DATE() return sqls_to_return diff --git a/semantic_model_generator/snowflake_utils/snowflake_connector.py b/semantic_model_generator/snowflake_utils/snowflake_connector.py index ea00f4bf..15966a86 100644 --- a/semantic_model_generator/snowflake_utils/snowflake_connector.py +++ b/semantic_model_generator/snowflake_utils/snowflake_connector.py @@ -16,7 +16,7 @@ ConnectionType = TypeVar("ConnectionType") # Append this to the end of the auto-generated comments to indicate that the comment was auto-generated. AUTOGEN_TOKEN = "__" -_autogen_model = "llama3-8b" +_autogen_model = "llama3.1-70b" # This is the raw column name from snowflake information schema or desc table _COMMENT_COL = "COMMENT" @@ -115,6 +115,7 @@ def _get_column_comment( values: {';'.join(column_values) if column_values else ""}; Please provide a business description for the column. Only return the description without any other text.""" complete_sql = f"select SNOWFLAKE.CORTEX.COMPLETE('{_autogen_model}', '{comment_prompt}')" + logger.debug(f"Complete_sql: {complete_sql}") cmt = conn.cursor().execute(complete_sql).fetchall()[0][0] # type: ignore[union-attr] return str(cmt + AUTOGEN_TOKEN) except Exception as e: From a581d904d521859f1f0d8ff6754057a8a92abb31 Mon Sep 17 00:00:00 2001 From: Adrian Stepniak Date: Mon, 21 Oct 2024 15:41:48 +0200 Subject: [PATCH 2/6] Add two models comparator app --- admin_apps/app.py | 12 +++- admin_apps/journeys/comparator.py | 115 ++++++++++++++++++++++++++++++ admin_apps/shared_utils.py | 23 ++++++ 3 files changed, 149 insertions(+), 1 deletion(-) create mode 100644 admin_apps/journeys/comparator.py diff --git a/admin_apps/app.py b/admin_apps/app.py index ad1f2b79..a6a292ee 100644 --- a/admin_apps/app.py +++ b/admin_apps/app.py @@ -62,7 +62,7 @@ def verify_environment_setup() -> None: if __name__ == "__main__": - from admin_apps.journeys import builder, iteration, partner + from admin_apps.journeys import builder, iteration, partner, comparator def onboarding_dialog() -> None: """ @@ -114,6 +114,14 @@ def onboarding_dialog() -> None: action="start", ) partner.show() + st.markdown("") + if st.button( + "**📝 Compare two semantic models**", + use_container_width=True, + type="primary", + ): + comparator.init_dialog() + st.markdown("") verify_environment_setup() @@ -130,5 +138,7 @@ def onboarding_dialog() -> None: # The builder flow is simply an intermediate dialog before the iteration flow. if st.session_state["page"] == GeneratorAppScreen.ITERATION: iteration.show() + elif st.session_state["page"] == GeneratorAppScreen.COMPARATOR: + comparator.show() else: onboarding_dialog() diff --git a/admin_apps/journeys/comparator.py b/admin_apps/journeys/comparator.py new file mode 100644 index 00000000..cf6c8a35 --- /dev/null +++ b/admin_apps/journeys/comparator.py @@ -0,0 +1,115 @@ +import streamlit as st +from admin_apps.shared_utils import GeneratorAppScreen +from admin_apps.shared_utils import return_home_button +from admin_apps.shared_utils import download_yaml_fqn +from semantic_model_generator.data_processing.proto_utils import yaml_to_semantic_model, proto_to_yaml +from streamlit_monaco import st_monaco + +MODEL1_PATH = "model1_path" +MODEL1_YAML = "model1_yaml" +MODEL2_PATH = "model2_path" +MODEL2_YAML = "model2_yaml" + + +def init_session_states() -> None: + st.session_state["page"] = GeneratorAppScreen.COMPARATOR + + +def comparator_app() -> None: + st.write( + """ + ## Compare two semantic models + """ + ) + col1, col2 = st.columns(2) + with col1: + st.write(f"Model 1 from: `{st.session_state[MODEL1_PATH]}`") + content1 = st_monaco( + value=st.session_state[MODEL1_YAML], + height="400px", + language="yaml", + ) + + with col2: + st.write(f"Model 2 from: `{st.session_state[MODEL2_PATH]}`") + content2 = st_monaco( + value=st.session_state[MODEL2_YAML], + height="400px", + language="yaml", + ) + + # TODO: + # - Compare the two models + # - Show the differences + # - Validation of the models + # - Check if both models are pointing at the same table + # - dialog to ask questions + # - Results of the cortex analyst with both models + + return_home_button() + + +def is_session_state_initialized() -> bool: + return all([ + MODEL1_YAML in st.session_state, + MODEL2_YAML in st.session_state, + MODEL1_PATH in st.session_state, + MODEL2_PATH in st.session_state, + ]) + + +def show() -> None: + init_session_states() + if is_session_state_initialized(): + comparator_app() + else: + init_dialog() + + +def read_semantic_model(model_path: str) -> str: + """Read the semantic model from the given path (local or snowflake stage). + + Args: + model_path (str): The path to the semantic model. + + Returns: + str: The semantic model as a string. + + Raises: + FileNotFoundError: If the model is not found. + """ + if model_path.startswith('@'): + return download_yaml_fqn(model_path) + else: + with open(model_path, "r") as f: + return f.read() + + +@st.dialog("Welcome to the Cortex Analyst Annotation Workspace! 📝", width="large") +def init_dialog() -> None: + init_session_states() + + st.write("Please enter the paths (local or stage) of the two models you would like to compare.") + + model_1_path = st.text_input("Model 1", placeholder="e.g. /local/path/to/model1.yaml") + model_2_path = st.text_input("Model 2", placeholder="e.g. @DATABASE.SCHEMA.STAGE_NAME/path/to/model2.yaml") + + if st.button("Compare"): + model_1_yaml = model_2_yaml = None + try: + model_1_yaml = read_semantic_model(model_1_path) + except FileNotFoundError as e: + st.error(f"Model 1 not found: {e}") + try: + model_2_yaml = read_semantic_model(model_2_path) + except FileNotFoundError as e: + st.error(f"Model 2 not found: {e}") + + if model_1_yaml and model_2_yaml: + st.session_state[MODEL1_PATH] = model_1_path + st.session_state[MODEL1_YAML] = model_1_yaml + st.session_state[MODEL2_PATH] = model_2_path + st.session_state[MODEL2_YAML] = model_2_yaml + st.rerun() + + return_home_button() diff --git a/admin_apps/shared_utils.py b/admin_apps/shared_utils.py index 1e77c3f3..a85d2329 100644 --- a/admin_apps/shared_utils.py +++ b/admin_apps/shared_utils.py @@ -120,6 +120,7 @@ class GeneratorAppScreen(str, Enum): ONBOARDING = "onboarding" ITERATION = "iteration" + COMPARATOR = "comparator" def return_home_button() -> None: @@ -889,6 +890,28 @@ def download_yaml(file_name: str, conn: SnowflakeConnection) -> str: return yaml_str +def download_yaml_fqn(file_name: str, conn: SnowflakeConnection) -> str: + """util to download a semantic YAML from a stage.""" + import os + import tempfile + + if not file_name.endswith(".yaml") and not file_name.startswith("@"): + raise ValueError( + "file_name should be a valid, fully qualified stage name starting with @ with .yaml suffix." + ) + + with tempfile.TemporaryDirectory() as temp_dir: + # Downloads the YAML to {temp_dir}/{file_name}. + download_yaml_sql = f"GET {file_name} file://{temp_dir}" + conn.cursor().execute(download_yaml_sql) + + tmp_file_path = os.path.join(temp_dir, f"{file_name}") + with open(tmp_file_path, "r") as temp_file: + # Read the raw contents from {temp_dir}/{file_name} and return it as a string. + yaml_str = temp_file.read() + return yaml_str + + def get_sit_query_tag( vendor: Optional[str] = None, action: Optional[str] = None ) -> str: From ae57d65c9914ac9ee21c475a731e9d374232e110 Mon Sep 17 00:00:00 2001 From: Adrian Stepniak Date: Tue, 22 Oct 2024 14:01:30 +0200 Subject: [PATCH 3/6] added functionality to check semantic models and answers side-by-side --- admin_apps/app.py | 2 +- admin_apps/journeys/comparator.py | 236 ++++++++++++++++++++++-------- admin_apps/journeys/iteration.py | 4 +- admin_apps/shared_utils.py | 55 ++++--- 4 files changed, 210 insertions(+), 87 deletions(-) diff --git a/admin_apps/app.py b/admin_apps/app.py index a6a292ee..f380ea4c 100644 --- a/admin_apps/app.py +++ b/admin_apps/app.py @@ -62,7 +62,7 @@ def verify_environment_setup() -> None: if __name__ == "__main__": - from admin_apps.journeys import builder, iteration, partner, comparator + from admin_apps.journeys import builder, comparator, iteration, partner def onboarding_dialog() -> None: """ diff --git a/admin_apps/journeys/comparator.py b/admin_apps/journeys/comparator.py index cf6c8a35..787e77d9 100644 --- a/admin_apps/journeys/comparator.py +++ b/admin_apps/journeys/comparator.py @@ -1,10 +1,20 @@ +import json +from typing import Any + +import pandas as pd +import sqlglot import streamlit as st -from admin_apps.shared_utils import GeneratorAppScreen -from admin_apps.shared_utils import return_home_button -from admin_apps.shared_utils import download_yaml_fqn -from semantic_model_generator.data_processing.proto_utils import yaml_to_semantic_model, proto_to_yaml +from loguru import logger +from snowflake.connector import SnowflakeConnection from streamlit_monaco import st_monaco +from admin_apps.shared_utils import GeneratorAppScreen, return_home_button, send_message +from semantic_model_generator.data_processing.proto_utils import yaml_to_semantic_model +from semantic_model_generator.snowflake_utils.snowflake_connector import ( + SnowflakeConnector, +) +from semantic_model_generator.validate_model import validate + MODEL1_PATH = "model1_path" MODEL1_YAML = "model1_yaml" MODEL2_PATH = "model2_path" @@ -16,13 +26,10 @@ def init_session_states() -> None: def comparator_app() -> None: - st.write( - """ - ## Compare two semantic models - """ - ) + return_home_button() + st.write("## Compare two semantic models") col1, col2 = st.columns(2) - with col1: + with col1, st.container(border=True): st.write(f"Model 1 from: `{st.session_state[MODEL1_PATH]}`") content1 = st_monaco( value=st.session_state[MODEL1_YAML], @@ -30,7 +37,7 @@ def comparator_app() -> None: language="yaml", ) - with col2: + with col2, st.container(border=True): st.write(f"Model 2 from: `{st.session_state[MODEL2_PATH]}`") content2 = st_monaco( value=st.session_state[MODEL2_YAML], @@ -38,78 +45,179 @@ def comparator_app() -> None: language="yaml", ) + if st.button("Validate models"): + with st.spinner(f"validating {st.session_state[MODEL1_PATH]}..."): + try: + validate(content1, st.session_state.account_name) + st.session_state["model1_valid"] = True + st.session_state[MODEL1_YAML] = content1 + except Exception as e: + st.error(f"Validation failed on the first model with error: {e}") + st.session_state["model1_valid"] = False + + with st.spinner(f"validating {st.session_state[MODEL2_PATH]}..."): + try: + validate(content2, st.session_state.account_name) + st.session_state["model2_valid"] = True + st.session_state[MODEL2_YAML] = content2 + except Exception as e: + st.error(f"Validation failed on the second model with error: {e}") + st.session_state["model2_valid"] = False + + if st.session_state.get("model1_valid", False) and st.session_state.get( + "model2_valid", False + ): + st.success("Both models are correct.") + st.session_state["validated"] = True + else: + st.error("Please fix the models and try again.") + st.session_state["validated"] = False + + if ( + content1 != st.session_state[MODEL1_YAML] + or content2 != st.session_state[MODEL2_YAML] + ): + st.info("Please validate the models again after making changes.") + st.session_state["validated"] = False + + if not st.session_state.get("validated", False): + st.info("Please validate the models first.") + else: + prompt = st.text_input( + "What question would you like to ask the Cortex Analyst?" + ) + if prompt: + st.write(f"Asking both models question: {prompt}") + user_message = [ + {"role": "user", "content": [{"type": "text", "text": prompt}]} + ] + connector = SnowflakeConnector( + account_name=st.session_state.account_name, + max_workers=1, + ) + conn = connector.open_connection(db_name="") + col1, col2 = st.columns(2) + with col1, st.container(border=True), st.spinner("Model 1 is thinking..."): + semantic_model = st.session_state[MODEL1_YAML] + json_resp = send_message( + conn, user_message, yaml_to_semantic_model(semantic_model) + ) + display_content(conn, json_resp["message"]["content"]) + st.json(json_resp, expanded=False) + + with col2, st.container(border=True), st.spinner("Model 2 is thinking..."): + semantic_model = st.session_state[MODEL2_YAML] + json_resp = send_message( + conn, user_message, yaml_to_semantic_model(semantic_model) + ) + display_content(conn, json_resp["message"]["content"]) + st.json(json_resp, expanded=False) + # TODO: - # - Compare the two models # - Show the differences - # - Validation of the models # - Check if both models are pointing at the same table - # - dialog to ask questions - # - Results of the cortex analyst with both models - - return_home_button() -def is_session_state_initialized() -> bool: - return all([ - MODEL1_YAML in st.session_state, - MODEL2_YAML in st.session_state, - MODEL1_PATH in st.session_state, - MODEL2_PATH in st.session_state, - ]) - - -def show() -> None: - init_session_states() - if is_session_state_initialized(): - comparator_app() - else: - init_dialog() - +@st.cache_data(show_spinner=False) +def prettify_sql(sql: str) -> str: + """ + Prettify SQL using SQLGlot with an option to use the Snowflake dialect for syntax checks. -def read_semantic_model(model_path: str) -> str: - """Read the semantic model from the given path (local or snowflake stage). - Args: - model_path (str): The path to the semantic model. + sql (str): SQL query string to be formatted. Returns: - str: The semantic model as a string. - - Raises: - FileNotFoundError: If the model is not found. + str: Formatted SQL string or input SQL if sqlglot failed to parse. """ - if model_path.startswith('@'): - return download_yaml_fqn(model_path) - else: - with open(model_path, "r") as f: - return f.read() + try: + # Parse the SQL using SQLGlot + expression = sqlglot.parse_one(sql, dialect="snowflake") + + # Generate formatted SQL, specifying the dialect if necessary for specific syntax transformations + formatted_sql: str = expression.sql(dialect="snowflake", pretty=True) + return formatted_sql + except Exception as e: + logger.debug(f"Failed to prettify SQL: {e}") + return sql + + +def display_content( + conn: SnowflakeConnection, + content: list[dict[str, Any]], +) -> None: + """Displays a content item for a message. For generated SQL, allow user to add to verified queries directly or edit then add.""" + for item in content: + if item["type"] == "text": + # If API rejects to answer directly and provided disambiguate suggestions, we'll return text with as prefix. + if "" in item["text"]: + suggestion_response = json.loads(item["text"][12:])[0] + st.markdown(suggestion_response["explanation"]) + with st.expander("Suggestions", expanded=True): + for suggestion in suggestion_response["suggestions"]: + st.markdown(f"- {suggestion}") + else: + st.markdown(item["text"]) + elif item["type"] == "suggestions": + with st.expander("Suggestions", expanded=True): + for suggestion in item["suggestions"]: + st.markdown(f"- {suggestion}") + elif item["type"] == "sql": + with st.container(height=500, border=False): + sql = item["statement"] + sql = prettify_sql(sql) + with st.container(height=250, border=False): + st.code(item["statement"], language="sql") + + df = pd.read_sql(sql, conn) + st.dataframe(df, hide_index=True) + + +def is_session_state_initialized() -> bool: + return all( + [ + MODEL1_YAML in st.session_state, + MODEL2_YAML in st.session_state, + MODEL1_PATH in st.session_state, + MODEL2_PATH in st.session_state, + ] + ) @st.dialog("Welcome to the Cortex Analyst Annotation Workspace! 📝", width="large") def init_dialog() -> None: init_session_states() - st.write("Please enter the paths (local or stage) of the two models you would like to compare.") + st.write( + "Please choose the two semantic model files that you would like to compare." + ) - model_1_path = st.text_input("Model 1", placeholder="e.g. /local/path/to/model1.yaml") - model_2_path = st.text_input("Model 2", placeholder="e.g. @DATABASE.SCHEMA.STAGE_NAME/path/to/model2.yaml") + model_1_file = st.file_uploader( + "Choose first semantic model file", + type=["yaml"], + help="Choose a local YAML file that contains semantic model.", + ) + model_2_file = st.file_uploader( + "Choose second semantic model file", + type=["yaml"], + help="Choose a local YAML file that contains semantic model.", + ) if st.button("Compare"): - model_1_yaml = model_2_yaml = None - try: - model_1_yaml = read_semantic_model(model_1_path) - except FileNotFoundError as e: - st.error(f"Model 1 not found: {e}") - try: - model_2_yaml = read_semantic_model(model_2_path) - except FileNotFoundError as e: - st.error(f"Model 2 not found: {e}") - - if model_1_yaml and model_2_yaml: - st.session_state[MODEL1_PATH] = model_1_path - st.session_state[MODEL1_YAML] = model_1_yaml - st.session_state[MODEL2_PATH] = model_2_path - st.session_state[MODEL2_YAML] = model_2_yaml + if model_1_file is None or model_2_file is None: + st.error("Please upload the both models first.") + else: + st.session_state[MODEL1_PATH] = model_1_file.name + st.session_state[MODEL1_YAML] = model_1_file.getvalue().decode("utf-8") + st.session_state[MODEL2_PATH] = model_2_file.name + st.session_state[MODEL2_YAML] = model_2_file.getvalue().decode("utf-8") st.rerun() return_home_button() + + +def show() -> None: + init_session_states() + if is_session_state_initialized(): + comparator_app() + else: + init_dialog() diff --git a/admin_apps/journeys/iteration.py b/admin_apps/journeys/iteration.py index bb9a8dd1..a5eaf590 100644 --- a/admin_apps/journeys/iteration.py +++ b/admin_apps/journeys/iteration.py @@ -14,6 +14,7 @@ from admin_apps.journeys.joins import joins_dialog from admin_apps.shared_utils import ( + API_ENDPOINT, GeneratorAppScreen, SnowflakeStage, changed_from_last_validated_model, @@ -75,9 +76,6 @@ def pretty_print_sql(sql: str) -> str: return formatted_sql -API_ENDPOINT = "https://{HOST}/api/v2/cortex/analyst/message" - - @st.cache_data(ttl=60, show_spinner=False) def send_message( _conn: SnowflakeConnection, messages: list[dict[str, str]] diff --git a/admin_apps/shared_utils.py b/admin_apps/shared_utils.py index a85d2329..18d48b8c 100644 --- a/admin_apps/shared_utils.py +++ b/admin_apps/shared_utils.py @@ -10,7 +10,9 @@ from typing import Any, Optional import pandas as pd +import requests import streamlit as st +from loguru import logger from PIL import Image from snowflake.connector import SnowflakeConnection @@ -43,6 +45,8 @@ "https://logos-world.net/wp-content/uploads/2022/11/Snowflake-Symbol.png" ) +API_ENDPOINT = "https://{HOST}/api/v2/cortex/analyst/message" + @st.cache_resource def get_connector() -> SnowflakeConnector: @@ -890,26 +894,39 @@ def download_yaml(file_name: str, conn: SnowflakeConnection) -> str: return yaml_str -def download_yaml_fqn(file_name: str, conn: SnowflakeConnection) -> str: - """util to download a semantic YAML from a stage.""" - import os - import tempfile - - if not file_name.endswith(".yaml") and not file_name.startswith("@"): - raise ValueError( - "file_name should be a valid, fully qualified stage name starting with @ with .yaml suffix." - ) - - with tempfile.TemporaryDirectory() as temp_dir: - # Downloads the YAML to {temp_dir}/{file_name}. - download_yaml_sql = f"GET {file_name} file://{temp_dir}" - conn.cursor().execute(download_yaml_sql) +def send_message( + _conn: SnowflakeConnection, + messages: list[dict[str, str]], + semantic_model: semantic_model_pb2.SemanticModel, +) -> dict[str, Any]: + """ + Calls the REST API with a list of messages and returns the response. + Args: + _conn: SnowflakeConnection, used to grab the token for auth. + messages: list of chat messages to pass to the Analyst API. - tmp_file_path = os.path.join(temp_dir, f"{file_name}") - with open(tmp_file_path, "r") as temp_file: - # Read the raw contents from {temp_dir}/{file_name} and return it as a string. - yaml_str = temp_file.read() - return yaml_str + Returns: The raw ChatMessage response from Analyst. + """ + request_body = { + "messages": messages, + "semantic_model": proto_to_yaml(semantic_model), + } + api_endpoint = API_ENDPOINT.format(HOST=st.session_state.host_name) + logger.debug(api_endpoint) + logger.debug(request_body) + resp = requests.post( + api_endpoint, + json=request_body, + headers={ + "Authorization": f'Snowflake Token="{_conn.rest.token}"', # type: ignore[union-attr] + "Content-Type": "application/json", + }, + ) + if resp.status_code < 400: + json_resp: dict[str, Any] = resp.json() + return json_resp + else: + raise Exception(f"Failed request with status {resp.status_code}: {resp.text}") def get_sit_query_tag( From 6ca027ebb126bd0fec304772301c855ee82023ae Mon Sep 17 00:00:00 2001 From: Adrian Stepniak Date: Tue, 22 Oct 2024 15:12:38 +0200 Subject: [PATCH 4/6] clean up --- admin_apps/journeys/iteration.py | 9 ++++----- admin_apps/shared_utils.py | 3 --- semantic_model_generator/data_processing/cte_utils.py | 1 - .../snowflake_utils/snowflake_connector.py | 1 - 4 files changed, 4 insertions(+), 10 deletions(-) diff --git a/admin_apps/journeys/iteration.py b/admin_apps/journeys/iteration.py index a5eaf590..43e1bf2e 100644 --- a/admin_apps/journeys/iteration.py +++ b/admin_apps/journeys/iteration.py @@ -6,7 +6,6 @@ import requests import sqlglot import streamlit as st -from loguru import logger from snowflake.connector import ProgrammingError, SnowflakeConnection from streamlit.delta_generator import DeltaGenerator from streamlit_extras.row import row @@ -92,10 +91,11 @@ def send_message( "messages": messages, "semantic_model": proto_to_yaml(st.session_state.semantic_model), } - api_endpoint = API_ENDPOINT.format(HOST=st.session_state.host_name) - logger.debug(f"Sending request to Analyst API at {api_endpoint}: {request_body}") + host = st.session_state.host_name resp = requests.post( - api_endpoint, + API_ENDPOINT.format( + HOST=host, + ), json=request_body, headers={ "Authorization": f'Snowflake Token="{_conn.rest.token}"', # type: ignore[union-attr] @@ -104,7 +104,6 @@ def send_message( ) if resp.status_code < 400: json_resp: Dict[str, Any] = resp.json() - logger.debug(f"Received response from Analyst API: {json_resp}") return json_resp else: raise Exception(f"Failed request with status {resp.status_code}: {resp.text}") diff --git a/admin_apps/shared_utils.py b/admin_apps/shared_utils.py index 18d48b8c..1026477c 100644 --- a/admin_apps/shared_utils.py +++ b/admin_apps/shared_utils.py @@ -12,7 +12,6 @@ import pandas as pd import requests import streamlit as st -from loguru import logger from PIL import Image from snowflake.connector import SnowflakeConnection @@ -912,8 +911,6 @@ def send_message( "semantic_model": proto_to_yaml(semantic_model), } api_endpoint = API_ENDPOINT.format(HOST=st.session_state.host_name) - logger.debug(api_endpoint) - logger.debug(request_body) resp = requests.post( api_endpoint, json=request_body, diff --git a/semantic_model_generator/data_processing/cte_utils.py b/semantic_model_generator/data_processing/cte_utils.py index 8300a54f..84f16a1b 100644 --- a/semantic_model_generator/data_processing/cte_utils.py +++ b/semantic_model_generator/data_processing/cte_utils.py @@ -142,7 +142,6 @@ def _generate_cte_for( cte += ",\n".join(expr_columns) + "\n" cte += f"FROM {fully_qualified_table_name(table.base_table)}" cte += ")" - return cte diff --git a/semantic_model_generator/snowflake_utils/snowflake_connector.py b/semantic_model_generator/snowflake_utils/snowflake_connector.py index 15966a86..ad957ef4 100644 --- a/semantic_model_generator/snowflake_utils/snowflake_connector.py +++ b/semantic_model_generator/snowflake_utils/snowflake_connector.py @@ -115,7 +115,6 @@ def _get_column_comment( values: {';'.join(column_values) if column_values else ""}; Please provide a business description for the column. Only return the description without any other text.""" complete_sql = f"select SNOWFLAKE.CORTEX.COMPLETE('{_autogen_model}', '{comment_prompt}')" - logger.debug(f"Complete_sql: {complete_sql}") cmt = conn.cursor().execute(complete_sql).fetchall()[0][0] # type: ignore[union-attr] return str(cmt + AUTOGEN_TOKEN) except Exception as e: From bcedfd0d908bc341cfdddba06fd7dd471047b2cd Mon Sep 17 00:00:00 2001 From: Adrian Stepniak Date: Tue, 22 Oct 2024 15:13:07 +0200 Subject: [PATCH 5/6] clean up --- admin_apps/journeys/comparator.py | 77 +++++++++++++++++++------------ 1 file changed, 48 insertions(+), 29 deletions(-) diff --git a/admin_apps/journeys/comparator.py b/admin_apps/journeys/comparator.py index 787e77d9..666f7ee6 100644 --- a/admin_apps/journeys/comparator.py +++ b/admin_apps/journeys/comparator.py @@ -1,4 +1,3 @@ -import json from typing import Any import pandas as pd @@ -91,33 +90,58 @@ def comparator_app() -> None: user_message = [ {"role": "user", "content": [{"type": "text", "text": prompt}]} ] - connector = SnowflakeConnector( + connection = SnowflakeConnector( account_name=st.session_state.account_name, max_workers=1, - ) - conn = connector.open_connection(db_name="") + ).open_connection(db_name="") + col1, col2 = st.columns(2) - with col1, st.container(border=True), st.spinner("Model 1 is thinking..."): - semantic_model = st.session_state[MODEL1_YAML] - json_resp = send_message( - conn, user_message, yaml_to_semantic_model(semantic_model) - ) - display_content(conn, json_resp["message"]["content"]) - st.json(json_resp, expanded=False) - - with col2, st.container(border=True), st.spinner("Model 2 is thinking..."): - semantic_model = st.session_state[MODEL2_YAML] - json_resp = send_message( - conn, user_message, yaml_to_semantic_model(semantic_model) - ) - display_content(conn, json_resp["message"]["content"]) - st.json(json_resp, expanded=False) + ask_cortex_analyst( + user_message, + st.session_state[MODEL1_YAML], + connection, + col1, + "Model 1 is thinking...", + ) + ask_cortex_analyst( + user_message, + st.session_state[MODEL2_YAML], + connection, + col2, + "Model 2 is thinking...", + ) # TODO: # - Show the differences # - Check if both models are pointing at the same table +def ask_cortex_analyst( + prompt: str, + semantic_model: str, + conn: SnowflakeConnection, + container: Any, + spinner_text: str, +) -> None: + """Ask the Cortex Analyst a question and display the response. + + Args: + prompt (str): The question to ask the Cortex Analyst. + semantic_model (str): The semantic model to use for the question. + conn (SnowflakeConnection): The Snowflake connection to use for the question. + container (st.DeltaGenerator): The streamlit container to display the response (e.g. st.columns()). + spinner_text (str): The text to display in the waiting spinner + + Returns: + None + + """ + with container, st.container(border=True), st.spinner(spinner_text): + json_resp = send_message(conn, prompt, yaml_to_semantic_model(semantic_model)) + display_content(conn, json_resp["message"]["content"]) + st.json(json_resp, expanded=False) + + @st.cache_data(show_spinner=False) def prettify_sql(sql: str) -> str: """ @@ -145,18 +169,10 @@ def display_content( conn: SnowflakeConnection, content: list[dict[str, Any]], ) -> None: - """Displays a content item for a message. For generated SQL, allow user to add to verified queries directly or edit then add.""" + """Displays a content item for a message from the Cortex Analyst.""" for item in content: if item["type"] == "text": - # If API rejects to answer directly and provided disambiguate suggestions, we'll return text with as prefix. - if "" in item["text"]: - suggestion_response = json.loads(item["text"][12:])[0] - st.markdown(suggestion_response["explanation"]) - with st.expander("Suggestions", expanded=True): - for suggestion in suggestion_response["suggestions"]: - st.markdown(f"- {suggestion}") - else: - st.markdown(item["text"]) + st.markdown(item["text"]) elif item["type"] == "suggestions": with st.expander("Suggestions", expanded=True): for suggestion in item["suggestions"]: @@ -170,6 +186,9 @@ def display_content( df = pd.read_sql(sql, conn) st.dataframe(df, hide_index=True) + else: + logger.warning(f"Unknown content type: {item['type']}") + st.write(item) def is_session_state_initialized() -> bool: From 822627d5dabe1e2bcbcbcf618f5e6813a1b54c77 Mon Sep 17 00:00:00 2001 From: Adrian Stepniak Date: Wed, 6 Nov 2024 12:26:54 +0100 Subject: [PATCH 6/6] add try-except when reading from db --- admin_apps/journeys/comparator.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/admin_apps/journeys/comparator.py b/admin_apps/journeys/comparator.py index 666f7ee6..16bce3e4 100644 --- a/admin_apps/journeys/comparator.py +++ b/admin_apps/journeys/comparator.py @@ -183,9 +183,11 @@ def display_content( sql = prettify_sql(sql) with st.container(height=250, border=False): st.code(item["statement"], language="sql") - - df = pd.read_sql(sql, conn) - st.dataframe(df, hide_index=True) + try: + df = pd.read_sql(sql, conn) + st.dataframe(df, hide_index=True) + except Exception as e: + st.error(f"Failed to execute SQL: {e}") else: logger.warning(f"Unknown content type: {item['type']}") st.write(item)