[mlir] Add a new MutableOperandRange class for adding/remove operands

This class allows for mutating an operand range in-place, and provides vector like API for adding/erasing/setting. ODS now uses this class to generate mutable wrappers for named operands, with the name `MutableOperandRange <operand-name>Mutable()`

Differential Revision: https://reviews.llvm.org/D78892
This commit is contained in:
River Riddle
2020-04-29 16:09:11 -07:00
parent 9b16ece6ca
commit 108abd2f2e
10 changed files with 431 additions and 86 deletions

View File

@@ -296,6 +296,10 @@ public:
Attribute get(StringRef name) const;
Attribute get(Identifier name) const;
/// Return the specified named attribute if present, None otherwise.
Optional<NamedAttribute> getNamed(StringRef name) const;
Optional<NamedAttribute> getNamed(Identifier name) const;
/// Support range iteration.
using iterator = llvm::ArrayRef<NamedAttribute>::iterator;
iterator begin() const;
@@ -1513,6 +1517,10 @@ public:
Attribute get(StringRef name) const;
Attribute get(Identifier name) const;
/// Return the specified named attribute if present, None otherwise.
Optional<NamedAttribute> getNamed(StringRef name) const;
Optional<NamedAttribute> getNamed(Identifier name) const;
/// If the an attribute exists with the specified name, change it to the new
/// value. Otherwise, add a new attribute with the specified name/value.
void set(Identifier name, Attribute value);

View File

@@ -205,6 +205,14 @@ public:
/// 'operands'.
void setOperands(ValueRange operands);
/// Replace the operands beginning at 'start' and ending at 'start' + 'length'
/// with the ones provided in 'operands'. 'operands' may be smaller or larger
/// than the range pointed to by 'start'+'length'.
void setOperands(unsigned start, unsigned length, ValueRange operands);
/// Insert the given operands into the operand list at the given 'index'.
void insertOperands(unsigned index, ValueRange operands);
unsigned getNumOperands() {
return LLVM_LIKELY(hasOperandStorage) ? getOperandStorage().size() : 0;
}
@@ -214,6 +222,15 @@ public:
return getOpOperand(idx).set(value);
}
/// Erase the operand at position `idx`.
void eraseOperand(unsigned idx) { eraseOperands(idx); }
/// Erase the operands starting at position `idx` and ending at position
/// 'idx'+'length'.
void eraseOperands(unsigned idx, unsigned length = 1) {
getOperandStorage().eraseOperands(idx, length);
}
// Support operand iteration.
using operand_range = OperandRange;
using operand_iterator = operand_range::iterator;
@@ -221,12 +238,9 @@ public:
operand_iterator operand_begin() { return getOperands().begin(); }
operand_iterator operand_end() { return getOperands().end(); }
/// Returns an iterator on the underlying Value's (Value ).
/// Returns an iterator on the underlying Value's.
operand_range getOperands() { return operand_range(this); }
/// Erase the operand at position `idx`.
void eraseOperand(unsigned idx) { getOperandStorage().eraseOperand(idx); }
MutableArrayRef<OpOperand> getOpOperands() {
return LLVM_LIKELY(hasOperandStorage) ? getOperandStorage().getOperands()
: MutableArrayRef<OpOperand>();

View File

@@ -369,8 +369,14 @@ public:
/// 'values'.
void setOperands(Operation *owner, ValueRange values);
/// Erase an operand held by the storage.
void eraseOperand(unsigned index);
/// Replace the operands beginning at 'start' and ending at 'start' + 'length'
/// with the ones provided in 'operands'. 'operands' may be smaller or larger
/// than the range pointed to by 'start'+'length'.
void setOperands(Operation *owner, unsigned start, unsigned length,
ValueRange operands);
/// Erase the operands held by the storage within the given range.
void eraseOperands(unsigned start, unsigned length);
/// Get the operation operands held by the storage.
MutableArrayRef<OpOperand> getOperands() {
@@ -653,6 +659,62 @@ private:
friend RangeBaseT;
};
//===----------------------------------------------------------------------===//
// MutableOperandRange
/// This class provides a mutable adaptor for a range of operands. It allows for
/// setting, inserting, and erasing operands from the given range.
class MutableOperandRange {
public:
/// A pair of a named attribute corresponding to an operand segment attribute,
/// and the index within that attribute. The attribute should correspond to an
/// i32 DenseElementsAttr.
using OperandSegment = std::pair<unsigned, NamedAttribute>;
/// Construct a new mutable range from the given operand, operand start index,
/// and range length. `operandSegments` is an optional set of operand segments
/// to be updated when mutating the operand list.
MutableOperandRange(Operation *owner, unsigned start, unsigned length,
ArrayRef<OperandSegment> operandSegments = llvm::None);
MutableOperandRange(Operation *owner);
/// Append the given values to the range.
void append(ValueRange values);
/// Assign this range to the given values.
void assign(ValueRange values);
/// Assign the range to the given value.
void assign(Value value);
/// Erase the operands within the given sub-range.
void erase(unsigned subStart, unsigned subLen = 1);
/// Clear this range and erase all of the operands.
void clear();
/// Returns the current size of the range.
unsigned size() const { return length; }
/// Allow implicit conversion to an OperandRange.
operator OperandRange() const;
private:
/// Update the length of this range to the one provided.
void updateLength(unsigned newLength);
/// The owning operation of this range.
Operation *owner;
/// The start index of the operand range within the owner operand list, and
/// the length starting from `start`.
unsigned start, length;
/// Optional set of operand segments that should be updated when mutating the
/// length of this range.
SmallVector<std::pair<unsigned, NamedAttribute>, 1> operandSegments;
};
//===----------------------------------------------------------------------===//
// ResultRange

View File

@@ -164,7 +164,8 @@ public:
other.back = nullptr;
nextUse = nullptr;
back = nullptr;
insertIntoCurrent();
if (value)
insertIntoCurrent();
return *this;
}

View File

@@ -196,15 +196,26 @@ ArrayRef<NamedAttribute> DictionaryAttr::getValue() const {
/// Return the specified attribute if present, null otherwise.
Attribute DictionaryAttr::get(StringRef name) const {
ArrayRef<NamedAttribute> values = getValue();
auto it = llvm::lower_bound(values, name, compareNamedAttributeWithName);
return it != values.end() && it->first == name ? it->second : Attribute();
Optional<NamedAttribute> attr = getNamed(name);
return attr ? attr->second : nullptr;
}
Attribute DictionaryAttr::get(Identifier name) const {
Optional<NamedAttribute> attr = getNamed(name);
return attr ? attr->second : nullptr;
}
/// Return the specified named attribute if present, None otherwise.
Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const {
ArrayRef<NamedAttribute> values = getValue();
auto it = llvm::lower_bound(values, name, compareNamedAttributeWithName);
return it != values.end() && it->first == name ? *it
: Optional<NamedAttribute>();
}
Optional<NamedAttribute> DictionaryAttr::getNamed(Identifier name) const {
for (auto elt : getValue())
if (elt.first == name)
return elt.second;
return nullptr;
return elt;
return llvm::None;
}
DictionaryAttr::iterator DictionaryAttr::begin() const {
@@ -1191,6 +1202,15 @@ Attribute MutableDictionaryAttr::get(Identifier name) const {
return attrs ? attrs.get(name) : nullptr;
}
/// Return the specified named attribute if present, None otherwise.
Optional<NamedAttribute> MutableDictionaryAttr::getNamed(StringRef name) const {
return attrs ? attrs.getNamed(name) : Optional<NamedAttribute>();
}
Optional<NamedAttribute>
MutableDictionaryAttr::getNamed(Identifier name) const {
return attrs ? attrs.getNamed(name) : Optional<NamedAttribute>();
}
/// If the an attribute exists with the specified name, change it to the new
/// value. Otherwise, add a new attribute with the specified name/value.
void MutableDictionaryAttr::set(Identifier name, Attribute value) {

View File

@@ -244,6 +244,25 @@ void Operation::setOperands(ValueRange operands) {
assert(operands.empty() && "setting operands without an operand storage");
}
/// Replace the operands beginning at 'start' and ending at 'start' + 'length'
/// with the ones provided in 'operands'. 'operands' may be smaller or larger
/// than the range pointed to by 'start'+'length'.
void Operation::setOperands(unsigned start, unsigned length,
ValueRange operands) {
assert((start + length) <= getNumOperands() &&
"invalid operand range specified");
if (LLVM_LIKELY(hasOperandStorage))
return getOperandStorage().setOperands(this, start, length, operands);
assert(operands.empty() && "setting operands without an operand storage");
}
/// Insert the given operands into the operand list at the given 'index'.
void Operation::insertOperands(unsigned index, ValueRange operands) {
if (LLVM_LIKELY(hasOperandStorage))
return setOperands(index, /*length=*/0, operands);
assert(operands.empty() && "inserting operands without an operand storage");
}
//===----------------------------------------------------------------------===//
// Diagnostics
//===----------------------------------------------------------------------===//

View File

@@ -13,7 +13,9 @@
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/StandardTypes.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
@@ -89,6 +91,55 @@ void detail::OperandStorage::setOperands(Operation *owner, ValueRange values) {
storageOperands[i].set(values[i]);
}
/// Replace the operands beginning at 'start' and ending at 'start' + 'length'
/// with the ones provided in 'operands'. 'operands' may be smaller or larger
/// than the range pointed to by 'start'+'length'.
void detail::OperandStorage::setOperands(Operation *owner, unsigned start,
unsigned length, ValueRange operands) {
// If the new size is the same, we can update inplace.
unsigned newSize = operands.size();
if (newSize == length) {
MutableArrayRef<OpOperand> storageOperands = getOperands();
for (unsigned i = 0, e = length; i != e; ++i)
storageOperands[start + i].set(operands[i]);
return;
}
// If the new size is greater, remove the extra operands and set the rest
// inplace.
if (newSize < length) {
eraseOperands(start + operands.size(), length - newSize);
setOperands(owner, start, newSize, operands);
return;
}
// Otherwise, the new size is greater so we need to grow the storage.
auto storageOperands = resize(owner, size() + (newSize - length));
// Shift operands to the right to make space for the new operands.
unsigned rotateSize = storageOperands.size() - (start + length);
auto rbegin = storageOperands.rbegin();
std::rotate(rbegin, std::next(rbegin, newSize - length), rbegin + rotateSize);
// Update the operands inplace.
for (unsigned i = 0, e = operands.size(); i != e; ++i)
storageOperands[start + i].set(operands[i]);
}
/// Erase an operand held by the storage.
void detail::OperandStorage::eraseOperands(unsigned start, unsigned length) {
TrailingOperandStorage &storage = getStorage();
MutableArrayRef<OpOperand> operands = storage.getOperands();
assert((start + length) <= operands.size());
storage.numOperands -= length;
// Shift all operands down if the operand to remove is not at the end.
if (start != storage.numOperands) {
auto indexIt = std::next(operands.begin(), start);
std::rotate(indexIt, std::next(indexIt, length), operands.end());
}
for (unsigned i = 0; i != length; ++i)
operands[storage.numOperands + i].~OpOperand();
}
/// Resize the storage to the given size. Returns the array containing the new
/// operands.
MutableArrayRef<OpOperand> detail::OperandStorage::resize(Operation *owner,
@@ -149,20 +200,6 @@ MutableArrayRef<OpOperand> detail::OperandStorage::resize(Operation *owner,
return newOperands;
}
/// Erase an operand held by the storage.
void detail::OperandStorage::eraseOperand(unsigned index) {
assert(index < size());
TrailingOperandStorage &storage = getStorage();
MutableArrayRef<OpOperand> operands = storage.getOperands();
--storage.numOperands;
// Shift all operands down by 1 if the operand to remove is not at the end.
auto indexIt = std::next(operands.begin(), index);
if (index != storage.numOperands)
std::rotate(indexIt, std::next(indexIt), operands.end());
operands[storage.numOperands].~OpOperand();
}
//===----------------------------------------------------------------------===//
// ResultStorage
//===----------------------------------------------------------------------===//
@@ -235,6 +272,83 @@ unsigned OperandRange::getBeginOperandIndex() const {
return base->getOperandNumber();
}
//===----------------------------------------------------------------------===//
// MutableOperandRange
/// Construct a new mutable range from the given operand, operand start index,
/// and range length.
MutableOperandRange::MutableOperandRange(
Operation *owner, unsigned start, unsigned length,
ArrayRef<OperandSegment> operandSegments)
: owner(owner), start(start), length(length),
operandSegments(operandSegments.begin(), operandSegments.end()) {
assert((start + length) <= owner->getNumOperands() && "invalid range");
}
MutableOperandRange::MutableOperandRange(Operation *owner)
: MutableOperandRange(owner, /*start=*/0, owner->getNumOperands()) {}
/// Append the given values to the range.
void MutableOperandRange::append(ValueRange values) {
if (values.empty())
return;
owner->insertOperands(start + length, values);
updateLength(length + values.size());
}
/// Assign this range to the given values.
void MutableOperandRange::assign(ValueRange values) {
owner->setOperands(start, length, values);
if (length != values.size())
updateLength(/*newLength=*/values.size());
}
/// Assign the range to the given value.
void MutableOperandRange::assign(Value value) {
if (length == 1) {
owner->setOperand(start, value);
} else {
owner->setOperands(start, length, value);
updateLength(/*newLength=*/1);
}
}
/// Erase the operands within the given sub-range.
void MutableOperandRange::erase(unsigned subStart, unsigned subLen) {
assert((subStart + subLen) <= length && "invalid sub-range");
if (length == 0)
return;
owner->eraseOperands(start + subStart, subLen);
updateLength(length - subLen);
}
/// Clear this range and erase all of the operands.
void MutableOperandRange::clear() {
if (length != 0) {
owner->eraseOperands(start, length);
updateLength(/*newLength=*/0);
}
}
/// Allow implicit conversion to an OperandRange.
MutableOperandRange::operator OperandRange() const {
return owner->getOperands().slice(start, length);
}
/// Update the length of this range to the one provided.
void MutableOperandRange::updateLength(unsigned newLength) {
int32_t diff = int32_t(newLength) - int32_t(length);
length = newLength;
// Update any of the provided segment attributes.
for (OperandSegment &segment : operandSegments) {
auto attr = segment.second.second.cast<DenseIntElementsAttr>();
SmallVector<int32_t, 8> segments(attr.getValues<int32_t>());
segments[segment.first] += diff;
segment.second.second = DenseIntElementsAttr::get(attr.getType(), segments);
owner->setAttr(segment.second.first, segment.second.second);
}
}
//===----------------------------------------------------------------------===//
// ResultRange

View File

@@ -67,6 +67,8 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {
// CHECK: Operation::operand_range getODSOperands(unsigned index);
// CHECK: Value a();
// CHECK: Operation::operand_range b();
// CHECK: ::mlir::MutableOperandRange aMutable();
// CHECK: ::mlir::MutableOperandRange bMutable();
// CHECK: Operation::result_range getODSResults(unsigned index);
// CHECK: Value r();
// CHECK: Region &someRegion();
@@ -119,6 +121,7 @@ def NS_EOp : NS_Op<"op_with_optionals", []> {
// CHECK-LABEL: NS::EOp declarations
// CHECK: Value a();
// CHECK: ::mlir::MutableOperandRange aMutable();
// CHECK: Value b();
// CHECK: static void build(OpBuilder &odsBuilder, OperationState &odsState, /*optional*/Type b, /*optional*/Value a)

View File

@@ -45,25 +45,23 @@ static const char *const builderOpState = "odsState";
// {1}: The total number of non-variadic operands/results.
// {2}: The total number of variadic operands/results.
// {3}: The total number of actual values.
// {4}: The begin iterator of the actual values.
// {5}: "operand" or "result".
// {4}: "operand" or "result".
const char *sameVariadicSizeValueRangeCalcCode = R"(
bool isVariadic[] = {{{0}};
int prevVariadicCount = 0;
for (unsigned i = 0; i < index; ++i)
if (isVariadic[i]) ++prevVariadicCount;
// Calculate how many dynamic values a static variadic {5} corresponds to.
// This assumes all static variadic {5}s have the same dynamic value count.
// Calculate how many dynamic values a static variadic {4} corresponds to.
// This assumes all static variadic {4}s have the same dynamic value count.
int variadicSize = ({3} - {1}) / {2};
// `index` passed in as the parameter is the static index which counts each
// {5} (variadic or not) as size 1. So here for each previous static variadic
// {5}, we need to offset by (variadicSize - 1) to get where the dynamic
// value pack for this static {5} starts.
int offset = index + (variadicSize - 1) * prevVariadicCount;
// {4} (variadic or not) as size 1. So here for each previous static variadic
// {4}, we need to offset by (variadicSize - 1) to get where the dynamic
// value pack for this static {4} starts.
int start = index + (variadicSize - 1) * prevVariadicCount;
int size = isVariadic[index] ? variadicSize : 1;
return {{std::next({4}, offset), std::next({4}, offset + size)};
return {{start, size};
)";
// The logic to calculate the actual value range for a declared operand/result
@@ -72,14 +70,23 @@ const char *sameVariadicSizeValueRangeCalcCode = R"(
// (variadic or not).
//
// {0}: The name of the attribute specifying the segment sizes.
// {1}: The begin iterator of the actual values.
const char *attrSizedSegmentValueRangeCalcCode = R"(
auto sizeAttr = getAttrOfType<DenseIntElementsAttr>("{0}");
unsigned start = 0;
for (unsigned i = 0; i < index; ++i)
start += (*(sizeAttr.begin() + i)).getZExtValue();
unsigned end = start + (*(sizeAttr.begin() + index)).getZExtValue();
return {{std::next({1}, start), std::next({1}, end)};
unsigned size = (*(sizeAttr.begin() + index)).getZExtValue();
return {{start, size};
)";
// The logic to build a range of either operand or result values.
//
// {0}: The begin iterator of the actual values.
// {1}: The call to generate the start and length of the value range.
const char *valueRangeReturnCode = R"(
auto valueRange = {1};
return {{std::next({0}, valueRange.first),
std::next({0}, valueRange.first + valueRange.second)};
)";
static const char *const opCommentHeader = R"(
@@ -177,6 +184,9 @@ private:
// Generates getters for named operands.
void genNamedOperandGetters();
// Generates setters for named operands.
void genNamedOperandSetters();
// Generates getters for named results.
void genNamedResultGetters();
@@ -310,6 +320,7 @@ OpEmitter::OpEmitter(const Operator &op)
genOpAsmInterface();
genOpNameGetter();
genNamedOperandGetters();
genNamedOperandSetters();
genNamedResultGetters();
genNamedRegionGetters();
genNamedSuccessorGetters();
@@ -478,6 +489,37 @@ void OpEmitter::genAttrSetters() {
}
}
// Generates the code to compute the start and end index of an operand or result
// range.
template <typename RangeT>
static void
generateValueRangeStartAndEnd(Class &opClass, StringRef methodName,
int numVariadic, int numNonVariadic,
StringRef rangeSizeCall, bool hasAttrSegmentSize,
StringRef segmentSizeAttr, RangeT &&odsValues) {
auto &method = opClass.newMethod("std::pair<unsigned, unsigned>", methodName,
"unsigned index");
if (numVariadic == 0) {
method.body() << " return {index, 1};\n";
} else if (hasAttrSegmentSize) {
method.body() << formatv(attrSizedSegmentValueRangeCalcCode,
segmentSizeAttr);
} else {
// Because the op can have arbitrarily interleaved variadic and non-variadic
// operands, we need to embed a list in the "sink" getter method for
// calculation at run-time.
llvm::SmallVector<StringRef, 4> isVariadic;
isVariadic.reserve(llvm::size(odsValues));
for (auto &it : odsValues)
isVariadic.push_back(it.isVariableLength() ? "true" : "false");
std::string isVariadicList = llvm::join(isVariadic, ", ");
method.body() << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList,
numNonVariadic, numVariadic, rangeSizeCall,
"operand");
}
}
// Generates the named operand getter methods for the given Operator `op` and
// puts them in `opClass`. Uses `rangeType` as the return type of getters that
// return a range of operands (individual operands are `Value ` and each
@@ -519,32 +561,16 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass,
"'SameVariadicOperandSize' traits");
}
// First emit a "sink" getter method upon which we layer all nicer named
// First emit a few "sink" getter methods upon which we layer all nicer named
// getter methods.
generateValueRangeStartAndEnd(
opClass, "getODSOperandIndexAndLength", numVariadicOperands,
numNormalOperands, rangeSizeCall, attrSizedOperands,
"operand_segment_sizes", const_cast<Operator &>(op).getOperands());
auto &m = opClass.newMethod(rangeType, "getODSOperands", "unsigned index");
if (numVariadicOperands == 0) {
// We still need to match the return type, which is a range.
m.body() << " return {std::next(" << rangeBeginCall
<< ", index), std::next(" << rangeBeginCall << ", index + 1)};";
} else if (attrSizedOperands) {
m.body() << formatv(attrSizedSegmentValueRangeCalcCode,
"operand_segment_sizes", rangeBeginCall);
} else {
// Because the op can have arbitrarily interleaved variadic and non-variadic
// operands, we need to embed a list in the "sink" getter method for
// calculation at run-time.
llvm::SmallVector<StringRef, 4> isVariadic;
isVariadic.reserve(numOperands);
for (int i = 0; i < numOperands; ++i)
isVariadic.push_back(op.getOperand(i).isVariableLength() ? "true"
: "false");
std::string isVariadicList = llvm::join(isVariadic, ", ");
m.body() << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList,
numNormalOperands, numVariadicOperands, rangeSizeCall,
rangeBeginCall, "operand");
}
m.body() << formatv(valueRangeReturnCode, rangeBeginCall,
"getODSOperandIndexAndLength(index)");
// Then we emit nicer named getter methods by redirecting to the "sink" getter
// method.
@@ -579,6 +605,26 @@ void OpEmitter::genNamedOperandGetters() {
/*getOperandCallPattern=*/"getOperation()->getOperand({0})");
}
void OpEmitter::genNamedOperandSetters() {
auto *attrSizedOperands = op.getTrait("OpTrait::AttrSizedOperandSegments");
for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
const auto &operand = op.getOperand(i);
if (operand.name.empty())
continue;
auto &m = opClass.newMethod("::mlir::MutableOperandRange",
(operand.name + "Mutable").str());
auto &body = m.body();
body << " auto range = getODSOperandIndexAndLength(" << i << ");\n"
<< " return ::mlir::MutableOperandRange(getOperation(), "
"range.first, range.second";
if (attrSizedOperands)
body << ", ::mlir::MutableOperandRange::OperandSegment(" << i
<< "u, *getOperation()->getMutableAttrDict().getNamed("
"\"operand_segment_sizes\"))";
body << ");\n";
}
}
void OpEmitter::genNamedResultGetters() {
const int numResults = op.getNumResults();
const int numVariadicResults = op.getNumVariableLengthResults();
@@ -607,29 +653,14 @@ void OpEmitter::genNamedResultGetters() {
"'SameVariadicResultSize' traits");
}
generateValueRangeStartAndEnd(
opClass, "getODSResultIndexAndLength", numVariadicResults,
numNormalResults, "getOperation()->getNumResults()", attrSizedResults,
"result_segment_sizes", op.getResults());
auto &m = opClass.newMethod("Operation::result_range", "getODSResults",
"unsigned index");
if (numVariadicResults == 0) {
m.body() << " return {std::next(getOperation()->result_begin(), index), "
"std::next(getOperation()->result_begin(), index + 1)};";
} else if (attrSizedResults) {
m.body() << formatv(attrSizedSegmentValueRangeCalcCode,
"result_segment_sizes",
"getOperation()->result_begin()");
} else {
llvm::SmallVector<StringRef, 4> isVariadic;
isVariadic.reserve(numResults);
for (int i = 0; i < numResults; ++i)
isVariadic.push_back(op.getResult(i).isVariableLength() ? "true"
: "false");
std::string isVariadicList = llvm::join(isVariadic, ", ");
m.body() << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList,
numNormalResults, numVariadicResults,
"getOperation()->getNumResults()",
"getOperation()->result_begin()", "result");
}
m.body() << formatv(valueRangeReturnCode, "getOperation()->result_begin()",
"getODSResultIndexAndLength(index)");
for (int i = 0; i != numResults; ++i) {
const auto &result = op.getResult(i);

View File

@@ -33,7 +33,7 @@ TEST(OperandStorageTest, NonResizable) {
Value operand = useOp->getResult(0);
// Create a non-resizable operation with one operand.
Operation *user = createOp(&context, operand, builder.getIntegerType(16));
Operation *user = createOp(&context, operand);
// The same number of operands is okay.
user->setOperands(operand);
@@ -57,7 +57,7 @@ TEST(OperandStorageTest, Resizable) {
Value operand = useOp->getResult(0);
// Create a resizable operation with one operand.
Operation *user = createOp(&context, operand, builder.getIntegerType(16));
Operation *user = createOp(&context, operand);
// The same number of operands is okay.
user->setOperands(operand);
@@ -76,4 +76,77 @@ TEST(OperandStorageTest, Resizable) {
useOp->destroy();
}
TEST(OperandStorageTest, RangeReplace) {
MLIRContext context;
Builder builder(&context);
Operation *useOp =
createOp(&context, /*operands=*/llvm::None, builder.getIntegerType(16));
Value operand = useOp->getResult(0);
// Create a resizable operation with one operand.
Operation *user = createOp(&context, operand);
// Check setting with the same number of operands.
user->setOperands(/*start=*/0, /*length=*/1, operand);
EXPECT_EQ(user->getNumOperands(), 1u);
// Check setting with more operands.
user->setOperands(/*start=*/0, /*length=*/1, {operand, operand, operand});
EXPECT_EQ(user->getNumOperands(), 3u);
// Check setting with less operands.
user->setOperands(/*start=*/1, /*length=*/2, {operand});
EXPECT_EQ(user->getNumOperands(), 2u);
// Check inserting without replacing operands.
user->setOperands(/*start=*/2, /*length=*/0, {operand});
EXPECT_EQ(user->getNumOperands(), 3u);
// Check erasing operands.
user->setOperands(/*start=*/0, /*length=*/3, {});
EXPECT_EQ(user->getNumOperands(), 0u);
// Destroy the operations.
user->destroy();
useOp->destroy();
}
TEST(OperandStorageTest, MutableRange) {
MLIRContext context;
Builder builder(&context);
Operation *useOp =
createOp(&context, /*operands=*/llvm::None, builder.getIntegerType(16));
Value operand = useOp->getResult(0);
// Create a resizable operation with one operand.
Operation *user = createOp(&context, operand);
// Check setting with the same number of operands.
MutableOperandRange mutableOperands(user);
mutableOperands.assign(operand);
EXPECT_EQ(mutableOperands.size(), 1u);
EXPECT_EQ(user->getNumOperands(), 1u);
// Check setting with more operands.
mutableOperands.assign({operand, operand, operand});
EXPECT_EQ(mutableOperands.size(), 3u);
EXPECT_EQ(user->getNumOperands(), 3u);
// Check with inserting a new operand.
mutableOperands.append({operand, operand});
EXPECT_EQ(mutableOperands.size(), 5u);
EXPECT_EQ(user->getNumOperands(), 5u);
// Check erasing operands.
mutableOperands.clear();
EXPECT_EQ(mutableOperands.size(), 0u);
EXPECT_EQ(user->getNumOperands(), 0u);
// Destroy the operations.
user->destroy();
useOp->destroy();
}
} // end namespace