Add support for inlining calls with different arg/result types from the callable.

Some dialects have implicit conversions inherent in their modeling, meaning that a call may have a different type that the type that the callable expects. To support this, a hook is added to the dialect interface that allows for materializing conversion operations during inlining when there is a mismatch. A hook is also added to the callable interface to allow for introspecting the expected result types.

PiperOrigin-RevId: 272814379
This commit is contained in:
River Riddle
2019-10-03 23:10:25 -07:00
committed by A. Unique TensorFlower
parent a20d96e436
commit 5830f71a45
8 changed files with 231 additions and 60 deletions

View File

@@ -80,11 +80,17 @@ def CallableOpInterface : OpInterface<"CallableOpInterface"> {
"Region *", "getCallableRegion", (ins "CallInterfaceCallable":$callable)
>,
InterfaceMethod<[{
Returns all of the callable regions of this operation
Returns all of the callable regions of this operation.
}],
"void", "getCallableRegions",
(ins "SmallVectorImpl<Region *> &":$callables)
>,
InterfaceMethod<[{
Returns the results types that the given callable region produces when
executed.
}],
"ArrayRef<Type>", "getCallableResults", (ins "Region *":$callable)
>,
];
}

View File

@@ -128,6 +128,13 @@ public:
callables.push_back(&getBody());
}
/// Returns the results types that the given callable region produces when
/// executed.
ArrayRef<Type> getCallableResults(Region *region) {
assert(!isExternal() && region == &getBody() && "invalid callable");
return getType().getResults();
}
private:
// This trait needs access to `getNumFuncArguments` and `verifyType` hooks
// defined below.

View File

@@ -30,7 +30,10 @@ namespace mlir {
class Block;
class BlockAndValueMapping;
class CallableOpInterface;
class CallOpInterface;
class FuncOp;
class OpBuilder;
class Operation;
class Region;
class Value;
@@ -106,6 +109,27 @@ public:
llvm_unreachable(
"must implement handleTerminator in the case of one inlined block");
}
/// Attempt to materialize a conversion for a type mismatch between a call
/// from this dialect, and a callable region. This method should generate an
/// operation that takes 'input' as the only operand, and produces a single
/// result of 'resultType'. If a conversion can not be generated, nullptr
/// should be returned. For example, this hook may be invoked in the following
/// scenarios:
/// func @foo(i32) -> i32 { ... }
///
/// // Mismatched input operand
/// ... = foo.call @foo(%input : i16) -> i32
///
/// // Mismatched result type.
/// ... = foo.call @foo(%input : i32) -> i16
///
/// NOTE: This hook may be invoked before the 'isLegal' checks above.
virtual Operation *materializeCallConversion(OpBuilder &builder, Value *input,
Type resultType,
Location conversionLoc) const {
return nullptr;
}
};
/// This interface provides the hooks into the inlining interface.
@@ -115,7 +139,6 @@ class InlinerInterface
: public DialectInterfaceCollection<DialectInlinerInterface> {
public:
using Base::Base;
virtual ~InlinerInterface();
/// Process a set of blocks that have been inlined. This callback is invoked
/// *before* inlined terminator operations have been processed.
@@ -178,24 +201,15 @@ LogicalResult inlineRegion(InlinerInterface &interface, Region *src,
llvm::Optional<Location> inlineLoc = llvm::None,
bool shouldCloneInlinedRegion = true);
/// This function inlines a FuncOp into another. This function returns failure
/// if it is not possible to inline this FuncOp. If the function returned
/// failure, then no changes to the module have been made.
///
/// Note that this only does one level of inlining. For example, if the
/// instruction 'call B' is inlined into function 'A', and function 'B' also
/// calls 'C', then the call to 'C' now exists inside the body of 'A'. Similarly
/// this will inline a recursive FuncOp by one level.
///
/// 'callOperands' must correspond, 1-1, with the arguments to the provided
/// FuncOp. 'callResults' must correspond, 1-1, with the results of the
/// provided FuncOp. These results will be replaced by the operands of any
/// return operations that are inlined. 'inlineLoc' should refer to the location
/// that the FuncOp is being inlined into.
LogicalResult inlineFunction(InlinerInterface &interface, FuncOp callee,
Operation *inlinePoint,
ArrayRef<Value *> callOperands,
ArrayRef<Value *> callResults, Location inlineLoc);
/// This function inlines a given region, 'src', of a callable operation,
/// 'callable', into the location defined by the given call operation. This
/// function returns failure if inlining is not possible, success otherwise. On
/// failure, no changes are made to the module. 'shouldCloneInlinedRegion'
/// corresponds to whether the source region should be cloned into the 'call' or
/// spliced directly.
LogicalResult inlineCall(InlinerInterface &interface, CallOpInterface call,
CallableOpInterface callable, Region *src,
bool shouldCloneInlinedRegion = true);
} // end namespace mlir

View File

@@ -157,10 +157,10 @@ static void inlineCallsInSCC(Inliner &inliner,
continue;
CallOpInterface call = it.call;
LogicalResult inlineResult = inlineRegion(
inliner, it.targetNode->getCallableRegion(), call,
llvm::to_vector<8>(call.getArgOperands()),
llvm::to_vector<8>(call.getOperation()->getResults()), call.getLoc());
Region *targetRegion = it.targetNode->getCallableRegion();
LogicalResult inlineResult = inlineCall(
inliner, call, cast<CallableOpInterface>(targetRegion->getParentOp()),
targetRegion);
if (failed(inlineResult))
continue;

View File

@@ -22,6 +22,7 @@
#include "mlir/Transforms/InliningUtils.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Operation.h"
#include "llvm/ADT/MapVector.h"
@@ -65,8 +66,6 @@ remapInlinedOperands(llvm::iterator_range<Region::iterator> inlinedBlocks,
// InlinerInterface
//===----------------------------------------------------------------------===//
InlinerInterface::~InlinerInterface() {}
bool InlinerInterface::isLegalToInline(
Region *dest, Region *src, BlockAndValueMapping &valueMapping) const {
// Regions can always be inlined into functions.
@@ -74,7 +73,7 @@ bool InlinerInterface::isLegalToInline(
return true;
auto *handler = getInterfaceFor(dest->getParentOp());
return handler ? handler->isLegalToInline(src, dest, valueMapping) : false;
return handler ? handler->isLegalToInline(dest, src, valueMapping) : false;
}
bool InlinerInterface::isLegalToInline(
@@ -253,38 +252,109 @@ LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
inlineLoc, shouldCloneInlinedRegion);
}
/// This function inlines a FuncOp into another. This function returns failure
/// if it is not possible to inline this FuncOp. If the function returned
/// failure, then no changes to the module have been made.
///
/// Note that this only does one level of inlining. For example, if the
/// instruction 'call B' is inlined, and 'B' calls 'C', then the call to 'C' now
/// exists in the instruction stream. Similarly this will inline a recursive
/// FuncOp by one level.
///
LogicalResult mlir::inlineFunction(InlinerInterface &interface, FuncOp callee,
Operation *inlinePoint,
ArrayRef<Value *> callOperands,
ArrayRef<Value *> callResults,
Location inlineLoc) {
// We don't inline if the provided callee function is a declaration.
assert(callee && "expected valid function to inline");
if (callee.isExternal())
return failure();
/// Utility function used to generate a cast operation from the given interface,
/// or return nullptr if a cast could not be generated.
static Value *materializeConversion(const DialectInlinerInterface *interface,
SmallVectorImpl<Operation *> &castOps,
OpBuilder &castBuilder, Value *arg,
Type type, Location conversionLoc) {
if (!interface)
return nullptr;
// Verify that the provided arguments match the function arguments.
if (callOperands.size() != callee.getNumArguments())
return failure();
// Check to see if the interface for the call can materialize a conversion.
Operation *castOp = interface->materializeCallConversion(castBuilder, arg,
type, conversionLoc);
if (!castOp)
return nullptr;
castOps.push_back(castOp);
// Verify that the provided values to replace match the function results.
auto funcResultTypes = callee.getType().getResults();
if (callResults.size() != funcResultTypes.size())
return failure();
for (unsigned i = 0, e = callResults.size(); i != e; ++i)
if (callResults[i]->getType() != funcResultTypes[i])
return failure();
// Call into the main region inliner function.
return inlineRegion(interface, &callee.getBody(), inlinePoint, callOperands,
callResults, inlineLoc);
// Ensure that the generated cast is correct.
assert(castOp->getNumOperands() == 1 && castOp->getOperand(0) == arg &&
castOp->getNumResults() == 1 && *castOp->result_type_begin() == type);
return castOp->getResult(0);
}
/// This function inlines a given region, 'src', of a callable operation,
/// 'callable', into the location defined by the given call operation. This
/// function returns failure if inlining is not possible, success otherwise. On
/// failure, no changes are made to the module. 'shouldCloneInlinedRegion'
/// corresponds to whether the source region should be cloned into the 'call' or
/// spliced directly.
LogicalResult mlir::inlineCall(InlinerInterface &interface,
CallOpInterface call,
CallableOpInterface callable, Region *src,
bool shouldCloneInlinedRegion) {
// We expect the region to have at least one block.
if (src->empty())
return failure();
auto *entryBlock = &src->front();
ArrayRef<Type> callableResultTypes = callable.getCallableResults(src);
// Make sure that the number of arguments and results matchup between the call
// and the region.
SmallVector<Value *, 8> callOperands(call.getArgOperands());
SmallVector<Value *, 8> callResults(call.getOperation()->getResults());
if (callOperands.size() != entryBlock->getNumArguments() ||
callResults.size() != callableResultTypes.size())
return failure();
// A set of cast operations generated to matchup the signature of the region
// with the signature of the call.
SmallVector<Operation *, 4> castOps;
castOps.reserve(callOperands.size() + callResults.size());
// Functor used to cleanup generated state on failure.
auto cleanupState = [&] {
for (auto *op : castOps) {
op->getResult(0)->replaceAllUsesWith(op->getOperand(0));
op->erase();
}
return failure();
};
// Builder used for any conversion operations that need to be materialized.
OpBuilder castBuilder(call);
Location castLoc = call.getLoc();
auto *callInterface = interface.getInterfaceFor(call.getDialect());
// Map the provided call operands to the arguments of the region.
BlockAndValueMapping mapper;
for (unsigned i = 0, e = callOperands.size(); i != e; ++i) {
BlockArgument *regionArg = entryBlock->getArgument(i);
Value *operand = callOperands[i];
// If the call operand doesn't match the expected region argument, try to
// generate a cast.
Type regionArgType = regionArg->getType();
if (operand->getType() != regionArgType) {
if (!(operand = materializeConversion(callInterface, castOps, castBuilder,
operand, regionArgType, castLoc)))
return cleanupState();
}
mapper.map(regionArg, operand);
}
// Ensure that the resultant values of the call, match the callable.
castBuilder.setInsertionPointAfter(call);
for (unsigned i = 0, e = callResults.size(); i != e; ++i) {
Value *callResult = callResults[i];
if (callResult->getType() == callableResultTypes[i])
continue;
// Generate a conversion that will produce the original type, so that the IR
// is still valid after the original call gets replaced.
Value *castResult =
materializeConversion(callInterface, castOps, castBuilder, callResult,
callResult->getType(), castLoc);
if (!castResult)
return cleanupState();
callResult->replaceAllUsesWith(castResult);
castResult->getDefiningOp()->replaceUsesOfWith(castResult, callResult);
}
// Attempt to inline the call.
if (failed(inlineRegion(interface, src, call, mapper, callResults,
call.getLoc(), shouldCloneInlinedRegion)))
return cleanupState();
return success();
}

View File

@@ -105,3 +105,39 @@ func @no_inline_recursive() {
}) : () -> (() -> ())
return
}
// Check that we can convert types for inputs and results as necessary.
func @convert_callee_fn(%arg : i32) -> i32 {
return %arg : i32
}
func @convert_callee_fn_multi_arg(%a : i32, %b : i32) -> () {
return
}
func @convert_callee_fn_multi_res() -> (i32, i32) {
%res = constant 0 : i32
return %res, %res : i32, i32
}
// CHECK-LABEL: func @inline_convert_call
func @inline_convert_call() -> i16 {
// CHECK: %[[INPUT:.*]] = constant
%test_input = constant 0 : i16
// CHECK: %[[CAST_INPUT:.*]] = "test.cast"(%[[INPUT]]) : (i16) -> i32
// CHECK: %[[CAST_RESULT:.*]] = "test.cast"(%[[CAST_INPUT]]) : (i32) -> i16
// CHECK-NEXT: return %[[CAST_RESULT]]
%res = "test.conversion_call_op"(%test_input) { callee=@convert_callee_fn } : (i16) -> (i16)
return %res : i16
}
// CHECK-LABEL: func @no_inline_convert_call
func @no_inline_convert_call() {
// CHECK: "test.conversion_call_op"
%test_input_i16 = constant 0 : i16
%test_input_i64 = constant 0 : i64
"test.conversion_call_op"(%test_input_i16, %test_input_i64) { callee=@convert_callee_fn_multi_arg } : (i16, i64) -> ()
// CHECK: "test.conversion_call_op"
%res_2:2 = "test.conversion_call_op"() { callee=@convert_callee_fn_multi_res } : () -> (i16, i64)
return
}

View File

@@ -58,7 +58,7 @@ struct TestInlinerInterface : public DialectInlinerInterface {
return true;
}
bool shouldAnalyzeRecursively(Operation *op) const override {
bool shouldAnalyzeRecursively(Operation *op) const final {
// Analyze recursively if this is not a functional region operation, it
// froms a separate functional scope.
return !isa<FunctionalRegionOp>(op);
@@ -82,6 +82,21 @@ struct TestInlinerInterface : public DialectInlinerInterface {
for (const auto &it : llvm::enumerate(returnOp.getOperands()))
valuesToRepl[it.index()]->replaceAllUsesWith(it.value());
}
/// Attempt to materialize a conversion for a type mismatch between a call
/// from this dialect, and a callable region. This method should generate an
/// operation that takes 'input' as the only operand, and produces a single
/// result of 'resultType'. If a conversion can not be generated, nullptr
/// should be returned.
Operation *materializeCallConversion(OpBuilder &builder, Value *input,
Type resultType,
Location conversionLoc) const final {
// Only allow conversion for i16/i32 types.
if (!(resultType.isInteger(16) || resultType.isInteger(32)) ||
!(input->getType().isInteger(16) || input->getType().isInteger(32)))
return nullptr;
return builder.create<TestCastOp>(conversionLoc, resultType, input);
}
};
} // end anonymous namespace

View File

@@ -194,6 +194,26 @@ def SizedRegionOp : TEST_Op<"sized_region_op", []> {
let regions = (region SizedRegion<2>:$my_region, SizedRegion<1>);
}
//===----------------------------------------------------------------------===//
// Test Call Interfaces
//===----------------------------------------------------------------------===//
def ConversionCallOp : TEST_Op<"conversion_call_op",
[CallOpInterface]> {
let arguments = (ins Variadic<AnyType>:$inputs, SymbolRefAttr:$callee);
let results = (outs Variadic<AnyType>);
let extraClassDeclaration = [{
/// Get the argument operands to the called function.
operand_range getArgOperands() { return inputs(); }
/// Return the callee of this operation.
CallInterfaceCallable getCallableForCallee() {
return getAttrOfType<SymbolRefAttr>("callee");
}
}];
}
def FunctionalRegionOp : TEST_Op<"functional_region_op",
[CallableOpInterface]> {
let regions = (region AnyRegion:$body);
@@ -204,6 +224,9 @@ def FunctionalRegionOp : TEST_Op<"functional_region_op",
void getCallableRegions(SmallVectorImpl<Region *> &callables) {
callables.push_back(&body());
}
ArrayRef<Type> getCallableResults(Region *) {
return getType().cast<FunctionType>().getResults();
}
}];
}