mirror of
https://github.com/intel/llvm.git
synced 2026-01-14 20:10:50 +08:00
[RISCVGatherScatterLowering] Support broadcast base pointer
A broadcast base pointer is the same as a scalar base pointer for GEP semantics (when there's at least one other vector operand). This is the form that SLP likes to emit, so we should handle it. Differential Revision: https://reviews.llvm.org/D157132
This commit is contained in:
committed by
Philip Reames
parent
56d92c1758
commit
999ac10d76
@@ -331,8 +331,12 @@ RISCVGatherScatterLowering::determineBaseAndStride(GetElementPtrInst *GEP,
|
||||
SmallVector<Value *, 2> Ops(GEP->operands());
|
||||
|
||||
// Base pointer needs to be a scalar.
|
||||
if (Ops[0]->getType()->isVectorTy())
|
||||
return std::make_pair(nullptr, nullptr);
|
||||
Value *ScalarBase = Ops[0];
|
||||
if (ScalarBase->getType()->isVectorTy()) {
|
||||
ScalarBase = getSplatValue(ScalarBase);
|
||||
if (!ScalarBase)
|
||||
return std::make_pair(nullptr, nullptr);
|
||||
}
|
||||
|
||||
std::optional<unsigned> VecOperand;
|
||||
unsigned TypeScale = 0;
|
||||
@@ -379,7 +383,7 @@ RISCVGatherScatterLowering::determineBaseAndStride(GetElementPtrInst *GEP,
|
||||
Ops[*VecOperand] = Start;
|
||||
Type *SourceTy = GEP->getSourceElementType();
|
||||
Value *BasePtr =
|
||||
Builder.CreateGEP(SourceTy, Ops[0], ArrayRef(Ops).drop_front());
|
||||
Builder.CreateGEP(SourceTy, ScalarBase, ArrayRef(Ops).drop_front());
|
||||
|
||||
// Convert stride to pointer size if needed.
|
||||
Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType());
|
||||
@@ -415,7 +419,7 @@ RISCVGatherScatterLowering::determineBaseAndStride(GetElementPtrInst *GEP,
|
||||
Ops[*VecOperand] = BasePhi;
|
||||
Type *SourceTy = GEP->getSourceElementType();
|
||||
Value *BasePtr =
|
||||
Builder.CreateGEP(SourceTy, Ops[0], ArrayRef(Ops).drop_front());
|
||||
Builder.CreateGEP(SourceTy, ScalarBase, ArrayRef(Ops).drop_front());
|
||||
|
||||
// Final adjustments to stride should go in the start block.
|
||||
Builder.SetInsertPoint(
|
||||
|
||||
@@ -947,3 +947,19 @@ bb4: ; preds = %bb4, %bb2
|
||||
bb16: ; preds = %bb4, %bb
|
||||
ret void
|
||||
}
|
||||
|
||||
define <8 x i8> @broadcast_ptr_base(ptr %a) {
|
||||
; CHECK-LABEL: @broadcast_ptr_base(
|
||||
; CHECK-NEXT: entry:
|
||||
; CHECK-NEXT: [[TMP0:%.*]] = call <8 x i8> @llvm.riscv.masked.strided.load.v8i8.p0.i64(<8 x i8> poison, ptr [[A:%.*]], i64 64, <8 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>)
|
||||
; CHECK-NEXT: ret <8 x i8> [[TMP0]]
|
||||
;
|
||||
entry:
|
||||
%0 = insertelement <8 x ptr> poison, ptr %a, i64 0
|
||||
%1 = shufflevector <8 x ptr> %0, <8 x ptr> poison, <8 x i32> zeroinitializer
|
||||
%2 = getelementptr i8, <8 x ptr> %1, <8 x i64> <i64 0, i64 64, i64 128, i64 192, i64 256, i64 320, i64 384, i64 448>
|
||||
%3 = tail call <8 x i8> @llvm.masked.gather.v8i8.v8p0(<8 x ptr> %2, i32 1, <8 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>, <8 x i8> poison)
|
||||
ret <8 x i8> %3
|
||||
}
|
||||
|
||||
declare <8 x i8> @llvm.masked.gather.v8i8.v8p0(<8 x ptr>, i32 immarg, <8 x i1>, <8 x i8>)
|
||||
|
||||
Reference in New Issue
Block a user