mirror of
https://github.com/intel/llvm.git
synced 2026-01-26 12:26:52 +08:00
Add support for inlining calls with different arg/result types from the callable.
Some dialects have implicit conversions inherent in their modeling, meaning that a call may have a different type that the type that the callable expects. To support this, a hook is added to the dialect interface that allows for materializing conversion operations during inlining when there is a mismatch. A hook is also added to the callable interface to allow for introspecting the expected result types. PiperOrigin-RevId: 272814379
This commit is contained in:
committed by
A. Unique TensorFlower
parent
a20d96e436
commit
5830f71a45
@@ -80,11 +80,17 @@ def CallableOpInterface : OpInterface<"CallableOpInterface"> {
|
||||
"Region *", "getCallableRegion", (ins "CallInterfaceCallable":$callable)
|
||||
>,
|
||||
InterfaceMethod<[{
|
||||
Returns all of the callable regions of this operation
|
||||
Returns all of the callable regions of this operation.
|
||||
}],
|
||||
"void", "getCallableRegions",
|
||||
(ins "SmallVectorImpl<Region *> &":$callables)
|
||||
>,
|
||||
InterfaceMethod<[{
|
||||
Returns the results types that the given callable region produces when
|
||||
executed.
|
||||
}],
|
||||
"ArrayRef<Type>", "getCallableResults", (ins "Region *":$callable)
|
||||
>,
|
||||
];
|
||||
}
|
||||
|
||||
|
||||
@@ -128,6 +128,13 @@ public:
|
||||
callables.push_back(&getBody());
|
||||
}
|
||||
|
||||
/// Returns the results types that the given callable region produces when
|
||||
/// executed.
|
||||
ArrayRef<Type> getCallableResults(Region *region) {
|
||||
assert(!isExternal() && region == &getBody() && "invalid callable");
|
||||
return getType().getResults();
|
||||
}
|
||||
|
||||
private:
|
||||
// This trait needs access to `getNumFuncArguments` and `verifyType` hooks
|
||||
// defined below.
|
||||
|
||||
@@ -30,7 +30,10 @@ namespace mlir {
|
||||
|
||||
class Block;
|
||||
class BlockAndValueMapping;
|
||||
class CallableOpInterface;
|
||||
class CallOpInterface;
|
||||
class FuncOp;
|
||||
class OpBuilder;
|
||||
class Operation;
|
||||
class Region;
|
||||
class Value;
|
||||
@@ -106,6 +109,27 @@ public:
|
||||
llvm_unreachable(
|
||||
"must implement handleTerminator in the case of one inlined block");
|
||||
}
|
||||
|
||||
/// Attempt to materialize a conversion for a type mismatch between a call
|
||||
/// from this dialect, and a callable region. This method should generate an
|
||||
/// operation that takes 'input' as the only operand, and produces a single
|
||||
/// result of 'resultType'. If a conversion can not be generated, nullptr
|
||||
/// should be returned. For example, this hook may be invoked in the following
|
||||
/// scenarios:
|
||||
/// func @foo(i32) -> i32 { ... }
|
||||
///
|
||||
/// // Mismatched input operand
|
||||
/// ... = foo.call @foo(%input : i16) -> i32
|
||||
///
|
||||
/// // Mismatched result type.
|
||||
/// ... = foo.call @foo(%input : i32) -> i16
|
||||
///
|
||||
/// NOTE: This hook may be invoked before the 'isLegal' checks above.
|
||||
virtual Operation *materializeCallConversion(OpBuilder &builder, Value *input,
|
||||
Type resultType,
|
||||
Location conversionLoc) const {
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
/// This interface provides the hooks into the inlining interface.
|
||||
@@ -115,7 +139,6 @@ class InlinerInterface
|
||||
: public DialectInterfaceCollection<DialectInlinerInterface> {
|
||||
public:
|
||||
using Base::Base;
|
||||
virtual ~InlinerInterface();
|
||||
|
||||
/// Process a set of blocks that have been inlined. This callback is invoked
|
||||
/// *before* inlined terminator operations have been processed.
|
||||
@@ -178,24 +201,15 @@ LogicalResult inlineRegion(InlinerInterface &interface, Region *src,
|
||||
llvm::Optional<Location> inlineLoc = llvm::None,
|
||||
bool shouldCloneInlinedRegion = true);
|
||||
|
||||
/// This function inlines a FuncOp into another. This function returns failure
|
||||
/// if it is not possible to inline this FuncOp. If the function returned
|
||||
/// failure, then no changes to the module have been made.
|
||||
///
|
||||
/// Note that this only does one level of inlining. For example, if the
|
||||
/// instruction 'call B' is inlined into function 'A', and function 'B' also
|
||||
/// calls 'C', then the call to 'C' now exists inside the body of 'A'. Similarly
|
||||
/// this will inline a recursive FuncOp by one level.
|
||||
///
|
||||
/// 'callOperands' must correspond, 1-1, with the arguments to the provided
|
||||
/// FuncOp. 'callResults' must correspond, 1-1, with the results of the
|
||||
/// provided FuncOp. These results will be replaced by the operands of any
|
||||
/// return operations that are inlined. 'inlineLoc' should refer to the location
|
||||
/// that the FuncOp is being inlined into.
|
||||
LogicalResult inlineFunction(InlinerInterface &interface, FuncOp callee,
|
||||
Operation *inlinePoint,
|
||||
ArrayRef<Value *> callOperands,
|
||||
ArrayRef<Value *> callResults, Location inlineLoc);
|
||||
/// This function inlines a given region, 'src', of a callable operation,
|
||||
/// 'callable', into the location defined by the given call operation. This
|
||||
/// function returns failure if inlining is not possible, success otherwise. On
|
||||
/// failure, no changes are made to the module. 'shouldCloneInlinedRegion'
|
||||
/// corresponds to whether the source region should be cloned into the 'call' or
|
||||
/// spliced directly.
|
||||
LogicalResult inlineCall(InlinerInterface &interface, CallOpInterface call,
|
||||
CallableOpInterface callable, Region *src,
|
||||
bool shouldCloneInlinedRegion = true);
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
|
||||
@@ -157,10 +157,10 @@ static void inlineCallsInSCC(Inliner &inliner,
|
||||
continue;
|
||||
|
||||
CallOpInterface call = it.call;
|
||||
LogicalResult inlineResult = inlineRegion(
|
||||
inliner, it.targetNode->getCallableRegion(), call,
|
||||
llvm::to_vector<8>(call.getArgOperands()),
|
||||
llvm::to_vector<8>(call.getOperation()->getResults()), call.getLoc());
|
||||
Region *targetRegion = it.targetNode->getCallableRegion();
|
||||
LogicalResult inlineResult = inlineCall(
|
||||
inliner, call, cast<CallableOpInterface>(targetRegion->getParentOp()),
|
||||
targetRegion);
|
||||
if (failed(inlineResult))
|
||||
continue;
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@
|
||||
#include "mlir/Transforms/InliningUtils.h"
|
||||
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "llvm/ADT/MapVector.h"
|
||||
@@ -65,8 +66,6 @@ remapInlinedOperands(llvm::iterator_range<Region::iterator> inlinedBlocks,
|
||||
// InlinerInterface
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
InlinerInterface::~InlinerInterface() {}
|
||||
|
||||
bool InlinerInterface::isLegalToInline(
|
||||
Region *dest, Region *src, BlockAndValueMapping &valueMapping) const {
|
||||
// Regions can always be inlined into functions.
|
||||
@@ -74,7 +73,7 @@ bool InlinerInterface::isLegalToInline(
|
||||
return true;
|
||||
|
||||
auto *handler = getInterfaceFor(dest->getParentOp());
|
||||
return handler ? handler->isLegalToInline(src, dest, valueMapping) : false;
|
||||
return handler ? handler->isLegalToInline(dest, src, valueMapping) : false;
|
||||
}
|
||||
|
||||
bool InlinerInterface::isLegalToInline(
|
||||
@@ -253,38 +252,109 @@ LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
|
||||
inlineLoc, shouldCloneInlinedRegion);
|
||||
}
|
||||
|
||||
/// This function inlines a FuncOp into another. This function returns failure
|
||||
/// if it is not possible to inline this FuncOp. If the function returned
|
||||
/// failure, then no changes to the module have been made.
|
||||
///
|
||||
/// Note that this only does one level of inlining. For example, if the
|
||||
/// instruction 'call B' is inlined, and 'B' calls 'C', then the call to 'C' now
|
||||
/// exists in the instruction stream. Similarly this will inline a recursive
|
||||
/// FuncOp by one level.
|
||||
///
|
||||
LogicalResult mlir::inlineFunction(InlinerInterface &interface, FuncOp callee,
|
||||
Operation *inlinePoint,
|
||||
ArrayRef<Value *> callOperands,
|
||||
ArrayRef<Value *> callResults,
|
||||
Location inlineLoc) {
|
||||
// We don't inline if the provided callee function is a declaration.
|
||||
assert(callee && "expected valid function to inline");
|
||||
if (callee.isExternal())
|
||||
return failure();
|
||||
/// Utility function used to generate a cast operation from the given interface,
|
||||
/// or return nullptr if a cast could not be generated.
|
||||
static Value *materializeConversion(const DialectInlinerInterface *interface,
|
||||
SmallVectorImpl<Operation *> &castOps,
|
||||
OpBuilder &castBuilder, Value *arg,
|
||||
Type type, Location conversionLoc) {
|
||||
if (!interface)
|
||||
return nullptr;
|
||||
|
||||
// Verify that the provided arguments match the function arguments.
|
||||
if (callOperands.size() != callee.getNumArguments())
|
||||
return failure();
|
||||
// Check to see if the interface for the call can materialize a conversion.
|
||||
Operation *castOp = interface->materializeCallConversion(castBuilder, arg,
|
||||
type, conversionLoc);
|
||||
if (!castOp)
|
||||
return nullptr;
|
||||
castOps.push_back(castOp);
|
||||
|
||||
// Verify that the provided values to replace match the function results.
|
||||
auto funcResultTypes = callee.getType().getResults();
|
||||
if (callResults.size() != funcResultTypes.size())
|
||||
return failure();
|
||||
for (unsigned i = 0, e = callResults.size(); i != e; ++i)
|
||||
if (callResults[i]->getType() != funcResultTypes[i])
|
||||
return failure();
|
||||
|
||||
// Call into the main region inliner function.
|
||||
return inlineRegion(interface, &callee.getBody(), inlinePoint, callOperands,
|
||||
callResults, inlineLoc);
|
||||
// Ensure that the generated cast is correct.
|
||||
assert(castOp->getNumOperands() == 1 && castOp->getOperand(0) == arg &&
|
||||
castOp->getNumResults() == 1 && *castOp->result_type_begin() == type);
|
||||
return castOp->getResult(0);
|
||||
}
|
||||
|
||||
/// This function inlines a given region, 'src', of a callable operation,
|
||||
/// 'callable', into the location defined by the given call operation. This
|
||||
/// function returns failure if inlining is not possible, success otherwise. On
|
||||
/// failure, no changes are made to the module. 'shouldCloneInlinedRegion'
|
||||
/// corresponds to whether the source region should be cloned into the 'call' or
|
||||
/// spliced directly.
|
||||
LogicalResult mlir::inlineCall(InlinerInterface &interface,
|
||||
CallOpInterface call,
|
||||
CallableOpInterface callable, Region *src,
|
||||
bool shouldCloneInlinedRegion) {
|
||||
// We expect the region to have at least one block.
|
||||
if (src->empty())
|
||||
return failure();
|
||||
auto *entryBlock = &src->front();
|
||||
ArrayRef<Type> callableResultTypes = callable.getCallableResults(src);
|
||||
|
||||
// Make sure that the number of arguments and results matchup between the call
|
||||
// and the region.
|
||||
SmallVector<Value *, 8> callOperands(call.getArgOperands());
|
||||
SmallVector<Value *, 8> callResults(call.getOperation()->getResults());
|
||||
if (callOperands.size() != entryBlock->getNumArguments() ||
|
||||
callResults.size() != callableResultTypes.size())
|
||||
return failure();
|
||||
|
||||
// A set of cast operations generated to matchup the signature of the region
|
||||
// with the signature of the call.
|
||||
SmallVector<Operation *, 4> castOps;
|
||||
castOps.reserve(callOperands.size() + callResults.size());
|
||||
|
||||
// Functor used to cleanup generated state on failure.
|
||||
auto cleanupState = [&] {
|
||||
for (auto *op : castOps) {
|
||||
op->getResult(0)->replaceAllUsesWith(op->getOperand(0));
|
||||
op->erase();
|
||||
}
|
||||
return failure();
|
||||
};
|
||||
|
||||
// Builder used for any conversion operations that need to be materialized.
|
||||
OpBuilder castBuilder(call);
|
||||
Location castLoc = call.getLoc();
|
||||
auto *callInterface = interface.getInterfaceFor(call.getDialect());
|
||||
|
||||
// Map the provided call operands to the arguments of the region.
|
||||
BlockAndValueMapping mapper;
|
||||
for (unsigned i = 0, e = callOperands.size(); i != e; ++i) {
|
||||
BlockArgument *regionArg = entryBlock->getArgument(i);
|
||||
Value *operand = callOperands[i];
|
||||
|
||||
// If the call operand doesn't match the expected region argument, try to
|
||||
// generate a cast.
|
||||
Type regionArgType = regionArg->getType();
|
||||
if (operand->getType() != regionArgType) {
|
||||
if (!(operand = materializeConversion(callInterface, castOps, castBuilder,
|
||||
operand, regionArgType, castLoc)))
|
||||
return cleanupState();
|
||||
}
|
||||
mapper.map(regionArg, operand);
|
||||
}
|
||||
|
||||
// Ensure that the resultant values of the call, match the callable.
|
||||
castBuilder.setInsertionPointAfter(call);
|
||||
for (unsigned i = 0, e = callResults.size(); i != e; ++i) {
|
||||
Value *callResult = callResults[i];
|
||||
if (callResult->getType() == callableResultTypes[i])
|
||||
continue;
|
||||
|
||||
// Generate a conversion that will produce the original type, so that the IR
|
||||
// is still valid after the original call gets replaced.
|
||||
Value *castResult =
|
||||
materializeConversion(callInterface, castOps, castBuilder, callResult,
|
||||
callResult->getType(), castLoc);
|
||||
if (!castResult)
|
||||
return cleanupState();
|
||||
callResult->replaceAllUsesWith(castResult);
|
||||
castResult->getDefiningOp()->replaceUsesOfWith(castResult, callResult);
|
||||
}
|
||||
|
||||
// Attempt to inline the call.
|
||||
if (failed(inlineRegion(interface, src, call, mapper, callResults,
|
||||
call.getLoc(), shouldCloneInlinedRegion)))
|
||||
return cleanupState();
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -105,3 +105,39 @@ func @no_inline_recursive() {
|
||||
}) : () -> (() -> ())
|
||||
return
|
||||
}
|
||||
|
||||
// Check that we can convert types for inputs and results as necessary.
|
||||
func @convert_callee_fn(%arg : i32) -> i32 {
|
||||
return %arg : i32
|
||||
}
|
||||
func @convert_callee_fn_multi_arg(%a : i32, %b : i32) -> () {
|
||||
return
|
||||
}
|
||||
func @convert_callee_fn_multi_res() -> (i32, i32) {
|
||||
%res = constant 0 : i32
|
||||
return %res, %res : i32, i32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @inline_convert_call
|
||||
func @inline_convert_call() -> i16 {
|
||||
// CHECK: %[[INPUT:.*]] = constant
|
||||
%test_input = constant 0 : i16
|
||||
|
||||
// CHECK: %[[CAST_INPUT:.*]] = "test.cast"(%[[INPUT]]) : (i16) -> i32
|
||||
// CHECK: %[[CAST_RESULT:.*]] = "test.cast"(%[[CAST_INPUT]]) : (i32) -> i16
|
||||
// CHECK-NEXT: return %[[CAST_RESULT]]
|
||||
%res = "test.conversion_call_op"(%test_input) { callee=@convert_callee_fn } : (i16) -> (i16)
|
||||
return %res : i16
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @no_inline_convert_call
|
||||
func @no_inline_convert_call() {
|
||||
// CHECK: "test.conversion_call_op"
|
||||
%test_input_i16 = constant 0 : i16
|
||||
%test_input_i64 = constant 0 : i64
|
||||
"test.conversion_call_op"(%test_input_i16, %test_input_i64) { callee=@convert_callee_fn_multi_arg } : (i16, i64) -> ()
|
||||
|
||||
// CHECK: "test.conversion_call_op"
|
||||
%res_2:2 = "test.conversion_call_op"() { callee=@convert_callee_fn_multi_res } : () -> (i16, i64)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -58,7 +58,7 @@ struct TestInlinerInterface : public DialectInlinerInterface {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool shouldAnalyzeRecursively(Operation *op) const override {
|
||||
bool shouldAnalyzeRecursively(Operation *op) const final {
|
||||
// Analyze recursively if this is not a functional region operation, it
|
||||
// froms a separate functional scope.
|
||||
return !isa<FunctionalRegionOp>(op);
|
||||
@@ -82,6 +82,21 @@ struct TestInlinerInterface : public DialectInlinerInterface {
|
||||
for (const auto &it : llvm::enumerate(returnOp.getOperands()))
|
||||
valuesToRepl[it.index()]->replaceAllUsesWith(it.value());
|
||||
}
|
||||
|
||||
/// Attempt to materialize a conversion for a type mismatch between a call
|
||||
/// from this dialect, and a callable region. This method should generate an
|
||||
/// operation that takes 'input' as the only operand, and produces a single
|
||||
/// result of 'resultType'. If a conversion can not be generated, nullptr
|
||||
/// should be returned.
|
||||
Operation *materializeCallConversion(OpBuilder &builder, Value *input,
|
||||
Type resultType,
|
||||
Location conversionLoc) const final {
|
||||
// Only allow conversion for i16/i32 types.
|
||||
if (!(resultType.isInteger(16) || resultType.isInteger(32)) ||
|
||||
!(input->getType().isInteger(16) || input->getType().isInteger(32)))
|
||||
return nullptr;
|
||||
return builder.create<TestCastOp>(conversionLoc, resultType, input);
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
|
||||
@@ -194,6 +194,26 @@ def SizedRegionOp : TEST_Op<"sized_region_op", []> {
|
||||
let regions = (region SizedRegion<2>:$my_region, SizedRegion<1>);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test Call Interfaces
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def ConversionCallOp : TEST_Op<"conversion_call_op",
|
||||
[CallOpInterface]> {
|
||||
let arguments = (ins Variadic<AnyType>:$inputs, SymbolRefAttr:$callee);
|
||||
let results = (outs Variadic<AnyType>);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Get the argument operands to the called function.
|
||||
operand_range getArgOperands() { return inputs(); }
|
||||
|
||||
/// Return the callee of this operation.
|
||||
CallInterfaceCallable getCallableForCallee() {
|
||||
return getAttrOfType<SymbolRefAttr>("callee");
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def FunctionalRegionOp : TEST_Op<"functional_region_op",
|
||||
[CallableOpInterface]> {
|
||||
let regions = (region AnyRegion:$body);
|
||||
@@ -204,6 +224,9 @@ def FunctionalRegionOp : TEST_Op<"functional_region_op",
|
||||
void getCallableRegions(SmallVectorImpl<Region *> &callables) {
|
||||
callables.push_back(&body());
|
||||
}
|
||||
ArrayRef<Type> getCallableResults(Region *) {
|
||||
return getType().cast<FunctionType>().getResults();
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user