[InstCombine] Generate better code for std::bit_ceil

Without this patch, std::bit_ceil<uint32_t> is compiled as:

  %dec = add i32 %x, -1
  %lz = tail call i32 @llvm.ctlz.i32(i32 %dec, i1 false)
  %sub = sub i32 32, %lz
  %res = shl i32 1, %sub
  %ugt = icmp ugt i32 %x, 1
  %sel = select i1 %ugt, i32 %res, i32 1

With this patch, we generate:

  %dec = add i32 %x, -1
  %ctlz = tail call i32 @llvm.ctlz.i32(i32 %dec, i1 false)
  %sub = sub nsw i32 0, %ctlz
  %and = and i32 %1, 31
  %sel = shl nuw i32 1, %and
  ret i32 %sel

https://alive2.llvm.org/ce/z/pwezvF

This patch recognizes the specific pattern from std::bit_ceil in
libc++ and libstdc++ and drops the conditional move.  In addition to
the LLVM IR generated for std::bit_ceil(X), this patch recognizes
variants like:

  std::bit_ceil(X - 1)
  std::bit_ceil(X + 1)
  std::bit_ceil(X + 2)
  std::bit_ceil(-X)
  std::bit_ceil(~X)

This patch fixes:

https://github.com/llvm/llvm-project/issues/60802

Differential Revision: https://reviews.llvm.org/D145299
This commit is contained in:
Kazu Hirata
2023-03-23 19:26:43 -07:00
parent 5f48b861f8
commit 231fa27435
2 changed files with 160 additions and 41 deletions

View File

@@ -3163,6 +3163,134 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) {
return nullptr;
}
// Return true if we can safely remove the select instruction for std::bit_ceil
// pattern.
static bool isSafeToRemoveBitCeilSelect(ICmpInst::Predicate Pred, Value *Cond0,
const APInt *Cond1, Value *CtlzOp,
unsigned BitWidth) {
// The challenge in recognizing std::bit_ceil(X) is that the operand is used
// for the CTLZ proper and select condition, each possibly with some
// operation like add and sub.
//
// Our aim is to make sure that -ctlz & (BitWidth - 1) == 0 even when the
// select instruction would select 1, which allows us to get rid of the select
// instruction.
//
// To see if we can do so, we do some symbolic execution with ConstantRange.
// Specifically, we compute the range of values that Cond0 could take when
// Cond == false. Then we successively transform the range until we obtain
// the range of values that CtlzOp could take.
//
// Conceptually, we follow the def-use chain backward from Cond0 while
// transforming the range for Cond0 until we meet the common ancestor of Cond0
// and CtlzOp. Then we follow the def-use chain forward until we obtain the
// range for CtlzOp. That said, we only follow at most one ancestor from
// Cond0. Likewise, we only follow at most one ancestor from CtrlOp.
ConstantRange CR = ConstantRange::makeExactICmpRegion(
CmpInst::getInversePredicate(Pred), *Cond1);
// Match the operation that's used to compute CtlzOp from CommonAncestor. If
// CtlzOp == CommonAncestor, return true as no operation is needed. If a
// match is found, execute the operation on CR, update CR, and return true.
// Otherwise, return false.
auto MatchForward = [&](Value *CommonAncestor) {
const APInt *C = nullptr;
if (CtlzOp == CommonAncestor)
return true;
if (match(CtlzOp, m_Add(m_Specific(CommonAncestor), m_APInt(C)))) {
CR = CR.add(*C);
return true;
}
if (match(CtlzOp, m_Sub(m_APInt(C), m_Specific(CommonAncestor)))) {
CR = ConstantRange(*C).sub(CR);
return true;
}
if (match(CtlzOp, m_Not(m_Specific(CommonAncestor)))) {
CR = CR.binaryNot();
return true;
}
return false;
};
const APInt *C = nullptr;
Value *CommonAncestor;
if (MatchForward(Cond0)) {
// Cond0 is either CtlzOp or CtlzOp's parent. CR has been updated.
} else if (match(Cond0, m_Add(m_Value(CommonAncestor), m_APInt(C)))) {
CR = CR.sub(*C);
if (!MatchForward(CommonAncestor))
return false;
// Cond0's parent is either CtlzOp or CtlzOp's parent. CR has been updated.
} else {
return false;
}
// Return true if all the values in the range are either 0 or negative (if
// treated as signed). We do so by evaluating:
//
// CR - 1 u>= (1 << BitWidth) - 1.
APInt IntMax = APInt::getSignMask(BitWidth) - 1;
CR = CR.sub(APInt(BitWidth, 1));
return CR.icmp(ICmpInst::ICMP_UGE, IntMax);
}
// Transform the std::bit_ceil(X) pattern like:
//
// %dec = add i32 %x, -1
// %ctlz = tail call i32 @llvm.ctlz.i32(i32 %dec, i1 false)
// %sub = sub i32 32, %ctlz
// %shl = shl i32 1, %sub
// %ugt = icmp ugt i32 %x, 1
// %sel = select i1 %ugt, i32 %shl, i32 1
//
// into:
//
// %dec = add i32 %x, -1
// %ctlz = tail call i32 @llvm.ctlz.i32(i32 %dec, i1 false)
// %neg = sub i32 0, %ctlz
// %masked = and i32 %ctlz, 31
// %shl = shl i32 1, %sub
//
// Note that the select is optimized away while the shift count is masked with
// 31. We handle some variations of the input operand like std::bit_ceil(X +
// 1).
static Instruction *foldBitCeil(SelectInst &SI, IRBuilderBase &Builder) {
Type *SelType = SI.getType();
unsigned BitWidth = SelType->getScalarSizeInBits();
Value *FalseVal = SI.getFalseValue();
Value *TrueVal = SI.getTrueValue();
ICmpInst::Predicate Pred;
const APInt *Cond1;
Value *Cond0, *Ctlz, *CtlzOp;
if (!match(SI.getCondition(), m_ICmp(Pred, m_Value(Cond0), m_APInt(Cond1))))
return nullptr;
if (match(TrueVal, m_One())) {
std::swap(FalseVal, TrueVal);
Pred = CmpInst::getInversePredicate(Pred);
}
if (!match(FalseVal, m_One()) ||
!match(TrueVal,
m_OneUse(m_Shl(m_One(), m_OneUse(m_Sub(m_SpecificInt(BitWidth),
m_Value(Ctlz)))))) ||
!match(Ctlz, m_Intrinsic<Intrinsic::ctlz>(m_Value(CtlzOp), m_Zero())) ||
!isSafeToRemoveBitCeilSelect(Pred, Cond0, Cond1, CtlzOp, BitWidth))
return nullptr;
// Build 1 << (-CTLZ & (BitWidth-1)). The negation likely corresponds to a
// single hardware instruction as opposed to BitWidth - CTLZ, where BitWidth
// is an integer constant. Masking with BitWidth-1 comes free on some
// hardware as part of the shift instruction.
Value *Neg = Builder.CreateNeg(Ctlz);
Value *Masked =
Builder.CreateAnd(Neg, ConstantInt::get(SelType, BitWidth - 1));
return BinaryOperator::Create(Instruction::Shl, ConstantInt::get(SelType, 1),
Masked);
}
Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
Value *CondVal = SI.getCondition();
Value *TrueVal = SI.getTrueValue();
@@ -3590,5 +3718,8 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
if (sinkNotIntoOtherHandOfLogicalOp(SI))
return &SI;
if (Instruction *I = foldBitCeil(SI, Builder))
return I;
return nullptr;
}

View File

@@ -6,10 +6,9 @@ define i32 @bit_ceil_32(i32 %x) {
; CHECK-LABEL: @bit_ceil_32(
; CHECK-NEXT: [[DEC:%.*]] = add i32 [[X:%.*]], -1
; CHECK-NEXT: [[CTLZ:%.*]] = tail call i32 @llvm.ctlz.i32(i32 [[DEC]], i1 false), !range [[RNG0:![0-9]+]]
; CHECK-NEXT: [[SUB:%.*]] = sub nuw nsw i32 32, [[CTLZ]]
; CHECK-NEXT: [[SHL:%.*]] = shl nuw i32 1, [[SUB]]
; CHECK-NEXT: [[UGT:%.*]] = icmp ugt i32 [[X]], 1
; CHECK-NEXT: [[SEL:%.*]] = select i1 [[UGT]], i32 [[SHL]], i32 1
; CHECK-NEXT: [[TMP1:%.*]] = sub nsw i32 0, [[CTLZ]]
; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[TMP1]], 31
; CHECK-NEXT: [[SEL:%.*]] = shl nuw i32 1, [[TMP2]]
; CHECK-NEXT: ret i32 [[SEL]]
;
%dec = add i32 %x, -1
@@ -26,10 +25,9 @@ define i64 @bit_ceil_64(i64 %x) {
; CHECK-LABEL: @bit_ceil_64(
; CHECK-NEXT: [[DEC:%.*]] = add i64 [[X:%.*]], -1
; CHECK-NEXT: [[CTLZ:%.*]] = tail call i64 @llvm.ctlz.i64(i64 [[DEC]], i1 false), !range [[RNG1:![0-9]+]]
; CHECK-NEXT: [[SUB:%.*]] = sub nuw nsw i64 64, [[CTLZ]]
; CHECK-NEXT: [[SHL:%.*]] = shl nuw i64 1, [[SUB]]
; CHECK-NEXT: [[UGT:%.*]] = icmp ugt i64 [[X]], 1
; CHECK-NEXT: [[SEL:%.*]] = select i1 [[UGT]], i64 [[SHL]], i64 1
; CHECK-NEXT: [[TMP1:%.*]] = sub nsw i64 0, [[CTLZ]]
; CHECK-NEXT: [[TMP2:%.*]] = and i64 [[TMP1]], 63
; CHECK-NEXT: [[SEL:%.*]] = shl nuw i64 1, [[TMP2]]
; CHECK-NEXT: ret i64 [[SEL]]
;
%dec = add i64 %x, -1
@@ -47,11 +45,9 @@ define i32 @bit_ceil_32_minus_1(i32 %x) {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[SUB:%.*]] = add i32 [[X:%.*]], -2
; CHECK-NEXT: [[CTLZ:%.*]] = tail call i32 @llvm.ctlz.i32(i32 [[SUB]], i1 false), !range [[RNG0]]
; CHECK-NEXT: [[SUB2:%.*]] = sub nuw nsw i32 32, [[CTLZ]]
; CHECK-NEXT: [[SHL:%.*]] = shl nuw i32 1, [[SUB2]]
; CHECK-NEXT: [[ADD:%.*]] = add i32 [[X]], -3
; CHECK-NEXT: [[ULT:%.*]] = icmp ult i32 [[ADD]], -2
; CHECK-NEXT: [[SEL:%.*]] = select i1 [[ULT]], i32 [[SHL]], i32 1
; CHECK-NEXT: [[TMP0:%.*]] = sub nsw i32 0, [[CTLZ]]
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[TMP0]], 31
; CHECK-NEXT: [[SEL:%.*]] = shl nuw i32 1, [[TMP1]]
; CHECK-NEXT: ret i32 [[SEL]]
;
entry:
@@ -69,11 +65,9 @@ entry:
define i32 @bit_ceil_32_plus_1(i32 %x) {
; CHECK-LABEL: @bit_ceil_32_plus_1(
; CHECK-NEXT: [[CTLZ:%.*]] = tail call i32 @llvm.ctlz.i32(i32 [[X:%.*]], i1 false), !range [[RNG0]]
; CHECK-NEXT: [[SUB:%.*]] = sub nuw nsw i32 32, [[CTLZ]]
; CHECK-NEXT: [[SHL:%.*]] = shl nuw i32 1, [[SUB]]
; CHECK-NEXT: [[DEC:%.*]] = add i32 [[X]], -1
; CHECK-NEXT: [[ULT:%.*]] = icmp ult i32 [[DEC]], -2
; CHECK-NEXT: [[SEL:%.*]] = select i1 [[ULT]], i32 [[SHL]], i32 1
; CHECK-NEXT: [[TMP1:%.*]] = sub nsw i32 0, [[CTLZ]]
; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[TMP1]], 31
; CHECK-NEXT: [[SEL:%.*]] = shl nuw i32 1, [[TMP2]]
; CHECK-NEXT: ret i32 [[SEL]]
;
%ctlz = tail call i32 @llvm.ctlz.i32(i32 %x, i1 false)
@@ -91,10 +85,9 @@ define i32 @bit_ceil_plus_2(i32 %x) {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[SUB:%.*]] = add i32 [[X:%.*]], 1
; CHECK-NEXT: [[CTLZ:%.*]] = tail call i32 @llvm.ctlz.i32(i32 [[SUB]], i1 false), !range [[RNG0]]
; CHECK-NEXT: [[SUB2:%.*]] = sub nuw nsw i32 32, [[CTLZ]]
; CHECK-NEXT: [[SHL:%.*]] = shl nuw i32 1, [[SUB2]]
; CHECK-NEXT: [[ULT:%.*]] = icmp ult i32 [[X]], -2
; CHECK-NEXT: [[SEL:%.*]] = select i1 [[ULT]], i32 [[SHL]], i32 1
; CHECK-NEXT: [[TMP0:%.*]] = sub nsw i32 0, [[CTLZ]]
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[TMP0]], 31
; CHECK-NEXT: [[SEL:%.*]] = shl nuw i32 1, [[TMP1]]
; CHECK-NEXT: ret i32 [[SEL]]
;
entry:
@@ -113,11 +106,9 @@ define i32 @bit_ceil_32_neg(i32 %x) {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[SUB:%.*]] = xor i32 [[X:%.*]], -1
; CHECK-NEXT: [[CTLZ:%.*]] = tail call i32 @llvm.ctlz.i32(i32 [[SUB]], i1 false), !range [[RNG0]]
; CHECK-NEXT: [[SUB2:%.*]] = sub nuw nsw i32 32, [[CTLZ]]
; CHECK-NEXT: [[SHL:%.*]] = shl nuw i32 1, [[SUB2]]
; CHECK-NEXT: [[NOTSUB:%.*]] = add i32 [[X]], -1
; CHECK-NEXT: [[ULT:%.*]] = icmp ult i32 [[NOTSUB]], -2
; CHECK-NEXT: [[SEL:%.*]] = select i1 [[ULT]], i32 [[SHL]], i32 1
; CHECK-NEXT: [[TMP0:%.*]] = sub nsw i32 0, [[CTLZ]]
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[TMP0]], 31
; CHECK-NEXT: [[SEL:%.*]] = shl nuw i32 1, [[TMP1]]
; CHECK-NEXT: ret i32 [[SEL]]
;
entry:
@@ -137,10 +128,9 @@ define i32 @bit_ceil_not(i32 %x) {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[SUB:%.*]] = sub i32 -2, [[X:%.*]]
; CHECK-NEXT: [[CTLZ:%.*]] = tail call i32 @llvm.ctlz.i32(i32 [[SUB]], i1 false), !range [[RNG0]]
; CHECK-NEXT: [[SUB2:%.*]] = sub nuw nsw i32 32, [[CTLZ]]
; CHECK-NEXT: [[SHL:%.*]] = shl nuw i32 1, [[SUB2]]
; CHECK-NEXT: [[ULT:%.*]] = icmp ult i32 [[X]], -2
; CHECK-NEXT: [[SEL:%.*]] = select i1 [[ULT]], i32 [[SHL]], i32 1
; CHECK-NEXT: [[TMP0:%.*]] = sub nsw i32 0, [[CTLZ]]
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[TMP0]], 31
; CHECK-NEXT: [[SEL:%.*]] = shl nuw i32 1, [[TMP1]]
; CHECK-NEXT: ret i32 [[SEL]]
;
entry:
@@ -158,18 +148,17 @@ define i32 @bit_ceil_commuted_operands(i32 %x) {
; CHECK-LABEL: @bit_ceil_commuted_operands(
; CHECK-NEXT: [[DEC:%.*]] = add i32 [[X:%.*]], -1
; CHECK-NEXT: [[CTLZ:%.*]] = tail call i32 @llvm.ctlz.i32(i32 [[DEC]], i1 false), !range [[RNG0]]
; CHECK-NEXT: [[SUB:%.*]] = sub nuw nsw i32 32, [[CTLZ]]
; CHECK-NEXT: [[SHL:%.*]] = shl nuw i32 1, [[SUB]]
; CHECK-NEXT: [[UGT_INV:%.*]] = icmp ugt i32 [[X]], 1
; CHECK-NEXT: [[SEL:%.*]] = select i1 [[UGT_INV]], i32 [[SHL]], i32 1
; CHECK-NEXT: [[TMP1:%.*]] = sub nsw i32 0, [[CTLZ]]
; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[TMP1]], 31
; CHECK-NEXT: [[SEL:%.*]] = shl nuw i32 1, [[TMP2]]
; CHECK-NEXT: ret i32 [[SEL]]
;
%dec = add i32 %x, -1
%ctlz = tail call i32 @llvm.ctlz.i32(i32 %dec, i1 false)
%sub = sub i32 32, %ctlz
%shl = shl i32 1, %sub
%ugt = icmp ule i32 %x, 1
%sel = select i1 %ugt, i32 1, i32 %shl
%eq = icmp eq i32 %dec, 0
%sel = select i1 %eq, i32 1, i32 %shl
ret i32 %sel
}
@@ -282,10 +271,9 @@ define <4 x i32> @bit_ceil_v4i32(<4 x i32> %x) {
; CHECK-LABEL: @bit_ceil_v4i32(
; CHECK-NEXT: [[DEC:%.*]] = add <4 x i32> [[X:%.*]], <i32 -1, i32 -1, i32 -1, i32 -1>
; CHECK-NEXT: [[CTLZ:%.*]] = tail call <4 x i32> @llvm.ctlz.v4i32(<4 x i32> [[DEC]], i1 false), !range [[RNG0]]
; CHECK-NEXT: [[SUB:%.*]] = sub nuw nsw <4 x i32> <i32 32, i32 32, i32 32, i32 32>, [[CTLZ]]
; CHECK-NEXT: [[SHL:%.*]] = shl nuw <4 x i32> <i32 1, i32 1, i32 1, i32 1>, [[SUB]]
; CHECK-NEXT: [[UGT:%.*]] = icmp ugt <4 x i32> [[X]], <i32 1, i32 1, i32 1, i32 1>
; CHECK-NEXT: [[SEL:%.*]] = select <4 x i1> [[UGT]], <4 x i32> [[SHL]], <4 x i32> <i32 1, i32 1, i32 1, i32 1>
; CHECK-NEXT: [[TMP1:%.*]] = sub nsw <4 x i32> zeroinitializer, [[CTLZ]]
; CHECK-NEXT: [[TMP2:%.*]] = and <4 x i32> [[TMP1]], <i32 31, i32 31, i32 31, i32 31>
; CHECK-NEXT: [[SEL:%.*]] = shl nuw <4 x i32> <i32 1, i32 1, i32 1, i32 1>, [[TMP2]]
; CHECK-NEXT: ret <4 x i32> [[SEL]]
;
%dec = add <4 x i32> %x, <i32 -1, i32 -1, i32 -1, i32 -1>