Preserve param alignment in NVPTXLowerArgs pass.

NVPTXLowerArgs works as follows.

  * Create a regular alloca with alignment identical to arg.
  * Copy arg from param space (and ASC'ing it from generic AS first) to
    the alloca (it's still in generic AS).
  * Replace loads of arg with loads of alloca.

The bug here is that we did not preserve the arg's alignment when
loading from the alloca.

The impact of this bug is that sometimes param loads would be lowered as
a series of u8 loads, because we're incorrectly assuming everything has
alignment 1.

Differential Revision: https://reviews.llvm.org/D89404
This commit is contained in:
Justin Lebar
2020-10-14 09:30:05 -07:00
parent 008c0ea6a4
commit e9ac1869a8
2 changed files with 32 additions and 1 deletions

View File

@@ -172,8 +172,12 @@ void NVPTXLowerArgs::handleByValParam(Argument *Arg) {
Value *ArgInParam = new AddrSpaceCastInst(
Arg, PointerType::get(StructType, ADDRESS_SPACE_PARAM), Arg->getName(),
FirstInst);
// Be sure to propagate alignment to this load; LLVM doesn't know that NVPTX
// addrspacecast preserves alignment. Since params are constant, this load is
// definitely not volatile.
LoadInst *LI =
new LoadInst(StructType, ArgInParam, Arg->getName(), FirstInst);
new LoadInst(StructType, ArgInParam, Arg->getName(),
/*isVolatile=*/false, AllocA->getAlign(), FirstInst);
new StoreInst(LI, AllocA, FirstInst);
}

View File

@@ -0,0 +1,27 @@
; RUN: opt < %s -S -nvptx-lower-args | FileCheck %s --check-prefix IR
; RUN: llc < %s -mcpu=sm_20 | FileCheck %s --check-prefix PTX
target datalayout = "e-i64:64-i128:128-v16:16-v32:32-n16:32:64"
target triple = "nvptx64-nvidia-cuda"
%class.outer = type <{ %class.inner, i32, [4 x i8] }>
%class.inner = type { i32*, i32* }
; Check that nvptx-lower-args preserves arg alignment
define void @load_alignment(%class.outer* nocapture readonly byval(%class.outer) align 8 %arg) {
entry:
; IR: load %class.outer, %class.outer addrspace(101)*
; IR-SAME: align 8
; PTX: ld.param.u64
; PTX-NOT: ld.param.u8
%arg.idx = getelementptr %class.outer, %class.outer* %arg, i64 0, i32 0, i32 0
%arg.idx.val = load i32*, i32** %arg.idx, align 8
%arg.idx1 = getelementptr %class.outer, %class.outer* %arg, i64 0, i32 0, i32 1
%arg.idx1.val = load i32*, i32** %arg.idx1, align 8
%arg.idx2 = getelementptr %class.outer, %class.outer* %arg, i64 0, i32 1
%arg.idx2.val = load i32, i32* %arg.idx2, align 8
%arg.idx.val.val = load i32, i32* %arg.idx.val, align 4
%add.i = add nsw i32 %arg.idx.val.val, %arg.idx2.val
store i32 %add.i, i32* %arg.idx1.val, align 4
ret void
}