mirror of
https://github.com/intel/llvm.git
synced 2026-01-26 21:53:12 +08:00
[mlir][math] Uplift from arith to math.fma
Add pass to uplift from arith mulf + addf ops to math.fma if fastmath flags allow it. Differential Revision: https://reviews.llvm.org/D152633
This commit is contained in:
@@ -1 +1,2 @@
|
||||
add_subdirectory(IR)
|
||||
add_subdirectory(Transforms)
|
||||
|
||||
5
mlir/include/mlir/Dialect/Math/Transforms/CMakeLists.txt
Normal file
5
mlir/include/mlir/Dialect/Math/Transforms/CMakeLists.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Math)
|
||||
add_public_tablegen_target(MLIRMathTransformsIncGen)
|
||||
|
||||
add_mlir_doc(Passes MathPasses ./ -gen-pass-doc)
|
||||
@@ -9,7 +9,17 @@
|
||||
#ifndef MLIR_DIALECT_MATH_TRANSFORMS_PASSES_H_
|
||||
#define MLIR_DIALECT_MATH_TRANSFORMS_PASSES_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace math {
|
||||
#define GEN_PASS_DECL
|
||||
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
|
||||
#define GEN_PASS_DECL_MATHUPLIFTTOFMA
|
||||
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
|
||||
} // namespace math
|
||||
|
||||
class RewritePatternSet;
|
||||
|
||||
@@ -34,6 +44,8 @@ void populateMathPolynomialApproximationPatterns(
|
||||
RewritePatternSet &patterns,
|
||||
const MathPolynomialApproximationOptions &options = {});
|
||||
|
||||
void populateUpliftToFMAPatterns(RewritePatternSet &patterns);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES_H_
|
||||
|
||||
22
mlir/include/mlir/Dialect/Math/Transforms/Passes.td
Normal file
22
mlir/include/mlir/Dialect/Math/Transforms/Passes.td
Normal file
@@ -0,0 +1,22 @@
|
||||
//===-- Passes.td - Math pass definition file --------------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_MATH_TRANSFORMS_PASSES
|
||||
#define MLIR_DIALECT_MATH_TRANSFORMS_PASSES
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def MathUpliftToFMA : Pass<"math-uplift-to-fma"> {
|
||||
let summary = "Uplift arith ops to math.fma.";
|
||||
let description = [{
|
||||
Uplift sequence of addf and mulf ops to math.fma if fastmath flags allows it.
|
||||
}];
|
||||
let dependentDialects = ["math::MathDialect"];
|
||||
}
|
||||
|
||||
#endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES
|
||||
@@ -25,6 +25,7 @@
|
||||
#include "mlir/Dialect/GPU/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/LLVMIR/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/Linalg/Passes.h"
|
||||
#include "mlir/Dialect/Math/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/NVGPU/Passes.h"
|
||||
#include "mlir/Dialect/SCF/Transforms/Passes.h"
|
||||
@@ -70,6 +71,7 @@ inline void registerAllPasses() {
|
||||
registerNVGPUPasses();
|
||||
registerSparseTensorPasses();
|
||||
LLVM::registerLLVMPasses();
|
||||
math::registerMathPasses();
|
||||
memref::registerMemRefPasses();
|
||||
registerSCFPasses();
|
||||
registerShapePasses();
|
||||
|
||||
@@ -2,10 +2,14 @@ add_mlir_dialect_library(MLIRMathTransforms
|
||||
AlgebraicSimplification.cpp
|
||||
ExpandPatterns.cpp
|
||||
PolynomialApproximation.cpp
|
||||
UpliftToFMA.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Math/Transforms
|
||||
|
||||
DEPENDS
|
||||
MLIRMathTransformsIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRArithDialect
|
||||
MLIRDialectUtils
|
||||
|
||||
79
mlir/lib/Dialect/Math/Transforms/UpliftToFMA.cpp
Normal file
79
mlir/lib/Dialect/Math/Transforms/UpliftToFMA.cpp
Normal file
@@ -0,0 +1,79 @@
|
||||
//===- UpliftToFMA.cpp - Arith to FMA uplifting ---------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements uplifting from arith ops to math.fma.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/Math/Transforms/Passes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
namespace mlir::math {
|
||||
#define GEN_PASS_DEF_MATHUPLIFTTOFMA
|
||||
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
|
||||
} // namespace mlir::math
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
template <typename Op>
|
||||
static bool isValidForFMA(Op op) {
|
||||
return static_cast<bool>(op.getFastmath() & arith::FastMathFlags::contract);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
struct UpliftFma final : OpRewritePattern<arith::AddFOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(arith::AddFOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (!isValidForFMA(op))
|
||||
return rewriter.notifyMatchFailure(op, "addf op is not suitable for fma");
|
||||
|
||||
Value c;
|
||||
arith::MulFOp ab;
|
||||
if ((ab = op.getLhs().getDefiningOp<arith::MulFOp>())) {
|
||||
c = op.getRhs();
|
||||
} else if ((ab = op.getRhs().getDefiningOp<arith::MulFOp>())) {
|
||||
c = op.getLhs();
|
||||
} else {
|
||||
return rewriter.notifyMatchFailure(op, "no mulf op");
|
||||
}
|
||||
|
||||
if (!isValidForFMA(ab))
|
||||
return rewriter.notifyMatchFailure(ab, "mulf op is not suitable for fma");
|
||||
|
||||
Value a = ab.getLhs();
|
||||
Value b = ab.getRhs();
|
||||
arith::FastMathFlags fmf = op.getFastmath() & ab.getFastmath();
|
||||
rewriter.replaceOpWithNewOp<math::FmaOp>(op, a, b, c, fmf);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct MathUpliftToFMA final
|
||||
: math::impl::MathUpliftToFMABase<MathUpliftToFMA> {
|
||||
using MathUpliftToFMABase::MathUpliftToFMABase;
|
||||
|
||||
void runOnOperation() override {
|
||||
RewritePatternSet patterns(&getContext());
|
||||
populateUpliftToFMAPatterns(patterns);
|
||||
if (failed(
|
||||
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void mlir::populateUpliftToFMAPatterns(RewritePatternSet &patterns) {
|
||||
patterns.insert<UpliftFma>(patterns.getContext());
|
||||
}
|
||||
37
mlir/test/Dialect/Math/uplift-to-fma.mlir
Normal file
37
mlir/test/Dialect/Math/uplift-to-fma.mlir
Normal file
@@ -0,0 +1,37 @@
|
||||
// RUN: mlir-opt %s --split-input-file --math-uplift-to-fma | FileCheck %s
|
||||
|
||||
// No uplifting without fastmath flags.
|
||||
// CHECK-LABEL: func @test
|
||||
// CHECK-SAME: (%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32)
|
||||
// CHECK: %[[V1:.*]] = arith.mulf %[[ARG1]], %[[ARG2]]
|
||||
// CHECK: %[[V2:.*]] = arith.addf %[[V1]], %[[ARG3]]
|
||||
// CHECK: return %[[V2]]
|
||||
func.func @test(%arg1: f32, %arg2: f32, %arg3: f32) -> f32 {
|
||||
%1 = arith.mulf %arg1, %arg2 : f32
|
||||
%2 = arith.addf %1, %arg3 : f32
|
||||
return %2 : f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @test
|
||||
// CHECK-SAME: (%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32)
|
||||
// CHECK: %[[RES:.*]] = math.fma %[[ARG1]], %[[ARG2]], %[[ARG3]] fastmath<fast> : f32
|
||||
// CHECK: return %[[RES]]
|
||||
func.func @test(%arg1: f32, %arg2: f32, %arg3: f32) -> f32 {
|
||||
%1 = arith.mulf %arg1, %arg2 fastmath<fast> : f32
|
||||
%2 = arith.addf %1, %arg3 fastmath<fast> : f32
|
||||
return %2 : f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @test
|
||||
// CHECK-SAME: (%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32)
|
||||
// CHECK: %[[RES:.*]] = math.fma %[[ARG1]], %[[ARG2]], %[[ARG3]] fastmath<contract> : f32
|
||||
// CHECK: return %[[RES]]
|
||||
func.func @test(%arg1: f32, %arg2: f32, %arg3: f32) -> f32 {
|
||||
%1 = arith.mulf %arg1, %arg2 fastmath<fast> : f32
|
||||
%2 = arith.addf %arg3, %1 fastmath<contract> : f32
|
||||
return %2 : f32
|
||||
}
|
||||
Reference in New Issue
Block a user