-
Notifications
You must be signed in to change notification settings - Fork 23
Description
Performance Optimization Suggestions for DFlash Decoding
After analyzing the DFlash codebase, I've identified several optimization opportunities that could significantly improve decoding speed. Here are the key areas for improvement:
1. Optimized Concatenation Operations
Current Issue: In Qwen3DFlashAttention.forward(), there are inefficient torch.cat operations followed by view and transpose.
Proposed Solution: Pre-allocate tensors and use direct assignment:
def forward_optimized(self, hidden_states, target_hidden, ...):
bsz, q_len = hidden_states.shape[:-1]
ctx_len = target_hidden.shape[1]
# Pre-allocate tensors
total_len = ctx_len + q_len
k = torch.empty((bsz, total_len, self.num_key_value_heads, self.head_dim),
device=hidden_states.device, dtype=hidden_states.dtype)
v = torch.empty_like(k)
# Direct assignment instead of concat
k[:, :ctx_len] = self.k_proj(target_hidden).view(bsz, ctx_len, -1, self.head_dim)
k[:, ctx_len:] = self.k_proj(hidden_states).view(bsz, q_len, -1, self.head_dim)
v[:, :ctx_len] = self.v_proj(target_hidden).view(bsz, ctx_len, -1, self.head_dim)
v[:, ctx_len:] = self.v_proj(hidden_states).view(bsz, q_len, -1, self.head_dim)2. Embedding Cache
Current Issue: Embeddings are recalculated for each block in spec_generate().
Proposed Solution: Implement an embedding cache:
class EmbeddingCache:
def __init__(self, embed_fn, max_size=1000):
self.embed_fn = embed_fn
self.cache = {}
self.max_size = max_size
def __call__(self, tokens):
key = tokens.data_ptr()
if key not in self.cache:
if len(self.cache) >= self.max_size:
self.cache.clear()
self.cache[key] = self.embed_fn(tokens)
return self.cache[key]
# Usage
embedding_cache = EmbeddingCache(target.model.embed_tokens)3. Optimized Sampling
Current Issue: Repetitive sampling operations.
Proposed Solution: Batch sampling for better performance:
def batch_sample(logits: torch.Tensor, temperature: float = 0.0) -> torch.Tensor:
if temperature < 1e-5:
return torch.argmax(logits, dim=-1)
bsz, seq_len, vocab_size = logits.shape
logits = logits.view(-1, vocab_size) / temperature
probs = torch.softmax(logits, dim=-1)
samples = torch.multinomial(probs, num_samples=seq_len)
return samples.view(bsz, seq_len)4. Cache Management Optimization
Current Issue: Frequent crop operations in cache management.
Proposed Solution: Reduce crop frequency:
def optimized_cache_strategy(past_key_values_draft, start, block_size):
# Crop only when necessary (every N blocks)
if start % (block_size * 4) == 0:
past_key_values_draft.crop(start)5. Tensor Parallelism
Current Issue: Sequential operations in decoding.
Proposed Solution: Parallelize independent operations:
def parallel_attention_forward(q, k, v, attention_mask=None):
with torch.cuda.stream(torch.cuda.Stream()):
qk = torch.matmul(q, k.transpose(-2, -1))
if attention_mask is not None:
qk = qk + attention_mask
attn_weights = torch.softmax(qk, dim=-1)
attn_output = torch.matmul(attn_weights, v)
return attn_output, attn_weights6. Kernel Fusion
Current Issue: Multiple separate operations.
Proposed Solution: Create fused kernels for common operations:
class FusedQKVProjection(nn.Module):
def __init__(self, hidden_size, num_heads, head_dim):
super().__init__()
self.qkv_proj = nn.Linear(hidden_size, (num_heads + 2 * num_heads) * head_dim)
self.num_heads = num_heads
self.head_dim = head_dim
def forward(self, x):
qkv = self.qkv_proj(x)
bsz, seq_len, _ = qkv.shape
qkv = qkv.view(bsz, seq_len, self.num_heads + 2, self.head_dim)
q = qkv[:, :, :self.num_heads].transpose(1, 2)
k = qkv[:, :, self.num_heads:self.num_heads+1].transpose(1, 2)
v = qkv[:, :, self.num_heads+1:].transpose(1, 2)
return q, k, v7. Additional Optimizations
Memory Layout Optimization:
- Use
torch.contiguous()to ensure optimal memory access patterns - Consider channel-first memory layout when beneficial
Quantization:
- Implement mixed-precision training/inference
- Explore INT8 quantization for specific operations
CUDA Kernels:
- Custom CUDA kernels for bottleneck operations
- Optimized memory access patterns
Expected Performance Gains
These optimizations should provide:
- 20-30% speedup from optimized tensor operations
- 15-25% improvement from embedding caching
- 10-20% gain from parallelization
- 5-15% improvement from kernel fusion
Implementation Priority
- Embedding Cache (easiest, high impact)
- Optimized Concatenation Operations
- Batch Sampling
- Cache Management
- Tensor Parallelism
- Kernel Fusion (most complex)
Would be happy to collaborate on implementing these optimizations!