Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions backends/qualcomm/quantizer/annotators.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,42 @@ def annotate_scalar_tensor(node: Node, quantization_config: QuantizationConfig)
@register_annotator([torch.ops.aten.tanh.default])
def annotate_tanh(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
qconfig_16a = get_16a16w_qnn_ptq_config()
qmax, qmin = (
qconfig_16a.output_activation.quant_max,
qconfig_16a.output_activation.quant_min,
)
if (
quantization_config.output_activation.quant_max == qmax
and quantization_config.output_activation.quant_min == qmin
and _is_float_tensor(node)
):
scale = 1 / 32768.0
zero_point = 32768
if isinstance(
quantization_config.output_activation.observer_or_fake_quant_ctr,
torch.ao.quantization.fake_quantize.FakeQuantizeBase,
):
observer_ctr = FixedQParamsFakeQuantize
else:
observer_ctr = FixedQParamsObserver
observer = observer_ctr.with_args(
scale=scale,
zero_point=zero_point,
dtype=quantization_config.output_activation.dtype,
qscheme=torch.torch.per_tensor_affine,
quant_max=quantization_config.output_activation.quant_max,
quant_min=quantization_config.output_activation.quant_min,
)

annotate_output_qspec = QuantizationSpec(
dtype=quantization_config.output_activation.dtype,
quant_max=quantization_config.output_activation.quant_max,
quant_min=quantization_config.output_activation.quant_min,
observer_or_fake_quant_ctr=observer,
qscheme=torch.torch.per_tensor_affine,
)
node.meta["quantization_annotation"].output_qspec = annotate_output_qspec


@register_annotator([torch.ops.aten.full_like.default, torch.ops.aten.full.default])
Expand Down
3 changes: 3 additions & 0 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6071,6 +6071,9 @@ def setUp(self):
"gemma-2b": TestExampleLLMScript.LlmSpecs(
SM8650=32, SM8750=36, ppl=35, pte_size=2_700_000_000
), # 2.7 GB
"gemma2-2b": TestExampleLLMScript.LlmSpecs(
SM8650=32, SM8750=36, ppl=14, pte_size=2_860_000_000
), # 2.86 GB
"gemma3-1b": TestExampleLLMScript.LlmSpecs(
SM8650=70, SM8750=100, ppl=23, pte_size=1_200_000_000
), # 1.2 GB
Expand Down
3 changes: 3 additions & 0 deletions examples/models/llama/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ class ModelArgs:
model_architecture: Optional[str] = (
None # Architecture of model. For HF models, please refer to the HF model.config.architectures. This is used in QNN backend only for now.
)
# gemma2 attn and output soft capping
final_logit_softcapping: Optional[float] = None
attn_logit_softcapping: Optional[float] = None

def __post_init__(self):
if self.n_kv_heads is None:
Expand Down
6 changes: 6 additions & 0 deletions examples/qualcomm/oss_scripts/llama/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ This file provides you the instructions to run LLM Decoder model with different
1. LLAMA3.2 3B
1. Codegen2 1B
1. Gemma 2B
1. Gemma2 2B
1. Gemma3 1B
1. GLM 1.5B
1. Granite3.3 2B
Expand Down Expand Up @@ -136,6 +137,11 @@ Default example using hybrid mode
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --decoder_model gemma-2b --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1
```

#### Gemma2 2B
Default example using hybrid mode
```bash
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --decoder_model gemma2-2b --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1
```

#### Gemma3 1B
Default example using hybrid mode
Expand Down
23 changes: 23 additions & 0 deletions examples/qualcomm/oss_scripts/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
convert_weights as convert_codegen_weights,
)
from executorch.examples.models.gemma import convert_weights as convert_gemma_weights
from executorch.examples.models.gemma2 import convert_weights as convert_gemma2_weights
from executorch.examples.models.gemma3 import convert_weights as convert_gemma3_weights

from executorch.examples.models.glm import convert_weights as convert_glm_weights
Expand Down Expand Up @@ -59,6 +60,7 @@

from executorch.examples.qualcomm.oss_scripts.llama.static_llm_quant_recipe import (
CodegenQuantRecipe,
Gemma2QuantRecipe,
Gemma3QuantRecipe,
Gemma_2BQuantRecipe,
GLM_1_5B_InstructQuantRecipe,
Expand Down Expand Up @@ -88,6 +90,7 @@
"gemma3-1b": MultiScopeAwareLlamaModel,
"smolvlm_500m_instruct": LlamaModelWithoutEmbedding,
"internvl3_1b": LlamaModelWithoutEmbedding,
"gemma2-2b": MultiScopeAwareLlamaModel,
}


Expand Down Expand Up @@ -303,6 +306,26 @@ class Gemma_2B(LLMModelConfig):
quant_recipe = Gemma_2BQuantRecipe


@register_llm_model("gemma2-2b")
@dataclass(init=False, frozen=True)
class Gemma2(LLMModelConfig):
repo_id: str = "google/gemma-2-2b-it"
params_path: str = os.path.join(
BASE_DIR, "../../../models/gemma2/config/2b_config.json"
)
convert_weights = convert_gemma2_weights
transform_weight = False
instruct_model = True

num_sharding = 4
masked_softmax = True
seq_mse_candidates = 0
r1 = False
r2 = False
r3 = False
quant_recipe = Gemma2QuantRecipe


@register_llm_model("gemma3-1b")
@dataclass(init=False, frozen=True)
class Gemma3(LLMModelConfig):
Expand Down
1 change: 1 addition & 0 deletions examples/qualcomm/oss_scripts/llama/decoder_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
"llama3_2-3b_instruct": "llama3",
"codegen2_1b": "codegen",
"gemma-2b": "gemma",
"gemma2-2b": "gemma2",
"gemma3-1b": "gemma3",
"granite_3_3-2b_instruct": "granite",
"phi_4_mini": "phi_4_mini",
Expand Down
1 change: 1 addition & 0 deletions examples/qualcomm/oss_scripts/llama/model/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(self, hidden_size: int, eps=1e-5):
super().__init__(hidden_size, eps=eps)


@register_norm("gemma2")
@register_norm("gemma3")
@register_norm("rmsnorm")
class RMSNorm(torch.nn.RMSNorm, Norm):
Expand Down
15 changes: 15 additions & 0 deletions examples/qualcomm/oss_scripts/llama/model/static_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ def __init__(self, layer_idx: int, config: ModelArgs, output_new_cache_only=Fals
else 1.0 / config.attention_multiplier
)

# gemma 2 uses soft-capping on attention and logits
self.attn_logit_softcapping = config.attn_logit_softcapping

if getattr(config, "enable_r3", False):
self.register_buffer(
"r3_weight",
Expand Down Expand Up @@ -276,6 +279,11 @@ def forward(
vh = repeat_kv(vh, self.num_key_value_groups)

attn = q @ kh
# gemma2-2b
if self.attn_logit_softcapping is not None:
attn = attn / self.attn_logit_softcapping
attn = torch.tanh(attn)
attn = attn * self.attn_logit_softcapping
attn = attn / self.scale
if self.enable_masked_softmax:
attn_min = torch.amin(attn, dim=-1, keepdim=True)
Expand Down Expand Up @@ -742,6 +750,8 @@ def __init__(
self.sliding_window = kwargs["sliding_window"]
# Get local freq base for sliding attention
rope_freq_base = kwargs["rope_local_base_freq"]
# Parameter final_logit_softcapping is not necessary for all
self.final_logit_softcapping = kwargs.get("final_logit_softcapping")

local_freqs_cos, local_freqs_sin = hf_precompute_freqs_cis(
config.head_dim,
Expand Down Expand Up @@ -817,6 +827,11 @@ def forward(

hidden_states = self.norm(hidden_states)
logits = self.output(hidden_states)
if self.final_logit_softcapping:
logits = logits / self.final_logit_softcapping
logits = torch.tanh(logits)
logits = logits * self.final_logit_softcapping

if self.output_cache:
return logits, output_k_cache, output_v_cache
return logits
Expand Down
12 changes: 9 additions & 3 deletions examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
/**
* @file
*
* This tool can run Llama2 110M, Llama3.2 1B / 3B, Gemma 2B, Gemma3 1B,
* Granite3.3 2B, phi4-mini-instruct, Qwen2.5 0.5B / 1.5B, Qwen3 0.6B / 1.7B,
* SmolLM2 135M, SmolLM3 3B with Qualcomm AI Engine Direct.
* This tool can run Llama2 110M, Llama3.2 1B / 3B, Gemma 2B, Gemma2 2B, Gemma3
* 1B, Granite3.3 2B, phi4-mini-instruct, Qwen2.5 0.5B / 1.5B, Qwen3 0.6B
* / 1.7B, SmolLM2 135M, SmolLM3 3B with Qualcomm AI Engine Direct.
*
*/

Expand Down Expand Up @@ -130,6 +130,12 @@ std::string get_formatted_prompt(
formatted_prompt.append("<end_of_turn>\n");
}
break;
case example::DecoderModelVersion::kGemma2:
formatted_prompt.append("<start_of_turn>user\n");
formatted_prompt.append(prompt);
formatted_prompt.append("<end_of_turn>\n");
formatted_prompt.append("<start_of_turn>model\n");
break;
case example::DecoderModelVersion::kGranite:
if (!system_prompt.empty()) {
formatted_prompt.append("<|start_of_role|>system<|end_of_role|>");
Expand Down
4 changes: 4 additions & 0 deletions examples/qualcomm/oss_scripts/llama/runner/runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ Runner<T>::Runner(
decoder_model_version_ = DecoderModelVersion::kLlama3;
} else if (decoder_model_version == "gemma") {
decoder_model_version_ = DecoderModelVersion::kGemma;
} else if (decoder_model_version == "gemma2") {
decoder_model_version_ = DecoderModelVersion::kGemma2;
cache_mode_ = CacheMode::HybridCache;
} else if (decoder_model_version == "gemma3") {
decoder_model_version_ = DecoderModelVersion::kGemma3;
cache_mode_ = CacheMode::HybridCache;
Expand Down Expand Up @@ -202,6 +205,7 @@ Error Runner<T>::load() {
eos_ids->insert(tokenizer_->encode("<|im_end|>", 0, 0).get()[0]);
} else if (
decoder_model_version_ == DecoderModelVersion::kGemma ||
decoder_model_version_ == DecoderModelVersion::kGemma2 ||
decoder_model_version_ == DecoderModelVersion::kGemma3) {
eos_ids->insert(tokenizer_->encode("<end_of_turn>", 0, 0).get()[0]);
} else if (decoder_model_version_ == DecoderModelVersion::kCodegen) {
Expand Down
1 change: 1 addition & 0 deletions examples/qualcomm/oss_scripts/llama/runner/runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ enum DecoderModelVersion {
kSmollm3,
kCodegen,
kGlm,
kGemma2,
};

enum KvBitWidth {
Expand Down
37 changes: 37 additions & 0 deletions examples/qualcomm/oss_scripts/llama/static_llm_quant_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,43 @@ def __init__(self, verbose: bool = False):
self.recipe.custom_quant_annotations.append(annotate_kv_8bit)


class Gemma2QuantRecipe(StaticLLMQuantRecipe):
default_quant_dtype = QuantDtype.use_16a4w

def __init__(self, verbose: bool = False):
super().__init__()

self.recipe = (
QuantRecipe(
self.default_quant_dtype,
False,
act_observer=MinMaxObserver,
granularity=QuantGranularity.PER_TENSOR,
verbose=verbose,
)
.add_node_target(
{
torch.ops.aten.conv2d.default,
},
QuantDtype.use_16a4w_block,
False,
act_observer=MinMaxObserver,
granularity=QuantGranularity.PER_BLOCK,
extra_kwargs={"block_size": (1, 32, 1, 1)},
)
.add_regex(
{
r"layers\..*\.attention\.wv.*",
},
QuantDtype.use_16a8w,
False,
act_observer=MinMaxObserver,
granularity=QuantGranularity.PER_CHANNEL,
)
)
self.recipe.custom_quant_annotations.append(annotate_kv_8bit)


class Gemma3QuantRecipe(StaticLLMQuantRecipe):
default_quant_dtype = QuantDtype.use_16a4w

Expand Down
40 changes: 29 additions & 11 deletions examples/qualcomm/oss_scripts/llama/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,16 +188,6 @@ def __init__(
get_passes_dependency_for_capture_program() if apply_embedding else None
)

# check if sharding required
if self.config.num_sharding > 1:
SplitGraph, setting = model_sharding.get_split_graph_pass(
self.meta["get_n_layers"],
shares=self.config.num_sharding,
)
self.passes_job[SplitGraph] = setting
self.dep_table[SplitGraph] = [FoldQDQ]
self.dep_table[TagQuantIO] = [SplitGraph]

# load static llama model args
params_path = (
config.params_path if control_args.params is None else control_args.params
Expand All @@ -210,6 +200,17 @@ def __init__(
self.decoder = None
if (instance := self._prepare_model()) is not None:
self.tok_embedding, self.decoder = instance
self.meta = self.decoder.get_metadata()

# check if sharding required
if instance and self.config.num_sharding > 1:
SplitGraph, setting = model_sharding.get_split_graph_pass(
self.meta["get_n_layers"],
shares=self.config.num_sharding,
)
self.passes_job[SplitGraph] = setting
self.dep_table[SplitGraph] = [FoldQDQ]
self.dep_table[TagQuantIO] = [SplitGraph]

def _process_model_args(self, model_args: ModelArgs):
# TODO: support batch inputs if necessary
Expand Down Expand Up @@ -246,7 +247,11 @@ def _prepare_model(self): # noqa: C901
state_dict = torch.load(
checkpoint, weights_only=True, map_location="cpu", mmap=True
)
if self.control_args.decoder_model in {"gemma-2b", "gemma3-1b"}:
if self.control_args.decoder_model in {
"gemma-2b",
"gemma2-2b",
"gemma3-1b",
}:
for k, v in state_dict.items():
if "norm" not in k:
continue
Expand Down Expand Up @@ -395,6 +400,19 @@ def _get_model_specific_kwargs(self):
hf_config.text_config.rope_local_base_freq
)
kwargs["sliding_window"] = hf_config.sliding_window
case "gemma2-2b":
from transformers import Gemma2Config

hf_config = Gemma2Config.from_pretrained(self.config.repo_id)
kwargs["layer_types"] = hf_config.layer_types
kwargs["rope_local_base_freq"] = hf_config.rope_parameters[
"rope_theta"
]
kwargs["sliding_window"] = hf_config.sliding_window
kwargs["final_logit_softcapping"] = (
hf_config.final_logit_softcapping
)
kwargs["attn_logit_softcapping"] = hf_config.attn_logit_softcapping

return kwargs

Expand Down
Loading