Skip to content

Commit

Permalink
After setting up a data source, show a progress bar until all models …
Browse files Browse the repository at this point in the history
…are ready.
  • Loading branch information
fabioz committed Dec 17, 2024
1 parent dbe3dda commit ab1b326
Show file tree
Hide file tree
Showing 13 changed files with 400 additions and 69 deletions.
3 changes: 1 addition & 2 deletions docs/changelog.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
## Unreleased

- Add setup support for external data sources
- Show a progress bar until all models are ready
- Add Data Sources to the Packages tree
- Hover (tree item) action: remove Data Source
- Hover (tree item) action: setup Data Source
Expand Down
24 changes: 22 additions & 2 deletions sema4ai-python-ls-core/src/sema4ai_ls_core/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,6 +998,11 @@ class ActionInfoTypedDict(TypedDict):
kind: str


class ModelStateTypedDict(TypedDict):
status: Literal["complete", "error", "generating", "training", "creating"]
error: str | None


class DatasourceInfoTypedDict(TypedDict):
python_variable_name: str | None
range: "RangeTypedDict"
Expand All @@ -1013,8 +1018,23 @@ class DatasourceInfoTypedDict(TypedDict):
file: str | None


class DatasourceInfoWithStatusTypedDict(DatasourceInfoTypedDict):
# --- Data Source State below ---
# if None, it means it wasn't set yet
configured: bool | None
# if None, it means it wasn't set yet or it's not a model
model_state: ModelStateTypedDict | None
configuration_valid: bool | None
configuration_errors: list[str] | None


class DataSourceStateDict(TypedDict):
unconfigured_data_sources: list[DatasourceInfoTypedDict]
# Unconfigured Data Sources (kind of redundant now that DataSourceInfoTypedDict has the 'configured' field
# but it's kept for backward compatibility).
unconfigured_data_sources: list[DatasourceInfoWithStatusTypedDict]
# Error messages on Data Sources.
uri_to_error_messages: dict[str, list["DiagnosticsTypedDict"]]
required_data_sources: list[DatasourceInfoTypedDict]
# All the required Data Sources.
required_data_sources: list[DatasourceInfoWithStatusTypedDict]
# All Data Sources in the data server.
data_sources_in_data_server: list[str]
28 changes: 28 additions & 0 deletions sema4ai/src/sema4ai_code/data/data_source_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,31 @@ def _update_custom_sql(self, datasource: "DatasourceInfoTypedDict") -> None | st
sqls.append(txt)
self._custom_sql = tuple(sqls)
return None


def get_data_source_caption(data_source: "DatasourceInfoTypedDict") -> str:
if data_source.get("created_table") and data_source.get("model_name"):
return f"Bad datasource: {data_source['name']} - created_table: {data_source['created_table']} and model_name: {data_source['model_name']} ({data_source['engine']})"

if data_source.get("created_table"):
if data_source["engine"] == "files" or data_source["engine"] == "custom":
return f"{data_source['name']}.{data_source['created_table']} ({data_source['engine']})"
return f"Bad datasource: {data_source['name']} - created_table: {data_source['created_table']} ({data_source['engine']}) - created_table is only expected in 'files' and 'custom' engines"

if data_source.get("model_name"):
if (
data_source["engine"].startswith("prediction")
or data_source["engine"] == "custom"
):
return f"{data_source['name']}.{data_source['model_name']} ({data_source['engine']})"
return f"Bad datasource: {data_source['name']} - model_name: {data_source['model_name']} ({data_source['engine']}) - model_name is only expected in 'prediction' and 'custom' engines"

# Created table is expected for files engine
if data_source["engine"] == "files":
return f"Bad datasource: {data_source['name']} ({data_source['engine']}) - expected created_table to be defined"

# Model name is expected for prediction engines
if data_source["engine"].startswith("prediction"):
return f"Bad datasource: {data_source['name']} ({data_source['engine']}) - expected model_name to be defined"

return f"{data_source['name']} ({data_source['engine']})"
142 changes: 110 additions & 32 deletions sema4ai/src/sema4ai_code/robocorp_language_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
TextDocumentCodeActionTypedDict,
)
from sema4ai_ls_core.protocols import (
ActionInfoTypedDict,
DatasourceInfoTypedDict,
DatasourceInfoWithStatusTypedDict,
DataSourceStateDict,
IConfig,
IMonitor,
Expand Down Expand Up @@ -77,10 +80,6 @@
WorkspaceInfoDict,
)
from sema4ai_code.refresh_agent_spec_helper import update_agent_spec
from sema4ai_code.robo.collect_actions_ast import (
ActionInfoTypedDict,
DatasourceInfoTypedDict,
)
from sema4ai_code.vendored_deps.package_deps._deps_protocols import (
ICondaCloud,
IPyPiCloud,
Expand Down Expand Up @@ -2469,11 +2468,22 @@ def _setup_data_source_impl(
) -> ActionResultDict[DataSourceSetupResponse]:
from sema4ai_ls_core.progress_report import progress_context

from sema4ai_code.data.data_source_helper import DataSourceHelper
from sema4ai_code.data.data_source_helper import (
DataSourceHelper,
get_data_source_caption,
)

with progress_context(
self._endpoint, "Setting up data sources", self._dir_cache
):
if not isinstance(datasource, list):
datasources = [datasource]
else:
datasources = datasource

if len(datasources) == 1:
caption = f"Data Source: {get_data_source_caption(datasources[0])}"
else:
caption = f"{len(datasources)} Data Sources"

with progress_context(self._endpoint, f"Setting up {caption}", self._dir_cache):
root_path = Path(uris.to_fs_path(action_package_yaml_directory_uri))
if not root_path.exists():
return (
Expand All @@ -2484,11 +2494,6 @@ def _setup_data_source_impl(
.as_dict()
)

if not isinstance(datasource, list):
datasources = [datasource]
else:
datasources = datasource

connection = self._get_connection(data_server_info)
messages = []
for datasource in datasources:
Expand Down Expand Up @@ -2607,24 +2612,30 @@ def m_compute_data_source_state(
self,
action_package_yaml_directory_uri: str,
data_server_info: DataServerConfigTypedDict,
show_progress: bool = True,
) -> partial[ActionResultDict[DataSourceStateDict]]:
return require_monitor(
partial(
self._compute_data_source_state_impl,
action_package_yaml_directory_uri,
data_server_info,
show_progress,
)
)

def _compute_data_source_state_impl(
self,
action_package_yaml_directory_uri: str,
data_server_info: DataServerConfigTypedDict,
show_progress: bool,
monitor: IMonitor,
) -> ActionResultDict[DataSourceStateDict]:
try:
return self._impl_compute_data_source_state_impl(
action_package_yaml_directory_uri, data_server_info, monitor
action_package_yaml_directory_uri,
data_server_info,
show_progress,
monitor,
)
except Exception as e:
log.exception("Error computing data source state")
Expand All @@ -2638,16 +2649,25 @@ def _impl_compute_data_source_state_impl(
self,
action_package_yaml_directory_uri: str,
data_server_info: DataServerConfigTypedDict,
show_progress: bool,
monitor: IMonitor,
) -> ActionResultDict[DataSourceStateDict]:
from sema4ai_ls_core.constants import NULL
from sema4ai_ls_core.lsp import DiagnosticSeverity, DiagnosticsTypedDict
from sema4ai_ls_core.progress_report import progress_context
from sema4ai_ls_core.protocols import ModelStateTypedDict

from sema4ai_code.data.data_source_helper import DataSourceHelper

with progress_context(
self._endpoint, "Computing data sources state", self._dir_cache
) as progress_reporter:
ctx: Any
if show_progress:
ctx = progress_context(
self._endpoint, "Computing data sources state", self._dir_cache
)
else:
ctx = NULL

with ctx as progress_reporter:
progress_reporter.set_additional_info("Listing actions")
actions_and_datasources_result: ActionResultDict[
"list[ActionInfoTypedDict | DatasourceInfoTypedDict]"
Expand All @@ -2665,7 +2685,9 @@ def _impl_compute_data_source_state_impl(
monitor.check_cancelled()
progress_reporter.set_additional_info("Getting data sources")
connection = self._get_connection(data_server_info)
monitor.check_cancelled()
projects_as_dicts = connection.get_data_sources("WHERE type = 'project'")
monitor.check_cancelled()
not_projects_as_dicts = connection.get_data_sources(
"WHERE type != 'project'"
)
Expand Down Expand Up @@ -2694,15 +2716,25 @@ def _impl_compute_data_source_state_impl(
monitor.check_cancelled()
progress_reporter.set_additional_info("Getting models")

data_source_to_models = {}
data_source_to_model_name_to_model_state: dict[
str, dict[str, ModelStateTypedDict]
] = {}
for data_source in projects_data_source_names_in_data_server:
monitor.check_cancelled()
result_set_models = connection.query(
data_source, "SELECT * FROM models"
)
if result_set_models:
data_source_to_models[data_source] = [
x["name"] for x in result_set_models.iter_as_dicts()
]
name_to_status_and_error: dict[str, ModelStateTypedDict] = {}
for entry in result_set_models.iter_as_dicts():
name_to_status_and_error[entry["name"]] = {
"status": entry["status"],
"error": entry["error"],
}

data_source_to_model_name_to_model_state[data_source] = (
name_to_status_and_error
)

monitor.check_cancelled()
progress_reporter.set_additional_info("Getting files")
Expand All @@ -2719,14 +2751,38 @@ def _impl_compute_data_source_state_impl(
actions_and_datasources: "list[ActionInfoTypedDict | DatasourceInfoTypedDict]" = actions_and_datasources_result[
"result"
]
required_data_sources: list["DatasourceInfoTypedDict"] = [
typing.cast("DatasourceInfoTypedDict", d)
for d in actions_and_datasources
if d["kind"] == "datasource"
]

unconfigured_data_sources: list["DatasourceInfoTypedDict"] = []
required_data_sources: list["DatasourceInfoWithStatusTypedDict"] = []
for d in actions_and_datasources:
if d["kind"] == "datasource":
datasource_info = typing.cast("DatasourceInfoTypedDict", d)
required_data_sources.append(
{
"python_variable_name": datasource_info[
"python_variable_name"
],
"range": datasource_info["range"],
"name": datasource_info["name"],
"uri": datasource_info["uri"],
"kind": "datasource",
"engine": datasource_info["engine"],
"model_name": datasource_info["model_name"],
"created_table": datasource_info["created_table"],
"setup_sql": datasource_info["setup_sql"],
"setup_sql_files": datasource_info["setup_sql_files"],
"description": datasource_info["description"],
"file": datasource_info["file"],
# --- Data Source State below (we'll set these later) ---
"configured": None,
"model_state": None,
"configuration_valid": None,
"configuration_errors": None,
}
)

unconfigured_data_sources: list["DatasourceInfoWithStatusTypedDict"] = []
uri_to_error_messages: dict[str, list[DiagnosticsTypedDict]] = {}

ret: DataSourceStateDict = {
"unconfigured_data_sources": unconfigured_data_sources,
"uri_to_error_messages": uri_to_error_messages,
Expand All @@ -2736,13 +2792,15 @@ def _impl_compute_data_source_state_impl(

if required_data_sources:
root_path = Path(uris.to_fs_path(action_package_yaml_directory_uri))
datasource: "DatasourceInfoTypedDict"
datasource: "DatasourceInfoWithStatusTypedDict"
for datasource in required_data_sources:
uri = datasource.get("uri", "<uri-missing>")
datasource_helper = DataSourceHelper(
root_path, datasource, connection
)
validation_errors = datasource_helper.get_validation_errors()
validation_errors: tuple[str, ...] = (
datasource_helper.get_validation_errors()
)
if validation_errors:
for validation_error in validation_errors:
uri_to_error_messages.setdefault(uri, []).append(
Expand All @@ -2752,16 +2810,22 @@ def _impl_compute_data_source_state_impl(
"message": validation_error,
}
)
datasource["configuration_valid"] = False
datasource["configuration_errors"] = list(validation_errors)
continue # this one is invalid, so, we can't go forward.

datasource["configuration_valid"] = True

datasource_name = datasource["name"]
datasource_engine = datasource["engine"]

if datasource_helper.is_table_datasource:
created_table = datasource["created_table"]
if datasource_engine == "files":
if created_table not in files_table_names:
datasource["configured"] = False
unconfigured_data_sources.append(datasource)
continue
else:
# Custom datasource with created_table.
tables_result_set = connection.query("files", "SHOW TABLES")
Expand All @@ -2770,18 +2834,32 @@ def _impl_compute_data_source_state_impl(
for x in tables_result_set.iter_as_dicts()
)
if created_table not in custom_table_names:
datasource["configured"] = False
unconfigured_data_sources.append(datasource)
continue

datasource["configured"] = True
continue # Ok, handled use case.

if datasource_helper.is_model_datasource:
model_name = datasource["model_name"]
if model_name not in data_source_to_models.get(
datasource_name, []
):
found = data_source_to_model_name_to_model_state.get(
datasource_name, {}
)

if model_name not in found:
datasource["configured"] = False
unconfigured_data_sources.append(datasource)
continue

datasource["model_state"] = found[model_name]
datasource["configured"] = True
continue

if datasource_name.lower() not in data_source_names_in_data_server:
datasource["configured"] = False
unconfigured_data_sources.append(datasource)
else:
datasource["configured"] = True

return ActionResult[DataSourceStateDict].make_success(ret).as_dict()
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def wait_for_models_to_be_ready(
row = next(iter(found))
status = row.get("status", "").lower()

if status in ("generating", "training"):
if status in ("generating", "training", "creating"):
still_training.setdefault(project, []).append(model)
log.info(
f"Waiting for model {project}.{model} to complete. Current status: {status}"
Expand Down
Loading

0 comments on commit ab1b326

Please sign in to comment.