mirror of
https://github.com/intel/llvm.git
synced 2026-02-05 13:21:04 +08:00
[mlir][Vector] Add LLVM lowering for masked reductions
This patch adds the conversion patterns to lower masked reduction operations to the corresponding vp intrinsics in LLVM. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D142177
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
|
||||
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
@@ -408,15 +409,154 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/// Reduction neutral classes for overloading.
|
||||
class ReductionNeutralZero {};
|
||||
class ReductionNeutralIntOne {};
|
||||
class ReductionNeutralFPOne {};
|
||||
class ReductionNeutralAllOnes {};
|
||||
class ReductionNeutralSIntMin {};
|
||||
class ReductionNeutralUIntMin {};
|
||||
class ReductionNeutralSIntMax {};
|
||||
class ReductionNeutralUIntMax {};
|
||||
class ReductionNeutralFPMin {};
|
||||
class ReductionNeutralFPMax {};
|
||||
|
||||
/// Create the reduction neutral zero value.
|
||||
static Value createReductionNeutralValue(ReductionNeutralZero neutral,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Location loc, Type llvmType) {
|
||||
return rewriter.create<LLVM::ConstantOp>(loc, llvmType,
|
||||
rewriter.getZeroAttr(llvmType));
|
||||
}
|
||||
|
||||
/// Create the reduction neutral integer one value.
|
||||
static Value createReductionNeutralValue(ReductionNeutralIntOne neutral,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Location loc, Type llvmType) {
|
||||
return rewriter.create<LLVM::ConstantOp>(
|
||||
loc, llvmType, rewriter.getIntegerAttr(llvmType, 1));
|
||||
}
|
||||
|
||||
/// Create the reduction neutral fp one value.
|
||||
static Value createReductionNeutralValue(ReductionNeutralFPOne neutral,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Location loc, Type llvmType) {
|
||||
return rewriter.create<LLVM::ConstantOp>(
|
||||
loc, llvmType, rewriter.getFloatAttr(llvmType, 1.0));
|
||||
}
|
||||
|
||||
/// Create the reduction neutral all-ones value.
|
||||
static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Location loc, Type llvmType) {
|
||||
return rewriter.create<LLVM::ConstantOp>(
|
||||
loc, llvmType,
|
||||
rewriter.getIntegerAttr(
|
||||
llvmType, llvm::APInt::getAllOnes(llvmType.getIntOrFloatBitWidth())));
|
||||
}
|
||||
|
||||
/// Create the reduction neutral signed int minimum value.
|
||||
static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Location loc, Type llvmType) {
|
||||
return rewriter.create<LLVM::ConstantOp>(
|
||||
loc, llvmType,
|
||||
rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMinValue(
|
||||
llvmType.getIntOrFloatBitWidth())));
|
||||
}
|
||||
|
||||
/// Create the reduction neutral unsigned int minimum value.
|
||||
static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Location loc, Type llvmType) {
|
||||
return rewriter.create<LLVM::ConstantOp>(
|
||||
loc, llvmType,
|
||||
rewriter.getIntegerAttr(llvmType, llvm::APInt::getMinValue(
|
||||
llvmType.getIntOrFloatBitWidth())));
|
||||
}
|
||||
|
||||
/// Create the reduction neutral signed int maximum value.
|
||||
static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Location loc, Type llvmType) {
|
||||
return rewriter.create<LLVM::ConstantOp>(
|
||||
loc, llvmType,
|
||||
rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMaxValue(
|
||||
llvmType.getIntOrFloatBitWidth())));
|
||||
}
|
||||
|
||||
/// Create the reduction neutral unsigned int maximum value.
|
||||
static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Location loc, Type llvmType) {
|
||||
return rewriter.create<LLVM::ConstantOp>(
|
||||
loc, llvmType,
|
||||
rewriter.getIntegerAttr(llvmType, llvm::APInt::getMaxValue(
|
||||
llvmType.getIntOrFloatBitWidth())));
|
||||
}
|
||||
|
||||
/// Create the reduction neutral fp minimum value.
|
||||
static Value createReductionNeutralValue(ReductionNeutralFPMin neutral,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Location loc, Type llvmType) {
|
||||
auto floatType = llvmType.cast<FloatType>();
|
||||
return rewriter.create<LLVM::ConstantOp>(
|
||||
loc, llvmType,
|
||||
rewriter.getFloatAttr(
|
||||
llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
|
||||
/*Negative=*/false)));
|
||||
}
|
||||
|
||||
/// Create the reduction neutral fp maximum value.
|
||||
static Value createReductionNeutralValue(ReductionNeutralFPMax neutral,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Location loc, Type llvmType) {
|
||||
auto floatType = llvmType.cast<FloatType>();
|
||||
return rewriter.create<LLVM::ConstantOp>(
|
||||
loc, llvmType,
|
||||
rewriter.getFloatAttr(
|
||||
llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
|
||||
/*Negative=*/true)));
|
||||
}
|
||||
|
||||
/// Returns `accumulator` if it has a valid value. Otherwise, creates and
|
||||
/// returns a new accumulator value using `ReductionNeutral`.
|
||||
template <class ReductionNeutral>
|
||||
static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Type llvmType,
|
||||
Value accumulator) {
|
||||
if (accumulator)
|
||||
return accumulator;
|
||||
|
||||
return createReductionNeutralValue(ReductionNeutral(), rewriter, loc,
|
||||
llvmType);
|
||||
}
|
||||
|
||||
/// Creates a constant value with the 1-D vector shape provided in `llvmType`.
|
||||
/// This is used as effective vector length by some intrinsics supporting
|
||||
/// dynamic vector lengths at runtime.
|
||||
static Value createVectorLengthValue(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Type llvmType) {
|
||||
VectorType vType = cast<VectorType>(llvmType);
|
||||
auto vShape = vType.getShape();
|
||||
assert(vShape.size() == 1 && "Unexpected multi-dim vector type");
|
||||
|
||||
return rewriter.create<LLVM::ConstantOp>(
|
||||
loc, rewriter.getI32Type(),
|
||||
rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0]));
|
||||
}
|
||||
|
||||
/// Helper method to lower a `vector.reduction` op that performs an arithmetic
|
||||
/// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use
|
||||
/// and `ScalarOp` is the scalar operation used to add the accumulation value if
|
||||
/// non-null.
|
||||
template <class VectorOp, class ScalarOp>
|
||||
template <class LLVMRedIntrinOp, class ScalarOp>
|
||||
static Value createIntegerReductionArithmeticOpLowering(
|
||||
ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
|
||||
Value vectorOperand, Value accumulator) {
|
||||
Value result = rewriter.create<VectorOp>(loc, llvmType, vectorOperand);
|
||||
|
||||
Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
|
||||
|
||||
if (accumulator)
|
||||
result = rewriter.create<ScalarOp>(loc, accumulator, result);
|
||||
return result;
|
||||
@@ -426,11 +566,11 @@ static Value createIntegerReductionArithmeticOpLowering(
|
||||
/// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector
|
||||
/// intrinsic to use and `predicate` is the predicate to use to compare+combine
|
||||
/// the accumulator value if non-null.
|
||||
template <class VectorOp>
|
||||
template <class LLVMRedIntrinOp>
|
||||
static Value createIntegerReductionComparisonOpLowering(
|
||||
ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
|
||||
Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) {
|
||||
Value result = rewriter.create<VectorOp>(loc, llvmType, vectorOperand);
|
||||
Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
|
||||
if (accumulator) {
|
||||
Value cmp =
|
||||
rewriter.create<LLVM::ICmpOp>(loc, predicate, accumulator, result);
|
||||
@@ -460,6 +600,91 @@ static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
|
||||
return builder.create<LLVM::SelectOp>(loc, isNan, nan, sel);
|
||||
}
|
||||
|
||||
template <class LLVMRedIntrinOp>
|
||||
static Value createFPReductionComparisonOpLowering(
|
||||
ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
|
||||
Value vectorOperand, Value accumulator, bool isMin) {
|
||||
Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
|
||||
|
||||
if (accumulator)
|
||||
result = createMinMaxF(rewriter, loc, result, accumulator, /*isMin=*/isMin);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Overloaded methods to lower a reduction to an llvm instrinsic that requires
|
||||
/// a start value. This start value format spans across fp reductions without
|
||||
/// mask and all the masked reduction intrinsics.
|
||||
template <class LLVMVPRedIntrinOp, class ReductionNeutral>
|
||||
static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Type llvmType,
|
||||
Value vectorOperand,
|
||||
Value accumulator) {
|
||||
accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
|
||||
llvmType, accumulator);
|
||||
return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
|
||||
/*startValue=*/accumulator,
|
||||
vectorOperand);
|
||||
}
|
||||
|
||||
template <class LLVMVPRedIntrinOp, class ReductionNeutral>
|
||||
static Value
|
||||
lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Type llvmType, Value vectorOperand,
|
||||
Value accumulator, bool reassociateFPReds) {
|
||||
accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
|
||||
llvmType, accumulator);
|
||||
return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
|
||||
/*startValue=*/accumulator,
|
||||
vectorOperand, reassociateFPReds);
|
||||
}
|
||||
|
||||
template <class LLVMVPRedIntrinOp, class ReductionNeutral>
|
||||
static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Type llvmType,
|
||||
Value vectorOperand,
|
||||
Value accumulator, Value mask) {
|
||||
accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
|
||||
llvmType, accumulator);
|
||||
Value vectorLength =
|
||||
createVectorLengthValue(rewriter, loc, vectorOperand.getType());
|
||||
return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
|
||||
/*startValue=*/accumulator,
|
||||
vectorOperand, mask, vectorLength);
|
||||
}
|
||||
|
||||
template <class LLVMVPRedIntrinOp, class ReductionNeutral>
|
||||
static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Type llvmType,
|
||||
Value vectorOperand,
|
||||
Value accumulator, Value mask,
|
||||
bool reassociateFPReds) {
|
||||
accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
|
||||
llvmType, accumulator);
|
||||
Value vectorLength =
|
||||
createVectorLengthValue(rewriter, loc, vectorOperand.getType());
|
||||
return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
|
||||
/*startValue=*/accumulator,
|
||||
vectorOperand, mask, vectorLength,
|
||||
reassociateFPReds);
|
||||
}
|
||||
|
||||
template <class LLVMIntVPRedIntrinOp, class IntReductionNeutral,
|
||||
class LLVMFPVPRedIntrinOp, class FPReductionNeutral>
|
||||
static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Type llvmType,
|
||||
Value vectorOperand,
|
||||
Value accumulator, Value mask) {
|
||||
if (llvmType.isIntOrIndex())
|
||||
return lowerReductionWithStartValue<LLVMIntVPRedIntrinOp,
|
||||
IntReductionNeutral>(
|
||||
rewriter, loc, llvmType, vectorOperand, accumulator, mask);
|
||||
|
||||
// FP dispatch.
|
||||
return lowerReductionWithStartValue<LLVMFPVPRedIntrinOp, FPReductionNeutral>(
|
||||
rewriter, loc, llvmType, vectorOperand, accumulator, mask);
|
||||
}
|
||||
|
||||
/// Conversion pattern for all vector reductions.
|
||||
class VectorReductionOpConversion
|
||||
: public ConvertOpToLLVMPattern<vector::ReductionOp> {
|
||||
@@ -478,6 +703,12 @@ public:
|
||||
Value operand = adaptor.getVector();
|
||||
Value acc = adaptor.getAcc();
|
||||
Location loc = reductionOp.getLoc();
|
||||
|
||||
// Masked reductions are lowered separately.
|
||||
auto maskableOp = cast<MaskableOpInterface>(reductionOp.getOperation());
|
||||
if (maskableOp.isMasked())
|
||||
return failure();
|
||||
|
||||
if (eltType.isIntOrIndex()) {
|
||||
// Integer reductions: add/mul/min/max/and/or/xor.
|
||||
Value result;
|
||||
@@ -544,45 +775,31 @@ public:
|
||||
return failure();
|
||||
|
||||
// Floating-point reductions: add/mul/min/max
|
||||
Value result;
|
||||
if (kind == vector::CombiningKind::ADD) {
|
||||
// Optional accumulator (or zero).
|
||||
Value acc = adaptor.getOperands().size() > 1
|
||||
? adaptor.getOperands()[1]
|
||||
: rewriter.create<LLVM::ConstantOp>(
|
||||
reductionOp->getLoc(), llvmType,
|
||||
rewriter.getZeroAttr(eltType));
|
||||
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>(
|
||||
reductionOp, llvmType, acc, operand,
|
||||
rewriter.getBoolAttr(reassociateFPReductions));
|
||||
result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd,
|
||||
ReductionNeutralZero>(
|
||||
rewriter, loc, llvmType, operand, acc, reassociateFPReductions);
|
||||
} else if (kind == vector::CombiningKind::MUL) {
|
||||
// Optional accumulator (or one).
|
||||
Value acc = adaptor.getOperands().size() > 1
|
||||
? adaptor.getOperands()[1]
|
||||
: rewriter.create<LLVM::ConstantOp>(
|
||||
reductionOp->getLoc(), llvmType,
|
||||
rewriter.getFloatAttr(eltType, 1.0));
|
||||
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
|
||||
reductionOp, llvmType, acc, operand,
|
||||
rewriter.getBoolAttr(reassociateFPReductions));
|
||||
result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul,
|
||||
ReductionNeutralFPOne>(
|
||||
rewriter, loc, llvmType, operand, acc, reassociateFPReductions);
|
||||
} else if (kind == vector::CombiningKind::MINF) {
|
||||
// FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle
|
||||
// NaNs/-0.0/+0.0 in the same way.
|
||||
Value result =
|
||||
rewriter.create<LLVM::vector_reduce_fmin>(loc, llvmType, operand);
|
||||
if (acc)
|
||||
result = createMinMaxF(rewriter, loc, result, acc, /*isMin=*/true);
|
||||
rewriter.replaceOp(reductionOp, result);
|
||||
result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>(
|
||||
rewriter, loc, llvmType, operand, acc,
|
||||
/*isMin=*/true);
|
||||
} else if (kind == vector::CombiningKind::MAXF) {
|
||||
// FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle
|
||||
// NaNs/-0.0/+0.0 in the same way.
|
||||
Value result =
|
||||
rewriter.create<LLVM::vector_reduce_fmax>(loc, llvmType, operand);
|
||||
if (acc)
|
||||
result = createMinMaxF(rewriter, loc, result, acc, /*isMin=*/false);
|
||||
rewriter.replaceOp(reductionOp, result);
|
||||
result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>(
|
||||
rewriter, loc, llvmType, operand, acc,
|
||||
/*isMin=*/false);
|
||||
} else
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOp(reductionOp, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -590,6 +807,127 @@ private:
|
||||
const bool reassociateFPReductions;
|
||||
};
|
||||
|
||||
/// Base class to convert a `vector.mask` operation while matching traits
|
||||
/// of the maskable operation nested inside. A `VectorMaskOpConversionBase`
|
||||
/// instance matches against a `vector.mask` operation. The `matchAndRewrite`
|
||||
/// method performs a second match against the maskable operation `MaskedOp`.
|
||||
/// Finally, it invokes the virtual method `matchAndRewriteMaskableOp` to be
|
||||
/// implemented by the concrete conversion classes. This method can match
|
||||
/// against specific traits of the `vector.mask` and the maskable operation. It
|
||||
/// must replace the `vector.mask` operation.
|
||||
template <class MaskedOp>
|
||||
class VectorMaskOpConversionBase
|
||||
: public ConvertOpToLLVMPattern<vector::MaskOp> {
|
||||
public:
|
||||
using ConvertOpToLLVMPattern<vector::MaskOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override final {
|
||||
// Match against the maskable operation kind.
|
||||
Operation *maskableOp = maskOp.getMaskableOp();
|
||||
if (!isa<MaskedOp>(maskableOp))
|
||||
return failure();
|
||||
return matchAndRewriteMaskableOp(
|
||||
maskOp, cast<MaskedOp>(maskOp.getMaskableOp()), rewriter);
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual LogicalResult
|
||||
matchAndRewriteMaskableOp(vector::MaskOp maskOp,
|
||||
vector::MaskableOpInterface maskableOp,
|
||||
ConversionPatternRewriter &rewriter) const = 0;
|
||||
};
|
||||
|
||||
class MaskedReductionOpConversion
|
||||
: public VectorMaskOpConversionBase<vector::ReductionOp> {
|
||||
|
||||
public:
|
||||
using VectorMaskOpConversionBase<
|
||||
vector::ReductionOp>::VectorMaskOpConversionBase;
|
||||
|
||||
virtual LogicalResult matchAndRewriteMaskableOp(
|
||||
vector::MaskOp maskOp, MaskableOpInterface maskableOp,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto reductionOp = cast<ReductionOp>(maskableOp.getOperation());
|
||||
auto kind = reductionOp.getKind();
|
||||
Type eltType = reductionOp.getDest().getType();
|
||||
Type llvmType = typeConverter->convertType(eltType);
|
||||
Value operand = reductionOp.getVector();
|
||||
Value acc = reductionOp.getAcc();
|
||||
Location loc = reductionOp.getLoc();
|
||||
|
||||
Value result;
|
||||
switch (kind) {
|
||||
case vector::CombiningKind::ADD:
|
||||
result = lowerReductionWithStartValue<
|
||||
LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp,
|
||||
ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc,
|
||||
maskOp.getMask());
|
||||
break;
|
||||
case vector::CombiningKind::MUL:
|
||||
result = lowerReductionWithStartValue<
|
||||
LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp,
|
||||
ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc,
|
||||
maskOp.getMask());
|
||||
break;
|
||||
case vector::CombiningKind::MINUI:
|
||||
result = lowerReductionWithStartValue<LLVM::VPReduceUMinOp,
|
||||
ReductionNeutralUIntMax>(
|
||||
rewriter, loc, llvmType, operand, acc, maskOp.getMask());
|
||||
break;
|
||||
case vector::CombiningKind::MINSI:
|
||||
result = lowerReductionWithStartValue<LLVM::VPReduceSMinOp,
|
||||
ReductionNeutralSIntMax>(
|
||||
rewriter, loc, llvmType, operand, acc, maskOp.getMask());
|
||||
break;
|
||||
case vector::CombiningKind::MAXUI:
|
||||
result = lowerReductionWithStartValue<LLVM::VPReduceUMaxOp,
|
||||
ReductionNeutralUIntMin>(
|
||||
rewriter, loc, llvmType, operand, acc, maskOp.getMask());
|
||||
break;
|
||||
case vector::CombiningKind::MAXSI:
|
||||
result = lowerReductionWithStartValue<LLVM::VPReduceSMaxOp,
|
||||
ReductionNeutralSIntMin>(
|
||||
rewriter, loc, llvmType, operand, acc, maskOp.getMask());
|
||||
break;
|
||||
case vector::CombiningKind::AND:
|
||||
result = lowerReductionWithStartValue<LLVM::VPReduceAndOp,
|
||||
ReductionNeutralAllOnes>(
|
||||
rewriter, loc, llvmType, operand, acc, maskOp.getMask());
|
||||
break;
|
||||
case vector::CombiningKind::OR:
|
||||
result = lowerReductionWithStartValue<LLVM::VPReduceOrOp,
|
||||
ReductionNeutralZero>(
|
||||
rewriter, loc, llvmType, operand, acc, maskOp.getMask());
|
||||
break;
|
||||
case vector::CombiningKind::XOR:
|
||||
result = lowerReductionWithStartValue<LLVM::VPReduceXorOp,
|
||||
ReductionNeutralZero>(
|
||||
rewriter, loc, llvmType, operand, acc, maskOp.getMask());
|
||||
break;
|
||||
case vector::CombiningKind::MINF:
|
||||
// FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle
|
||||
// NaNs/-0.0/+0.0 in the same way.
|
||||
result = lowerReductionWithStartValue<LLVM::VPReduceFMinOp,
|
||||
ReductionNeutralFPMax>(
|
||||
rewriter, loc, llvmType, operand, acc, maskOp.getMask());
|
||||
break;
|
||||
case vector::CombiningKind::MAXF:
|
||||
// FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle
|
||||
// NaNs/-0.0/+0.0 in the same way.
|
||||
result = lowerReductionWithStartValue<LLVM::VPReduceFMaxOp,
|
||||
ReductionNeutralFPMin>(
|
||||
rewriter, loc, llvmType, operand, acc, maskOp.getMask());
|
||||
break;
|
||||
}
|
||||
|
||||
// Replace `vector.mask` operation altogether.
|
||||
rewriter.replaceOp(maskOp, result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class VectorShuffleOpConversion
|
||||
: public ConvertOpToLLVMPattern<vector::ShuffleOp> {
|
||||
public:
|
||||
@@ -1381,8 +1719,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
|
||||
VectorGatherOpConversion, VectorScatterOpConversion,
|
||||
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
|
||||
VectorSplatOpLowering, VectorSplatNdOpLowering,
|
||||
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering>(
|
||||
converter);
|
||||
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
|
||||
MaskedReductionOpConversion>(converter);
|
||||
// Transfer ops with rank > 1 are handled by VectorToSCF.
|
||||
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
// RUN: mlir-opt %s -convert-vector-to-llvm='use-opaque-pointers=1' | FileCheck %s
|
||||
// RUN: mlir-opt %s -convert-vector-to-llvm='reassociate-fp-reductions use-opaque-pointers=1' | FileCheck %s --check-prefix=REASSOC
|
||||
// RUN: mlir-opt %s -convert-vector-to-llvm='use-opaque-pointers=1' -split-input-file | FileCheck %s
|
||||
// RUN: mlir-opt %s -convert-vector-to-llvm='reassociate-fp-reductions use-opaque-pointers=1' -split-input-file | FileCheck %s --check-prefix=REASSOC
|
||||
|
||||
//
|
||||
// CHECK-LABEL: @reduce_add_f32(
|
||||
// CHECK-SAME: %[[A:.*]]: vector<16xf32>)
|
||||
// CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
|
||||
@@ -21,7 +20,8 @@ func.func @reduce_add_f32(%arg0: vector<16xf32>) -> f32 {
|
||||
return %0 : f32
|
||||
}
|
||||
|
||||
//
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @reduce_mul_f32(
|
||||
// CHECK-SAME: %[[A:.*]]: vector<16xf32>)
|
||||
// CHECK: %[[C:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
|
||||
@@ -40,3 +40,191 @@ func.func @reduce_mul_f32(%arg0: vector<16xf32>) -> f32 {
|
||||
%0 = vector.reduction <mul>, %arg0 : vector<16xf32> into f32
|
||||
return %0 : f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @masked_reduce_add_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -> f32 {
|
||||
%0 = vector.mask %mask { vector.reduction <add>, %arg0 : vector<16xf32> into f32 } : vector<16xi1> -> f32
|
||||
return %0 : f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @masked_reduce_add_f32(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: vector<16xf32>,
|
||||
// CHECK-SAME: %[[MASK:.*]]: vector<16xi1>) -> f32 {
|
||||
// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
|
||||
// CHECK: %[[VL:.*]] = llvm.mlir.constant(16 : i32) : i32
|
||||
// CHECK: "llvm.intr.vp.reduce.fadd"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (f32, vector<16xf32>, vector<16xi1>, i32) -> f32
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
func.func @masked_reduce_mul_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -> f32 {
|
||||
%0 = vector.mask %mask { vector.reduction <mul>, %arg0 : vector<16xf32> into f32 } : vector<16xi1> -> f32
|
||||
return %0 : f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @masked_reduce_mul_f32(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: vector<16xf32>,
|
||||
// CHECK-SAME: %[[MASK:.*]]: vector<16xi1>) -> f32 {
|
||||
// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
|
||||
// CHECK: %[[VL:.*]] = llvm.mlir.constant(16 : i32) : i32
|
||||
// CHECK: "llvm.intr.vp.reduce.fmul"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (f32, vector<16xf32>, vector<16xi1>, i32) -> f32
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
func.func @masked_reduce_minf_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -> f32 {
|
||||
%0 = vector.mask %mask { vector.reduction <minf>, %arg0 : vector<16xf32> into f32 } : vector<16xi1> -> f32
|
||||
return %0 : f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @masked_reduce_minf_f32(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: vector<16xf32>,
|
||||
// CHECK-SAME: %[[MASK:.*]]: vector<16xi1>) -> f32 {
|
||||
// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0xFFC00000 : f32) : f32
|
||||
// CHECK: %[[VL:.*]] = llvm.mlir.constant(16 : i32) : i32
|
||||
// CHECK: "llvm.intr.vp.reduce.fmin"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (f32, vector<16xf32>, vector<16xi1>, i32) -> f32
|
||||
|
||||
// -----
|
||||
|
||||
func.func @masked_reduce_maxf_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -> f32 {
|
||||
%0 = vector.mask %mask { vector.reduction <maxf>, %arg0 : vector<16xf32> into f32 } : vector<16xi1> -> f32
|
||||
return %0 : f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @masked_reduce_maxf_f32(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: vector<16xf32>,
|
||||
// CHECK-SAME: %[[MASK:.*]]: vector<16xi1>) -> f32 {
|
||||
// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0x7FC00000 : f32) : f32
|
||||
// CHECK: %[[VL:.*]] = llvm.mlir.constant(16 : i32) : i32
|
||||
// CHECK: "llvm.intr.vp.reduce.fmax"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (f32, vector<16xf32>, vector<16xi1>, i32) -> f32
|
||||
|
||||
// -----
|
||||
|
||||
func.func @masked_reduce_add_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
|
||||
%0 = vector.mask %mask { vector.reduction <add>, %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8
|
||||
return %0 : i8
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @masked_reduce_add_i8(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: vector<32xi8>,
|
||||
// CHECK-SAME: %[[MASK:.*]]: vector<32xi1>) -> i8 {
|
||||
// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0 : i8) : i8
|
||||
// CHECK: %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32
|
||||
// CHECK: "llvm.intr.vp.reduce.add"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
func.func @masked_reduce_mul_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
|
||||
%0 = vector.mask %mask { vector.reduction <mul>, %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8
|
||||
return %0 : i8
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @masked_reduce_mul_i8(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: vector<32xi8>,
|
||||
// CHECK-SAME: %[[MASK:.*]]: vector<32xi1>) -> i8 {
|
||||
// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(1 : i8) : i8
|
||||
// CHECK: %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32
|
||||
// CHECK: %[[VAL_4:.*]] = "llvm.intr.vp.reduce.mul"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8
|
||||
|
||||
// -----
|
||||
|
||||
func.func @masked_reduce_minui_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
|
||||
%0 = vector.mask %mask { vector.reduction <minui>, %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8
|
||||
return %0 : i8
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @masked_reduce_minui_i8(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: vector<32xi8>,
|
||||
// CHECK-SAME: %[[MASK:.*]]: vector<32xi1>) -> i8 {
|
||||
// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(-1 : i8) : i8
|
||||
// CHECK: %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32
|
||||
// CHECK: "llvm.intr.vp.reduce.umin"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8
|
||||
|
||||
// -----
|
||||
|
||||
func.func @masked_reduce_maxui_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
|
||||
%0 = vector.mask %mask { vector.reduction <maxui>, %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8
|
||||
return %0 : i8
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @masked_reduce_maxui_i8(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: vector<32xi8>,
|
||||
// CHECK-SAME: %[[MASK:.*]]: vector<32xi1>) -> i8 {
|
||||
// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0 : i8) : i8
|
||||
// CHECK: %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32
|
||||
// CHECK: "llvm.intr.vp.reduce.umax"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8
|
||||
|
||||
// -----
|
||||
|
||||
func.func @masked_reduce_minsi_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
|
||||
%0 = vector.mask %mask { vector.reduction <minsi>, %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8
|
||||
return %0 : i8
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @masked_reduce_minsi_i8(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: vector<32xi8>,
|
||||
// CHECK-SAME: %[[MASK:.*]]: vector<32xi1>) -> i8 {
|
||||
// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(127 : i8) : i8
|
||||
// CHECK: %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32
|
||||
// CHECK: "llvm.intr.vp.reduce.smin"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8
|
||||
|
||||
// -----
|
||||
|
||||
func.func @masked_reduce_maxsi_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
|
||||
%0 = vector.mask %mask { vector.reduction <maxsi>, %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8
|
||||
return %0 : i8
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @masked_reduce_maxsi_i8(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: vector<32xi8>,
|
||||
// CHECK-SAME: %[[MASK:.*]]: vector<32xi1>) -> i8 {
|
||||
// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(-128 : i8) : i8
|
||||
// CHECK: %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32
|
||||
// CHECK: "llvm.intr.vp.reduce.smax"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8
|
||||
|
||||
// -----
|
||||
|
||||
func.func @masked_reduce_or_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
|
||||
%0 = vector.mask %mask { vector.reduction <or>, %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8
|
||||
return %0 : i8
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @masked_reduce_or_i8(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: vector<32xi8>,
|
||||
// CHECK-SAME: %[[MASK:.*]]: vector<32xi1>) -> i8 {
|
||||
// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0 : i8) : i8
|
||||
// CHECK: %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32
|
||||
// CHECK: "llvm.intr.vp.reduce.or"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
func.func @masked_reduce_and_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
|
||||
%0 = vector.mask %mask { vector.reduction <and>, %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8
|
||||
return %0 : i8
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @masked_reduce_and_i8(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: vector<32xi8>,
|
||||
// CHECK-SAME: %[[MASK:.*]]: vector<32xi1>) -> i8 {
|
||||
// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(-1 : i8) : i8
|
||||
// CHECK: %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32
|
||||
// CHECK: "llvm.intr.vp.reduce.and"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8
|
||||
|
||||
// -----
|
||||
|
||||
func.func @masked_reduce_xor_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
|
||||
%0 = vector.mask %mask { vector.reduction <xor>, %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8
|
||||
return %0 : i8
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @masked_reduce_xor_i8(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: vector<32xi8>,
|
||||
// CHECK-SAME: %[[MASK:.*]]: vector<32xi1>) -> i8 {
|
||||
// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0 : i8) : i8
|
||||
// CHECK: %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32
|
||||
// CHECK: "llvm.intr.vp.reduce.xor"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user