[X86][BF16] Improve float -> bfloat lowering under AVX512BF16 and AVXNECONVERT (#78042)

This commit is contained in:
Phoebe Wang
2024-01-17 10:09:26 +08:00
committed by GitHub
parent 46a395d8c4
commit 9745c13ca8
3 changed files with 389 additions and 488 deletions

View File

@@ -21523,9 +21523,19 @@ static SDValue LowerFP_TO_FP16(SDValue Op, SelectionDAG &DAG) {
SDValue X86TargetLowering::LowerFP_TO_BF16(SDValue Op,
SelectionDAG &DAG) const {
SDLoc DL(Op);
MVT SVT = Op.getOperand(0).getSimpleValueType();
if (SVT == MVT::f32 && (Subtarget.hasBF16() || Subtarget.hasAVXNECONVERT())) {
SDValue Res;
Res = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v4f32, Op.getOperand(0));
Res = DAG.getNode(X86ISD::CVTNEPS2BF16, DL, MVT::v8bf16, Res);
Res = DAG.getBitcast(MVT::v8i16, Res);
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i16, Res,
DAG.getIntPtrConstant(0, DL));
}
MakeLibCallOptions CallOptions;
RTLIB::Libcall LC =
RTLIB::getFPROUND(Op.getOperand(0).getValueType(), MVT::bf16);
RTLIB::Libcall LC = RTLIB::getFPROUND(SVT, MVT::bf16);
SDValue Res =
makeLibCall(DAG, LC, MVT::f16, Op.getOperand(0), CallOptions, DL).first;
return DAG.getBitcast(MVT::i16, Res);

View File

@@ -8331,6 +8331,10 @@ let Predicates = [HasAVXNECONVERT] in {
f256mem>, T8;
defm VCVTNEPS2BF16 : VCVTNEPS2BF16_BASE, VEX, T8, XS, ExplicitVEXPrefix;
def : Pat<(v8bf16 (X86cvtneps2bf16 (v4f32 VR128X:$src))),
(VCVTNEPS2BF16rr VR128:$src)>;
def : Pat<(v8bf16 (X86cvtneps2bf16 (loadv4f32 addr:$src))),
(VCVTNEPS2BF16rm addr:$src)>;
def : Pat<(v8bf16 (X86vfpround (v8f32 VR256:$src))),
(VCVTNEPS2BF16Yrr VR256:$src)>;
def : Pat<(v8bf16 (X86vfpround (loadv8f32 addr:$src))),

File diff suppressed because it is too large Load Diff