[mlir][math] Uplift from arith to math.fma

Add pass to uplift from arith mulf + addf ops to math.fma if fastmath flags allow it.

Differential Revision: https://reviews.llvm.org/D152633
This commit is contained in:
Ivan Butygin
2023-06-10 22:59:24 +02:00
parent e1164c7a92
commit ee8b8d6b58
8 changed files with 162 additions and 0 deletions

View File

@@ -1 +1,2 @@
add_subdirectory(IR)
add_subdirectory(Transforms)

View File

@@ -0,0 +1,5 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Math)
add_public_tablegen_target(MLIRMathTransformsIncGen)
add_mlir_doc(Passes MathPasses ./ -gen-pass-doc)

View File

@@ -9,7 +9,17 @@
#ifndef MLIR_DIALECT_MATH_TRANSFORMS_PASSES_H_
#define MLIR_DIALECT_MATH_TRANSFORMS_PASSES_H_
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace math {
#define GEN_PASS_DECL
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
#define GEN_PASS_DECL_MATHUPLIFTTOFMA
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
} // namespace math
class RewritePatternSet;
@@ -34,6 +44,8 @@ void populateMathPolynomialApproximationPatterns(
RewritePatternSet &patterns,
const MathPolynomialApproximationOptions &options = {});
void populateUpliftToFMAPatterns(RewritePatternSet &patterns);
} // namespace mlir
#endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES_H_

View File

@@ -0,0 +1,22 @@
//===-- Passes.td - Math pass definition file --------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_MATH_TRANSFORMS_PASSES
#define MLIR_DIALECT_MATH_TRANSFORMS_PASSES
include "mlir/Pass/PassBase.td"
def MathUpliftToFMA : Pass<"math-uplift-to-fma"> {
let summary = "Uplift arith ops to math.fma.";
let description = [{
Uplift sequence of addf and mulf ops to math.fma if fastmath flags allows it.
}];
let dependentDialects = ["math::MathDialect"];
}
#endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES

View File

@@ -25,6 +25,7 @@
#include "mlir/Dialect/GPU/Transforms/Passes.h"
#include "mlir/Dialect/LLVMIR/Transforms/Passes.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/NVGPU/Passes.h"
#include "mlir/Dialect/SCF/Transforms/Passes.h"
@@ -70,6 +71,7 @@ inline void registerAllPasses() {
registerNVGPUPasses();
registerSparseTensorPasses();
LLVM::registerLLVMPasses();
math::registerMathPasses();
memref::registerMemRefPasses();
registerSCFPasses();
registerShapePasses();

View File

@@ -2,10 +2,14 @@ add_mlir_dialect_library(MLIRMathTransforms
AlgebraicSimplification.cpp
ExpandPatterns.cpp
PolynomialApproximation.cpp
UpliftToFMA.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Math/Transforms
DEPENDS
MLIRMathTransformsIncGen
LINK_LIBS PUBLIC
MLIRArithDialect
MLIRDialectUtils

View File

@@ -0,0 +1,79 @@
//===- UpliftToFMA.cpp - Arith to FMA uplifting ---------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements uplifting from arith ops to math.fma.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir::math {
#define GEN_PASS_DEF_MATHUPLIFTTOFMA
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
} // namespace mlir::math
using namespace mlir;
template <typename Op>
static bool isValidForFMA(Op op) {
return static_cast<bool>(op.getFastmath() & arith::FastMathFlags::contract);
}
namespace {
struct UpliftFma final : OpRewritePattern<arith::AddFOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(arith::AddFOp op,
PatternRewriter &rewriter) const override {
if (!isValidForFMA(op))
return rewriter.notifyMatchFailure(op, "addf op is not suitable for fma");
Value c;
arith::MulFOp ab;
if ((ab = op.getLhs().getDefiningOp<arith::MulFOp>())) {
c = op.getRhs();
} else if ((ab = op.getRhs().getDefiningOp<arith::MulFOp>())) {
c = op.getLhs();
} else {
return rewriter.notifyMatchFailure(op, "no mulf op");
}
if (!isValidForFMA(ab))
return rewriter.notifyMatchFailure(ab, "mulf op is not suitable for fma");
Value a = ab.getLhs();
Value b = ab.getRhs();
arith::FastMathFlags fmf = op.getFastmath() & ab.getFastmath();
rewriter.replaceOpWithNewOp<math::FmaOp>(op, a, b, c, fmf);
return success();
}
};
struct MathUpliftToFMA final
: math::impl::MathUpliftToFMABase<MathUpliftToFMA> {
using MathUpliftToFMABase::MathUpliftToFMABase;
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateUpliftToFMAPatterns(patterns);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
};
} // namespace
void mlir::populateUpliftToFMAPatterns(RewritePatternSet &patterns) {
patterns.insert<UpliftFma>(patterns.getContext());
}

View File

@@ -0,0 +1,37 @@
// RUN: mlir-opt %s --split-input-file --math-uplift-to-fma | FileCheck %s
// No uplifting without fastmath flags.
// CHECK-LABEL: func @test
// CHECK-SAME: (%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32)
// CHECK: %[[V1:.*]] = arith.mulf %[[ARG1]], %[[ARG2]]
// CHECK: %[[V2:.*]] = arith.addf %[[V1]], %[[ARG3]]
// CHECK: return %[[V2]]
func.func @test(%arg1: f32, %arg2: f32, %arg3: f32) -> f32 {
%1 = arith.mulf %arg1, %arg2 : f32
%2 = arith.addf %1, %arg3 : f32
return %2 : f32
}
// -----
// CHECK-LABEL: func @test
// CHECK-SAME: (%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32)
// CHECK: %[[RES:.*]] = math.fma %[[ARG1]], %[[ARG2]], %[[ARG3]] fastmath<fast> : f32
// CHECK: return %[[RES]]
func.func @test(%arg1: f32, %arg2: f32, %arg3: f32) -> f32 {
%1 = arith.mulf %arg1, %arg2 fastmath<fast> : f32
%2 = arith.addf %1, %arg3 fastmath<fast> : f32
return %2 : f32
}
// -----
// CHECK-LABEL: func @test
// CHECK-SAME: (%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32)
// CHECK: %[[RES:.*]] = math.fma %[[ARG1]], %[[ARG2]], %[[ARG3]] fastmath<contract> : f32
// CHECK: return %[[RES]]
func.func @test(%arg1: f32, %arg2: f32, %arg3: f32) -> f32 {
%1 = arith.mulf %arg1, %arg2 fastmath<fast> : f32
%2 = arith.addf %arg3, %1 fastmath<contract> : f32
return %2 : f32
}