[mlir][Vector] Add LLVM lowering for masked reductions

This patch adds the conversion patterns to lower masked reduction
operations to the corresponding vp intrinsics in LLVM.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D142177
This commit is contained in:
Diego Caballero
2023-02-13 19:31:23 +00:00
parent a409f3c069
commit e9b82a5c4f
2 changed files with 564 additions and 38 deletions

View File

@@ -15,6 +15,7 @@
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
@@ -408,15 +409,154 @@ public:
}
};
/// Reduction neutral classes for overloading.
class ReductionNeutralZero {};
class ReductionNeutralIntOne {};
class ReductionNeutralFPOne {};
class ReductionNeutralAllOnes {};
class ReductionNeutralSIntMin {};
class ReductionNeutralUIntMin {};
class ReductionNeutralSIntMax {};
class ReductionNeutralUIntMax {};
class ReductionNeutralFPMin {};
class ReductionNeutralFPMax {};
/// Create the reduction neutral zero value.
static Value createReductionNeutralValue(ReductionNeutralZero neutral,
ConversionPatternRewriter &rewriter,
Location loc, Type llvmType) {
return rewriter.create<LLVM::ConstantOp>(loc, llvmType,
rewriter.getZeroAttr(llvmType));
}
/// Create the reduction neutral integer one value.
static Value createReductionNeutralValue(ReductionNeutralIntOne neutral,
ConversionPatternRewriter &rewriter,
Location loc, Type llvmType) {
return rewriter.create<LLVM::ConstantOp>(
loc, llvmType, rewriter.getIntegerAttr(llvmType, 1));
}
/// Create the reduction neutral fp one value.
static Value createReductionNeutralValue(ReductionNeutralFPOne neutral,
ConversionPatternRewriter &rewriter,
Location loc, Type llvmType) {
return rewriter.create<LLVM::ConstantOp>(
loc, llvmType, rewriter.getFloatAttr(llvmType, 1.0));
}
/// Create the reduction neutral all-ones value.
static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral,
ConversionPatternRewriter &rewriter,
Location loc, Type llvmType) {
return rewriter.create<LLVM::ConstantOp>(
loc, llvmType,
rewriter.getIntegerAttr(
llvmType, llvm::APInt::getAllOnes(llvmType.getIntOrFloatBitWidth())));
}
/// Create the reduction neutral signed int minimum value.
static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral,
ConversionPatternRewriter &rewriter,
Location loc, Type llvmType) {
return rewriter.create<LLVM::ConstantOp>(
loc, llvmType,
rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMinValue(
llvmType.getIntOrFloatBitWidth())));
}
/// Create the reduction neutral unsigned int minimum value.
static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral,
ConversionPatternRewriter &rewriter,
Location loc, Type llvmType) {
return rewriter.create<LLVM::ConstantOp>(
loc, llvmType,
rewriter.getIntegerAttr(llvmType, llvm::APInt::getMinValue(
llvmType.getIntOrFloatBitWidth())));
}
/// Create the reduction neutral signed int maximum value.
static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral,
ConversionPatternRewriter &rewriter,
Location loc, Type llvmType) {
return rewriter.create<LLVM::ConstantOp>(
loc, llvmType,
rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMaxValue(
llvmType.getIntOrFloatBitWidth())));
}
/// Create the reduction neutral unsigned int maximum value.
static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral,
ConversionPatternRewriter &rewriter,
Location loc, Type llvmType) {
return rewriter.create<LLVM::ConstantOp>(
loc, llvmType,
rewriter.getIntegerAttr(llvmType, llvm::APInt::getMaxValue(
llvmType.getIntOrFloatBitWidth())));
}
/// Create the reduction neutral fp minimum value.
static Value createReductionNeutralValue(ReductionNeutralFPMin neutral,
ConversionPatternRewriter &rewriter,
Location loc, Type llvmType) {
auto floatType = llvmType.cast<FloatType>();
return rewriter.create<LLVM::ConstantOp>(
loc, llvmType,
rewriter.getFloatAttr(
llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
/*Negative=*/false)));
}
/// Create the reduction neutral fp maximum value.
static Value createReductionNeutralValue(ReductionNeutralFPMax neutral,
ConversionPatternRewriter &rewriter,
Location loc, Type llvmType) {
auto floatType = llvmType.cast<FloatType>();
return rewriter.create<LLVM::ConstantOp>(
loc, llvmType,
rewriter.getFloatAttr(
llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
/*Negative=*/true)));
}
/// Returns `accumulator` if it has a valid value. Otherwise, creates and
/// returns a new accumulator value using `ReductionNeutral`.
template <class ReductionNeutral>
static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter,
Location loc, Type llvmType,
Value accumulator) {
if (accumulator)
return accumulator;
return createReductionNeutralValue(ReductionNeutral(), rewriter, loc,
llvmType);
}
/// Creates a constant value with the 1-D vector shape provided in `llvmType`.
/// This is used as effective vector length by some intrinsics supporting
/// dynamic vector lengths at runtime.
static Value createVectorLengthValue(ConversionPatternRewriter &rewriter,
Location loc, Type llvmType) {
VectorType vType = cast<VectorType>(llvmType);
auto vShape = vType.getShape();
assert(vShape.size() == 1 && "Unexpected multi-dim vector type");
return rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(),
rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0]));
}
/// Helper method to lower a `vector.reduction` op that performs an arithmetic
/// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use
/// and `ScalarOp` is the scalar operation used to add the accumulation value if
/// non-null.
template <class VectorOp, class ScalarOp>
template <class LLVMRedIntrinOp, class ScalarOp>
static Value createIntegerReductionArithmeticOpLowering(
ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
Value vectorOperand, Value accumulator) {
Value result = rewriter.create<VectorOp>(loc, llvmType, vectorOperand);
Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
if (accumulator)
result = rewriter.create<ScalarOp>(loc, accumulator, result);
return result;
@@ -426,11 +566,11 @@ static Value createIntegerReductionArithmeticOpLowering(
/// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector
/// intrinsic to use and `predicate` is the predicate to use to compare+combine
/// the accumulator value if non-null.
template <class VectorOp>
template <class LLVMRedIntrinOp>
static Value createIntegerReductionComparisonOpLowering(
ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) {
Value result = rewriter.create<VectorOp>(loc, llvmType, vectorOperand);
Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
if (accumulator) {
Value cmp =
rewriter.create<LLVM::ICmpOp>(loc, predicate, accumulator, result);
@@ -460,6 +600,91 @@ static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
return builder.create<LLVM::SelectOp>(loc, isNan, nan, sel);
}
template <class LLVMRedIntrinOp>
static Value createFPReductionComparisonOpLowering(
ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
Value vectorOperand, Value accumulator, bool isMin) {
Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
if (accumulator)
result = createMinMaxF(rewriter, loc, result, accumulator, /*isMin=*/isMin);
return result;
}
/// Overloaded methods to lower a reduction to an llvm instrinsic that requires
/// a start value. This start value format spans across fp reductions without
/// mask and all the masked reduction intrinsics.
template <class LLVMVPRedIntrinOp, class ReductionNeutral>
static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter,
Location loc, Type llvmType,
Value vectorOperand,
Value accumulator) {
accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
llvmType, accumulator);
return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
/*startValue=*/accumulator,
vectorOperand);
}
template <class LLVMVPRedIntrinOp, class ReductionNeutral>
static Value
lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc,
Type llvmType, Value vectorOperand,
Value accumulator, bool reassociateFPReds) {
accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
llvmType, accumulator);
return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
/*startValue=*/accumulator,
vectorOperand, reassociateFPReds);
}
template <class LLVMVPRedIntrinOp, class ReductionNeutral>
static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter,
Location loc, Type llvmType,
Value vectorOperand,
Value accumulator, Value mask) {
accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
llvmType, accumulator);
Value vectorLength =
createVectorLengthValue(rewriter, loc, vectorOperand.getType());
return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
/*startValue=*/accumulator,
vectorOperand, mask, vectorLength);
}
template <class LLVMVPRedIntrinOp, class ReductionNeutral>
static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter,
Location loc, Type llvmType,
Value vectorOperand,
Value accumulator, Value mask,
bool reassociateFPReds) {
accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
llvmType, accumulator);
Value vectorLength =
createVectorLengthValue(rewriter, loc, vectorOperand.getType());
return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
/*startValue=*/accumulator,
vectorOperand, mask, vectorLength,
reassociateFPReds);
}
template <class LLVMIntVPRedIntrinOp, class IntReductionNeutral,
class LLVMFPVPRedIntrinOp, class FPReductionNeutral>
static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter,
Location loc, Type llvmType,
Value vectorOperand,
Value accumulator, Value mask) {
if (llvmType.isIntOrIndex())
return lowerReductionWithStartValue<LLVMIntVPRedIntrinOp,
IntReductionNeutral>(
rewriter, loc, llvmType, vectorOperand, accumulator, mask);
// FP dispatch.
return lowerReductionWithStartValue<LLVMFPVPRedIntrinOp, FPReductionNeutral>(
rewriter, loc, llvmType, vectorOperand, accumulator, mask);
}
/// Conversion pattern for all vector reductions.
class VectorReductionOpConversion
: public ConvertOpToLLVMPattern<vector::ReductionOp> {
@@ -478,6 +703,12 @@ public:
Value operand = adaptor.getVector();
Value acc = adaptor.getAcc();
Location loc = reductionOp.getLoc();
// Masked reductions are lowered separately.
auto maskableOp = cast<MaskableOpInterface>(reductionOp.getOperation());
if (maskableOp.isMasked())
return failure();
if (eltType.isIntOrIndex()) {
// Integer reductions: add/mul/min/max/and/or/xor.
Value result;
@@ -544,45 +775,31 @@ public:
return failure();
// Floating-point reductions: add/mul/min/max
Value result;
if (kind == vector::CombiningKind::ADD) {
// Optional accumulator (or zero).
Value acc = adaptor.getOperands().size() > 1
? adaptor.getOperands()[1]
: rewriter.create<LLVM::ConstantOp>(
reductionOp->getLoc(), llvmType,
rewriter.getZeroAttr(eltType));
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>(
reductionOp, llvmType, acc, operand,
rewriter.getBoolAttr(reassociateFPReductions));
result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd,
ReductionNeutralZero>(
rewriter, loc, llvmType, operand, acc, reassociateFPReductions);
} else if (kind == vector::CombiningKind::MUL) {
// Optional accumulator (or one).
Value acc = adaptor.getOperands().size() > 1
? adaptor.getOperands()[1]
: rewriter.create<LLVM::ConstantOp>(
reductionOp->getLoc(), llvmType,
rewriter.getFloatAttr(eltType, 1.0));
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
reductionOp, llvmType, acc, operand,
rewriter.getBoolAttr(reassociateFPReductions));
result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul,
ReductionNeutralFPOne>(
rewriter, loc, llvmType, operand, acc, reassociateFPReductions);
} else if (kind == vector::CombiningKind::MINF) {
// FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle
// NaNs/-0.0/+0.0 in the same way.
Value result =
rewriter.create<LLVM::vector_reduce_fmin>(loc, llvmType, operand);
if (acc)
result = createMinMaxF(rewriter, loc, result, acc, /*isMin=*/true);
rewriter.replaceOp(reductionOp, result);
result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>(
rewriter, loc, llvmType, operand, acc,
/*isMin=*/true);
} else if (kind == vector::CombiningKind::MAXF) {
// FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle
// NaNs/-0.0/+0.0 in the same way.
Value result =
rewriter.create<LLVM::vector_reduce_fmax>(loc, llvmType, operand);
if (acc)
result = createMinMaxF(rewriter, loc, result, acc, /*isMin=*/false);
rewriter.replaceOp(reductionOp, result);
result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>(
rewriter, loc, llvmType, operand, acc,
/*isMin=*/false);
} else
return failure();
rewriter.replaceOp(reductionOp, result);
return success();
}
@@ -590,6 +807,127 @@ private:
const bool reassociateFPReductions;
};
/// Base class to convert a `vector.mask` operation while matching traits
/// of the maskable operation nested inside. A `VectorMaskOpConversionBase`
/// instance matches against a `vector.mask` operation. The `matchAndRewrite`
/// method performs a second match against the maskable operation `MaskedOp`.
/// Finally, it invokes the virtual method `matchAndRewriteMaskableOp` to be
/// implemented by the concrete conversion classes. This method can match
/// against specific traits of the `vector.mask` and the maskable operation. It
/// must replace the `vector.mask` operation.
template <class MaskedOp>
class VectorMaskOpConversionBase
: public ConvertOpToLLVMPattern<vector::MaskOp> {
public:
using ConvertOpToLLVMPattern<vector::MaskOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override final {
// Match against the maskable operation kind.
Operation *maskableOp = maskOp.getMaskableOp();
if (!isa<MaskedOp>(maskableOp))
return failure();
return matchAndRewriteMaskableOp(
maskOp, cast<MaskedOp>(maskOp.getMaskableOp()), rewriter);
}
protected:
virtual LogicalResult
matchAndRewriteMaskableOp(vector::MaskOp maskOp,
vector::MaskableOpInterface maskableOp,
ConversionPatternRewriter &rewriter) const = 0;
};
class MaskedReductionOpConversion
: public VectorMaskOpConversionBase<vector::ReductionOp> {
public:
using VectorMaskOpConversionBase<
vector::ReductionOp>::VectorMaskOpConversionBase;
virtual LogicalResult matchAndRewriteMaskableOp(
vector::MaskOp maskOp, MaskableOpInterface maskableOp,
ConversionPatternRewriter &rewriter) const override {
auto reductionOp = cast<ReductionOp>(maskableOp.getOperation());
auto kind = reductionOp.getKind();
Type eltType = reductionOp.getDest().getType();
Type llvmType = typeConverter->convertType(eltType);
Value operand = reductionOp.getVector();
Value acc = reductionOp.getAcc();
Location loc = reductionOp.getLoc();
Value result;
switch (kind) {
case vector::CombiningKind::ADD:
result = lowerReductionWithStartValue<
LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp,
ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc,
maskOp.getMask());
break;
case vector::CombiningKind::MUL:
result = lowerReductionWithStartValue<
LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp,
ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc,
maskOp.getMask());
break;
case vector::CombiningKind::MINUI:
result = lowerReductionWithStartValue<LLVM::VPReduceUMinOp,
ReductionNeutralUIntMax>(
rewriter, loc, llvmType, operand, acc, maskOp.getMask());
break;
case vector::CombiningKind::MINSI:
result = lowerReductionWithStartValue<LLVM::VPReduceSMinOp,
ReductionNeutralSIntMax>(
rewriter, loc, llvmType, operand, acc, maskOp.getMask());
break;
case vector::CombiningKind::MAXUI:
result = lowerReductionWithStartValue<LLVM::VPReduceUMaxOp,
ReductionNeutralUIntMin>(
rewriter, loc, llvmType, operand, acc, maskOp.getMask());
break;
case vector::CombiningKind::MAXSI:
result = lowerReductionWithStartValue<LLVM::VPReduceSMaxOp,
ReductionNeutralSIntMin>(
rewriter, loc, llvmType, operand, acc, maskOp.getMask());
break;
case vector::CombiningKind::AND:
result = lowerReductionWithStartValue<LLVM::VPReduceAndOp,
ReductionNeutralAllOnes>(
rewriter, loc, llvmType, operand, acc, maskOp.getMask());
break;
case vector::CombiningKind::OR:
result = lowerReductionWithStartValue<LLVM::VPReduceOrOp,
ReductionNeutralZero>(
rewriter, loc, llvmType, operand, acc, maskOp.getMask());
break;
case vector::CombiningKind::XOR:
result = lowerReductionWithStartValue<LLVM::VPReduceXorOp,
ReductionNeutralZero>(
rewriter, loc, llvmType, operand, acc, maskOp.getMask());
break;
case vector::CombiningKind::MINF:
// FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle
// NaNs/-0.0/+0.0 in the same way.
result = lowerReductionWithStartValue<LLVM::VPReduceFMinOp,
ReductionNeutralFPMax>(
rewriter, loc, llvmType, operand, acc, maskOp.getMask());
break;
case vector::CombiningKind::MAXF:
// FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle
// NaNs/-0.0/+0.0 in the same way.
result = lowerReductionWithStartValue<LLVM::VPReduceFMaxOp,
ReductionNeutralFPMin>(
rewriter, loc, llvmType, operand, acc, maskOp.getMask());
break;
}
// Replace `vector.mask` operation altogether.
rewriter.replaceOp(maskOp, result);
return success();
}
};
class VectorShuffleOpConversion
: public ConvertOpToLLVMPattern<vector::ShuffleOp> {
public:
@@ -1381,8 +1719,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorGatherOpConversion, VectorScatterOpConversion,
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
VectorSplatOpLowering, VectorSplatNdOpLowering,
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering>(
converter);
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
MaskedReductionOpConversion>(converter);
// Transfer ops with rank > 1 are handled by VectorToSCF.
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
}

View File

@@ -1,7 +1,6 @@
// RUN: mlir-opt %s -convert-vector-to-llvm='use-opaque-pointers=1' | FileCheck %s
// RUN: mlir-opt %s -convert-vector-to-llvm='reassociate-fp-reductions use-opaque-pointers=1' | FileCheck %s --check-prefix=REASSOC
// RUN: mlir-opt %s -convert-vector-to-llvm='use-opaque-pointers=1' -split-input-file | FileCheck %s
// RUN: mlir-opt %s -convert-vector-to-llvm='reassociate-fp-reductions use-opaque-pointers=1' -split-input-file | FileCheck %s --check-prefix=REASSOC
//
// CHECK-LABEL: @reduce_add_f32(
// CHECK-SAME: %[[A:.*]]: vector<16xf32>)
// CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
@@ -21,7 +20,8 @@ func.func @reduce_add_f32(%arg0: vector<16xf32>) -> f32 {
return %0 : f32
}
//
// -----
// CHECK-LABEL: @reduce_mul_f32(
// CHECK-SAME: %[[A:.*]]: vector<16xf32>)
// CHECK: %[[C:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
@@ -40,3 +40,191 @@ func.func @reduce_mul_f32(%arg0: vector<16xf32>) -> f32 {
%0 = vector.reduction <mul>, %arg0 : vector<16xf32> into f32
return %0 : f32
}
// -----
func.func @masked_reduce_add_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -> f32 {
%0 = vector.mask %mask { vector.reduction <add>, %arg0 : vector<16xf32> into f32 } : vector<16xi1> -> f32
return %0 : f32
}
// CHECK-LABEL: func.func @masked_reduce_add_f32(
// CHECK-SAME: %[[INPUT:.*]]: vector<16xf32>,
// CHECK-SAME: %[[MASK:.*]]: vector<16xi1>) -> f32 {
// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
// CHECK: %[[VL:.*]] = llvm.mlir.constant(16 : i32) : i32
// CHECK: "llvm.intr.vp.reduce.fadd"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (f32, vector<16xf32>, vector<16xi1>, i32) -> f32
// -----
func.func @masked_reduce_mul_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -> f32 {
%0 = vector.mask %mask { vector.reduction <mul>, %arg0 : vector<16xf32> into f32 } : vector<16xi1> -> f32
return %0 : f32
}
// CHECK-LABEL: func.func @masked_reduce_mul_f32(
// CHECK-SAME: %[[INPUT:.*]]: vector<16xf32>,
// CHECK-SAME: %[[MASK:.*]]: vector<16xi1>) -> f32 {
// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
// CHECK: %[[VL:.*]] = llvm.mlir.constant(16 : i32) : i32
// CHECK: "llvm.intr.vp.reduce.fmul"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (f32, vector<16xf32>, vector<16xi1>, i32) -> f32
// -----
func.func @masked_reduce_minf_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -> f32 {
%0 = vector.mask %mask { vector.reduction <minf>, %arg0 : vector<16xf32> into f32 } : vector<16xi1> -> f32
return %0 : f32
}
// CHECK-LABEL: func.func @masked_reduce_minf_f32(
// CHECK-SAME: %[[INPUT:.*]]: vector<16xf32>,
// CHECK-SAME: %[[MASK:.*]]: vector<16xi1>) -> f32 {
// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0xFFC00000 : f32) : f32
// CHECK: %[[VL:.*]] = llvm.mlir.constant(16 : i32) : i32
// CHECK: "llvm.intr.vp.reduce.fmin"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (f32, vector<16xf32>, vector<16xi1>, i32) -> f32
// -----
func.func @masked_reduce_maxf_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -> f32 {
%0 = vector.mask %mask { vector.reduction <maxf>, %arg0 : vector<16xf32> into f32 } : vector<16xi1> -> f32
return %0 : f32
}
// CHECK-LABEL: func.func @masked_reduce_maxf_f32(
// CHECK-SAME: %[[INPUT:.*]]: vector<16xf32>,
// CHECK-SAME: %[[MASK:.*]]: vector<16xi1>) -> f32 {
// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0x7FC00000 : f32) : f32
// CHECK: %[[VL:.*]] = llvm.mlir.constant(16 : i32) : i32
// CHECK: "llvm.intr.vp.reduce.fmax"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (f32, vector<16xf32>, vector<16xi1>, i32) -> f32
// -----
func.func @masked_reduce_add_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
%0 = vector.mask %mask { vector.reduction <add>, %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8
return %0 : i8
}
// CHECK-LABEL: func.func @masked_reduce_add_i8(
// CHECK-SAME: %[[INPUT:.*]]: vector<32xi8>,
// CHECK-SAME: %[[MASK:.*]]: vector<32xi1>) -> i8 {
// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0 : i8) : i8
// CHECK: %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32
// CHECK: "llvm.intr.vp.reduce.add"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8
// -----
func.func @masked_reduce_mul_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
%0 = vector.mask %mask { vector.reduction <mul>, %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8
return %0 : i8
}
// CHECK-LABEL: func.func @masked_reduce_mul_i8(
// CHECK-SAME: %[[INPUT:.*]]: vector<32xi8>,
// CHECK-SAME: %[[MASK:.*]]: vector<32xi1>) -> i8 {
// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(1 : i8) : i8
// CHECK: %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32
// CHECK: %[[VAL_4:.*]] = "llvm.intr.vp.reduce.mul"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8
// -----
func.func @masked_reduce_minui_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
%0 = vector.mask %mask { vector.reduction <minui>, %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8
return %0 : i8
}
// CHECK-LABEL: func.func @masked_reduce_minui_i8(
// CHECK-SAME: %[[INPUT:.*]]: vector<32xi8>,
// CHECK-SAME: %[[MASK:.*]]: vector<32xi1>) -> i8 {
// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(-1 : i8) : i8
// CHECK: %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32
// CHECK: "llvm.intr.vp.reduce.umin"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8
// -----
func.func @masked_reduce_maxui_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
%0 = vector.mask %mask { vector.reduction <maxui>, %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8
return %0 : i8
}
// CHECK-LABEL: func.func @masked_reduce_maxui_i8(
// CHECK-SAME: %[[INPUT:.*]]: vector<32xi8>,
// CHECK-SAME: %[[MASK:.*]]: vector<32xi1>) -> i8 {
// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0 : i8) : i8
// CHECK: %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32
// CHECK: "llvm.intr.vp.reduce.umax"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8
// -----
func.func @masked_reduce_minsi_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
%0 = vector.mask %mask { vector.reduction <minsi>, %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8
return %0 : i8
}
// CHECK-LABEL: func.func @masked_reduce_minsi_i8(
// CHECK-SAME: %[[INPUT:.*]]: vector<32xi8>,
// CHECK-SAME: %[[MASK:.*]]: vector<32xi1>) -> i8 {
// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(127 : i8) : i8
// CHECK: %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32
// CHECK: "llvm.intr.vp.reduce.smin"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8
// -----
func.func @masked_reduce_maxsi_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
%0 = vector.mask %mask { vector.reduction <maxsi>, %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8
return %0 : i8
}
// CHECK-LABEL: func.func @masked_reduce_maxsi_i8(
// CHECK-SAME: %[[INPUT:.*]]: vector<32xi8>,
// CHECK-SAME: %[[MASK:.*]]: vector<32xi1>) -> i8 {
// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(-128 : i8) : i8
// CHECK: %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32
// CHECK: "llvm.intr.vp.reduce.smax"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8
// -----
func.func @masked_reduce_or_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
%0 = vector.mask %mask { vector.reduction <or>, %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8
return %0 : i8
}
// CHECK-LABEL: func.func @masked_reduce_or_i8(
// CHECK-SAME: %[[INPUT:.*]]: vector<32xi8>,
// CHECK-SAME: %[[MASK:.*]]: vector<32xi1>) -> i8 {
// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0 : i8) : i8
// CHECK: %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32
// CHECK: "llvm.intr.vp.reduce.or"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8
// -----
func.func @masked_reduce_and_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
%0 = vector.mask %mask { vector.reduction <and>, %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8
return %0 : i8
}
// CHECK-LABEL: func.func @masked_reduce_and_i8(
// CHECK-SAME: %[[INPUT:.*]]: vector<32xi8>,
// CHECK-SAME: %[[MASK:.*]]: vector<32xi1>) -> i8 {
// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(-1 : i8) : i8
// CHECK: %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32
// CHECK: "llvm.intr.vp.reduce.and"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8
// -----
func.func @masked_reduce_xor_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
%0 = vector.mask %mask { vector.reduction <xor>, %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8
return %0 : i8
}
// CHECK-LABEL: func.func @masked_reduce_xor_i8(
// CHECK-SAME: %[[INPUT:.*]]: vector<32xi8>,
// CHECK-SAME: %[[MASK:.*]]: vector<32xi1>) -> i8 {
// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0 : i8) : i8
// CHECK: %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32
// CHECK: "llvm.intr.vp.reduce.xor"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8