mirror of
https://github.com/intel/llvm.git
synced 2026-01-26 12:26:52 +08:00
[mlir] Add a new MutableOperandRange class for adding/remove operands
This class allows for mutating an operand range in-place, and provides vector like API for adding/erasing/setting. ODS now uses this class to generate mutable wrappers for named operands, with the name `MutableOperandRange <operand-name>Mutable()` Differential Revision: https://reviews.llvm.org/D78892
This commit is contained in:
@@ -296,6 +296,10 @@ public:
|
||||
Attribute get(StringRef name) const;
|
||||
Attribute get(Identifier name) const;
|
||||
|
||||
/// Return the specified named attribute if present, None otherwise.
|
||||
Optional<NamedAttribute> getNamed(StringRef name) const;
|
||||
Optional<NamedAttribute> getNamed(Identifier name) const;
|
||||
|
||||
/// Support range iteration.
|
||||
using iterator = llvm::ArrayRef<NamedAttribute>::iterator;
|
||||
iterator begin() const;
|
||||
@@ -1513,6 +1517,10 @@ public:
|
||||
Attribute get(StringRef name) const;
|
||||
Attribute get(Identifier name) const;
|
||||
|
||||
/// Return the specified named attribute if present, None otherwise.
|
||||
Optional<NamedAttribute> getNamed(StringRef name) const;
|
||||
Optional<NamedAttribute> getNamed(Identifier name) const;
|
||||
|
||||
/// If the an attribute exists with the specified name, change it to the new
|
||||
/// value. Otherwise, add a new attribute with the specified name/value.
|
||||
void set(Identifier name, Attribute value);
|
||||
|
||||
@@ -205,6 +205,14 @@ public:
|
||||
/// 'operands'.
|
||||
void setOperands(ValueRange operands);
|
||||
|
||||
/// Replace the operands beginning at 'start' and ending at 'start' + 'length'
|
||||
/// with the ones provided in 'operands'. 'operands' may be smaller or larger
|
||||
/// than the range pointed to by 'start'+'length'.
|
||||
void setOperands(unsigned start, unsigned length, ValueRange operands);
|
||||
|
||||
/// Insert the given operands into the operand list at the given 'index'.
|
||||
void insertOperands(unsigned index, ValueRange operands);
|
||||
|
||||
unsigned getNumOperands() {
|
||||
return LLVM_LIKELY(hasOperandStorage) ? getOperandStorage().size() : 0;
|
||||
}
|
||||
@@ -214,6 +222,15 @@ public:
|
||||
return getOpOperand(idx).set(value);
|
||||
}
|
||||
|
||||
/// Erase the operand at position `idx`.
|
||||
void eraseOperand(unsigned idx) { eraseOperands(idx); }
|
||||
|
||||
/// Erase the operands starting at position `idx` and ending at position
|
||||
/// 'idx'+'length'.
|
||||
void eraseOperands(unsigned idx, unsigned length = 1) {
|
||||
getOperandStorage().eraseOperands(idx, length);
|
||||
}
|
||||
|
||||
// Support operand iteration.
|
||||
using operand_range = OperandRange;
|
||||
using operand_iterator = operand_range::iterator;
|
||||
@@ -221,12 +238,9 @@ public:
|
||||
operand_iterator operand_begin() { return getOperands().begin(); }
|
||||
operand_iterator operand_end() { return getOperands().end(); }
|
||||
|
||||
/// Returns an iterator on the underlying Value's (Value ).
|
||||
/// Returns an iterator on the underlying Value's.
|
||||
operand_range getOperands() { return operand_range(this); }
|
||||
|
||||
/// Erase the operand at position `idx`.
|
||||
void eraseOperand(unsigned idx) { getOperandStorage().eraseOperand(idx); }
|
||||
|
||||
MutableArrayRef<OpOperand> getOpOperands() {
|
||||
return LLVM_LIKELY(hasOperandStorage) ? getOperandStorage().getOperands()
|
||||
: MutableArrayRef<OpOperand>();
|
||||
|
||||
@@ -369,8 +369,14 @@ public:
|
||||
/// 'values'.
|
||||
void setOperands(Operation *owner, ValueRange values);
|
||||
|
||||
/// Erase an operand held by the storage.
|
||||
void eraseOperand(unsigned index);
|
||||
/// Replace the operands beginning at 'start' and ending at 'start' + 'length'
|
||||
/// with the ones provided in 'operands'. 'operands' may be smaller or larger
|
||||
/// than the range pointed to by 'start'+'length'.
|
||||
void setOperands(Operation *owner, unsigned start, unsigned length,
|
||||
ValueRange operands);
|
||||
|
||||
/// Erase the operands held by the storage within the given range.
|
||||
void eraseOperands(unsigned start, unsigned length);
|
||||
|
||||
/// Get the operation operands held by the storage.
|
||||
MutableArrayRef<OpOperand> getOperands() {
|
||||
@@ -653,6 +659,62 @@ private:
|
||||
friend RangeBaseT;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MutableOperandRange
|
||||
|
||||
/// This class provides a mutable adaptor for a range of operands. It allows for
|
||||
/// setting, inserting, and erasing operands from the given range.
|
||||
class MutableOperandRange {
|
||||
public:
|
||||
/// A pair of a named attribute corresponding to an operand segment attribute,
|
||||
/// and the index within that attribute. The attribute should correspond to an
|
||||
/// i32 DenseElementsAttr.
|
||||
using OperandSegment = std::pair<unsigned, NamedAttribute>;
|
||||
|
||||
/// Construct a new mutable range from the given operand, operand start index,
|
||||
/// and range length. `operandSegments` is an optional set of operand segments
|
||||
/// to be updated when mutating the operand list.
|
||||
MutableOperandRange(Operation *owner, unsigned start, unsigned length,
|
||||
ArrayRef<OperandSegment> operandSegments = llvm::None);
|
||||
MutableOperandRange(Operation *owner);
|
||||
|
||||
/// Append the given values to the range.
|
||||
void append(ValueRange values);
|
||||
|
||||
/// Assign this range to the given values.
|
||||
void assign(ValueRange values);
|
||||
|
||||
/// Assign the range to the given value.
|
||||
void assign(Value value);
|
||||
|
||||
/// Erase the operands within the given sub-range.
|
||||
void erase(unsigned subStart, unsigned subLen = 1);
|
||||
|
||||
/// Clear this range and erase all of the operands.
|
||||
void clear();
|
||||
|
||||
/// Returns the current size of the range.
|
||||
unsigned size() const { return length; }
|
||||
|
||||
/// Allow implicit conversion to an OperandRange.
|
||||
operator OperandRange() const;
|
||||
|
||||
private:
|
||||
/// Update the length of this range to the one provided.
|
||||
void updateLength(unsigned newLength);
|
||||
|
||||
/// The owning operation of this range.
|
||||
Operation *owner;
|
||||
|
||||
/// The start index of the operand range within the owner operand list, and
|
||||
/// the length starting from `start`.
|
||||
unsigned start, length;
|
||||
|
||||
/// Optional set of operand segments that should be updated when mutating the
|
||||
/// length of this range.
|
||||
SmallVector<std::pair<unsigned, NamedAttribute>, 1> operandSegments;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ResultRange
|
||||
|
||||
|
||||
@@ -164,7 +164,8 @@ public:
|
||||
other.back = nullptr;
|
||||
nextUse = nullptr;
|
||||
back = nullptr;
|
||||
insertIntoCurrent();
|
||||
if (value)
|
||||
insertIntoCurrent();
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
||||
@@ -196,15 +196,26 @@ ArrayRef<NamedAttribute> DictionaryAttr::getValue() const {
|
||||
|
||||
/// Return the specified attribute if present, null otherwise.
|
||||
Attribute DictionaryAttr::get(StringRef name) const {
|
||||
ArrayRef<NamedAttribute> values = getValue();
|
||||
auto it = llvm::lower_bound(values, name, compareNamedAttributeWithName);
|
||||
return it != values.end() && it->first == name ? it->second : Attribute();
|
||||
Optional<NamedAttribute> attr = getNamed(name);
|
||||
return attr ? attr->second : nullptr;
|
||||
}
|
||||
Attribute DictionaryAttr::get(Identifier name) const {
|
||||
Optional<NamedAttribute> attr = getNamed(name);
|
||||
return attr ? attr->second : nullptr;
|
||||
}
|
||||
|
||||
/// Return the specified named attribute if present, None otherwise.
|
||||
Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const {
|
||||
ArrayRef<NamedAttribute> values = getValue();
|
||||
auto it = llvm::lower_bound(values, name, compareNamedAttributeWithName);
|
||||
return it != values.end() && it->first == name ? *it
|
||||
: Optional<NamedAttribute>();
|
||||
}
|
||||
Optional<NamedAttribute> DictionaryAttr::getNamed(Identifier name) const {
|
||||
for (auto elt : getValue())
|
||||
if (elt.first == name)
|
||||
return elt.second;
|
||||
return nullptr;
|
||||
return elt;
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
DictionaryAttr::iterator DictionaryAttr::begin() const {
|
||||
@@ -1191,6 +1202,15 @@ Attribute MutableDictionaryAttr::get(Identifier name) const {
|
||||
return attrs ? attrs.get(name) : nullptr;
|
||||
}
|
||||
|
||||
/// Return the specified named attribute if present, None otherwise.
|
||||
Optional<NamedAttribute> MutableDictionaryAttr::getNamed(StringRef name) const {
|
||||
return attrs ? attrs.getNamed(name) : Optional<NamedAttribute>();
|
||||
}
|
||||
Optional<NamedAttribute>
|
||||
MutableDictionaryAttr::getNamed(Identifier name) const {
|
||||
return attrs ? attrs.getNamed(name) : Optional<NamedAttribute>();
|
||||
}
|
||||
|
||||
/// If the an attribute exists with the specified name, change it to the new
|
||||
/// value. Otherwise, add a new attribute with the specified name/value.
|
||||
void MutableDictionaryAttr::set(Identifier name, Attribute value) {
|
||||
|
||||
@@ -244,6 +244,25 @@ void Operation::setOperands(ValueRange operands) {
|
||||
assert(operands.empty() && "setting operands without an operand storage");
|
||||
}
|
||||
|
||||
/// Replace the operands beginning at 'start' and ending at 'start' + 'length'
|
||||
/// with the ones provided in 'operands'. 'operands' may be smaller or larger
|
||||
/// than the range pointed to by 'start'+'length'.
|
||||
void Operation::setOperands(unsigned start, unsigned length,
|
||||
ValueRange operands) {
|
||||
assert((start + length) <= getNumOperands() &&
|
||||
"invalid operand range specified");
|
||||
if (LLVM_LIKELY(hasOperandStorage))
|
||||
return getOperandStorage().setOperands(this, start, length, operands);
|
||||
assert(operands.empty() && "setting operands without an operand storage");
|
||||
}
|
||||
|
||||
/// Insert the given operands into the operand list at the given 'index'.
|
||||
void Operation::insertOperands(unsigned index, ValueRange operands) {
|
||||
if (LLVM_LIKELY(hasOperandStorage))
|
||||
return setOperands(index, /*length=*/0, operands);
|
||||
assert(operands.empty() && "inserting operands without an operand storage");
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Diagnostics
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -13,7 +13,9 @@
|
||||
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
using namespace mlir;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -89,6 +91,55 @@ void detail::OperandStorage::setOperands(Operation *owner, ValueRange values) {
|
||||
storageOperands[i].set(values[i]);
|
||||
}
|
||||
|
||||
/// Replace the operands beginning at 'start' and ending at 'start' + 'length'
|
||||
/// with the ones provided in 'operands'. 'operands' may be smaller or larger
|
||||
/// than the range pointed to by 'start'+'length'.
|
||||
void detail::OperandStorage::setOperands(Operation *owner, unsigned start,
|
||||
unsigned length, ValueRange operands) {
|
||||
// If the new size is the same, we can update inplace.
|
||||
unsigned newSize = operands.size();
|
||||
if (newSize == length) {
|
||||
MutableArrayRef<OpOperand> storageOperands = getOperands();
|
||||
for (unsigned i = 0, e = length; i != e; ++i)
|
||||
storageOperands[start + i].set(operands[i]);
|
||||
return;
|
||||
}
|
||||
// If the new size is greater, remove the extra operands and set the rest
|
||||
// inplace.
|
||||
if (newSize < length) {
|
||||
eraseOperands(start + operands.size(), length - newSize);
|
||||
setOperands(owner, start, newSize, operands);
|
||||
return;
|
||||
}
|
||||
// Otherwise, the new size is greater so we need to grow the storage.
|
||||
auto storageOperands = resize(owner, size() + (newSize - length));
|
||||
|
||||
// Shift operands to the right to make space for the new operands.
|
||||
unsigned rotateSize = storageOperands.size() - (start + length);
|
||||
auto rbegin = storageOperands.rbegin();
|
||||
std::rotate(rbegin, std::next(rbegin, newSize - length), rbegin + rotateSize);
|
||||
|
||||
// Update the operands inplace.
|
||||
for (unsigned i = 0, e = operands.size(); i != e; ++i)
|
||||
storageOperands[start + i].set(operands[i]);
|
||||
}
|
||||
|
||||
/// Erase an operand held by the storage.
|
||||
void detail::OperandStorage::eraseOperands(unsigned start, unsigned length) {
|
||||
TrailingOperandStorage &storage = getStorage();
|
||||
MutableArrayRef<OpOperand> operands = storage.getOperands();
|
||||
assert((start + length) <= operands.size());
|
||||
storage.numOperands -= length;
|
||||
|
||||
// Shift all operands down if the operand to remove is not at the end.
|
||||
if (start != storage.numOperands) {
|
||||
auto indexIt = std::next(operands.begin(), start);
|
||||
std::rotate(indexIt, std::next(indexIt, length), operands.end());
|
||||
}
|
||||
for (unsigned i = 0; i != length; ++i)
|
||||
operands[storage.numOperands + i].~OpOperand();
|
||||
}
|
||||
|
||||
/// Resize the storage to the given size. Returns the array containing the new
|
||||
/// operands.
|
||||
MutableArrayRef<OpOperand> detail::OperandStorage::resize(Operation *owner,
|
||||
@@ -149,20 +200,6 @@ MutableArrayRef<OpOperand> detail::OperandStorage::resize(Operation *owner,
|
||||
return newOperands;
|
||||
}
|
||||
|
||||
/// Erase an operand held by the storage.
|
||||
void detail::OperandStorage::eraseOperand(unsigned index) {
|
||||
assert(index < size());
|
||||
TrailingOperandStorage &storage = getStorage();
|
||||
MutableArrayRef<OpOperand> operands = storage.getOperands();
|
||||
--storage.numOperands;
|
||||
|
||||
// Shift all operands down by 1 if the operand to remove is not at the end.
|
||||
auto indexIt = std::next(operands.begin(), index);
|
||||
if (index != storage.numOperands)
|
||||
std::rotate(indexIt, std::next(indexIt), operands.end());
|
||||
operands[storage.numOperands].~OpOperand();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ResultStorage
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -235,6 +272,83 @@ unsigned OperandRange::getBeginOperandIndex() const {
|
||||
return base->getOperandNumber();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MutableOperandRange
|
||||
|
||||
/// Construct a new mutable range from the given operand, operand start index,
|
||||
/// and range length.
|
||||
MutableOperandRange::MutableOperandRange(
|
||||
Operation *owner, unsigned start, unsigned length,
|
||||
ArrayRef<OperandSegment> operandSegments)
|
||||
: owner(owner), start(start), length(length),
|
||||
operandSegments(operandSegments.begin(), operandSegments.end()) {
|
||||
assert((start + length) <= owner->getNumOperands() && "invalid range");
|
||||
}
|
||||
MutableOperandRange::MutableOperandRange(Operation *owner)
|
||||
: MutableOperandRange(owner, /*start=*/0, owner->getNumOperands()) {}
|
||||
|
||||
/// Append the given values to the range.
|
||||
void MutableOperandRange::append(ValueRange values) {
|
||||
if (values.empty())
|
||||
return;
|
||||
owner->insertOperands(start + length, values);
|
||||
updateLength(length + values.size());
|
||||
}
|
||||
|
||||
/// Assign this range to the given values.
|
||||
void MutableOperandRange::assign(ValueRange values) {
|
||||
owner->setOperands(start, length, values);
|
||||
if (length != values.size())
|
||||
updateLength(/*newLength=*/values.size());
|
||||
}
|
||||
|
||||
/// Assign the range to the given value.
|
||||
void MutableOperandRange::assign(Value value) {
|
||||
if (length == 1) {
|
||||
owner->setOperand(start, value);
|
||||
} else {
|
||||
owner->setOperands(start, length, value);
|
||||
updateLength(/*newLength=*/1);
|
||||
}
|
||||
}
|
||||
|
||||
/// Erase the operands within the given sub-range.
|
||||
void MutableOperandRange::erase(unsigned subStart, unsigned subLen) {
|
||||
assert((subStart + subLen) <= length && "invalid sub-range");
|
||||
if (length == 0)
|
||||
return;
|
||||
owner->eraseOperands(start + subStart, subLen);
|
||||
updateLength(length - subLen);
|
||||
}
|
||||
|
||||
/// Clear this range and erase all of the operands.
|
||||
void MutableOperandRange::clear() {
|
||||
if (length != 0) {
|
||||
owner->eraseOperands(start, length);
|
||||
updateLength(/*newLength=*/0);
|
||||
}
|
||||
}
|
||||
|
||||
/// Allow implicit conversion to an OperandRange.
|
||||
MutableOperandRange::operator OperandRange() const {
|
||||
return owner->getOperands().slice(start, length);
|
||||
}
|
||||
|
||||
/// Update the length of this range to the one provided.
|
||||
void MutableOperandRange::updateLength(unsigned newLength) {
|
||||
int32_t diff = int32_t(newLength) - int32_t(length);
|
||||
length = newLength;
|
||||
|
||||
// Update any of the provided segment attributes.
|
||||
for (OperandSegment &segment : operandSegments) {
|
||||
auto attr = segment.second.second.cast<DenseIntElementsAttr>();
|
||||
SmallVector<int32_t, 8> segments(attr.getValues<int32_t>());
|
||||
segments[segment.first] += diff;
|
||||
segment.second.second = DenseIntElementsAttr::get(attr.getType(), segments);
|
||||
owner->setAttr(segment.second.first, segment.second.second);
|
||||
}
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ResultRange
|
||||
|
||||
|
||||
@@ -67,6 +67,8 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {
|
||||
// CHECK: Operation::operand_range getODSOperands(unsigned index);
|
||||
// CHECK: Value a();
|
||||
// CHECK: Operation::operand_range b();
|
||||
// CHECK: ::mlir::MutableOperandRange aMutable();
|
||||
// CHECK: ::mlir::MutableOperandRange bMutable();
|
||||
// CHECK: Operation::result_range getODSResults(unsigned index);
|
||||
// CHECK: Value r();
|
||||
// CHECK: Region &someRegion();
|
||||
@@ -119,6 +121,7 @@ def NS_EOp : NS_Op<"op_with_optionals", []> {
|
||||
|
||||
// CHECK-LABEL: NS::EOp declarations
|
||||
// CHECK: Value a();
|
||||
// CHECK: ::mlir::MutableOperandRange aMutable();
|
||||
// CHECK: Value b();
|
||||
// CHECK: static void build(OpBuilder &odsBuilder, OperationState &odsState, /*optional*/Type b, /*optional*/Value a)
|
||||
|
||||
|
||||
@@ -45,25 +45,23 @@ static const char *const builderOpState = "odsState";
|
||||
// {1}: The total number of non-variadic operands/results.
|
||||
// {2}: The total number of variadic operands/results.
|
||||
// {3}: The total number of actual values.
|
||||
// {4}: The begin iterator of the actual values.
|
||||
// {5}: "operand" or "result".
|
||||
// {4}: "operand" or "result".
|
||||
const char *sameVariadicSizeValueRangeCalcCode = R"(
|
||||
bool isVariadic[] = {{{0}};
|
||||
int prevVariadicCount = 0;
|
||||
for (unsigned i = 0; i < index; ++i)
|
||||
if (isVariadic[i]) ++prevVariadicCount;
|
||||
|
||||
// Calculate how many dynamic values a static variadic {5} corresponds to.
|
||||
// This assumes all static variadic {5}s have the same dynamic value count.
|
||||
// Calculate how many dynamic values a static variadic {4} corresponds to.
|
||||
// This assumes all static variadic {4}s have the same dynamic value count.
|
||||
int variadicSize = ({3} - {1}) / {2};
|
||||
// `index` passed in as the parameter is the static index which counts each
|
||||
// {5} (variadic or not) as size 1. So here for each previous static variadic
|
||||
// {5}, we need to offset by (variadicSize - 1) to get where the dynamic
|
||||
// value pack for this static {5} starts.
|
||||
int offset = index + (variadicSize - 1) * prevVariadicCount;
|
||||
// {4} (variadic or not) as size 1. So here for each previous static variadic
|
||||
// {4}, we need to offset by (variadicSize - 1) to get where the dynamic
|
||||
// value pack for this static {4} starts.
|
||||
int start = index + (variadicSize - 1) * prevVariadicCount;
|
||||
int size = isVariadic[index] ? variadicSize : 1;
|
||||
|
||||
return {{std::next({4}, offset), std::next({4}, offset + size)};
|
||||
return {{start, size};
|
||||
)";
|
||||
|
||||
// The logic to calculate the actual value range for a declared operand/result
|
||||
@@ -72,14 +70,23 @@ const char *sameVariadicSizeValueRangeCalcCode = R"(
|
||||
// (variadic or not).
|
||||
//
|
||||
// {0}: The name of the attribute specifying the segment sizes.
|
||||
// {1}: The begin iterator of the actual values.
|
||||
const char *attrSizedSegmentValueRangeCalcCode = R"(
|
||||
auto sizeAttr = getAttrOfType<DenseIntElementsAttr>("{0}");
|
||||
unsigned start = 0;
|
||||
for (unsigned i = 0; i < index; ++i)
|
||||
start += (*(sizeAttr.begin() + i)).getZExtValue();
|
||||
unsigned end = start + (*(sizeAttr.begin() + index)).getZExtValue();
|
||||
return {{std::next({1}, start), std::next({1}, end)};
|
||||
unsigned size = (*(sizeAttr.begin() + index)).getZExtValue();
|
||||
return {{start, size};
|
||||
)";
|
||||
|
||||
// The logic to build a range of either operand or result values.
|
||||
//
|
||||
// {0}: The begin iterator of the actual values.
|
||||
// {1}: The call to generate the start and length of the value range.
|
||||
const char *valueRangeReturnCode = R"(
|
||||
auto valueRange = {1};
|
||||
return {{std::next({0}, valueRange.first),
|
||||
std::next({0}, valueRange.first + valueRange.second)};
|
||||
)";
|
||||
|
||||
static const char *const opCommentHeader = R"(
|
||||
@@ -177,6 +184,9 @@ private:
|
||||
// Generates getters for named operands.
|
||||
void genNamedOperandGetters();
|
||||
|
||||
// Generates setters for named operands.
|
||||
void genNamedOperandSetters();
|
||||
|
||||
// Generates getters for named results.
|
||||
void genNamedResultGetters();
|
||||
|
||||
@@ -310,6 +320,7 @@ OpEmitter::OpEmitter(const Operator &op)
|
||||
genOpAsmInterface();
|
||||
genOpNameGetter();
|
||||
genNamedOperandGetters();
|
||||
genNamedOperandSetters();
|
||||
genNamedResultGetters();
|
||||
genNamedRegionGetters();
|
||||
genNamedSuccessorGetters();
|
||||
@@ -478,6 +489,37 @@ void OpEmitter::genAttrSetters() {
|
||||
}
|
||||
}
|
||||
|
||||
// Generates the code to compute the start and end index of an operand or result
|
||||
// range.
|
||||
template <typename RangeT>
|
||||
static void
|
||||
generateValueRangeStartAndEnd(Class &opClass, StringRef methodName,
|
||||
int numVariadic, int numNonVariadic,
|
||||
StringRef rangeSizeCall, bool hasAttrSegmentSize,
|
||||
StringRef segmentSizeAttr, RangeT &&odsValues) {
|
||||
auto &method = opClass.newMethod("std::pair<unsigned, unsigned>", methodName,
|
||||
"unsigned index");
|
||||
|
||||
if (numVariadic == 0) {
|
||||
method.body() << " return {index, 1};\n";
|
||||
} else if (hasAttrSegmentSize) {
|
||||
method.body() << formatv(attrSizedSegmentValueRangeCalcCode,
|
||||
segmentSizeAttr);
|
||||
} else {
|
||||
// Because the op can have arbitrarily interleaved variadic and non-variadic
|
||||
// operands, we need to embed a list in the "sink" getter method for
|
||||
// calculation at run-time.
|
||||
llvm::SmallVector<StringRef, 4> isVariadic;
|
||||
isVariadic.reserve(llvm::size(odsValues));
|
||||
for (auto &it : odsValues)
|
||||
isVariadic.push_back(it.isVariableLength() ? "true" : "false");
|
||||
std::string isVariadicList = llvm::join(isVariadic, ", ");
|
||||
method.body() << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList,
|
||||
numNonVariadic, numVariadic, rangeSizeCall,
|
||||
"operand");
|
||||
}
|
||||
}
|
||||
|
||||
// Generates the named operand getter methods for the given Operator `op` and
|
||||
// puts them in `opClass`. Uses `rangeType` as the return type of getters that
|
||||
// return a range of operands (individual operands are `Value ` and each
|
||||
@@ -519,32 +561,16 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass,
|
||||
"'SameVariadicOperandSize' traits");
|
||||
}
|
||||
|
||||
// First emit a "sink" getter method upon which we layer all nicer named
|
||||
// First emit a few "sink" getter methods upon which we layer all nicer named
|
||||
// getter methods.
|
||||
generateValueRangeStartAndEnd(
|
||||
opClass, "getODSOperandIndexAndLength", numVariadicOperands,
|
||||
numNormalOperands, rangeSizeCall, attrSizedOperands,
|
||||
"operand_segment_sizes", const_cast<Operator &>(op).getOperands());
|
||||
|
||||
auto &m = opClass.newMethod(rangeType, "getODSOperands", "unsigned index");
|
||||
|
||||
if (numVariadicOperands == 0) {
|
||||
// We still need to match the return type, which is a range.
|
||||
m.body() << " return {std::next(" << rangeBeginCall
|
||||
<< ", index), std::next(" << rangeBeginCall << ", index + 1)};";
|
||||
} else if (attrSizedOperands) {
|
||||
m.body() << formatv(attrSizedSegmentValueRangeCalcCode,
|
||||
"operand_segment_sizes", rangeBeginCall);
|
||||
} else {
|
||||
// Because the op can have arbitrarily interleaved variadic and non-variadic
|
||||
// operands, we need to embed a list in the "sink" getter method for
|
||||
// calculation at run-time.
|
||||
llvm::SmallVector<StringRef, 4> isVariadic;
|
||||
isVariadic.reserve(numOperands);
|
||||
for (int i = 0; i < numOperands; ++i)
|
||||
isVariadic.push_back(op.getOperand(i).isVariableLength() ? "true"
|
||||
: "false");
|
||||
std::string isVariadicList = llvm::join(isVariadic, ", ");
|
||||
|
||||
m.body() << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList,
|
||||
numNormalOperands, numVariadicOperands, rangeSizeCall,
|
||||
rangeBeginCall, "operand");
|
||||
}
|
||||
m.body() << formatv(valueRangeReturnCode, rangeBeginCall,
|
||||
"getODSOperandIndexAndLength(index)");
|
||||
|
||||
// Then we emit nicer named getter methods by redirecting to the "sink" getter
|
||||
// method.
|
||||
@@ -579,6 +605,26 @@ void OpEmitter::genNamedOperandGetters() {
|
||||
/*getOperandCallPattern=*/"getOperation()->getOperand({0})");
|
||||
}
|
||||
|
||||
void OpEmitter::genNamedOperandSetters() {
|
||||
auto *attrSizedOperands = op.getTrait("OpTrait::AttrSizedOperandSegments");
|
||||
for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
|
||||
const auto &operand = op.getOperand(i);
|
||||
if (operand.name.empty())
|
||||
continue;
|
||||
auto &m = opClass.newMethod("::mlir::MutableOperandRange",
|
||||
(operand.name + "Mutable").str());
|
||||
auto &body = m.body();
|
||||
body << " auto range = getODSOperandIndexAndLength(" << i << ");\n"
|
||||
<< " return ::mlir::MutableOperandRange(getOperation(), "
|
||||
"range.first, range.second";
|
||||
if (attrSizedOperands)
|
||||
body << ", ::mlir::MutableOperandRange::OperandSegment(" << i
|
||||
<< "u, *getOperation()->getMutableAttrDict().getNamed("
|
||||
"\"operand_segment_sizes\"))";
|
||||
body << ");\n";
|
||||
}
|
||||
}
|
||||
|
||||
void OpEmitter::genNamedResultGetters() {
|
||||
const int numResults = op.getNumResults();
|
||||
const int numVariadicResults = op.getNumVariableLengthResults();
|
||||
@@ -607,29 +653,14 @@ void OpEmitter::genNamedResultGetters() {
|
||||
"'SameVariadicResultSize' traits");
|
||||
}
|
||||
|
||||
generateValueRangeStartAndEnd(
|
||||
opClass, "getODSResultIndexAndLength", numVariadicResults,
|
||||
numNormalResults, "getOperation()->getNumResults()", attrSizedResults,
|
||||
"result_segment_sizes", op.getResults());
|
||||
auto &m = opClass.newMethod("Operation::result_range", "getODSResults",
|
||||
"unsigned index");
|
||||
|
||||
if (numVariadicResults == 0) {
|
||||
m.body() << " return {std::next(getOperation()->result_begin(), index), "
|
||||
"std::next(getOperation()->result_begin(), index + 1)};";
|
||||
} else if (attrSizedResults) {
|
||||
m.body() << formatv(attrSizedSegmentValueRangeCalcCode,
|
||||
"result_segment_sizes",
|
||||
"getOperation()->result_begin()");
|
||||
} else {
|
||||
llvm::SmallVector<StringRef, 4> isVariadic;
|
||||
isVariadic.reserve(numResults);
|
||||
for (int i = 0; i < numResults; ++i)
|
||||
isVariadic.push_back(op.getResult(i).isVariableLength() ? "true"
|
||||
: "false");
|
||||
std::string isVariadicList = llvm::join(isVariadic, ", ");
|
||||
|
||||
m.body() << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList,
|
||||
numNormalResults, numVariadicResults,
|
||||
"getOperation()->getNumResults()",
|
||||
"getOperation()->result_begin()", "result");
|
||||
}
|
||||
m.body() << formatv(valueRangeReturnCode, "getOperation()->result_begin()",
|
||||
"getODSResultIndexAndLength(index)");
|
||||
|
||||
for (int i = 0; i != numResults; ++i) {
|
||||
const auto &result = op.getResult(i);
|
||||
|
||||
@@ -33,7 +33,7 @@ TEST(OperandStorageTest, NonResizable) {
|
||||
Value operand = useOp->getResult(0);
|
||||
|
||||
// Create a non-resizable operation with one operand.
|
||||
Operation *user = createOp(&context, operand, builder.getIntegerType(16));
|
||||
Operation *user = createOp(&context, operand);
|
||||
|
||||
// The same number of operands is okay.
|
||||
user->setOperands(operand);
|
||||
@@ -57,7 +57,7 @@ TEST(OperandStorageTest, Resizable) {
|
||||
Value operand = useOp->getResult(0);
|
||||
|
||||
// Create a resizable operation with one operand.
|
||||
Operation *user = createOp(&context, operand, builder.getIntegerType(16));
|
||||
Operation *user = createOp(&context, operand);
|
||||
|
||||
// The same number of operands is okay.
|
||||
user->setOperands(operand);
|
||||
@@ -76,4 +76,77 @@ TEST(OperandStorageTest, Resizable) {
|
||||
useOp->destroy();
|
||||
}
|
||||
|
||||
TEST(OperandStorageTest, RangeReplace) {
|
||||
MLIRContext context;
|
||||
Builder builder(&context);
|
||||
|
||||
Operation *useOp =
|
||||
createOp(&context, /*operands=*/llvm::None, builder.getIntegerType(16));
|
||||
Value operand = useOp->getResult(0);
|
||||
|
||||
// Create a resizable operation with one operand.
|
||||
Operation *user = createOp(&context, operand);
|
||||
|
||||
// Check setting with the same number of operands.
|
||||
user->setOperands(/*start=*/0, /*length=*/1, operand);
|
||||
EXPECT_EQ(user->getNumOperands(), 1u);
|
||||
|
||||
// Check setting with more operands.
|
||||
user->setOperands(/*start=*/0, /*length=*/1, {operand, operand, operand});
|
||||
EXPECT_EQ(user->getNumOperands(), 3u);
|
||||
|
||||
// Check setting with less operands.
|
||||
user->setOperands(/*start=*/1, /*length=*/2, {operand});
|
||||
EXPECT_EQ(user->getNumOperands(), 2u);
|
||||
|
||||
// Check inserting without replacing operands.
|
||||
user->setOperands(/*start=*/2, /*length=*/0, {operand});
|
||||
EXPECT_EQ(user->getNumOperands(), 3u);
|
||||
|
||||
// Check erasing operands.
|
||||
user->setOperands(/*start=*/0, /*length=*/3, {});
|
||||
EXPECT_EQ(user->getNumOperands(), 0u);
|
||||
|
||||
// Destroy the operations.
|
||||
user->destroy();
|
||||
useOp->destroy();
|
||||
}
|
||||
|
||||
TEST(OperandStorageTest, MutableRange) {
|
||||
MLIRContext context;
|
||||
Builder builder(&context);
|
||||
|
||||
Operation *useOp =
|
||||
createOp(&context, /*operands=*/llvm::None, builder.getIntegerType(16));
|
||||
Value operand = useOp->getResult(0);
|
||||
|
||||
// Create a resizable operation with one operand.
|
||||
Operation *user = createOp(&context, operand);
|
||||
|
||||
// Check setting with the same number of operands.
|
||||
MutableOperandRange mutableOperands(user);
|
||||
mutableOperands.assign(operand);
|
||||
EXPECT_EQ(mutableOperands.size(), 1u);
|
||||
EXPECT_EQ(user->getNumOperands(), 1u);
|
||||
|
||||
// Check setting with more operands.
|
||||
mutableOperands.assign({operand, operand, operand});
|
||||
EXPECT_EQ(mutableOperands.size(), 3u);
|
||||
EXPECT_EQ(user->getNumOperands(), 3u);
|
||||
|
||||
// Check with inserting a new operand.
|
||||
mutableOperands.append({operand, operand});
|
||||
EXPECT_EQ(mutableOperands.size(), 5u);
|
||||
EXPECT_EQ(user->getNumOperands(), 5u);
|
||||
|
||||
// Check erasing operands.
|
||||
mutableOperands.clear();
|
||||
EXPECT_EQ(mutableOperands.size(), 0u);
|
||||
EXPECT_EQ(user->getNumOperands(), 0u);
|
||||
|
||||
// Destroy the operations.
|
||||
user->destroy();
|
||||
useOp->destroy();
|
||||
}
|
||||
|
||||
} // end namespace
|
||||
|
||||
Reference in New Issue
Block a user