mirror of
https://github.com/intel/llvm.git
synced 2026-02-08 08:57:43 +08:00
[mlir] optionally allow repeated handles in transform dialect
Some operations may be able to deal with handles pointing to the same operation when the handle is consumed. For example, merge handles with deduplication doesn't actually destroy payload operations and is specifically intended to remove the situation with duplicates. Add a method to the transform interface to allow ops to declare they can support repeated handles. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D140124
This commit is contained in:
@@ -48,6 +48,19 @@ def TransformOpInterface : OpInterface<"TransformOpInterface"> {
|
||||
"::mlir::transform::TransformResults &":$transformResults,
|
||||
"::mlir::transform::TransformState &":$state
|
||||
)>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Indicates whether the op instance allows its handle operands to be
|
||||
associated with the same payload operations.
|
||||
}],
|
||||
/*returnType=*/"bool",
|
||||
/*name=*/"allowsRepeatedHandleOperands",
|
||||
/*arguments=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
return false;
|
||||
}]
|
||||
>,
|
||||
];
|
||||
|
||||
let extraSharedClassDeclaration = [{
|
||||
|
||||
@@ -210,7 +210,7 @@ def GetProducerOfOperand : TransformDialectOp<"get_producer_of_operand",
|
||||
}
|
||||
|
||||
def MergeHandlesOp : TransformDialectOp<"merge_handles",
|
||||
[DeclareOpInterfaceMethods<TransformOpInterface>,
|
||||
[DeclareOpInterfaceMethods<TransformOpInterface, ["allowsRepeatedHandleOperands"]>,
|
||||
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
|
||||
SameOperandsAndResultType]> {
|
||||
let summary = "Merges handles into one pointing to the union of payload ops";
|
||||
|
||||
@@ -189,7 +189,8 @@ LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
|
||||
for (OpOperand &target : transform->getOpOperands()) {
|
||||
// If the operand uses an invalidated handle, report it.
|
||||
auto it = invalidatedHandles.find(target.get());
|
||||
if (it != invalidatedHandles.end())
|
||||
if (!transform.allowsRepeatedHandleOperands() &&
|
||||
it != invalidatedHandles.end())
|
||||
return it->getSecond()(transform->getLoc()), failure();
|
||||
|
||||
// Invalidate handles pointing to the operations nested in the operation
|
||||
@@ -201,6 +202,7 @@ LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
|
||||
if (llvm::any_of(effects, consumesTarget))
|
||||
recordHandleInvalidation(target);
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
|
||||
@@ -449,6 +449,11 @@ transform::MergeHandlesOp::apply(transform::TransformResults &results,
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() {
|
||||
// Handles may be the same if deduplicating is enabled.
|
||||
return getDeduplicate();
|
||||
}
|
||||
|
||||
void transform::MergeHandlesOp::getEffects(
|
||||
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
||||
consumesHandle(getHandles(), effects);
|
||||
|
||||
@@ -99,3 +99,17 @@ module {
|
||||
transform.test_consume_operand %1, %2
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Deduplication attribute allows "merge_handles" to take repeated operands.
|
||||
|
||||
module {
|
||||
|
||||
transform.sequence failures(propagate) {
|
||||
^bb0(%0: !pdl.operation):
|
||||
%1 = transform.test_copy_payload %0
|
||||
%2 = transform.test_copy_payload %0
|
||||
transform.merge_handles %1, %2 { deduplicate } : !pdl.operation
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user