Skip to content
This repository was archived by the owner on Oct 31, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 76 additions & 18 deletions crates/rustc_codegen_spirv/src/codegen_cx/entry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@ use crate::builder_spirv::SpirvValue;
use crate::spirv_type::SpirvType;
use rspirv::dr::Operand;
use rspirv::spirv::{Decoration, ExecutionModel, FunctionControl, StorageClass, Word};
use rustc_codegen_ssa::traits::BaseTypeMethods;
use rustc_hir as hir;
use rustc_middle::ty::layout::TyAndLayout;
use rustc_middle::ty::layout::{HasParamEnv, TyAndLayout};
use rustc_middle::ty::{Instance, Ty, TyKind};
use rustc_span::Span;
use rustc_target::abi::call::{FnAbi, PassMode};
use rustc_target::abi::LayoutOf;
use rustc_target::abi::{
call::{ArgAbi, ArgAttribute, ArgAttributes, FnAbi, PassMode},
LayoutOf, Size,
};
use std::collections::HashMap;

impl<'tcx> CodegenCx<'tcx> {
Expand All @@ -37,9 +40,27 @@ impl<'tcx> CodegenCx<'tcx> {
};
let fn_hir_id = self.tcx.hir().local_def_id_to_hir_id(local_id);
let body = self.tcx.hir().body(self.tcx.hir().body_owned_by(fn_hir_id));
const EMPTY: ArgAttribute = ArgAttribute::empty();
for (abi, arg) in fn_abi.args.iter().zip(body.params) {
match abi.mode {
PassMode::Direct(_) | PassMode::Indirect { .. } => {}
PassMode::Direct(_)
| PassMode::Indirect { .. }
// plain DST/RTA/VLA
| PassMode::Pair(
ArgAttributes {
pointee_size: Size::ZERO,
..
},
ArgAttributes { regular: EMPTY, .. },
)
// DST struct with fields before the DST member
| PassMode::Pair(
ArgAttributes { .. },
ArgAttributes {
pointee_size: Size::ZERO,
..
},
Comment on lines +56 to +62
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The second element in a pair for &SomeUnsizedType is always the "metadata" like usize for anything based on slices. I assume pointee_size is just zero for non-pointers (and I'm not sure checking pointee_size even makes sense?).

You could probably just check that the type of the PassMode::Pair argument is TyKind::Ref, that should limit it to &SomeUnsizedType.

) => {}
_ => self.tcx.sess.span_err(
arg.span,
&format!("PassMode {:?} invalid for entry point parameter", abi.mode),
Expand All @@ -63,7 +84,7 @@ impl<'tcx> CodegenCx<'tcx> {
self.shader_entry_stub(
self.tcx.def_span(instance.def_id()),
entry_func,
fn_abi,
&fn_abi.args,
body.params,
name,
execution_model,
Expand All @@ -82,7 +103,7 @@ impl<'tcx> CodegenCx<'tcx> {
&self,
span: Span,
entry_func: SpirvValue,
entry_fn_abi: &FnAbi<'tcx, Ty<'tcx>>,
arg_abis: &[ArgAbi<'tcx, Ty<'tcx>>],
Comment on lines -85 to +106
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this is a good idea, especially if we want to do something with the return type at some point.

hir_params: &[hir::Param<'tcx>],
name: String,
execution_model: ExecutionModel,
Expand All @@ -94,25 +115,22 @@ impl<'tcx> CodegenCx<'tcx> {
}
.def(span, self);
let entry_func_return_type = match self.lookup_type(entry_func.ty) {
SpirvType::Function {
return_type,
arguments: _,
} => return_type,
SpirvType::Function { return_type, .. } => return_type,
other => self.tcx.sess.fatal(&format!(
"Invalid entry_stub type: {}",
other.debug(entry_func.ty, self)
)),
};
let mut decoration_locations = HashMap::new();
// Create OpVariables before OpFunction so they're global instead of local vars.
let declared_params = entry_fn_abi
.args
let declared_params = arg_abis
.iter()
.zip(hir_params)
.map(|(entry_fn_arg, hir_param)| {
self.declare_parameter(entry_fn_arg.layout, hir_param, &mut decoration_locations)
})
.collect::<Vec<_>>();
let len_t = self.type_isize();
let mut emit = self.emit_global();
let fn_id = emit
.begin_function(void, None, FunctionControl::NONE, fn_void_void)
Expand All @@ -121,12 +139,19 @@ impl<'tcx> CodegenCx<'tcx> {
// Adjust any global `OpVariable`s as needed (e.g. loading from `Input`s).
let arguments: Vec<_> = declared_params
.iter()
.zip(&entry_fn_abi.args)
.zip(arg_abis)
.zip(hir_params)
.map(|((&(var, storage_class), entry_fn_arg), hir_param)| {
match entry_fn_arg.layout.ty.kind() {
TyKind::Ref(..) => var,

.flat_map(|((&(var, storage_class), entry_fn_arg), hir_param)| {
let mut dst_len_arg = None;
let arg = match entry_fn_arg.layout.ty.kind() {
TyKind::Ref(_, ty, _) => {
if !ty.is_sized(self.tcx.at(span), self.param_env()) {
dst_len_arg.replace(
self.dst_length_argument(&mut emit, ty, hir_param, len_t, var),
);
}
var
}
_ => match entry_fn_arg.mode {
PassMode::Indirect { .. } => var,
PassMode::Direct(_) => {
Expand All @@ -142,7 +167,8 @@ impl<'tcx> CodegenCx<'tcx> {
}
_ => unreachable!(),
},
}
};
std::iter::once(arg).chain(dst_len_arg)
})
.collect();
emit.function_call(
Expand Down Expand Up @@ -170,6 +196,38 @@ impl<'tcx> CodegenCx<'tcx> {
fn_id
}

fn dst_length_argument(
&self,
emit: &mut std::cell::RefMut<'_, rspirv::dr::Builder>,
ty: Ty<'tcx>,
hir_param: &hir::Param<'tcx>,
len_t: Word,
var: Word,
) -> Word {
match ty.kind() {
TyKind::Adt(adt_def, substs) => {
let (member_idx, field_def) = adt_def.all_fields().enumerate().last().unwrap();
let field_ty = field_def.ty(self.tcx, substs);
if !matches!(field_ty.kind(), TyKind::Slice(..)) {
self.tcx.sess.span_fatal(
hir_param.ty_span,
"DST parameters are currently restricted to a reference to a struct whose last field is a slice.",
)
}
emit.array_length(len_t, None, var, member_idx as u32)
.unwrap()
}
TyKind::Slice(..) | TyKind::Str => self.tcx.sess.span_fatal(
hir_param.ty_span,
"Straight slices are not yet supported, wrap the slice in a newtype.",
),
_ => self
.tcx
.sess
.span_fatal(hir_param.ty_span, "Unsupported parameter type."),
}
}

fn declare_parameter(
&self,
layout: TyAndLayout<'tcx>,
Expand Down
11 changes: 11 additions & 0 deletions crates/rustc_codegen_spirv/src/spirv_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,17 @@ impl SpirvType {
}
Self::RuntimeArray { element } => {
let result = cx.emit_global().type_runtime_array(element);
// ArrayStride decoration wants in *bytes*
let element_size = cx
.lookup_type(element)
.sizeof(cx)
.expect("Element of sized array must be sized")
.bytes();
cx.emit_global().decorate(
result,
Decoration::ArrayStride,
iter::once(Operand::LiteralInt32(element_size as u32)),
);
if cx.kernel_mode {
cx.zombie_with_span(result, def_span, "RuntimeArray in kernel mode");
}
Expand Down
82 changes: 67 additions & 15 deletions crates/spirv-builder/src/test/basic.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{dis_fn, dis_globals, val, val_vulkan};
use super::{dis_entry_fn, dis_fn, dis_globals, val, val_vulkan};
use std::ffi::OsStr;

struct SetEnvVar<'a> {
Expand Down Expand Up @@ -183,20 +183,21 @@ OpEntryPoint Fragment %1 "main"
OpExecutionMode %1 OriginUpperLeft
OpName %2 "test_project::add_decorate"
OpName %3 "test_project::main"
OpDecorate %4 DescriptorSet 0
OpDecorate %4 Binding 0
%5 = OpTypeVoid
%6 = OpTypeFunction %5
%7 = OpTypeInt 32 0
%8 = OpTypePointer Function %7
%9 = OpConstant %7 1
%10 = OpTypeFloat 32
%11 = OpTypeImage %10 2D 0 0 0 1 Unknown
%12 = OpTypeSampledImage %11
%13 = OpTypeRuntimeArray %12
%14 = OpTypePointer UniformConstant %13
%4 = OpVariable %14 UniformConstant
%15 = OpTypePointer UniformConstant %12"#,
OpDecorate %4 ArrayStride 4
OpDecorate %5 DescriptorSet 0
OpDecorate %5 Binding 0
%6 = OpTypeVoid
%7 = OpTypeFunction %6
%8 = OpTypeInt 32 0
%9 = OpTypePointer Function %8
%10 = OpConstant %8 1
%11 = OpTypeFloat 32
%12 = OpTypeImage %11 2D 0 0 0 1 Unknown
%13 = OpTypeSampledImage %12
%4 = OpTypeRuntimeArray %13
%14 = OpTypePointer UniformConstant %4
%5 = OpVariable %14 UniformConstant
%15 = OpTypePointer UniformConstant %13"#,
);
}

Expand Down Expand Up @@ -479,3 +480,54 @@ fn ptr_copy_from_method() {
"#
);
}

#[test]
fn index_user_dst() {
dis_entry_fn(
r#"
#[spirv(fragment)]
pub fn main(
#[spirv(uniform, descriptor_set = 0, binding = 0)] slice: &mut SliceF32,
) {
let float: f32 = slice.rta[0];
let _ = float;
}

pub struct SliceF32 {
rta: [f32],
}
"#,
"main",
r#"%1 = OpFunction %2 None %3
%4 = OpLabel
%5 = OpArrayLength %6 %7 0
%8 = OpCompositeInsert %9 %7 %10 0
%11 = OpCompositeInsert %9 %5 %8 1
%12 = OpAccessChain %13 %7 %14
%15 = OpULessThan %16 %14 %5
OpSelectionMerge %17 None
OpBranchConditional %15 %18 %19
%18 = OpLabel
%20 = OpAccessChain %13 %7 %14
%21 = OpInBoundsAccessChain %22 %20 %14
%23 = OpLoad %24 %21
OpReturn
%19 = OpLabel
OpBranch %25
%25 = OpLabel
OpBranch %26
%26 = OpLabel
%27 = OpPhi %16 %28 %25 %28 %29
OpLoopMerge %30 %29 None
OpBranchConditional %27 %31 %30
%31 = OpLabel
OpBranch %29
%29 = OpLabel
OpBranch %26
%30 = OpLabel
OpUnreachable
%17 = OpLabel
OpUnreachable
OpFunctionEnd"#,
)
}
27 changes: 27 additions & 0 deletions crates/spirv-builder/src/test/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,33 @@ fn dis_fn(src: &str, func: &str, expect: &str) {
assert_str_eq(expect, &func.disassemble())
}

fn dis_entry_fn(src: &str, func: &str, expect: &str) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was confused by this so maybe dis_entry_stub might be easier to get? A comment would also help, but the name made me think the high-level Rust fn, not the generated stub.

let _lock = global_lock();
let module = read_module(&build(src)).unwrap();
let id = module
.entry_points
.iter()
.find(|inst| inst.operands.last().unwrap().unwrap_literal_string() == func)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it always operands[1] for the name, whereas further operands are interface variables? https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpEntryPoint

.unwrap_or_else(|| {
panic!(
"no entry point with the name `{}` found in:\n{}\n",
func,
module.disassemble()
)
})
.operands[1]
.unwrap_id_ref();
let mut func = module
.functions
.into_iter()
.find(|f| f.def_id().unwrap() == id)
.unwrap();
// Compact to make IDs more stable
compact_ids(&mut func);
use rspirv::binary::Disassemble;
assert_str_eq(expect, &func.disassemble())
Comment on lines +178 to +186
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part could maybe be shared between the two functions (so only the picking the function ID would be different).

}

fn dis_globals(src: &str, expect: &str) {
let _lock = global_lock();
let module = read_module(&build(src)).unwrap();
Expand Down
2 changes: 1 addition & 1 deletion tests/ui/lang/core/ptr/allocate_const_scalar.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ error: pointer has non-null integer address
|
= note: Stack:
allocate_const_scalar::main
Unnamed function ID %4
Unnamed function ID %5

error: invalid binary:0:0 - No OpEntryPoint instruction was found. This is only allowed if the Linkage capability is being used.
|
Expand Down