2019-04-03 16:07:37 -07:00
|
|
|
//===- ConvertSimQuant.cpp - Converts simulated quant ops------------------===//
|
2019-04-03 11:16:32 -07:00
|
|
|
//
|
|
|
|
|
// Copyright 2019 The MLIR Authors.
|
|
|
|
|
//
|
|
|
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
// you may not use this file except in compliance with the License.
|
|
|
|
|
// You may obtain a copy of the License at
|
|
|
|
|
//
|
|
|
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
//
|
|
|
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
|
// limitations under the License.
|
|
|
|
|
// =============================================================================
|
|
|
|
|
|
2019-05-14 11:03:55 -07:00
|
|
|
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h"
|
|
|
|
|
#include "mlir/Dialect/QuantOps/Passes.h"
|
|
|
|
|
#include "mlir/Dialect/QuantOps/QuantOps.h"
|
|
|
|
|
#include "mlir/Dialect/QuantOps/UniformSupport.h"
|
2019-04-03 11:16:32 -07:00
|
|
|
#include "mlir/IR/Attributes.h"
|
|
|
|
|
#include "mlir/IR/PatternMatch.h"
|
|
|
|
|
#include "mlir/IR/StandardTypes.h"
|
|
|
|
|
#include "mlir/Pass/Pass.h"
|
|
|
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
|
using namespace mlir::quant;
|
|
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
|
2019-04-03 16:07:37 -07:00
|
|
|
class ConvertSimulatedQuantPass
|
|
|
|
|
: public FunctionPass<ConvertSimulatedQuantPass> {
|
2019-04-03 11:16:32 -07:00
|
|
|
public:
|
|
|
|
|
void runOnFunction() override;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // end anonymous namespace
|
|
|
|
|
|
2019-04-03 16:07:37 -07:00
|
|
|
/// Rewrites ConstFakeQuant into a qbarrier/dbarrier pair.
|
|
|
|
|
class ConstFakeQuantRewrite : public RewritePattern {
|
2019-04-03 11:16:32 -07:00
|
|
|
public:
|
|
|
|
|
bool *hadFailure;
|
|
|
|
|
|
2019-04-03 16:07:37 -07:00
|
|
|
ConstFakeQuantRewrite(MLIRContext *context, bool *hadFailure)
|
|
|
|
|
: RewritePattern(ConstFakeQuant::getOperationName(), 1, context),
|
2019-04-03 11:16:32 -07:00
|
|
|
hadFailure(hadFailure) {}
|
|
|
|
|
|
2019-04-03 16:07:37 -07:00
|
|
|
PatternMatchResult matchAndRewrite(Operation *op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
2019-04-03 11:16:32 -07:00
|
|
|
// TODO: If this pattern comes up more frequently, consider adding core
|
|
|
|
|
// support for failable rewrites.
|
|
|
|
|
if (failableRewrite(op, rewriter)) {
|
|
|
|
|
*hadFailure = true;
|
2019-04-22 16:06:09 -07:00
|
|
|
return matchFailure();
|
2019-04-03 11:16:32 -07:00
|
|
|
}
|
2019-04-03 16:07:37 -07:00
|
|
|
|
|
|
|
|
return matchSuccess();
|
2019-04-03 11:16:32 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool failableRewrite(Operation *op, PatternRewriter &rewriter) const {
|
2019-05-11 17:57:32 -07:00
|
|
|
auto fqOp = cast<ConstFakeQuant>(op);
|
2019-04-03 11:16:32 -07:00
|
|
|
|
|
|
|
|
auto converter =
|
|
|
|
|
ExpressedToUniformQuantizedConverter::forInputType(fqOp.getType());
|
|
|
|
|
if (!converter) {
|
|
|
|
|
return (op->emitError("unsupported quantized type conversion"), true);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
UniformQuantizedType uniformElementType = fakeQuantAttrsToType(
|
|
|
|
|
fqOp.getLoc(), fqOp.num_bits().getSExtValue(),
|
2019-04-04 09:25:38 -07:00
|
|
|
fqOp.min().convertToFloat(), fqOp.max().convertToFloat(),
|
2019-07-18 11:25:53 -07:00
|
|
|
fqOp.narrow_range(), converter.expressedType, fqOp.is_signed());
|
2019-04-03 11:16:32 -07:00
|
|
|
|
|
|
|
|
if (!uniformElementType) {
|
|
|
|
|
// Note that the fakeQuantAttrsToType will have emitted the error.
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Type quantizedType = converter.convert(uniformElementType);
|
|
|
|
|
assert(quantizedType &&
|
|
|
|
|
"Converter accepted a type that it did not convert");
|
|
|
|
|
|
|
|
|
|
// TODO: Map to a qbarrier with an attribute like [Forced] to signal that
|
|
|
|
|
// this is a forced/hard-coded constraint.
|
2019-04-16 18:36:24 -07:00
|
|
|
auto qbarrier = rewriter.create<QuantizeCastOp>(op->getLoc(), quantizedType,
|
|
|
|
|
fqOp.inputs());
|
|
|
|
|
rewriter.replaceOpWithNewOp<DequantizeCastOp>(op, converter.inputType,
|
|
|
|
|
qbarrier.getResult());
|
2019-04-03 11:16:32 -07:00
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
2019-04-03 16:07:37 -07:00
|
|
|
void ConvertSimulatedQuantPass::runOnFunction() {
|
2019-04-03 11:16:32 -07:00
|
|
|
bool hadFailure = false;
|
|
|
|
|
OwningRewritePatternList patterns;
|
2019-07-01 10:29:09 -07:00
|
|
|
auto func = getFunction();
|
2019-04-03 11:16:32 -07:00
|
|
|
auto *context = &getContext();
|
2019-08-05 18:37:56 -07:00
|
|
|
patterns.insert<ConstFakeQuantRewrite>(context, &hadFailure);
|
2019-08-09 17:20:02 -07:00
|
|
|
applyPatternsGreedily(func, patterns);
|
2019-04-03 11:16:32 -07:00
|
|
|
if (hadFailure)
|
|
|
|
|
signalPassFailure();
|
|
|
|
|
}
|
|
|
|
|
|
2019-04-30 15:32:55 -07:00
|
|
|
FunctionPassBase *mlir::quant::createConvertSimulatedQuantPass() {
|
2019-04-03 16:07:37 -07:00
|
|
|
return new ConvertSimulatedQuantPass();
|
|
|
|
|
}
|
2019-04-03 11:16:32 -07:00
|
|
|
|
2019-04-03 16:07:37 -07:00
|
|
|
static PassRegistration<ConvertSimulatedQuantPass>
|
|
|
|
|
pass("quant-convert-simulated-quantization",
|
|
|
|
|
"Converts training-time simulated quantization ops to corresponding "
|
|
|
|
|
"quantize/dequantize casts.");
|