From 616eccf87ece4be9e39edc62d8d75724203fb2e6 Mon Sep 17 00:00:00 2001 From: Zhiyu Cheng Date: Thu, 15 Jan 2026 23:06:11 -0800 Subject: [PATCH 1/4] add support for MTP loading for GLM-4.7 in PTQ flow Signed-off-by: Zhiyu Cheng --- examples/llm_ptq/example_utils.py | 78 +++++++++++++++++++++++++++++++ examples/llm_ptq/hf_ptq.py | 5 ++ 2 files changed, 83 insertions(+) diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index 40f700781..0d10f3a8b 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -261,6 +261,81 @@ def get_processor( return None +def load_mtp_weights_if_needed(model: torch.nn.Module, model_path: str) -> bool: + """Load MTP weights from separate safetensors if needed (e.g., GLM-4.7). + + Some models store additional layers in separate safetensors files with non-standard + names (e.g., mtp.safetensors). HuggingFace's from_pretrained() may not load these + files even though they're referenced in model.safetensors.index.json. + + This function detects such cases and explicitly loads the missing weights. + + Args: + model: The loaded model that may be missing weights + model_path: Path to the model directory + + Returns: + True if additional weights were loaded, False otherwise + """ + model_path = Path(model_path) + index_file = model_path / "model.safetensors.index.json" + + if not index_file.exists(): + return False + + # Load the index to find all referenced safetensors files + with open(index_file) as f: + index = json.load(f) + + # Find all unique safetensors files referenced + all_files = set(index["weight_map"].values()) + + # Find non-standard shard files (not matching model-XXXXX-of-XXXXX.safetensors pattern) + import re + standard_pattern = re.compile(r"model-\d{5}-of-\d{5}\.safetensors") + non_standard_files = [f for f in all_files if not standard_pattern.match(f)] + + if not non_standard_files: + return False + + # Check which non-standard files exist and have missing weights + model_state = model.state_dict() + total_loaded = 0 + + for filename in non_standard_files: + filepath = model_path / filename + if not filepath.exists(): + continue + + # Find keys that should be in this file + expected_keys = [k for k, v in index["weight_map"].items() if v == filename] + + # Check which are missing from the model + missing_keys = [k for k in expected_keys if k not in model_state] + + if not missing_keys: + continue + + print(f"Loading {len(missing_keys)} missing weights from {filename}...") + + # Load the weights + weights = load_file(str(filepath)) + weights_to_load = {k: v for k, v in weights.items() if k in missing_keys} + + # Load into model + missing, unexpected = model.load_state_dict(weights_to_load, strict=False) + total_loaded += len(weights_to_load) + + if missing: + print(f" Warning: {len(missing)} keys still missing after loading {filename}") + + if total_loaded > 0: + print(f"✓ Successfully loaded {total_loaded} weights from non-standard safetensors files") + return True + + return False + + def get_dtype(dtype): if dtype == "bf16": dtype = torch.bfloat16 @@ -409,6 +484,9 @@ def get_model( if device == "cuda" and not is_model_on_gpu(model): print("Warning: Some parameters are not on a GPU. Calibration can be slow or hit OOM") + # Load any missing weights from non-standard safetensors files (e.g., GLM-4.7's mtp.safetensors) + load_mtp_weights_if_needed(model, ckpt_path) + return model diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index a9862a742..32424f445 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -282,6 +282,11 @@ def load_model(args: argparse.Namespace): ) calibration_only = True + # Load any missing weights from non-standard safetensors (handled in get_model for non-low-memory mode) + from example_utils import load_mtp_weights_if_needed + + load_mtp_weights_if_needed(full_model, args.pyt_ckpt_path) + model_type = get_model_type(full_model) device = full_model.device From 72ae28a05e0542fcd589719678efebe22aee0e38 Mon Sep 17 00:00:00 2001 From: Zhiyu Cheng Date: Thu, 15 Jan 2026 23:06:44 -0800 Subject: [PATCH 2/4] add support for MTP loading for GLM-4.7 in PTQ flow Signed-off-by: Zhiyu Cheng --- examples/llm_ptq/example_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index 0d10f3a8b..866e30ae5 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -292,6 +292,7 @@ def load_mtp_weights_if_needed(model: torch.nn.Module, model_path: str) -> bool: # Find non-standard shard files (not matching model-XXXXX-of-XXXXX.safetensors pattern) import re + standard_pattern = re.compile(r"model-\d{5}-of-\d{5}\.safetensors") non_standard_files = [f for f in all_files if not standard_pattern.match(f)] From ee965fa3c0f72cb57f2e655d0d64a7595a1522e5 Mon Sep 17 00:00:00 2001 From: Zhiyu Cheng Date: Thu, 15 Jan 2026 23:08:21 -0800 Subject: [PATCH 3/4] add support for MTP loading for GLM-4.7 in PTQ flow Signed-off-by: Zhiyu Cheng --- examples/llm_ptq/example_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index 866e30ae5..9e68b935e 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -15,7 +15,9 @@ import copy import glob +import json import os +import re import shutil import sys import warnings @@ -26,6 +28,7 @@ import transformers from accelerate import infer_auto_device_map, init_empty_weights from accelerate.utils import get_max_memory +from safetensors.torch import load_file from transformers import ( AutoConfig, AutoModelForCausalLM, From 3842c634d5eeec60a8ff18015652219396f1acdd Mon Sep 17 00:00:00 2001 From: Zhiyu Cheng Date: Thu, 15 Jan 2026 23:08:47 -0800 Subject: [PATCH 4/4] add support for MTP loading for GLM-4.7 in PTQ flow Signed-off-by: Zhiyu Cheng --- examples/llm_ptq/example_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index 9e68b935e..c34423478 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -294,7 +294,6 @@ def load_mtp_weights_if_needed(model: torch.nn.Module, model_path: str) -> bool: all_files = set(index["weight_map"].values()) # Find non-standard shard files (not matching model-XXXXX-of-XXXXX.safetensors pattern) - import re standard_pattern = re.compile(r"model-\d{5}-of-\d{5}\.safetensors") non_standard_files = [f for f in all_files if not standard_pattern.match(f)]