Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 48 additions & 2 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2025 Arm Limited and/or its affiliates.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -37,6 +37,7 @@
get_mps_partitioner,
get_openvino_partitioner,
get_qnn_partitioner,
get_tosa_partitioner,
get_vulkan_partitioner,
get_xnnpack_partitioner,
)
Expand All @@ -46,6 +47,7 @@
get_pt2e_quantization_params,
get_pt2e_quantizers,
get_qnn_quantizer,
get_tosa_quantizer,
get_vulkan_quantizer,
)
from executorch.util.activation_memory_profiler import generate_memory_trace
Expand Down Expand Up @@ -210,6 +212,7 @@ def build_args_parser() -> argparse.ArgumentParser:
"coreml_baseline_8a_c8w",
"coreml_baseline_8a_c4w",
"vulkan_8w",
"tosa_8a8w",
],
help="Use PT2E quantization. Comma separated options. e.g. xnnpack_dynamic (for per channel 8 bit weight), xnnpack_dynamic_qc4 (for per channel 4 bit weight), embedding.",
)
Expand Down Expand Up @@ -788,6 +791,11 @@ def get_quantizer_and_quant_params(llm_config):
llm_config.quantization.pt2e_quantize.value
)
quantizers.append(coreml_quantizer)
if llm_config.backend.tosa.enabled and llm_config.quantization.pt2e_quantize:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to add a unit test to cover this logic branch?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me know what you think

tosa_quantizer = get_tosa_quantizer(
llm_config.backend.tosa.version, llm_config.quantization.pt2e_quantize.value
)
quantizers.append(tosa_quantizer)
if llm_config.backend.vulkan.enabled and llm_config.quantization.pt2e_quantize:
assert (
len(quantizers) == 0
Expand Down Expand Up @@ -930,6 +938,32 @@ def _to_edge_and_lower_llama_openvino(
return builder.to_executorch(passes=additional_passes)


def _to_edge_and_lower_llama_tosa(
builder_exported,
modelname,
quantizers,
additional_passes,
tosa_spec,
verbose: bool = False,
) -> LLMEdgeManager:

logging.info("Lowering model using TOSA partitioner")

partitioners = []
partitioners.append(get_tosa_partitioner(tosa_spec))

modelname = f"tosa_{modelname}"

builder = builder_exported.pt2e_quantize(quantizers).to_edge_transform_and_lower(
partitioners
)

if verbose:
print_delegation_info(builder.edge_manager.exported_program().graph_module)

return builder.to_executorch(passes=additional_passes)


def _to_edge_and_lower_llama( # noqa: C901
builder_exported,
modelname,
Expand Down Expand Up @@ -1119,7 +1153,10 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])]

# export_to_edge
builder_exported = _prepare_for_llama_export(llm_config).export()
builder_manager = _prepare_for_llama_export(llm_config)
if llm_config.backend.tosa.enabled:
builder_manager.skip_dim_order = False
builder_exported = builder_manager.export()
builder_exported.run_canonical_optimizations()
modelname = builder_exported.modelname

Expand Down Expand Up @@ -1162,6 +1199,15 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
openvino_device=llm_config.backend.openvino.device,
verbose=llm_config.debug.verbose,
)
elif llm_config.backend.tosa.enabled:
builder = _to_edge_and_lower_llama_tosa(
builder_exported,
modelname,
quantizers,
additional_passes,
llm_config.backend.tosa.version,
verbose=llm_config.debug.verbose,
)
else:
builder = _to_edge_and_lower_llama(
builder_exported,
Expand Down
1 change: 1 addition & 0 deletions examples/models/llama/tests/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ python_unittest(
],
deps = [
"//caffe2:torch",
"//executorch/backends/arm/quantizer:lib",
"//executorch/examples/models/llama:export_library",
"//executorch/examples/models/llama:llama_transformer",
"//pytorch/ao:torchao",
Expand Down
20 changes: 19 additions & 1 deletion examples/models/llama/tests/test_export_llama_lib.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

from executorch.backends.arm.quantizer.arm_quantizer import TOSAQuantizer

from executorch.devtools.backend_debug import get_delegation_info
from executorch.examples.models.llama.export_llama_lib import (
_export_llama,
build_args_parser,
get_quantizer_and_quant_params,
)
from executorch.extension.llm.export.config.llm_config import LlmConfig
from executorch.extension.llm.export.config.llm_config import LlmConfig, Pt2eQuantize

UNWANTED_OPS = [
"aten_permute_copy_default",
Expand Down Expand Up @@ -48,3 +52,17 @@ def test_has_expected_ops_and_op_counts(self):

for op, _op_info in delegation_info.delegation_by_operator.items():
self.assertTrue(op not in UNWANTED_OPS)

def test_get_quantizer_and_quant_params_returns_tosa_quantizer(self):
llm_config = LlmConfig()
llm_config.backend.tosa.enabled = True
llm_config.quantization.pt2e_quantize = Pt2eQuantize.tosa_8a8w

pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(
llm_config
)

self.assertIsNone(pt2e_quant_params)
self.assertIsNone(quant_dtype)
self.assertEqual(len(quantizers), 1)
self.assertIsInstance(quantizers[0], TOSAQuantizer)
5 changes: 4 additions & 1 deletion extension/llm/export/builder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -96,6 +97,7 @@ def __init__(
dynamic_shapes: Optional[Any] = None,
save_exported_program: bool = False,
generate_etrecord: bool = False,
skip_dim_order: bool = True,
):
# Store necessary constructor arguments.
self.model = model
Expand All @@ -118,6 +120,7 @@ def __init__(
self.dynamic_shapes = dynamic_shapes
self.save_exported_program = save_exported_program
self.generate_etrecord = generate_etrecord
self.skip_dim_order = skip_dim_order

# Note: treat this as the source of truth for the result of
# torch.export'ing a model. If the overall ExportedProgram is needed,
Expand Down Expand Up @@ -197,7 +200,7 @@ def _get_dynamic_shape(self) -> Any:
def _get_edge_config(self) -> EdgeCompileConfig:
edge_config = EdgeCompileConfig(
_check_ir_validity=False,
_skip_dim_order=True,
_skip_dim_order=self.skip_dim_order,
)
return edge_config

Expand Down
13 changes: 13 additions & 0 deletions extension/llm/export/config/llm_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -288,6 +289,7 @@ class Pt2eQuantize(str, Enum):
coreml_baseline_8a_c8w = "coreml_baseline_8a_c8w"
coreml_baseline_8a_c4w = "coreml_baseline_8a_c4w"
vulkan_8w = "vulkan_8w"
tosa_8a8w = "tosa_8a8w"


class SpinQuant(str, Enum):
Expand Down Expand Up @@ -474,6 +476,16 @@ class TorchAOKernelsConfig:
use_torchao_kernels_tied_embedding: bool = False


@dataclass
class TosaConfig:
"""
Configures the TOSA backend.
"""

enabled: bool = False
version: str = "TOSA-1.0+INT"


@dataclass
class BackendConfig:
"""
Expand All @@ -488,6 +500,7 @@ class BackendConfig:
mps: MPSConfig = field(default_factory=MPSConfig)
openvino: OpenvinoConfig = field(default_factory=OpenvinoConfig)
torchao: TorchAOKernelsConfig = field(default_factory=TorchAOKernelsConfig)
tosa: TosaConfig = field(default_factory=TosaConfig)


################################################################################
Expand Down
10 changes: 10 additions & 0 deletions extension/llm/export/partitioner_lib.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -236,3 +237,12 @@ def get_qnn_partitioner(
# TODO: if deprecated legacy export, skip_mutable_buffer can be set False
skip_mutable_buffer=True,
)


def get_tosa_partitioner(version: str):
from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec
from executorch.backends.arm.tosa.partitioner import TOSAPartitioner

compile_spec = TosaCompileSpec(version)

return TOSAPartitioner(compile_spec)
20 changes: 20 additions & 0 deletions extension/llm/export/quantizer_lib.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -320,3 +321,22 @@ def get_vulkan_quantizer(pt2e_quantize: str):

quantizer = VulkanQuantizer().set_global(config)
return quantizer


def get_tosa_quantizer(version: str, pt2e_quantize: str):
from executorch.backends.arm.quantizer.arm_quantizer import (
get_symmetric_quantization_config,
TOSAQuantizer,
)
from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec

compile_spec = TosaCompileSpec(version)

quantizer = TOSAQuantizer(compile_spec)

if pt2e_quantize == "tosa_8a8w":
quantizer.set_global(get_symmetric_quantization_config())
else:
raise ValueError(f"Unsupported quantizer specification {pt2e_quantize}")

return quantizer
Loading