diff --git a/llvm/include/llvm/CodeGenTypes/LowLevelType.h b/llvm/include/llvm/CodeGenTypes/LowLevelType.h index 472a3f3e23b3..92def9209d47 100644 --- a/llvm/include/llvm/CodeGenTypes/LowLevelType.h +++ b/llvm/include/llvm/CodeGenTypes/LowLevelType.h @@ -207,10 +207,16 @@ public: return isVector() ? getElementType() : *this; } + /// Returns a vector with the same number of elements but the new element + /// type. Must only be called on vector types. + constexpr LLT changeVectorElementType(LLT NewEltTy) const { + return LLT::vector(getElementCount(), NewEltTy); + } + /// If this type is a vector, return a vector with the same number of elements /// but the new element type. Otherwise, return the new element type. constexpr LLT changeElementType(LLT NewEltTy) const { - return isVector() ? LLT::vector(getElementCount(), NewEltTy) : NewEltTy; + return isVector() ? changeVectorElementType(NewEltTy) : NewEltTy; } /// If this type is a vector, return a vector with the same number of elements @@ -223,6 +229,14 @@ public: : LLT::scalar(NewEltSize); } + /// Return a vector with the same element type and the new element count. Must + /// be called on vector types. + constexpr LLT changeVectorElementCount(ElementCount EC) const { + assert(isVector() && + "cannot change vector element count of non-vector type"); + return LLT::vector(EC, getElementType()); + } + /// Return a vector or scalar with the same element type and the new element /// count. constexpr LLT changeElementCount(ElementCount EC) const { diff --git a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp index 251ea4b1e019..802738ba19a7 100644 --- a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp +++ b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp @@ -70,9 +70,8 @@ getNarrowTypeBreakDown(LLT OrigTy, LLT NarrowTy, LLT &LeftoverTy) { unsigned EltSize = OrigTy.getScalarSizeInBits(); if (LeftoverSize % EltSize != 0) return {-1, -1}; - LeftoverTy = - LLT::scalarOrVector(ElementCount::getFixed(LeftoverSize / EltSize), - OrigTy.getElementType()); + LeftoverTy = OrigTy.changeElementCount( + ElementCount::getFixed(LeftoverSize / EltSize)); } else { LeftoverTy = LLT::scalar(LeftoverSize); } @@ -1558,10 +1557,7 @@ LegalizerHelper::LegalizeResult LegalizerHelper::narrowScalar(MachineInstr &MI, // combines not being hit). This seems to be a problem related to the // artifact combiner. if (SizeOp0 % NarrowSize != 0) { - LLT ImplicitTy = NarrowTy; - if (DstTy.isVector()) - ImplicitTy = LLT::vector(DstTy.getElementCount(), ImplicitTy); - + LLT ImplicitTy = DstTy.changeElementType(NarrowTy); Register ImplicitReg = MIRBuilder.buildUndef(ImplicitTy).getReg(0); MIRBuilder.buildAnyExt(DstReg, ImplicitReg); @@ -3289,7 +3285,8 @@ LegalizerHelper::widenScalar(MachineInstr &MI, unsigned TypeIdx, LLT WideTy) { Observer.changingInstr(MI); widenScalarSrc( - MI, LLT::vector(VecTy.getElementCount(), WideTy.getSizeInBits()), 1, + MI, + VecTy.changeVectorElementType(LLT::scalar(WideTy.getSizeInBits())), 1, TargetOpcode::G_ANYEXT); widenScalarDst(MI, WideTy, 0); @@ -3321,7 +3318,7 @@ LegalizerHelper::widenScalar(MachineInstr &MI, unsigned TypeIdx, LLT WideTy) { Register VecReg = MI.getOperand(1).getReg(); LLT VecTy = MRI.getType(VecReg); - LLT WideVecTy = LLT::vector(VecTy.getElementCount(), WideTy); + LLT WideVecTy = VecTy.changeVectorElementType(WideTy); widenScalarSrc(MI, WideVecTy, 1, TargetOpcode::G_ANYEXT); widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_ANYEXT); @@ -3522,9 +3519,7 @@ LegalizerHelper::widenScalar(MachineInstr &MI, unsigned TypeIdx, LLT WideTy) { Observer.changingInstr(MI); Register VecReg = MI.getOperand(1).getReg(); LLT VecTy = MRI.getType(VecReg); - LLT WideVecTy = VecTy.isVector() - ? LLT::vector(VecTy.getElementCount(), WideTy) - : WideTy; + LLT WideVecTy = VecTy.changeElementType(WideTy); widenScalarSrc(MI, WideVecTy, 1, TargetOpcode::G_FPEXT); widenScalarDst(MI, WideTy, 0, TargetOpcode::G_FPTRUNC); Observer.changedInstr(MI); @@ -3658,7 +3653,8 @@ LegalizerHelper::lowerBitcast(MachineInstr &MI) { // %3:_(<2 x s8>) = G_BITCAST %2 // %4:_(<2 x s8>) = G_BITCAST %3 // %1:_(<4 x s16>) = G_CONCAT_VECTORS %3, %4 - DstCastTy = LLT::fixed_vector(NumDstElt / NumSrcElt, DstEltTy); + DstCastTy = DstTy.changeVectorElementCount( + ElementCount::getFixed(NumDstElt / NumSrcElt)); SrcPartTy = SrcEltTy; } else if (NumSrcElt > NumDstElt) { // Source element type is smaller. // @@ -3670,7 +3666,8 @@ LegalizerHelper::lowerBitcast(MachineInstr &MI) { // %3:_(s16) = G_BITCAST %2 // %4:_(s16) = G_BITCAST %3 // %1:_(<2 x s16>) = G_BUILD_VECTOR %3, %4 - SrcPartTy = LLT::fixed_vector(NumSrcElt / NumDstElt, SrcEltTy); + SrcPartTy = SrcTy.changeVectorElementCount( + ElementCount::getFixed(NumSrcElt / NumDstElt)); DstCastTy = DstEltTy; } @@ -3736,7 +3733,7 @@ LegalizerHelper::bitcastExtractVectorElt(MachineInstr &MI, unsigned TypeIdx, unsigned NewNumElts = CastTy.isVector() ? CastTy.getNumElements() : 1; unsigned OldNumElts = SrcVecTy.getNumElements(); - LLT NewEltTy = CastTy.isVector() ? CastTy.getElementType() : CastTy; + LLT NewEltTy = CastTy.getScalarType(); Register CastVec = MIRBuilder.buildBitcast(CastTy, SrcVec).getReg(0); const unsigned NewEltSize = NewEltTy.getSizeInBits(); @@ -3758,7 +3755,7 @@ LegalizerHelper::bitcastExtractVectorElt(MachineInstr &MI, unsigned TypeIdx, // Type of the intermediate result vector. const unsigned NewEltsPerOldElt = NewNumElts / OldNumElts; LLT MidTy = - LLT::scalarOrVector(ElementCount::getFixed(NewEltsPerOldElt), NewEltTy); + CastTy.changeElementCount(ElementCount::getFixed(NewEltsPerOldElt)); auto NewEltsPerOldEltK = MIRBuilder.buildConstant(IdxTy, NewEltsPerOldElt); @@ -4231,8 +4228,8 @@ LegalizerHelper::LegalizeResult LegalizerHelper::lowerLoad(GAnyLoad &LoadMI) { // the size of the load without needing to scalarize it. if (Alignment.value() * 8 > MemSizeInBits && isPowerOf2_64(DstTy.getScalarSizeInBits())) { - LLT MoreTy = LLT::fixed_vector(NextPowerOf2(DstTy.getNumElements()), - DstTy.getElementType()); + LLT MoreTy = DstTy.changeVectorElementCount( + ElementCount::getFixed(NextPowerOf2(DstTy.getNumElements()))); MachineMemOperand *NewMMO = MF.getMachineMemOperand(&MMO, 0, MoreTy); auto NewLoad = MIRBuilder.buildLoad(MoreTy, PtrReg, *NewMMO); MIRBuilder.buildDeleteTrailingVectorElements(LoadMI.getReg(0), @@ -5023,8 +5020,7 @@ static void makeDstOps(SmallVectorImpl &DstOps, LLT Ty, unsigned NumElts) { LLT LeftoverTy; assert(Ty.isVector() && "Expected vector type"); - LLT EltTy = Ty.getElementType(); - LLT NarrowTy = (NumElts == 1) ? EltTy : LLT::fixed_vector(NumElts, EltTy); + LLT NarrowTy = Ty.changeElementCount(ElementCount::getFixed(NumElts)); int NumParts, NumLeftover; std::tie(NumParts, NumLeftover) = getNarrowTypeBreakDown(Ty, NarrowTy, LeftoverTy); @@ -5705,7 +5701,8 @@ LegalizerHelper::fewerElementsBitcast(MachineInstr &MI, unsigned int TypeIdx, auto Unmerge = MIRBuilder.buildUnmerge(SrcNarrowTy, SrcReg); getUnmergeResults(SrcVRegs, *Unmerge); } else { - LLT SrcNarrowTy = LLT::fixed_vector(NewElemCount, SrcTy.getElementType()); + LLT SrcNarrowTy = + SrcTy.changeVectorElementCount(ElementCount::getFixed(NewElemCount)); // Split the Src and Dst Reg into smaller registers if (extractGCDType(SrcVRegs, DstTy, SrcNarrowTy, SrcReg) != SrcNarrowTy) @@ -6837,8 +6834,7 @@ LegalizerHelper::moreElementsVector(MachineInstr &MI, unsigned TypeIdx, Observer.changingInstr(MI); moreElementsVectorSrc(MI, MoreTy, 2); moreElementsVectorSrc(MI, MoreTy, 3); - LLT CondTy = LLT::fixed_vector( - MoreTy.getNumElements(), + LLT CondTy = MoreTy.changeVectorElementType( MRI.getType(MI.getOperand(0).getReg()).getElementType()); moreElementsVectorDst(MI, CondTy, 0); Observer.changedInstr(MI); @@ -6930,7 +6926,8 @@ LegalizerHelper::equalizeVectorShuffleLengths(MachineInstr &MI) { unsigned PaddedMaskNumElts = alignTo(MaskNumElts, SrcNumElts); unsigned NumConcat = PaddedMaskNumElts / SrcNumElts; - LLT PaddedTy = LLT::fixed_vector(PaddedMaskNumElts, DestEltTy); + LLT PaddedTy = + DstTy.changeVectorElementCount(ElementCount::getFixed(PaddedMaskNumElts)); // Create new source vectors by concatenating the initial // source vectors with undefined vectors of the same size. @@ -9894,9 +9891,7 @@ LegalizerHelper::lowerISFPCLASS(MachineInstr &MI) { unsigned BitSize = SrcTy.getScalarSizeInBits(); const fltSemantics &Semantics = getFltSemanticForLLT(SrcTy.getScalarType()); - LLT IntTy = LLT::scalar(BitSize); - if (SrcTy.isVector()) - IntTy = LLT::vector(SrcTy.getElementCount(), IntTy); + LLT IntTy = SrcTy.changeElementType(LLT::scalar(BitSize)); auto AsInt = MIRBuilder.buildCopy(IntTy, SrcReg); // Various masks.