mirror of
https://github.com/intel/llvm.git
synced 2026-01-16 13:35:38 +08:00
[mlir] Add getArgOperandsMutable method to CallOpInterface
Add a method to the CallOpInterface to get a mutable operand range over the function arguments. This allows to add, remove, or change the type of call arguments in a generic manner without having to assume that the argument operand range is at the end of the operand list, or having to type switch on all supported concrete operation kinds. Alternatively, a new OpInterface could be added which inherits from CallOpInterface and appends it with the mutable variants of the base interface. There will be two users of this new function in the beginning: (1) A few passes in the Arc dialect in CIRCT already use a downstream implementation of the alternative case mentioned above: https://github.com/llvm/circt/blob/main/include/circt/Dialect/Arc/ArcInterfaces.td#L15 (2) The BufferDeallocation pass will be modified to be able to pass ownership of memrefs to called private functions if the caller does not need the memref anymore by appending the function argument list with a boolean value per memref, thus enabling earlier deallocation of the memref which can lead to lower peak memory usage. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D156675
This commit is contained in:
@@ -2347,6 +2347,12 @@ def fir_CallOp : fir_Op<"call",
|
||||
return {arg_operand_begin() + 1, arg_operand_end()};
|
||||
}
|
||||
|
||||
mlir::MutableOperandRange getArgOperandsMutable() {
|
||||
if ((*this)->getAttrOfType<mlir::SymbolRefAttr>(getCalleeAttrName()))
|
||||
return getArgsMutable();
|
||||
return mlir::MutableOperandRange(*this, 1, getArgs().size() - 1);
|
||||
}
|
||||
|
||||
operand_iterator arg_operand_begin() { return operand_begin(); }
|
||||
operand_iterator arg_operand_end() { return operand_end(); }
|
||||
|
||||
|
||||
@@ -348,6 +348,12 @@ void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
|
||||
/// call interface.
|
||||
Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); }
|
||||
|
||||
/// Get the argument operands to the called function as a mutable range, this is
|
||||
/// required by the call interface.
|
||||
MutableOperandRange GenericCallOp::getArgOperandsMutable() {
|
||||
return getInputsMutable();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MulOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -348,6 +348,12 @@ void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
|
||||
/// call interface.
|
||||
Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); }
|
||||
|
||||
/// Get the argument operands to the called function as a mutable range, this is
|
||||
/// required by the call interface.
|
||||
MutableOperandRange GenericCallOp::getArgOperandsMutable() {
|
||||
return getInputsMutable();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MulOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -348,6 +348,12 @@ void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
|
||||
/// call interface.
|
||||
Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); }
|
||||
|
||||
/// Get the argument operands to the called function as a mutable range, this is
|
||||
/// required by the call interface.
|
||||
MutableOperandRange GenericCallOp::getArgOperandsMutable() {
|
||||
return getInputsMutable();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MulOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -377,6 +377,12 @@ void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
|
||||
/// call interface.
|
||||
Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); }
|
||||
|
||||
/// Get the argument operands to the called function as a mutable range, this is
|
||||
/// required by the call interface.
|
||||
MutableOperandRange GenericCallOp::getArgOperandsMutable() {
|
||||
return getInputsMutable();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MulOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -264,6 +264,10 @@ def Async_CallOp : Async_Op<"call",
|
||||
return {arg_operand_begin(), arg_operand_end()};
|
||||
}
|
||||
|
||||
MutableOperandRange getArgOperandsMutable() {
|
||||
return getOperandsMutable();
|
||||
}
|
||||
|
||||
operand_iterator arg_operand_begin() { return operand_begin(); }
|
||||
operand_iterator arg_operand_end() { return operand_end(); }
|
||||
|
||||
|
||||
@@ -83,6 +83,10 @@ def CallOp : Func_Op<"call",
|
||||
return {arg_operand_begin(), arg_operand_end()};
|
||||
}
|
||||
|
||||
MutableOperandRange getArgOperandsMutable() {
|
||||
return getOperandsMutable();
|
||||
}
|
||||
|
||||
operand_iterator arg_operand_begin() { return operand_begin(); }
|
||||
operand_iterator arg_operand_end() { return operand_end(); }
|
||||
|
||||
@@ -152,6 +156,10 @@ def CallIndirectOp : Func_Op<"call_indirect", [
|
||||
return {arg_operand_begin(), arg_operand_end()};
|
||||
}
|
||||
|
||||
MutableOperandRange getArgOperandsMutable() {
|
||||
return getCalleeOperandsMutable();
|
||||
}
|
||||
|
||||
operand_iterator arg_operand_begin() { return ++operand_begin(); }
|
||||
operand_iterator arg_operand_end() { return operand_end(); }
|
||||
|
||||
|
||||
@@ -616,7 +616,7 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
|
||||
}];
|
||||
|
||||
dag args = (ins OptionalAttr<FlatSymbolRefAttr>:$callee,
|
||||
Variadic<LLVM_Type>,
|
||||
Variadic<LLVM_Type>:$callee_operands,
|
||||
DefaultValuedAttr<LLVM_FastmathFlagsAttr,
|
||||
"{}">:$fastmathFlags,
|
||||
OptionalAttr<DenseI32ArrayAttr>:$branch_weights);
|
||||
|
||||
@@ -632,6 +632,10 @@ def IncludeOp : TransformDialectOp<"include",
|
||||
::mlir::Operation::operand_range getArgOperands() {
|
||||
return getOperands();
|
||||
}
|
||||
|
||||
::mlir::MutableOperandRange getArgOperandsMutable() {
|
||||
return getOperandsMutable();
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
@@ -55,6 +55,11 @@ def CallOpInterface : OpInterface<"CallOpInterface"> {
|
||||
}],
|
||||
"::mlir::Operation::operand_range", "getArgOperands"
|
||||
>,
|
||||
InterfaceMethod<[{
|
||||
Returns the operands within this call that are used as arguments to the
|
||||
callee as a mutable range.
|
||||
}],
|
||||
"::mlir::MutableOperandRange", "getArgOperandsMutable">,
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
|
||||
@@ -1003,6 +1003,11 @@ Operation::operand_range CallOp::getArgOperands() {
|
||||
return getOperands().drop_front(getCallee().has_value() ? 0 : 1);
|
||||
}
|
||||
|
||||
MutableOperandRange CallOp::getArgOperandsMutable() {
|
||||
return MutableOperandRange(*this, getCallee().has_value() ? 0 : 1,
|
||||
getCalleeOperands().size());
|
||||
}
|
||||
|
||||
LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
||||
if (getNumResults() > 1)
|
||||
return emitOpError("must have 0 or 1 result");
|
||||
@@ -1237,6 +1242,11 @@ Operation::operand_range InvokeOp::getArgOperands() {
|
||||
return getOperands().drop_front(getCallee().has_value() ? 0 : 1);
|
||||
}
|
||||
|
||||
MutableOperandRange InvokeOp::getArgOperandsMutable() {
|
||||
return MutableOperandRange(*this, getCallee().has_value() ? 0 : 1,
|
||||
getCalleeOperands().size());
|
||||
}
|
||||
|
||||
LogicalResult InvokeOp::verify() {
|
||||
if (getNumResults() > 1)
|
||||
return emitOpError("must have 0 or 1 result");
|
||||
|
||||
@@ -208,6 +208,10 @@ Operation::operand_range FunctionCallOp::getArgOperands() {
|
||||
return getArguments();
|
||||
}
|
||||
|
||||
MutableOperandRange FunctionCallOp::getArgOperandsMutable() {
|
||||
return getArgumentsMutable();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.mlir.loop
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -1263,6 +1263,10 @@ Operation::operand_range TestCallAndStoreOp::getArgOperands() {
|
||||
return getCalleeOperands();
|
||||
}
|
||||
|
||||
MutableOperandRange TestCallAndStoreOp::getArgOperandsMutable() {
|
||||
return getCalleeOperandsMutable();
|
||||
}
|
||||
|
||||
void TestStoreWithARegion::getSuccessorRegions(
|
||||
std::optional<unsigned> index, ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
|
||||
Reference in New Issue
Block a user