From 22db20cc9c7cdbb487b6b4ad00232dfa2d1d9b2e Mon Sep 17 00:00:00 2001 From: Pete Date: Thu, 16 May 2024 10:09:35 -0700 Subject: [PATCH] Refactor `ShardedFlatTensor` class (#21) * Refactor `ShardedFlatTensor` class * maybe fix? * make dispatch work when unsharded * revert some changes * another try * another try * another fix * updates * update PyTorch version in CI * fix? * updates * add a little more safety --- .github/actions/setup-venv/action.yml | 9 +- .github/workflows/main.yml | 4 +- src/olmo_core/distributed/checkpoint.py | 21 +- .../distributed/fsdp/flat_param_handle.py | 12 +- .../tensors/sharded_flat_parameter.py | 23 +-- .../tensors/sharded_flat_tensor.py | 185 ++++++++++-------- .../tensors/sharded_flat_parameter_test.py | 3 +- .../tensors/sharded_flat_tensor_test.py | 27 ++- 8 files changed, 157 insertions(+), 127 deletions(-) diff --git a/.github/actions/setup-venv/action.yml b/.github/actions/setup-venv/action.yml index 16a7d04c..3b7e713a 100644 --- a/.github/actions/setup-venv/action.yml +++ b/.github/actions/setup-venv/action.yml @@ -11,7 +11,7 @@ inputs: torch-version: description: The PyTorch version to install required: false - default: '==2.2.1' + default: '==2.3.0' runs: using: composite steps: @@ -34,7 +34,9 @@ runs: id: virtualenv-cache with: path: .venv - key: ${{ inputs.cache-prefix }}-${{ runner.os }}-${{ env.PYTHON_VERSION }}-${{ hashFiles('*requirements.txt', '*pyproject.toml') }} + key: ${{ inputs.cache-prefix }}-${{ runner.os }}-${{ env.PYTHON_VERSION }}-${{ inputs.torch-version }}-${{ hashFiles('*requirements.txt', '*pyproject.toml') }} + restore-keys: | + ${{ inputs.cache-prefix }}-${{ runner.os }}-${{ env.PYTHON_VERSION }}-${{ inputs.torch-version }} - if: steps.virtualenv-cache.outputs.cache-hit != 'true' shell: bash @@ -42,8 +44,7 @@ runs: # Set up virtual environment without cache hit. test -d .venv || virtualenv -p $(which python) --copies --reset-app-data .venv . .venv/bin/activate - #pip install 'torch${{ inputs.torch-version }}' --extra-index-url https://download.pytorch.org/whl/cpu - pip install 'torch${{ inputs.torch-version }}' + pip install 'torch${{ inputs.torch-version }}' --extra-index-url https://download.pytorch.org/whl/cpu pip install -e .[all] - if: steps.virtualenv-cache.outputs.cache-hit == 'true' diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 89e47376..022203cf 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -24,9 +24,7 @@ env: jobs: checks: name: ${{ matrix.task.name }} - # TODO: change to 'ubuntu-latest' once repo is public (will have more RAM then), and update the torch - # install command in the setup-venv action. - runs-on: [macos-13] + runs-on: [ubuntu-latest] timeout-minutes: 5 strategy: fail-fast: false diff --git a/src/olmo_core/distributed/checkpoint.py b/src/olmo_core/distributed/checkpoint.py index 7b8adf6c..6bb63088 100644 --- a/src/olmo_core/distributed/checkpoint.py +++ b/src/olmo_core/distributed/checkpoint.py @@ -1321,33 +1321,20 @@ def _patch_key(model: nn.Module, key: str) -> str: def _get_local_tensor_data(tensor: torch.Tensor) -> torch.Tensor: if isinstance(tensor, DTensor): return tensor.to_local() + elif isinstance(tensor, ShardedFlatTensor): + return tensor.sharded_data else: return tensor.data def _wrap_tensor_for_sharded_parameter(tensor: torch.Tensor, param: Optional[torch.Tensor]) -> torch.Tensor: - if isinstance(tensor, DTensor): + if isinstance(tensor, DTensor) or (isinstance(tensor, ShardedFlatTensor) and tensor.metadata_set): return tensor - # TODO: (fixme) when you call `torch.empty_like(x)` on a `ShardedFlatTensor`, `x`, you get - # a `ShardedFlatTensor` without the metadata. Since PyTorch optimizer's use `torch.empty_like()` - # on each param to initialize its state, we run into an issue unless we still call `ShardedFlatTensor.wrap()` - # below. - # if isinstance(tensor, ShardedFlatTensor): - # return tensor - if isinstance(param, ShardedFlatTensor): return param.wrap(tensor, requires_grad=False) elif isinstance(param, DTensor): - return DTensor( # type: ignore - tensor, - param.device_mesh, - param.placements, - shape=param.size(), - dtype=tensor.dtype, - requires_grad=False, - stride=param.stride(), - ) + return DTensor.from_local(tensor, device_mesh=param.device_mesh, placements=param.placements) elif isinstance(param, nn.Parameter) and isinstance(param.data, DTensor): return _wrap_tensor_for_sharded_parameter(tensor, param.data) else: diff --git a/src/olmo_core/distributed/fsdp/flat_param_handle.py b/src/olmo_core/distributed/fsdp/flat_param_handle.py index a3e58de7..cc4f09e5 100644 --- a/src/olmo_core/distributed/fsdp/flat_param_handle.py +++ b/src/olmo_core/distributed/fsdp/flat_param_handle.py @@ -223,7 +223,7 @@ def shard_params( # Now that we have all of the flat parameters we need to collate all of their data into a single # sharded flat tensor, then set the data for each flat parameter as a view into that flat tensor. - local_flat_sharded_data = torch.cat([flat_param.data for flat_param in flat_params]) + local_flat_sharded_data = torch.cat([flat_param.sharded_data for flat_param in flat_params]) params_data = ShardedFlatTensor( F.pad(local_flat_sharded_data, (0, padded_sharded_numel - local_flat_sharded_data.numel())) ) @@ -245,7 +245,7 @@ def shard_params( ) offset = 0 for flat_param in flat_params: - flat_param.data = params_data[offset : offset + flat_param.numel()] + flat_param.sharded_data = params_data[offset : offset + flat_param.numel()] offset += flat_param.numel() return cls( @@ -319,7 +319,7 @@ def unshard_( assert self.params_data.is_sharded self.params_data.unshard_(dtype=dtype, rank0_only=rank0_only) if set_grads and self.requires_grad: - self.params_unsharded_grad = torch.zeros_like(self.params_data) + self.params_unsharded_grad = torch.zeros_like(self.params_data.data) else: assert not self.params_data.is_sharded # We prefer to use `all_gather_into_tensor()` directly when possible as it involves @@ -342,9 +342,9 @@ def unshard_( offset = 0 for param in self.params: if rank0_only and local_rank != 0: - unsharded_data = torch.empty_like(self.params_data) + unsharded_data = torch.empty_like(self.params_data.data) else: - unsharded_data = self.params_data[offset : offset + param.unsharded_numel] + unsharded_data = self.params_data.data[offset : offset + param.unsharded_numel] param.unshard_(unsharded_data, dtype=dtype, rank0_only=rank0_only) @@ -369,7 +369,7 @@ def reshard_(self, writeback: bool = False): flat_param.reshard_(writeback=False) if writeback: # Reset the view into the new `params_data`. - flat_param.data = self.params_data[offset : offset + flat_param.sharded_numel] + flat_param.sharded_data = self.params_data[offset : offset + flat_param.sharded_numel] offset += flat_param.sharded_numel def pre_reduce_scatter_grads_( diff --git a/src/olmo_core/distributed/tensors/sharded_flat_parameter.py b/src/olmo_core/distributed/tensors/sharded_flat_parameter.py index f8d509d6..4a52703e 100644 --- a/src/olmo_core/distributed/tensors/sharded_flat_parameter.py +++ b/src/olmo_core/distributed/tensors/sharded_flat_parameter.py @@ -27,19 +27,20 @@ def __new__(cls, data: Optional[torch.Tensor] = None, requires_grad: bool = True ) if isinstance(data, ShardedFlatTensor): - setattr( - param, - cls.SHARDED_FLAT_TENSOR_METADATA_NAME, - getattr(data, cls.SHARDED_FLAT_TENSOR_METADATA_NAME, {}).copy(), - ) + param._local_tensor = data._local_tensor + param._sharding_spec = data._sharding_spec + param._process_group = data._process_group else: - setattr(param, cls.SHARDED_FLAT_TENSOR_METADATA_NAME, {}) + param._local_tensor = None if data is None else data.data.detach() + param._sharding_spec = None # type: ignore[assignment] + param._process_group = None + + param._global_tensor = None return param def __repr__(self) -> str: - r = torch.Tensor.__repr__(self) - if r.startswith("Parameter("): # ) -- the open parenthesis confuses treesitter sometimes - r = r.replace("Parameter(", "", 1) # ) -- the open parenthesis confuses treesitter sometimes - r = r[:-1] - return r + if self._global_tensor is not None: + return f"ShardedFlatParameter(local_tensor={self._local_tensor}, global_tensor={self._global_tensor}, requires_grad={self.requires_grad})" + else: + return f"ShardedFlatParameter(local_tensor={self._local_tensor}, requires_grad={self.requires_grad})" diff --git a/src/olmo_core/distributed/tensors/sharded_flat_tensor.py b/src/olmo_core/distributed/tensors/sharded_flat_tensor.py index f546788a..d537894d 100644 --- a/src/olmo_core/distributed/tensors/sharded_flat_tensor.py +++ b/src/olmo_core/distributed/tensors/sharded_flat_tensor.py @@ -3,13 +3,17 @@ import math from dataclasses import dataclass from functools import reduce -from typing import Any, List, Optional, Tuple, Type, TypeVar +from typing import List, Optional, Tuple, Type, TypeVar import torch import torch.distributed as dist -import torch.nn as nn import torch.nn.functional as F +try: + from torch.utils import _cxx_pytree as pytree +except ImportError: + from torch.utils import _pytree as pytree # type: ignore[no-redef] + from ..utils import get_rank, get_world_size __all__ = ["ShardedFlatTensor", "ShardingSpec"] @@ -77,43 +81,58 @@ class ShardedFlatTensor(torch.Tensor): a contiguous slice into the flattened unsharded tensor. """ - SHARDED_FLAT_TENSOR_METADATA_NAME = "__sharded_metadata__" - SHARDED_FLAT_TENSOR_SHARDING_SPEC_KEY = "sharding_spec" - SHARDED_FLAT_TENSOR_PROCESS_GROUP_KEY = "process_group" - SHARDED_FLAT_TENSOR_CACHED_SHARDED_DATA_KEY = "sharded_data" + __slots__ = ["_local_tensor", "_global_tensor", "_sharding_spec", "_process_group"] @staticmethod def __new__(cls, data: torch.Tensor, requires_grad: bool = False) -> ShardedFlatTensor: if data.ndim != 1: raise ValueError(f"{cls.__name__} requires flat data! Got {data.shape}") + sharding_spec: Optional[ShardingSpec] = None + process_group: Optional[dist.ProcessGroup] = None + tensor: ShardedFlatTensor if isinstance(data, ShardedFlatTensor): - tensor = data - setattr( - tensor, - cls.SHARDED_FLAT_TENSOR_METADATA_NAME, - getattr(data, cls.SHARDED_FLAT_TENSOR_METADATA_NAME, {}).copy(), - ) - elif type(data) is torch.Tensor or type(data) is nn.Parameter: - # For ease of BC maintenance, keep this path for standard Tensor. - # Eventually (tm), we should change the behavior for standard Tensor to match. - tensor = torch.Tensor._make_subclass(cls, data, requires_grad) - setattr(tensor, cls.SHARDED_FLAT_TENSOR_METADATA_NAME, {}) - else: - raise TypeError(f"found unexpected type for {cls.__name__} data: {type(data)}") + sharding_spec = data._sharding_spec + process_group = data._process_group + data = data._local_tensor + + tensor = torch.Tensor._make_subclass(cls, data, requires_grad) + tensor._local_tensor = data + tensor._global_tensor = None + tensor._sharding_spec = sharding_spec # type: ignore[assignment] + tensor._process_group = process_group return tensor - def __repr__(self) -> str: - return torch.Tensor.__repr__(self) + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + del types + kwargs = kwargs or {} + + def unwrap(x): + if isinstance(x, ShardedFlatTensor): + return x._global_tensor if x._global_tensor is not None else x._local_tensor + else: + return x + + def wrap(x): + if isinstance(x, torch.Tensor): + if x.shape == self.shape: + return self.wrap(x, requires_grad=x.requires_grad) + return x + + out = func(*pytree.tree_map(unwrap, args), **pytree.tree_map(unwrap, kwargs)) - def _set_metadata(self, key: str, value: Any, force: bool = False): - metadata = getattr(self, self.SHARDED_FLAT_TENSOR_METADATA_NAME) - if not force and key in metadata: - raise ValueError(f"Metadata key '{key}' already exists in {self.__class__.__name__}") + if func in {torch.ops.aten.empty_like.default, torch.ops.aten.zeros_like.default, torch.ops.aten.ones_like.default}: # type: ignore + out = pytree.tree_map(wrap, out) + + return out + + def __repr__(self) -> str: + if self._global_tensor is not None: + return f"ShardedFlatTensor(local_tensor={self._local_tensor}, global_tensor={self._global_tensor})" else: - metadata[key] = value + return f"ShardedFlatTensor(local_tensor={self._local_tensor})" def _gather_data(self, dtype: Optional[torch.dtype] = None, rank0_only: bool = False) -> torch.Tensor: # NOTE: ``all_gather_into_tensor`` is not supported on Gloo. @@ -123,7 +142,7 @@ def _gather_data(self, dtype: Optional[torch.dtype] = None, rank0_only: bool = F local_padding = (0, max_numel - sharded_numels[local_rank]) flat_sharded_tensor_list: Optional[List[torch.Tensor]] = None - local_flat_padded_tensor = F.pad(self.data.to(dtype or self.dtype), local_padding) + local_flat_padded_tensor = F.pad(self._local_tensor.to(dtype or self.dtype), local_padding) # Pad sharded tensors to the same size. if not rank0_only or local_rank == 0: @@ -241,11 +260,11 @@ def gather(self, dtype: Optional[torch.dtype] = None, rank0_only: bool = False) """ Gather the sharded flat parameter across a process group into a full unsharded parameter. """ - if self.is_sharded: + if self._global_tensor is not None: + unsharded_data = self._global_tensor + else: unsharded_data = self._gather_data(dtype=dtype, rank0_only=rank0_only) unsharded_data.requires_grad = self.requires_grad - else: - unsharded_data = self.data return unsharded_data def unshard_( @@ -261,12 +280,17 @@ def unshard_( """ if unsharded_data is None: unsharded_data = ( - self.data if not self.is_sharded else self._gather_data(dtype=dtype, rank0_only=rank0_only) + self._global_tensor + if self._global_tensor is not None + else self._gather_data(dtype=dtype, rank0_only=rank0_only) ) elif not rank0_only or get_rank(self.process_group) == 0: unsharded_data = unsharded_data.view(*self.unsharded_shape) - self._set_metadata(self.SHARDED_FLAT_TENSOR_CACHED_SHARDED_DATA_KEY, self.data) - self.data = unsharded_data + self._global_tensor = unsharded_data + # NOTE: despite `__torch_dispatch__`, we still need to set `self.data` to the unsharded + # data in order for `self.shape()`, `self.numel()`, and other methods to return the + # right values corresponding to the unsharded data. + self.data = unsharded_data # type: ignore[misc] def reshard_(self, writeback: bool = False): """ @@ -274,20 +298,10 @@ def reshard_(self, writeback: bool = False): This does *not* do anything with the parameter's gradient, if it has one. That should be handled separately by the calling code. """ - if self.is_sharded: + if (unsharded_data := self._global_tensor) is None: return - metadata = getattr(self, self.SHARDED_FLAT_TENSOR_METADATA_NAME) - try: - sharded_data = metadata[self.SHARDED_FLAT_TENSOR_CACHED_SHARDED_DATA_KEY] - except KeyError: - raise ValueError( - f"{self.__class__.__name__} has not been unsharded in place yet, " - "did you forget to class '.unshard_()'?" - ) - if writeback: - unsharded_data = self.data if unsharded_data.shape != self.unsharded_shape: # unsharded data could be empty if `.unshard_` was called with `rank0_only=True`. if unsharded_data.numel() > 0: @@ -303,28 +317,40 @@ def reshard_(self, writeback: bool = False): 0 if self.process_group is None else dist.get_global_rank(self.process_group, 0), group=self.process_group, ) - self.data = self.sharded_chunk(unsharded_data).to(dtype=sharded_data.dtype).clone() - else: - self.data = sharded_data - del metadata[self.SHARDED_FLAT_TENSOR_CACHED_SHARDED_DATA_KEY] + self._local_tensor = self.sharded_chunk(unsharded_data).to(dtype=self._local_tensor.dtype).clone() + + self._global_tensor = None + + # NOTE: despite `__torch_dispatch__`, we still need to set `self.data` back to the sharded + # data in order for `self.shape()`, `self.numel()`, and other methods to return the + # right values corresponding to the sharded data. + self.data = self._local_tensor # type: ignore[misc] def mark_as_sharded(self, sharding_spec: ShardingSpec, process_group: Optional[dist.ProcessGroup] = None): if self.numel() != (shard_numel := sharding_spec.sharded_numels[get_rank(group=process_group)]): raise ValueError( f"invalid sharding spec, numel in spec ({shard_numel}) does not match numel in shard ({self.numel()})" ) - self._set_metadata(self.SHARDED_FLAT_TENSOR_SHARDING_SPEC_KEY, sharding_spec) - self._set_metadata(self.SHARDED_FLAT_TENSOR_PROCESS_GROUP_KEY, process_group) + self._sharding_spec = sharding_spec + self._process_group = process_group - def wrap(self, tensor: torch.Tensor, requires_grad: bool = True) -> ShardedFlatTensor: + def wrap(self, tensor: torch.Tensor, requires_grad: Optional[bool] = None) -> ShardedFlatTensor: """ Wrap another tensor and mark as sharded with the same sharding spec. - ``tensor`` should have the same shape as ``self.data``, the sharded data. + ``tensor`` should have the same shape. """ - if tensor.shape != self.data.shape: - raise ValueError(f"shape mismatched, expected {self.data.shape}, got {tensor.shape}") - wrapped = ShardedFlatTensor(tensor.data, requires_grad=requires_grad) # type: ignore - wrapped.mark_as_sharded(self.sharding_spec, process_group=self.process_group) + if self.is_sharded and tensor.shape != self.sharded_shape: + raise ValueError(f"shape mismatched, expected {self.sharded_shape}, got {tensor.shape}") + elif not self.is_sharded and tensor.shape != self.unsharded_shape: + raise ValueError(f"shape mismatched, expected {self.unsharded_shape}, got {tensor.shape}") + requires_grad = requires_grad if requires_grad is not None else tensor.requires_grad + if self.is_sharded: + wrapped = ShardedFlatTensor(tensor.data, requires_grad=requires_grad) # type: ignore + wrapped.mark_as_sharded(self.sharding_spec, process_group=self.process_group) + else: + wrapped = ShardedFlatTensor(self.sharded_chunk(tensor), requires_grad=requires_grad) # type: ignore + wrapped.mark_as_sharded(self.sharding_spec, process_group=self.process_group) + wrapped.unshard_(tensor) return wrapped def chunk_unsharded(self, tensor: torch.Tensor, pad: bool = False) -> List[torch.Tensor]: @@ -361,39 +387,29 @@ def sharded_chunk(self, tensor: torch.Tensor) -> torch.Tensor: rank_chunks.append(flat_tensor[start_idx:end_idx]) return rank_chunks[0] if len(rank_chunks) == 1 else torch.cat(rank_chunks) + @property + def metadata_set(self) -> bool: + for slot in self.__slots__: + if not hasattr(self, slot): + return False + return True + @property def is_sharded(self) -> bool: - try: - metadata = getattr(self, self.SHARDED_FLAT_TENSOR_METADATA_NAME) - return ( - self.SHARDED_FLAT_TENSOR_SHARDING_SPEC_KEY in metadata - and self.SHARDED_FLAT_TENSOR_CACHED_SHARDED_DATA_KEY not in metadata - ) - except AttributeError: - return False + return self._global_tensor is None @property def sharding_spec(self) -> ShardingSpec: - try: - metadata = getattr(self, self.SHARDED_FLAT_TENSOR_METADATA_NAME) - return metadata[self.SHARDED_FLAT_TENSOR_SHARDING_SPEC_KEY] - except (KeyError, AttributeError): + if self._sharding_spec is None: raise ValueError( f"{self.__class__.__name__} has not been marked as sharded yet, " "did you forget to class '.mark_as_sharded()'?" ) + return self._sharding_spec @property def process_group(self) -> Optional[dist.ProcessGroup]: - try: - return getattr(self, self.SHARDED_FLAT_TENSOR_METADATA_NAME)[ - self.SHARDED_FLAT_TENSOR_PROCESS_GROUP_KEY - ] - except KeyError: - raise ValueError( - f"{self.__class__.__name__} has not been marked as sharded yet, " - "did you forget to class '.mark_as_sharded()'?" - ) + return self._process_group @property def unsharded_flattened_offsets(self) -> Tuple[Tuple[int, int], ...]: @@ -417,8 +433,13 @@ def sharded_shape(self) -> Tuple[int, ...]: @property def sharded_data(self) -> torch.Tensor: - metadata = getattr(self, self.SHARDED_FLAT_TENSOR_METADATA_NAME) - try: - return metadata[self.SHARDED_FLAT_TENSOR_CACHED_SHARDED_DATA_KEY] - except KeyError: - return self.data + return self._local_tensor + + @sharded_data.setter + def sharded_data(self, sharded_data: torch.Tensor): + self._local_tensor = sharded_data + self.data = sharded_data + + @property + def unsharded_data(self) -> Optional[torch.Tensor]: + return self._global_tensor diff --git a/src/test/distributed/tensors/sharded_flat_parameter_test.py b/src/test/distributed/tensors/sharded_flat_parameter_test.py index 6b84eaa2..048ac7ab 100644 --- a/src/test/distributed/tensors/sharded_flat_parameter_test.py +++ b/src/test/distributed/tensors/sharded_flat_parameter_test.py @@ -19,8 +19,7 @@ def test_init_empty_sharded_parameter(): assert isinstance(sp, torch.nn.Parameter) assert isinstance(sp, ShardedFlatTensor) assert isinstance(sp, torch.Tensor) - assert repr(sp) == "ShardedFlatParameter([], requires_grad=True)" - assert not sp.is_sharded # hasn't been marked sharded yet + assert repr(sp) == "ShardedFlatParameter(local_tensor=None, requires_grad=True)" def test_init_sharded_parameter_from_tensor(): diff --git a/src/test/distributed/tensors/sharded_flat_tensor_test.py b/src/test/distributed/tensors/sharded_flat_tensor_test.py index be7965fe..2444ee3f 100644 --- a/src/test/distributed/tensors/sharded_flat_tensor_test.py +++ b/src/test/distributed/tensors/sharded_flat_tensor_test.py @@ -17,8 +17,14 @@ def test_init_sharded(): assert isinstance(tensor, ShardedFlatTensor) assert isinstance(tensor, ShardedFlatTensor) assert isinstance(tensor, torch.Tensor) - assert repr(tensor) == "ShardedFlatTensor([0])" - assert not tensor.is_sharded # hasn't been marked sharded yet + assert tensor.metadata_set + assert repr(tensor) == "ShardedFlatTensor(local_tensor=tensor([0]))" + + +def test_not_has_metadata(): + tensor = torch.Tensor._make_subclass(ShardedFlatTensor, torch.rand(3), False) + assert isinstance(tensor, ShardedFlatTensor) + assert not tensor.metadata_set def test_init_sharded_tensor_from_tensor(): @@ -28,6 +34,23 @@ def test_init_sharded_tensor_from_tensor(): assert tensor.shape == (6,) +def test_init_new_tensor_from_sharded_tensor(): + x = ShardedFlatTensor(torch.rand(6)) + x.mark_as_sharded(ShardingSpec(unsharded_shape=(2, 6), unsharded_flattened_offsets=(((0, 6),), ((6, 12),)))) + + y1 = torch.empty_like(x) + assert isinstance(y1, ShardedFlatTensor) + assert y1.is_sharded + + y2 = torch.zeros_like(x) + assert isinstance(y2, ShardedFlatTensor) + assert y2.is_sharded + + y3 = torch.ones_like(x) + assert isinstance(y3, ShardedFlatTensor) + assert y3.is_sharded + + def test_init_sharded_tensor_from_param(): tensor = ShardedFlatTensor(torch.nn.Parameter(torch.rand(6))) assert isinstance(tensor, ShardedFlatTensor)