diff --git a/docs/changelog.md b/docs/changelog.md index 4edf09e4..1a17a6b9 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,6 +1,5 @@ -## Unreleased - - Add setup support for external data sources + - Show a progress bar until all models are ready - Add Data Sources to the Packages tree - Hover (tree item) action: remove Data Source - Hover (tree item) action: setup Data Source diff --git a/sema4ai-python-ls-core/src/sema4ai_ls_core/protocols.py b/sema4ai-python-ls-core/src/sema4ai_ls_core/protocols.py index 44802c41..13301646 100644 --- a/sema4ai-python-ls-core/src/sema4ai_ls_core/protocols.py +++ b/sema4ai-python-ls-core/src/sema4ai_ls_core/protocols.py @@ -998,6 +998,11 @@ class ActionInfoTypedDict(TypedDict): kind: str +class ModelStateTypedDict(TypedDict): + status: Literal["complete", "error", "generating", "training", "creating"] + error: str | None + + class DatasourceInfoTypedDict(TypedDict): python_variable_name: str | None range: "RangeTypedDict" @@ -1013,8 +1018,23 @@ class DatasourceInfoTypedDict(TypedDict): file: str | None +class DatasourceInfoWithStatusTypedDict(DatasourceInfoTypedDict): + # --- Data Source State below --- + # if None, it means it wasn't set yet + configured: bool | None + # if None, it means it wasn't set yet or it's not a model + model_state: ModelStateTypedDict | None + configuration_valid: bool | None + configuration_errors: list[str] | None + + class DataSourceStateDict(TypedDict): - unconfigured_data_sources: list[DatasourceInfoTypedDict] + # Unconfigured Data Sources (kind of redundant now that DataSourceInfoTypedDict has the 'configured' field + # but it's kept for backward compatibility). + unconfigured_data_sources: list[DatasourceInfoWithStatusTypedDict] + # Error messages on Data Sources. uri_to_error_messages: dict[str, list["DiagnosticsTypedDict"]] - required_data_sources: list[DatasourceInfoTypedDict] + # All the required Data Sources. + required_data_sources: list[DatasourceInfoWithStatusTypedDict] + # All Data Sources in the data server. data_sources_in_data_server: list[str] diff --git a/sema4ai/src/sema4ai_code/data/data_source_helper.py b/sema4ai/src/sema4ai_code/data/data_source_helper.py index 3566a674..1c132abe 100644 --- a/sema4ai/src/sema4ai_code/data/data_source_helper.py +++ b/sema4ai/src/sema4ai_code/data/data_source_helper.py @@ -142,3 +142,31 @@ def _update_custom_sql(self, datasource: "DatasourceInfoTypedDict") -> None | st sqls.append(txt) self._custom_sql = tuple(sqls) return None + + +def get_data_source_caption(data_source: "DatasourceInfoTypedDict") -> str: + if data_source.get("created_table") and data_source.get("model_name"): + return f"Bad datasource: {data_source['name']} - created_table: {data_source['created_table']} and model_name: {data_source['model_name']} ({data_source['engine']})" + + if data_source.get("created_table"): + if data_source["engine"] == "files" or data_source["engine"] == "custom": + return f"{data_source['name']}.{data_source['created_table']} ({data_source['engine']})" + return f"Bad datasource: {data_source['name']} - created_table: {data_source['created_table']} ({data_source['engine']}) - created_table is only expected in 'files' and 'custom' engines" + + if data_source.get("model_name"): + if ( + data_source["engine"].startswith("prediction") + or data_source["engine"] == "custom" + ): + return f"{data_source['name']}.{data_source['model_name']} ({data_source['engine']})" + return f"Bad datasource: {data_source['name']} - model_name: {data_source['model_name']} ({data_source['engine']}) - model_name is only expected in 'prediction' and 'custom' engines" + + # Created table is expected for files engine + if data_source["engine"] == "files": + return f"Bad datasource: {data_source['name']} ({data_source['engine']}) - expected created_table to be defined" + + # Model name is expected for prediction engines + if data_source["engine"].startswith("prediction"): + return f"Bad datasource: {data_source['name']} ({data_source['engine']}) - expected model_name to be defined" + + return f"{data_source['name']} ({data_source['engine']})" diff --git a/sema4ai/src/sema4ai_code/robocorp_language_server.py b/sema4ai/src/sema4ai_code/robocorp_language_server.py index 3dcfa25b..e98fe90b 100644 --- a/sema4ai/src/sema4ai_code/robocorp_language_server.py +++ b/sema4ai/src/sema4ai_code/robocorp_language_server.py @@ -21,6 +21,9 @@ TextDocumentCodeActionTypedDict, ) from sema4ai_ls_core.protocols import ( + ActionInfoTypedDict, + DatasourceInfoTypedDict, + DatasourceInfoWithStatusTypedDict, DataSourceStateDict, IConfig, IMonitor, @@ -77,10 +80,6 @@ WorkspaceInfoDict, ) from sema4ai_code.refresh_agent_spec_helper import update_agent_spec -from sema4ai_code.robo.collect_actions_ast import ( - ActionInfoTypedDict, - DatasourceInfoTypedDict, -) from sema4ai_code.vendored_deps.package_deps._deps_protocols import ( ICondaCloud, IPyPiCloud, @@ -2469,11 +2468,22 @@ def _setup_data_source_impl( ) -> ActionResultDict[DataSourceSetupResponse]: from sema4ai_ls_core.progress_report import progress_context - from sema4ai_code.data.data_source_helper import DataSourceHelper + from sema4ai_code.data.data_source_helper import ( + DataSourceHelper, + get_data_source_caption, + ) - with progress_context( - self._endpoint, "Setting up data sources", self._dir_cache - ): + if not isinstance(datasource, list): + datasources = [datasource] + else: + datasources = datasource + + if len(datasources) == 1: + caption = f"Data Source: {get_data_source_caption(datasources[0])}" + else: + caption = f"{len(datasources)} Data Sources" + + with progress_context(self._endpoint, f"Setting up {caption}", self._dir_cache): root_path = Path(uris.to_fs_path(action_package_yaml_directory_uri)) if not root_path.exists(): return ( @@ -2484,11 +2494,6 @@ def _setup_data_source_impl( .as_dict() ) - if not isinstance(datasource, list): - datasources = [datasource] - else: - datasources = datasource - connection = self._get_connection(data_server_info) messages = [] for datasource in datasources: @@ -2607,12 +2612,14 @@ def m_compute_data_source_state( self, action_package_yaml_directory_uri: str, data_server_info: DataServerConfigTypedDict, + show_progress: bool = True, ) -> partial[ActionResultDict[DataSourceStateDict]]: return require_monitor( partial( self._compute_data_source_state_impl, action_package_yaml_directory_uri, data_server_info, + show_progress, ) ) @@ -2620,11 +2627,15 @@ def _compute_data_source_state_impl( self, action_package_yaml_directory_uri: str, data_server_info: DataServerConfigTypedDict, + show_progress: bool, monitor: IMonitor, ) -> ActionResultDict[DataSourceStateDict]: try: return self._impl_compute_data_source_state_impl( - action_package_yaml_directory_uri, data_server_info, monitor + action_package_yaml_directory_uri, + data_server_info, + show_progress, + monitor, ) except Exception as e: log.exception("Error computing data source state") @@ -2638,16 +2649,25 @@ def _impl_compute_data_source_state_impl( self, action_package_yaml_directory_uri: str, data_server_info: DataServerConfigTypedDict, + show_progress: bool, monitor: IMonitor, ) -> ActionResultDict[DataSourceStateDict]: + from sema4ai_ls_core.constants import NULL from sema4ai_ls_core.lsp import DiagnosticSeverity, DiagnosticsTypedDict from sema4ai_ls_core.progress_report import progress_context + from sema4ai_ls_core.protocols import ModelStateTypedDict from sema4ai_code.data.data_source_helper import DataSourceHelper - with progress_context( - self._endpoint, "Computing data sources state", self._dir_cache - ) as progress_reporter: + ctx: Any + if show_progress: + ctx = progress_context( + self._endpoint, "Computing data sources state", self._dir_cache + ) + else: + ctx = NULL + + with ctx as progress_reporter: progress_reporter.set_additional_info("Listing actions") actions_and_datasources_result: ActionResultDict[ "list[ActionInfoTypedDict | DatasourceInfoTypedDict]" @@ -2665,7 +2685,9 @@ def _impl_compute_data_source_state_impl( monitor.check_cancelled() progress_reporter.set_additional_info("Getting data sources") connection = self._get_connection(data_server_info) + monitor.check_cancelled() projects_as_dicts = connection.get_data_sources("WHERE type = 'project'") + monitor.check_cancelled() not_projects_as_dicts = connection.get_data_sources( "WHERE type != 'project'" ) @@ -2694,15 +2716,25 @@ def _impl_compute_data_source_state_impl( monitor.check_cancelled() progress_reporter.set_additional_info("Getting models") - data_source_to_models = {} + data_source_to_model_name_to_model_state: dict[ + str, dict[str, ModelStateTypedDict] + ] = {} for data_source in projects_data_source_names_in_data_server: + monitor.check_cancelled() result_set_models = connection.query( data_source, "SELECT * FROM models" ) if result_set_models: - data_source_to_models[data_source] = [ - x["name"] for x in result_set_models.iter_as_dicts() - ] + name_to_status_and_error: dict[str, ModelStateTypedDict] = {} + for entry in result_set_models.iter_as_dicts(): + name_to_status_and_error[entry["name"]] = { + "status": entry["status"], + "error": entry["error"], + } + + data_source_to_model_name_to_model_state[data_source] = ( + name_to_status_and_error + ) monitor.check_cancelled() progress_reporter.set_additional_info("Getting files") @@ -2719,14 +2751,38 @@ def _impl_compute_data_source_state_impl( actions_and_datasources: "list[ActionInfoTypedDict | DatasourceInfoTypedDict]" = actions_and_datasources_result[ "result" ] - required_data_sources: list["DatasourceInfoTypedDict"] = [ - typing.cast("DatasourceInfoTypedDict", d) - for d in actions_and_datasources - if d["kind"] == "datasource" - ] - unconfigured_data_sources: list["DatasourceInfoTypedDict"] = [] + required_data_sources: list["DatasourceInfoWithStatusTypedDict"] = [] + for d in actions_and_datasources: + if d["kind"] == "datasource": + datasource_info = typing.cast("DatasourceInfoTypedDict", d) + required_data_sources.append( + { + "python_variable_name": datasource_info[ + "python_variable_name" + ], + "range": datasource_info["range"], + "name": datasource_info["name"], + "uri": datasource_info["uri"], + "kind": "datasource", + "engine": datasource_info["engine"], + "model_name": datasource_info["model_name"], + "created_table": datasource_info["created_table"], + "setup_sql": datasource_info["setup_sql"], + "setup_sql_files": datasource_info["setup_sql_files"], + "description": datasource_info["description"], + "file": datasource_info["file"], + # --- Data Source State below (we'll set these later) --- + "configured": None, + "model_state": None, + "configuration_valid": None, + "configuration_errors": None, + } + ) + + unconfigured_data_sources: list["DatasourceInfoWithStatusTypedDict"] = [] uri_to_error_messages: dict[str, list[DiagnosticsTypedDict]] = {} + ret: DataSourceStateDict = { "unconfigured_data_sources": unconfigured_data_sources, "uri_to_error_messages": uri_to_error_messages, @@ -2736,13 +2792,15 @@ def _impl_compute_data_source_state_impl( if required_data_sources: root_path = Path(uris.to_fs_path(action_package_yaml_directory_uri)) - datasource: "DatasourceInfoTypedDict" + datasource: "DatasourceInfoWithStatusTypedDict" for datasource in required_data_sources: uri = datasource.get("uri", "") datasource_helper = DataSourceHelper( root_path, datasource, connection ) - validation_errors = datasource_helper.get_validation_errors() + validation_errors: tuple[str, ...] = ( + datasource_helper.get_validation_errors() + ) if validation_errors: for validation_error in validation_errors: uri_to_error_messages.setdefault(uri, []).append( @@ -2752,8 +2810,12 @@ def _impl_compute_data_source_state_impl( "message": validation_error, } ) + datasource["configuration_valid"] = False + datasource["configuration_errors"] = list(validation_errors) continue # this one is invalid, so, we can't go forward. + datasource["configuration_valid"] = True + datasource_name = datasource["name"] datasource_engine = datasource["engine"] @@ -2761,7 +2823,9 @@ def _impl_compute_data_source_state_impl( created_table = datasource["created_table"] if datasource_engine == "files": if created_table not in files_table_names: + datasource["configured"] = False unconfigured_data_sources.append(datasource) + continue else: # Custom datasource with created_table. tables_result_set = connection.query("files", "SHOW TABLES") @@ -2770,18 +2834,32 @@ def _impl_compute_data_source_state_impl( for x in tables_result_set.iter_as_dicts() ) if created_table not in custom_table_names: + datasource["configured"] = False unconfigured_data_sources.append(datasource) + continue + + datasource["configured"] = True continue # Ok, handled use case. if datasource_helper.is_model_datasource: model_name = datasource["model_name"] - if model_name not in data_source_to_models.get( - datasource_name, [] - ): + found = data_source_to_model_name_to_model_state.get( + datasource_name, {} + ) + + if model_name not in found: + datasource["configured"] = False unconfigured_data_sources.append(datasource) + continue + + datasource["model_state"] = found[model_name] + datasource["configured"] = True continue if datasource_name.lower() not in data_source_names_in_data_server: + datasource["configured"] = False unconfigured_data_sources.append(datasource) + else: + datasource["configured"] = True return ActionResult[DataSourceStateDict].make_success(ret).as_dict() diff --git a/sema4ai/tests/sema4ai_code_tests/test_compute_data_sources_state.py b/sema4ai/tests/sema4ai_code_tests/test_compute_data_sources_state.py index 016cd9fa..9a9808fa 100644 --- a/sema4ai/tests/sema4ai_code_tests/test_compute_data_sources_state.py +++ b/sema4ai/tests/sema4ai_code_tests/test_compute_data_sources_state.py @@ -66,7 +66,7 @@ def wait_for_models_to_be_ready( row = next(iter(found)) status = row.get("status", "").lower() - if status in ("generating", "training"): + if status in ("generating", "training", "creating"): still_training.setdefault(project, []).append(model) log.info( f"Waiting for model {project}.{model} to complete. Current status: {status}" diff --git a/sema4ai/tests/sema4ai_code_tests/test_compute_data_sources_state/missing_data_source_all.yml b/sema4ai/tests/sema4ai_code_tests/test_compute_data_sources_state/missing_data_source_all.yml index b44aa26b..4161a364 100644 --- a/sema4ai/tests/sema4ai_code_tests/test_compute_data_sources_state/missing_data_source_all.yml +++ b/sema4ai/tests/sema4ai_code_tests/test_compute_data_sources_state/missing_data_source_all.yml @@ -1,11 +1,15 @@ data_sources_in_data_server: [] required_data_sources: -- created_table: customers_in_test_compute_data_sources_state +- configuration_errors: null + configuration_valid: true + configured: false + created_table: customers_in_test_compute_data_sources_state description: Data source for customers. engine: files file: files/customers.csv kind: datasource model_name: null + model_state: null name: files python_variable_name: CustomersDataSource range: @@ -18,12 +22,16 @@ required_data_sources: setup_sql: null setup_sql_files: null uri: data_sources.py -- created_table: null +- configuration_errors: null + configuration_valid: true + configured: false + created_table: null description: Predict something. engine: prediction:lightwood file: null kind: datasource model_name: predict_compute_data_sources_state + model_state: null name: models python_variable_name: PredictDataSource range: @@ -48,12 +56,16 @@ required_data_sources: HORIZON 4;' setup_sql_files: null uri: data_sources.py -- created_table: null +- configuration_errors: null + configuration_valid: true + configured: false + created_table: null description: Data source for tests. engine: sqlite file: null kind: datasource model_name: null + model_state: null name: test_compute_data_sources_state python_variable_name: TestSqliteDataSource range: @@ -67,12 +79,16 @@ required_data_sources: setup_sql_files: null uri: data_sources.py unconfigured_data_sources: -- created_table: customers_in_test_compute_data_sources_state +- configuration_errors: null + configuration_valid: true + configured: false + created_table: customers_in_test_compute_data_sources_state description: Data source for customers. engine: files file: files/customers.csv kind: datasource model_name: null + model_state: null name: files python_variable_name: CustomersDataSource range: @@ -85,12 +101,16 @@ unconfigured_data_sources: setup_sql: null setup_sql_files: null uri: data_sources.py -- created_table: null +- configuration_errors: null + configuration_valid: true + configured: false + created_table: null description: Predict something. engine: prediction:lightwood file: null kind: datasource model_name: predict_compute_data_sources_state + model_state: null name: models python_variable_name: PredictDataSource range: @@ -115,12 +135,16 @@ unconfigured_data_sources: HORIZON 4;' setup_sql_files: null uri: data_sources.py -- created_table: null +- configuration_errors: null + configuration_valid: true + configured: false + created_table: null description: Data source for tests. engine: sqlite file: null kind: datasource model_name: null + model_state: null name: test_compute_data_sources_state python_variable_name: TestSqliteDataSource range: diff --git a/sema4ai/tests/sema4ai_code_tests/test_compute_data_sources_state/missing_data_source_none.yml b/sema4ai/tests/sema4ai_code_tests/test_compute_data_sources_state/missing_data_source_none.yml index 956cedfd..820727a0 100644 --- a/sema4ai/tests/sema4ai_code_tests/test_compute_data_sources_state/missing_data_source_none.yml +++ b/sema4ai/tests/sema4ai_code_tests/test_compute_data_sources_state/missing_data_source_none.yml @@ -1,11 +1,15 @@ data_sources_in_data_server: [] required_data_sources: -- created_table: customers_in_test_compute_data_sources_state +- configuration_errors: null + configuration_valid: true + configured: true + created_table: customers_in_test_compute_data_sources_state description: Data source for customers. engine: files file: files/customers.csv kind: datasource model_name: null + model_state: null name: files python_variable_name: CustomersDataSource range: @@ -18,12 +22,18 @@ required_data_sources: setup_sql: null setup_sql_files: null uri: data_sources.py -- created_table: null +- configuration_errors: null + configuration_valid: true + configured: true + created_table: null description: Predict something. engine: prediction:lightwood file: null kind: datasource model_name: predict_compute_data_sources_state + model_state: + error: null + status: complete name: models python_variable_name: PredictDataSource range: @@ -48,12 +58,16 @@ required_data_sources: HORIZON 4;' setup_sql_files: null uri: data_sources.py -- created_table: null +- configuration_errors: null + configuration_valid: true + configured: true + created_table: null description: Data source for tests. engine: sqlite file: null kind: datasource model_name: null + model_state: null name: test_compute_data_sources_state python_variable_name: TestSqliteDataSource range: diff --git a/sema4ai/tests/sema4ai_code_tests/test_compute_data_sources_state/missing_data_source_prediction.yml b/sema4ai/tests/sema4ai_code_tests/test_compute_data_sources_state/missing_data_source_prediction.yml index efad67bb..7f6270fb 100644 --- a/sema4ai/tests/sema4ai_code_tests/test_compute_data_sources_state/missing_data_source_prediction.yml +++ b/sema4ai/tests/sema4ai_code_tests/test_compute_data_sources_state/missing_data_source_prediction.yml @@ -1,11 +1,15 @@ data_sources_in_data_server: [] required_data_sources: -- created_table: customers_in_test_compute_data_sources_state +- configuration_errors: null + configuration_valid: true + configured: true + created_table: customers_in_test_compute_data_sources_state description: Data source for customers. engine: files file: files/customers.csv kind: datasource model_name: null + model_state: null name: files python_variable_name: CustomersDataSource range: @@ -18,12 +22,16 @@ required_data_sources: setup_sql: null setup_sql_files: null uri: data_sources.py -- created_table: null +- configuration_errors: null + configuration_valid: true + configured: false + created_table: null description: Predict something. engine: prediction:lightwood file: null kind: datasource model_name: predict_compute_data_sources_state + model_state: null name: models python_variable_name: PredictDataSource range: @@ -48,12 +56,16 @@ required_data_sources: HORIZON 4;' setup_sql_files: null uri: data_sources.py -- created_table: null +- configuration_errors: null + configuration_valid: true + configured: true + created_table: null description: Data source for tests. engine: sqlite file: null kind: datasource model_name: null + model_state: null name: test_compute_data_sources_state python_variable_name: TestSqliteDataSource range: @@ -67,12 +79,16 @@ required_data_sources: setup_sql_files: null uri: data_sources.py unconfigured_data_sources: -- created_table: null +- configuration_errors: null + configuration_valid: true + configured: false + created_table: null description: Predict something. engine: prediction:lightwood file: null kind: datasource model_name: predict_compute_data_sources_state + model_state: null name: models python_variable_name: PredictDataSource range: diff --git a/sema4ai/tests/sema4ai_code_tests/test_compute_data_sources_state/missing_data_source_sqlite.yml b/sema4ai/tests/sema4ai_code_tests/test_compute_data_sources_state/missing_data_source_sqlite.yml index 3347c550..30c8df37 100644 --- a/sema4ai/tests/sema4ai_code_tests/test_compute_data_sources_state/missing_data_source_sqlite.yml +++ b/sema4ai/tests/sema4ai_code_tests/test_compute_data_sources_state/missing_data_source_sqlite.yml @@ -1,11 +1,15 @@ data_sources_in_data_server: [] required_data_sources: -- created_table: customers_in_test_compute_data_sources_state +- configuration_errors: null + configuration_valid: true + configured: true + created_table: customers_in_test_compute_data_sources_state description: Data source for customers. engine: files file: files/customers.csv kind: datasource model_name: null + model_state: null name: files python_variable_name: CustomersDataSource range: @@ -18,12 +22,16 @@ required_data_sources: setup_sql: null setup_sql_files: null uri: data_sources.py -- created_table: null +- configuration_errors: null + configuration_valid: true + configured: false + created_table: null description: Predict something. engine: prediction:lightwood file: null kind: datasource model_name: predict_compute_data_sources_state + model_state: null name: models python_variable_name: PredictDataSource range: @@ -48,12 +56,16 @@ required_data_sources: HORIZON 4;' setup_sql_files: null uri: data_sources.py -- created_table: null +- configuration_errors: null + configuration_valid: true + configured: false + created_table: null description: Data source for tests. engine: sqlite file: null kind: datasource model_name: null + model_state: null name: test_compute_data_sources_state python_variable_name: TestSqliteDataSource range: @@ -67,12 +79,16 @@ required_data_sources: setup_sql_files: null uri: data_sources.py unconfigured_data_sources: -- created_table: null +- configuration_errors: null + configuration_valid: true + configured: false + created_table: null description: Predict something. engine: prediction:lightwood file: null kind: datasource model_name: predict_compute_data_sources_state + model_state: null name: models python_variable_name: PredictDataSource range: @@ -97,12 +113,16 @@ unconfigured_data_sources: HORIZON 4;' setup_sql_files: null uri: data_sources.py -- created_table: null +- configuration_errors: null + configuration_valid: true + configured: false + created_table: null description: Data source for tests. engine: sqlite file: null kind: datasource model_name: null + model_state: null name: test_compute_data_sources_state python_variable_name: TestSqliteDataSource range: diff --git a/sema4ai/vscode-client/src/extension.ts b/sema4ai/vscode-client/src/extension.ts index 72c18e47..b899e127 100644 --- a/sema4ai/vscode-client/src/extension.ts +++ b/sema4ai/vscode-client/src/extension.ts @@ -1157,11 +1157,15 @@ export async function dropDataSource(entry?: RobotEntry) { cancellable: false, }, async (progress: Progress<{ message?: string; increment?: number }>, token: CancellationToken) => { - OUTPUT_CHANNEL.appendLine("Dropping Data Source: " + JSON.stringify(datasource)); - return await langServer.sendRequest("dropDataSource", { - "datasource": datasource, - "data_server_info": dataServerStatus["data"], - }); + OUTPUT_CHANNEL.appendLine("Dropping Data Source: " + JSON.stringify(datasource, null, 4)); + return await langServer.sendRequest( + "dropDataSource", + { + "datasource": datasource, + "data_server_info": dataServerStatus["data"], + }, + token + ); } ); diff --git a/sema4ai/vscode-client/src/protocols.ts b/sema4ai/vscode-client/src/protocols.ts index 0041241a..da702292 100644 --- a/sema4ai/vscode-client/src/protocols.ts +++ b/sema4ai/vscode-client/src/protocols.ts @@ -197,6 +197,11 @@ export interface ActionServerPackageUploadStatusOutput { }; } +export interface ModelState { + status: "complete" | "error" | "generating" | "training" | "creating"; + error?: string; +} + export interface DatasourceInfo { range: Range; name: string; @@ -209,6 +214,13 @@ export interface DatasourceInfo { python_variable_name?: string; setup_sql?: string | string[]; setup_sql_files?: string | string[]; + file?: string; + + // Data Source State below (just available when computeDataSourceState is called) + configured?: boolean; + model_state?: ModelState; + configuration_valid?: boolean; + configuration_errors?: string[]; } export interface DiagnosticInfo { diff --git a/sema4ai/vscode-client/src/robo/actionPackage.ts b/sema4ai/vscode-client/src/robo/actionPackage.ts index a2c10e3a..4abed1e4 100644 --- a/sema4ai/vscode-client/src/robo/actionPackage.ts +++ b/sema4ai/vscode-client/src/robo/actionPackage.ts @@ -451,7 +451,7 @@ export function getDataSourceCaption(dataSource: DatasourceInfo): string { return `${dataSource.name} (${dataSource.engine})`; } -async function computeDataSourceState( +export async function computeDataSourceState( actionPackageYamlDirectoryUri: string, dataServerInfo: DataServerConfig ): Promise> { diff --git a/sema4ai/vscode-client/src/robo/dataSourceHandling.ts b/sema4ai/vscode-client/src/robo/dataSourceHandling.ts index ee2f6255..bcdf844b 100644 --- a/sema4ai/vscode-client/src/robo/dataSourceHandling.ts +++ b/sema4ai/vscode-client/src/robo/dataSourceHandling.ts @@ -1,9 +1,11 @@ -import { Uri, window, commands } from "vscode"; +import { Uri, window, commands, ProgressLocation, Progress, CancellationToken } from "vscode"; import { logError, OUTPUT_CHANNEL } from "../channel"; import { RobotEntry } from "../viewsCommon"; -import { DatasourceInfo } from "../protocols"; +import { DatasourceInfo, DataSourceState } from "../protocols"; import { langServer } from "../extension"; import { startDataServerAndGetInfo } from "../dataExtension"; +import { computeDataSourceState, DataServerConfig } from "./actionPackage"; +import { sleep } from "../time"; function isExternalDatasource(datasource: DatasourceInfo): boolean { const externalEngines = ["custom", "files", "models"]; @@ -25,8 +27,8 @@ export const setupDataSource = async (entry?: RobotEntry) => { return false; } - OUTPUT_CHANNEL.appendLine("setupDataSource: " + JSON.stringify(entry)); const datasource: DatasourceInfo = entry.extraData.datasource; + OUTPUT_CHANNEL.appendLine("setupDataSource: " + JSON.stringify(datasource, null, 4)); if (isExternalDatasource(datasource)) { try { @@ -39,16 +41,130 @@ export const setupDataSource = async (entry?: RobotEntry) => { return; } - const result = await langServer.sendRequest("setupDataSource", { + const setupDataSourceResult = await langServer.sendRequest("setupDataSource", { action_package_yaml_directory_uri: Uri.file(entry.robot.directory).toString(), datasource: datasource, data_server_info: dataServerInfo, }); - if (result["success"]) { - const messages = result["result"]; + OUTPUT_CHANNEL.appendLine("setupDataSourceResult: " + JSON.stringify(setupDataSourceResult, null, 4)); + if (setupDataSourceResult["success"]) { + const messages = setupDataSourceResult["result"]; window.showInformationMessage(messages.join("\n")); + await waitForModelsToBeReady(Uri.file(entry.robot.directory).toString(), dataServerInfo); } else { - const error = result["message"]; - window.showErrorMessage(error); + const error = setupDataSourceResult["message"]; + window.showErrorMessage(`There was an error setting up the Data Source: ${error}`); } }; + +export const waitForModelsToBeReady = async ( + actionPackageYamlDirectoryUri: string, + dataServerInfo: DataServerConfig +) => { + const dataSourceStateResult = await computeDataSourceState(actionPackageYamlDirectoryUri, dataServerInfo); + + const modelsBeingTrained: DatasourceInfo[] = []; + + if (dataSourceStateResult.success) { + const dataSourceState: DataSourceState = dataSourceStateResult.result; + // Ok, now see if any model is still being trained or has errors in the model. + for (const datasource of dataSourceState.required_data_sources) { + if (datasource.model_name) { + if (datasource.model_state?.status === "error") { + window.showErrorMessage( + `There is an error with the model: ${datasource.name}.${datasource.model_name}: ${datasource.model_state.error}` + ); + } else if (datasource.model_state?.status !== "complete") { + modelsBeingTrained.push(datasource); + } + } + } + + if (modelsBeingTrained.length > 0) { + const modelNames = modelsBeingTrained.map((m) => `${m.name}.${m.model_name}`).join(", "); + await window.withProgress( + { + location: ProgressLocation.Notification, + title: `Waiting for model(s) to be ready: ${modelNames}`, + cancellable: true, + }, + async (progress: Progress<{ message?: string; increment?: number }>, token: CancellationToken) => { + return await showProgressUntilModelsAreAllReady( + actionPackageYamlDirectoryUri, + dataServerInfo, + modelsBeingTrained, + progress, + token + ); + } + ); + } + } else { + window.showErrorMessage( + `There was an error computing the Data Sources state: ${dataSourceStateResult.message}` + ); + } +}; + +const showProgressUntilModelsAreAllReady = async ( + actionPackageYamlDirectoryUri: string, + dataServerInfo: DataServerConfig, + models: DatasourceInfo[], + progress: Progress<{ message?: string; increment?: number }>, + token: CancellationToken +) => { + // Extract datasource name and model name + const waitForModels: string[] = models.map((m) => `${m.name}.${m.model_name}`); + const initialTime = Date.now(); + + let keepChecking = true; + while (keepChecking) { + await sleep(500); + if (token.isCancellationRequested) { + return; + } + const elapsedTime = Date.now() - initialTime; + const timeInSecondsFormattedAsString = (elapsedTime / 1000).toFixed(1); + progress.report({ + message: `elapsed: ${timeInSecondsFormattedAsString}s`, + }); + keepChecking = false; + const dataSourceStateResult = await computeDataSourceState(actionPackageYamlDirectoryUri, dataServerInfo); + const errors = []; + if (dataSourceStateResult.success) { + const dataSourceState: DataSourceState = dataSourceStateResult.result; + for (const datasource of dataSourceState.required_data_sources) { + if (waitForModels.includes(`${datasource.name}.${datasource.model_name}`)) { + if (!datasource.model_state) { + // It was deleted in the meantime, so we can stop waiting for it. + continue; + } + if (datasource.model_state.status === "error") { + errors.push( + `There is an error with the model: ${datasource.model_name}: ${datasource.model_state.error}` + ); + } else if (datasource.model_state.status !== "complete") { + // Check if the model is in the list of models to wait for + keepChecking = true; + break; + } + } + } + } else { + window.showErrorMessage( + `There was an error computing the Data Sources state: ${dataSourceStateResult.message} (will stop waiting for models to be ready).` + ); + return; + } + + if (!keepChecking) { + if (errors.length > 0) { + window.showErrorMessage( + `Finished training models, but the following models had errors:\n${errors.join("\n")}` + ); + } + return; + } + } + OUTPUT_CHANNEL.appendLine("Finished waiting for models to be ready (should not get here)."); +};