mirror of
https://github.com/intel/llvm.git
synced 2026-01-24 08:30:34 +08:00
This revision adds support for generating utilities for passes such as options/statistics/etc. that can be inferred from the tablegen definition. This removes additional boilerplate from the pass, and also makes it easier to remove the reliance on the pass registry to provide certain things(e.g. the pass argument). Differential Revision: https://reviews.llvm.org/D76659
395 lines
14 KiB
C++
395 lines
14 KiB
C++
//===- LowerUniformRealMath.cpp ------------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "UniformKernelUtils.h"
|
|
|
|
#include "mlir/Dialect/FxpMathOps/FxpMathOps.h"
|
|
#include "mlir/Dialect/FxpMathOps/Passes.h"
|
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
|
#include "mlir/IR/Diagnostics.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::fxpmath;
|
|
using namespace mlir::fxpmath::detail;
|
|
using namespace mlir::quant;
|
|
|
|
namespace {
|
|
struct LowerUniformRealMathPass
|
|
: public FunctionPass<LowerUniformRealMathPass> {
|
|
/// Include the generated pass utilities.
|
|
#define GEN_PASS_FxpMathLowerUniformRealMath
|
|
#include "mlir/Dialect/FxpMathOps/Passes.h.inc"
|
|
|
|
void runOnFunction() override;
|
|
};
|
|
|
|
struct LowerUniformCastsPass : public FunctionPass<LowerUniformCastsPass> {
|
|
/// Include the generated pass utilities.
|
|
#define GEN_PASS_FxpMathLowerUniformCasts
|
|
#include "mlir/Dialect/FxpMathOps/Passes.h.inc"
|
|
|
|
void runOnFunction() override;
|
|
};
|
|
|
|
} // end anonymous namespace
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Dequantize
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static Value emitUniformPerLayerDequantize(Location loc, Value input,
|
|
UniformQuantizedType elementType,
|
|
PatternRewriter &rewriter) {
|
|
// Pre-conditions.
|
|
if (!elementType.isSigned()) {
|
|
// TODO: Support unsigned storage type.
|
|
emitWarning(loc, "unimplemented: dequantize signed uniform");
|
|
return nullptr;
|
|
}
|
|
|
|
Type storageType = elementType.castToStorageType(input.getType());
|
|
Type realType = elementType.castToExpressedType(input.getType());
|
|
Type intermediateType =
|
|
castElementType(storageType, IntegerType::get(32, rewriter.getContext()));
|
|
assert(storageType && "cannot cast to storage type");
|
|
assert(realType && "cannot cast to expressed type");
|
|
|
|
// Cast to storage type.
|
|
input = rewriter.create<StorageCastOp>(loc, storageType, input);
|
|
|
|
// Promote to intermediate type.
|
|
input = rewriter.create<ConvertISOp>(loc, intermediateType, input);
|
|
|
|
// Apply zero-point offset.
|
|
if (elementType.getZeroPoint() != 0) {
|
|
Value negZeroPointConst = rewriter.create<ConstantOp>(
|
|
loc, broadcastScalarConstIntValue(intermediateType,
|
|
-elementType.getZeroPoint()));
|
|
input = rewriter.create<AddIOp>(loc, input, negZeroPointConst);
|
|
}
|
|
|
|
// Convert to float.
|
|
input = rewriter.create<ConvertISToFOp>(loc, realType, input);
|
|
|
|
// Mul by scale.
|
|
Value scaleConst = rewriter.create<ConstantOp>(
|
|
loc, broadcastScalarConstFloatValue(realType,
|
|
APFloat(elementType.getScale())));
|
|
return rewriter.create<MulFOp>(loc, input, scaleConst);
|
|
}
|
|
|
|
static Value
|
|
emitUniformPerAxisDequantize(Location loc, Value input,
|
|
UniformQuantizedPerAxisType elementType,
|
|
PatternRewriter &rewriter) {
|
|
// TODO: Support per-axis dequantize.
|
|
rewriter.getContext()->getDiagEngine().emit(loc, DiagnosticSeverity::Warning)
|
|
<< "unimplemented: per-axis uniform dequantization";
|
|
return nullptr;
|
|
}
|
|
|
|
static Value emitDequantize(Location loc, Value input,
|
|
PatternRewriter &rewriter) {
|
|
Type inputType = input.getType();
|
|
QuantizedType qElementType =
|
|
QuantizedType::getQuantizedElementType(inputType);
|
|
if (auto uperLayerElementType =
|
|
qElementType.dyn_cast_or_null<UniformQuantizedType>()) {
|
|
return emitUniformPerLayerDequantize(loc, input, uperLayerElementType,
|
|
rewriter);
|
|
} else if (auto uperAxisElementType =
|
|
qElementType.dyn_cast_or_null<UniformQuantizedPerAxisType>()) {
|
|
return emitUniformPerAxisDequantize(loc, input, uperAxisElementType,
|
|
rewriter);
|
|
} else {
|
|
return nullptr;
|
|
}
|
|
}
|
|
|
|
namespace {
|
|
|
|
struct UniformDequantizePattern : public OpRewritePattern<DequantizeCastOp> {
|
|
using OpRewritePattern<DequantizeCastOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(DequantizeCastOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
Type inputType = op.arg().getType();
|
|
Type outputType = op.getResult().getType();
|
|
|
|
QuantizedType inputElementType =
|
|
QuantizedType::getQuantizedElementType(inputType);
|
|
Type expressedOutputType = inputElementType.castToExpressedType(inputType);
|
|
if (expressedOutputType != outputType) {
|
|
// Not a valid uniform cast.
|
|
return failure();
|
|
}
|
|
|
|
Value dequantizedValue = emitDequantize(op.getLoc(), op.arg(), rewriter);
|
|
if (!dequantizedValue) {
|
|
return failure();
|
|
}
|
|
|
|
rewriter.replaceOp(op, dequantizedValue);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // end anonymous namespace
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Elementwise add
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static LogicalResult
|
|
tryRewriteAffineAddEwIsomorphicSigned(const UniformBinaryOpInfo &info,
|
|
PatternRewriter &rewriter) {
|
|
if (!info.resultType.isSigned() || info.lhsType != info.resultType ||
|
|
info.rhsType != info.resultType) {
|
|
return failure();
|
|
}
|
|
|
|
// Choose a byte aligned intermediate width big enough to perform the
|
|
// calculation without overflow.
|
|
// TODO: This should probably be made just big enough to avoid overflow and
|
|
// leave the downstream tooling to decide how to align that to machine
|
|
// word sizes.
|
|
unsigned intermediateWidth =
|
|
info.resultType.getStorageTypeIntegralWidth() <= 8 ? 16 : 32;
|
|
IntegerType intermediateElementType =
|
|
IntegerType::get(intermediateWidth, rewriter.getContext());
|
|
Type intermediateType =
|
|
castElementType(info.resultStorageType, intermediateElementType);
|
|
|
|
// Cast operands to storage type.
|
|
Value lhsValue = rewriter
|
|
.create<StorageCastOp>(info.op->getLoc(),
|
|
info.lhsStorageType, info.lhs)
|
|
.getResult();
|
|
Value rhsValue = rewriter
|
|
.create<StorageCastOp>(info.op->getLoc(),
|
|
info.rhsStorageType, info.rhs)
|
|
.getResult();
|
|
|
|
// Cast to the intermediate sized type.
|
|
lhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
|
|
lhsValue);
|
|
rhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
|
|
rhsValue);
|
|
|
|
// Add.
|
|
Value resultValue =
|
|
rewriter.create<AddIOp>(info.op->getLoc(), lhsValue, rhsValue);
|
|
|
|
// Zero point offset adjustment.
|
|
// result = (lhs - zp) + (rhs - zp) + zp
|
|
// zpOffset = -zp
|
|
int zpOffset = -1 * info.resultType.getZeroPoint();
|
|
if (zpOffset != 0) {
|
|
Value zpOffsetConst = rewriter.create<ConstantOp>(
|
|
info.op->getLoc(),
|
|
broadcastScalarConstIntValue(intermediateType, zpOffset));
|
|
resultValue =
|
|
rewriter.create<AddIOp>(info.op->getLoc(), resultValue, zpOffsetConst);
|
|
}
|
|
|
|
// Clamp.
|
|
auto clampMinMax = info.getClampMinMax(intermediateElementType);
|
|
resultValue = rewriter.create<ClampISOp>(
|
|
info.op->getLoc(), resultValue, clampMinMax.first, clampMinMax.second);
|
|
|
|
// Convert back to original type.
|
|
resultValue = rewriter.create<ConvertISOp>(
|
|
info.op->getLoc(), info.resultStorageType, resultValue);
|
|
|
|
// Cast back for new result.
|
|
rewriter.replaceOpWithNewOp<StorageCastOp>(
|
|
info.op, info.getQuantizedResultType(), resultValue);
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Elementwise mul
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static LogicalResult
|
|
tryRewriteAffineMulEwSigned(const UniformBinaryOpInfo &info,
|
|
PatternRewriter &rewriter) {
|
|
if (!info.resultType.isSigned()) {
|
|
return failure();
|
|
}
|
|
|
|
double outputMultiplierReal = info.lhsType.getScale() *
|
|
info.rhsType.getScale() /
|
|
info.resultType.getScale();
|
|
if (outputMultiplierReal > 1.0) {
|
|
info.op->emitWarning(
|
|
"unimplemented: cannot multiply with multiplier > 1.0");
|
|
return failure();
|
|
}
|
|
|
|
// TODO: Choose an appropriate intermediate width for muls > 8 bits to
|
|
// avoid overflow.
|
|
unsigned intermediateWidth = 32;
|
|
IntegerType intermediateElementType =
|
|
IntegerType::get(intermediateWidth, rewriter.getContext());
|
|
Type intermediateType =
|
|
castElementType(info.resultStorageType, intermediateElementType);
|
|
|
|
// Cast operands to storage type.
|
|
Value lhsValue = rewriter
|
|
.create<StorageCastOp>(info.op->getLoc(),
|
|
info.lhsStorageType, info.lhs)
|
|
.getResult();
|
|
Value rhsValue = rewriter
|
|
.create<StorageCastOp>(info.op->getLoc(),
|
|
info.rhsStorageType, info.rhs)
|
|
.getResult();
|
|
|
|
// Cast to the intermediate sized type.
|
|
lhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
|
|
lhsValue);
|
|
rhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
|
|
rhsValue);
|
|
|
|
// Apply argument zeroPoints.
|
|
if (info.lhsType.getZeroPoint() != 0) {
|
|
Value zpOffsetConst = rewriter.create<ConstantOp>(
|
|
info.op->getLoc(), broadcastScalarConstIntValue(
|
|
intermediateType, -info.lhsType.getZeroPoint()));
|
|
lhsValue =
|
|
rewriter.create<AddIOp>(info.op->getLoc(), lhsValue, zpOffsetConst);
|
|
}
|
|
|
|
if (info.rhsType.getZeroPoint() != 0) {
|
|
Value zpOffsetConst = rewriter.create<ConstantOp>(
|
|
info.op->getLoc(), broadcastScalarConstIntValue(
|
|
intermediateType, -info.rhsType.getZeroPoint()));
|
|
rhsValue =
|
|
rewriter.create<AddIOp>(info.op->getLoc(), rhsValue, zpOffsetConst);
|
|
}
|
|
|
|
// Mul.
|
|
Value resultValue =
|
|
rewriter.create<MulIOp>(info.op->getLoc(), lhsValue, rhsValue);
|
|
|
|
// Scale output.
|
|
QuantizedMultiplierSmallerThanOneExp outputMultiplier(outputMultiplierReal);
|
|
resultValue = rewriter.create<VecScalarSaturatingRoundingDoublingHighMulISOp>(
|
|
info.op->getLoc(), resultValue,
|
|
IntegerAttr::get(intermediateElementType, outputMultiplier.multiplier));
|
|
resultValue = rewriter.create<RoundingDivideByPotISOp>(
|
|
info.op->getLoc(), resultValue,
|
|
IntegerAttr::get(intermediateElementType, -outputMultiplier.exponent));
|
|
|
|
// Zero point offset adjustment.
|
|
if (info.resultType.getZeroPoint() != 0) {
|
|
Value zpOffsetConst = rewriter.create<ConstantOp>(
|
|
info.op->getLoc(),
|
|
broadcastScalarConstIntValue(intermediateType,
|
|
info.resultType.getZeroPoint()));
|
|
resultValue =
|
|
rewriter.create<AddIOp>(info.op->getLoc(), resultValue, zpOffsetConst);
|
|
}
|
|
|
|
// Clamp.
|
|
auto clampMinMax = info.getClampMinMax(intermediateElementType);
|
|
resultValue = rewriter.create<ClampISOp>(
|
|
info.op->getLoc(), resultValue, clampMinMax.first, clampMinMax.second);
|
|
|
|
// Convert back to original type.
|
|
resultValue = rewriter.create<ConvertISOp>(
|
|
info.op->getLoc(), info.resultStorageType, resultValue);
|
|
|
|
// Cast back for new result.
|
|
rewriter.replaceOpWithNewOp<StorageCastOp>(
|
|
info.op, info.getQuantizedResultType(), resultValue);
|
|
|
|
return success();
|
|
}
|
|
|
|
namespace {
|
|
|
|
struct UniformRealAddEwPattern : public OpRewritePattern<RealAddEwOp> {
|
|
using OpRewritePattern<RealAddEwOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(RealAddEwOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
const UniformBinaryOpInfo info(op, op.lhs(), op.rhs(), op.clamp_min(),
|
|
op.clamp_max());
|
|
if (!info.isValid()) {
|
|
return failure();
|
|
}
|
|
|
|
// Try all of the permutations we support.
|
|
if (succeeded(tryRewriteAffineAddEwIsomorphicSigned(info, rewriter))) {
|
|
return success();
|
|
}
|
|
|
|
return failure();
|
|
}
|
|
};
|
|
|
|
struct UniformRealMulEwPattern : public OpRewritePattern<RealMulEwOp> {
|
|
using OpRewritePattern<RealMulEwOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(RealMulEwOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
const UniformBinaryOpInfo info(op, op.lhs(), op.rhs(), op.clamp_min(),
|
|
op.clamp_max());
|
|
if (!info.isValid()) {
|
|
return failure();
|
|
}
|
|
|
|
// Try all of the permutations we support.
|
|
if (succeeded(tryRewriteAffineMulEwSigned(info, rewriter))) {
|
|
return success();
|
|
}
|
|
|
|
return failure();
|
|
}
|
|
};
|
|
|
|
} // end anonymous namespace
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// LowerUniformRealMath pass
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void LowerUniformRealMathPass::runOnFunction() {
|
|
auto fn = getFunction();
|
|
OwningRewritePatternList patterns;
|
|
auto *context = &getContext();
|
|
patterns.insert<UniformRealAddEwPattern, UniformRealMulEwPattern>(context);
|
|
applyPatternsGreedily(fn, patterns);
|
|
}
|
|
|
|
std::unique_ptr<OpPassBase<FuncOp>>
|
|
mlir::fxpmath::createLowerUniformRealMathPass() {
|
|
return std::make_unique<LowerUniformRealMathPass>();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// LowerUniformCasts pass
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void LowerUniformCastsPass::runOnFunction() {
|
|
auto fn = getFunction();
|
|
OwningRewritePatternList patterns;
|
|
auto *context = &getContext();
|
|
patterns.insert<UniformDequantizePattern>(context);
|
|
applyPatternsGreedily(fn, patterns);
|
|
}
|
|
|
|
std::unique_ptr<OpPassBase<FuncOp>>
|
|
mlir::fxpmath::createLowerUniformCastsPass() {
|
|
return std::make_unique<LowerUniformCastsPass>();
|
|
}
|