From a61252419779a6d4a5ebf71e7e2fc4adc75cfddd Mon Sep 17 00:00:00 2001 From: Alexey Bataev Date: Sun, 7 Apr 2024 09:51:47 -0400 Subject: [PATCH] [SLP]Fix the cost of the reduction result to the final type. Need to fix the way the cost is calculated, otherwise wrong cast opcode can be selected and lead to the over-optimistic vector cost. Plus, need to take into account reduction type size. Reviewers: RKSimon Reviewed By: RKSimon Pull Request: https://github.com/llvm/llvm-project/pull/87528 --- .../Transforms/Vectorize/SLPVectorizer.cpp | 8 +++--- .../SLPVectorizer/RISCV/reductions.ll | 18 ++++++++++--- .../X86/minbitwidth-drop-wrapping-flags.ll | 25 +++++++++++-------- .../X86/minbitwidth-transformed-operand.ll | 2 +- 4 files changed, 35 insertions(+), 18 deletions(-) diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index 332877f35081..6a662b2791bd 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -9824,11 +9824,13 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef VectorizedVals) { if (BWIt != MinBWs.end()) { Type *DstTy = Root.Scalars.front()->getType(); unsigned OriginalSz = DL->getTypeSizeInBits(DstTy); - if (OriginalSz != BWIt->second.first) { + unsigned SrcSz = + ReductionBitWidth == 0 ? BWIt->second.first : ReductionBitWidth; + if (OriginalSz != SrcSz) { unsigned Opcode = Instruction::Trunc; - if (OriginalSz < BWIt->second.first) + if (OriginalSz > SrcSz) Opcode = BWIt->second.second ? Instruction::SExt : Instruction::ZExt; - Type *SrcTy = IntegerType::get(DstTy->getContext(), BWIt->second.first); + Type *SrcTy = IntegerType::get(DstTy->getContext(), SrcSz); Cost += TTI->getCastInstrCost(Opcode, DstTy, SrcTy, TTI::CastContextHint::None, TTI::TCK_RecipThroughput); diff --git a/llvm/test/Transforms/SLPVectorizer/RISCV/reductions.ll b/llvm/test/Transforms/SLPVectorizer/RISCV/reductions.ll index 500f10659f04..1e7eb4a41672 100644 --- a/llvm/test/Transforms/SLPVectorizer/RISCV/reductions.ll +++ b/llvm/test/Transforms/SLPVectorizer/RISCV/reductions.ll @@ -801,10 +801,20 @@ entry: define i64 @red_zext_ld_4xi64(ptr %ptr) { ; CHECK-LABEL: @red_zext_ld_4xi64( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[TMP0:%.*]] = load <4 x i8>, ptr [[PTR:%.*]], align 1 -; CHECK-NEXT: [[TMP1:%.*]] = zext <4 x i8> [[TMP0]] to <4 x i16> -; CHECK-NEXT: [[TMP2:%.*]] = call i16 @llvm.vector.reduce.add.v4i16(<4 x i16> [[TMP1]]) -; CHECK-NEXT: [[TMP3:%.*]] = zext i16 [[TMP2]] to i64 +; CHECK-NEXT: [[LD0:%.*]] = load i8, ptr [[PTR:%.*]], align 1 +; CHECK-NEXT: [[ZEXT:%.*]] = zext i8 [[LD0]] to i64 +; CHECK-NEXT: [[GEP:%.*]] = getelementptr inbounds i8, ptr [[PTR]], i64 1 +; CHECK-NEXT: [[LD1:%.*]] = load i8, ptr [[GEP]], align 1 +; CHECK-NEXT: [[ZEXT_1:%.*]] = zext i8 [[LD1]] to i64 +; CHECK-NEXT: [[ADD_1:%.*]] = add nuw nsw i64 [[ZEXT]], [[ZEXT_1]] +; CHECK-NEXT: [[GEP_1:%.*]] = getelementptr inbounds i8, ptr [[PTR]], i64 2 +; CHECK-NEXT: [[LD2:%.*]] = load i8, ptr [[GEP_1]], align 1 +; CHECK-NEXT: [[ZEXT_2:%.*]] = zext i8 [[LD2]] to i64 +; CHECK-NEXT: [[ADD_2:%.*]] = add nuw nsw i64 [[ADD_1]], [[ZEXT_2]] +; CHECK-NEXT: [[GEP_2:%.*]] = getelementptr inbounds i8, ptr [[PTR]], i64 3 +; CHECK-NEXT: [[LD3:%.*]] = load i8, ptr [[GEP_2]], align 1 +; CHECK-NEXT: [[ZEXT_3:%.*]] = zext i8 [[LD3]] to i64 +; CHECK-NEXT: [[TMP3:%.*]] = add nuw nsw i64 [[ADD_2]], [[ZEXT_3]] ; CHECK-NEXT: ret i64 [[TMP3]] ; entry: diff --git a/llvm/test/Transforms/SLPVectorizer/X86/minbitwidth-drop-wrapping-flags.ll b/llvm/test/Transforms/SLPVectorizer/X86/minbitwidth-drop-wrapping-flags.ll index 44738aa1a674..a8d481a3e28a 100644 --- a/llvm/test/Transforms/SLPVectorizer/X86/minbitwidth-drop-wrapping-flags.ll +++ b/llvm/test/Transforms/SLPVectorizer/X86/minbitwidth-drop-wrapping-flags.ll @@ -5,17 +5,22 @@ define i32 @test() { ; CHECK-LABEL: define i32 @test() { ; CHECK-NEXT: entry: ; CHECK-NEXT: [[A_PROMOTED:%.*]] = load i8, ptr null, align 1 -; CHECK-NEXT: [[TMP0:%.*]] = insertelement <4 x i8> poison, i8 [[A_PROMOTED]], i32 0 -; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x i8> [[TMP0]], <4 x i8> poison, <4 x i32> zeroinitializer -; CHECK-NEXT: [[TMP2:%.*]] = add <4 x i8> [[TMP1]], zeroinitializer -; CHECK-NEXT: [[TMP3:%.*]] = or <4 x i8> [[TMP1]], zeroinitializer -; CHECK-NEXT: [[TMP4:%.*]] = shufflevector <4 x i8> [[TMP2]], <4 x i8> [[TMP3]], <4 x i32> -; CHECK-NEXT: [[TMP5:%.*]] = zext <4 x i8> [[TMP4]] to <4 x i16> -; CHECK-NEXT: [[TMP6:%.*]] = add <4 x i16> [[TMP5]], -; CHECK-NEXT: [[TMP7:%.*]] = call i16 @llvm.vector.reduce.or.v4i16(<4 x i16> [[TMP6]]) -; CHECK-NEXT: [[TMP8:%.*]] = zext i16 [[TMP7]] to i32 +; CHECK-NEXT: [[DEC_4:%.*]] = add i8 [[A_PROMOTED]], 0 +; CHECK-NEXT: [[CONV_I_4:%.*]] = zext i8 [[DEC_4]] to i32 +; CHECK-NEXT: [[SUB_I_4:%.*]] = add nuw nsw i32 [[CONV_I_4]], 0 +; CHECK-NEXT: [[DEC_5:%.*]] = add i8 [[A_PROMOTED]], 0 +; CHECK-NEXT: [[CONV_I_5:%.*]] = zext i8 [[DEC_5]] to i32 +; CHECK-NEXT: [[SUB_I_5:%.*]] = add nuw nsw i32 [[CONV_I_5]], 65535 +; CHECK-NEXT: [[TMP0:%.*]] = or i32 [[SUB_I_4]], [[SUB_I_5]] +; CHECK-NEXT: [[DEC_6:%.*]] = or i8 [[A_PROMOTED]], 0 +; CHECK-NEXT: [[CONV_I_6:%.*]] = zext i8 [[DEC_6]] to i32 +; CHECK-NEXT: [[SUB_I_6:%.*]] = add nuw nsw i32 [[CONV_I_6]], 0 +; CHECK-NEXT: [[TMP1:%.*]] = or i32 [[TMP0]], [[SUB_I_6]] +; CHECK-NEXT: [[TMP10:%.*]] = or i8 [[A_PROMOTED]], 0 +; CHECK-NEXT: [[CONV_I_7:%.*]] = zext i8 [[TMP10]] to i32 +; CHECK-NEXT: [[SUB_I_7:%.*]] = add nuw nsw i32 [[CONV_I_7]], 0 +; CHECK-NEXT: [[TMP8:%.*]] = or i32 [[TMP1]], [[SUB_I_7]] ; CHECK-NEXT: [[TMP9:%.*]] = and i32 [[TMP8]], 65535 -; CHECK-NEXT: [[TMP10:%.*]] = extractelement <4 x i8> [[TMP4]], i32 3 ; CHECK-NEXT: store i8 [[TMP10]], ptr null, align 1 ; CHECK-NEXT: [[CALL3:%.*]] = tail call i32 (ptr, ...) null(ptr null, i32 [[TMP9]]) ; CHECK-NEXT: ret i32 0 diff --git a/llvm/test/Transforms/SLPVectorizer/X86/minbitwidth-transformed-operand.ll b/llvm/test/Transforms/SLPVectorizer/X86/minbitwidth-transformed-operand.ll index 4acd63078b82..4af69dff179e 100644 --- a/llvm/test/Transforms/SLPVectorizer/X86/minbitwidth-transformed-operand.ll +++ b/llvm/test/Transforms/SLPVectorizer/X86/minbitwidth-transformed-operand.ll @@ -1,5 +1,5 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4 -; RUN: opt -passes=slp-vectorizer -S -slp-threshold=-6 -mtriple=x86_64-unknown-linux-gnu < %s | FileCheck %s +; RUN: opt -passes=slp-vectorizer -S -slp-threshold=-7 -mtriple=x86_64-unknown-linux-gnu < %s | FileCheck %s define void @test(i64 %d.promoted.i) { ; CHECK-LABEL: define void @test(