mirror of
https://github.com/intel/llvm.git
synced 2026-01-24 08:30:34 +08:00
Summary: Renamed QuantOps to Quant to avoid the Ops suffix. All dialects will contain ops, so the Ops suffix is redundant. Differential Revision: https://reviews.llvm.org/D76318
279 lines
11 KiB
C++
279 lines
11 KiB
C++
//===- FxpMathConfig.cpp - Reference fixed point config -------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file defines a TargetConfiguration for reference fixed-point math
|
|
// quantization scheme based on the FxpMathOps (plus a small category of
|
|
// extension ops that can be added from other dialects).
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Quantizer/Configurations/FxpMathConfig.h"
|
|
|
|
#include "mlir/Dialect/FxpMathOps/FxpMathOps.h"
|
|
#include "mlir/Dialect/Quant/QuantOps.h"
|
|
#include "mlir/Dialect/Quant/QuantTypes.h"
|
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
#include "mlir/IR/StandardTypes.h"
|
|
#include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h"
|
|
#include "mlir/Quantizer/Support/Metadata.h"
|
|
#include "mlir/Quantizer/Support/Statistics.h"
|
|
#include "mlir/Quantizer/Support/UniformConstraints.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::quantizer;
|
|
using namespace mlir::fxpmath;
|
|
using namespace mlir::quant;
|
|
using namespace std::placeholders;
|
|
|
|
namespace {
|
|
|
|
struct FxpMathTargetConfigImpl : public FxpMathTargetConfig {
|
|
FxpMathTargetConfigImpl(SolverContext &context)
|
|
: FxpMathTargetConfig(context) {
|
|
Builder b(&context.getMlirContext());
|
|
IntegerType i8Type = b.getIntegerType(8);
|
|
IntegerType i16Type = b.getIntegerType(16);
|
|
IntegerType i32Type = b.getIntegerType(32);
|
|
|
|
q8 = addCandidateType(
|
|
AnyQuantizedType::get(QuantizationFlags::Signed, i8Type, nullptr,
|
|
std::numeric_limits<int8_t>::min(),
|
|
std::numeric_limits<int8_t>::max()),
|
|
CandidateQuantizedType::Scheme::UniformPerLayer);
|
|
q16 = addCandidateType(
|
|
AnyQuantizedType::get(QuantizationFlags::Signed, i16Type, nullptr,
|
|
std::numeric_limits<int16_t>::min(),
|
|
std::numeric_limits<int16_t>::max()),
|
|
CandidateQuantizedType::Scheme::UniformPerLayer);
|
|
q32ExplicitFixedPoint = addCandidateType(
|
|
AnyQuantizedType::get(QuantizationFlags::Signed, i32Type, nullptr,
|
|
std::numeric_limits<int32_t>::min(),
|
|
std::numeric_limits<int32_t>::max()),
|
|
CandidateQuantizedType::Scheme::UniformExplicitFixedPointScale);
|
|
|
|
// Op handlers.
|
|
addOpHandler<ConstantOp>(
|
|
std::bind(&FxpMathTargetConfigImpl::handleConstant, this, _1, _2));
|
|
addOpHandler<mlir::ReturnOp>(
|
|
std::bind(&FxpMathTargetConfigImpl::handleTerminal, this, _1, _2));
|
|
addOpHandler<quant::StatisticsOp>(
|
|
std::bind(&FxpMathTargetConfigImpl::handleStats, this, _1, _2));
|
|
|
|
// FxpMathOps.
|
|
addOpHandler<RealAddEwOp>(
|
|
std::bind(&FxpMathTargetConfigImpl::handleAdd, this, _1, _2));
|
|
addOpHandler<RealMulEwOp>(
|
|
std::bind(&FxpMathTargetConfigImpl::handleMul, this, _1, _2));
|
|
addOpHandler<RealMatMulOp>(
|
|
std::bind(&FxpMathTargetConfigImpl::handleMatMul, this, _1, _2));
|
|
addOpHandler<RealMatMulBiasOp>(
|
|
std::bind(&FxpMathTargetConfigImpl::handleMatMulBias, this, _1, _2));
|
|
|
|
// Require stats ops.
|
|
addRequireStatsOp<RealAddEwOp>();
|
|
addRequireStatsOp<RealSubEwOp>();
|
|
addRequireStatsOp<RealDivEwOp>();
|
|
addRequireStatsOp<RealMulEwOp>();
|
|
addRequireStatsOp<RealMatMulOp>();
|
|
addRequireStatsOp<RealMatMulBiasOp>();
|
|
}
|
|
|
|
bool isHandledType(Type t) const final {
|
|
if (t.isa<FloatType>())
|
|
return true;
|
|
return (t.isa<VectorType>() || t.isa<TensorType>()) &&
|
|
t.cast<ShapedType>().getElementType().isa<FloatType>();
|
|
}
|
|
|
|
void finalizeAnchors(CAGSlice &cag) const override {
|
|
cag.enumerateImpliedConnections(
|
|
[&](CAGAnchorNode *from, CAGAnchorNode *to) {
|
|
UniformConstraintsBuilder(cag).coupleAnchors(from, to);
|
|
});
|
|
}
|
|
|
|
void addValueIdentityOpByName(StringRef opName) override {
|
|
addOpHandlerByName(
|
|
opName,
|
|
std::bind(&FxpMathTargetConfigImpl::handleValueIdentity, this, _1, _2));
|
|
}
|
|
|
|
void handleValueIdentity(Operation *op, CAGSlice &cag) const {
|
|
assert(op->getNumResults() == 1);
|
|
if (!isHandledType(op->getResult(0).getType()))
|
|
return;
|
|
|
|
auto resultNode = cag.getResultAnchor(op, 0);
|
|
resultNode->setTypeTransformRule(
|
|
CAGAnchorNode::TypeTransformRule::DirectStorage);
|
|
|
|
for (unsigned opIdx = 0, e = op->getNumOperands(); opIdx < e; ++opIdx) {
|
|
if (!isHandledType(op->getOperand(opIdx).getType()))
|
|
continue;
|
|
auto operandNode = cag.getOperandAnchor(op, opIdx);
|
|
operandNode->setTypeTransformRule(
|
|
CAGAnchorNode::TypeTransformRule::DirectStorage);
|
|
UniformConstraintsBuilder(cag).coupleAnchors(operandNode, resultNode);
|
|
}
|
|
}
|
|
|
|
void handleConstant(Operation *op, CAGSlice &cag) const {
|
|
if (!isHandledType(op->getResult(0).getType()))
|
|
return;
|
|
|
|
auto resultNode = cag.getResultAnchor(op, 0);
|
|
resultNode->setTypeTransformRule(
|
|
CAGAnchorNode::TypeTransformRule::ExpressedOnly);
|
|
Attribute valueAttr;
|
|
if (!matchPattern(op, m_Constant(&valueAttr))) {
|
|
return;
|
|
}
|
|
|
|
AttributeTensorStatistics stats(valueAttr);
|
|
TensorAxisStatistics layerStats;
|
|
if (!stats.get(layerStats)) {
|
|
op->emitOpError("could not compute statistics");
|
|
return;
|
|
}
|
|
|
|
UniformConstraintsBuilder(cag).applyStats(resultNode, layerStats);
|
|
}
|
|
|
|
void handleTerminal(Operation *op, CAGSlice &cag) const {
|
|
if (!isHandledType(op->getOperand(0).getType()))
|
|
return;
|
|
auto operandNode = cag.getOperandAnchor(op, 0);
|
|
operandNode->setTypeTransformRule(
|
|
CAGAnchorNode::TypeTransformRule::ExpressedOnly);
|
|
}
|
|
|
|
void handleStats(Operation *op, CAGSlice &cag) const {
|
|
if (!isHandledType(op->getResult(0).getType()))
|
|
return;
|
|
|
|
auto argNode = cag.getOperandAnchor(op, 0);
|
|
auto resultNode = cag.getResultAnchor(op, 0);
|
|
UniformConstraintsBuilder(cag).coupleAnchors(argNode, resultNode);
|
|
|
|
TensorAxisStatistics layerStats;
|
|
auto statsOp = cast<quant::StatisticsOp>(op);
|
|
auto layerStatsAttr = statsOp.layerStats();
|
|
layerStats.minValue =
|
|
layerStatsAttr.getValue<FloatAttr>(0).getValueAsDouble();
|
|
layerStats.maxValue =
|
|
layerStatsAttr.getValue<FloatAttr>(1).getValueAsDouble();
|
|
UniformConstraintsBuilder(cag).applyStats(resultNode, layerStats);
|
|
}
|
|
|
|
void handleAdd(Operation *op, CAGSlice &cag) const {
|
|
if (!isHandledType(op->getResult(0).getType()))
|
|
return;
|
|
|
|
auto lhs = cag.getOperandAnchor(op, 0);
|
|
auto rhs = cag.getOperandAnchor(op, 1);
|
|
auto resultNode = cag.getResultAnchor(op, 0);
|
|
// Add supports 8/16 bit math.
|
|
llvm::SmallBitVector disableMask =
|
|
getCandidateTypeDisabledExceptMask({q8, q16});
|
|
lhs->getUniformMetadata().disabledCandidateTypes = disableMask;
|
|
rhs->getUniformMetadata().disabledCandidateTypes = disableMask;
|
|
resultNode->getUniformMetadata().disabledCandidateTypes = disableMask;
|
|
// NOTE: We couple the add such that the scale/zeroPoint match between
|
|
// both args and the result. This is overly constrained in that it is
|
|
// possible to write efficient add kernels with a bit more freedom (i.e.
|
|
// zeroPoints can vary, scales can differ by a power of two, etc).
|
|
// However, fully coupled yields the simples solutions on the fast path.
|
|
// Further efficiency can be had by constraining the zeroPoint to 0, but
|
|
// there isn't a constraint for this yet (and there are tradeoffs).
|
|
UniformConstraintsBuilder(cag).coupleAnchors(lhs, resultNode);
|
|
UniformConstraintsBuilder(cag).coupleAnchors(rhs, resultNode);
|
|
addRealMathOptionalConstraints(op, resultNode, cag);
|
|
}
|
|
|
|
void handleMul(Operation *op, CAGSlice &cag) const {
|
|
if (!isHandledType(op->getResult(0).getType()))
|
|
return;
|
|
|
|
auto lhs = cag.getOperandAnchor(op, 0);
|
|
auto rhs = cag.getOperandAnchor(op, 1);
|
|
auto resultNode = cag.getResultAnchor(op, 0);
|
|
// Mul supports 8/16 bit math.
|
|
llvm::SmallBitVector disableMask =
|
|
getCandidateTypeDisabledExceptMask({q8, q16});
|
|
lhs->getUniformMetadata().disabledCandidateTypes = disableMask;
|
|
rhs->getUniformMetadata().disabledCandidateTypes = disableMask;
|
|
resultNode->getUniformMetadata().disabledCandidateTypes = disableMask;
|
|
addRealMathOptionalConstraints(op, resultNode, cag);
|
|
}
|
|
|
|
void handleMatMul(Operation *op, CAGSlice &cag) const {
|
|
if (!isHandledType(op->getResult(0).getType()))
|
|
return;
|
|
|
|
auto lhs = cag.getOperandAnchor(op, 0);
|
|
auto rhs = cag.getOperandAnchor(op, 1);
|
|
auto resultNode = cag.getResultAnchor(op, 0);
|
|
// Mul supports 8/16 bit math.
|
|
llvm::SmallBitVector disableMask =
|
|
getCandidateTypeDisabledExceptMask({q8, q16});
|
|
lhs->getUniformMetadata().disabledCandidateTypes = disableMask;
|
|
rhs->getUniformMetadata().disabledCandidateTypes = disableMask;
|
|
resultNode->getUniformMetadata().disabledCandidateTypes = disableMask;
|
|
addRealMathOptionalConstraints(op, resultNode, cag);
|
|
}
|
|
|
|
void handleMatMulBias(Operation *op, CAGSlice &cag) const {
|
|
if (!isHandledType(op->getResult(0).getType()))
|
|
return;
|
|
|
|
auto lhs = cag.getOperandAnchor(op, 0);
|
|
auto rhs = cag.getOperandAnchor(op, 1);
|
|
auto bias = cag.getOperandAnchor(op, 2);
|
|
bias->getUniformMetadata().disabledCandidateTypes =
|
|
getCandidateTypeDisabledExceptMask({q32ExplicitFixedPoint});
|
|
|
|
auto resultNode = cag.getResultAnchor(op, 0);
|
|
UniformConstraintsBuilder(cag).propagateExplicitScale(resultNode, bias);
|
|
|
|
// Mul supports 8/16 bit math.
|
|
llvm::SmallBitVector disableMask =
|
|
getCandidateTypeDisabledExceptMask({q8, q16});
|
|
lhs->getUniformMetadata().disabledCandidateTypes = disableMask;
|
|
rhs->getUniformMetadata().disabledCandidateTypes = disableMask;
|
|
resultNode->getUniformMetadata().disabledCandidateTypes = disableMask;
|
|
addRealMathOptionalConstraints(op, resultNode, cag);
|
|
}
|
|
|
|
void addRealMathOptionalConstraints(Operation *op, CAGAnchorNode *anchor,
|
|
CAGSlice &cag) const {
|
|
// TODO: It would be nice if these all extended some base trait instead
|
|
// of requiring name lookup.
|
|
auto clampMinAttr = op->getAttrOfType<FloatAttr>("clamp_min");
|
|
auto clampMaxAttr = op->getAttrOfType<FloatAttr>("clamp_max");
|
|
|
|
if (clampMinAttr || clampMaxAttr) {
|
|
auto nan = APFloat::getQNaN(APFloat::IEEEdouble());
|
|
auto clampMin = clampMinAttr ? clampMinAttr.getValue() : nan;
|
|
auto clampMax = clampMaxAttr ? clampMaxAttr.getValue() : nan;
|
|
UniformConstraintsBuilder(cag).clamp(anchor, clampMin, clampMax);
|
|
}
|
|
}
|
|
|
|
unsigned q8;
|
|
unsigned q16;
|
|
unsigned q32ExplicitFixedPoint;
|
|
};
|
|
|
|
} // anonymous namespace
|
|
|
|
std::unique_ptr<FxpMathTargetConfig>
|
|
FxpMathTargetConfig::create(SolverContext &context) {
|
|
return std::make_unique<FxpMathTargetConfigImpl>(context);
|
|
}
|