From 09d8892190360a21106511754883c41715f76b5c Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Fri, 24 May 2024 10:28:58 -0400 Subject: [PATCH] fix typing (#1235) --- llmfoundry/models/utils/config_moe_args.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/llmfoundry/models/utils/config_moe_args.py b/llmfoundry/models/utils/config_moe_args.py index e27514275c..963c596e76 100644 --- a/llmfoundry/models/utils/config_moe_args.py +++ b/llmfoundry/models/utils/config_moe_args.py @@ -19,7 +19,7 @@ ] -def create_process_group_ranks(ranks: tuple[int]): +def create_process_group_ranks(ranks: tuple[int, ...]): """Creates a new distributed group. Used in create_set_process_group and create_mod_process_group methods below. @@ -27,7 +27,7 @@ def create_process_group_ranks(ranks: tuple[int]): This function is an alternative to `distributed.new_group(ranks)`. Args: - ranks (tuple[int]): Tuple of ranks of group members. + ranks (tuple[int, ...]): Tuple of ranks of group members. Returns: A handle of distributed group that can be given to collective calls. @@ -66,14 +66,14 @@ def create_set_process_group(k: int): def get_megablocks_device_mesh( - device_mesh_cfg: Optional[tuple[int]], + device_mesh_cfg: Optional[tuple[int, ...]], moe_world_size: int, world_size: int, ) -> DeviceMesh: """Helper function to get the device mesh for MegaBlocks MoE. Args: - device_mesh_cfg (Optional[tuple[int]]): The device mesh configuration specification. + device_mesh_cfg (Optional[tuple[int, ...]]): The device mesh configuration specification. moe_world_size (int): The MoE world size. world_size (int): The world size.