[GlobalISel](NFC) Refactor construction of LLTs in LegalizerHelper (#170664)

I spotted a number of places where we're duplicating logic provided by
the `LLT` class inline in `LegalizerHelper`. This PR tidies up these
spots.
This commit is contained in:
Nathan Corbyn
2025-12-15 12:26:27 +00:00
committed by GitHub
parent 72f3995363
commit 2f9bf3f292
2 changed files with 37 additions and 28 deletions

View File

@@ -207,10 +207,16 @@ public:
return isVector() ? getElementType() : *this;
}
/// Returns a vector with the same number of elements but the new element
/// type. Must only be called on vector types.
constexpr LLT changeVectorElementType(LLT NewEltTy) const {
return LLT::vector(getElementCount(), NewEltTy);
}
/// If this type is a vector, return a vector with the same number of elements
/// but the new element type. Otherwise, return the new element type.
constexpr LLT changeElementType(LLT NewEltTy) const {
return isVector() ? LLT::vector(getElementCount(), NewEltTy) : NewEltTy;
return isVector() ? changeVectorElementType(NewEltTy) : NewEltTy;
}
/// If this type is a vector, return a vector with the same number of elements
@@ -223,6 +229,14 @@ public:
: LLT::scalar(NewEltSize);
}
/// Return a vector with the same element type and the new element count. Must
/// be called on vector types.
constexpr LLT changeVectorElementCount(ElementCount EC) const {
assert(isVector() &&
"cannot change vector element count of non-vector type");
return LLT::vector(EC, getElementType());
}
/// Return a vector or scalar with the same element type and the new element
/// count.
constexpr LLT changeElementCount(ElementCount EC) const {

View File

@@ -70,9 +70,8 @@ getNarrowTypeBreakDown(LLT OrigTy, LLT NarrowTy, LLT &LeftoverTy) {
unsigned EltSize = OrigTy.getScalarSizeInBits();
if (LeftoverSize % EltSize != 0)
return {-1, -1};
LeftoverTy =
LLT::scalarOrVector(ElementCount::getFixed(LeftoverSize / EltSize),
OrigTy.getElementType());
LeftoverTy = OrigTy.changeElementCount(
ElementCount::getFixed(LeftoverSize / EltSize));
} else {
LeftoverTy = LLT::scalar(LeftoverSize);
}
@@ -1558,10 +1557,7 @@ LegalizerHelper::LegalizeResult LegalizerHelper::narrowScalar(MachineInstr &MI,
// combines not being hit). This seems to be a problem related to the
// artifact combiner.
if (SizeOp0 % NarrowSize != 0) {
LLT ImplicitTy = NarrowTy;
if (DstTy.isVector())
ImplicitTy = LLT::vector(DstTy.getElementCount(), ImplicitTy);
LLT ImplicitTy = DstTy.changeElementType(NarrowTy);
Register ImplicitReg = MIRBuilder.buildUndef(ImplicitTy).getReg(0);
MIRBuilder.buildAnyExt(DstReg, ImplicitReg);
@@ -3289,7 +3285,8 @@ LegalizerHelper::widenScalar(MachineInstr &MI, unsigned TypeIdx, LLT WideTy) {
Observer.changingInstr(MI);
widenScalarSrc(
MI, LLT::vector(VecTy.getElementCount(), WideTy.getSizeInBits()), 1,
MI,
VecTy.changeVectorElementType(LLT::scalar(WideTy.getSizeInBits())), 1,
TargetOpcode::G_ANYEXT);
widenScalarDst(MI, WideTy, 0);
@@ -3321,7 +3318,7 @@ LegalizerHelper::widenScalar(MachineInstr &MI, unsigned TypeIdx, LLT WideTy) {
Register VecReg = MI.getOperand(1).getReg();
LLT VecTy = MRI.getType(VecReg);
LLT WideVecTy = LLT::vector(VecTy.getElementCount(), WideTy);
LLT WideVecTy = VecTy.changeVectorElementType(WideTy);
widenScalarSrc(MI, WideVecTy, 1, TargetOpcode::G_ANYEXT);
widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_ANYEXT);
@@ -3522,9 +3519,7 @@ LegalizerHelper::widenScalar(MachineInstr &MI, unsigned TypeIdx, LLT WideTy) {
Observer.changingInstr(MI);
Register VecReg = MI.getOperand(1).getReg();
LLT VecTy = MRI.getType(VecReg);
LLT WideVecTy = VecTy.isVector()
? LLT::vector(VecTy.getElementCount(), WideTy)
: WideTy;
LLT WideVecTy = VecTy.changeElementType(WideTy);
widenScalarSrc(MI, WideVecTy, 1, TargetOpcode::G_FPEXT);
widenScalarDst(MI, WideTy, 0, TargetOpcode::G_FPTRUNC);
Observer.changedInstr(MI);
@@ -3658,7 +3653,8 @@ LegalizerHelper::lowerBitcast(MachineInstr &MI) {
// %3:_(<2 x s8>) = G_BITCAST %2
// %4:_(<2 x s8>) = G_BITCAST %3
// %1:_(<4 x s16>) = G_CONCAT_VECTORS %3, %4
DstCastTy = LLT::fixed_vector(NumDstElt / NumSrcElt, DstEltTy);
DstCastTy = DstTy.changeVectorElementCount(
ElementCount::getFixed(NumDstElt / NumSrcElt));
SrcPartTy = SrcEltTy;
} else if (NumSrcElt > NumDstElt) { // Source element type is smaller.
//
@@ -3670,7 +3666,8 @@ LegalizerHelper::lowerBitcast(MachineInstr &MI) {
// %3:_(s16) = G_BITCAST %2
// %4:_(s16) = G_BITCAST %3
// %1:_(<2 x s16>) = G_BUILD_VECTOR %3, %4
SrcPartTy = LLT::fixed_vector(NumSrcElt / NumDstElt, SrcEltTy);
SrcPartTy = SrcTy.changeVectorElementCount(
ElementCount::getFixed(NumSrcElt / NumDstElt));
DstCastTy = DstEltTy;
}
@@ -3736,7 +3733,7 @@ LegalizerHelper::bitcastExtractVectorElt(MachineInstr &MI, unsigned TypeIdx,
unsigned NewNumElts = CastTy.isVector() ? CastTy.getNumElements() : 1;
unsigned OldNumElts = SrcVecTy.getNumElements();
LLT NewEltTy = CastTy.isVector() ? CastTy.getElementType() : CastTy;
LLT NewEltTy = CastTy.getScalarType();
Register CastVec = MIRBuilder.buildBitcast(CastTy, SrcVec).getReg(0);
const unsigned NewEltSize = NewEltTy.getSizeInBits();
@@ -3758,7 +3755,7 @@ LegalizerHelper::bitcastExtractVectorElt(MachineInstr &MI, unsigned TypeIdx,
// Type of the intermediate result vector.
const unsigned NewEltsPerOldElt = NewNumElts / OldNumElts;
LLT MidTy =
LLT::scalarOrVector(ElementCount::getFixed(NewEltsPerOldElt), NewEltTy);
CastTy.changeElementCount(ElementCount::getFixed(NewEltsPerOldElt));
auto NewEltsPerOldEltK = MIRBuilder.buildConstant(IdxTy, NewEltsPerOldElt);
@@ -4231,8 +4228,8 @@ LegalizerHelper::LegalizeResult LegalizerHelper::lowerLoad(GAnyLoad &LoadMI) {
// the size of the load without needing to scalarize it.
if (Alignment.value() * 8 > MemSizeInBits &&
isPowerOf2_64(DstTy.getScalarSizeInBits())) {
LLT MoreTy = LLT::fixed_vector(NextPowerOf2(DstTy.getNumElements()),
DstTy.getElementType());
LLT MoreTy = DstTy.changeVectorElementCount(
ElementCount::getFixed(NextPowerOf2(DstTy.getNumElements())));
MachineMemOperand *NewMMO = MF.getMachineMemOperand(&MMO, 0, MoreTy);
auto NewLoad = MIRBuilder.buildLoad(MoreTy, PtrReg, *NewMMO);
MIRBuilder.buildDeleteTrailingVectorElements(LoadMI.getReg(0),
@@ -5023,8 +5020,7 @@ static void makeDstOps(SmallVectorImpl<DstOp> &DstOps, LLT Ty,
unsigned NumElts) {
LLT LeftoverTy;
assert(Ty.isVector() && "Expected vector type");
LLT EltTy = Ty.getElementType();
LLT NarrowTy = (NumElts == 1) ? EltTy : LLT::fixed_vector(NumElts, EltTy);
LLT NarrowTy = Ty.changeElementCount(ElementCount::getFixed(NumElts));
int NumParts, NumLeftover;
std::tie(NumParts, NumLeftover) =
getNarrowTypeBreakDown(Ty, NarrowTy, LeftoverTy);
@@ -5705,7 +5701,8 @@ LegalizerHelper::fewerElementsBitcast(MachineInstr &MI, unsigned int TypeIdx,
auto Unmerge = MIRBuilder.buildUnmerge(SrcNarrowTy, SrcReg);
getUnmergeResults(SrcVRegs, *Unmerge);
} else {
LLT SrcNarrowTy = LLT::fixed_vector(NewElemCount, SrcTy.getElementType());
LLT SrcNarrowTy =
SrcTy.changeVectorElementCount(ElementCount::getFixed(NewElemCount));
// Split the Src and Dst Reg into smaller registers
if (extractGCDType(SrcVRegs, DstTy, SrcNarrowTy, SrcReg) != SrcNarrowTy)
@@ -6837,8 +6834,7 @@ LegalizerHelper::moreElementsVector(MachineInstr &MI, unsigned TypeIdx,
Observer.changingInstr(MI);
moreElementsVectorSrc(MI, MoreTy, 2);
moreElementsVectorSrc(MI, MoreTy, 3);
LLT CondTy = LLT::fixed_vector(
MoreTy.getNumElements(),
LLT CondTy = MoreTy.changeVectorElementType(
MRI.getType(MI.getOperand(0).getReg()).getElementType());
moreElementsVectorDst(MI, CondTy, 0);
Observer.changedInstr(MI);
@@ -6930,7 +6926,8 @@ LegalizerHelper::equalizeVectorShuffleLengths(MachineInstr &MI) {
unsigned PaddedMaskNumElts = alignTo(MaskNumElts, SrcNumElts);
unsigned NumConcat = PaddedMaskNumElts / SrcNumElts;
LLT PaddedTy = LLT::fixed_vector(PaddedMaskNumElts, DestEltTy);
LLT PaddedTy =
DstTy.changeVectorElementCount(ElementCount::getFixed(PaddedMaskNumElts));
// Create new source vectors by concatenating the initial
// source vectors with undefined vectors of the same size.
@@ -9894,9 +9891,7 @@ LegalizerHelper::lowerISFPCLASS(MachineInstr &MI) {
unsigned BitSize = SrcTy.getScalarSizeInBits();
const fltSemantics &Semantics = getFltSemanticForLLT(SrcTy.getScalarType());
LLT IntTy = LLT::scalar(BitSize);
if (SrcTy.isVector())
IntTy = LLT::vector(SrcTy.getElementCount(), IntTy);
LLT IntTy = SrcTy.changeElementType(LLT::scalar(BitSize));
auto AsInt = MIRBuilder.buildCopy(IntTy, SrcReg);
// Various masks.