[mlir][Arith] Add pass for emulating unsupported float ops (#1079)

To complement the bf16 expansion and truncation patterns added to
ExpandOps, define a pass that replaces, for any arithmetic operation
op,
%y = arith.op %v0, %v1, ... : T
with
%e0 = arith.expf %v0 : T to U
%e1 = arith.expf %v1 : T to U
...
%y.exp = arith.op %e0, %e1, ... : U
%y = arith.truncf %y.exp : U to T

This allows for "emulating" floating-point operations not supported on
a given target (such as bfloat operations or most arithmetic on 8-bit
floats) by extending those types to supported ones, performing the
arithmetic operation, and then truncating back to the original
type (which ensures appropriate rounding behavior).

The lowering of the extf and truncf ops introduced by this
transformation should be handled by subsequent passes.

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D154539
This commit is contained in:
Krzysztof Drewniak
2023-06-15 13:42:54 -05:00
parent 980cd18354
commit 10b56e0210
5 changed files with 311 additions and 0 deletions

View File

@@ -13,6 +13,8 @@
namespace mlir {
class DataFlowSolver;
class ConversionTarget;
class TypeConverter;
namespace arith {
@@ -42,6 +44,21 @@ void populateArithWideIntEmulationPatterns(
void populateArithNarrowTypeEmulationPatterns(
NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns);
/// Populate the type conversions needed to emulate the unsupported
/// `sourceTypes` with `destType`
void populateEmulateUnsupportedFloatsConversions(TypeConverter &converter,
ArrayRef<Type> sourceTypes,
Type targetType);
/// Add rewrite patterns for converting operations that use illegal float types
/// to ones that use legal ones.
void populateEmulateUnsupportedFloatsPatterns(RewritePatternSet &patterns,
TypeConverter &converter);
/// Set up a dialect conversion to reject arithmetic operations on unsupported
/// float types.
void populateEmulateUnsupportedFloatsLegality(ConversionTarget &target,
TypeConverter &converter);
/// Add patterns to expand Arith ceil/floor division ops.
void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns);

View File

@@ -63,6 +63,28 @@ def ArithIntRangeOpts : Pass<"int-range-optimizations"> {
}];
}
def ArithEmulateUnsupportedFloats : Pass<"arith-emulate-unsupported-floats"> {
let summary = "Emulate operations on unsupported floats with extf/truncf";
let description = [{
Emulate arith and vector floating point operations that use float types
which are unspported on a target by inserting extf/truncf pairs around all
such operations in order to produce arithmetic that can be performed while
preserving the original rounding behavior.
This pass does not attempt to reason about the operations being performed
to determine when type conversions can be elided.
}];
let options = [
ListOption<"sourceTypeStrs", "source-types", "std::string",
"MLIR types without arithmetic support on a given target">,
Option<"targetTypeStr", "target-type", "std::string", "\"f32\"",
"MLIR type to convert the unsupported source types to">,
];
let dependentDialects = ["vector::VectorDialect"];
}
def ArithEmulateWideInt : Pass<"arith-emulate-wide-int"> {
let summary = "Emulate 2*N-bit integer operations using N-bit operations";
let description = [{

View File

@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRArithTransforms
BufferizableOpInterfaceImpl.cpp
Bufferize.cpp
EmulateUnsupportedFloats.cpp
EmulateWideInt.cpp
EmulateNarrowType.cpp
ExpandOps.cpp

View File

@@ -0,0 +1,184 @@
//===- EmulateUnsupportedFloats.cpp - Promote small floats --*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// This pass promotes small floats (of some unsupported types T) to a supported
// type U by wrapping all float operations on Ts with expansion to and
// truncation from U, then operating on U.
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/STLExtras.h"
#include <optional>
namespace mlir::arith {
#define GEN_PASS_DEF_ARITHEMULATEUNSUPPORTEDFLOATS
#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
} // namespace mlir::arith
using namespace mlir;
namespace {
struct EmulateUnsupportedFloatsPass
: arith::impl::ArithEmulateUnsupportedFloatsBase<
EmulateUnsupportedFloatsPass> {
using arith::impl::ArithEmulateUnsupportedFloatsBase<
EmulateUnsupportedFloatsPass>::ArithEmulateUnsupportedFloatsBase;
void runOnOperation() override;
};
struct EmulateFloatPattern final : ConversionPattern {
EmulateFloatPattern(TypeConverter &converter, MLIRContext *ctx)
: ConversionPattern(converter, Pattern::MatchAnyOpTypeTag(), 1, ctx) {}
LogicalResult match(Operation *op) const override;
void rewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
} // end namespace
/// Map strings to float types. This function is here because no one else needs
/// it yet, feel free to abstract it out.
static std::optional<FloatType> parseFloatType(MLIRContext *ctx,
StringRef name) {
Builder b(ctx);
return llvm::StringSwitch<std::optional<FloatType>>(name)
.Case("f8E5M2", b.getFloat8E5M2Type())
.Case("f8E4M3FN", b.getFloat8E4M3FNType())
.Case("f8E5M2FNUZ", b.getFloat8E5M2FNUZType())
.Case("f8E4M3FNUZ", b.getFloat8E4M3FNUZType())
.Case("bf16", b.getBF16Type())
.Case("f16", b.getF16Type())
.Case("f32", b.getF32Type())
.Case("f64", b.getF64Type())
.Case("f80", b.getF80Type())
.Case("f128", b.getF128Type())
.Default(std::nullopt);
}
LogicalResult EmulateFloatPattern::match(Operation *op) const {
if (getTypeConverter()->isLegal(op))
return failure();
// The rewrite doesn't handle cloning regions.
if (op->getNumRegions() != 0)
return failure();
return success();
}
void EmulateFloatPattern::rewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
Location loc = op->getLoc();
TypeConverter *converter = getTypeConverter();
SmallVector<Type> resultTypes;
assert(
succeeded(converter->convertTypes(op->getResultTypes(), resultTypes)) &&
"type conversions shouldn't fail in this pass");
Operation *expandedOp =
rewriter.create(loc, op->getName().getIdentifier(), operands, resultTypes,
op->getAttrs(), op->getSuccessors(), /*regions=*/{});
SmallVector<Value> newResults(expandedOp->getResults());
for (auto [res, oldType, newType] : llvm::zip_equal(
MutableArrayRef{newResults}, op->getResultTypes(), resultTypes)) {
if (oldType != newType)
res = rewriter.create<arith::TruncFOp>(loc, oldType, res);
}
rewriter.replaceOp(op, newResults);
}
void mlir::arith::populateEmulateUnsupportedFloatsConversions(
TypeConverter &converter, ArrayRef<Type> sourceTypes, Type targetType) {
converter.addConversion([sourceTypes = SmallVector<Type>(sourceTypes),
targetType](Type type) -> std::optional<Type> {
if (llvm::is_contained(sourceTypes, type))
return targetType;
if (auto shaped = type.dyn_cast<ShapedType>())
if (llvm::is_contained(sourceTypes, shaped.getElementType()))
return shaped.clone(targetType);
// All other types legal
return type;
});
converter.addTargetMaterialization(
[](OpBuilder &b, Type target, ValueRange input, Location loc) {
return b.create<arith::ExtFOp>(loc, target, input);
});
}
void mlir::arith::populateEmulateUnsupportedFloatsPatterns(
RewritePatternSet &patterns, TypeConverter &converter) {
patterns.add<EmulateFloatPattern>(converter, patterns.getContext());
}
void mlir::arith::populateEmulateUnsupportedFloatsLegality(
ConversionTarget &target, TypeConverter &converter) {
// Don't try to legalize functions and other ops that don't need expansion.
target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
target.addDynamicallyLegalDialect<arith::ArithDialect>(
[&](Operation *op) -> std::optional<bool> {
return converter.isLegal(op);
});
// Manually mark arithmetic-performing vector instructions.
target.addDynamicallyLegalOp<
vector::ContractionOp, vector::ReductionOp, vector::MultiDimReductionOp,
vector::FMAOp, vector::OuterProductOp, vector::MatmulOp, vector::ScanOp>(
[&](Operation *op) { return converter.isLegal(op); });
target.addLegalOp<arith::ExtFOp, arith::TruncFOp, arith::ConstantOp,
vector::SplatOp>();
}
void EmulateUnsupportedFloatsPass::runOnOperation() {
MLIRContext *ctx = &getContext();
Operation *op = getOperation();
SmallVector<Type> sourceTypes;
Type targetType;
std::optional<FloatType> maybeTargetType = parseFloatType(ctx, targetTypeStr);
if (!maybeTargetType) {
emitError(UnknownLoc::get(ctx), "could not map target type '" +
targetTypeStr +
"' to a known floating-point type");
return signalPassFailure();
}
targetType = *maybeTargetType;
for (StringRef sourceTypeStr : sourceTypeStrs) {
std::optional<FloatType> maybeSourceType =
parseFloatType(ctx, sourceTypeStr);
if (!maybeSourceType) {
emitError(UnknownLoc::get(ctx), "could not map source type '" +
sourceTypeStr +
"' to a known floating-point type");
return signalPassFailure();
}
sourceTypes.push_back(*maybeSourceType);
}
if (sourceTypes.empty())
(void)emitOptionalWarning(
std::nullopt,
"no source types specified, float emulation will do nothing");
if (llvm::is_contained(sourceTypes, targetType)) {
emitError(UnknownLoc::get(ctx),
"target type cannot be an unsupported source type");
return signalPassFailure();
}
TypeConverter converter;
arith::populateEmulateUnsupportedFloatsConversions(converter, sourceTypes,
targetType);
RewritePatternSet patterns(ctx);
arith::populateEmulateUnsupportedFloatsPatterns(patterns, converter);
ConversionTarget target(getContext());
arith::populateEmulateUnsupportedFloatsLegality(target, converter);
if (failed(applyPartialConversion(op, target, std::move(patterns))))
signalPassFailure();
}

View File

@@ -0,0 +1,87 @@
// RUN: mlir-opt --split-input-file --arith-emulate-unsupported-floats="source-types=bf16,f8E4M3FNUZ target-type=f32" %s | FileCheck %s
func.func @basic_expansion(%x: bf16) -> bf16 {
// CHECK-LABEL: @basic_expansion
// CHECK-SAME: [[X:%.+]]: bf16
// CHECK-DAG: [[C:%.+]] = arith.constant {{.*}} : bf16
// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] : bf16 to f32
// CHECK-DAG: [[C_EXP:%.+]] = arith.extf [[C]] : bf16 to f32
// CHECK: [[Y_EXP:%.+]] = arith.addf [[X_EXP]], [[C_EXP]] : f32
// CHECK: [[Y:%.+]] = arith.truncf [[Y_EXP]] : f32 to bf16
// CHECK: return [[Y]]
%c = arith.constant 1.0 : bf16
%y = arith.addf %x, %c : bf16
func.return %y : bf16
}
// -----
func.func @chained(%x: bf16, %y: bf16, %z: bf16) -> i1 {
// CHECK-LABEL: @chained
// CHECK-SAME: [[X:%.+]]: bf16, [[Y:%.+]]: bf16, [[Z:%.+]]: bf16
// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] : bf16 to f32
// CHECK-DAG: [[Y_EXP:%.+]] = arith.extf [[Y]] : bf16 to f32
// CHECK-DAG: [[Z_EXP:%.+]] = arith.extf [[Z]] : bf16 to f32
// CHECK: [[P_EXP:%.+]] = arith.addf [[X_EXP]], [[Y_EXP]] : f32
// CHECK: [[P:%.+]] = arith.truncf [[P_EXP]] : f32 to bf16
// CHECK: [[P_EXP2:%.+]] = arith.extf [[P]] : bf16 to f32
// CHECK: [[Q_EXP:%.+]] = arith.mulf [[P_EXP2]], [[Z_EXP]]
// CHECK: [[Q:%.+]] = arith.truncf [[Q_EXP]] : f32 to bf16
// CHECK: [[Q_EXP2:%.+]] = arith.extf [[Q]] : bf16 to f32
// CHECK: [[RES:%.+]] = arith.cmpf ole, [[P_EXP2]], [[Q_EXP2]] : f32
// CHECK: return [[RES]]
%p = arith.addf %x, %y : bf16
%q = arith.mulf %p, %z : bf16
%res = arith.cmpf ole, %p, %q : bf16
func.return %res : i1
}
// -----
func.func @memops(%a: memref<4xf8E4M3FNUZ>, %b: memref<4xf8E4M3FNUZ>) {
// CHECK-LABEL: @memops
// CHECK: [[V:%.+]] = memref.load {{.*}} : memref<4xf8E4M3FNUZ>
// CHECK: [[V_EXP:%.+]] = arith.extf [[V]] : f8E4M3FNUZ to f32
// CHECK: memref.store [[V]]
// CHECK: [[W:%.+]] = memref.load
// CHECK: [[W_EXP:%.+]] = arith.extf [[W]] : f8E4M3FNUZ to f32
// CHECK: [[X_EXP:%.+]] = arith.addf [[V_EXP]], [[W_EXP]] : f32
// CHECK: [[X:%.+]] = arith.truncf [[X_EXP]] : f32 to f8E4M3FNUZ
// CHECK: memref.store [[X]]
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%v = memref.load %a[%c0] : memref<4xf8E4M3FNUZ>
memref.store %v, %b[%c0] : memref<4xf8E4M3FNUZ>
%w = memref.load %a[%c1] : memref<4xf8E4M3FNUZ>
%x = arith.addf %v, %w : f8E4M3FNUZ
memref.store %x, %b[%c1] : memref<4xf8E4M3FNUZ>
func.return
}
// -----
func.func @vectors(%a: vector<4xf8E4M3FNUZ>) -> vector<4xf32> {
// CHECK-LABEL: @vectors
// CHECK-SAME: [[A:%.+]]: vector<4xf8E4M3FNUZ>
// CHECK: [[A_EXP:%.+]] = arith.extf [[A]] : vector<4xf8E4M3FNUZ> to vector<4xf32>
// CHECK: [[B_EXP:%.+]] = arith.mulf [[A_EXP]], [[A_EXP]] : vector<4xf32>
// CHECK: [[B:%.+]] = arith.truncf [[B_EXP]] : vector<4xf32> to vector<4xf8E4M3FNUZ>
// CHECK: [[RET:%.+]] = arith.extf [[B]] : vector<4xf8E4M3FNUZ> to vector<4xf32>
// CHECK: return [[RET]]
%b = arith.mulf %a, %a : vector<4xf8E4M3FNUZ>
%ret = arith.extf %b : vector<4xf8E4M3FNUZ> to vector<4xf32>
func.return %ret : vector<4xf32>
}
// -----
func.func @no_expansion(%x: f32) -> f32 {
// CHECK-LABEL: @no_expansion
// CHECK-SAME: [[X:%.+]]: f32
// CHECK-DAG: [[C:%.+]] = arith.constant {{.*}} : f32
// CHECK: [[Y:%.+]] = arith.addf [[X]], [[C]] : f32
// CHECK: return [[Y]]
%c = arith.constant 1.0 : f32
%y = arith.addf %x, %c : f32
func.return %y : f32
}