diff --git a/backends/arm/process_node.py b/backends/arm/process_node.py index 989e0253a65..5a1d563ee0b 100644 --- a/backends/arm/process_node.py +++ b/backends/arm/process_node.py @@ -15,7 +15,6 @@ from executorch.backends.arm.tosa.mapping import TosaArg, TosaSpecialDtype from executorch.backends.arm.tosa.specification import TosaSpecification from executorch.backends.arm.tosa.utils import tosa_shape -from executorch.exir.graph_module import get_cond_while_submodules from torch._export.utils import ( get_buffer, get_lifted_tensor_constant, @@ -183,9 +182,10 @@ def process_inputs_to_lifted_tensor_constants( ) from e tensor = get_lifted_tensor_constant(edge_program, node) tensor_data = tensor.detach().numpy() # type: ignore[union-attr] + tensor_values = np.transpose(tensor_data, tosa_arg.dim_order) tosa_graph.addConst( - tensor_data.shape, tosa_arg.dtype, tensor_data, name=tosa_arg.name + tensor_values.shape, tosa_arg.dtype, tensor_values, name=tosa_arg.name ) @@ -195,46 +195,7 @@ def _is_submodule_input( """Determines whether 'node' is an input to a submodule of 'containing_graph_module'.""" if node.op != "placeholder": return False - - for _, _, submodule_node in get_cond_while_submodules(containing_graph_module): - args = cast(list[torch.fx.Node], submodule_node.args[-1]) - for arg in args: - if isinstance(arg.target, str): - # If argument is a buffer or similar, we can match exactly. - if arg.target == node.name: - return True - # If argument target has a name, the submodule input is operator name + number to avoid duplication. - # For example: cond input namespace::my_op -> submodule input my_op_1 - if (name_fn := (getattr(arg.target, "name", None))) is not None: - op_name = name_fn().split(":")[-1] - if op_name in node.name: - return True - return False - - -def _submodule_has_user_input( - containing_graph_module: torch.fx.GraphModule, edge_program: ExportedProgram -): - # If argument is a user input, there is no such guarantee. We need to to a heuristic match. - for _, _, control_flow_node in get_cond_while_submodules(containing_graph_module): - match control_flow_node.target: - case torch.ops.higher_order.cond: - args = control_flow_node.args[-1] - case torch.ops.higher_order.while_loop: - args = cast(list, control_flow_node.args[-2]) + cast( - list, control_flow_node.args[-1] - ) - case _: - raise RuntimeError( - f"Unexpected control flow target: {control_flow_node.target}" - ) - args = cast(list[torch.fx.Node], args) - for arg in args: - if ( - isinstance(arg.target, str) - and arg.target in edge_program.graph_signature.user_inputs - ): - return True + return node.meta.get("is_input", False) def process_placeholder( @@ -268,11 +229,6 @@ def process_placeholder( raise NotImplementedError( "Placeholder is of type 'lifted custom object' which is not supported." ) - elif containing_graph_module and _submodule_has_user_input( - containing_graph_module, edge_program - ): - # If we are in a submodule and it has user input, process as regular input. - process_inputs(node, tosa_graph, tosa_spec) else: raise RuntimeError(f"Placeholder '{node.name}' is of unknown type.") diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index a0abc543597..0f6d461aed9 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -712,6 +712,7 @@ def any_or_hardtanh_min_zero(n: Node): ): submodule_args_pos = -1 if node.target == torch.ops.higher_order.cond else -2 submodule_args = node.args[submodule_args_pos] + output_qspec = output_act_qspec if len(submodule_args) > 0: # type: ignore[arg-type] # The way the TOSA backend handles quantized inputs, arrays of input tensors (such as the input to a # conditional graph) need shared quantization. @@ -727,7 +728,14 @@ def any_or_hardtanh_min_zero(n: Node): ], ) ] - quant_properties.quant_output = _QuantProperty(0, output_act_qspec) + if node.target == torch.ops.higher_order.while_loop: + # The output of the while loop body can either re-enter the body, or exit the while loop. + # Therefore, A and B in the diagram below need to share the same quantization parameters. + # A -> while ( RESCALE -> ... RESCALE -> ) -> B + output_qspec = shared_qspec + + quant_properties.quant_output = _QuantProperty(0, output_qspec) + else: return None diff --git a/backends/arm/test/ops/test_cond.py b/backends/arm/test/ops/test_cond.py index 77405354bd4..c0a6df69a2d 100644 --- a/backends/arm/test/ops/test_cond.py +++ b/backends/arm/test/ops/test_cond.py @@ -47,7 +47,7 @@ def false_branch(arg: torch.Tensor) -> torch.Tensor: class CondOneArgBufferOneOutput(torch.nn.Module): def __init__(self, *args: common.Any, **kwargs: common.Any) -> None: super().__init__(*args, **kwargs) - self.buffer = torch.rand(2, 3) + self.buffer = torch.rand(1, 1, 2, 2) def forward(self, x: torch.Tensor) -> torch.Tensor: def true_branch(arg: torch.Tensor, buffer: torch.Tensor) -> torch.Tensor: @@ -159,7 +159,7 @@ def _single_input_case( module_factory: Callable[[], torch.nn.Module] ) -> Callable[[], tuple[torch.nn.Module, input_t1]]: def _create() -> tuple[torch.nn.Module, input_t1]: - return module_factory(), (torch.randn(2, 3),) + return module_factory(), (torch.randn(1, 1, 2, 2),) return _create @@ -168,7 +168,7 @@ def _dual_input_case( module_factory: Callable[[], torch.nn.Module] ) -> Callable[[], tuple[torch.nn.Module, input_t2]]: def _create() -> tuple[torch.nn.Module, input_t2]: - return module_factory(), (torch.randn(2, 3), torch.randn(2, 3)) + return module_factory(), (torch.randn(2, 3, 4, 6), torch.randn(2, 3, 4, 6)) return _create @@ -223,6 +223,7 @@ def test_cond_tosa_FP(case: Callable[[], tuple[torch.nn.Module, tuple]]): pipeline = TosaPipelineFP[tuple]( module, example_inputs, aten_op, tosa_extensions=["cf"] ) + # Make sure no cond ops are left after partitioning. pipeline.add_stage_after( "to_edge_transform_and_lower", diff --git a/backends/arm/test/ops/test_while.py b/backends/arm/test/ops/test_while.py index f66d8995683..d8ce782de2e 100644 --- a/backends/arm/test/ops/test_while.py +++ b/backends/arm/test/ops/test_while.py @@ -71,6 +71,27 @@ def body_fn( return result # type: ignore +class DecreasingOutput(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, value: torch.Tensor) -> torch.Tensor: + def cond_fn(value: torch.Tensor) -> torch.Tensor: + total = value.sum() + return torch.gt(total, torch.full((1,), 60.0)).squeeze() + + def body_fn(value: torch.Tensor) -> Tuple[torch.Tensor]: + return (torch.div(value, torch.full((1,), 2.0)),) + + result = torch.ops.higher_order.while_loop( + cond_fn, + body_fn, + (value,), + (), + ) + return result[0] # type: ignore + + class WhileAdditionalArg(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -121,7 +142,7 @@ def _single_input_case( module_factory: Callable[[], torch.nn.Module], ) -> Callable[[], Tuple[torch.nn.Module, input_single]]: def _create() -> Tuple[torch.nn.Module, input_single]: - return module_factory(), (torch.ones(2, 3),) + return module_factory(), (torch.ones(2, 3, 4, 6),) return _create @@ -138,6 +159,7 @@ def _create() -> Tuple[torch.nn.Module, input_double]: test_cases: dict[str, Callable[[], Tuple[torch.nn.Module, Tuple]]] = { "two_in_two_out": _dual_input_case(WhileTwoInputsTwoOutputs), "one_in_one_buffer_two_out": _single_input_case(WhileOneInputOneBufferTwoOutputs), + "decreasing_output": _single_input_case(DecreasingOutput), "additional_arg": _single_input_case(WhileAdditionalArg), "two_in_one_captured_out": _single_input_case(WhileSingleCapturedOutput), } diff --git a/backends/arm/tosa/backend.py b/backends/arm/tosa/backend.py index 69d16a2a708..913b5207767 100644 --- a/backends/arm/tosa/backend.py +++ b/backends/arm/tosa/backend.py @@ -20,6 +20,8 @@ from itertools import count from typing import cast, Dict, final, List +import torch + import tosa_serializer as ts from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec from executorch.backends.arm.common.debug import debug_fail, debug_tosa_dump @@ -33,6 +35,7 @@ from executorch.backends.arm.tosa.mapping import TOSA_TENSOR_NAME_META from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult from executorch.exir.backend.compile_spec_schema import CompileSpec +from executorch.exir.dim_order_utils import get_memory_format from executorch.exir.graph_module import get_cond_while_submodules from torch.export.exported_program import ExportedProgram from torch.fx import Graph, GraphModule, Node @@ -112,6 +115,15 @@ def _sort_key(t: Node) -> int: return graph_module +def _get_matching_fake_tensor(node: Node): + """Return a fake tensor with the same properties as node, + but with .dim_order() == node.meta["tosa_dim_order"] + """ + fake_tensor = node.meta["val"] + desired_dim_order = node.meta["tosa_dim_order"] + return fake_tensor.to(memory_format=get_memory_format(list(desired_dim_order))) + + def arm_get_first_delegation_tag(graph_module) -> str: """Return the first delegation tag discovered in the FX graph. @@ -253,6 +265,47 @@ def _preprocess( # noqa: C901 return PreprocessResult(processed_bytes=binary) + @staticmethod + def _regularize_submodule(submodule: GraphModule, submodule_node: Node): + """To make a submodule fit into the normal flow of a graph_module, we need to do some regularizations. + + - Buffers created before passes are treated as input to the submodule. Buffers created during passes + are treated as "normal" buffers, i.e. gathered from the state_dict. + To make it easy to tell them apart, mark all placeholders with "is_input = True" before running passes. + - Make sure output node args[0] is always iterable. + - Match the dim_order() of the input tensors with the dim orders of the submodule_node inputs. + - Match the dim_order() of the out tensors with the dim orders of the submodule_node outputs. + """ + submodule_inputs: list[Node] = [] + for node in submodule.graph.nodes: + if node.op == "placeholder": + node.meta["is_input"] = True + submodule_inputs.append(node) + match submodule_node.target: + case torch.ops.higher_order.cond: + args = cast(list[Node], submodule_node.args[-1]) + case torch.ops.higher_order.while_loop: + args = cast(list[Node], submodule_node.args[-2]) + cast( + list, submodule_node.args[-1] + ) + case _: + raise RuntimeError( + f"Unexpected control flow target: {submodule_node.target}" + ) + + for submodule_input, submodule_arg in zip(submodule_inputs, args, strict=True): + submodule_input.meta["val"] = _get_matching_fake_tensor(submodule_arg) + + output_node = submodule.graph.output_node() + if isinstance(output_node.args[0], Node): + output_node.update_arg(0, [output_node.args[0]]) + output_args = cast(list[Node], output_node.args[0]) + + # Not all outputs might be used, causing len(users) < len(outputs) + # Therefore, strict != True in the zip + for submodule_output, submodule_user in zip(output_args, submodule_node.users): + submodule_output.meta["val"] = _get_matching_fake_tensor(submodule_user) + @staticmethod def _preprocess_module( # noqa: C901 graph_module: GraphModule, @@ -278,9 +331,6 @@ def _preprocess_module( # noqa: C901 """ tosa_spec = compile_spec.tosa_spec - output_node = graph_module.graph.output_node() - if isinstance(output_node.args[0], Node): - output_node.update_arg(0, [output_node.args[0]]) node_to_id_map = _annotate_external_ids(graph_module.graph) artifact_path = compile_spec.get_intermediate_path() output_order_workaround = compile_spec.get_output_order_workaround() @@ -351,7 +401,10 @@ def _preprocess_module( # noqa: C901 raise # Recursively preprocess controlflow submodules. - for name, submodule, _ in get_cond_while_submodules(graph_module): + for name, submodule, control_flow_node in get_cond_while_submodules( + graph_module + ): + TOSABackend._regularize_submodule(submodule, control_flow_node) TOSABackend._preprocess_module( submodule, edge_program,