Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ If you're interviewing for any role touching LLMs or Transformers, expect at lea

| # | Problem | What You'll Implement | Difficulty | Freq | Key Concepts |
|:---:|---------|----------------------|:----------:|:----:|--------------|
| 23 | <a href="https://github.com/duoan/TorchCode/blob/master/templates/23_cross_attention.ipynb" target="_blank">Cross-Attention</a> <a href="https://colab.research.google.com/github/duoan/TorchCode/blob/master/templates/23_cross_attention.ipynb" target="_blank"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab" height="20"></a> | `MultiHeadCrossAttention` (nn.Module) | ![Medium](https://img.shields.io/badge/Medium-FF9800?style=flat-square) | ⭐ | Encoder-decoder, Q from decoder, K/V from encoder |
| 5 | <a href="https://github.com/duoan/TorchCode/blob/master/templates/05_attention.ipynb" target="_blank">Scaled Dot-Product Attention</a> <a href="https://colab.research.google.com/github/duoan/TorchCode/blob/master/templates/05_attention.ipynb" target="_blank"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab" height="20"></a> | `scaled_dot_product_attention(Q, K, V)` | ![Hard](https://img.shields.io/badge/Hard-F44336?style=flat-square) | 🔥 | `softmax(QK^T/√d_k)V`, the foundation of everything |
| 23 | <a href="https://github.com/duoan/TorchCode/blob/master/templates/23_cross_attention.ipynb" target="_blank">Cross-Attention</a> <a href="https://colab.research.google.com/github/duoan/TorchCode/blob/master/templates/23_cross_attention.ipynb" target="_blank"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab" height="20"></a> | `MultiHeadCrossAttention` (nn.Module) | ![Hard](https://img.shields.io/badge/Hard-F44336?style=flat-square) | ⭐ | Encoder-decoder, Q from decoder, K/V from encoder |
| 5 | <a href="https://github.com/duoan/TorchCode/blob/master/templates/05_attention.ipynb" target="_blank">Scaled Dot-Product Attention</a> <a href="https://colab.research.google.com/github/duoan/TorchCode/blob/master/templates/05_attention.ipynb" target="_blank"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab" height="20"></a> | `scaled_dot_product_attention(Q, K, V)` | ![Medium](https://img.shields.io/badge/Medium-FF9800?style=flat-square) | 🔥 | `softmax(QK^T/√d_k)V`, the foundation of everything |
| 6 | <a href="https://github.com/duoan/TorchCode/blob/master/templates/06_multihead_attention.ipynb" target="_blank">Multi-Head Attention</a> <a href="https://colab.research.google.com/github/duoan/TorchCode/blob/master/templates/06_multihead_attention.ipynb" target="_blank"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab" height="20"></a> | `MultiHeadAttention` (nn.Module) | ![Hard](https://img.shields.io/badge/Hard-F44336?style=flat-square) | 🔥 | Parallel heads, split/concat, projection matrices |
| 9 | <a href="https://github.com/duoan/TorchCode/blob/master/templates/09_causal_attention.ipynb" target="_blank">Causal Self-Attention</a> <a href="https://colab.research.google.com/github/duoan/TorchCode/blob/master/templates/09_causal_attention.ipynb" target="_blank"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab" height="20"></a> | `causal_attention(Q, K, V)` | ![Hard](https://img.shields.io/badge/Hard-F44336?style=flat-square) | 🔥 | Autoregressive masking with `-inf`, GPT-style |
| 10 | <a href="https://github.com/duoan/TorchCode/blob/master/templates/10_gqa.ipynb" target="_blank">Grouped Query Attention</a> <a href="https://colab.research.google.com/github/duoan/TorchCode/blob/master/templates/10_gqa.ipynb" target="_blank"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab" height="20"></a> | `GroupQueryAttention` (nn.Module) | ![Hard](https://img.shields.io/badge/Hard-F44336?style=flat-square) | ⭐ | GQA (LLaMA 2), KV sharing across heads |
Expand Down
209 changes: 105 additions & 104 deletions solutions/05_attention_solution.ipynb
Original file line number Diff line number Diff line change
@@ -1,106 +1,107 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "5f63d076",
"metadata": {},
"source": [
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/duoan/TorchCode/blob/master/solutions/05_attention_solution.ipynb)\n\n",
"# 🔴 Solution: Softmax Attention\n",
"\n",
"Reference solution for the core Transformer attention mechanism.\n",
"\n",
"$$\\text{Attention}(Q, K, V) = \\text{softmax}\\!\\left(\\frac{QK^T}{\\sqrt{d_k}}\\right)V$$"
]
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/duoan/TorchCode/blob/master/solutions/05_attention_solution.ipynb)\n",
"\n",
"# 🟠 Solution: Softmax Attention\n",
"\n",
"Reference solution for the core Transformer attention mechanism.\n",
"\n",
"$$\\text{Attention}(Q, K, V) = \\text{softmax}\\!\\left(\\frac{QK^T}{\\sqrt{d_k}}\\right)V$$"
],
"id": "5f63d076"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"# Install torch-judge in Colab (no-op in JupyterLab/Docker)\n",
"try:\n",
" import google.colab\n",
" get_ipython().run_line_magic('pip', 'install -q torch-judge')\n",
"except ImportError:\n",
" pass\n"
],
"execution_count": null,
"outputs": [],
"id": "ce663fb0"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"import torch\n",
"import math"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {},
"source": [
"# ✅ SOLUTION\n",
"\n",
"def scaled_dot_product_attention(Q, K, V):\n",
" d_k = K.size(-1)\n",
" scores = torch.bmm(Q, K.transpose(1, 2)) / math.sqrt(d_k)\n",
" weights = torch.softmax(scores, dim=-1)\n",
" return torch.bmm(weights, V)"
],
"execution_count": null,
"outputs": [],
"id": "828be673"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"# Verify\n",
"torch.manual_seed(42)\n",
"Q = torch.randn(2, 4, 8)\n",
"K = torch.randn(2, 4, 8)\n",
"V = torch.randn(2, 4, 8)\n",
"\n",
"out = scaled_dot_product_attention(Q, K, V)\n",
"print(\"Output shape:\", out.shape)\n",
"print(\"Attention weights sum to 1?\", True)\n",
"\n",
"# Cross-attention (seq_q != seq_k)\n",
"Q2 = torch.randn(1, 3, 16)\n",
"K2 = torch.randn(1, 5, 16)\n",
"V2 = torch.randn(1, 5, 32)\n",
"out2 = scaled_dot_product_attention(Q2, K2, V2)\n",
"print(\"Cross-attention shape:\", out2.shape, \"(expected: 1, 3, 32)\")"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {},
"source": [
"# Run judge\n",
"from torch_judge import check\n",
"check(\"attention\")"
],
"execution_count": null,
"outputs": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.11.0"
}
},
{
"cell_type": "code",
"execution_count": null,
"id": "ce663fb0",
"metadata": {},
"outputs": [],
"source": [
"# Install torch-judge in Colab (no-op in JupyterLab/Docker)\n",
"try:\n",
" import google.colab\n",
" get_ipython().run_line_magic('pip', 'install -q torch-judge')\n",
"except ImportError:\n",
" pass\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import math"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "828be673",
"metadata": {},
"outputs": [],
"source": [
"# ✅ SOLUTION\n",
"\n",
"def scaled_dot_product_attention(Q, K, V):\n",
" d_k = K.size(-1)\n",
" scores = torch.bmm(Q, K.transpose(1, 2)) / math.sqrt(d_k)\n",
" weights = torch.softmax(scores, dim=-1)\n",
" return torch.bmm(weights, V)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Verify\n",
"torch.manual_seed(42)\n",
"Q = torch.randn(2, 4, 8)\n",
"K = torch.randn(2, 4, 8)\n",
"V = torch.randn(2, 4, 8)\n",
"\n",
"out = scaled_dot_product_attention(Q, K, V)\n",
"print(\"Output shape:\", out.shape)\n",
"print(\"Attention weights sum to 1?\", True)\n",
"\n",
"# Cross-attention (seq_q != seq_k)\n",
"Q2 = torch.randn(1, 3, 16)\n",
"K2 = torch.randn(1, 5, 16)\n",
"V2 = torch.randn(1, 5, 32)\n",
"out2 = scaled_dot_product_attention(Q2, K2, V2)\n",
"print(\"Cross-attention shape:\", out2.shape, \"(expected: 1, 3, 32)\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Run judge\n",
"from torch_judge import check\n",
"check(\"attention\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.11.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
"nbformat": 4,
"nbformat_minor": 5
}
Loading