2020-10-12 14:03:09 -07:00
|
|
|
//===- Bufferize.cpp - Bufferization utilities ----------------------------===//
|
|
|
|
|
//
|
|
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
|
//
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
|
|
#include "mlir/Transforms/Bufferize.h"
|
|
|
|
|
#include "mlir/IR/Operation.h"
|
|
|
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
2020-10-12 14:32:38 -07:00
|
|
|
// BufferizeTypeConverter
|
2020-10-12 14:03:09 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
2020-10-12 14:32:38 -07:00
|
|
|
/// Registers conversions into BufferizeTypeConverter
|
|
|
|
|
BufferizeTypeConverter::BufferizeTypeConverter() {
|
2020-10-12 14:03:09 -07:00
|
|
|
// Keep all types unchanged.
|
|
|
|
|
addConversion([](Type type) { return type; });
|
|
|
|
|
// Convert RankedTensorType to MemRefType.
|
2020-10-12 14:47:31 -07:00
|
|
|
addConversion([](RankedTensorType type) -> Type {
|
|
|
|
|
return MemRefType::get(type.getShape(), type.getElementType());
|
2020-10-12 14:03:09 -07:00
|
|
|
});
|
|
|
|
|
// Convert UnrankedTensorType to UnrankedMemRefType.
|
2020-10-12 14:47:31 -07:00
|
|
|
addConversion([](UnrankedTensorType type) -> Type {
|
|
|
|
|
return UnrankedMemRefType::get(type.getElementType(), 0);
|
2020-10-12 14:03:09 -07:00
|
|
|
});
|
2020-11-02 15:12:55 -08:00
|
|
|
addSourceMaterialization([](OpBuilder &builder, TensorType type,
|
2020-10-14 11:26:22 -07:00
|
|
|
ValueRange inputs, Location loc) -> Value {
|
|
|
|
|
assert(inputs.size() == 1);
|
|
|
|
|
assert(inputs[0].getType().isa<BaseMemRefType>());
|
|
|
|
|
return builder.create<TensorLoadOp>(loc, type, inputs[0]);
|
|
|
|
|
});
|
2020-11-02 15:12:55 -08:00
|
|
|
addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type,
|
2020-10-14 11:26:22 -07:00
|
|
|
ValueRange inputs, Location loc) -> Value {
|
|
|
|
|
assert(inputs.size() == 1);
|
|
|
|
|
assert(inputs[0].getType().isa<TensorType>());
|
|
|
|
|
return builder.create<TensorToMemrefOp>(loc, type, inputs[0]);
|
|
|
|
|
});
|
2020-10-12 14:03:09 -07:00
|
|
|
}
|
|
|
|
|
|
2020-10-15 20:17:25 -07:00
|
|
|
void mlir::populateBufferizeMaterializationLegality(ConversionTarget &target) {
|
|
|
|
|
target.addLegalOp<TensorLoadOp, TensorToMemrefOp>();
|
2020-10-24 00:22:48 +00:00
|
|
|
}
|
2020-10-15 20:17:25 -07:00
|
|
|
|
2020-10-26 12:52:28 -07:00
|
|
|
namespace {
|
|
|
|
|
// In a finalizing bufferize conversion, we know that all tensors have been
|
|
|
|
|
// converted to memrefs, thus, this op becomes an identity.
|
|
|
|
|
class BufferizeTensorLoadOp : public OpConversionPattern<TensorLoadOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(TensorLoadOp op, ArrayRef<Value> operands,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
TensorLoadOp::Adaptor adaptor(operands);
|
|
|
|
|
rewriter.replaceOp(op, adaptor.memref());
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
// In a finalizing bufferize conversion, we know that all tensors have been
|
|
|
|
|
// converted to memrefs, thus, this op becomes an identity.
|
|
|
|
|
class BufferizeTensorToMemrefOp : public OpConversionPattern<TensorToMemrefOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(TensorToMemrefOp op, ArrayRef<Value> operands,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
TensorToMemrefOp::Adaptor adaptor(operands);
|
|
|
|
|
rewriter.replaceOp(op, adaptor.tensor());
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
void mlir::populateEliminateBufferizeMaterializationsPatterns(
|
|
|
|
|
MLIRContext *context, BufferizeTypeConverter &typeConverter,
|
|
|
|
|
OwningRewritePatternList &patterns) {
|
|
|
|
|
patterns.insert<BufferizeTensorLoadOp, BufferizeTensorToMemrefOp>(
|
|
|
|
|
typeConverter, context);
|
|
|
|
|
}
|