Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 3 additions & 47 deletions backends/arm/process_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from executorch.backends.arm.tosa.mapping import TosaArg, TosaSpecialDtype
from executorch.backends.arm.tosa.specification import TosaSpecification
from executorch.backends.arm.tosa.utils import tosa_shape
from executorch.exir.graph_module import get_cond_while_submodules
from torch._export.utils import (
get_buffer,
get_lifted_tensor_constant,
Expand Down Expand Up @@ -183,9 +182,10 @@ def process_inputs_to_lifted_tensor_constants(
) from e
tensor = get_lifted_tensor_constant(edge_program, node)
tensor_data = tensor.detach().numpy() # type: ignore[union-attr]
tensor_values = np.transpose(tensor_data, tosa_arg.dim_order)

tosa_graph.addConst(
tensor_data.shape, tosa_arg.dtype, tensor_data, name=tosa_arg.name
tensor_values.shape, tosa_arg.dtype, tensor_values, name=tosa_arg.name
)


Expand All @@ -195,46 +195,7 @@ def _is_submodule_input(
"""Determines whether 'node' is an input to a submodule of 'containing_graph_module'."""
if node.op != "placeholder":
return False

for _, _, submodule_node in get_cond_while_submodules(containing_graph_module):
args = cast(list[torch.fx.Node], submodule_node.args[-1])
for arg in args:
if isinstance(arg.target, str):
# If argument is a buffer or similar, we can match exactly.
if arg.target == node.name:
return True
# If argument target has a name, the submodule input is operator name + number to avoid duplication.
# For example: cond input namespace::my_op -> submodule input my_op_1
if (name_fn := (getattr(arg.target, "name", None))) is not None:
op_name = name_fn().split(":")[-1]
if op_name in node.name:
return True
return False


def _submodule_has_user_input(
containing_graph_module: torch.fx.GraphModule, edge_program: ExportedProgram
):
# If argument is a user input, there is no such guarantee. We need to to a heuristic match.
for _, _, control_flow_node in get_cond_while_submodules(containing_graph_module):
match control_flow_node.target:
case torch.ops.higher_order.cond:
args = control_flow_node.args[-1]
case torch.ops.higher_order.while_loop:
args = cast(list, control_flow_node.args[-2]) + cast(
list, control_flow_node.args[-1]
)
case _:
raise RuntimeError(
f"Unexpected control flow target: {control_flow_node.target}"
)
args = cast(list[torch.fx.Node], args)
for arg in args:
if (
isinstance(arg.target, str)
and arg.target in edge_program.graph_signature.user_inputs
):
return True
return node.meta.get("is_input", False)


def process_placeholder(
Expand Down Expand Up @@ -268,11 +229,6 @@ def process_placeholder(
raise NotImplementedError(
"Placeholder is of type 'lifted custom object' which is not supported."
)
elif containing_graph_module and _submodule_has_user_input(
containing_graph_module, edge_program
):
# If we are in a submodule and it has user input, process as regular input.
process_inputs(node, tosa_graph, tosa_spec)
else:
raise RuntimeError(f"Placeholder '{node.name}' is of unknown type.")

Expand Down
10 changes: 9 additions & 1 deletion backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,7 @@ def any_or_hardtanh_min_zero(n: Node):
):
submodule_args_pos = -1 if node.target == torch.ops.higher_order.cond else -2
submodule_args = node.args[submodule_args_pos]
output_qspec = output_act_qspec
if len(submodule_args) > 0: # type: ignore[arg-type]
# The way the TOSA backend handles quantized inputs, arrays of input tensors (such as the input to a
# conditional graph) need shared quantization.
Expand All @@ -727,7 +728,14 @@ def any_or_hardtanh_min_zero(n: Node):
],
)
]
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
if node.target == torch.ops.higher_order.while_loop:
# The output of the while loop body can either re-enter the body, or exit the while loop.
# Therefore, A and B in the diagram below need to share the same quantization parameters.
# A -> while ( RESCALE -> ... RESCALE -> ) -> B
output_qspec = shared_qspec

quant_properties.quant_output = _QuantProperty(0, output_qspec)

else:
return None

Expand Down
7 changes: 4 additions & 3 deletions backends/arm/test/ops/test_cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def false_branch(arg: torch.Tensor) -> torch.Tensor:
class CondOneArgBufferOneOutput(torch.nn.Module):
def __init__(self, *args: common.Any, **kwargs: common.Any) -> None:
super().__init__(*args, **kwargs)
self.buffer = torch.rand(2, 3)
self.buffer = torch.rand(1, 1, 2, 2)

def forward(self, x: torch.Tensor) -> torch.Tensor:
def true_branch(arg: torch.Tensor, buffer: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -159,7 +159,7 @@ def _single_input_case(
module_factory: Callable[[], torch.nn.Module]
) -> Callable[[], tuple[torch.nn.Module, input_t1]]:
def _create() -> tuple[torch.nn.Module, input_t1]:
return module_factory(), (torch.randn(2, 3),)
return module_factory(), (torch.randn(1, 1, 2, 2),)

return _create

Expand All @@ -168,7 +168,7 @@ def _dual_input_case(
module_factory: Callable[[], torch.nn.Module]
) -> Callable[[], tuple[torch.nn.Module, input_t2]]:
def _create() -> tuple[torch.nn.Module, input_t2]:
return module_factory(), (torch.randn(2, 3), torch.randn(2, 3))
return module_factory(), (torch.randn(2, 3, 4, 6), torch.randn(2, 3, 4, 6))

return _create

Expand Down Expand Up @@ -223,6 +223,7 @@ def test_cond_tosa_FP(case: Callable[[], tuple[torch.nn.Module, tuple]]):
pipeline = TosaPipelineFP[tuple](
module, example_inputs, aten_op, tosa_extensions=["cf"]
)

# Make sure no cond ops are left after partitioning.
pipeline.add_stage_after(
"to_edge_transform_and_lower",
Expand Down
24 changes: 23 additions & 1 deletion backends/arm/test/ops/test_while.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,27 @@ def body_fn(
return result # type: ignore


class DecreasingOutput(torch.nn.Module):
def __init__(self) -> None:
super().__init__()

def forward(self, value: torch.Tensor) -> torch.Tensor:
def cond_fn(value: torch.Tensor) -> torch.Tensor:
total = value.sum()
return torch.gt(total, torch.full((1,), 60.0)).squeeze()

def body_fn(value: torch.Tensor) -> Tuple[torch.Tensor]:
return (torch.div(value, torch.full((1,), 2.0)),)

result = torch.ops.higher_order.while_loop(
cond_fn,
body_fn,
(value,),
(),
)
return result[0] # type: ignore


class WhileAdditionalArg(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
Expand Down Expand Up @@ -121,7 +142,7 @@ def _single_input_case(
module_factory: Callable[[], torch.nn.Module],
) -> Callable[[], Tuple[torch.nn.Module, input_single]]:
def _create() -> Tuple[torch.nn.Module, input_single]:
return module_factory(), (torch.ones(2, 3),)
return module_factory(), (torch.ones(2, 3, 4, 6),)

return _create

Expand All @@ -138,6 +159,7 @@ def _create() -> Tuple[torch.nn.Module, input_double]:
test_cases: dict[str, Callable[[], Tuple[torch.nn.Module, Tuple]]] = {
"two_in_two_out": _dual_input_case(WhileTwoInputsTwoOutputs),
"one_in_one_buffer_two_out": _single_input_case(WhileOneInputOneBufferTwoOutputs),
"decreasing_output": _single_input_case(DecreasingOutput),
"additional_arg": _single_input_case(WhileAdditionalArg),
"two_in_one_captured_out": _single_input_case(WhileSingleCapturedOutput),
}
Expand Down
61 changes: 57 additions & 4 deletions backends/arm/tosa/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from itertools import count
from typing import cast, Dict, final, List

import torch

import tosa_serializer as ts
from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec
from executorch.backends.arm.common.debug import debug_fail, debug_tosa_dump
Expand All @@ -33,6 +35,7 @@
from executorch.backends.arm.tosa.mapping import TOSA_TENSOR_NAME_META
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.dim_order_utils import get_memory_format
from executorch.exir.graph_module import get_cond_while_submodules
from torch.export.exported_program import ExportedProgram
from torch.fx import Graph, GraphModule, Node
Expand Down Expand Up @@ -112,6 +115,15 @@ def _sort_key(t: Node) -> int:
return graph_module


def _get_matching_fake_tensor(node: Node):
"""Return a fake tensor with the same properties as node,
but with .dim_order() == node.meta["tosa_dim_order"]
"""
fake_tensor = node.meta["val"]
desired_dim_order = node.meta["tosa_dim_order"]
return fake_tensor.to(memory_format=get_memory_format(list(desired_dim_order)))


def arm_get_first_delegation_tag(graph_module) -> str:
"""Return the first delegation tag discovered in the FX graph.

Expand Down Expand Up @@ -253,6 +265,47 @@ def _preprocess( # noqa: C901

return PreprocessResult(processed_bytes=binary)

@staticmethod
def _regularize_submodule(submodule: GraphModule, submodule_node: Node):
"""To make a submodule fit into the normal flow of a graph_module, we need to do some regularizations.

- Buffers created before passes are treated as input to the submodule. Buffers created during passes
are treated as "normal" buffers, i.e. gathered from the state_dict.
To make it easy to tell them apart, mark all placeholders with "is_input = True" before running passes.
- Make sure output node args[0] is always iterable.
- Match the dim_order() of the input tensors with the dim orders of the submodule_node inputs.
- Match the dim_order() of the out tensors with the dim orders of the submodule_node outputs.
"""
submodule_inputs: list[Node] = []
for node in submodule.graph.nodes:
if node.op == "placeholder":
node.meta["is_input"] = True
submodule_inputs.append(node)
match submodule_node.target:
case torch.ops.higher_order.cond:
args = cast(list[Node], submodule_node.args[-1])
case torch.ops.higher_order.while_loop:
args = cast(list[Node], submodule_node.args[-2]) + cast(
list, submodule_node.args[-1]
)
case _:
raise RuntimeError(
f"Unexpected control flow target: {submodule_node.target}"
)

for submodule_input, submodule_arg in zip(submodule_inputs, args, strict=True):
submodule_input.meta["val"] = _get_matching_fake_tensor(submodule_arg)

output_node = submodule.graph.output_node()
if isinstance(output_node.args[0], Node):
output_node.update_arg(0, [output_node.args[0]])
output_args = cast(list[Node], output_node.args[0])

# Not all outputs might be used, causing len(users) < len(outputs)
# Therefore, strict != True in the zip
for submodule_output, submodule_user in zip(output_args, submodule_node.users):
submodule_output.meta["val"] = _get_matching_fake_tensor(submodule_user)

@staticmethod
def _preprocess_module( # noqa: C901
graph_module: GraphModule,
Expand All @@ -278,9 +331,6 @@ def _preprocess_module( # noqa: C901

"""
tosa_spec = compile_spec.tosa_spec
output_node = graph_module.graph.output_node()
if isinstance(output_node.args[0], Node):
output_node.update_arg(0, [output_node.args[0]])
node_to_id_map = _annotate_external_ids(graph_module.graph)
artifact_path = compile_spec.get_intermediate_path()
output_order_workaround = compile_spec.get_output_order_workaround()
Expand Down Expand Up @@ -351,7 +401,10 @@ def _preprocess_module( # noqa: C901
raise

# Recursively preprocess controlflow submodules.
for name, submodule, _ in get_cond_while_submodules(graph_module):
for name, submodule, control_flow_node in get_cond_while_submodules(
graph_module
):
TOSABackend._regularize_submodule(submodule, control_flow_node)
TOSABackend._preprocess_module(
submodule,
edge_program,
Expand Down
Loading