diff --git a/admin_apps/app.py b/admin_apps/app.py index ad1f2b79..f380ea4c 100644 --- a/admin_apps/app.py +++ b/admin_apps/app.py @@ -62,7 +62,7 @@ def verify_environment_setup() -> None: if __name__ == "__main__": - from admin_apps.journeys import builder, iteration, partner + from admin_apps.journeys import builder, comparator, iteration, partner def onboarding_dialog() -> None: """ @@ -114,6 +114,14 @@ def onboarding_dialog() -> None: action="start", ) partner.show() + st.markdown("") + if st.button( + "**📝 Compare two semantic models**", + use_container_width=True, + type="primary", + ): + comparator.init_dialog() + st.markdown("") verify_environment_setup() @@ -130,5 +138,7 @@ def onboarding_dialog() -> None: # The builder flow is simply an intermediate dialog before the iteration flow. if st.session_state["page"] == GeneratorAppScreen.ITERATION: iteration.show() + elif st.session_state["page"] == GeneratorAppScreen.COMPARATOR: + comparator.show() else: onboarding_dialog() diff --git a/admin_apps/journeys/comparator.py b/admin_apps/journeys/comparator.py new file mode 100644 index 00000000..16bce3e4 --- /dev/null +++ b/admin_apps/journeys/comparator.py @@ -0,0 +1,244 @@ +from typing import Any + +import pandas as pd +import sqlglot +import streamlit as st +from loguru import logger +from snowflake.connector import SnowflakeConnection +from streamlit_monaco import st_monaco + +from admin_apps.shared_utils import GeneratorAppScreen, return_home_button, send_message +from semantic_model_generator.data_processing.proto_utils import yaml_to_semantic_model +from semantic_model_generator.snowflake_utils.snowflake_connector import ( + SnowflakeConnector, +) +from semantic_model_generator.validate_model import validate + +MODEL1_PATH = "model1_path" +MODEL1_YAML = "model1_yaml" +MODEL2_PATH = "model2_path" +MODEL2_YAML = "model2_yaml" + + +def init_session_states() -> None: + st.session_state["page"] = GeneratorAppScreen.COMPARATOR + + +def comparator_app() -> None: + return_home_button() + st.write("## Compare two semantic models") + col1, col2 = st.columns(2) + with col1, st.container(border=True): + st.write(f"Model 1 from: `{st.session_state[MODEL1_PATH]}`") + content1 = st_monaco( + value=st.session_state[MODEL1_YAML], + height="400px", + language="yaml", + ) + + with col2, st.container(border=True): + st.write(f"Model 2 from: `{st.session_state[MODEL2_PATH]}`") + content2 = st_monaco( + value=st.session_state[MODEL2_YAML], + height="400px", + language="yaml", + ) + + if st.button("Validate models"): + with st.spinner(f"validating {st.session_state[MODEL1_PATH]}..."): + try: + validate(content1, st.session_state.account_name) + st.session_state["model1_valid"] = True + st.session_state[MODEL1_YAML] = content1 + except Exception as e: + st.error(f"Validation failed on the first model with error: {e}") + st.session_state["model1_valid"] = False + + with st.spinner(f"validating {st.session_state[MODEL2_PATH]}..."): + try: + validate(content2, st.session_state.account_name) + st.session_state["model2_valid"] = True + st.session_state[MODEL2_YAML] = content2 + except Exception as e: + st.error(f"Validation failed on the second model with error: {e}") + st.session_state["model2_valid"] = False + + if st.session_state.get("model1_valid", False) and st.session_state.get( + "model2_valid", False + ): + st.success("Both models are correct.") + st.session_state["validated"] = True + else: + st.error("Please fix the models and try again.") + st.session_state["validated"] = False + + if ( + content1 != st.session_state[MODEL1_YAML] + or content2 != st.session_state[MODEL2_YAML] + ): + st.info("Please validate the models again after making changes.") + st.session_state["validated"] = False + + if not st.session_state.get("validated", False): + st.info("Please validate the models first.") + else: + prompt = st.text_input( + "What question would you like to ask the Cortex Analyst?" + ) + if prompt: + st.write(f"Asking both models question: {prompt}") + user_message = [ + {"role": "user", "content": [{"type": "text", "text": prompt}]} + ] + connection = SnowflakeConnector( + account_name=st.session_state.account_name, + max_workers=1, + ).open_connection(db_name="") + + col1, col2 = st.columns(2) + ask_cortex_analyst( + user_message, + st.session_state[MODEL1_YAML], + connection, + col1, + "Model 1 is thinking...", + ) + ask_cortex_analyst( + user_message, + st.session_state[MODEL2_YAML], + connection, + col2, + "Model 2 is thinking...", + ) + + # TODO: + # - Show the differences + # - Check if both models are pointing at the same table + + +def ask_cortex_analyst( + prompt: str, + semantic_model: str, + conn: SnowflakeConnection, + container: Any, + spinner_text: str, +) -> None: + """Ask the Cortex Analyst a question and display the response. + + Args: + prompt (str): The question to ask the Cortex Analyst. + semantic_model (str): The semantic model to use for the question. + conn (SnowflakeConnection): The Snowflake connection to use for the question. + container (st.DeltaGenerator): The streamlit container to display the response (e.g. st.columns()). + spinner_text (str): The text to display in the waiting spinner + + Returns: + None + + """ + with container, st.container(border=True), st.spinner(spinner_text): + json_resp = send_message(conn, prompt, yaml_to_semantic_model(semantic_model)) + display_content(conn, json_resp["message"]["content"]) + st.json(json_resp, expanded=False) + + +@st.cache_data(show_spinner=False) +def prettify_sql(sql: str) -> str: + """ + Prettify SQL using SQLGlot with an option to use the Snowflake dialect for syntax checks. + + Args: + sql (str): SQL query string to be formatted. + + Returns: + str: Formatted SQL string or input SQL if sqlglot failed to parse. + """ + try: + # Parse the SQL using SQLGlot + expression = sqlglot.parse_one(sql, dialect="snowflake") + + # Generate formatted SQL, specifying the dialect if necessary for specific syntax transformations + formatted_sql: str = expression.sql(dialect="snowflake", pretty=True) + return formatted_sql + except Exception as e: + logger.debug(f"Failed to prettify SQL: {e}") + return sql + + +def display_content( + conn: SnowflakeConnection, + content: list[dict[str, Any]], +) -> None: + """Displays a content item for a message from the Cortex Analyst.""" + for item in content: + if item["type"] == "text": + st.markdown(item["text"]) + elif item["type"] == "suggestions": + with st.expander("Suggestions", expanded=True): + for suggestion in item["suggestions"]: + st.markdown(f"- {suggestion}") + elif item["type"] == "sql": + with st.container(height=500, border=False): + sql = item["statement"] + sql = prettify_sql(sql) + with st.container(height=250, border=False): + st.code(item["statement"], language="sql") + try: + df = pd.read_sql(sql, conn) + st.dataframe(df, hide_index=True) + except Exception as e: + st.error(f"Failed to execute SQL: {e}") + else: + logger.warning(f"Unknown content type: {item['type']}") + st.write(item) + + +def is_session_state_initialized() -> bool: + return all( + [ + MODEL1_YAML in st.session_state, + MODEL2_YAML in st.session_state, + MODEL1_PATH in st.session_state, + MODEL2_PATH in st.session_state, + ] + ) + + +@st.dialog("Welcome to the Cortex Analyst Annotation Workspace! 📝", width="large") +def init_dialog() -> None: + init_session_states() + + st.write( + "Please choose the two semantic model files that you would like to compare." + ) + + model_1_file = st.file_uploader( + "Choose first semantic model file", + type=["yaml"], + help="Choose a local YAML file that contains semantic model.", + ) + model_2_file = st.file_uploader( + "Choose second semantic model file", + type=["yaml"], + help="Choose a local YAML file that contains semantic model.", + ) + + if st.button("Compare"): + if model_1_file is None or model_2_file is None: + st.error("Please upload the both models first.") + else: + st.session_state[MODEL1_PATH] = model_1_file.name + st.session_state[MODEL1_YAML] = model_1_file.getvalue().decode("utf-8") + st.session_state[MODEL2_PATH] = model_2_file.name + st.session_state[MODEL2_YAML] = model_2_file.getvalue().decode("utf-8") + st.rerun() + + return_home_button() + + +def show() -> None: + init_session_states() + if is_session_state_initialized(): + comparator_app() + else: + init_dialog() diff --git a/admin_apps/journeys/iteration.py b/admin_apps/journeys/iteration.py index b1c01f10..43e1bf2e 100644 --- a/admin_apps/journeys/iteration.py +++ b/admin_apps/journeys/iteration.py @@ -13,6 +13,7 @@ from admin_apps.journeys.joins import joins_dialog from admin_apps.shared_utils import ( + API_ENDPOINT, GeneratorAppScreen, SnowflakeStage, changed_from_last_validated_model, @@ -74,9 +75,6 @@ def pretty_print_sql(sql: str) -> str: return formatted_sql -API_ENDPOINT = "https://{HOST}/api/v2/cortex/analyst/message" - - @st.cache_data(ttl=60, show_spinner=False) def send_message( _conn: SnowflakeConnection, messages: list[dict[str, str]] diff --git a/admin_apps/shared_utils.py b/admin_apps/shared_utils.py index 1e77c3f3..1026477c 100644 --- a/admin_apps/shared_utils.py +++ b/admin_apps/shared_utils.py @@ -10,6 +10,7 @@ from typing import Any, Optional import pandas as pd +import requests import streamlit as st from PIL import Image from snowflake.connector import SnowflakeConnection @@ -43,6 +44,8 @@ "https://logos-world.net/wp-content/uploads/2022/11/Snowflake-Symbol.png" ) +API_ENDPOINT = "https://{HOST}/api/v2/cortex/analyst/message" + @st.cache_resource def get_connector() -> SnowflakeConnector: @@ -120,6 +123,7 @@ class GeneratorAppScreen(str, Enum): ONBOARDING = "onboarding" ITERATION = "iteration" + COMPARATOR = "comparator" def return_home_button() -> None: @@ -889,6 +893,39 @@ def download_yaml(file_name: str, conn: SnowflakeConnection) -> str: return yaml_str +def send_message( + _conn: SnowflakeConnection, + messages: list[dict[str, str]], + semantic_model: semantic_model_pb2.SemanticModel, +) -> dict[str, Any]: + """ + Calls the REST API with a list of messages and returns the response. + Args: + _conn: SnowflakeConnection, used to grab the token for auth. + messages: list of chat messages to pass to the Analyst API. + + Returns: The raw ChatMessage response from Analyst. + """ + request_body = { + "messages": messages, + "semantic_model": proto_to_yaml(semantic_model), + } + api_endpoint = API_ENDPOINT.format(HOST=st.session_state.host_name) + resp = requests.post( + api_endpoint, + json=request_body, + headers={ + "Authorization": f'Snowflake Token="{_conn.rest.token}"', # type: ignore[union-attr] + "Content-Type": "application/json", + }, + ) + if resp.status_code < 400: + json_resp: dict[str, Any] = resp.json() + return json_resp + else: + raise Exception(f"Failed request with status {resp.status_code}: {resp.text}") + + def get_sit_query_tag( vendor: Optional[str] = None, action: Optional[str] = None ) -> str: diff --git a/semantic_model_generator/data_processing/cte_utils.py b/semantic_model_generator/data_processing/cte_utils.py index 92857c54..84f16a1b 100644 --- a/semantic_model_generator/data_processing/cte_utils.py +++ b/semantic_model_generator/data_processing/cte_utils.py @@ -266,7 +266,10 @@ def generate_select( non_agg_cte + f"SELECT * FROM {logical_table_name(table_in_column_format)} LIMIT {limit}" ) - sqls_to_return.append(_convert_to_snowflake_sql(non_agg_sql)) + # sqls_to_return.append(_convert_to_snowflake_sql(non_agg_sql)) + sqls_to_return.append( + non_agg_sql + ) # do not convert to snowflake sql for now, as sqlglot make mistakes sometimes, e.g. with TO_DATE() # Generate select query for columns with aggregation exprs. agg_cols = [ @@ -280,7 +283,10 @@ def generate_select( agg_cte + f"SELECT * FROM {logical_table_name(table_in_column_format)} LIMIT {limit}" ) - sqls_to_return.append(_convert_to_snowflake_sql(agg_sql)) + # sqls_to_return.append(_convert_to_snowflake_sql(agg_sql)) + sqls_to_return.append( + agg_sql + ) # do not convert to snowflake sql for now, as sqlglot make mistakes sometimes, e.g. with TO_DATE() return sqls_to_return diff --git a/semantic_model_generator/snowflake_utils/snowflake_connector.py b/semantic_model_generator/snowflake_utils/snowflake_connector.py index ea00f4bf..ad957ef4 100644 --- a/semantic_model_generator/snowflake_utils/snowflake_connector.py +++ b/semantic_model_generator/snowflake_utils/snowflake_connector.py @@ -16,7 +16,7 @@ ConnectionType = TypeVar("ConnectionType") # Append this to the end of the auto-generated comments to indicate that the comment was auto-generated. AUTOGEN_TOKEN = "__" -_autogen_model = "llama3-8b" +_autogen_model = "llama3.1-70b" # This is the raw column name from snowflake information schema or desc table _COMMENT_COL = "COMMENT"