[mlir] optionally allow repeated handles in transform dialect

Some operations may be able to deal with handles pointing to the same
operation when the handle is consumed. For example, merge handles with
deduplication doesn't actually destroy payload operations and is
specifically intended to remove the situation with duplicates. Add a
method to the transform interface to allow ops to declare they can
support repeated handles.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D140124
This commit is contained in:
Alex Zinenko
2022-12-15 18:48:09 +00:00
parent 58e9cc13e2
commit 4299be1a08
5 changed files with 36 additions and 2 deletions

View File

@@ -48,6 +48,19 @@ def TransformOpInterface : OpInterface<"TransformOpInterface"> {
"::mlir::transform::TransformResults &":$transformResults,
"::mlir::transform::TransformState &":$state
)>,
InterfaceMethod<
/*desc=*/[{
Indicates whether the op instance allows its handle operands to be
associated with the same payload operations.
}],
/*returnType=*/"bool",
/*name=*/"allowsRepeatedHandleOperands",
/*arguments=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return false;
}]
>,
];
let extraSharedClassDeclaration = [{

View File

@@ -210,7 +210,7 @@ def GetProducerOfOperand : TransformDialectOp<"get_producer_of_operand",
}
def MergeHandlesOp : TransformDialectOp<"merge_handles",
[DeclareOpInterfaceMethods<TransformOpInterface>,
[DeclareOpInterfaceMethods<TransformOpInterface, ["allowsRepeatedHandleOperands"]>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
SameOperandsAndResultType]> {
let summary = "Merges handles into one pointing to the union of payload ops";

View File

@@ -189,7 +189,8 @@ LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
for (OpOperand &target : transform->getOpOperands()) {
// If the operand uses an invalidated handle, report it.
auto it = invalidatedHandles.find(target.get());
if (it != invalidatedHandles.end())
if (!transform.allowsRepeatedHandleOperands() &&
it != invalidatedHandles.end())
return it->getSecond()(transform->getLoc()), failure();
// Invalidate handles pointing to the operations nested in the operation
@@ -201,6 +202,7 @@ LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
if (llvm::any_of(effects, consumesTarget))
recordHandleInvalidation(target);
}
return success();
}

View File

@@ -449,6 +449,11 @@ transform::MergeHandlesOp::apply(transform::TransformResults &results,
return DiagnosedSilenceableFailure::success();
}
bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() {
// Handles may be the same if deduplicating is enabled.
return getDeduplicate();
}
void transform::MergeHandlesOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
consumesHandle(getHandles(), effects);

View File

@@ -99,3 +99,17 @@ module {
transform.test_consume_operand %1, %2
}
}
// -----
// Deduplication attribute allows "merge_handles" to take repeated operands.
module {
transform.sequence failures(propagate) {
^bb0(%0: !pdl.operation):
%1 = transform.test_copy_payload %0
%2 = transform.test_copy_payload %0
transform.merge_handles %1, %2 { deduplicate } : !pdl.operation
}
}