Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified .gitignore
Binary file not shown.
42 changes: 29 additions & 13 deletions scripts/eval_from_generations.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,20 +60,36 @@
operating_sys = "ubuntu22.04"
tag = f"{cuda_version}-{flavor}-{operating_sys}"

image = (
modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.10")
.apt_install("git",
"gcc-10",
"g++-10",
"clang"
)
.pip_install_from_requirements(os.path.join(REPO_TOP_DIR, "requirements.txt"))
.add_local_dir(
KERNEL_BENCH_PATH,
remote_path="/root/KernelBench"
# ThunderKittens support - use TK image if directory exists locally
THUNDERKITTENS_LOCAL_PATH = os.path.join(REPO_TOP_DIR, "ThunderKittens")
SRC_PATH = os.path.join(REPO_TOP_DIR, "src")

if os.path.isdir(THUNDERKITTENS_LOCAL_PATH):
# ThunderKittens image with TK environment and mounting
image = (
modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.10")
.apt_install("git", "gcc-10", "g++-10", "clang")
.pip_install_from_requirements(os.path.join(REPO_TOP_DIR, "requirements.txt"))
.env({
"THUNDERKITTENS_ROOT": "/root/ThunderKittens",
"THUNDERKITTENS_PATH": "/root/ThunderKittens",
"TORCH_CUDA_ARCH_LIST": "9.0",
"CXX": "g++-10",
"CC": "gcc-10",
})
.add_local_dir(THUNDERKITTENS_LOCAL_PATH, remote_path="/root/ThunderKittens", copy=True)
.add_local_dir(KERNEL_BENCH_PATH, remote_path="/root/KernelBench")
.add_local_dir(SRC_PATH, remote_path="/root/src")
)
else:
# Standard image
image = (
modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.10")
.apt_install("git", "gcc-10", "g++-10", "clang")
.pip_install_from_requirements(os.path.join(REPO_TOP_DIR, "requirements.txt"))
.add_local_dir(KERNEL_BENCH_PATH, remote_path="/root/KernelBench")
.add_local_dir(SRC_PATH, remote_path="/root/src")
)
.add_local_python_source("src")
)


class EvalConfig(Config):
Expand Down
6 changes: 5 additions & 1 deletion scripts/generate_and_eval_single_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,16 @@ def main(config: EvalConfig):
include_hardware = include_hardware.lower() in ["true", "1", "yes"]
config.include_hardware_info = include_hardware

supported_backends = {"cuda", "triton", "tilelang", "cute"}
supported_backends = {"cuda", "triton", "tilelang", "cute", "thunderkittens"}
backend = config.backend.lower()
if backend not in supported_backends:
raise ValueError(
f"Unsupported backend: {config.backend}. Must be one of {sorted(supported_backends)}."
)

# ThunderKittens uses fp32 by default
if backend == "thunderkittens":
config.precision = "fp32"

if backend == "tilelang":
config.precision = "fp16" # tilelang only operates with fp16
Expand Down
45 changes: 34 additions & 11 deletions scripts/generate_and_eval_single_sample_modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,16 +95,35 @@ def __repr__(self):
operating_sys = "ubuntu22.04"
tag = f"{cuda_version}-{flavor}-{operating_sys}"

image = (
modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.10")
.apt_install("git",
"gcc-10",
"g++-10",
"clang" # note i skip a step
)
.pip_install_from_requirements(os.path.join(REPO_TOP_DIR, "requirements.txt"))
.add_local_python_source("src")
)
# ThunderKittens support - use TK image if directory exists locally
THUNDERKITTENS_LOCAL_PATH = os.path.join(REPO_TOP_DIR, "ThunderKittens")

SRC_PATH = os.path.join(REPO_TOP_DIR, "src")

if os.path.isdir(THUNDERKITTENS_LOCAL_PATH):
# ThunderKittens image with TK environment and mounting
image = (
modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.10")
.apt_install("git", "gcc-10", "g++-10", "clang")
.pip_install_from_requirements(os.path.join(REPO_TOP_DIR, "requirements.txt"))
.env({
"THUNDERKITTENS_ROOT": "/root/ThunderKittens",
"THUNDERKITTENS_PATH": "/root/ThunderKittens",
"TORCH_CUDA_ARCH_LIST": "9.0",
"CXX": "g++-10",
"CC": "gcc-10",
})
.add_local_dir(THUNDERKITTENS_LOCAL_PATH, remote_path="/root/ThunderKittens", copy=True)
.add_local_dir(SRC_PATH, remote_path="/root/src")
)
else:
# Standard image
image = (
modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.10")
.apt_install("git", "gcc-10", "g++-10", "clang")
.pip_install_from_requirements(os.path.join(REPO_TOP_DIR, "requirements.txt"))
.add_local_dir(SRC_PATH, remote_path="/root/src")
)

@app.cls(image=image)
class EvalFunc:
Expand Down Expand Up @@ -215,12 +234,16 @@ def main(config: EvalConfig):
include_hardware = include_hardware.lower() in ["true", "1", "yes"]
config.include_hardware_info = include_hardware

supported_backends = {"cuda", "triton", "tilelang", "cute"}
supported_backends = {"cuda", "triton", "tilelang", "cute", "thunderkittens"}
backend = config.backend.lower()
if backend not in supported_backends:
raise ValueError(
f"Unsupported backend: {config.backend}. Must be one of {sorted(supported_backends)}."
)

# ThunderKittens uses fp32 by default
if backend == "thunderkittens":
config.precision = "fp32"

#tilelang only supports fp16 or bf16
if backend == "tilelang":
Expand Down
4 changes: 3 additions & 1 deletion scripts/generate_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def main(config: GenerationConfig):
include_hardware = include_hardware.lower() in ["true", "1", "yes"]
config.include_hardware_info = include_hardware

supported_backends = {"cuda", "triton", "cute", "tilelang"}
supported_backends = {"cuda", "triton", "cute", "tilelang", "thunderkittens"}
backend = config.backend.lower()
if backend not in supported_backends:
raise ValueError(
Expand All @@ -248,6 +248,8 @@ def main(config: GenerationConfig):
config.backend = backend
if backend == "tilelang":
config.precision = "fp16"
if backend == "thunderkittens":
config.precision = "fp32" # ThunderKittens supports fp32 by default

config.prompt_option = str(config.prompt_option).lower()
valid_prompt_options = {"zero_shot", "one_shot", "few_shot"}
Expand Down
40 changes: 32 additions & 8 deletions scripts/run_and_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,44 @@

REPO_TOP_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
KERNEL_BENCH_PATH = os.path.join(REPO_TOP_PATH, "KernelBench")
THUNDERKITTENS_LOCAL_PATH = os.path.join(REPO_TOP_PATH, "ThunderKittens")
SRC_PATH = os.path.join(REPO_TOP_PATH, "src")

cuda_version = "12.8.0"
flavor = "devel"
operating_sys = "ubuntu22.04"
tag = f"{cuda_version}-{flavor}-{operating_sys}"

image = (
modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.10")
.apt_install("git", "gcc-10", "g++-10", "clang")
.pip_install_from_requirements(os.path.join(REPO_TOP_PATH, "requirements.txt"))
.add_local_dir(KERNEL_BENCH_PATH, remote_path="/root/KernelBench")
.add_local_python_source("src")
.add_local_python_source("scripts")
)
# ThunderKittens support - use TK image if directory exists locally
if os.path.isdir(THUNDERKITTENS_LOCAL_PATH):
# ThunderKittens image with TK environment and mounting
image = (
modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.10")
.apt_install("git", "gcc-10", "g++-10", "clang")
.pip_install_from_requirements(os.path.join(REPO_TOP_PATH, "requirements.txt"))
.env({
"THUNDERKITTENS_ROOT": "/root/ThunderKittens",
"THUNDERKITTENS_PATH": "/root/ThunderKittens",
"TORCH_CUDA_ARCH_LIST": "9.0",
"CXX": "g++-10",
"CC": "gcc-10",
})
.add_local_dir(THUNDERKITTENS_LOCAL_PATH, remote_path="/root/ThunderKittens", copy=True)
.add_local_dir(KERNEL_BENCH_PATH, remote_path="/root/KernelBench")
.add_local_dir(SRC_PATH, remote_path="/root/src")
.add_local_python_source("src")
.add_local_python_source("scripts")
)
else:
# Standard image without ThunderKittens
image = (
modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.10")
.apt_install("git", "gcc-10", "g++-10", "clang")
.pip_install_from_requirements(os.path.join(REPO_TOP_PATH, "requirements.txt"))
.add_local_dir(KERNEL_BENCH_PATH, remote_path="/root/KernelBench")
.add_local_python_source("src")
.add_local_python_source("scripts")
)

"""
Run a pair of KernelBench format (problem, solution) to check if solution is correct and compute speedup
Expand Down
142 changes: 142 additions & 0 deletions src/prompts/model_new_ex_add_thunderkittens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import torch
import torch.nn as nn
from torch.utils.cpp_extension import load_inline
import os

# ThunderKittens header-only library path (set via environment variable)
# Default to /root/ThunderKittens for Modal containers, or use THUNDERKITTENS_PATH env var
TK_PATH = os.environ.get("THUNDERKITTENS_PATH", os.environ.get("THUNDERKITTENS_ROOT", "/root/ThunderKittens"))

# C++ source: function declaration for binding
elementwise_add_cpp_source = """
torch::Tensor elementwise_add_cuda(torch::Tensor a, torch::Tensor b);
"""

# CUDA source: ThunderKittens kernel implementation
#
# IMPORTANT ThunderKittens API notes:
# 1. Define KITTENS_HOPPER before including kittens.cuh for H100/Hopper GPUs
# 2. Operations like load, store, zero, mma_AB are NOT free functions!
# They are static member functions inside kittens::group<N> template struct.
# 3. Create an alias like: using warp = kittens::group<1>;
# 4. Then call: warp::load(...), warp::zero(...), etc.
#
elementwise_add_cuda_source = """
// IMPORTANT: Define KITTENS_HOPPER before including ThunderKittens headers for H100/Hopper GPUs
// This enables FP8 types and Hopper-specific features
#define KITTENS_HOPPER

#include <torch/extension.h>
#include <cuda_runtime.h>

// Include ThunderKittens headers
#include "kittens.cuh"

// ThunderKittens namespace and group aliases
// Operations are accessed through these group types, NOT as free functions
using namespace kittens;
using warp = kittens::group<1>; // For single-warp operations (32 threads)
// For multi-warp operations, use: using warpgroup = kittens::group<4>;

// Constants for tile dimensions
constexpr int TILE_DIM = 16;

// ThunderKittens elementwise add kernel using shared memory tiles
// This example demonstrates the ThunderKittens API pattern
__global__ void tk_elementwise_add_kernel(const float* __restrict__ a_ptr,
const float* __restrict__ b_ptr,
float* __restrict__ out_ptr,
int rows, int cols) {
// For simple element-wise ops, we use a straightforward approach
// ThunderKittens shines for matrix ops with tiles, but here we show basic pattern

int idx = blockIdx.x * blockDim.x + threadIdx.x;
int total = rows * cols;

// Grid-stride loop for simple element-wise addition
for (int i = idx; i < total; i += blockDim.x * gridDim.x) {
out_ptr[i] = a_ptr[i] + b_ptr[i];
}
}

// Alternative: ThunderKittens tiled version for larger matrices
// Shows proper usage of ThunderKittens tile types and group operations
// Uncomment and adapt for matrix operations:
/*
__global__ void tk_matmul_kernel(const bf16* A, const bf16* B, bf16* C,
int M, int N, int K) {
// Define aliases for the group - THIS IS REQUIRED for ThunderKittens ops
using warpgroup = kittens::group<4>; // 4 warps = 128 threads

// ThunderKittens register tiles for accumulation
rt_fl<16, 16> acc; // 16x16 float register tile

// Shared memory tiles
extern __shared__ alignment_dummy __shm[];
st_bf<16, 16> (&a_smem)[2] = *reinterpret_cast<st_bf<16, 16>(*)[2]>(__shm);
st_bf<16, 16> (&b_smem)[2] = *reinterpret_cast<st_bf<16, 16>(*)[2]>(__shm + sizeof(st_bf<16,16>)*2);

// Initialize accumulator to zero - NOTE: use warpgroup:: prefix!
warpgroup::zero(acc);

// Main loop would go here with:
// warpgroup::load(a_smem[...], ...); // Load from global to shared
// warpgroup::mma_AB(acc, a_tile, b_tile); // Matrix multiply-accumulate
// warpgroup::store(C_ptr, acc, ...); // Store result
}
*/

torch::Tensor elementwise_add_cuda(torch::Tensor a, torch::Tensor b) {
TORCH_CHECK(a.is_cuda(), "Input tensor a must be on CUDA");
TORCH_CHECK(b.is_cuda(), "Input tensor b must be on CUDA");
TORCH_CHECK(a.sizes() == b.sizes(), "Input tensors must have the same shape");

auto out = torch::empty_like(a);
int rows = a.size(0);
int cols = a.numel() / rows;

const int block_size = 256;
const int num_blocks = (a.numel() + block_size - 1) / block_size;

tk_elementwise_add_kernel<<<num_blocks, block_size>>>(
a.data_ptr<float>(),
b.data_ptr<float>(),
out.data_ptr<float>(),
rows, cols
);

return out;
}
"""

# Compile the ThunderKittens kernel inline
elementwise_add = load_inline(
name="elementwise_add_tk",
cpp_sources=elementwise_add_cpp_source,
cuda_sources=elementwise_add_cuda_source,
functions=["elementwise_add_cuda"],
verbose=True,
extra_include_paths=[
TK_PATH,
os.path.join(TK_PATH, "include"),
],
extra_cflags=["-std=c++20", "-O3"],
extra_cuda_cflags=[
"-std=c++20",
"-O3",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"-Xcompiler", "-fPIC",
"-DNDEBUG",
"-DKITTENS_HOPPER",
],
)


class ModelNew(nn.Module):
def __init__(self) -> None:
super().__init__()
self.elementwise_add = elementwise_add

def forward(self, a, b):
return self.elementwise_add.elementwise_add_cuda(a, b)
5 changes: 5 additions & 0 deletions src/prompts/prompts.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ backend_display = "TileLang kernels"
one_shot_new_arch = "src/prompts/model_new_ex_add_tilelang.py"
# No few_shot_examples - will use one-shot when few_shot option is selected

[backends.thunderkittens]
backend_display = "ThunderKittens kernels"
one_shot_new_arch = "src/prompts/model_new_ex_add_thunderkittens.py"
# No few_shot_examples - will use one-shot when few_shot option is selected

# -------------------------------------------------------------------------
# Precision: Precision-specific configuration
# -------------------------------------------------------------------------
Expand Down