2019-04-03 16:07:37 -07:00
|
|
|
//===- ConvertSimQuant.cpp - Converts simulated quant ops------------------===//
|
2019-04-03 11:16:32 -07:00
|
|
|
//
|
2020-01-26 03:58:30 +00:00
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
2019-12-23 09:35:36 -08:00
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
2019-04-03 11:16:32 -07:00
|
|
|
//
|
2019-12-23 09:35:36 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2019-04-03 11:16:32 -07:00
|
|
|
|
2020-03-17 14:56:52 -07:00
|
|
|
#include "mlir/Dialect/Quant/FakeQuantSupport.h"
|
|
|
|
|
#include "mlir/Dialect/Quant/Passes.h"
|
|
|
|
|
#include "mlir/Dialect/Quant/QuantOps.h"
|
|
|
|
|
#include "mlir/Dialect/Quant/UniformSupport.h"
|
2019-04-03 11:16:32 -07:00
|
|
|
#include "mlir/IR/Attributes.h"
|
|
|
|
|
#include "mlir/IR/PatternMatch.h"
|
|
|
|
|
#include "mlir/IR/StandardTypes.h"
|
|
|
|
|
#include "mlir/Pass/Pass.h"
|
|
|
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
|
using namespace mlir::quant;
|
|
|
|
|
|
|
|
|
|
namespace {
|
2020-04-01 01:50:29 -07:00
|
|
|
struct ConvertSimulatedQuantPass
|
2019-04-03 16:07:37 -07:00
|
|
|
: public FunctionPass<ConvertSimulatedQuantPass> {
|
2020-04-01 01:50:29 -07:00
|
|
|
/// Include the generated pass utilities.
|
|
|
|
|
#define GEN_PASS_QuantConvertSimulatedQuant
|
|
|
|
|
#include "mlir/Dialect/Quant/Passes.h.inc"
|
|
|
|
|
|
2019-04-03 11:16:32 -07:00
|
|
|
void runOnFunction() override;
|
|
|
|
|
};
|
|
|
|
|
|
2019-09-10 10:50:16 -07:00
|
|
|
/// Base class rewrites ConstFakeQuant into a qbarrier/dbarrier pair.
|
2019-10-04 04:37:14 -07:00
|
|
|
template <typename ConcreteRewriteClass, typename FakeQuantOp>
|
2019-09-10 10:50:16 -07:00
|
|
|
class FakeQuantRewrite : public OpRewritePattern<FakeQuantOp> {
|
2019-04-03 11:16:32 -07:00
|
|
|
public:
|
2019-09-10 10:50:16 -07:00
|
|
|
using OpRewritePattern<FakeQuantOp>::OpRewritePattern;
|
2019-04-03 11:16:32 -07:00
|
|
|
|
2019-09-10 10:50:16 -07:00
|
|
|
FakeQuantRewrite(MLIRContext *ctx, bool *hadFailure)
|
|
|
|
|
: OpRewritePattern<FakeQuantOp>(ctx), hadFailure(hadFailure) {}
|
2019-04-03 11:16:32 -07:00
|
|
|
|
2020-03-17 20:07:55 -07:00
|
|
|
LogicalResult matchAndRewrite(FakeQuantOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
2019-04-03 11:16:32 -07:00
|
|
|
// TODO: If this pattern comes up more frequently, consider adding core
|
|
|
|
|
// support for failable rewrites.
|
|
|
|
|
if (failableRewrite(op, rewriter)) {
|
|
|
|
|
*hadFailure = true;
|
2020-03-17 20:07:55 -07:00
|
|
|
return failure();
|
2019-04-03 11:16:32 -07:00
|
|
|
}
|
2019-04-03 16:07:37 -07:00
|
|
|
|
2020-03-17 20:07:55 -07:00
|
|
|
return success();
|
2019-04-03 11:16:32 -07:00
|
|
|
}
|
|
|
|
|
|
2019-09-10 10:50:16 -07:00
|
|
|
private:
|
|
|
|
|
bool *hadFailure;
|
2019-04-03 11:16:32 -07:00
|
|
|
|
2019-09-10 10:50:16 -07:00
|
|
|
bool failableRewrite(FakeQuantOp op, PatternRewriter &rewriter) const {
|
|
|
|
|
auto converter = ExpressedToQuantizedConverter::forInputType(op.getType());
|
2019-04-03 11:16:32 -07:00
|
|
|
if (!converter) {
|
2019-09-10 10:50:16 -07:00
|
|
|
return (op.emitError("unsupported quantized type conversion"), true);
|
2019-04-03 11:16:32 -07:00
|
|
|
}
|
|
|
|
|
|
2019-09-10 10:50:16 -07:00
|
|
|
QuantizedType elementType =
|
2019-10-04 04:37:14 -07:00
|
|
|
static_cast<const ConcreteRewriteClass *>(this)
|
2019-09-10 10:50:16 -07:00
|
|
|
->convertFakeQuantAttrsToType(op, converter.expressedType);
|
2019-04-03 11:16:32 -07:00
|
|
|
|
2019-09-10 10:50:16 -07:00
|
|
|
if (!elementType) {
|
2019-04-03 11:16:32 -07:00
|
|
|
// Note that the fakeQuantAttrsToType will have emitted the error.
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
2019-09-10 10:50:16 -07:00
|
|
|
Type quantizedType = converter.convert(elementType);
|
2019-04-03 11:16:32 -07:00
|
|
|
assert(quantizedType &&
|
|
|
|
|
"Converter accepted a type that it did not convert");
|
|
|
|
|
|
|
|
|
|
// TODO: Map to a qbarrier with an attribute like [Forced] to signal that
|
|
|
|
|
// this is a forced/hard-coded constraint.
|
2019-09-10 10:50:16 -07:00
|
|
|
auto qbarrier = rewriter.create<QuantizeCastOp>(op.getLoc(), quantizedType,
|
|
|
|
|
op.inputs());
|
2019-04-16 18:36:24 -07:00
|
|
|
rewriter.replaceOpWithNewOp<DequantizeCastOp>(op, converter.inputType,
|
|
|
|
|
qbarrier.getResult());
|
2019-04-03 11:16:32 -07:00
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
2019-09-10 10:50:16 -07:00
|
|
|
class ConstFakeQuantRewrite
|
|
|
|
|
: public FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant> {
|
|
|
|
|
public:
|
|
|
|
|
using BaseRewrite = FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant>;
|
|
|
|
|
|
|
|
|
|
ConstFakeQuantRewrite(MLIRContext *ctx, bool *hadFailure)
|
|
|
|
|
: BaseRewrite(ctx, hadFailure) {}
|
|
|
|
|
|
|
|
|
|
QuantizedType convertFakeQuantAttrsToType(ConstFakeQuant fqOp,
|
|
|
|
|
Type expressedType) const {
|
|
|
|
|
return fakeQuantAttrsToType(
|
|
|
|
|
fqOp.getLoc(), fqOp.num_bits().getSExtValue(),
|
|
|
|
|
fqOp.min().convertToFloat(), fqOp.max().convertToFloat(),
|
|
|
|
|
fqOp.narrow_range(), expressedType, fqOp.is_signed());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class ConstFakeQuantPerAxisRewrite
|
|
|
|
|
: public FakeQuantRewrite<ConstFakeQuantPerAxisRewrite,
|
|
|
|
|
ConstFakeQuantPerAxis> {
|
|
|
|
|
public:
|
|
|
|
|
using BaseRewrite =
|
|
|
|
|
FakeQuantRewrite<ConstFakeQuantPerAxisRewrite, ConstFakeQuantPerAxis>;
|
|
|
|
|
|
|
|
|
|
ConstFakeQuantPerAxisRewrite(MLIRContext *ctx, bool *hadFailure)
|
|
|
|
|
: BaseRewrite(ctx, hadFailure) {}
|
|
|
|
|
|
|
|
|
|
QuantizedType convertFakeQuantAttrsToType(ConstFakeQuantPerAxis fqOp,
|
|
|
|
|
Type expressedType) const {
|
|
|
|
|
SmallVector<double, 4> min, max;
|
|
|
|
|
min.reserve(fqOp.min().size());
|
|
|
|
|
max.reserve(fqOp.max().size());
|
|
|
|
|
for (auto m : fqOp.min())
|
|
|
|
|
min.push_back(m.cast<FloatAttr>().getValueAsDouble());
|
|
|
|
|
for (auto m : fqOp.max())
|
|
|
|
|
max.push_back(m.cast<FloatAttr>().getValueAsDouble());
|
|
|
|
|
|
|
|
|
|
return fakeQuantAttrsToType(fqOp.getLoc(), fqOp.num_bits().getSExtValue(),
|
|
|
|
|
fqOp.axis().getSExtValue(), min, max,
|
|
|
|
|
fqOp.narrow_range(), expressedType,
|
|
|
|
|
fqOp.is_signed());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
2020-01-14 14:06:12 +01:00
|
|
|
} // namespace
|
|
|
|
|
|
2019-04-03 16:07:37 -07:00
|
|
|
void ConvertSimulatedQuantPass::runOnFunction() {
|
2019-04-03 11:16:32 -07:00
|
|
|
bool hadFailure = false;
|
|
|
|
|
OwningRewritePatternList patterns;
|
2019-07-01 10:29:09 -07:00
|
|
|
auto func = getFunction();
|
2019-09-10 10:50:16 -07:00
|
|
|
auto ctx = func.getContext();
|
|
|
|
|
patterns.insert<ConstFakeQuantRewrite, ConstFakeQuantPerAxisRewrite>(
|
|
|
|
|
ctx, &hadFailure);
|
2019-08-09 17:20:02 -07:00
|
|
|
applyPatternsGreedily(func, patterns);
|
2019-04-03 11:16:32 -07:00
|
|
|
if (hadFailure)
|
|
|
|
|
signalPassFailure();
|
|
|
|
|
}
|
|
|
|
|
|
2019-09-13 13:33:46 -07:00
|
|
|
std::unique_ptr<OpPassBase<FuncOp>>
|
2019-08-12 19:12:42 -07:00
|
|
|
mlir::quant::createConvertSimulatedQuantPass() {
|
2019-08-17 11:05:35 -07:00
|
|
|
return std::make_unique<ConvertSimulatedQuantPass>();
|
2019-04-03 16:07:37 -07:00
|
|
|
}
|