Skip to content

Commit ff9533a

Browse files
committed
Emit alignment for all memory ops
1 parent 6dc03a0 commit ff9533a

File tree

2 files changed

+59
-20
lines changed

2 files changed

+59
-20
lines changed

crates/rustc_codegen_spirv/src/builder/builder_methods.rs

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1473,11 +1473,17 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
14731473
self.fatal("dynamic alloca not supported yet")
14741474
}
14751475

1476-
fn load(&mut self, ty: Self::Type, ptr: Self::Value, _align: Align) -> Self::Value {
1476+
fn load(&mut self, ty: Self::Type, ptr: Self::Value, align: Align) -> Self::Value {
14771477
let (ptr, access_ty) = self.adjust_pointer_for_typed_access(ptr, ty);
14781478
let loaded_val = ptr.const_fold_load(self).unwrap_or_else(|| {
14791479
self.emit()
1480-
.load(access_ty, None, ptr.def(self), None, empty())
1480+
.load(
1481+
access_ty,
1482+
None,
1483+
ptr.def(self),
1484+
Some(MemoryAccess::ALIGNED),
1485+
std::iter::once(Operand::LiteralBit32(align.bytes() as _)),
1486+
)
14811487
.unwrap()
14821488
.with_type(access_ty)
14831489
});
@@ -1599,12 +1605,17 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
15991605
// ignore
16001606
}
16011607

1602-
fn store(&mut self, val: Self::Value, ptr: Self::Value, _align: Align) -> Self::Value {
1608+
fn store(&mut self, val: Self::Value, ptr: Self::Value, align: Align) -> Self::Value {
16031609
let (ptr, access_ty) = self.adjust_pointer_for_typed_access(ptr, val.ty);
16041610
let val = self.bitcast(val, access_ty);
16051611

16061612
self.emit()
1607-
.store(ptr.def(self), val.def(self), None, empty())
1613+
.store(
1614+
ptr.def(self),
1615+
val.def(self),
1616+
Some(MemoryAccess::ALIGNED),
1617+
std::iter::once(Operand::LiteralBit32(align.bytes() as _)),
1618+
)
16081619
.unwrap();
16091620
// FIXME(eddyb) this is meant to be a handle the store instruction itself.
16101621
val
@@ -2262,9 +2273,9 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
22622273
fn memcpy(
22632274
&mut self,
22642275
dst: Self::Value,
2265-
_dst_align: Align,
2276+
dst_align: Align,
22662277
src: Self::Value,
2267-
_src_align: Align,
2278+
src_align: Align,
22682279
size: Self::Value,
22692280
flags: MemFlags,
22702281
) {
@@ -2302,12 +2313,29 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
23022313
}
23032314
});
23042315

2316+
// Pass all operands as `additional_params` since rspirv doesn't allow specifying
2317+
// extra operands ofter the first `MemoryAccess`
2318+
let mut ops: SmallVec<[_; 4]> = Default::default();
2319+
ops.push(Operand::MemoryAccess(MemoryAccess::ALIGNED));
2320+
if src_align != dst_align {
2321+
if self.emit().version().unwrap() > (1, 3) {
2322+
ops.push(Operand::LiteralBit32(dst_align.bytes() as _));
2323+
ops.push(Operand::MemoryAccess(MemoryAccess::ALIGNED));
2324+
ops.push(Operand::LiteralBit32(src_align.bytes() as _));
2325+
} else {
2326+
let align = dst_align.min(src_align);
2327+
ops.push(Operand::LiteralBit32(align.bytes() as _));
2328+
}
2329+
} else {
2330+
ops.push(Operand::LiteralBit32(dst_align.bytes() as _));
2331+
}
2332+
23052333
if let Some((dst, src)) = typed_copy_dst_src {
23062334
if let Some(const_value) = src.const_fold_load(self) {
23072335
self.store(const_value, dst, Align::from_bytes(0).unwrap());
23082336
} else {
23092337
self.emit()
2310-
.copy_memory(dst.def(self), src.def(self), None, None, empty())
2338+
.copy_memory(dst.def(self), src.def(self), None, None, ops)
23112339
.unwrap();
23122340
}
23132341
} else {
@@ -2318,7 +2346,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
23182346
size.def(self),
23192347
None,
23202348
None,
2321-
empty(),
2349+
ops,
23222350
)
23232351
.unwrap();
23242352
self.zombie(dst.def(self), "cannot memcpy dynamically sized data");

crates/rustc_codegen_spirv/src/linker/mem2reg.rs

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@
1111
1212
use super::simple_passes::outgoing_edges;
1313
use super::{apply_rewrite_rules, id};
14+
use itertools::Itertools;
1415
use rspirv::dr::{Block, Function, Instruction, ModuleHeader, Operand};
1516
use rspirv::spirv::{Op, Word};
1617
use rustc_data_structures::fx::{FxHashMap, FxHashSet, FxIndexMap};
17-
use rustc_middle::bug;
1818
use std::collections::hash_map;
19+
use std::iter;
1920

2021
// HACK(eddyb) newtype instead of type alias to avoid mistakes.
2122
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
@@ -328,10 +329,15 @@ fn split_copy_memory(
328329
if inst.class.opcode == Op::CopyMemory {
329330
let target = inst.operands[0].id_ref_any().unwrap();
330331
let source = inst.operands[1].id_ref_any().unwrap();
331-
if inst.operands.len() > 2 {
332-
// TODO: Copy the memory operands to the load/store
333-
bug!("mem2reg OpCopyMemory doesn't support memory operands yet");
334-
}
332+
let mem_ops = &inst.operands[2..];
333+
let (store_mem_ops, load_mem_ops) = if let Some((index, _)) = mem_ops[1..]
334+
.iter()
335+
.find_position(|op| matches!(op, Operand::MemoryAccess(..)))
336+
{
337+
mem_ops.split_at(index)
338+
} else {
339+
(mem_ops, mem_ops)
340+
};
335341
let ty = match (var_map.get(&target), var_map.get(&source)) {
336342
(None, None) => {
337343
inst_index += 1;
@@ -345,17 +351,22 @@ fn split_copy_memory(
345351
}
346352
};
347353
let temp_id = id(header);
354+
355+
let load_ops = iter::once(Operand::IdRef(source))
356+
.chain(load_mem_ops.iter().cloned())
357+
.collect();
358+
359+
let store_ops = [Operand::IdRef(target), Operand::IdRef(temp_id)]
360+
.into_iter()
361+
.chain(store_mem_ops.iter().cloned())
362+
.collect();
363+
348364
block.instructions[inst_index] =
349-
Instruction::new(Op::Load, Some(ty), Some(temp_id), vec![Operand::IdRef(
350-
source,
351-
)]);
365+
Instruction::new(Op::Load, Some(ty), Some(temp_id), load_ops);
352366
inst_index += 1;
353367
block.instructions.insert(
354368
inst_index,
355-
Instruction::new(Op::Store, None, None, vec![
356-
Operand::IdRef(target),
357-
Operand::IdRef(temp_id),
358-
]),
369+
Instruction::new(Op::Store, None, None, store_ops),
359370
);
360371
}
361372
inst_index += 1;

0 commit comments

Comments
 (0)