You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
the test test_sharded_resnet_block_with_iree fails with
FAILED tests/models/punet/sharded_resnet_block_with_iree_test.py::test_sharded_resnet_block_with_iree - ValueError: Node type mismatch; expected <class 'tuple'>, but got <class 'list'>.
The test passes with torch==2.4.0+cpu as well as torch==2.4.1+cpu.
Full error log:
mlir_path = PosixPath('/tmp/tmpjiqhho0l/model.mlir'), module_path = PosixPath('/tmp/tmpjiqhho0l/module.vmfb'), parameters_path = PosixPath('/tmp/tmpjiqhho0l/params.irpa'), caching = False
@pytest.mark.xfail(
reason="Maybe numerical issues with low accuracy.",
strict=True,
raises=AssertionError,
)
def test_sharded_resnet_block_with_iree(
mlir_path: Optional[Path],
module_path: Optional[Path],
parameters_path: Optional[Path],
caching: bool,
):
"""Test sharding, exportation and execution with IREE local-task of a Resnet block.
The result is compared against execution with torch.
The model is tensor sharded across 2 devices.
"""
with tempfile.TemporaryDirectory(
# TODO: verify hypothesis and remove ignore_cleanup_errors=True
# torch.export.export is spawning some processes that don't exit when the
# function returns, this causes some objects to not get destroyed, which
# in turn holds files params.rank0.irpa and params.rank1.irpa open.
ignore_cleanup_errors=True
) as tmp_dir:
mlir_path = Path(tmp_dir) / "model.mlir" if mlir_path is None else mlir_path
module_path = (
Path(tmp_dir) / "module.vmfb" if module_path is None else module_path
)
parameters_path = (
Path(tmp_dir) / "params.irpa"
if parameters_path is None
else parameters_path
)
> run_test_sharded_resnet_block_with_iree(
mlir_path, module_path, parameters_path, caching
)
tests/models/punet/sharded_resnet_block_with_iree_test.py:258:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
tests/models/punet/sharded_resnet_block_with_iree_test.py:189: in run_test_sharded_resnet_block_with_iree
exported_resnet_block = aot.export(
.venv-uv-3.11/lib/python3.11/site-packages/iree/turbine/aot/exporter.py:310: in export
exported_program = torch.export.export(
.venv-uv-3.11/lib/python3.11/site-packages/torch/export/__init__.py:270: in export
return _export(
.venv-uv-3.11/lib/python3.11/site-packages/torch/export/_trace.py:1017: in wrapper
raise e
.venv-uv-3.11/lib/python3.11/site-packages/torch/export/_trace.py:990: in wrapper
ep = fn(*args, **kwargs)
.venv-uv-3.11/lib/python3.11/site-packages/torch/export/exported_program.py:114: in wrapper
return fn(*args, **kwargs)
.venv-uv-3.11/lib/python3.11/site-packages/torch/export/_trace.py:1880: in _export
export_artifact = export_func( # type: ignore[operator]
.venv-uv-3.11/lib/python3.11/site-packages/torch/export/_trace.py:1224: in _strict_export
return _strict_export_lower_to_aten_ir(
.venv-uv-3.11/lib/python3.11/site-packages/torch/export/_trace.py:1252: in _strict_export_lower_to_aten_ir
gm_torch_level = _export_to_torch_ir(
.venv-uv-3.11/lib/python3.11/site-packages/torch/export/_trace.py:550: in _export_to_torch_ir
transformed_dynamic_shapes = _transform_shapes_for_default_dynamic(
.venv-uv-3.11/lib/python3.11/site-packages/torch/export/dynamic_shapes.py:911: in _transform_shapes_for_default_dynamic
result = _tree_map_with_path(
.venv-uv-3.11/lib/python3.11/site-packages/torch/export/dynamic_shapes.py:481: in _tree_map_with_path
return tree_map_with_path(f, tree, *dynamic_shapes, is_leaf=is_leaf)
.venv-uv-3.11/lib/python3.11/site-packages/torch/utils/_pytree.py:1608: in tree_map_with_path
return treespec.unflatten(func(*xs) for xs in zip(*all_keypath_leaves))
.venv-uv-3.11/lib/python3.11/site-packages/torch/utils/_pytree.py:803: in unflatten
leaves = list(leaves)
.venv-uv-3.11/lib/python3.11/site-packages/torch/utils/_pytree.py:1608: in <genexpr>
return treespec.unflatten(func(*xs) for xs in zip(*all_keypath_leaves))
.venv-uv-3.11/lib/python3.11/site-packages/torch/export/dynamic_shapes.py:471: in f
return tree_map_with_path(
.venv-uv-3.11/lib/python3.11/site-packages/torch/utils/_pytree.py:1607: in tree_map_with_path
all_keypath_leaves = keypath_leaves + [treespec.flatten_up_to(r) for r in rests]
.venv-uv-3.11/lib/python3.11/site-packages/torch/utils/_pytree.py:1607: in <listcomp>
all_keypath_leaves = keypath_leaves + [treespec.flatten_up_to(r) for r in rests]
.venv-uv-3.11/lib/python3.11/site-packages/torch/utils/_pytree.py:798: in flatten_up_to
self._flatten_up_to_helper(tree, subtrees)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = TreeSpec(tuple, None, [*,
*]), tree = [[None], [None]], subtrees = []
def _flatten_up_to_helper(self, tree: PyTree, subtrees: List[PyTree]) -> None:
if self.is_leaf():
subtrees.append(tree)
return
node_type = _get_node_type(tree)
if self.type not in BUILTIN_TYPES:
# Always require custom node types to match exactly
if node_type != self.type:
raise ValueError(
f"Type mismatch; "
f"expected {self.type!r}, but got {node_type!r}.",
)
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
child_pytrees, context = flatten_fn(tree)
if len(child_pytrees) != self.num_children:
raise ValueError(
f"Node arity mismatch; "
f"expected {self.num_children}, but got {len(child_pytrees)}.",
)
if context != self.context:
raise ValueError(
f"Node context mismatch for custom node type {self.type!r}.",
)
else:
# For builtin dictionary types, we allow some flexibility
# Otherwise, we require exact matches
both_standard_dict = (
self.type in STANDARD_DICT_TYPES and node_type in STANDARD_DICT_TYPES
)
if node_type != self.type and not both_standard_dict:
> raise ValueError(
f"Node type mismatch; "
f"expected {self.type!r}, but got {node_type!r}.",
)
E ValueError: Node type mismatch; expected <class 'tuple'>, but got <class 'list'>.
.venv-uv-3.11/lib/python3.11/site-packages/torch/utils/_pytree.py:751: ValueError
The text was updated successfully, but these errors were encountered:
marbre
changed the title
(shartank) test_sharded_resnet_block_with_iree fails with torch 2.5.0 and above
(shartank) test_sharded_resnet_block_with_iree fails with torch 2.5.0 and above
Dec 12, 2024
marbre
added a commit
to marbre/shark-ai
that referenced
this issue
Dec 12, 2024
Using
the test
test_sharded_resnet_block_with_iree
fails withThe test passes with
torch==2.4.0+cpu
as well astorch==2.4.1+cpu
.Full error log:
The text was updated successfully, but these errors were encountered: