[MLIR] Add replaceUsesWithIf on Operation

Add replaceUsesWithIf on Operation along the lines of
Value::replaceUsesWithIf. This had been missing on Operation and is
convenient to replace multi-result operations' results conditionally.

Reviewed By: lattner

Differential Revision: https://reviews.llvm.org/D144348
This commit is contained in:
Uday Bondhugula
2023-02-21 10:10:15 +05:30
parent 4638edaeaa
commit 8a583dd220
4 changed files with 37 additions and 6 deletions

View File

@@ -257,6 +257,15 @@ public:
getResults().replaceAllUsesWith(std::forward<ValuesT>(values));
}
/// Replace uses of results of this operation with the provided `values` if
/// the given callback returns true.
template <typename ValuesT>
void replaceUsesWithIf(ValuesT &&values,
function_ref<bool(OpOperand &)> shouldReplace) {
getResults().replaceUsesWithIf(std::forward<ValuesT>(values),
shouldReplace);
}
/// Destroys this operation and its subclass data.
void destroy();

View File

@@ -279,6 +279,26 @@ public:
/// Replace all uses of results of this range with results of 'op'.
void replaceAllUsesWith(Operation *op);
/// Replace uses of results of this range with the provided 'values' if the
/// given callback returns true. The size of `values` must match the size of
/// this range.
template <typename ValuesT>
std::enable_if_t<!std::is_convertible<ValuesT, Operation *>::value>
replaceUsesWithIf(ValuesT &&values,
function_ref<bool(OpOperand &)> shouldReplace) {
assert(static_cast<size_t>(std::distance(values.begin(), values.end())) ==
size() &&
"expected 'values' to correspond 1-1 with the number of results");
for (auto it : llvm::zip(*this, values))
std::get<0>(it).replaceUsesWithIf(std::get<1>(it), shouldReplace);
}
/// Replace uses of results of this range with results of `op` if the given
/// callback returns true.
void replaceUsesWithIf(Operation *op,
function_ref<bool(OpOperand &)> shouldReplace);
//===--------------------------------------------------------------------===//
// Users
//===--------------------------------------------------------------------===//

View File

@@ -589,6 +589,11 @@ void ResultRange::replaceAllUsesWith(Operation *op) {
replaceAllUsesWith(op->getResults());
}
void ResultRange::replaceUsesWithIf(
Operation *op, function_ref<bool(OpOperand &)> shouldReplace) {
replaceUsesWithIf(op->getResults(), shouldReplace);
}
//===----------------------------------------------------------------------===//
// ValueRange

View File

@@ -124,12 +124,9 @@ void CSE::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
} else {
// When the region does not have SSA dominance, we need to check if we
// have visited a use before replacing any use.
for (auto it : llvm::zip(op->getResults(), existing->getResults())) {
std::get<0>(it).replaceUsesWithIf(
std::get<1>(it), [&](OpOperand &operand) {
return !knownValues.count(operand.getOwner());
});
}
op->replaceUsesWithIf(existing->getResults(), [&](OpOperand &operand) {
return !knownValues.count(operand.getOwner());
});
// There may be some remaining uses of the operation.
if (op->use_empty())