[mlir][Vector] Add fastmath flags to vector.reduction (#66905)

This revision pipes the fastmath attribute support through the
vector.reduction op. This seemingly simple first step already requires
quite some genuflexions, file and builder reorganization. In the
process, retire the boolean reassoc flag deep in the LLVM dialect
builders and just use the fastmath attribute.

During conversions, templated builders for predicated intrinsics are
partially cleaned up. In the future, to finalize the cleanups, one
should consider adding fastmath to the VPIntrinsic ops.
This commit is contained in:
Nicolas Vasilache
2023-09-20 16:57:20 +02:00
committed by GitHub
parent ebefe83c09
commit 1b8b556443
15 changed files with 327 additions and 231 deletions

View File

@@ -654,32 +654,14 @@ class LLVM_VecReductionI<string mnem>
// LLVM vector reduction over a single vector, with an initial value,
// and with permission to reassociate the reduction operations.
class LLVM_VecReductionAccBase<string mnem, Type element>
: LLVM_OneResultIntrOp<"vector.reduce." # mnem, [], [0],
[Pure, SameOperandsAndResultElementType]>,
Arguments<(ins element:$start_value, LLVM_VectorOf<element>:$input,
DefaultValuedAttr<BoolAttr, "false">:$reassoc)> {
let llvmBuilder = [{
llvm::Module *module = builder.GetInsertBlock()->getModule();
llvm::Function *fn = llvm::Intrinsic::getDeclaration(
module,
llvm::Intrinsic::vector_reduce_}] # mnem # [{,
{ }] # !interleave(ListIntSubst<LLVM_IntrPatterns.operand, [1]>.lst,
", ") # [{
});
auto operands = moduleTranslation.lookupValues(opInst.getOperands());
llvm::FastMathFlags origFM = builder.getFastMathFlags();
llvm::FastMathFlags tempFM = origFM;
tempFM.setAllowReassoc($reassoc);
builder.setFastMathFlags(tempFM); // set fastmath flag
$res = builder.CreateCall(fn, operands);
builder.setFastMathFlags(origFM); // restore fastmath flag
}];
let mlirBuilder = [{
bool allowReassoc = inst->getFastMathFlags().allowReassoc();
$res = $_builder.create<$_qualCppClassName>($_location,
$_resultType, $start_value, $input, allowReassoc);
}];
}
: LLVM_OneResultIntrOp</*mnem=*/"vector.reduce." # mnem,
/*overloadedResults=*/[],
/*overloadedOperands=*/[1],
/*traits=*/[Pure, SameOperandsAndResultElementType],
/*equiresFastmath=*/1>,
Arguments<(ins element:$start_value,
LLVM_VectorOf<element>:$input,
DefaultValuedAttr<LLVM_FastmathFlagsAttr, "{}">:$fastmathFlags)>;
class LLVM_VecReductionAccF<string mnem>
: LLVM_VecReductionAccBase<mnem, AnyFloat>;

View File

@@ -1,10 +1,18 @@
add_mlir_dialect(VectorOps vector)
add_mlir_doc(VectorOps VectorOps Dialects/ -gen-op-doc)
add_mlir_dialect(Vector vector)
add_mlir_doc(Vector Vector Dialects/ -gen-op-doc -dialect=vector)
# Add Vector operations
set(LLVM_TARGET_DEFINITIONS VectorOps.td)
mlir_tablegen(VectorOpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(VectorOpsEnums.cpp.inc -gen-enum-defs)
mlir_tablegen(VectorOpsAttrDefs.h.inc -gen-attrdef-decls)
mlir_tablegen(VectorOpsAttrDefs.cpp.inc -gen-attrdef-defs)
add_public_tablegen_target(MLIRVectorOpsEnumsIncGen)
add_dependencies(mlir-headers MLIRVectorOpsEnumsIncGen)
mlir_tablegen(VectorOps.h.inc -gen-op-decls)
mlir_tablegen(VectorOps.cpp.inc -gen-op-defs)
add_public_tablegen_target(MLIRVectorOpsIncGen)
add_dependencies(mlir-generic-headers MLIRVectorOpsIncGen)
# Add Vector attributes
set(LLVM_TARGET_DEFINITIONS VectorAttributes.td)
mlir_tablegen(VectorEnums.h.inc -gen-enum-decls)
mlir_tablegen(VectorEnums.cpp.inc -gen-enum-defs)
mlir_tablegen(VectorAttributes.h.inc -gen-attrdef-decls)
mlir_tablegen(VectorAttributes.cpp.inc -gen-attrdef-defs)
add_public_tablegen_target(MLIRVectorAttributesIncGen)
add_dependencies(mlir-generic-headers MLIRVectorAttributesIncGen)

View File

@@ -0,0 +1,31 @@
//===- Vector.td - Vector Dialect --------------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file declares the Vector dialect.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_VECTOR_IR_VECTOR
#define MLIR_DIALECT_VECTOR_IR_VECTOR
include "mlir/IR/OpBase.td"
def Vector_Dialect : Dialect {
let name = "vector";
let cppNamespace = "::mlir::vector";
let useDefaultAttributePrinterParser = 1;
let hasConstantMaterializer = 1;
let dependentDialects = ["arith::ArithDialect"];
}
// Base class for Vector dialect ops.
class Vector_Op<string mnemonic, list<Trait> traits = []> :
Op<Vector_Dialect, mnemonic, traits>;
#endif // MLIR_DIALECT_VECTOR_IR_VECTOR

View File

@@ -0,0 +1,85 @@
//===- VectorAttributes.td - Vector Dialect ----------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file declares the attributes used in the Vector dialect.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_VECTOR_IR_VECTOR_ATTRIBUTES
#define MLIR_DIALECT_VECTOR_IR_VECTOR_ATTRIBUTES
include "mlir/Dialect/Vector/IR/Vector.td"
include "mlir/IR/EnumAttr.td"
// The "kind" of combining function for contractions and reductions.
def COMBINING_KIND_ADD : I32BitEnumAttrCaseBit<"ADD", 0, "add">;
def COMBINING_KIND_MUL : I32BitEnumAttrCaseBit<"MUL", 1, "mul">;
def COMBINING_KIND_MINUI : I32BitEnumAttrCaseBit<"MINUI", 2, "minui">;
def COMBINING_KIND_MINSI : I32BitEnumAttrCaseBit<"MINSI", 3, "minsi">;
def COMBINING_KIND_MINF : I32BitEnumAttrCaseBit<"MINF", 4, "minf">;
def COMBINING_KIND_MAXUI : I32BitEnumAttrCaseBit<"MAXUI", 5, "maxui">;
def COMBINING_KIND_MAXSI : I32BitEnumAttrCaseBit<"MAXSI", 6, "maxsi">;
def COMBINING_KIND_MAXF : I32BitEnumAttrCaseBit<"MAXF", 7, "maxf">;
def COMBINING_KIND_AND : I32BitEnumAttrCaseBit<"AND", 8, "and">;
def COMBINING_KIND_OR : I32BitEnumAttrCaseBit<"OR", 9, "or">;
def COMBINING_KIND_XOR : I32BitEnumAttrCaseBit<"XOR", 10, "xor">;
def COMBINING_KIND_MINIMUMF : I32BitEnumAttrCaseBit<"MINIMUMF", 11, "minimumf">;
def COMBINING_KIND_MAXIMUMF : I32BitEnumAttrCaseBit<"MAXIMUMF", 12, "maximumf">;
def CombiningKind : I32BitEnumAttr<
"CombiningKind",
"Kind of combining function for contractions and reductions",
[COMBINING_KIND_ADD, COMBINING_KIND_MUL, COMBINING_KIND_MINUI,
COMBINING_KIND_MINSI, COMBINING_KIND_MINF, COMBINING_KIND_MAXUI,
COMBINING_KIND_MAXSI, COMBINING_KIND_MAXF, COMBINING_KIND_AND,
COMBINING_KIND_OR, COMBINING_KIND_XOR,
COMBINING_KIND_MAXIMUMF, COMBINING_KIND_MINIMUMF]> {
let cppNamespace = "::mlir::vector";
let genSpecializedAttr = 0;
}
/// An attribute that specifies the combining function for `vector.contract`,
/// and `vector.reduction`.
def Vector_CombiningKindAttr : EnumAttr<Vector_Dialect, CombiningKind, "kind"> {
let assemblyFormat = "`<` $value `>`";
}
def Vector_IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [
I32EnumAttrCase<"parallel", 0>,
I32EnumAttrCase<"reduction", 1>
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::vector";
}
def Vector_IteratorTypeEnum
: EnumAttr<Vector_Dialect, Vector_IteratorType, "iterator_type"> {
let assemblyFormat = "`<` $value `>`";
}
def Vector_IteratorTypeArrayAttr
: TypedArrayAttrBase<Vector_IteratorTypeEnum,
"Iterator type should be an enum.">;
def PrintPunctuation : I32EnumAttr<"PrintPunctuation",
"Punctuation for separating vectors or vector elements", [
I32EnumAttrCase<"NoPunctuation", 0, "no_punctuation">,
I32EnumAttrCase<"NewLine", 1, "newline">,
I32EnumAttrCase<"Comma", 2, "comma">,
I32EnumAttrCase<"Open", 3, "open">,
I32EnumAttrCase<"Close", 4, "close">
]> {
let cppNamespace = "::mlir::vector";
let genSpecializedAttr = 0;
}
def Vector_PrintPunctuation : EnumAttr<Vector_Dialect, PrintPunctuation, "punctuation"> {
let assemblyFormat = "`<` $value `>`";
}
#endif // MLIR_DIALECT_VECTOR_IR_VECTOR_ATTRIBUTES

View File

@@ -14,6 +14,7 @@
#define MLIR_DIALECT_VECTOR_IR_VECTOROPS_H
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
#include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.h"
#include "mlir/IR/AffineMap.h"
@@ -31,10 +32,10 @@
#include "llvm/ADT/StringExtras.h"
// Pull in all enum type definitions and utility function declarations.
#include "mlir/Dialect/Vector/IR/VectorOpsEnums.h.inc"
#include "mlir/Dialect/Vector/IR/VectorEnums.h.inc"
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/Vector/IR/VectorOpsAttrDefs.h.inc"
#include "mlir/Dialect/Vector/IR/VectorAttributes.h.inc"
namespace mlir {
class MLIRContext;
@@ -157,7 +158,7 @@ Value selectPassthru(OpBuilder &builder, Value mask, Value newValue,
} // namespace mlir
#define GET_OP_CLASSES
#include "mlir/Dialect/Vector/IR/VectorDialect.h.inc"
#include "mlir/Dialect/Vector/IR/VectorOps.h.inc"
#include "mlir/Dialect/Vector/IR/VectorOpsDialect.h.inc"
#endif // MLIR_DIALECT_VECTOR_IR_VECTOROPS_H

View File

@@ -10,9 +10,13 @@
//
//===----------------------------------------------------------------------===//
#ifndef VECTOR_OPS
#define VECTOR_OPS
#ifndef MLIR_DIALECT_VECTOR_IR_VECTOR_OPS
#define MLIR_DIALECT_VECTOR_IR_VECTOR_OPS
include "mlir/Dialect/Vector/IR/Vector.td"
include "mlir/Dialect/Vector/IR/VectorAttributes.td"
include "mlir/Dialect/Arith/IR/ArithBase.td"
include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td"
include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.td"
include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.td"
include "mlir/IR/EnumAttr.td"
@@ -23,69 +27,6 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/VectorInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
def Vector_Dialect : Dialect {
let name = "vector";
let cppNamespace = "::mlir::vector";
let useDefaultAttributePrinterParser = 1;
let hasConstantMaterializer = 1;
let dependentDialects = ["arith::ArithDialect"];
}
// Base class for Vector dialect ops.
class Vector_Op<string mnemonic, list<Trait> traits = []> :
Op<Vector_Dialect, mnemonic, traits>;
// The "kind" of combining function for contractions and reductions.
def COMBINING_KIND_ADD : I32BitEnumAttrCaseBit<"ADD", 0, "add">;
def COMBINING_KIND_MUL : I32BitEnumAttrCaseBit<"MUL", 1, "mul">;
def COMBINING_KIND_MINUI : I32BitEnumAttrCaseBit<"MINUI", 2, "minui">;
def COMBINING_KIND_MINSI : I32BitEnumAttrCaseBit<"MINSI", 3, "minsi">;
def COMBINING_KIND_MINF : I32BitEnumAttrCaseBit<"MINF", 4, "minf">;
def COMBINING_KIND_MAXUI : I32BitEnumAttrCaseBit<"MAXUI", 5, "maxui">;
def COMBINING_KIND_MAXSI : I32BitEnumAttrCaseBit<"MAXSI", 6, "maxsi">;
def COMBINING_KIND_MAXF : I32BitEnumAttrCaseBit<"MAXF", 7, "maxf">;
def COMBINING_KIND_AND : I32BitEnumAttrCaseBit<"AND", 8, "and">;
def COMBINING_KIND_OR : I32BitEnumAttrCaseBit<"OR", 9, "or">;
def COMBINING_KIND_XOR : I32BitEnumAttrCaseBit<"XOR", 10, "xor">;
def COMBINING_KIND_MINIMUMF : I32BitEnumAttrCaseBit<"MINIMUMF", 11, "minimumf">;
def COMBINING_KIND_MAXIMUMF : I32BitEnumAttrCaseBit<"MAXIMUMF", 12, "maximumf">;
def CombiningKind : I32BitEnumAttr<
"CombiningKind",
"Kind of combining function for contractions and reductions",
[COMBINING_KIND_ADD, COMBINING_KIND_MUL, COMBINING_KIND_MINUI,
COMBINING_KIND_MINSI, COMBINING_KIND_MINF, COMBINING_KIND_MAXUI,
COMBINING_KIND_MAXSI, COMBINING_KIND_MAXF, COMBINING_KIND_AND,
COMBINING_KIND_OR, COMBINING_KIND_XOR,
COMBINING_KIND_MAXIMUMF, COMBINING_KIND_MINIMUMF]> {
let cppNamespace = "::mlir::vector";
let genSpecializedAttr = 0;
}
/// An attribute that specifies the combining function for `vector.contract`,
/// and `vector.reduction`.
def Vector_CombiningKindAttr : EnumAttr<Vector_Dialect, CombiningKind, "kind"> {
let assemblyFormat = "`<` $value `>`";
}
def Vector_IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [
I32EnumAttrCase<"parallel", 0>,
I32EnumAttrCase<"reduction", 1>
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::vector";
}
def Vector_IteratorTypeEnum
: EnumAttr<Vector_Dialect, Vector_IteratorType, "iterator_type"> {
let assemblyFormat = "`<` $value `>`";
}
def Vector_IteratorTypeArrayAttr
: TypedArrayAttrBase<Vector_IteratorTypeEnum,
"Iterator type should be an enum.">;
// TODO: Add an attribute to specify a different algebra with operators other
// than the current set: {*, +}.
def Vector_ContractionOp :
@@ -274,12 +215,16 @@ def Vector_ReductionOp :
Vector_Op<"reduction", [Pure,
PredOpTrait<"source operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
DeclareOpInterfaceMethods<ArithFastMathInterface>,
DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface,
["getShapeForUnroll"]>]>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
]>,
Arguments<(ins Vector_CombiningKindAttr:$kind,
AnyVectorOfAnyRank:$vector,
Optional<AnyType>:$acc)>,
Optional<AnyType>:$acc,
DefaultValuedAttr<
Arith_FastMathAttr,
"::mlir::arith::FastMathFlags::none">:$fastmath)>,
Results<(outs AnyType:$dest)> {
let summary = "reduction operation";
let description = [{
@@ -309,9 +254,13 @@ def Vector_ReductionOp :
}];
let builders = [
// Builder that infers the type of `dest`.
OpBuilder<(ins "CombiningKind":$kind, "Value":$vector, "Value":$acc)>,
OpBuilder<(ins "CombiningKind":$kind, "Value":$vector, "Value":$acc,
CArg<"::mlir::arith::FastMathFlags",
"::mlir::arith::FastMathFlags::none">:$fastMathFlags)>,
// Builder that infers the type of `dest` and has no accumulator.
OpBuilder<(ins "CombiningKind":$kind, "Value":$vector)>
OpBuilder<(ins "CombiningKind":$kind, "Value":$vector,
CArg<"::mlir::arith::FastMathFlags",
"::mlir::arith::FastMathFlags::none">:$fastMathFlags)>
];
// TODO: Migrate to assemblyFormat once `AllTypesMatch` supports optional
@@ -2469,22 +2418,6 @@ def Vector_TransposeOp :
let hasVerifier = 1;
}
def PrintPunctuation : I32EnumAttr<"PrintPunctuation",
"Punctuation for separating vectors or vector elements", [
I32EnumAttrCase<"NoPunctuation", 0, "no_punctuation">,
I32EnumAttrCase<"NewLine", 1, "newline">,
I32EnumAttrCase<"Comma", 2, "comma">,
I32EnumAttrCase<"Open", 3, "open">,
I32EnumAttrCase<"Close", 4, "close">
]> {
let cppNamespace = "::mlir::vector";
let genSpecializedAttr = 0;
}
def Vector_PrintPunctuation : EnumAttr<Vector_Dialect, PrintPunctuation, "punctuation"> {
let assemblyFormat = "`<` $value `>`";
}
def Vector_PrintOp :
Vector_Op<"print", []>,
Arguments<(ins Optional<Type<Or<[
@@ -2939,4 +2872,4 @@ def Vector_WarpExecuteOnLane0Op : Vector_Op<"warp_execute_on_lane_0",
}];
}
#endif // VECTOR_OPS
#endif // MLIR_DIALECT_VECTOR_IR_VECTOR_OPS

View File

@@ -8,6 +8,7 @@
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -592,11 +593,11 @@ struct VectorToScalarMapper<LLVM::vector_reduce_fmin> {
} // namespace
template <class LLVMRedIntrinOp>
static Value
createFPReductionComparisonOpLowering(ConversionPatternRewriter &rewriter,
Location loc, Type llvmType,
Value vectorOperand, Value accumulator) {
Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
static Value createFPReductionComparisonOpLowering(
ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
Value vectorOperand, Value accumulator, LLVM::FastmathFlagsAttr fmf) {
Value result =
rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand, fmf);
if (accumulator) {
result =
@@ -641,25 +642,39 @@ static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter,
/// `fmaximum`/`fminimum`.
/// More information: https://github.com/llvm/llvm-project/issues/64940
template <class LLVMRedIntrinOp, class MaskNeutral>
static Value lowerMaskedReductionWithRegular(
ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
Value vectorOperand, Value accumulator, Value mask) {
static Value
lowerMaskedReductionWithRegular(ConversionPatternRewriter &rewriter,
Location loc, Type llvmType,
Value vectorOperand, Value accumulator,
Value mask, LLVM::FastmathFlagsAttr fmf) {
const Value vectorMaskNeutral = createMaskNeutralValue<MaskNeutral>(
rewriter, loc, llvmType, vectorOperand.getType());
const Value selectedVectorByMask = rewriter.create<LLVM::SelectOp>(
loc, mask, vectorOperand, vectorMaskNeutral);
return createFPReductionComparisonOpLowering<LLVMRedIntrinOp>(
rewriter, loc, llvmType, selectedVectorByMask, accumulator);
rewriter, loc, llvmType, selectedVectorByMask, accumulator, fmf);
}
/// Overloaded methods to lower a reduction to an llvm instrinsic that requires
/// a start value. This start value format spans across fp reductions without
/// mask and all the masked reduction intrinsics.
template <class LLVMRedIntrinOp, class ReductionNeutral>
static Value
lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc,
Type llvmType, Value vectorOperand,
Value accumulator, LLVM::FastmathFlagsAttr fmf) {
accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
llvmType, accumulator);
return rewriter.create<LLVMRedIntrinOp>(loc, llvmType,
/*startValue=*/accumulator,
vectorOperand, fmf);
}
/// Overloaded methods to lower a *predicated* reduction to an llvm instrinsic
/// that requires a start value. This start value format spans across fp
/// reductions without mask and all the masked reduction intrinsics.
template <class LLVMVPRedIntrinOp, class ReductionNeutral>
static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter,
Location loc, Type llvmType,
Value vectorOperand,
Value accumulator) {
static Value
lowerPredicatedReductionWithStartValue(ConversionPatternRewriter &rewriter,
Location loc, Type llvmType,
Value vectorOperand, Value accumulator) {
accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
llvmType, accumulator);
return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
@@ -668,22 +683,9 @@ static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter,
}
template <class LLVMVPRedIntrinOp, class ReductionNeutral>
static Value
lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc,
Type llvmType, Value vectorOperand,
Value accumulator, bool reassociateFPReds) {
accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
llvmType, accumulator);
return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
/*startValue=*/accumulator,
vectorOperand, reassociateFPReds);
}
template <class LLVMVPRedIntrinOp, class ReductionNeutral>
static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter,
Location loc, Type llvmType,
Value vectorOperand,
Value accumulator, Value mask) {
static Value lowerPredicatedReductionWithStartValue(
ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
Value vectorOperand, Value accumulator, Value mask) {
accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
llvmType, accumulator);
Value vectorLength =
@@ -693,35 +695,19 @@ static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter,
vectorOperand, mask, vectorLength);
}
template <class LLVMVPRedIntrinOp, class ReductionNeutral>
static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter,
Location loc, Type llvmType,
Value vectorOperand,
Value accumulator, Value mask,
bool reassociateFPReds) {
accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
llvmType, accumulator);
Value vectorLength =
createVectorLengthValue(rewriter, loc, vectorOperand.getType());
return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
/*startValue=*/accumulator,
vectorOperand, mask, vectorLength,
reassociateFPReds);
}
template <class LLVMIntVPRedIntrinOp, class IntReductionNeutral,
class LLVMFPVPRedIntrinOp, class FPReductionNeutral>
static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter,
Location loc, Type llvmType,
Value vectorOperand,
Value accumulator, Value mask) {
static Value lowerPredicatedReductionWithStartValue(
ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
Value vectorOperand, Value accumulator, Value mask) {
if (llvmType.isIntOrIndex())
return lowerReductionWithStartValue<LLVMIntVPRedIntrinOp,
IntReductionNeutral>(
return lowerPredicatedReductionWithStartValue<LLVMIntVPRedIntrinOp,
IntReductionNeutral>(
rewriter, loc, llvmType, vectorOperand, accumulator, mask);
// FP dispatch.
return lowerReductionWithStartValue<LLVMFPVPRedIntrinOp, FPReductionNeutral>(
return lowerPredicatedReductionWithStartValue<LLVMFPVPRedIntrinOp,
FPReductionNeutral>(
rewriter, loc, llvmType, vectorOperand, accumulator, mask);
}
@@ -809,30 +795,39 @@ public:
if (!isa<FloatType>(eltType))
return failure();
arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
reductionOp.getContext(),
convertArithFastMathFlagsToLLVM(fMFAttr.getValue()));
fmf = LLVM::FastmathFlagsAttr::get(
reductionOp.getContext(),
fmf.getValue() | (reassociateFPReductions ? LLVM::FastmathFlags::reassoc
: LLVM::FastmathFlags::none));
// Floating-point reductions: add/mul/min/max
Value result;
if (kind == vector::CombiningKind::ADD) {
result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd,
ReductionNeutralZero>(
rewriter, loc, llvmType, operand, acc, reassociateFPReductions);
rewriter, loc, llvmType, operand, acc, fmf);
} else if (kind == vector::CombiningKind::MUL) {
result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul,
ReductionNeutralFPOne>(
rewriter, loc, llvmType, operand, acc, reassociateFPReductions);
rewriter, loc, llvmType, operand, acc, fmf);
} else if (kind == vector::CombiningKind::MINIMUMF) {
result =
createFPReductionComparisonOpLowering<LLVM::vector_reduce_fminimum>(
rewriter, loc, llvmType, operand, acc);
rewriter, loc, llvmType, operand, acc, fmf);
} else if (kind == vector::CombiningKind::MAXIMUMF) {
result =
createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmaximum>(
rewriter, loc, llvmType, operand, acc);
rewriter, loc, llvmType, operand, acc, fmf);
} else if (kind == vector::CombiningKind::MINF) {
result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>(
rewriter, loc, llvmType, operand, acc);
rewriter, loc, llvmType, operand, acc, fmf);
} else if (kind == vector::CombiningKind::MAXF) {
result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>(
rewriter, loc, llvmType, operand, acc);
rewriter, loc, llvmType, operand, acc, fmf);
} else
return failure();
@@ -893,74 +888,79 @@ public:
Value acc = reductionOp.getAcc();
Location loc = reductionOp.getLoc();
arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
reductionOp.getContext(),
convertArithFastMathFlagsToLLVM(fMFAttr.getValue()));
Value result;
switch (kind) {
case vector::CombiningKind::ADD:
result = lowerReductionWithStartValue<
result = lowerPredicatedReductionWithStartValue<
LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp,
ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc,
maskOp.getMask());
break;
case vector::CombiningKind::MUL:
result = lowerReductionWithStartValue<
result = lowerPredicatedReductionWithStartValue<
LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp,
ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc,
maskOp.getMask());
break;
case vector::CombiningKind::MINUI:
result = lowerReductionWithStartValue<LLVM::VPReduceUMinOp,
ReductionNeutralUIntMax>(
result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMinOp,
ReductionNeutralUIntMax>(
rewriter, loc, llvmType, operand, acc, maskOp.getMask());
break;
case vector::CombiningKind::MINSI:
result = lowerReductionWithStartValue<LLVM::VPReduceSMinOp,
ReductionNeutralSIntMax>(
result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMinOp,
ReductionNeutralSIntMax>(
rewriter, loc, llvmType, operand, acc, maskOp.getMask());
break;
case vector::CombiningKind::MAXUI:
result = lowerReductionWithStartValue<LLVM::VPReduceUMaxOp,
ReductionNeutralUIntMin>(
result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMaxOp,
ReductionNeutralUIntMin>(
rewriter, loc, llvmType, operand, acc, maskOp.getMask());
break;
case vector::CombiningKind::MAXSI:
result = lowerReductionWithStartValue<LLVM::VPReduceSMaxOp,
ReductionNeutralSIntMin>(
result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMaxOp,
ReductionNeutralSIntMin>(
rewriter, loc, llvmType, operand, acc, maskOp.getMask());
break;
case vector::CombiningKind::AND:
result = lowerReductionWithStartValue<LLVM::VPReduceAndOp,
ReductionNeutralAllOnes>(
result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceAndOp,
ReductionNeutralAllOnes>(
rewriter, loc, llvmType, operand, acc, maskOp.getMask());
break;
case vector::CombiningKind::OR:
result = lowerReductionWithStartValue<LLVM::VPReduceOrOp,
ReductionNeutralZero>(
result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceOrOp,
ReductionNeutralZero>(
rewriter, loc, llvmType, operand, acc, maskOp.getMask());
break;
case vector::CombiningKind::XOR:
result = lowerReductionWithStartValue<LLVM::VPReduceXorOp,
ReductionNeutralZero>(
result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceXorOp,
ReductionNeutralZero>(
rewriter, loc, llvmType, operand, acc, maskOp.getMask());
break;
case vector::CombiningKind::MINF:
result = lowerReductionWithStartValue<LLVM::VPReduceFMinOp,
ReductionNeutralFPMax>(
result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMinOp,
ReductionNeutralFPMax>(
rewriter, loc, llvmType, operand, acc, maskOp.getMask());
break;
case vector::CombiningKind::MAXF:
result = lowerReductionWithStartValue<LLVM::VPReduceFMaxOp,
ReductionNeutralFPMin>(
result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMaxOp,
ReductionNeutralFPMin>(
rewriter, loc, llvmType, operand, acc, maskOp.getMask());
break;
case CombiningKind::MAXIMUMF:
result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fmaximum,
MaskNeutralFMaximum>(
rewriter, loc, llvmType, operand, acc, maskOp.getMask());
rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
break;
case CombiningKind::MINIMUMF:
result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fminimum,
MaskNeutralFMinimum>(
rewriter, loc, llvmType, operand, acc, maskOp.getMask());
rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
break;
}

View File

@@ -8,7 +8,7 @@ add_mlir_dialect_library(MLIRVectorDialect
MLIRMaskableOpInterfaceIncGen
MLIRMaskingOpInterfaceIncGen
MLIRVectorOpsIncGen
MLIRVectorOpsEnumsIncGen
MLIRVectorAttributesIncGen
LINK_LIBS PUBLIC
MLIRArithDialect

View File

@@ -42,9 +42,9 @@
#include <cstdint>
#include <numeric>
#include "mlir/Dialect/Vector/IR/VectorOpsDialect.cpp.inc"
#include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
// Pull in all enum type and utility function definitions.
#include "mlir/Dialect/Vector/IR/VectorOpsEnums.cpp.inc"
#include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
using namespace mlir;
using namespace mlir::vector;
@@ -256,7 +256,7 @@ struct BitmaskEnumStorage : public AttributeStorage {
void VectorDialect::initialize() {
addAttributes<
#define GET_ATTRDEF_LIST
#include "mlir/Dialect/Vector/IR/VectorOpsAttrDefs.cpp.inc"
#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
>();
addOperations<
@@ -415,15 +415,17 @@ void MultiDimReductionOp::getCanonicalizationPatterns(
//===----------------------------------------------------------------------===//
void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
CombiningKind kind, Value vector) {
build(builder, result, kind, vector, /*acc=*/Value());
CombiningKind kind, Value vector,
arith::FastMathFlags fastMathFlags) {
build(builder, result, kind, vector, /*acc=*/Value(), fastMathFlags);
}
void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
CombiningKind kind, Value vector, Value acc) {
CombiningKind kind, Value vector, Value acc,
arith::FastMathFlags fastMathFlags) {
build(builder, result,
llvm::cast<VectorType>(vector.getType()).getElementType(), kind, vector,
acc);
acc, fastMathFlags);
}
LogicalResult ReductionOp::verify() {
@@ -447,9 +449,13 @@ ParseResult ReductionOp::parse(OpAsmParser &parser, OperationState &result) {
Type redType;
Type resType;
CombiningKindAttr kindAttr;
arith::FastMathFlagsAttr fastMathAttr;
if (parser.parseCustomAttributeWithFallback(kindAttr, Type{}, "kind",
result.attributes) ||
parser.parseComma() || parser.parseOperandList(operandsInfo) ||
(succeeded(parser.parseOptionalKeyword("fastmath")) &&
parser.parseCustomAttributeWithFallback(fastMathAttr, Type{}, "fastmath",
result.attributes)) ||
parser.parseColonType(redType) ||
parser.parseKeywordType("into", resType) ||
(!operandsInfo.empty() &&
@@ -470,6 +476,12 @@ void ReductionOp::print(OpAsmPrinter &p) {
p << ", " << getVector();
if (getAcc())
p << ", " << getAcc();
if (getFastmathAttr() &&
getFastmathAttr().getValue() != arith::FastMathFlags::none) {
p << ' ' << getFastmathAttrName().getValue();
p.printStrippedAttrOrType(getFastmathAttr());
}
p << " : " << getVector().getType() << " into " << getDest().getType();
}
@@ -6049,7 +6061,7 @@ Value mlir::vector::selectPassthru(OpBuilder &builder, Value mask,
//===----------------------------------------------------------------------===//
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/Vector/IR/VectorOpsAttrDefs.cpp.inc"
#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
#define GET_OP_CLASSES
#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"

View File

@@ -0,0 +1,14 @@
//===-- Vector.td - Entry point for Vector bindings --------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef PYTHON_BINDINGS_VECTOR
#define PYTHON_BINDINGS_VECTOR
include "mlir/Dialect/Vector/IR/Vector.td"
#endif // PYTHON_BINDINGS_VECTOR

View File

@@ -5,14 +5,14 @@
// CHECK-SAME: %[[A:.*]]: vector<16xf32>)
// CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.fadd"(%[[C]], %[[A]])
// CHECK-SAME: <{reassoc = false}> : (f32, vector<16xf32>) -> f32
// CHECK-SAME: <{fastmathFlags = #llvm.fastmath<none>}> : (f32, vector<16xf32>) -> f32
// CHECK: return %[[V]] : f32
//
// REASSOC-LABEL: @reduce_add_f32(
// REASSOC-SAME: %[[A:.*]]: vector<16xf32>)
// REASSOC: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
// REASSOC: %[[V:.*]] = "llvm.intr.vector.reduce.fadd"(%[[C]], %[[A]])
// REASSOC-SAME: <{reassoc = true}> : (f32, vector<16xf32>) -> f32
// REASSOC-SAME: <{fastmathFlags = #llvm.fastmath<reassoc>}> : (f32, vector<16xf32>) -> f32
// REASSOC: return %[[V]] : f32
//
func.func @reduce_add_f32(%arg0: vector<16xf32>) -> f32 {
@@ -22,22 +22,45 @@ func.func @reduce_add_f32(%arg0: vector<16xf32>) -> f32 {
// -----
// CHECK-LABEL: @reduce_add_f32_always_reassoc(
// CHECK-SAME: %[[A:.*]]: vector<16xf32>)
// CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.fadd"(%[[C]], %[[A]])
/// Note: the reassoc flag remains even though the pass sets reassociate-fp-reduction to false.
/// Ponder whether this flag really is a property of the pass / pattern..
// CHECK-SAME: <{fastmathFlags = #llvm.fastmath<reassoc>}> : (f32, vector<16xf32>) -> f32
// CHECK: return %[[V]] : f32
//
// REASSOC-LABEL: @reduce_add_f32_always_reassoc(
// REASSOC-SAME: %[[A:.*]]: vector<16xf32>)
// REASSOC: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
// REASSOC: %[[V:.*]] = "llvm.intr.vector.reduce.fadd"(%[[C]], %[[A]])
// REASSOC-SAME: <{fastmathFlags = #llvm.fastmath<reassoc>}> : (f32, vector<16xf32>) -> f32
// REASSOC: return %[[V]] : f32
//
func.func @reduce_add_f32_always_reassoc(%arg0: vector<16xf32>) -> f32 {
%0 = vector.reduction <add>, %arg0 fastmath<reassoc> : vector<16xf32> into f32
return %0 : f32
}
// -----
// CHECK-LABEL: @reduce_mul_f32(
// CHECK-SAME: %[[A:.*]]: vector<16xf32>)
// CHECK: %[[C:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.fmul"(%[[C]], %[[A]])
// CHECK-SAME: <{reassoc = false}> : (f32, vector<16xf32>) -> f32
// CHECK-SAME: <{fastmathFlags = #llvm.fastmath<nnan, ninf>}> : (f32, vector<16xf32>) -> f32
// CHECK: return %[[V]] : f32
//
// REASSOC-LABEL: @reduce_mul_f32(
// REASSOC-SAME: %[[A:.*]]: vector<16xf32>)
// REASSOC: %[[C:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
// REASSOC: %[[V:.*]] = "llvm.intr.vector.reduce.fmul"(%[[C]], %[[A]])
// REASSOC-SAME: <{reassoc = true}> : (f32, vector<16xf32>) -> f32
// REASSOC-SAME: <{fastmathFlags = #llvm.fastmath<nnan, ninf, reassoc>}> : (f32, vector<16xf32>) -> f32
// REASSOC: return %[[V]] : f32
//
func.func @reduce_mul_f32(%arg0: vector<16xf32>) -> f32 {
%0 = vector.reduction <mul>, %arg0 : vector<16xf32> into f32
%0 = vector.reduction <mul>, %arg0 fastmath<nnan, ninf> : vector<16xf32> into f32
return %0 : f32
}

View File

@@ -1216,7 +1216,7 @@ func.func @reduce_0d_f32(%arg0: vector<f32>) -> f32 {
// CHECK: %[[CA:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<f32> to vector<1xf32>
// CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.fadd"(%[[C]], %[[CA]])
// CHECK-SAME: <{reassoc = false}> : (f32, vector<1xf32>) -> f32
// CHECK-SAME: <{fastmathFlags = #llvm.fastmath<none>}> : (f32, vector<1xf32>) -> f32
// CHECK: return %[[V]] : f32
// -----
@@ -1229,7 +1229,7 @@ func.func @reduce_f16(%arg0: vector<16xf16>) -> f16 {
// CHECK-SAME: %[[A:.*]]: vector<16xf16>)
// CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f16) : f16
// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.fadd"(%[[C]], %[[A]])
// CHECK-SAME: <{reassoc = false}> : (f16, vector<16xf16>) -> f16
// CHECK-SAME: <{fastmathFlags = #llvm.fastmath<none>}> : (f16, vector<16xf16>) -> f16
// CHECK: return %[[V]] : f16
// -----
@@ -1242,7 +1242,7 @@ func.func @reduce_f32(%arg0: vector<16xf32>) -> f32 {
// CHECK-SAME: %[[A:.*]]: vector<16xf32>)
// CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.fadd"(%[[C]], %[[A]])
// CHECK-SAME: <{reassoc = false}> : (f32, vector<16xf32>) -> f32
// CHECK-SAME: <{fastmathFlags = #llvm.fastmath<none>}> : (f32, vector<16xf32>) -> f32
// CHECK: return %[[V]] : f32
// -----
@@ -1255,7 +1255,7 @@ func.func @reduce_f64(%arg0: vector<16xf64>) -> f64 {
// CHECK-SAME: %[[A:.*]]: vector<16xf64>)
// CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f64) : f64
// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.fadd"(%[[C]], %[[A]])
// CHECK-SAME: <{reassoc = false}> : (f64, vector<16xf64>) -> f64
// CHECK-SAME: <{fastmathFlags = #llvm.fastmath<none>}> : (f64, vector<16xf64>) -> f64
// CHECK: return %[[V]] : f64
// -----

View File

@@ -1011,3 +1011,10 @@ func.func @contraction_masked_scalable(%A: vector<3x4xf32>,
: vector<3x[8]x4xi1> -> vector<3x[8]xf32>
return %0 : vector<3x[8]xf32>
}
// CHECK-LABEL: func.func @fastmath(
func.func @fastmath(%x: vector<42xf32>) -> f32 {
// CHECK: vector.reduction <minf>, %{{.*}} fastmath<reassoc,nnan,ninf>
%min = vector.reduction <minf>, %x fastmath<reassoc,nnan,ninf> : vector<42xf32> into f32
return %min: f32
}

View File

@@ -354,13 +354,13 @@ define void @vector_reductions(float %0, <8 x float> %1, <8 x i32> %2) {
%12 = call i32 @llvm.vector.reduce.umax.v8i32(<8 x i32> %2)
; CHECK: "llvm.intr.vector.reduce.umin"(%{{.*}}) : (vector<8xi32>) -> i32
%13 = call i32 @llvm.vector.reduce.umin.v8i32(<8 x i32> %2)
; CHECK: "llvm.intr.vector.reduce.fadd"(%{{.*}}, %{{.*}}) <{reassoc = false}> : (f32, vector<8xf32>) -> f32
; CHECK: "llvm.intr.vector.reduce.fadd"(%{{.*}}, %{{.*}}) <{fastmathFlags = #llvm.fastmath<none>}> : (f32, vector<8xf32>) -> f32
%14 = call float @llvm.vector.reduce.fadd.v8f32(float %0, <8 x float> %1)
; CHECK: "llvm.intr.vector.reduce.fmul"(%{{.*}}, %{{.*}}) <{reassoc = false}> : (f32, vector<8xf32>) -> f32
; CHECK: "llvm.intr.vector.reduce.fmul"(%{{.*}}, %{{.*}}) <{fastmathFlags = #llvm.fastmath<none>}> : (f32, vector<8xf32>) -> f32
%15 = call float @llvm.vector.reduce.fmul.v8f32(float %0, <8 x float> %1)
; CHECK: "llvm.intr.vector.reduce.fadd"(%{{.*}}, %{{.*}}) <{reassoc = true}> : (f32, vector<8xf32>) -> f32
; CHECK: "llvm.intr.vector.reduce.fadd"(%{{.*}}, %{{.*}}) <{fastmathFlags = #llvm.fastmath<reassoc>}> : (f32, vector<8xf32>) -> f32
%16 = call reassoc float @llvm.vector.reduce.fadd.v8f32(float %0, <8 x float> %1)
; CHECK: "llvm.intr.vector.reduce.fmul"(%{{.*}}, %{{.*}}) <{reassoc = true}> : (f32, vector<8xf32>) -> f32
; CHECK: "llvm.intr.vector.reduce.fmul"(%{{.*}}, %{{.*}}) <{fastmathFlags = #llvm.fastmath<reassoc>}> : (f32, vector<8xf32>) -> f32
%17 = call reassoc float @llvm.vector.reduce.fmul.v8f32(float %0, <8 x float> %1)
; CHECK: "llvm.intr.vector.reduce.xor"(%{{.*}}) : (vector<8xi32>) -> i32
%18 = call i32 @llvm.vector.reduce.xor.v8i32(<8 x i32> %2)

View File

@@ -375,9 +375,9 @@ llvm.func @vector_reductions(%arg0: f32, %arg1: vector<8xf32>, %arg2: vector<8xi
// CHECK: call float @llvm.vector.reduce.fmul.v8f32
"llvm.intr.vector.reduce.fmul"(%arg0, %arg1) : (f32, vector<8xf32>) -> f32
// CHECK: call reassoc float @llvm.vector.reduce.fadd.v8f32
"llvm.intr.vector.reduce.fadd"(%arg0, %arg1) {reassoc = true} : (f32, vector<8xf32>) -> f32
"llvm.intr.vector.reduce.fadd"(%arg0, %arg1) <{fastmathFlags = #llvm.fastmath<reassoc>}> : (f32, vector<8xf32>) -> f32
// CHECK: call reassoc float @llvm.vector.reduce.fmul.v8f32
"llvm.intr.vector.reduce.fmul"(%arg0, %arg1) {reassoc = true} : (f32, vector<8xf32>) -> f32
"llvm.intr.vector.reduce.fmul"(%arg0, %arg1) <{fastmathFlags = #llvm.fastmath<reassoc>}> : (f32, vector<8xf32>) -> f32
// CHECK: call i32 @llvm.vector.reduce.xor.v8i32
"llvm.intr.vector.reduce.xor"(%arg2) : (vector<8xi32>) -> i32
llvm.return