Skip to content

Commit

Permalink
Tosa specification handling (pytorch#6688)
Browse files Browse the repository at this point in the history
Add TOSA specification details to the Arm Backend.

* Mandate the need for a TOSA version in the compile spec list passed
to the Arm backend and propagate the information to node visitors
for serialization handling.
* Add TOSA version string to all TOSA tests
* Adds handling of TOSA 0.80 BI and MI profile as separate serialization
handlers for ADD as an example.

Signed-off-by: Per Åstrand <[email protected]>
  • Loading branch information
per authored Nov 8, 2024
1 parent 6d6630e commit ddc8ea6
Show file tree
Hide file tree
Showing 56 changed files with 618 additions and 133 deletions.
31 changes: 27 additions & 4 deletions backends/arm/arm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from executorch.backends.arm.operators.node_visitor import get_node_visitors
from executorch.backends.arm.operators.op_output import process_output
from executorch.backends.arm.operators.op_placeholder import process_placeholder

from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.backends.arm._passes.arm_pass_manager import (
ArmPassManager,
) # usort: skip
Expand Down Expand Up @@ -86,16 +88,23 @@ def ethosu_compile_spec(
if extra_flags is not None:
self.compiler_flags.append(extra_flags)

base_tosa_version = "TOSA-0.80.0+BI"
if "U55" in config:
# Add the Ethos-U55 extension marker
base_tosa_version += "+u55"
self.tosa_version = TosaSpecification.create_from_string(base_tosa_version)

return self

def tosa_compile_spec(self) -> "ArmCompileSpecBuilder":
def tosa_compile_spec(self, tosa_version: str) -> "ArmCompileSpecBuilder":
"""
Generate compile spec for TOSA flatbuffer output
"""
assert (
self.output_format is None
), f"Output format already set: {self.output_format}"
self.output_format = "tosa"
self.tosa_version = TosaSpecification.create_from_string(tosa_version)
return self

def dump_intermediate_artifacts_to(
Expand Down Expand Up @@ -129,6 +138,13 @@ def build(self) -> List[CompileSpec]:
"""
Generate a list of compile spec objects from the builder
"""
assert self.tosa_version

# Always supply a TOSA version
self.compile_spec = [
CompileSpec("tosa_version", str(self.tosa_version).encode())
]

if self.output_format == "vela":
self.compile_spec += [
CompileSpec("output_format", "vela".encode()),
Expand Down Expand Up @@ -210,25 +226,32 @@ def preprocess( # noqa: C901
if not output_format:
raise RuntimeError("output format is required")

tosa_spec = TosaSpecification.create_from_compilespecs(compile_spec)
assert (
tosa_spec is not None
), "TOSA backend needs a TOSA version specified in the CompileSpec!"

if output_format == "vela" and len(compile_flags) == 0:
# Not testing for compile_flags correctness here, just that they are
# present. The compiler will give errors if they are not valid.
raise RuntimeError("compile flags are required for vela output format")

logger.info(f"Converting ExportedProgram to TOSA: {tosa_spec}")

# Converted output for this subgraph, serializer needs path early as it emits
# const data directly. Path created and data written only in debug builds.
tosa_graph = ts.TosaSerializer(artifact_path)
graph_module = ArmPassManager().transform_to_backend_pipeline(
exported_program=edge_program, compile_spec=compile_spec
)

node_visitors = get_node_visitors(edge_program)
node_visitors = get_node_visitors(edge_program, tosa_spec)

for node in graph_module.graph.nodes:
if node.op == "call_function":
process_call_function(node, tosa_graph, node_visitors)
process_call_function(node, tosa_graph, node_visitors, tosa_spec)
elif node.op == "placeholder":
process_placeholder(node, tosa_graph, edge_program)
process_placeholder(node, tosa_graph, edge_program, tosa_spec)
elif node.op == "output":
process_output(node, tosa_graph)
else:
Expand Down
36 changes: 31 additions & 5 deletions backends/arm/operators/node_visitor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Arm Limited and/or its affiliates.
# Copyright 2023-2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -10,6 +10,7 @@
import serializer.tosa_serializer as ts
import torch
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_specification import TosaSpecification
from torch.export import ExportedProgram


Expand All @@ -18,8 +19,19 @@ class NodeVisitor:
Node Visitor pattern for lowering edge IR to TOSA
"""

def __init__(self, exported_program: ExportedProgram):
# Add the currently supported node_visitor specs as default.
# This should be overriden in the NodeVisitor subclasses to target
# a specific TOSA version.
# When all node_visitors has been refactored to target a specific
# version, this list should be removed.
tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80.0+BI"),
TosaSpecification.create_from_string("TOSA-0.80.0+MI"),
]

def __init__(self, exported_program: ExportedProgram, tosa_spec: TosaSpecification):
self._exported_program = exported_program or None
self.tosa_spec = tosa_spec

def define_node(
self,
Expand All @@ -33,16 +45,30 @@ def define_node(


# container for all node visitors
_node_visitor_dict = {}
_node_visitor_dicts = {
TosaSpecification.create_from_string("TOSA-0.80.0+BI"): {},
TosaSpecification.create_from_string("TOSA-0.80.0+MI"): {},
}


def register_node_visitor(visitor):
_node_visitor_dict[visitor.target] = visitor
for tosa_spec in visitor.tosa_specs:
_node_visitor_dicts[tosa_spec][visitor.target] = visitor
return visitor


def get_node_visitors(*args) -> Dict[str, NodeVisitor]:
node_visitors = {}
for target, visitor in _node_visitor_dict.items():
tosa_spec = None
for arg in args:
if isinstance(arg, TosaSpecification):
tosa_spec = arg
break

if tosa_spec is None:
raise RuntimeError("No TOSA specification supplied.")

for target, visitor in _node_visitor_dicts[tosa_spec].items():
node_visitors[target] = visitor(*args)

return node_visitors
73 changes: 60 additions & 13 deletions backends/arm/operators/op_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,25 @@
import executorch.backends.arm.tosa_utils as tutils

import serializer.tosa_serializer as ts
import torch
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_specification import TosaSpecification
from serializer.tosa_serializer import TosaOp
from torch.fx import Node


@register_node_visitor
class AddVisitor(NodeVisitor):
class AddVisitor_080_BI(NodeVisitor):
target = "aten.add.Tensor"

tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80.0+BI"),
]

def __init__(self, *args):
super().__init__(*args)

Expand All @@ -35,9 +41,22 @@ def define_node(
output: TosaArg,
is_quant_node: bool,
) -> None:
if is_quant_node:
input_nodes = tutils.get_two_inputs(node)
input_nodes = tutils.get_two_inputs(node)

if not is_quant_node and not all(
tensor.meta["val"].dtype in (torch.int8, torch.int32)
for tensor in input_nodes
):
raise RuntimeError(
f"Unexpected non quantized {AddVisitor_080_BI.target} node."
)

needs_rescale = not (
all(tensor.meta["val"].dtype == torch.int32 for tensor in input_nodes)
and node.meta["val"].dtype == torch.int32
)

if needs_rescale:
# Rescale inputs to 32 bit
rescaled_inputs, scale = tqutils.rescale_nodes_to_int32(
input_nodes, tosa_graph
Expand All @@ -48,20 +67,48 @@ def define_node(
rescaled_inputs[0].shape, rescaled_inputs[0].shape
)
add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
else:
add_output = output
rescaled_inputs = inputs

# Do the INT32 Add
tosa_graph.addOperator(
TosaOp.Op().ADD,
[
rescaled_inputs[0].name,
rescaled_inputs[1].name,
],
[add_output.name],
None,
)
# Do the INT32 Add
tosa_graph.addOperator(
TosaOp.Op().ADD,
[
rescaled_inputs[0].name,
rescaled_inputs[1].name,
],
[add_output.name],
None,
)

if needs_rescale:
# Scale output back to 8 bit
tqutils.rescale_node_back_to_int8(node, add_output, scale, tosa_graph)


@register_node_visitor
class AddVisitor_080_MI(AddVisitor_080_BI):
# inheriting 'target' from BI class

tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80.0+MI"),
]

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
is_quant_node: bool,
) -> None:
if is_quant_node:
# Call the inherited define_node for handling integers
super().define_node(node, tosa_graph, inputs, output, is_quant_node)
else:
# FP32 Add lowering
tosa_graph.addOperator(
Expand Down
12 changes: 10 additions & 2 deletions backends/arm/operators/op_placeholder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
get_quant_node_args,
is_quant_arg,
)
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.backends.arm.tosa_utils import (
is_bias_node_for_quantized_addmm,
is_bias_node_for_quantized_conv,
Expand All @@ -26,6 +27,7 @@
def process_inputs(
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
tosa_spec: TosaSpecification,
):
"""Serialize an input node"""
# inputs need to be in default dim_order (contiguous memory format)
Expand Down Expand Up @@ -95,6 +97,7 @@ def process_inputs_to_parameters(
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
edge_program: ExportedProgram,
tosa_spec: TosaSpecification,
):
"""Serialize bias and non-quantized weights"""
inputs = [TosaArg(node)]
Expand All @@ -106,9 +109,13 @@ def process_inputs_to_parameters(

if is_bias_node_for_quantized_addmm(node) or is_bias_node_for_quantized_conv(node):
# BI bias
assert tosa_spec.support_integer(), f"{tosa_spec} doesnt't support integer"
process_quantized_bias(node, tosa_graph, parameter_values)
else:
# MI weights or bias
if inputs[0].dtype == torch.float32:
assert tosa_spec.support_float(), f"{tosa_spec} doesn't support float"

parameter_values = np.transpose(parameter_values, inputs[0].dim_order)

tosa_graph.addConst(
Expand Down Expand Up @@ -158,15 +165,16 @@ def process_placeholder(
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
edge_program: ExportedProgram,
tosa_spec: TosaSpecification,
):
"""Wrapper for processing and serializing all types of placeholders"""
assert node.name == node.target, "Expect placeholder name and target to match"
assert 0 == len(node.args), "Can't handle default input values"

if node.name in edge_program.graph_signature.user_inputs:
process_inputs(node, tosa_graph)
process_inputs(node, tosa_graph, tosa_spec)
elif node.name in edge_program.graph_signature.inputs_to_parameters:
process_inputs_to_parameters(node, tosa_graph, edge_program)
process_inputs_to_parameters(node, tosa_graph, edge_program, tosa_spec)
elif node.name in edge_program.graph_signature.inputs_to_buffers:
process_inputs_to_buffers(node, tosa_graph, edge_program)
elif node.name in edge_program.graph_signature.inputs_to_lifted_tensor_constants:
Expand Down
10 changes: 6 additions & 4 deletions backends/arm/test/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,16 +177,18 @@ def maybe_get_tosa_collate_path() -> str | None:


def get_tosa_compile_spec(
permute_memory_to_nhwc=True, custom_path=None
tosa_version: str, permute_memory_to_nhwc=True, custom_path=None
) -> list[CompileSpec]:
"""
Default compile spec for TOSA tests.
"""
return get_tosa_compile_spec_unbuilt(permute_memory_to_nhwc, custom_path).build()
return get_tosa_compile_spec_unbuilt(
tosa_version, permute_memory_to_nhwc, custom_path
).build()


def get_tosa_compile_spec_unbuilt(
permute_memory_to_nhwc=False, custom_path=None
tosa_version: str, permute_memory_to_nhwc=False, custom_path=None
) -> ArmCompileSpecBuilder:
"""Get the ArmCompileSpecBuilder for the default TOSA tests, to modify
the compile spec before calling .build() to finalize it.
Expand All @@ -202,7 +204,7 @@ def get_tosa_compile_spec_unbuilt(
os.makedirs(intermediate_path, exist_ok=True)
compile_spec_builder = (
ArmCompileSpecBuilder()
.tosa_compile_spec()
.tosa_compile_spec(tosa_version)
.set_permute_memory_format(permute_memory_to_nhwc)
.dump_intermediate_artifacts_to(intermediate_path)
)
Expand Down
Loading

0 comments on commit ddc8ea6

Please sign in to comment.