Skip to content

Conversation

@knwng
Copy link

@knwng knwng commented Jan 13, 2026

Motivation

To add an example of GEMM + ReduceScatter by workgroup specialization

Technical Details

Test Plan

Test Result

Submission Checklist


tl.store(C + local_offset, c, mask=sub_mask, cache_modifier=".wt")
tl.debug_barrier()
tl.store(locks + tile_id, 1, cache_modifier=".wt")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use atomic_cas with release semantics.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know some other examples are using this barrier/volatile stores but these are not correct and we will fix them. The correct pattern is shown in this example:

def producer_kernel(
source_buffer, # tl.tensor: pointer to source data
target_buffer, # tl.tensor: pointer to target data
flag, # tl.tensor: pointer to flags
buffer_size, # int32: total number of elements
producer_rank: tl.constexpr,
consumer_rank: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
heap_bases_ptr: tl.tensor, # tl.tensor: pointer to heap bases pointers
):
pid = tl.program_id(0)
# Compute start index of this block
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Guard for out-of-bounds accesses
mask = offsets < buffer_size
# Load chunk from source buffer
values = iris.load(source_buffer + offsets, producer_rank, producer_rank, heap_bases_ptr, mask=mask)
# Store chunk to target buffer
iris.store(
target_buffer + offsets,
values,
producer_rank,
consumer_rank,
heap_bases_ptr,
mask=mask,
)
# Set flag to signal completion
iris.atomic_cas(flag + pid, 0, 1, producer_rank, consumer_rank, heap_bases_ptr, sem="release", scope="sys")
@triton.jit
def consumer_kernel(
buffer, # tl.tensor: pointer to shared buffer (read from target_rank)
flag, # tl.tensor: sync flag per block
buffer_size, # int32: total number of elements
consumer_rank: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
heap_bases_ptr: tl.tensor, # tl.tensor: pointer to heap bases pointers
):
pid = tl.program_id(0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < buffer_size
# Spin-wait until writer sets flag[pid] = 1
done = 0
while done == 0:
done = iris.atomic_cas(
flag + pid, 1, 0, consumer_rank, consumer_rank, heap_bases_ptr, sem="acquire", scope="sys"
)
# Read from the target buffer (written by producer)
values = iris.load(buffer + offsets, consumer_rank, consumer_rank, heap_bases_ptr, mask=mask)
# Do something with values...
# (Here you might write to output, do computation, etc.)
values = values * 2
# Store chunk to target buffer
iris.store(
buffer + offsets,
values,
consumer_rank,
consumer_rank,
heap_bases_ptr,
mask=mask,
)
# Optionally reset the flag for next iteration
tl.store(flag + pid, 0)

Copy link
Collaborator

@mawad-amd mawad-amd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR, Kyle! I know it is a draft but I left a couple of comments.


local_offset = rm[:, None] * stride_cm + rn[None, :] * stride_cn

while tl.load(locks + tile_id, cache_modifier=".cv", volatile=True) != 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use atomic_cas with acquire semantics here

args["gsize_m"],
args["num_stages"],
shmem.get_heap_bases(),
"gfx942",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we avoid the hardcoded arch here and maybe find it via torch.cuda.get_device_properties?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants