Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(shartank) test_sharded_conv2d_with_iree fails with torch 2.5.0 and above #682

Open
marbre opened this issue Dec 12, 2024 · 0 comments
Open

Comments

@marbre
Copy link
Collaborator

marbre commented Dec 12, 2024

Using

  • HEAD at commit d279aff
  • Python 3.11.10
  • iree-base-compiler==3.1.0rc20241212
  • iree-base-runtime==3.1.0rc20241212
  • iree-turbine==3.1.0rc20241211
  • torch==2.5.1+cpu and torch==2.5.0+cpu

the test test_sharded_conv2d_with_iree fails with

FAILED tests/layers/sharded_conv2d_with_iree_test.py::test_sharded_conv2d_with_iree - ValueError: Node type mismatch; expected <class 'tuple'>, but got <class 'list'>.

The test passes with torch==2.4.0+cpu as well as torch==2.4.1+cpu.

Full error log:

mlir_path = PosixPath('/tmp/tmp5xxfha5g/model.mlir'), module_path = PosixPath('/tmp/tmp5xxfha5g/module.vmfb'), parameters_path = PosixPath('/tmp/tmp5xxfha5g/params.irpa'), caching = False

    def test_sharded_conv2d_with_iree(
        mlir_path: Optional[Path],
        module_path: Optional[Path],
        parameters_path: Optional[Path],
        caching: bool,
    ):
        """Test sharding, exporting and running with IREE a 2D convolution layer."""
    
        with tempfile.TemporaryDirectory(
            # TODO: verify hypothesis and remove ignore_cleanup_errors=True after a fix.
            # torch.export.export is spawning some processes that don't exit when the
            # function returns, this causes some objects to not get destroyed, which
            # in turn holds files params.rank0.irpa and params.rank1.irpa open.
            ignore_cleanup_errors=True
        ) as tmp_dir:
            mlir_path = Path(tmp_dir) / "model.mlir" if mlir_path is None else mlir_path
            module_path = (
                Path(tmp_dir) / "module.vmfb" if module_path is None else module_path
            )
            parameters_path = (
                Path(tmp_dir) / "params.irpa"
                if parameters_path is None
                else parameters_path
            )
>           run_test_sharded_conv2d_with_iree(
                mlir_path, module_path, parameters_path, caching
            )

tests/layers/sharded_conv2d_with_iree_test.py:208: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
tests/layers/sharded_conv2d_with_iree_test.py:157: in run_test_sharded_conv2d_with_iree
    exported_module = aot.export(
.venv-uv-3.11/lib/python3.11/site-packages/iree/turbine/aot/exporter.py:310: in export
    exported_program = torch.export.export(
.venv-uv-3.11/lib/python3.11/site-packages/torch/export/__init__.py:270: in export
    return _export(
.venv-uv-3.11/lib/python3.11/site-packages/torch/export/_trace.py:1017: in wrapper
    raise e
.venv-uv-3.11/lib/python3.11/site-packages/torch/export/_trace.py:990: in wrapper
    ep = fn(*args, **kwargs)
.venv-uv-3.11/lib/python3.11/site-packages/torch/export/exported_program.py:114: in wrapper
    return fn(*args, **kwargs)
.venv-uv-3.11/lib/python3.11/site-packages/torch/export/_trace.py:1880: in _export
    export_artifact = export_func(  # type: ignore[operator]
.venv-uv-3.11/lib/python3.11/site-packages/torch/export/_trace.py:1224: in _strict_export
    return _strict_export_lower_to_aten_ir(
.venv-uv-3.11/lib/python3.11/site-packages/torch/export/_trace.py:1252: in _strict_export_lower_to_aten_ir
    gm_torch_level = _export_to_torch_ir(
.venv-uv-3.11/lib/python3.11/site-packages/torch/export/_trace.py:550: in _export_to_torch_ir
    transformed_dynamic_shapes = _transform_shapes_for_default_dynamic(
.venv-uv-3.11/lib/python3.11/site-packages/torch/export/dynamic_shapes.py:911: in _transform_shapes_for_default_dynamic
    result = _tree_map_with_path(
.venv-uv-3.11/lib/python3.11/site-packages/torch/export/dynamic_shapes.py:481: in _tree_map_with_path
    return tree_map_with_path(f, tree, *dynamic_shapes, is_leaf=is_leaf)
.venv-uv-3.11/lib/python3.11/site-packages/torch/utils/_pytree.py:1608: in tree_map_with_path
    return treespec.unflatten(func(*xs) for xs in zip(*all_keypath_leaves))
.venv-uv-3.11/lib/python3.11/site-packages/torch/utils/_pytree.py:803: in unflatten
    leaves = list(leaves)
.venv-uv-3.11/lib/python3.11/site-packages/torch/utils/_pytree.py:1608: in <genexpr>
    return treespec.unflatten(func(*xs) for xs in zip(*all_keypath_leaves))
.venv-uv-3.11/lib/python3.11/site-packages/torch/export/dynamic_shapes.py:471: in f
    return tree_map_with_path(
.venv-uv-3.11/lib/python3.11/site-packages/torch/utils/_pytree.py:1607: in tree_map_with_path
    all_keypath_leaves = keypath_leaves + [treespec.flatten_up_to(r) for r in rests]
.venv-uv-3.11/lib/python3.11/site-packages/torch/utils/_pytree.py:1607: in <listcomp>
    all_keypath_leaves = keypath_leaves + [treespec.flatten_up_to(r) for r in rests]
.venv-uv-3.11/lib/python3.11/site-packages/torch/utils/_pytree.py:798: in flatten_up_to
    self._flatten_up_to_helper(tree, subtrees)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = TreeSpec(tuple, None, [*,
  *]), tree = [[None], [None]], subtrees = []

    def _flatten_up_to_helper(self, tree: PyTree, subtrees: List[PyTree]) -> None:
        if self.is_leaf():
            subtrees.append(tree)
            return
    
        node_type = _get_node_type(tree)
        if self.type not in BUILTIN_TYPES:
            # Always require custom node types to match exactly
            if node_type != self.type:
                raise ValueError(
                    f"Type mismatch; "
                    f"expected {self.type!r}, but got {node_type!r}.",
                )
            flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
            child_pytrees, context = flatten_fn(tree)
            if len(child_pytrees) != self.num_children:
                raise ValueError(
                    f"Node arity mismatch; "
                    f"expected {self.num_children}, but got {len(child_pytrees)}.",
                )
            if context != self.context:
                raise ValueError(
                    f"Node context mismatch for custom node type {self.type!r}.",
                )
        else:
            # For builtin dictionary types, we allow some flexibility
            # Otherwise, we require exact matches
            both_standard_dict = (
                self.type in STANDARD_DICT_TYPES and node_type in STANDARD_DICT_TYPES
            )
            if node_type != self.type and not both_standard_dict:
>               raise ValueError(
                    f"Node type mismatch; "
                    f"expected {self.type!r}, but got {node_type!r}.",
                )
E               ValueError: Node type mismatch; expected <class 'tuple'>, but got <class 'list'>.

.venv-uv-3.11/lib/python3.11/site-packages/torch/utils/_pytree.py:751: ValueError
@marbre marbre changed the title (shartank) test_sharded_conv2d_with_iree fails with torch 2.5.0 and aboce (shartank) test_sharded_conv2d_with_iree fails with torch 2.5.0 and above Dec 12, 2024
marbre added a commit to marbre/shark-ai that referenced this issue Dec 12, 2024
Marks `test_sharded_conv2d_with_iree` as expected to fail if running
with torch>=2.5.0, see nod-ai#682.
marbre added a commit to marbre/shark-ai that referenced this issue Dec 12, 2024
Marks `test_sharded_conv2d_with_iree` as expected to fail if running
with torch>=2.5.0, see nod-ai#682.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant