[mlir] propagate silenceable failures in transform.foreach_match (#86956)

The original implementation was eagerly reporting silenceable failures
from actions as definite failures. Since silenceable failures are
intended for cases when the IR has not been irreversibly modified, it's
okay to propagate them as silenceable failures of the parent op.

Fixes #86834.
This commit is contained in:
Oleksandr "Alex" Zinenko
2024-03-28 18:52:10 +01:00
committed by GitHub
parent 2af3b43642
commit 0b790572b1
2 changed files with 95 additions and 2 deletions

View File

@@ -1020,6 +1020,8 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
matchActionPairs.emplace_back(matcherSymbol, actionSymbol);
}
DiagnosedSilenceableFailure overallDiag =
DiagnosedSilenceableFailure::success();
for (Operation *root : state.getPayloadOps(getRoot())) {
WalkResult walkResult = root->walk([&](Operation *op) {
// If getRestrictRoot is not present, skip over the root op itself so we
@@ -1058,8 +1060,19 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
action.getFunctionBody().front().without_terminator()) {
DiagnosedSilenceableFailure result =
state.applyTransform(cast<TransformOpInterface>(transform));
if (failed(result.checkAndReport()))
if (result.isDefiniteFailure())
return WalkResult::interrupt();
if (result.isSilenceableFailure()) {
if (overallDiag.succeeded()) {
overallDiag = emitSilenceableError() << "actions failed";
}
overallDiag.attachNote(action->getLoc())
<< "failed action: " << result.getMessage();
overallDiag.attachNote(op->getLoc())
<< "when applied to this matching payload";
(void)result.silence();
continue;
}
}
break;
}
@@ -1075,7 +1088,7 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
// by actions, are invalidated.
results.set(llvm::cast<OpResult>(getUpdated()),
state.getPayloadOps(getRoot()));
return DiagnosedSilenceableFailure::success();
return overallDiag;
}
void transform::ForeachMatchOp::getEffects(

View File

@@ -0,0 +1,80 @@
// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics
// Silenceable diagnostics suppressed.
module attributes { transform.with_named_sequence } {
func.func @test_loop_peeling_not_beneficial() {
%lb = arith.constant 0 : index
%ub = arith.constant 40 : index
%step = arith.constant 5 : index
scf.for %i = %lb to %ub step %step {
arith.addi %i, %i : index
}
return
}
transform.named_sequence @peel(%arg0: !transform.op<"scf.for"> {transform.consumed}) {
transform.loop.peel %arg0 : (!transform.op<"scf.for">) -> (!transform.any_op, !transform.any_op)
transform.yield
}
transform.named_sequence @match_for(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
transform.match.operation_name %arg0 ["scf.for"] : !transform.any_op
transform.yield %arg0 : !transform.any_op
}
transform.named_sequence @__transform_main(%root: !transform.any_op) {
transform.sequence %root : !transform.any_op failures(suppress) {
^bb0(%arg0: !transform.any_op):
transform.foreach_match in %arg0
@match_for -> @peel
: (!transform.any_op) -> !transform.any_op
transform.yield
}
transform.yield
}
}
// -----
// Silenceable diagnostics propagated.
module attributes { transform.with_named_sequence } {
func.func @test_loop_peeling_not_beneficial() {
%lb = arith.constant 0 : index
%ub = arith.constant 40 : index
%step = arith.constant 5 : index
// expected-note @below {{when applied to this matching payload}}
scf.for %i = %lb to %ub step %step {
arith.addi %i, %i : index
}
return
}
// expected-note @below {{failed to peel the last iteration}}
transform.named_sequence @peel(%arg0: !transform.op<"scf.for"> {transform.consumed}) {
transform.loop.peel %arg0 : (!transform.op<"scf.for">) -> (!transform.any_op, !transform.any_op)
transform.yield
}
transform.named_sequence @match_for(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
transform.match.operation_name %arg0 ["scf.for"] : !transform.any_op
transform.yield %arg0 : !transform.any_op
}
transform.named_sequence @main_suppress(%root: !transform.any_op) {
transform.sequence %root : !transform.any_op failures(suppress) {
^bb0(%arg0: !transform.any_op):
transform.foreach_match in %arg0
@match_for -> @peel
: (!transform.any_op) -> !transform.any_op
transform.yield
}
transform.yield
}
transform.named_sequence @__transform_main(%root: !transform.any_op) {
transform.sequence %root : !transform.any_op failures(propagate) {
^bb0(%arg0: !transform.any_op):
// expected-error @below {{actions failed}}
transform.foreach_match in %arg0
@match_for -> @peel
: (!transform.any_op) -> !transform.any_op
transform.yield
}
transform.yield
}
}