Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

After setting up a data source, show a progress bar until all models … #157

Merged
merged 1 commit into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions docs/changelog.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
## Unreleased

- Add setup support for external data sources
- Show a progress bar until all models are ready
- Add Data Sources to the Packages tree
- Hover (tree item) action: remove Data Source
- Hover (tree item) action: setup Data Source
Expand Down
24 changes: 22 additions & 2 deletions sema4ai-python-ls-core/src/sema4ai_ls_core/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,6 +998,11 @@ class ActionInfoTypedDict(TypedDict):
kind: str


class ModelStateTypedDict(TypedDict):
status: Literal["complete", "error", "generating", "training", "creating"]
error: str | None


class DatasourceInfoTypedDict(TypedDict):
python_variable_name: str | None
range: "RangeTypedDict"
Expand All @@ -1013,8 +1018,23 @@ class DatasourceInfoTypedDict(TypedDict):
file: str | None


class DatasourceInfoWithStatusTypedDict(DatasourceInfoTypedDict):
# --- Data Source State below ---
# if None, it means it wasn't set yet
configured: bool | None
# if None, it means it wasn't set yet or it's not a model
model_state: ModelStateTypedDict | None
configuration_valid: bool | None
configuration_errors: list[str] | None


class DataSourceStateDict(TypedDict):
unconfigured_data_sources: list[DatasourceInfoTypedDict]
# Unconfigured Data Sources (kind of redundant now that DataSourceInfoTypedDict has the 'configured' field
# but it's kept for backward compatibility).
unconfigured_data_sources: list[DatasourceInfoWithStatusTypedDict]
# Error messages on Data Sources.
uri_to_error_messages: dict[str, list["DiagnosticsTypedDict"]]
required_data_sources: list[DatasourceInfoTypedDict]
# All the required Data Sources.
required_data_sources: list[DatasourceInfoWithStatusTypedDict]
# All Data Sources in the data server.
data_sources_in_data_server: list[str]
28 changes: 28 additions & 0 deletions sema4ai/src/sema4ai_code/data/data_source_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,31 @@ def _update_custom_sql(self, datasource: "DatasourceInfoTypedDict") -> None | st
sqls.append(txt)
self._custom_sql = tuple(sqls)
return None


def get_data_source_caption(data_source: "DatasourceInfoTypedDict") -> str:
if data_source.get("created_table") and data_source.get("model_name"):
return f"Bad datasource: {data_source['name']} - created_table: {data_source['created_table']} and model_name: {data_source['model_name']} ({data_source['engine']})"

if data_source.get("created_table"):
if data_source["engine"] == "files" or data_source["engine"] == "custom":
return f"{data_source['name']}.{data_source['created_table']} ({data_source['engine']})"
return f"Bad datasource: {data_source['name']} - created_table: {data_source['created_table']} ({data_source['engine']}) - created_table is only expected in 'files' and 'custom' engines"

if data_source.get("model_name"):
if (
data_source["engine"].startswith("prediction")
or data_source["engine"] == "custom"
):
return f"{data_source['name']}.{data_source['model_name']} ({data_source['engine']})"
return f"Bad datasource: {data_source['name']} - model_name: {data_source['model_name']} ({data_source['engine']}) - model_name is only expected in 'prediction' and 'custom' engines"

# Created table is expected for files engine
if data_source["engine"] == "files":
return f"Bad datasource: {data_source['name']} ({data_source['engine']}) - expected created_table to be defined"

# Model name is expected for prediction engines
if data_source["engine"].startswith("prediction"):
return f"Bad datasource: {data_source['name']} ({data_source['engine']}) - expected model_name to be defined"

return f"{data_source['name']} ({data_source['engine']})"
142 changes: 110 additions & 32 deletions sema4ai/src/sema4ai_code/robocorp_language_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
TextDocumentCodeActionTypedDict,
)
from sema4ai_ls_core.protocols import (
ActionInfoTypedDict,
DatasourceInfoTypedDict,
DatasourceInfoWithStatusTypedDict,
DataSourceStateDict,
IConfig,
IMonitor,
Expand Down Expand Up @@ -77,10 +80,6 @@
WorkspaceInfoDict,
)
from sema4ai_code.refresh_agent_spec_helper import update_agent_spec
from sema4ai_code.robo.collect_actions_ast import (
ActionInfoTypedDict,
DatasourceInfoTypedDict,
)
from sema4ai_code.vendored_deps.package_deps._deps_protocols import (
ICondaCloud,
IPyPiCloud,
Expand Down Expand Up @@ -2469,11 +2468,22 @@ def _setup_data_source_impl(
) -> ActionResultDict[DataSourceSetupResponse]:
from sema4ai_ls_core.progress_report import progress_context

from sema4ai_code.data.data_source_helper import DataSourceHelper
from sema4ai_code.data.data_source_helper import (
DataSourceHelper,
get_data_source_caption,
)

with progress_context(
self._endpoint, "Setting up data sources", self._dir_cache
):
if not isinstance(datasource, list):
datasources = [datasource]
else:
datasources = datasource

if len(datasources) == 1:
caption = f"Data Source: {get_data_source_caption(datasources[0])}"
else:
caption = f"{len(datasources)} Data Sources"

with progress_context(self._endpoint, f"Setting up {caption}", self._dir_cache):
root_path = Path(uris.to_fs_path(action_package_yaml_directory_uri))
if not root_path.exists():
return (
Expand All @@ -2484,11 +2494,6 @@ def _setup_data_source_impl(
.as_dict()
)

if not isinstance(datasource, list):
datasources = [datasource]
else:
datasources = datasource

connection = self._get_connection(data_server_info)
messages = []
for datasource in datasources:
Expand Down Expand Up @@ -2607,24 +2612,30 @@ def m_compute_data_source_state(
self,
action_package_yaml_directory_uri: str,
data_server_info: DataServerConfigTypedDict,
show_progress: bool = True,
) -> partial[ActionResultDict[DataSourceStateDict]]:
return require_monitor(
partial(
self._compute_data_source_state_impl,
action_package_yaml_directory_uri,
data_server_info,
show_progress,
)
)

def _compute_data_source_state_impl(
self,
action_package_yaml_directory_uri: str,
data_server_info: DataServerConfigTypedDict,
show_progress: bool,
monitor: IMonitor,
) -> ActionResultDict[DataSourceStateDict]:
try:
return self._impl_compute_data_source_state_impl(
action_package_yaml_directory_uri, data_server_info, monitor
action_package_yaml_directory_uri,
data_server_info,
show_progress,
monitor,
)
except Exception as e:
log.exception("Error computing data source state")
Expand All @@ -2638,16 +2649,25 @@ def _impl_compute_data_source_state_impl(
self,
action_package_yaml_directory_uri: str,
data_server_info: DataServerConfigTypedDict,
show_progress: bool,
monitor: IMonitor,
) -> ActionResultDict[DataSourceStateDict]:
from sema4ai_ls_core.constants import NULL
from sema4ai_ls_core.lsp import DiagnosticSeverity, DiagnosticsTypedDict
from sema4ai_ls_core.progress_report import progress_context
from sema4ai_ls_core.protocols import ModelStateTypedDict

from sema4ai_code.data.data_source_helper import DataSourceHelper

with progress_context(
self._endpoint, "Computing data sources state", self._dir_cache
) as progress_reporter:
ctx: Any
if show_progress:
ctx = progress_context(
self._endpoint, "Computing data sources state", self._dir_cache
)
else:
ctx = NULL

with ctx as progress_reporter:
progress_reporter.set_additional_info("Listing actions")
actions_and_datasources_result: ActionResultDict[
"list[ActionInfoTypedDict | DatasourceInfoTypedDict]"
Expand All @@ -2665,7 +2685,9 @@ def _impl_compute_data_source_state_impl(
monitor.check_cancelled()
progress_reporter.set_additional_info("Getting data sources")
connection = self._get_connection(data_server_info)
monitor.check_cancelled()
projects_as_dicts = connection.get_data_sources("WHERE type = 'project'")
monitor.check_cancelled()
not_projects_as_dicts = connection.get_data_sources(
"WHERE type != 'project'"
)
Expand Down Expand Up @@ -2694,15 +2716,25 @@ def _impl_compute_data_source_state_impl(
monitor.check_cancelled()
progress_reporter.set_additional_info("Getting models")

data_source_to_models = {}
data_source_to_model_name_to_model_state: dict[
str, dict[str, ModelStateTypedDict]
] = {}
for data_source in projects_data_source_names_in_data_server:
monitor.check_cancelled()
result_set_models = connection.query(
data_source, "SELECT * FROM models"
)
if result_set_models:
data_source_to_models[data_source] = [
x["name"] for x in result_set_models.iter_as_dicts()
]
name_to_status_and_error: dict[str, ModelStateTypedDict] = {}
for entry in result_set_models.iter_as_dicts():
name_to_status_and_error[entry["name"]] = {
"status": entry["status"],
"error": entry["error"],
}

data_source_to_model_name_to_model_state[data_source] = (
name_to_status_and_error
)

monitor.check_cancelled()
progress_reporter.set_additional_info("Getting files")
Expand All @@ -2719,14 +2751,38 @@ def _impl_compute_data_source_state_impl(
actions_and_datasources: "list[ActionInfoTypedDict | DatasourceInfoTypedDict]" = actions_and_datasources_result[
"result"
]
required_data_sources: list["DatasourceInfoTypedDict"] = [
typing.cast("DatasourceInfoTypedDict", d)
for d in actions_and_datasources
if d["kind"] == "datasource"
]

unconfigured_data_sources: list["DatasourceInfoTypedDict"] = []
required_data_sources: list["DatasourceInfoWithStatusTypedDict"] = []
for d in actions_and_datasources:
if d["kind"] == "datasource":
datasource_info = typing.cast("DatasourceInfoTypedDict", d)
required_data_sources.append(
{
"python_variable_name": datasource_info[
"python_variable_name"
],
"range": datasource_info["range"],
"name": datasource_info["name"],
"uri": datasource_info["uri"],
"kind": "datasource",
"engine": datasource_info["engine"],
"model_name": datasource_info["model_name"],
"created_table": datasource_info["created_table"],
"setup_sql": datasource_info["setup_sql"],
"setup_sql_files": datasource_info["setup_sql_files"],
"description": datasource_info["description"],
"file": datasource_info["file"],
# --- Data Source State below (we'll set these later) ---
"configured": None,
"model_state": None,
"configuration_valid": None,
"configuration_errors": None,
}
)

unconfigured_data_sources: list["DatasourceInfoWithStatusTypedDict"] = []
uri_to_error_messages: dict[str, list[DiagnosticsTypedDict]] = {}

ret: DataSourceStateDict = {
"unconfigured_data_sources": unconfigured_data_sources,
"uri_to_error_messages": uri_to_error_messages,
Expand All @@ -2736,13 +2792,15 @@ def _impl_compute_data_source_state_impl(

if required_data_sources:
root_path = Path(uris.to_fs_path(action_package_yaml_directory_uri))
datasource: "DatasourceInfoTypedDict"
datasource: "DatasourceInfoWithStatusTypedDict"
for datasource in required_data_sources:
uri = datasource.get("uri", "<uri-missing>")
datasource_helper = DataSourceHelper(
root_path, datasource, connection
)
validation_errors = datasource_helper.get_validation_errors()
validation_errors: tuple[str, ...] = (
datasource_helper.get_validation_errors()
)
if validation_errors:
for validation_error in validation_errors:
uri_to_error_messages.setdefault(uri, []).append(
Expand All @@ -2752,16 +2810,22 @@ def _impl_compute_data_source_state_impl(
"message": validation_error,
}
)
datasource["configuration_valid"] = False
datasource["configuration_errors"] = list(validation_errors)
continue # this one is invalid, so, we can't go forward.

datasource["configuration_valid"] = True

datasource_name = datasource["name"]
datasource_engine = datasource["engine"]

if datasource_helper.is_table_datasource:
created_table = datasource["created_table"]
if datasource_engine == "files":
if created_table not in files_table_names:
datasource["configured"] = False
unconfigured_data_sources.append(datasource)
continue
else:
# Custom datasource with created_table.
tables_result_set = connection.query("files", "SHOW TABLES")
Expand All @@ -2770,18 +2834,32 @@ def _impl_compute_data_source_state_impl(
for x in tables_result_set.iter_as_dicts()
)
if created_table not in custom_table_names:
datasource["configured"] = False
unconfigured_data_sources.append(datasource)
continue

datasource["configured"] = True
continue # Ok, handled use case.

if datasource_helper.is_model_datasource:
model_name = datasource["model_name"]
if model_name not in data_source_to_models.get(
datasource_name, []
):
found = data_source_to_model_name_to_model_state.get(
datasource_name, {}
)

if model_name not in found:
datasource["configured"] = False
unconfigured_data_sources.append(datasource)
continue

datasource["model_state"] = found[model_name]
datasource["configured"] = True
continue

if datasource_name.lower() not in data_source_names_in_data_server:
datasource["configured"] = False
unconfigured_data_sources.append(datasource)
else:
datasource["configured"] = True

return ActionResult[DataSourceStateDict].make_success(ret).as_dict()
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def wait_for_models_to_be_ready(
row = next(iter(found))
status = row.get("status", "").lower()

if status in ("generating", "training"):
if status in ("generating", "training", "creating"):
still_training.setdefault(project, []).append(model)
log.info(
f"Waiting for model {project}.{model} to complete. Current status: {status}"
Expand Down
Loading
Loading