Skip to content
12 changes: 0 additions & 12 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,12 +229,6 @@ def build_args_parser() -> argparse.ArgumentParser:
help="Path to the checkpoint .pth file. When not provided, the model will be initialized with random weights.",
)

parser.add_argument(
"--checkpoint_dir",
default=None,
help="checkpoint directory. Use with a sharded checkpoint, not for the standard llama2 model. Note, checkpoint_dir takes precedence over checkpoint if both are set.",
)

parser.add_argument(
"--adapter_checkpoint",
required=False,
Expand Down Expand Up @@ -678,18 +672,12 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
if llm_config.base.checkpoint
else None
)
checkpoint_dir = (
canonical_path(llm_config.base.checkpoint_dir)
if llm_config.base.checkpoint_dir
else None
)
params_path = (
canonical_path(llm_config.base.params) if llm_config.base.params else None
)
output_dir_path = canonical_path(llm_config.export.output_dir, dir=True)

llm_config.base.checkpoint = checkpoint_path
llm_config.base.checkpoint_dir = checkpoint_dir
llm_config.base.params = params_path
llm_config.export.output_dir = output_dir_path

Expand Down
38 changes: 2 additions & 36 deletions examples/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
# pyre-unsafe

import json
import os
from typing import Dict, Optional, Tuple
from typing import Optional

import torch
from executorch.examples.models.checkpoint import (
Expand All @@ -18,7 +17,6 @@

from executorch.examples.models.llama.llama_transformer import construct_transformer
from executorch.examples.models.llama.model_args import ModelArgs
from executorch.examples.models.llama.rope import Rope

from executorch.extension.llm.export.config.llm_config import LlmConfig
from torchao.utils import TorchAOBaseTensor
Expand All @@ -39,12 +37,9 @@ def convert_to_llama_checkpoint(**kwargs):

class Llama2Model(EagerModelBase):
def __init__(self, llm_config: Optional[LlmConfig] = None):
resource_dir = get_default_model_resource_dir(__file__)

self.llm_config = llm_config if llm_config else LlmConfig()

checkpoint_path = self.llm_config.base.checkpoint
checkpoint_dir = self.llm_config.base.checkpoint_dir
params_path = self.llm_config.base.params

# Adapter checkpoint and config.
Expand Down Expand Up @@ -72,37 +67,8 @@ def __init__(self, llm_config: Optional[LlmConfig] = None):
# Follow the instruction in https://github.com/facebookresearch/llama to download the model.
device = "cpu"
# flake8: noqa: TOR102
cps = []
# Load sharded checkpoint.
checkpoint = {}
if checkpoint_dir is not None:
# Load multiple checkpoint; ignore the single path.
checkpoint_path = None
for i in range(4):
cp_name = f"consolidated.{i}.pth"
print(f"Loading {cp_name}")
cps.append(
torch.load(
os.path.join(checkpoint_dir, cp_name),
map_location=device,
mmap=True,
)
)
checkpoint = {}
for key in cps[0].keys():
if not torch.allclose(cps[0][key], cps[1][key]):
values = (cps[0][key], cps[1][key], cps[2][key], cps[3][key])
if "wo" in key or "w2" in key:
# Concat on dim=1 for "wo" and "w2".
checkpoint[key] = torch.cat(values, dim=1)
else:
# Concat on dim=0 for everything else.
checkpoint[key] = torch.cat(values, dim=0)
else:
# Do not duplicate layers shared between each checkpoint.
checkpoint[key] = cps[0][key]
# Load single checkpoint.
elif checkpoint_path:
if checkpoint_path:
checkpoint = torch.load(checkpoint_path, map_location=device, mmap=True)

# If given checkpoint is fairseq, convert to llama checkpoint.
Expand Down
4 changes: 0 additions & 4 deletions extension/llm/export/config/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ class BaseConfig:
If left empty, the model will either be initialized with random weights
if it is a Llama model or the weights will be downloaded from HuggingFace
if it is a non-Llama model.
checkpoint_dir: Path to directory containing sharded checkpoint files.
adapter_checkpoint: Path to the adapter.pt file from torchtune. Used if
the model has trained LoRA adapters. Must provide
adapter_config.json.
Expand All @@ -99,7 +98,6 @@ class BaseConfig:
model_class: ModelType = ModelType.llama3
params: Optional[str] = None
checkpoint: Optional[str] = None
checkpoint_dir: Optional[str] = None
adapter_checkpoint: Optional[str] = None
adapter_config: Optional[str] = None
tokenizer_path: Optional[str] = None
Expand Down Expand Up @@ -527,8 +525,6 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
llm_config.base.params = args.params
if hasattr(args, "checkpoint"):
llm_config.base.checkpoint = args.checkpoint
if hasattr(args, "checkpoint_dir"):
llm_config.base.checkpoint_dir = args.checkpoint_dir
if hasattr(args, "adapter_checkpoint"):
llm_config.base.adapter_checkpoint = args.adapter_checkpoint
if hasattr(args, "adapter_config"):
Expand Down