Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
264 changes: 189 additions & 75 deletions devtools/inspector/_inspector_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,168 @@ def dfs(node_id, component):
return connected_components


def _map_sequence_aot_output(
aot_intermediate_output: Sequence,
runtime_intermediate_output: Any,
negative_index: int,
) -> Tuple[Tuple, Tuple]:
"""
Handle the case when aot_intermediate_output is a Sequence.

Returns:
Tuple of (aot_intermediate_output as tuple, mapped runtime output as tuple)
"""
if not isinstance(runtime_intermediate_output, Sequence):
raise TypeError(
"runtime intermediate output should be a sequence when aot intermediate output is a sequence"
)
last_element = runtime_intermediate_output[negative_index]
# TODO: this (last_element = list) is never really the case because runtime never returns output as a list
# for delegate case.
if isinstance(last_element, list) and all(
isinstance(t, torch.Tensor) for t in last_element
):
# If the last element is a list of tensors (delegate case)
aot_mapped_runtime_intermediate_output = last_element
elif isinstance(last_element, torch.Tensor):
# If the last element is a tensor, as is always the case for runtime.
# However, now we have a strange condition where aot_intermediate_output is a list of tensors
# while runtime_intermediate_output is a single tensor. So we should never really come here.
# TODO: fix this
aot_mapped_runtime_intermediate_output = runtime_intermediate_output
else:
raise ValueError(
"The last element of runtime argument list must be a tensor or a list of tensors when aot intermediate output is a sequence"
)
# List can't be used as a key, so convert to tuple
return tuple(aot_intermediate_output), tuple(aot_mapped_runtime_intermediate_output)


def _find_matching_runtime_output_by_shape_and_dtype(
aot_intermediate_output: torch.Tensor,
runtime_intermediate_output: Sequence,
) -> Any:
"""
Find the runtime output that matches the AOT output shape and dtype.
Used for multi-output operations (like native_layer_norm.out, native_dropout.out).

Returns:
The matching runtime output, or runtime_intermediate_output[-1] as fallback.
"""
# Find all runtime outputs that match the AOT shape
matching_indices = []
for idx, runtime_out in enumerate(runtime_intermediate_output):
if isinstance(runtime_out, torch.Tensor):
if runtime_out.shape == aot_intermediate_output.shape:
matching_indices.append(idx)

if len(matching_indices) == 1:
# Exactly one shape match - use it (native multi-output case like layer_norm)
return runtime_intermediate_output[matching_indices[0]]

if len(matching_indices) > 1:
# Multiple shape matches - try to distinguish by dtype
# For native_dropout, output is float and mask is bool; prefer matching dtype
dtype_matching_indices = []
for idx in matching_indices:
runtime_out = runtime_intermediate_output[idx]
if isinstance(runtime_out, torch.Tensor):
if runtime_out.dtype == aot_intermediate_output.dtype:
dtype_matching_indices.append(idx)

if len(dtype_matching_indices) == 1:
# Exactly one dtype match - use it (e.g., dropout case where mask is bool)
return runtime_intermediate_output[dtype_matching_indices[0]]

# No unique match found, return the last element as fallback
return runtime_intermediate_output[-1]


def _map_non_sequence_aot_output(
aot_intermediate_output: Any,
runtime_intermediate_output: Any,
num_outputs: int,
negative_index: int,
) -> Any:
"""
Handle the case when aot_intermediate_output is NOT a Sequence.

Returns:
The mapped runtime intermediate output.
"""
if not isinstance(runtime_intermediate_output, Sequence):
return runtime_intermediate_output

# Use the last element of the runtime output as fallback if no match is found
aot_mapped_runtime_intermediate_output = runtime_intermediate_output[negative_index]

# delegate runtime call and AOT intermediate is not a sequence.
# For multi-output operations (like native_layer_norm.out, native_dropout.out),
# the runtime captures all outputs but AOT only captures the primary output.
# We need to find the runtime output that matches the AOT output shape and dtype.
if (
num_outputs == 1
and len(runtime_intermediate_output) > 1
and isinstance(aot_intermediate_output, torch.Tensor)
):
aot_mapped_runtime_intermediate_output = (
_find_matching_runtime_output_by_shape_and_dtype(
aot_intermediate_output, runtime_intermediate_output
)
)

return aot_mapped_runtime_intermediate_output


def _process_single_runtime_output(
aot_list: List[Tuple[DebugHandle, Any]],
runtime_debug_handle: DebugHandle,
runtime_intermediate_output: Any,
num_outputs: int,
output_index: int,
) -> Optional[Tuple[Tuple[DebugHandle, Any], Tuple[DebugHandle, Any]]]:
"""
Process a single runtime output and map it to the corresponding AOT output.

Returns:
A tuple of ((aot_debug_handle, aot_output), (runtime_debug_handle, runtime_output))
or None if the mapping should be skipped.
"""
negative_index = -1 * (output_index + 1)

# Combine aot debug handles into a single key
aot_combined_debug_handle, aot_intermediate_output = (
_combine_aot_overlapped_intermediate_outputs(
aot_list,
(runtime_debug_handle, runtime_intermediate_output, num_outputs),
negative_index,
)
)

if aot_combined_debug_handle == (-1,):
# Skip this mapping if the aot combined debug handle and runtime debug handle do not exact match.
return None

if isinstance(aot_intermediate_output, Sequence):
aot_intermediate_output, aot_mapped_runtime_intermediate_output = (
_map_sequence_aot_output(
aot_intermediate_output, runtime_intermediate_output, negative_index
)
)
else:
aot_mapped_runtime_intermediate_output = _map_non_sequence_aot_output(
aot_intermediate_output,
runtime_intermediate_output,
num_outputs,
negative_index,
)

return (
(aot_combined_debug_handle, aot_intermediate_output),
(runtime_debug_handle, aot_mapped_runtime_intermediate_output),
)


def map_runtime_aot_intermediate_outputs(
aot_intermediate_outputs: Dict[DebugHandle, Any],
runtime_intermediate_outputs: Dict[DebugHandle, Tuple[Any, int]],
Expand Down Expand Up @@ -789,84 +951,36 @@ def map_runtime_aot_intermediate_outputs(
if nodes[node_id].source == NodeSource.RUNTIME
]

# Map only if both AOT and runtime data are present.
if len(aot_list) != 0 and len(runtime_list) != 0:
# The size of runtime_list should be 1 because all AOT debug_handles are tuples with one element.
# Additionally, runtime debug handles have already undergone pre-processing to merge overlapping debug_hanldes.
# As a result, there shouldn't be any 1-to-n or n-to-n (AOT to runtime) mappings.
if len(runtime_list) != 1:
raise ValueError(
f"Expected only one runtime debug handle, but found {len(runtime_list)}: {runtime_list}"
)
if len(aot_list) == 0 or len(runtime_list) == 0:
# Skip this mapping if there are no AOT or runtime data.
continue

runtime_debug_handle, runtime_intermediate_output, num_outputs = (
runtime_list[0]
# The size of runtime_list should be 1 because all AOT debug_handles are tuples with one element.
# Additionally, runtime debug handles have already undergone pre-processing to merge overlapping debug_hanldes.
# As a result, there shouldn't be any 1-to-n or n-to-n (AOT to runtime) mappings.
if len(runtime_list) != 1:
raise ValueError(
f"Expected only one runtime debug handle, but found {len(runtime_list)}: {runtime_list}"
)
# iterate through each of the output from runtime,
# get the corresponding debug handle
# and map it to the aot debug handle
# and create a dictionary that maps aot debug handle + aot output to
# runtime debug handle + runtime output
# Note this works only for delegate case for now.
for i in range(num_outputs):

negative_index = -1 * (i + 1)
aot_mapped_runtime_intermediate_output = runtime_intermediate_output
# Combine aot debug handles into a single key
aot_combined_debug_handle, aot_intermediate_output = (
_combine_aot_overlapped_intermediate_outputs(
aot_list, runtime_list[0], negative_index
)
)

if aot_combined_debug_handle == (-1,):
# Skip this mapping if the aot combined debug handle and runtime debug handle do not exact match.
continue

if isinstance(aot_intermediate_output, Sequence):
if not isinstance(runtime_intermediate_output, Sequence):
raise TypeError(
"runtime intermediate output should be a sequence when aot intermediate output is a sequence"
)
last_element = runtime_intermediate_output[negative_index]
# TODO: this (last_element = list) is never really the case because runtime never returns output as a list
# for delegate case.
if isinstance(last_element, list) and all(
isinstance(t, torch.Tensor) for t in last_element
):
# If the last element is a list of tensors (delegate case)
aot_mapped_runtime_intermediate_output = last_element
elif isinstance(last_element, torch.Tensor):
# If the last element is a tensor, as is always the case for runtime.
# However, now we have a strange condition where aot_intermediate_output is a list of tensors
# while runtime_intermediate_output is a single tensor. So we should never really come here.
# TODO: fix this
aot_mapped_runtime_intermediate_output = (
runtime_intermediate_output
)
else:
raise ValueError(
"The last element of runtime argument list must be a tensor or a list of tensors when aot intermediate output is a sequence"
)
# List can't be used as a key, so convert to tuple
aot_intermediate_output = tuple(aot_intermediate_output)
aot_mapped_runtime_intermediate_output = tuple(
aot_mapped_runtime_intermediate_output
)

elif isinstance(runtime_intermediate_output, Sequence):
# delegate runtime call and AOT intermediate is not a sequence, just take the last element from runtime list
aot_mapped_runtime_intermediate_output = (
runtime_intermediate_output[negative_index]
)

# Create a mapping between runtime and aot
aot_runtime_mapping[
(aot_combined_debug_handle, aot_intermediate_output)
] = (
runtime_debug_handle,
aot_mapped_runtime_intermediate_output,
)
runtime_debug_handle, runtime_intermediate_output, num_outputs = runtime_list[0]
# iterate through each of the output from runtime,
# get the corresponding debug handle
# and map it to the aot debug handle
# and create a dictionary that maps aot debug handle + aot output to
# runtime debug handle + runtime output
# Note this works only for delegate case for now.
for i in range(num_outputs):
result = _process_single_runtime_output(
aot_list,
runtime_debug_handle,
runtime_intermediate_output,
num_outputs,
i,
)
if result is not None:
aot_key, runtime_value = result
aot_runtime_mapping[aot_key] = runtime_value

return aot_runtime_mapping

Expand Down
5 changes: 4 additions & 1 deletion devtools/inspector/tests/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,17 @@ python_unittest(
ci.buckconfig("executorch.event_tracer_enabled", "true"),
),
deps = [
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
"//executorch/backends/xnnpack/utils:xnnpack_utils",
"//executorch/devtools:lib",
"//executorch/devtools/debug_format:et_schema",
"//executorch/devtools/etdump:schema_flatcc",
"//executorch/devtools/etrecord/tests:etrecord_test_library",
"//executorch/devtools/inspector:inspector",
"//executorch/devtools/inspector:lib",
"//executorch/exir:lib",
"//executorch/devtools/inspector/tests:inspector_test_utils",
"//executorch/exir:lib",
"//executorch/runtime:runtime",
],
)

Expand Down
Loading
Loading