Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions megatron/core/inference/batch_dimensions_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ def adjust_batch_dims_for_expert_parallelism(
local_batch_dims.token_count,
int(is_non_decode),
int(has_explicit_chunked_prefill_req),
local_batch_dims.prefill_req_count,
local_batch_dims.decode_req_count,
],
dtype=torch.int32,
device=torch.cuda.current_device(),
Expand All @@ -208,10 +210,21 @@ def adjust_batch_dims_for_expert_parallelism(
return None # indicate no match, run in eager mode

assert not has_explicit_chunked_prefill_req

# If strict matching is enabled, we sync the request counts across EP ranks
# to ensure the graph captures the maximum needed capacity.
# TODO(ksanthanam): Add functional test for this scenario
adjusted_prefill_req_count = (
int(sync_tensor[3].item()) if strict else local_batch_dims.prefill_req_count
)
adjusted_decode_req_count = (
int(sync_tensor[4].item()) if strict else local_batch_dims.decode_req_count
)

adjusted_batch_dim = InferenceBatchDimensions(
token_count=int(sync_tensor[0].item()),
prefill_req_count=local_batch_dims.prefill_req_count,
decode_req_count=local_batch_dims.decode_req_count,
prefill_req_count=adjusted_prefill_req_count,
decode_req_count=adjusted_decode_req_count,
has_explicit_chunked_prefill_req=False,
)
return adjusted_batch_dim
Expand Down
4 changes: 4 additions & 0 deletions megatron/core/inference/engines/dynamic_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1652,7 +1652,11 @@ async def run_engine_with_coordinator(
if ep_group_has_work and local_pending_requests == 0:
# run dummy forward pass if EP group as a whole has work,
# but this rank does not have any work.
self.step_start_event.record()
self.controller.dummy_forward()
self.step_end_event.record()
self.step_end_event.synchronize()
self.step_count += 1
continue

# 3. No work in EP group
Expand Down
31 changes: 31 additions & 0 deletions megatron/core/ssm/mamba_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,37 @@ def mamba_state_shapes_per_request(self) -> Optional[Tuple[Tuple[int], Tuple[int
return layer.mamba_state_shapes_per_request()
return None

def _should_call_local_cudagraph(self, *args, **kwargs):
"""
Check if we should call the local cudagraph path.
"""
if not self.training and (
hasattr(self, 'cudagraph_manager')
and kwargs['attention_mask'] is None
and (
kwargs.get('inference_context') is not None
or kwargs.get('inference_params') is not None
)
and CudaGraphScope.full_iteration in self.config.cuda_graph_scope
):
if kwargs['inference_context'].is_static_batching():
using_cuda_graph = kwargs['inference_context'].is_decode_only()
else:
using_cuda_graph = kwargs['inference_context'].using_cuda_graph_this_step()

if using_cuda_graph:
return True
return False

def __call__(self, *args, **kwargs):
if self._should_call_local_cudagraph(*args, **kwargs):
kwargs['hidden_states'] = (
kwargs['hidden_states'].unwrap()
if isinstance(kwargs['hidden_states'], WrappedTensor)
else kwargs['hidden_states']
)
return super().__call__(*args, **kwargs)

def forward(
self,
hidden_states: Union[Tensor, WrappedTensor],
Expand Down
2 changes: 2 additions & 0 deletions megatron/core/ssm/mamba_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from megatron.core.inference.contexts import BaseInferenceContext
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.transformer.enums import CudaGraphScope
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.module import GraphableMegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
Expand Down Expand Up @@ -184,6 +185,7 @@ def _should_call_local_cudagraph(self, *args, **kwargs):
hasattr(self, 'cudagraph_manager')
and kwargs.get('attention_mask') is None
and kwargs.get('inference_context') is not None
and CudaGraphScope.full_iteration not in self.config.cuda_graph_scope
):
using_cuda_graph = kwargs['inference_context'].using_cuda_graph_this_step()
return using_cuda_graph
Expand Down
10 changes: 0 additions & 10 deletions megatron/core/transformer/transformer_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,16 +597,6 @@ def __call__(self, *args, **kwargs):
if isinstance(kwargs['hidden_states'], WrappedTensor)
else kwargs['hidden_states']
)
# dynamic_inference_decode_only is not a real argument to forward, it is only used
# to differentiate the cuda graph used for decode from the one used for non-decode
# inference.
dynamic_inference_decode_only = kwargs['inference_context'].is_decode_only()
# cudagraphmanager returns a singleton tuple, whereas the
# normal forward returns a tensor, therefore we need
# to extract the tensor from the tuple
return super().__call__(
*args, dynamic_inference_decode_only=dynamic_inference_decode_only, **kwargs
)[0]
return super().__call__(*args, **kwargs)

def forward(
Expand Down
16 changes: 0 additions & 16 deletions megatron/core/transformer/transformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,10 +483,6 @@ def forward(self, *args, **kwargs):
This method calls the core computation of a transformer layer, including
self-attention, cross-attention (if applicable), and feed-forward operations.
"""
# Remove 'dynamic_inference_decode_only' from kwargs if present
# this is only used to uniquely identify decode and non-decode cuda graph
# runners in the cuda graph manager
kwargs.pop("dynamic_inference_decode_only", None)
hidden_states, context = self._forward_attention(*args, **kwargs)
output = self._forward_mlp(hidden_states, kwargs.get("inference_context", None))
return output, context
Expand Down Expand Up @@ -1156,18 +1152,6 @@ def _should_call_local_cudagraph(self, *args, **kwargs):
return True
return False

def __call__(self, *args, **kwargs):
if self._should_call_local_cudagraph(*args, **kwargs):
# Inference mode.
if kwargs.get('inference_context') is not None:
# dynamic_inference_decode_only is not a real argument to forward, it is only used
# to differentiate the cuda graph used for decode from the one used for non-decode
# inference.
kwargs["dynamic_inference_decode_only"] = kwargs[
'inference_context'
].is_decode_only()
return super().__call__(*args, **kwargs)

def get_layer_norm_weights(self):
"""
Get the weights of all layernorms (attention and MLP) in the transformer layer.
Expand Down