mirror of
https://github.com/intel/llvm.git
synced 2026-02-07 16:11:27 +08:00
@@ -2216,6 +2216,14 @@ APInt avgCeilS(const APInt &C1, const APInt &C2);
|
||||
/// Compute the ceil of the unsigned average of C1 and C2
|
||||
APInt avgCeilU(const APInt &C1, const APInt &C2);
|
||||
|
||||
/// Performs (2*N)-bit multiplication on sign-extended operands.
|
||||
/// Returns the high N bits of the multiplication result.
|
||||
APInt mulhs(const APInt &C1, const APInt &C2);
|
||||
|
||||
/// Performs (2*N)-bit multiplication on zero-extended operands.
|
||||
/// Returns the high N bits of the multiplication result.
|
||||
APInt mulhu(const APInt &C1, const APInt &C2);
|
||||
|
||||
/// Compute GCD of two unsigned APInt values.
|
||||
///
|
||||
/// This function returns the greatest common divisor of the two APInt values
|
||||
|
||||
@@ -6078,18 +6078,6 @@ static std::optional<APInt> FoldValue(unsigned Opcode, const APInt &C1,
|
||||
if (!C2.getBoolValue())
|
||||
break;
|
||||
return C1.srem(C2);
|
||||
case ISD::MULHS: {
|
||||
unsigned FullWidth = C1.getBitWidth() * 2;
|
||||
APInt C1Ext = C1.sext(FullWidth);
|
||||
APInt C2Ext = C2.sext(FullWidth);
|
||||
return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
|
||||
}
|
||||
case ISD::MULHU: {
|
||||
unsigned FullWidth = C1.getBitWidth() * 2;
|
||||
APInt C1Ext = C1.zext(FullWidth);
|
||||
APInt C2Ext = C2.zext(FullWidth);
|
||||
return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
|
||||
}
|
||||
case ISD::AVGFLOORS:
|
||||
return APIntOps::avgFloorS(C1, C2);
|
||||
case ISD::AVGFLOORU:
|
||||
@@ -6102,10 +6090,13 @@ static std::optional<APInt> FoldValue(unsigned Opcode, const APInt &C1,
|
||||
return APIntOps::abds(C1, C2);
|
||||
case ISD::ABDU:
|
||||
return APIntOps::abdu(C1, C2);
|
||||
case ISD::MULHS:
|
||||
return APIntOps::mulhs(C1, C2);
|
||||
case ISD::MULHU:
|
||||
return APIntOps::mulhu(C1, C2);
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// Handle constant folding with UNDEF.
|
||||
// TODO: Handle more cases.
|
||||
static std::optional<APInt> FoldValueWithUndef(unsigned Opcode, const APInt &C1,
|
||||
|
||||
@@ -3121,3 +3121,19 @@ APInt APIntOps::avgCeilU(const APInt &C1, const APInt &C2) {
|
||||
// Return ceil((C1 + C2) / 2)
|
||||
return (C1 | C2) - (C1 ^ C2).lshr(1);
|
||||
}
|
||||
|
||||
APInt APIntOps::mulhs(const APInt &C1, const APInt &C2) {
|
||||
assert(C1.getBitWidth() == C2.getBitWidth() && "Unequal bitwidths");
|
||||
unsigned FullWidth = C1.getBitWidth() * 2;
|
||||
APInt C1Ext = C1.sext(FullWidth);
|
||||
APInt C2Ext = C2.sext(FullWidth);
|
||||
return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
|
||||
}
|
||||
|
||||
APInt APIntOps::mulhu(const APInt &C1, const APInt &C2) {
|
||||
assert(C1.getBitWidth() == C2.getBitWidth() && "Unequal bitwidths");
|
||||
unsigned FullWidth = C1.getBitWidth() * 2;
|
||||
APInt C1Ext = C1.zext(FullWidth);
|
||||
APInt C2Ext = C2.zext(FullWidth);
|
||||
return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
|
||||
}
|
||||
|
||||
@@ -2841,6 +2841,58 @@ TEST(APIntTest, multiply) {
|
||||
EXPECT_EQ(64U, i96.countr_zero());
|
||||
}
|
||||
|
||||
TEST(APIntOpsTest, Mulh) {
|
||||
|
||||
// Unsigned
|
||||
|
||||
// 32 bits
|
||||
APInt i32a(32, 0x0001'E235);
|
||||
APInt i32b(32, 0xF623'55AD);
|
||||
EXPECT_EQ(0x0001'CFA1, APIntOps::mulhu(i32a, i32b));
|
||||
|
||||
// 64 bits
|
||||
APInt i64a(64, 0x1234'5678'90AB'CDEF);
|
||||
APInt i64b(64, 0xFEDC'BA09'8765'4321);
|
||||
EXPECT_EQ(0x121F'A000'A372'3A57, APIntOps::mulhu(i64a, i64b));
|
||||
|
||||
// 128 bits
|
||||
APInt i128a(128, "1234567890ABCDEF1234567890ABCDEF", 16);
|
||||
APInt i128b(128, "FEDCBA0987654321FEDCBA0987654321", 16);
|
||||
APInt i128Res = APIntOps::mulhu(i128a, i128b);
|
||||
EXPECT_EQ(APInt(128, "121FA000A3723A57E68984312C3A8D7E", 16), i128Res);
|
||||
|
||||
// Signed
|
||||
|
||||
// 32 bits
|
||||
APInt i32c(32, 0x1234'5678); // +ve
|
||||
APInt i32d(32, 0x10AB'CDEF); // +ve
|
||||
APInt i32e(32, 0xFEDC'BA09); // -ve
|
||||
|
||||
EXPECT_EQ(0x012F'7D02, APIntOps::mulhs(i32c, i32d));
|
||||
EXPECT_EQ(0xFFEB'4988, APIntOps::mulhs(i32c, i32e));
|
||||
EXPECT_EQ(0x0001'4B68, APIntOps::mulhs(i32e, i32e));
|
||||
|
||||
// 64 bits
|
||||
APInt i64c(64, 0x1234'5678'90AB'CDEF); // +ve
|
||||
APInt i64d(64, 0x1234'5678'90FE'DCBA); // +ve
|
||||
APInt i64e(64, 0xFEDC'BA09'8765'4321); // -ve
|
||||
|
||||
EXPECT_EQ(0x014B'66DC'328E'10C1, APIntOps::mulhs(i64c, i64d));
|
||||
EXPECT_EQ(0xFFEB'4988'12C6'6C68, APIntOps::mulhs(i64c, i64e));
|
||||
EXPECT_EQ(0x0001'4B68'2174'FA18, APIntOps::mulhs(i64e, i64e));
|
||||
|
||||
// 128 bits
|
||||
APInt i128c(128, "1234567890ABCDEF1234567890ABCDEF", 16); // +ve
|
||||
APInt i128d(128, "1234567890FEDCBA1234567890FEDCBA", 16); // +ve
|
||||
APInt i128e(128, "FEDCBA0987654321FEDCBA0987654321", 16); // -ve
|
||||
|
||||
i128Res = APIntOps::mulhs(i128c, i128d);
|
||||
EXPECT_EQ(APInt(128, "14B66DC328E10C1FE303DF9EA0B2529", 16), i128Res);
|
||||
|
||||
i128Res = APIntOps::mulhs(i128c, i128e);
|
||||
EXPECT_EQ(APInt(128, "FFEB498812C66C68D4552DB89B8EBF8F", 16), i128Res);
|
||||
}
|
||||
|
||||
TEST(APIntTest, RoundingUDiv) {
|
||||
for (uint64_t Ai = 1; Ai <= 255; Ai++) {
|
||||
APInt A(8, Ai);
|
||||
|
||||
@@ -553,17 +553,11 @@ TEST(KnownBitsTest, BinaryExhaustive) {
|
||||
checkCorrectnessOnlyBinary);
|
||||
testBinaryOpExhaustive(
|
||||
KnownBits::mulhs,
|
||||
[](const APInt &N1, const APInt &N2) {
|
||||
unsigned Bits = N1.getBitWidth();
|
||||
return (N1.sext(2 * Bits) * N2.sext(2 * Bits)).extractBits(Bits, Bits);
|
||||
},
|
||||
[](const APInt &N1, const APInt &N2) { return APIntOps::mulhs(N1, N2); },
|
||||
checkCorrectnessOnlyBinary);
|
||||
testBinaryOpExhaustive(
|
||||
KnownBits::mulhu,
|
||||
[](const APInt &N1, const APInt &N2) {
|
||||
unsigned Bits = N1.getBitWidth();
|
||||
return (N1.zext(2 * Bits) * N2.zext(2 * Bits)).extractBits(Bits, Bits);
|
||||
},
|
||||
[](const APInt &N1, const APInt &N2) { return APIntOps::mulhu(N1, N2); },
|
||||
checkCorrectnessOnlyBinary);
|
||||
}
|
||||
|
||||
|
||||
@@ -457,9 +457,7 @@ arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
|
||||
// Invoke the constant fold helper again to calculate the 'high' result.
|
||||
Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
|
||||
adaptor.getOperands(), [](const APInt &a, const APInt &b) {
|
||||
unsigned bitWidth = a.getBitWidth();
|
||||
APInt fullProduct = a.sext(bitWidth * 2) * b.sext(bitWidth * 2);
|
||||
return fullProduct.extractBits(bitWidth, bitWidth);
|
||||
return llvm::APIntOps::mulhs(a, b);
|
||||
});
|
||||
assert(highAttr && "Unexpected constant-folding failure");
|
||||
|
||||
@@ -514,9 +512,7 @@ arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
|
||||
// Invoke the constant fold helper again to calculate the 'high' result.
|
||||
Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
|
||||
adaptor.getOperands(), [](const APInt &a, const APInt &b) {
|
||||
unsigned bitWidth = a.getBitWidth();
|
||||
APInt fullProduct = a.zext(bitWidth * 2) * b.zext(bitWidth * 2);
|
||||
return fullProduct.extractBits(bitWidth, bitWidth);
|
||||
return llvm::APIntOps::mulhu(a, b);
|
||||
});
|
||||
assert(highAttr && "Unexpected constant-folding failure");
|
||||
|
||||
|
||||
@@ -250,14 +250,11 @@ struct MulExtendedFold final : OpRewritePattern<MulOp> {
|
||||
|
||||
auto highBits = constFoldBinaryOp<IntegerAttr>(
|
||||
{lhsAttr, rhsAttr}, [](const APInt &a, const APInt &b) {
|
||||
unsigned bitWidth = a.getBitWidth();
|
||||
APInt c;
|
||||
if (IsSigned) {
|
||||
c = a.sext(bitWidth * 2) * b.sext(bitWidth * 2);
|
||||
return llvm::APIntOps::mulhs(a, b);
|
||||
} else {
|
||||
c = a.zext(bitWidth * 2) * b.zext(bitWidth * 2);
|
||||
return llvm::APIntOps::mulhu(a, b);
|
||||
}
|
||||
return c.extractBits(bitWidth, bitWidth); // Extract high result
|
||||
});
|
||||
|
||||
if (!highBits)
|
||||
|
||||
Reference in New Issue
Block a user