[mlir] Add getArgOperandsMutable method to CallOpInterface

Add a method to the CallOpInterface to get a mutable operand range over
the function arguments.  This allows to add, remove, or change the type
of call arguments in a generic manner without having to assume that the
argument operand range is at the end of the operand list, or having to
type switch on all supported concrete operation kinds.

Alternatively, a new OpInterface could be added which inherits from
CallOpInterface and appends it with the mutable variants of the base
interface.

There will be two users of this new function in the beginning:
(1) A few passes in the Arc dialect in CIRCT already use a downstream
implementation of the alternative case mentioned above: https://github.com/llvm/circt/blob/main/include/circt/Dialect/Arc/ArcInterfaces.td#L15
(2) The BufferDeallocation pass will be modified to be able to pass
ownership of memrefs to called private functions if the caller does not
need the memref anymore by appending the function argument list with a
boolean value per memref, thus enabling earlier deallocation of the
memref which can lead to lower peak memory usage.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D156675
This commit is contained in:
Martin Erhart
2023-08-02 08:08:03 +00:00
parent c2093b8504
commit d790a217a7
13 changed files with 70 additions and 1 deletions

View File

@@ -2347,6 +2347,12 @@ def fir_CallOp : fir_Op<"call",
return {arg_operand_begin() + 1, arg_operand_end()};
}
mlir::MutableOperandRange getArgOperandsMutable() {
if ((*this)->getAttrOfType<mlir::SymbolRefAttr>(getCalleeAttrName()))
return getArgsMutable();
return mlir::MutableOperandRange(*this, 1, getArgs().size() - 1);
}
operand_iterator arg_operand_begin() { return operand_begin(); }
operand_iterator arg_operand_end() { return operand_end(); }

View File

@@ -348,6 +348,12 @@ void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
/// call interface.
Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); }
/// Get the argument operands to the called function as a mutable range, this is
/// required by the call interface.
MutableOperandRange GenericCallOp::getArgOperandsMutable() {
return getInputsMutable();
}
//===----------------------------------------------------------------------===//
// MulOp
//===----------------------------------------------------------------------===//

View File

@@ -348,6 +348,12 @@ void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
/// call interface.
Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); }
/// Get the argument operands to the called function as a mutable range, this is
/// required by the call interface.
MutableOperandRange GenericCallOp::getArgOperandsMutable() {
return getInputsMutable();
}
//===----------------------------------------------------------------------===//
// MulOp
//===----------------------------------------------------------------------===//

View File

@@ -348,6 +348,12 @@ void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
/// call interface.
Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); }
/// Get the argument operands to the called function as a mutable range, this is
/// required by the call interface.
MutableOperandRange GenericCallOp::getArgOperandsMutable() {
return getInputsMutable();
}
//===----------------------------------------------------------------------===//
// MulOp
//===----------------------------------------------------------------------===//

View File

@@ -377,6 +377,12 @@ void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
/// call interface.
Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); }
/// Get the argument operands to the called function as a mutable range, this is
/// required by the call interface.
MutableOperandRange GenericCallOp::getArgOperandsMutable() {
return getInputsMutable();
}
//===----------------------------------------------------------------------===//
// MulOp
//===----------------------------------------------------------------------===//

View File

@@ -264,6 +264,10 @@ def Async_CallOp : Async_Op<"call",
return {arg_operand_begin(), arg_operand_end()};
}
MutableOperandRange getArgOperandsMutable() {
return getOperandsMutable();
}
operand_iterator arg_operand_begin() { return operand_begin(); }
operand_iterator arg_operand_end() { return operand_end(); }

View File

@@ -83,6 +83,10 @@ def CallOp : Func_Op<"call",
return {arg_operand_begin(), arg_operand_end()};
}
MutableOperandRange getArgOperandsMutable() {
return getOperandsMutable();
}
operand_iterator arg_operand_begin() { return operand_begin(); }
operand_iterator arg_operand_end() { return operand_end(); }
@@ -152,6 +156,10 @@ def CallIndirectOp : Func_Op<"call_indirect", [
return {arg_operand_begin(), arg_operand_end()};
}
MutableOperandRange getArgOperandsMutable() {
return getCalleeOperandsMutable();
}
operand_iterator arg_operand_begin() { return ++operand_begin(); }
operand_iterator arg_operand_end() { return operand_end(); }

View File

@@ -616,7 +616,7 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
}];
dag args = (ins OptionalAttr<FlatSymbolRefAttr>:$callee,
Variadic<LLVM_Type>,
Variadic<LLVM_Type>:$callee_operands,
DefaultValuedAttr<LLVM_FastmathFlagsAttr,
"{}">:$fastmathFlags,
OptionalAttr<DenseI32ArrayAttr>:$branch_weights);

View File

@@ -632,6 +632,10 @@ def IncludeOp : TransformDialectOp<"include",
::mlir::Operation::operand_range getArgOperands() {
return getOperands();
}
::mlir::MutableOperandRange getArgOperandsMutable() {
return getOperandsMutable();
}
}];
}

View File

@@ -55,6 +55,11 @@ def CallOpInterface : OpInterface<"CallOpInterface"> {
}],
"::mlir::Operation::operand_range", "getArgOperands"
>,
InterfaceMethod<[{
Returns the operands within this call that are used as arguments to the
callee as a mutable range.
}],
"::mlir::MutableOperandRange", "getArgOperandsMutable">,
];
let extraClassDeclaration = [{

View File

@@ -1003,6 +1003,11 @@ Operation::operand_range CallOp::getArgOperands() {
return getOperands().drop_front(getCallee().has_value() ? 0 : 1);
}
MutableOperandRange CallOp::getArgOperandsMutable() {
return MutableOperandRange(*this, getCallee().has_value() ? 0 : 1,
getCalleeOperands().size());
}
LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
if (getNumResults() > 1)
return emitOpError("must have 0 or 1 result");
@@ -1237,6 +1242,11 @@ Operation::operand_range InvokeOp::getArgOperands() {
return getOperands().drop_front(getCallee().has_value() ? 0 : 1);
}
MutableOperandRange InvokeOp::getArgOperandsMutable() {
return MutableOperandRange(*this, getCallee().has_value() ? 0 : 1,
getCalleeOperands().size());
}
LogicalResult InvokeOp::verify() {
if (getNumResults() > 1)
return emitOpError("must have 0 or 1 result");

View File

@@ -208,6 +208,10 @@ Operation::operand_range FunctionCallOp::getArgOperands() {
return getArguments();
}
MutableOperandRange FunctionCallOp::getArgOperandsMutable() {
return getArgumentsMutable();
}
//===----------------------------------------------------------------------===//
// spirv.mlir.loop
//===----------------------------------------------------------------------===//

View File

@@ -1263,6 +1263,10 @@ Operation::operand_range TestCallAndStoreOp::getArgOperands() {
return getCalleeOperands();
}
MutableOperandRange TestCallAndStoreOp::getArgOperandsMutable() {
return getCalleeOperandsMutable();
}
void TestStoreWithARegion::getSuccessorRegions(
std::optional<unsigned> index, ArrayRef<Attribute> operands,
SmallVectorImpl<RegionSuccessor> &regions) {