From 064ada4ec6bb4cb77d809ba366c90ca59e95d4ba Mon Sep 17 00:00:00 2001 From: Craig Topper Date: Tue, 16 Feb 2021 09:22:44 -0800 Subject: [PATCH] [SelectionDAG][AArch64] Restrict matchUnaryPredicate to only handle SPLAT_VECTOR for scalable vectors. fde24661718c7812a20a10e518cd853e8e060107 added support for scalable vectors to matchUnaryPredicate by handling SPLAT_VECTOR in addition to BUILD_VECTOR. This was used to enabled UDIV/SDIV/UREM/SREM by constant expansion in BuildUDIV/BuildSDIV in TargetLowering.cpp The caller there expects to call getBuildVector from the match factors. This leads to a crash right now if there is a SPLAT_VECTOR of fixed vectors since the number of vectors won't match the number of elements. To fix this, this patch updates the callers to check the opcode instead of whether the type is fixed or scalable. This assumes that only 3 opcodes are handled by matchUnaryPredicate so I've added an assertion to the final else to check that opcode. Reviewed By: RKSimon Differential Revision: https://reviews.llvm.org/D96174 --- .../CodeGen/SelectionDAG/TargetLowering.cpp | 20 +++++++++++-------- .../AArch64/sve-fixed-length-int-div.ll | 16 +++++++++++++++ 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp index ce1cd3d40cae..500bdb401b95 100644 --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -5032,16 +5032,17 @@ static SDValue BuildExactSDIV(const TargetLowering &TLI, SDNode *N, return SDValue(); SDValue Shift, Factor; - if (VT.isFixedLengthVector()) { + if (Op1.getOpcode() == ISD::BUILD_VECTOR) { Shift = DAG.getBuildVector(ShVT, dl, Shifts); Factor = DAG.getBuildVector(VT, dl, Factors); - } else if (VT.isScalableVector()) { + } else if (Op1.getOpcode() == ISD::SPLAT_VECTOR) { assert(Shifts.size() == 1 && Factors.size() == 1 && "Expected matchUnaryPredicate to return one element for scalable " "vectors"); Shift = DAG.getSplatVector(ShVT, dl, Shifts[0]); Factor = DAG.getSplatVector(VT, dl, Factors[0]); } else { + assert(isa(Op1) && "Expected a constant"); Shift = Shifts[0]; Factor = Factors[0]; } @@ -5147,12 +5148,12 @@ SDValue TargetLowering::BuildSDIV(SDNode *N, SelectionDAG &DAG, return SDValue(); SDValue MagicFactor, Factor, Shift, ShiftMask; - if (VT.isFixedLengthVector()) { + if (N1.getOpcode() == ISD::BUILD_VECTOR) { MagicFactor = DAG.getBuildVector(VT, dl, MagicFactors); Factor = DAG.getBuildVector(VT, dl, Factors); Shift = DAG.getBuildVector(ShVT, dl, Shifts); ShiftMask = DAG.getBuildVector(VT, dl, ShiftMasks); - } else if (VT.isScalableVector()) { + } else if (N1.getOpcode() == ISD::SPLAT_VECTOR) { assert(MagicFactors.size() == 1 && Factors.size() == 1 && Shifts.size() == 1 && ShiftMasks.size() == 1 && "Expected matchUnaryPredicate to return one element for scalable " @@ -5162,6 +5163,7 @@ SDValue TargetLowering::BuildSDIV(SDNode *N, SelectionDAG &DAG, Shift = DAG.getSplatVector(ShVT, dl, Shifts[0]); ShiftMask = DAG.getSplatVector(VT, dl, ShiftMasks[0]); } else { + assert(isa(N1) && "Expected a constant"); MagicFactor = MagicFactors[0]; Factor = Factors[0]; Shift = Shifts[0]; @@ -5303,12 +5305,12 @@ SDValue TargetLowering::BuildUDIV(SDNode *N, SelectionDAG &DAG, return SDValue(); SDValue PreShift, PostShift, MagicFactor, NPQFactor; - if (VT.isFixedLengthVector()) { + if (N1.getOpcode() == ISD::BUILD_VECTOR) { PreShift = DAG.getBuildVector(ShVT, dl, PreShifts); MagicFactor = DAG.getBuildVector(VT, dl, MagicFactors); NPQFactor = DAG.getBuildVector(VT, dl, NPQFactors); PostShift = DAG.getBuildVector(ShVT, dl, PostShifts); - } else if (VT.isScalableVector()) { + } else if (N1.getOpcode() == ISD::SPLAT_VECTOR) { assert(PreShifts.size() == 1 && MagicFactors.size() == 1 && NPQFactors.size() == 1 && PostShifts.size() == 1 && "Expected matchUnaryPredicate to return one for scalable vectors"); @@ -5317,6 +5319,7 @@ SDValue TargetLowering::BuildUDIV(SDNode *N, SelectionDAG &DAG, NPQFactor = DAG.getSplatVector(VT, dl, NPQFactors[0]); PostShift = DAG.getSplatVector(ShVT, dl, PostShifts[0]); } else { + assert(isa(N1) && "Expected a constant"); PreShift = PreShifts[0]; MagicFactor = MagicFactors[0]; PostShift = PostShifts[0]; @@ -5806,7 +5809,7 @@ TargetLowering::prepareSREMEqFold(EVT SETCCVT, SDValue REMNode, return SDValue(); SDValue PVal, AVal, KVal, QVal; - if (VT.isFixedLengthVector()) { + if (D.getOpcode() == ISD::BUILD_VECTOR) { if (HadOneDivisor) { // Try to turn PAmts into a splat, since we don't care about the values // that are currently '0'. If we can't, just keep '0'`s. @@ -5825,7 +5828,7 @@ TargetLowering::prepareSREMEqFold(EVT SETCCVT, SDValue REMNode, AVal = DAG.getBuildVector(VT, DL, AAmts); KVal = DAG.getBuildVector(ShVT, DL, KAmts); QVal = DAG.getBuildVector(VT, DL, QAmts); - } else if (VT.isScalableVector()) { + } else if (D.getOpcode() == ISD::SPLAT_VECTOR) { assert(PAmts.size() == 1 && AAmts.size() == 1 && KAmts.size() == 1 && QAmts.size() == 1 && "Expected matchUnaryPredicate to return one element for scalable " @@ -5835,6 +5838,7 @@ TargetLowering::prepareSREMEqFold(EVT SETCCVT, SDValue REMNode, KVal = DAG.getSplatVector(ShVT, DL, KAmts[0]); QVal = DAG.getSplatVector(VT, DL, QAmts[0]); } else { + assert(isa(D) && "Expected a constant"); PVal = PAmts[0]; AVal = AAmts[0]; KVal = KAmts[0]; diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-int-div.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-int-div.ll index b6e06556e4bf..b70228260850 100644 --- a/llvm/test/CodeGen/AArch64/sve-fixed-length-int-div.ll +++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-int-div.ll @@ -968,4 +968,20 @@ define void @udiv_v32i64(<32 x i64>* %a, <32 x i64>* %b) #0 { ret void } +; This used to crash because isUnaryPredicate and BuildUDIV don't know how +; a SPLAT_VECTOR of fixed vector type should be handled. +define void @udiv_constantsplat_v8i32(<8 x i32>* %a) #0 { +; CHECK-LABEL: udiv_constantsplat_v8i32: +; CHECK: ptrue [[PG:p[0-9]+]].s, vl[[#min(div(VBYTES,4),8)]] +; CHECK-NEXT: ld1w { [[OP1:z[0-9]+]].s }, [[PG]]/z, [x0] +; CHECK-NEXT: mov [[OP2:z[0-9]+]].s, #95 +; CHECK-NEXT: udiv [[RES:z[0-9]+]].s, [[PG]]/m, [[OP1]].s, [[OP2]].s +; CHECK-NEXT: st1w { [[RES]].s }, [[PG]], [x0] +; CHECK-NEXT: ret + %op1 = load <8 x i32>, <8 x i32>* %a + %res = udiv <8 x i32> %op1, + store <8 x i32> %res, <8 x i32>* %a + ret void +} + attributes #0 = { "target-features"="+sve" }