-
Notifications
You must be signed in to change notification settings - Fork 30
GEMM + ReduceScatter with Workgroup Specialization Example #317
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
|
||
| tl.store(C + local_offset, c, mask=sub_mask, cache_modifier=".wt") | ||
| tl.debug_barrier() | ||
| tl.store(locks + tile_id, 1, cache_modifier=".wt") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use atomic_cas with release semantics.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know some other examples are using this barrier/volatile stores but these are not correct and we will fix them. The correct pattern is shown in this example:
iris/examples/06_message_passing/message_passing_load_store.py
Lines 18 to 94 in 0dcaba0
| def producer_kernel( | |
| source_buffer, # tl.tensor: pointer to source data | |
| target_buffer, # tl.tensor: pointer to target data | |
| flag, # tl.tensor: pointer to flags | |
| buffer_size, # int32: total number of elements | |
| producer_rank: tl.constexpr, | |
| consumer_rank: tl.constexpr, | |
| BLOCK_SIZE: tl.constexpr, | |
| heap_bases_ptr: tl.tensor, # tl.tensor: pointer to heap bases pointers | |
| ): | |
| pid = tl.program_id(0) | |
| # Compute start index of this block | |
| block_start = pid * BLOCK_SIZE | |
| offsets = block_start + tl.arange(0, BLOCK_SIZE) | |
| # Guard for out-of-bounds accesses | |
| mask = offsets < buffer_size | |
| # Load chunk from source buffer | |
| values = iris.load(source_buffer + offsets, producer_rank, producer_rank, heap_bases_ptr, mask=mask) | |
| # Store chunk to target buffer | |
| iris.store( | |
| target_buffer + offsets, | |
| values, | |
| producer_rank, | |
| consumer_rank, | |
| heap_bases_ptr, | |
| mask=mask, | |
| ) | |
| # Set flag to signal completion | |
| iris.atomic_cas(flag + pid, 0, 1, producer_rank, consumer_rank, heap_bases_ptr, sem="release", scope="sys") | |
| @triton.jit | |
| def consumer_kernel( | |
| buffer, # tl.tensor: pointer to shared buffer (read from target_rank) | |
| flag, # tl.tensor: sync flag per block | |
| buffer_size, # int32: total number of elements | |
| consumer_rank: tl.constexpr, | |
| BLOCK_SIZE: tl.constexpr, | |
| heap_bases_ptr: tl.tensor, # tl.tensor: pointer to heap bases pointers | |
| ): | |
| pid = tl.program_id(0) | |
| block_start = pid * BLOCK_SIZE | |
| offsets = block_start + tl.arange(0, BLOCK_SIZE) | |
| mask = offsets < buffer_size | |
| # Spin-wait until writer sets flag[pid] = 1 | |
| done = 0 | |
| while done == 0: | |
| done = iris.atomic_cas( | |
| flag + pid, 1, 0, consumer_rank, consumer_rank, heap_bases_ptr, sem="acquire", scope="sys" | |
| ) | |
| # Read from the target buffer (written by producer) | |
| values = iris.load(buffer + offsets, consumer_rank, consumer_rank, heap_bases_ptr, mask=mask) | |
| # Do something with values... | |
| # (Here you might write to output, do computation, etc.) | |
| values = values * 2 | |
| # Store chunk to target buffer | |
| iris.store( | |
| buffer + offsets, | |
| values, | |
| consumer_rank, | |
| consumer_rank, | |
| heap_bases_ptr, | |
| mask=mask, | |
| ) | |
| # Optionally reset the flag for next iteration | |
| tl.store(flag + pid, 0) |
mawad-amd
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR, Kyle! I know it is a draft but I left a couple of comments.
|
|
||
| local_offset = rm[:, None] * stride_cm + rn[None, :] * stride_cn | ||
|
|
||
| while tl.load(locks + tile_id, cache_modifier=".cv", volatile=True) != 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use atomic_cas with acquire semantics here
| args["gsize_m"], | ||
| args["num_stages"], | ||
| shmem.get_heap_bases(), | ||
| "gfx942", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we avoid the hardcoded arch here and maybe find it via torch.cuda.get_device_properties?
Motivation
To add an example of GEMM + ReduceScatter by workgroup specialization
Technical Details
Test Plan
Test Result
Submission Checklist