diff --git a/src/server/package/src/model_explorer/apis.py b/src/server/package/src/model_explorer/apis.py index 120ad49f..87b3bbc5 100644 --- a/src/server/package/src/model_explorer/apis.py +++ b/src/server/package/src/model_explorer/apis.py @@ -13,10 +13,10 @@ # limitations under the License. # ============================================================================== -from typing import Union, TypedDict -from typing_extensions import NotRequired +from typing import TypedDict, Union import torch +from typing_extensions import NotRequired from . import server from .config import ModelExplorerConfig, NodeData @@ -58,7 +58,7 @@ def visualize( host=DEFAULT_HOST, port=DEFAULT_PORT, extensions: list[str] = [], - node_data_list: list[NodeDataInfo] = [], + node_data: Union[NodeDataInfo, list[NodeDataInfo]] = [], colab_height=DEFAULT_COLAB_HEIGHT, reuse_server: bool = False, reuse_server_host: str = DEFAULT_HOST, @@ -71,7 +71,7 @@ def visualize( host: The host of the server. Default to localhost. port: The port of the server. Default to 8080. extensions: List of extension names to be run with model explorer. - node_data_list: The list of node data to display. + node_data: The node data or a list of node data to display. colab_height: The height of the embedded iFrame when running in colab. reuse_server: Whether to reuse the current server/browser tab(s) to visualize. @@ -88,9 +88,7 @@ def visualize( for model_path in model_paths_list: cur_config.add_model_from_path(path=model_path) - _add_node_data_list_to_config( - node_data_list=node_data_list, config=cur_config - ) + _add_node_data_to_config(node_data=node_data, config=cur_config) if reuse_server: cur_config.set_reuse_server( @@ -113,7 +111,7 @@ def visualize_pytorch( host=DEFAULT_HOST, port=DEFAULT_PORT, extensions: list[str] = [], - node_data_list: list[NodeDataInfo] = [], + node_data: Union[NodeDataInfo, list[NodeDataInfo]] = [], colab_height=DEFAULT_COLAB_HEIGHT, settings=DEFAULT_SETTINGS, ) -> None: @@ -125,7 +123,7 @@ def visualize_pytorch( host: The host of the server. Default to localhost. port: The port of the server. Default to 8080. extensions: List of extension names to be run with model explorer. - node_data_list: The list of node data to display. + node_data: The node data or a list of node data to display. colab_height: The height of the embedded iFrame when running in colab. settings: The settings that config the visualization. """ @@ -135,9 +133,7 @@ def visualize_pytorch( name, exported_program=exported_program, settings=settings ) - _add_node_data_list_to_config( - node_data_list=node_data_list, config=cur_config - ) + _add_node_data_to_config(node_data=node_data, config=cur_config) # Start server. server.start( @@ -182,17 +178,25 @@ def visualize_from_config( ) -def _add_node_data_list_to_config( - node_data_list: list[NodeDataInfo], config: ModelExplorerConfig +def _add_node_data_to_config( + node_data: Union[NodeDataInfo, list[NodeDataInfo]], + config: ModelExplorerConfig, ): + # Convert NodeDataInfo to [NodeDataInfo] if necessary. + node_data_list: list[NodeDataInfo] = [] + if isinstance(node_data, list): + node_data_list = node_data + else: + node_data_list = [node_data] + for node_data_info in node_data_list: name = node_data_info.get('name', 'node data') node_data_path = node_data_info.get('node_data_path') - node_data = node_data_info.get('node_data') + node_data_obj = node_data_info.get('node_data') model_name = node_data_info.get('model_name') - if node_data: + if node_data_obj: config.add_node_data( - name=name, node_data=node_data, model_name=model_name + name=name, node_data=node_data_obj, model_name=model_name ) elif node_data_path: config.add_node_data_from_path(path=node_data_path, model_name=model_name)