[llvm][RISCV] Support mulh for P extension codegen (#171581)

For mulh pattern with operands that are both signed or unsigned,
combination is performed automatically. However for mulh with operands
which are signed and unsigned respectively we need to combine them
manually same approach as what we've done for PASUB*.

Note: This is first patch for mulh which only handle basic high part
multiplication, there will be followup patches to handle rest of mulh
related instructions.
This commit is contained in:
Brandon Wu
2025-12-15 22:55:26 +08:00
committed by GitHub
parent 8975eb3274
commit ef927ae263
4 changed files with 297 additions and 36 deletions

View File

@@ -15265,18 +15265,22 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
break;
}
case RISCVISD::PASUB:
case RISCVISD::PASUBU: {
case RISCVISD::PASUBU:
case RISCVISD::PMULHSU: {
MVT VT = N->getSimpleValueType(0);
SDValue Op0 = N->getOperand(0);
SDValue Op1 = N->getOperand(1);
assert(VT == MVT::v2i16 || VT == MVT::v4i8);
unsigned Opcode = N->getOpcode();
// PMULHSU doesn't support i8 variants
assert(VT == MVT::v2i16 ||
(Opcode != RISCVISD::PMULHSU && VT == MVT::v4i8));
MVT NewVT = MVT::v4i16;
if (VT == MVT::v4i8)
NewVT = MVT::v8i8;
SDValue Undef = DAG.getUNDEF(VT);
Op0 = DAG.getNode(ISD::CONCAT_VECTORS, DL, NewVT, {Op0, Undef});
Op1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, NewVT, {Op1, Undef});
Results.push_back(DAG.getNode(N->getOpcode(), DL, NewVT, {Op0, Op1}));
Results.push_back(DAG.getNode(Opcode, DL, NewVT, {Op0, Op1}));
return;
}
case ISD::EXTRACT_VECTOR_ELT: {
@@ -16386,9 +16390,9 @@ static SDValue combineTruncSelectToSMaxUSat(SDNode *N, SelectionDAG &DAG) {
return DAG.getNode(ISD::TRUNCATE, DL, VT, Min);
}
// Handle P extension averaging subtraction pattern:
// (vXiY (trunc (srl (sub ([s|z]ext vXiY:$a), ([s|z]ext vXiY:$b)), 1)))
// -> PASUB/PASUBU
// Handle P extension truncate patterns:
// PASUB/PASUBU: (trunc (srl (sub ([s|z]ext a), ([s|z]ext b)), 1))
// PMULHSU: (trunc (srl (mul (sext a), (zext b)), EltBits))
static SDValue combinePExtTruncate(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
SDValue N0 = N->getOperand(0);
@@ -16401,7 +16405,7 @@ static SDValue combinePExtTruncate(SDNode *N, SelectionDAG &DAG,
VecVT != MVT::v4i8 && VecVT != MVT::v2i32)
return SDValue();
// Check if shift amount is 1
// Check if shift amount is a splat constant
SDValue ShAmt = N0.getOperand(1);
if (ShAmt.getOpcode() != ISD::BUILD_VECTOR)
return SDValue();
@@ -16415,44 +16419,57 @@ static SDValue combinePExtTruncate(SDNode *N, SelectionDAG &DAG,
ConstantSDNode *C = dyn_cast<ConstantSDNode>(Splat);
if (!C)
return SDValue();
if (C->getZExtValue() != 1)
return SDValue();
// Check for SUB operation
SDValue Sub = N0.getOperand(0);
if (Sub.getOpcode() != ISD::SUB)
return SDValue();
SDValue Op = N0.getOperand(0);
unsigned ShAmtVal = C->getZExtValue();
SDValue LHS = Sub.getOperand(0);
SDValue RHS = Sub.getOperand(1);
SDValue LHS = Op.getOperand(0);
SDValue RHS = Op.getOperand(1);
// Check if both operands are sign/zero extends from the target
// type
bool IsSignExt = LHS.getOpcode() == ISD::SIGN_EXTEND &&
RHS.getOpcode() == ISD::SIGN_EXTEND;
bool IsZeroExt = LHS.getOpcode() == ISD::ZERO_EXTEND &&
RHS.getOpcode() == ISD::ZERO_EXTEND;
bool LHSIsSExt = LHS.getOpcode() == ISD::SIGN_EXTEND;
bool LHSIsZExt = LHS.getOpcode() == ISD::ZERO_EXTEND;
bool RHSIsSExt = RHS.getOpcode() == ISD::SIGN_EXTEND;
bool RHSIsZExt = RHS.getOpcode() == ISD::ZERO_EXTEND;
if (!IsSignExt && !IsZeroExt)
if (!(LHSIsSExt || LHSIsZExt) || !(RHSIsSExt || RHSIsZExt))
return SDValue();
SDValue A = LHS.getOperand(0);
SDValue B = RHS.getOperand(0);
// Check if the extends are from our target vector type
if (A.getValueType() != VT || B.getValueType() != VT)
return SDValue();
// Determine the instruction based on type and signedness
unsigned Opc;
if (IsSignExt)
Opc = RISCVISD::PASUB;
else if (IsZeroExt)
Opc = RISCVISD::PASUBU;
else
switch (Op.getOpcode()) {
default:
return SDValue();
case ISD::SUB:
// PASUB/PASUBU: shift amount must be 1
if (ShAmtVal != 1)
return SDValue();
if (LHSIsSExt && RHSIsSExt)
Opc = RISCVISD::PASUB;
else if (LHSIsZExt && RHSIsZExt)
Opc = RISCVISD::PASUBU;
else
return SDValue();
break;
case ISD::MUL:
// PMULHSU: shift amount must be element size, only for i16/i32
unsigned EltBits = VecVT.getScalarSizeInBits();
if (ShAmtVal != EltBits || (EltBits != 16 && EltBits != 32))
return SDValue();
if ((LHSIsSExt && RHSIsZExt) || (LHSIsZExt && RHSIsSExt)) {
Opc = RISCVISD::PMULHSU;
// commuted case
if (LHSIsZExt && RHSIsSExt)
std::swap(A, B);
} else
return SDValue();
break;
}
// Create the machine node directly
return DAG.getNode(Opc, SDLoc(N), VT, {A, B});
}

View File

@@ -1463,12 +1463,13 @@ let Predicates = [HasStdExtP, IsRV32] in {
def riscv_absw : RVSDNode<"ABSW", SDT_RISCVIntUnaryOpW>;
def SDT_RISCVPASUB : SDTypeProfile<1, 2, [SDTCisVec<0>,
SDTCisInt<0>,
SDTCisSameAs<0, 1>,
SDTCisSameAs<0, 2>]>;
def riscv_pasub : RVSDNode<"PASUB", SDT_RISCVPASUB>;
def riscv_pasubu : RVSDNode<"PASUBU", SDT_RISCVPASUB>;
def SDT_RISCVPBinOp : SDTypeProfile<1, 2, [SDTCisVec<0>,
SDTCisInt<0>,
SDTCisSameAs<0, 1>,
SDTCisSameAs<0, 2>]>;
def riscv_pasub : RVSDNode<"PASUB", SDT_RISCVPBinOp>;
def riscv_pasubu : RVSDNode<"PASUBU", SDT_RISCVPBinOp>;
def riscv_pmulhsu : RVSDNode<"PMULHSU", SDT_RISCVPBinOp>;
let Predicates = [HasStdExtP] in {
def : PatGpr<abs, ABS>;
@@ -1513,6 +1514,11 @@ let Predicates = [HasStdExtP] in {
def: Pat<(XLenVecI16VT (abds GPR:$rs1, GPR:$rs2)), (PABD_H GPR:$rs1, GPR:$rs2)>;
def: Pat<(XLenVecI16VT (abdu GPR:$rs1, GPR:$rs2)), (PABDU_H GPR:$rs1, GPR:$rs2)>;
// 16-bit multiply high patterns
def: Pat<(XLenVecI16VT (mulhs GPR:$rs1, GPR:$rs2)), (PMULH_H GPR:$rs1, GPR:$rs2)>;
def: Pat<(XLenVecI16VT (mulhu GPR:$rs1, GPR:$rs2)), (PMULHU_H GPR:$rs1, GPR:$rs2)>;
def: Pat<(XLenVecI16VT (riscv_pmulhsu GPR:$rs1, GPR:$rs2)), (PMULHSU_H GPR:$rs1, GPR:$rs2)>;
// 8-bit logical shift left/right patterns
def: Pat<(XLenVecI8VT (shl GPR:$rs1, (XLenVecI8VT (splat_vector uimm3:$shamt)))),
(PSLLI_B GPR:$rs1, uimm3:$shamt)>;
@@ -1609,6 +1615,11 @@ let Predicates = [HasStdExtP, IsRV64] in {
def: Pat<(v2i32 (riscv_pasub GPR:$rs1, GPR:$rs2)), (PASUB_W GPR:$rs1, GPR:$rs2)>;
def: Pat<(v2i32 (riscv_pasubu GPR:$rs1, GPR:$rs2)), (PASUBU_W GPR:$rs1, GPR:$rs2)>;
// 32-bit multiply high patterns
def: Pat<(v2i32 (mulhs GPR:$rs1, GPR:$rs2)), (PMULH_W GPR:$rs1, GPR:$rs2)>;
def: Pat<(v2i32 (mulhu GPR:$rs1, GPR:$rs2)), (PMULHU_W GPR:$rs1, GPR:$rs2)>;
def: Pat<(v2i32 (riscv_pmulhsu GPR:$rs1, GPR:$rs2)), (PMULHSU_W GPR:$rs1, GPR:$rs2)>;
// 32-bit logical shift left/right
def: Pat<(v2i32 (shl GPR:$rs1, (v2i32 (splat_vector (XLenVT GPR:$rs2))))),
(PSLL_WS GPR:$rs1, GPR:$rs2)>;

View File

@@ -1040,3 +1040,81 @@ define void @test_psra_bs_vec_shamt(ptr %ret_ptr, ptr %a_ptr, ptr %shamt_ptr) {
store <4 x i8> %res, ptr %ret_ptr
ret void
}
; Test packed multiply high signed for v2i16
define void @test_pmulh_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
; CHECK-LABEL: test_pmulh_h:
; CHECK: # %bb.0:
; CHECK-NEXT: lw a1, 0(a1)
; CHECK-NEXT: lw a2, 0(a2)
; CHECK-NEXT: pmulh.h a1, a1, a2
; CHECK-NEXT: sw a1, 0(a0)
; CHECK-NEXT: ret
%a = load <2 x i16>, ptr %a_ptr
%b = load <2 x i16>, ptr %b_ptr
%a_ext = sext <2 x i16> %a to <2 x i32>
%b_ext = sext <2 x i16> %b to <2 x i32>
%mul = mul <2 x i32> %a_ext, %b_ext
%shift = lshr <2 x i32> %mul, <i32 16, i32 16>
%res = trunc <2 x i32> %shift to <2 x i16>
store <2 x i16> %res, ptr %ret_ptr
ret void
}
; Test packed multiply high unsigned for v2i16
define void @test_pmulhu_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
; CHECK-LABEL: test_pmulhu_h:
; CHECK: # %bb.0:
; CHECK-NEXT: lw a1, 0(a1)
; CHECK-NEXT: lw a2, 0(a2)
; CHECK-NEXT: pmulhu.h a1, a1, a2
; CHECK-NEXT: sw a1, 0(a0)
; CHECK-NEXT: ret
%a = load <2 x i16>, ptr %a_ptr
%b = load <2 x i16>, ptr %b_ptr
%a_ext = zext <2 x i16> %a to <2 x i32>
%b_ext = zext <2 x i16> %b to <2 x i32>
%mul = mul <2 x i32> %a_ext, %b_ext
%shift = lshr <2 x i32> %mul, <i32 16, i32 16>
%res = trunc <2 x i32> %shift to <2 x i16>
store <2 x i16> %res, ptr %ret_ptr
ret void
}
; Test packed multiply high signed-unsigned for v2i16
define void @test_pmulhsu_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
; CHECK-LABEL: test_pmulhsu_h:
; CHECK: # %bb.0:
; CHECK-NEXT: lw a1, 0(a1)
; CHECK-NEXT: lw a2, 0(a2)
; CHECK-NEXT: pmulhsu.h a1, a1, a2
; CHECK-NEXT: sw a1, 0(a0)
; CHECK-NEXT: ret
%a = load <2 x i16>, ptr %a_ptr
%b = load <2 x i16>, ptr %b_ptr
%a_ext = sext <2 x i16> %a to <2 x i32>
%b_ext = zext <2 x i16> %b to <2 x i32>
%mul = mul <2 x i32> %a_ext, %b_ext
%shift = lshr <2 x i32> %mul, <i32 16, i32 16>
%res = trunc <2 x i32> %shift to <2 x i16>
store <2 x i16> %res, ptr %ret_ptr
ret void
}
define void @test_pmulhsu_h_commuted(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
; CHECK-LABEL: test_pmulhsu_h_commuted:
; CHECK: # %bb.0:
; CHECK-NEXT: lw a1, 0(a1)
; CHECK-NEXT: lw a2, 0(a2)
; CHECK-NEXT: pmulhsu.h a1, a2, a1
; CHECK-NEXT: sw a1, 0(a0)
; CHECK-NEXT: ret
%a = load <2 x i16>, ptr %a_ptr
%b = load <2 x i16>, ptr %b_ptr
%a_ext = zext <2 x i16> %a to <2 x i32>
%b_ext = sext <2 x i16> %b to <2 x i32>
%mul = mul <2 x i32> %a_ext, %b_ext
%shift = lshr <2 x i32> %mul, <i32 16, i32 16>
%res = trunc <2 x i32> %shift to <2 x i16>
store <2 x i16> %res, ptr %ret_ptr
ret void
}

View File

@@ -993,3 +993,158 @@ define void @test_psra_ws_vec_shamt(ptr %ret_ptr, ptr %a_ptr, ptr %shamt_ptr) {
store <2 x i32> %res, ptr %ret_ptr
ret void
}
; Test packed multiply high signed
define void @test_pmulh_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
; CHECK-LABEL: test_pmulh_h:
; CHECK: # %bb.0:
; CHECK-NEXT: ld a1, 0(a1)
; CHECK-NEXT: ld a2, 0(a2)
; CHECK-NEXT: pmulh.h a1, a1, a2
; CHECK-NEXT: sd a1, 0(a0)
; CHECK-NEXT: ret
%a = load <4 x i16>, ptr %a_ptr
%b = load <4 x i16>, ptr %b_ptr
%a_ext = sext <4 x i16> %a to <4 x i32>
%b_ext = sext <4 x i16> %b to <4 x i32>
%mul = mul <4 x i32> %a_ext, %b_ext
%shift = lshr <4 x i32> %mul, <i32 16, i32 16, i32 16, i32 16>
%res = trunc <4 x i32> %shift to <4 x i16>
store <4 x i16> %res, ptr %ret_ptr
ret void
}
define void @test_pmulh_w(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
; CHECK-LABEL: test_pmulh_w:
; CHECK: # %bb.0:
; CHECK-NEXT: ld a1, 0(a1)
; CHECK-NEXT: ld a2, 0(a2)
; CHECK-NEXT: pmulh.w a1, a1, a2
; CHECK-NEXT: sd a1, 0(a0)
; CHECK-NEXT: ret
%a = load <2 x i32>, ptr %a_ptr
%b = load <2 x i32>, ptr %b_ptr
%a_ext = sext <2 x i32> %a to <2 x i64>
%b_ext = sext <2 x i32> %b to <2 x i64>
%mul = mul <2 x i64> %a_ext, %b_ext
%shift = lshr <2 x i64> %mul, <i64 32, i64 32>
%res = trunc <2 x i64> %shift to <2 x i32>
store <2 x i32> %res, ptr %ret_ptr
ret void
}
; Test packed multiply high unsigned
define void @test_pmulhu_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
; CHECK-LABEL: test_pmulhu_h:
; CHECK: # %bb.0:
; CHECK-NEXT: ld a1, 0(a1)
; CHECK-NEXT: ld a2, 0(a2)
; CHECK-NEXT: pmulhu.h a1, a1, a2
; CHECK-NEXT: sd a1, 0(a0)
; CHECK-NEXT: ret
%a = load <4 x i16>, ptr %a_ptr
%b = load <4 x i16>, ptr %b_ptr
%a_ext = zext <4 x i16> %a to <4 x i32>
%b_ext = zext <4 x i16> %b to <4 x i32>
%mul = mul <4 x i32> %a_ext, %b_ext
%shift = lshr <4 x i32> %mul, <i32 16, i32 16, i32 16, i32 16>
%res = trunc <4 x i32> %shift to <4 x i16>
store <4 x i16> %res, ptr %ret_ptr
ret void
}
define void @test_pmulhu_w(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
; CHECK-LABEL: test_pmulhu_w:
; CHECK: # %bb.0:
; CHECK-NEXT: ld a1, 0(a1)
; CHECK-NEXT: ld a2, 0(a2)
; CHECK-NEXT: pmulhu.w a1, a1, a2
; CHECK-NEXT: sd a1, 0(a0)
; CHECK-NEXT: ret
%a = load <2 x i32>, ptr %a_ptr
%b = load <2 x i32>, ptr %b_ptr
%a_ext = zext <2 x i32> %a to <2 x i64>
%b_ext = zext <2 x i32> %b to <2 x i64>
%mul = mul <2 x i64> %a_ext, %b_ext
%shift = lshr <2 x i64> %mul, <i64 32, i64 32>
%res = trunc <2 x i64> %shift to <2 x i32>
store <2 x i32> %res, ptr %ret_ptr
ret void
}
; Test packed multiply high signed-unsigned
define void @test_pmulhsu_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
; CHECK-LABEL: test_pmulhsu_h:
; CHECK: # %bb.0:
; CHECK-NEXT: ld a1, 0(a1)
; CHECK-NEXT: ld a2, 0(a2)
; CHECK-NEXT: pmulhsu.h a1, a1, a2
; CHECK-NEXT: sd a1, 0(a0)
; CHECK-NEXT: ret
%a = load <4 x i16>, ptr %a_ptr
%b = load <4 x i16>, ptr %b_ptr
%a_ext = sext <4 x i16> %a to <4 x i32>
%b_ext = zext <4 x i16> %b to <4 x i32>
%mul = mul <4 x i32> %a_ext, %b_ext
%shift = lshr <4 x i32> %mul, <i32 16, i32 16, i32 16, i32 16>
%res = trunc <4 x i32> %shift to <4 x i16>
store <4 x i16> %res, ptr %ret_ptr
ret void
}
define void @test_pmulhsu_h_commuted(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
; CHECK-LABEL: test_pmulhsu_h_commuted:
; CHECK: # %bb.0:
; CHECK-NEXT: ld a1, 0(a1)
; CHECK-NEXT: ld a2, 0(a2)
; CHECK-NEXT: pmulhsu.h a1, a2, a1
; CHECK-NEXT: sd a1, 0(a0)
; CHECK-NEXT: ret
%a = load <4 x i16>, ptr %a_ptr
%b = load <4 x i16>, ptr %b_ptr
%a_ext = zext <4 x i16> %a to <4 x i32>
%b_ext = sext <4 x i16> %b to <4 x i32>
%mul = mul <4 x i32> %a_ext, %b_ext
%shift = lshr <4 x i32> %mul, <i32 16, i32 16, i32 16, i32 16>
%res = trunc <4 x i32> %shift to <4 x i16>
store <4 x i16> %res, ptr %ret_ptr
ret void
}
define void @test_pmulhsu_w(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
; CHECK-LABEL: test_pmulhsu_w:
; CHECK: # %bb.0:
; CHECK-NEXT: ld a1, 0(a1)
; CHECK-NEXT: ld a2, 0(a2)
; CHECK-NEXT: pmulhsu.w a1, a1, a2
; CHECK-NEXT: sd a1, 0(a0)
; CHECK-NEXT: ret
%a = load <2 x i32>, ptr %a_ptr
%b = load <2 x i32>, ptr %b_ptr
%a_ext = sext <2 x i32> %a to <2 x i64>
%b_ext = zext <2 x i32> %b to <2 x i64>
%mul = mul <2 x i64> %a_ext, %b_ext
%shift = lshr <2 x i64> %mul, <i64 32, i64 32>
%res = trunc <2 x i64> %shift to <2 x i32>
store <2 x i32> %res, ptr %ret_ptr
ret void
}
define void @test_pmulhsu_w_commuted(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
; CHECK-LABEL: test_pmulhsu_w_commuted:
; CHECK: # %bb.0:
; CHECK-NEXT: ld a1, 0(a1)
; CHECK-NEXT: ld a2, 0(a2)
; CHECK-NEXT: pmulhsu.w a1, a2, a1
; CHECK-NEXT: sd a1, 0(a0)
; CHECK-NEXT: ret
%a = load <2 x i32>, ptr %a_ptr
%b = load <2 x i32>, ptr %b_ptr
%a_ext = zext <2 x i32> %a to <2 x i64>
%b_ext = sext <2 x i32> %b to <2 x i64>
%mul = mul <2 x i64> %a_ext, %b_ext
%shift = lshr <2 x i64> %mul, <i64 32, i64 32>
%res = trunc <2 x i64> %shift to <2 x i32>
store <2 x i32> %res, ptr %ret_ptr
ret void
}