Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 64 additions & 8 deletions compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ use std::ffi::CString;

use llvm::Linkage::*;
use rustc_abi::Align;
use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue};
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
use rustc_middle::bug;
use rustc_middle::ty::offload_meta::OffloadMetadata;

use crate::builder::Builder;
Expand Down Expand Up @@ -69,6 +71,57 @@ impl<'ll> OffloadGlobals<'ll> {
}
}

pub(crate) struct OffloadKernelDims<'ll> {
num_workgroups: &'ll Value,
threads_per_block: &'ll Value,
workgroup_dims: &'ll Value,
thread_dims: &'ll Value,
}

impl<'ll> OffloadKernelDims<'ll> {
pub(crate) fn from_operands<'tcx>(
builder: &mut Builder<'_, 'll, 'tcx>,
workgroup_op: &OperandRef<'tcx, &'ll llvm::Value>,
thread_op: &OperandRef<'tcx, &'ll llvm::Value>,
) -> Self {
let cx = builder.cx;
let arr_ty = cx.type_array(cx.type_i32(), 3);
let four = Align::from_bytes(4).unwrap();

let OperandValue::Ref(place) = workgroup_op.val else {
bug!("expected array operand by reference");
};
let workgroup_val = builder.load(arr_ty, place.llval, four);

let OperandValue::Ref(place) = thread_op.val else {
bug!("expected array operand by reference");
};
let thread_val = builder.load(arr_ty, place.llval, four);

fn mul_dim3<'ll, 'tcx>(
builder: &mut Builder<'_, 'll, 'tcx>,
arr: &'ll Value,
) -> &'ll Value {
let x = builder.extract_value(arr, 0);
let y = builder.extract_value(arr, 1);
let z = builder.extract_value(arr, 2);

let xy = builder.mul(x, y);
builder.mul(xy, z)
}

let num_workgroups = mul_dim3(builder, workgroup_val);
let threads_per_block = mul_dim3(builder, thread_val);

OffloadKernelDims {
workgroup_dims: workgroup_val,
thread_dims: thread_val,
num_workgroups,
threads_per_block,
}
}
}

// ; Function Attrs: nounwind
// declare i32 @__tgt_target_kernel(ptr, i64, i32, i32, ptr, ptr) #2
fn generate_launcher<'ll>(cx: &CodegenCx<'ll, '_>) -> (&'ll llvm::Value, &'ll llvm::Type) {
Expand Down Expand Up @@ -204,12 +257,12 @@ impl KernelArgsTy {
num_args: u64,
memtransfer_types: &'ll Value,
geps: [&'ll Value; 3],
workgroup_dims: &'ll Value,
thread_dims: &'ll Value,
) -> [(Align, &'ll Value); 13] {
let four = Align::from_bytes(4).expect("4 Byte alignment should work");
let eight = Align::EIGHT;

let ti32 = cx.type_i32();
let ci32_0 = cx.get_const_i32(0);
[
(four, cx.get_const_i32(KernelArgsTy::OFFLOAD_VERSION)),
(four, cx.get_const_i32(num_args)),
Expand All @@ -222,8 +275,8 @@ impl KernelArgsTy {
(eight, cx.const_null(cx.type_ptr())), // dbg
(eight, cx.get_const_i64(KernelArgsTy::TRIPCOUNT)),
(eight, cx.get_const_i64(KernelArgsTy::FLAGS)),
(four, cx.const_array(ti32, &[cx.get_const_i32(2097152), ci32_0, ci32_0])),
(four, cx.const_array(ti32, &[cx.get_const_i32(256), ci32_0, ci32_0])),
(four, workgroup_dims),
(four, thread_dims),
(four, cx.get_const_i32(0)),
]
}
Expand Down Expand Up @@ -413,10 +466,13 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
types: &[&Type],
metadata: &[OffloadMetadata],
offload_globals: &OffloadGlobals<'ll>,
offload_dims: &OffloadKernelDims<'ll>,
) {
let cx = builder.cx;
let OffloadKernelGlobals { offload_sizes, offload_entry, memtransfer_types, region_id } =
offload_data;
let OffloadKernelDims { num_workgroups, threads_per_block, workgroup_dims, thread_dims } =
offload_dims;

let tgt_decl = offload_globals.launcher_fn;
let tgt_target_kernel_ty = offload_globals.launcher_ty;
Expand Down Expand Up @@ -554,7 +610,8 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
num_args,
s_ident_t,
);
let values = KernelArgsTy::new(&cx, num_args, memtransfer_types, geps);
let values =
KernelArgsTy::new(&cx, num_args, memtransfer_types, geps, workgroup_dims, thread_dims);

// Step 3)
// Here we fill the KernelArgsTy, see the documentation above
Expand All @@ -567,9 +624,8 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
s_ident_t,
// FIXME(offload) give users a way to select which GPU to use.
cx.get_const_i64(u64::MAX), // MAX == -1.
// FIXME(offload): Don't hardcode the numbers of threads in the future.
cx.get_const_i32(2097152),
cx.get_const_i32(256),
num_workgroups,
threads_per_block,
region_id,
a5,
];
Expand Down
7 changes: 4 additions & 3 deletions compiler/rustc_codegen_llvm/src/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use tracing::debug;
use crate::abi::FnAbiLlvmExt;
use crate::builder::Builder;
use crate::builder::autodiff::{adjust_activity_to_abi, generate_enzyme_call};
use crate::builder::gpu_offload::{gen_call_handling, gen_define_handling};
use crate::builder::gpu_offload::{OffloadKernelDims, gen_call_handling, gen_define_handling};
use crate::context::CodegenCx;
use crate::errors::{
AutoDiffWithoutEnable, AutoDiffWithoutLto, OffloadWithoutEnable, OffloadWithoutFatLTO,
Expand Down Expand Up @@ -1286,7 +1286,8 @@ fn codegen_offload<'ll, 'tcx>(
}
};

let args = get_args_from_tuple(bx, args[1], fn_target);
let offload_dims = OffloadKernelDims::from_operands(bx, &args[1], &args[2]);
let args = get_args_from_tuple(bx, args[3], fn_target);
let target_symbol = symbol_name_for_instance_in_crate(tcx, fn_target, LOCAL_CRATE);

let sig = tcx.fn_sig(fn_target.def_id()).skip_binder().skip_binder();
Expand All @@ -1305,7 +1306,7 @@ fn codegen_offload<'ll, 'tcx>(
}
};
let offload_data = gen_define_handling(&cx, &metadata, &types, target_symbol, offload_globals);
gen_call_handling(bx, &offload_data, &args, &types, &metadata, offload_globals);
gen_call_handling(bx, &offload_data, &args, &types, &metadata, offload_globals, &offload_dims);
}

fn get_args_from_tuple<'ll, 'tcx>(
Expand Down
14 changes: 12 additions & 2 deletions compiler/rustc_hir_analysis/src/check/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use rustc_abi::ExternAbi;
use rustc_errors::DiagMessage;
use rustc_hir::{self as hir, LangItem};
use rustc_middle::traits::{ObligationCause, ObligationCauseCode};
use rustc_middle::ty::{self, Ty, TyCtxt};
use rustc_middle::ty::{self, Const, Ty, TyCtxt};
use rustc_span::def_id::LocalDefId;
use rustc_span::{Span, Symbol, sym};

Expand Down Expand Up @@ -315,7 +315,17 @@ pub(crate) fn check_intrinsic_type(
let type_id = tcx.type_of(tcx.lang_items().type_id().unwrap()).instantiate_identity();
(0, 0, vec![type_id, type_id], tcx.types.bool)
}
sym::offload => (3, 0, vec![param(0), param(1)], param(2)),
sym::offload => (
3,
0,
vec![
param(0),
Ty::new_array_with_const_len(tcx, tcx.types.u32, Const::from_target_usize(tcx, 3)),
Ty::new_array_with_const_len(tcx, tcx.types.u32, Const::from_target_usize(tcx, 3)),
param(1),
],
param(2),
),
sym::offset => (2, 0, vec![param(0), param(1)], param(0)),
sym::arith_offset => (
1,
Expand Down
15 changes: 13 additions & 2 deletions library/core/src/intrinsics/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3407,11 +3407,17 @@ pub const fn autodiff<F, G, T: crate::marker::Tuple, R>(f: F, df: G, args: T) ->
/// - `T`: A tuple of arguments passed to `f`.
/// - `R`: The return type of the kernel.
///
/// Arguments:
/// - `f`: The kernel function to offload.
/// - `workgroup_dim`: A 3D size specifying the number of workgroups to launch.
/// - `thread_dim`: A 3D size specifying the number of threads per work-group.
/// - `args`: A tuple of arguments forwarded to `f`.
///
/// Example usage (pseudocode):
///
/// ```rust,ignore (pseudocode)
/// fn kernel(x: *mut [f64; 128]) {
/// core::intrinsics::offload(kernel_1, (x,))
/// core::intrinsics::offload(kernel_1, [256, 1, 1], [32, 1, 1], (x,))
/// }
///
/// #[cfg(target_os = "linux")]
Expand All @@ -3430,7 +3436,12 @@ pub const fn autodiff<F, G, T: crate::marker::Tuple, R>(f: F, df: G, args: T) ->
/// <https://clang.llvm.org/docs/OffloadingDesign.html>.
#[rustc_nounwind]
#[rustc_intrinsic]
pub const fn offload<F, T: crate::marker::Tuple, R>(f: F, args: T) -> R;
pub const fn offload<F, T: crate::marker::Tuple, R>(
f: F,
workgroup_dim: [u32; 3],
thread_dim: [u32; 3],
args: T,
) -> R;

/// Inform Miri that a given pointer definitely has a certain alignment.
#[cfg(miri)]
Expand Down
2 changes: 1 addition & 1 deletion src/doc/rustc-dev-guide/src/offload/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ fn main() {

#[inline(never)]
unsafe fn kernel(x: *mut [f64; 256]) {
core::intrinsics::offload(kernel_1, (x,))
core::intrinsics::offload(_kernel_1, [256, 1, 1], [32, 1, 1], (x,))
}

#[cfg(target_os = "linux")]
Expand Down
14 changes: 7 additions & 7 deletions tests/codegen-llvm/gpu_offload/gpu_host.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,14 @@ fn main() {
// CHECK-NEXT: %5 = getelementptr inbounds nuw i8, ptr %kernel_args, i64 40
// CHECK-NEXT: %6 = getelementptr inbounds nuw i8, ptr %kernel_args, i64 72
// CHECK-NEXT: call void @llvm.memset.p0.i64(ptr noundef nonnull align 8 dereferenceable(32) %5, i8 0, i64 32, i1 false)
// CHECK-NEXT: store <4 x i32> <i32 2097152, i32 0, i32 0, i32 256>, ptr %6, align 8
// CHECK-NEXT: %.fca.1.gep3 = getelementptr inbounds nuw i8, ptr %kernel_args, i64 88
// CHECK-NEXT: store i32 0, ptr %.fca.1.gep3, align 8
// CHECK-NEXT: %.fca.2.gep4 = getelementptr inbounds nuw i8, ptr %kernel_args, i64 92
// CHECK-NEXT: store i32 0, ptr %.fca.2.gep4, align 4
// CHECK-NEXT: store <4 x i32> <i32 256, i32 1, i32 1, i32 32>, ptr %6, align 8
// CHECK-NEXT: %.fca.1.gep5 = getelementptr inbounds nuw i8, ptr %kernel_args, i64 88
// CHECK-NEXT: store i32 1, ptr %.fca.1.gep5, align 8
// CHECK-NEXT: %.fca.2.gep7 = getelementptr inbounds nuw i8, ptr %kernel_args, i64 92
// CHECK-NEXT: store i32 1, ptr %.fca.2.gep7, align 4
// CHECK-NEXT: %7 = getelementptr inbounds nuw i8, ptr %kernel_args, i64 96
// CHECK-NEXT: store i32 0, ptr %7, align 8
// CHECK-NEXT: %8 = call i32 @__tgt_target_kernel(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 2097152, i32 256, ptr nonnull @._kernel_1.region_id, ptr nonnull %kernel_args)
// CHECK-NEXT: %8 = call i32 @__tgt_target_kernel(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 256, i32 32, ptr nonnull @._kernel_1.region_id, ptr nonnull %kernel_args)
// CHECK-NEXT: call void @__tgt_target_data_end_mapper(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 1, ptr nonnull %.offload_baseptrs, ptr nonnull %.offload_ptrs, ptr nonnull %.offload_sizes, ptr nonnull @.offload_maptypes._kernel_1, ptr null, ptr null)
// CHECK-NEXT: call void @__tgt_unregister_lib(ptr nonnull %EmptyDesc)
// CHECK-NEXT: ret void
Expand All @@ -98,7 +98,7 @@ fn main() {
#[unsafe(no_mangle)]
#[inline(never)]
pub fn kernel_1(x: &mut [f32; 256]) {
core::intrinsics::offload(_kernel_1, (x,))
core::intrinsics::offload(_kernel_1, [256, 1, 1], [32, 1, 1], (x,))
}

#[unsafe(no_mangle)]
Expand Down
Loading