Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions examples/llm_ptq/example_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@

import copy
import glob
import json
import os
import re
import shutil
import sys
import warnings
Expand All @@ -26,6 +28,7 @@
import transformers
from accelerate import infer_auto_device_map, init_empty_weights
from accelerate.utils import get_max_memory
from safetensors.torch import load_file
from transformers import (
AutoConfig,
AutoModelForCausalLM,
Expand Down Expand Up @@ -261,6 +264,81 @@ def get_processor(
return None


def load_mtp_weights_if_needed(model: torch.nn.Module, model_path: str) -> bool:
"""Load MTP weights from separate safetensors if needed (e.g., GLM-4.7).

Some models store additional layers in separate safetensors files with non-standard
names (e.g., mtp.safetensors). HuggingFace's from_pretrained() may not load these
files even though they're referenced in model.safetensors.index.json.

This function detects such cases and explicitly loads the missing weights.

Args:
model: The loaded model that may be missing weights
model_path: Path to the model directory

Returns:
True if additional weights were loaded, False otherwise
"""
model_path = Path(model_path)
index_file = model_path / "model.safetensors.index.json"

if not index_file.exists():
return False

# Load the index to find all referenced safetensors files
with open(index_file) as f:
index = json.load(f)

# Find all unique safetensors files referenced
all_files = set(index["weight_map"].values())

# Find non-standard shard files (not matching model-XXXXX-of-XXXXX.safetensors pattern)

standard_pattern = re.compile(r"model-\d{5}-of-\d{5}\.safetensors")
non_standard_files = [f for f in all_files if not standard_pattern.match(f)]

if not non_standard_files:
return False

# Check which non-standard files exist and have missing weights
model_state = model.state_dict()
total_loaded = 0

for filename in non_standard_files:
filepath = model_path / filename
if not filepath.exists():
continue

# Find keys that should be in this file
expected_keys = [k for k, v in index["weight_map"].items() if v == filename]

# Check which are missing from the model
missing_keys = [k for k in expected_keys if k not in model_state]

if not missing_keys:
continue

print(f"Loading {len(missing_keys)} missing weights from {filename}...")

# Load the weights
weights = load_file(str(filepath))
weights_to_load = {k: v for k, v in weights.items() if k in missing_keys}

# Load into model
missing, unexpected = model.load_state_dict(weights_to_load, strict=False)
total_loaded += len(weights_to_load)

if missing:
print(f" Warning: {len(missing)} keys still missing after loading {filename}")

if total_loaded > 0:
print(f"✓ Successfully loaded {total_loaded} weights from non-standard safetensors files")
return True

return False


def get_dtype(dtype):
if dtype == "bf16":
dtype = torch.bfloat16
Expand Down Expand Up @@ -409,6 +487,9 @@ def get_model(
if device == "cuda" and not is_model_on_gpu(model):
print("Warning: Some parameters are not on a GPU. Calibration can be slow or hit OOM")

# Load any missing weights from non-standard safetensors files (e.g., GLM-4.7's mtp.safetensors)
load_mtp_weights_if_needed(model, ckpt_path)

return model


Expand Down
5 changes: 5 additions & 0 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,11 @@ def load_model(args: argparse.Namespace):
)
calibration_only = True

# Load any missing weights from non-standard safetensors (handled in get_model for non-low-memory mode)
from example_utils import load_mtp_weights_if_needed

load_mtp_weights_if_needed(full_model, args.pyt_ckpt_path)

model_type = get_model_type(full_model)

device = full_model.device
Expand Down