[mlir] Flip PDL to use Both accessors

This allows for incrementally updating the old API usages without
needing to update everything at once. PDL will be left on Both
for a little bit and then flipped to prefixed when all APIs have been
updated.

Differential Revision: https://reviews.llvm.org/D134387
This commit is contained in:
River Riddle
2022-09-21 13:40:47 -07:00
parent 986b5c56ea
commit 72fddfb599
8 changed files with 117 additions and 109 deletions

View File

@@ -70,9 +70,8 @@ def PDL_Dialect : Dialect {
void registerTypes();
}];
// FIXME: Prefixed accessors overlap with builtin Operation members. Flip
// once resolved.
let emitAccessorPrefix = kEmitAccessorPrefix_Raw;
// FIXME: Flip to prefixed.
let emitAccessorPrefix = kEmitAccessorPrefix_Both;
}
#endif // MLIR_DIALECT_PDL_IR_PDLDIALECT

View File

@@ -122,10 +122,10 @@ def PDL_AttributeOp : PDL_Op<"attribute"> {
```
}];
let arguments = (ins Optional<PDL_Type>:$type,
let arguments = (ins Optional<PDL_Type>:$valueType,
OptionalAttr<AnyAttr>:$value);
let results = (outs PDL_Attribute:$attr);
let assemblyFormat = "(`:` $type^)? (`=` $value^)? attr-dict-with-keyword";
let assemblyFormat = "(`:` $valueType^)? (`=` $value^)? attr-dict-with-keyword";
let builders = [
OpBuilder<(ins CArg<"Value", "Value()">:$type), [{
@@ -156,8 +156,8 @@ def PDL_EraseOp : PDL_Op<"erase", [HasParent<"pdl::RewriteOp">]> {
pdl.erase %root
```
}];
let arguments = (ins PDL_Operation:$operation);
let assemblyFormat = "$operation attr-dict";
let arguments = (ins PDL_Operation:$opValue);
let assemblyFormat = "$opValue attr-dict";
}
//===----------------------------------------------------------------------===//
@@ -187,9 +187,9 @@ def PDL_OperandOp
```
}];
let arguments = (ins Optional<PDL_Type>:$type);
let results = (outs PDL_Value:$val);
let assemblyFormat = "(`:` $type^)? attr-dict";
let arguments = (ins Optional<PDL_Type>:$valueType);
let results = (outs PDL_Value:$value);
let assemblyFormat = "(`:` $valueType^)? attr-dict";
let builders = [
OpBuilder<(ins), [{
@@ -226,9 +226,9 @@ def PDL_OperandsOp
```
}];
let arguments = (ins Optional<PDL_RangeOf<PDL_Type>>:$type);
let results = (outs PDL_RangeOf<PDL_Value>:$val);
let assemblyFormat = "(`:` $type^)? attr-dict";
let arguments = (ins Optional<PDL_RangeOf<PDL_Type>>:$valueType);
let results = (outs PDL_RangeOf<PDL_Value>:$value);
let assemblyFormat = "(`:` $valueType^)? attr-dict";
let builders = [
OpBuilder<(ins), [{
@@ -341,16 +341,16 @@ def PDL_OperationOp : PDL_Op<"operation", [AttrSizedOperandSegments]> {
```
}];
let arguments = (ins OptionalAttr<StrAttr>:$name,
Variadic<PDL_InstOrRangeOf<PDL_Value>>:$operands,
Variadic<PDL_Attribute>:$attributes,
StrArrayAttr:$attributeNames,
Variadic<PDL_InstOrRangeOf<PDL_Type>>:$types);
let arguments = (ins OptionalAttr<StrAttr>:$opName,
Variadic<PDL_InstOrRangeOf<PDL_Value>>:$operandValues,
Variadic<PDL_Attribute>:$attributeValues,
StrArrayAttr:$attributeValueNames,
Variadic<PDL_InstOrRangeOf<PDL_Type>>:$typeValues);
let results = (outs PDL_Operation:$op);
let assemblyFormat = [{
($name^)? (`(` $operands^ `:` type($operands) `)`)?
custom<OperationOpAttributes>($attributes, $attributeNames)
(`->` `(` $types^ `:` type($types) `)`)? attr-dict
($opName^)? (`(` $operandValues^ `:` type($operandValues) `)`)?
custom<OperationOpAttributes>($attributeValues, $attributeValueNames)
(`->` `(` $typeValues^ `:` type($typeValues) `)`)? attr-dict
}];
let builders = [
@@ -413,9 +413,9 @@ def PDL_PatternOp : PDL_Op<"pattern", [
let arguments = (ins ConfinedAttr<I16Attr, [IntNonNegative]>:$benefit,
OptionalAttr<SymbolNameAttr>:$sym_name);
let regions = (region SizedRegion<1>:$body);
let regions = (region SizedRegion<1>:$bodyRegion);
let assemblyFormat = [{
($sym_name^)? `:` `benefit` `(` $benefit `)` attr-dict-with-keyword $body
($sym_name^)? `:` `benefit` `(` $benefit `)` attr-dict-with-keyword $bodyRegion
}];
let builders = [
@@ -467,11 +467,11 @@ def PDL_ReplaceOp : PDL_Op<"replace", [
pdl.replace %root with %otherOp
```
}];
let arguments = (ins PDL_Operation:$operation,
let arguments = (ins PDL_Operation:$opValue,
Optional<PDL_Operation>:$replOperation,
Variadic<PDL_InstOrRangeOf<PDL_Value>>:$replValues);
let assemblyFormat = [{
$operation `with` (`(` $replValues^ `:` type($replValues) `)`)?
$opValue `with` (`(` $replValues^ `:` type($replValues) `)`)?
($replOperation^)? attr-dict
}];
let hasVerifier = 1;
@@ -603,10 +603,10 @@ def PDL_RewriteOp : PDL_Op<"rewrite", [
let arguments = (ins Optional<PDL_Operation>:$root,
OptionalAttr<StrAttr>:$name,
Variadic<PDL_AnyType>:$externalArgs);
let regions = (region AnyRegion:$body);
let regions = (region AnyRegion:$bodyRegion);
let assemblyFormat = [{
($root^)? (`with` $name^ (`(` $externalArgs^ `:` type($externalArgs) `)`)?)?
($body^)?
($bodyRegion^)?
attr-dict-with-keyword
}];
let hasRegionVerifier = 1;
@@ -635,9 +635,9 @@ def PDL_TypeOp : PDL_Op<"type"> {
```
}];
let arguments = (ins OptionalAttr<TypeAttr>:$type);
let arguments = (ins OptionalAttr<TypeAttr>:$constantType);
let results = (outs PDL_Type:$result);
let assemblyFormat = "attr-dict (`:` $type^)?";
let assemblyFormat = "attr-dict (`:` $constantType^)?";
let hasVerifier = 1;
}
@@ -664,9 +664,9 @@ def PDL_TypesOp : PDL_Op<"types"> {
```
}];
let arguments = (ins OptionalAttr<TypeArrayAttr>:$types);
let arguments = (ins OptionalAttr<TypeArrayAttr>:$constantTypes);
let results = (outs PDL_RangeOf<PDL_Type>:$result);
let assemblyFormat = "attr-dict (`:` $types^)?";
let assemblyFormat = "attr-dict (`:` $constantTypes^)?";
let hasVerifier = 1;
}

View File

@@ -575,8 +575,9 @@ void PatternLowering::generate(SuccessNode *successNode, Block *&currentBlock) {
// Collect the set of operations generated by the rewriter.
SmallVector<StringRef, 4> generatedOps;
for (auto op : pattern.getRewriter().body().getOps<pdl::OperationOp>())
generatedOps.push_back(*op.name());
for (auto op :
pattern.getRewriter().getBodyRegion().getOps<pdl::OperationOp>())
generatedOps.push_back(*op.getOpName());
ArrayAttr generatedOpsAttr;
if (!generatedOps.empty())
generatedOpsAttr = builder.getStrArrayAttr(generatedOps);
@@ -584,7 +585,7 @@ void PatternLowering::generate(SuccessNode *successNode, Block *&currentBlock) {
// Grab the root kind if present.
StringAttr rootKindAttr;
if (pdl::OperationOp rootOp = root.getDefiningOp<pdl::OperationOp>())
if (Optional<StringRef> rootKind = rootOp.name())
if (Optional<StringRef> rootKind = rootOp.getOpName())
rootKindAttr = builder.getStringAttr(*rootKind);
builder.setInsertionPointToEnd(currentBlock);
@@ -620,12 +621,12 @@ SymbolRefAttr PatternLowering::generateRewriter(
attrOp.getLoc(), value);
}
} else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) {
if (TypeAttr type = typeOp.typeAttr()) {
if (TypeAttr type = typeOp.getConstantTypeAttr()) {
return newValue = builder.create<pdl_interp::CreateTypeOp>(
typeOp.getLoc(), type);
}
} else if (pdl::TypesOp typeOp = dyn_cast<pdl::TypesOp>(oldOp)) {
if (ArrayAttr type = typeOp.typesAttr()) {
if (ArrayAttr type = typeOp.getConstantTypesAttr()) {
return newValue = builder.create<pdl_interp::CreateTypesOp>(
typeOp.getLoc(), typeOp.getType(), type);
}
@@ -699,18 +700,18 @@ void PatternLowering::generateRewriter(
pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue) {
builder.create<pdl_interp::EraseOp>(eraseOp.getLoc(),
mapRewriteValue(eraseOp.operation()));
mapRewriteValue(eraseOp.getOpValue()));
}
void PatternLowering::generateRewriter(
pdl::OperationOp operationOp, DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue) {
SmallVector<Value, 4> operands;
for (Value operand : operationOp.operands())
for (Value operand : operationOp.getOperandValues())
operands.push_back(mapRewriteValue(operand));
SmallVector<Value, 4> attributes;
for (Value attr : operationOp.attributes())
for (Value attr : operationOp.getAttributeValues())
attributes.push_back(mapRewriteValue(attr));
bool hasInferredResultTypes = false;
@@ -721,14 +722,14 @@ void PatternLowering::generateRewriter(
// Create the new operation.
Location loc = operationOp.getLoc();
Value createdOp = builder.create<pdl_interp::CreateOperationOp>(
loc, *operationOp.name(), types, hasInferredResultTypes, operands,
attributes, operationOp.attributeNames());
loc, *operationOp.getOpName(), types, hasInferredResultTypes, operands,
attributes, operationOp.getAttributeValueNames());
rewriteValues[operationOp.op()] = createdOp;
// Generate accesses for any results that have their types constrained.
// Handle the case where there is a single range representing all of the
// result types.
OperandRange resultTys = operationOp.types();
OperandRange resultTys = operationOp.getTypeValues();
if (resultTys.size() == 1 && resultTys[0].getType().isa<pdl::RangeType>()) {
Value &type = rewriteValues[resultTys[0]];
if (!type) {
@@ -772,8 +773,8 @@ void PatternLowering::generateRewriter(
// user facing.
if (Value replOp = replaceOp.replOperation()) {
// Don't use replace if we know the replaced operation has no results.
auto opOp = replaceOp.operation().getDefiningOp<pdl::OperationOp>();
if (!opOp || !opOp.types().empty()) {
auto opOp = replaceOp.getOpValue().getDefiningOp<pdl::OperationOp>();
if (!opOp || !opOp.getTypeValues().empty()) {
replOperands.push_back(builder.create<pdl_interp::GetResultsOp>(
replOp.getLoc(), mapRewriteValue(replOp)));
}
@@ -784,13 +785,14 @@ void PatternLowering::generateRewriter(
// If there are no replacement values, just create an erase instead.
if (replOperands.empty()) {
builder.create<pdl_interp::EraseOp>(replaceOp.getLoc(),
mapRewriteValue(replaceOp.operation()));
builder.create<pdl_interp::EraseOp>(
replaceOp.getLoc(), mapRewriteValue(replaceOp.getOpValue()));
return;
}
builder.create<pdl_interp::ReplaceOp>(
replaceOp.getLoc(), mapRewriteValue(replaceOp.operation()), replOperands);
builder.create<pdl_interp::ReplaceOp>(replaceOp.getLoc(),
mapRewriteValue(replaceOp.getOpValue()),
replOperands);
}
void PatternLowering::generateRewriter(
@@ -814,7 +816,7 @@ void PatternLowering::generateRewriter(
function_ref<Value(Value)> mapRewriteValue) {
// If the type isn't constant, the users (e.g. OperationOp) will resolve this
// type.
if (TypeAttr typeAttr = typeOp.typeAttr()) {
if (TypeAttr typeAttr = typeOp.getConstantTypeAttr()) {
rewriteValues[typeOp] =
builder.create<pdl_interp::CreateTypeOp>(typeOp.getLoc(), typeAttr);
}
@@ -825,7 +827,7 @@ void PatternLowering::generateRewriter(
function_ref<Value(Value)> mapRewriteValue) {
// If the type isn't constant, the users (e.g. OperationOp) will resolve this
// type.
if (ArrayAttr typeAttr = typeOp.typesAttr()) {
if (ArrayAttr typeAttr = typeOp.getConstantTypesAttr()) {
rewriteValues[typeOp] = builder.create<pdl_interp::CreateTypesOp>(
typeOp.getLoc(), typeOp.getType(), typeAttr);
}
@@ -840,7 +842,7 @@ void PatternLowering::generateOperationResultTypeRewriter(
// Try to handle resolution for each of the result types individually. This is
// preferred over type inferrence because it will allow for us to use existing
// types directly, as opposed to trying to rebuild the type list.
OperandRange resultTypeValues = op.types();
OperandRange resultTypeValues = op.getTypeValues();
auto tryResolveResultTypes = [&] {
types.reserve(resultTypeValues.size());
for (const auto &it : llvm::enumerate(resultTypeValues)) {
@@ -886,7 +888,7 @@ void PatternLowering::generateOperationResultTypeRewriter(
// rewrites only have single block regions, so if the op isn't in the
// rewriter block (i.e. the current block of the operation) we already know
// it dominates (i.e. it's in the matcher).
Value replOpVal = replOpUser.operation();
Value replOpVal = replOpUser.getOpValue();
Operation *replacedOp = replOpVal.getDefiningOp();
if (replacedOp->getBlock() == rewriterBlock &&
!replacedOp->isBeforeInBlock(op))

View File

@@ -53,7 +53,7 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList,
predList.emplace_back(pos, builder.getIsNotNull());
// If the attribute has a type or value, add a constraint.
if (Value type = attr.type())
if (Value type = attr.getValueType())
getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
else if (Attribute value = attr.valueAttr())
predList.emplace_back(pos, builder.getAttributeConstraint(value));
@@ -76,7 +76,7 @@ static void getOperandTreePredicates(std::vector<PositionalPredicate> &predList,
cast<OperandGroupPosition>(pos)->getOperandGroupNumber())
predList.emplace_back(pos, builder.getIsNotNull());
if (Value type = op.type())
if (Value type = op.getValueType())
getTreePredicates(predList, type, builder, inputs,
builder.getType(pos));
})
@@ -120,12 +120,12 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList,
predList.emplace_back(pos, builder.getIsNotNull());
// Check that this is the correct root operation.
if (Optional<StringRef> opName = op.name())
if (Optional<StringRef> opName = op.getOpName())
predList.emplace_back(pos, builder.getOperationName(*opName));
// Check that the operation has the proper number of operands. If there are
// any variable length operands, we check a minimum instead of an exact count.
OperandRange operands = op.operands();
OperandRange operands = op.getOperandValues();
unsigned minOperands = getNumNonRangeValues(operands);
if (minOperands != operands.size()) {
if (minOperands)
@@ -136,7 +136,7 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList,
// Check that the operation has the proper number of results. If there are
// any variable length results, we check a minimum instead of an exact count.
OperandRange types = op.types();
OperandRange types = op.getTypeValues();
unsigned minResults = getNumNonRangeValues(types);
if (minResults == types.size())
predList.emplace_back(pos, builder.getResultCount(types.size()));
@@ -144,11 +144,11 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList,
predList.emplace_back(pos, builder.getResultCountAtLeast(minResults));
// Recurse into any attributes, operands, or results.
for (auto it : llvm::zip(op.attributeNames(), op.attributes())) {
for (auto [attrName, attr] :
llvm::zip(op.getAttributeValueNames(), op.getAttributeValues())) {
getTreePredicates(
predList, std::get<1>(it), builder, inputs,
builder.getAttribute(opPos,
std::get<0>(it).cast<StringAttr>().getValue()));
predList, attr, builder, inputs,
builder.getAttribute(opPos, attrName.cast<StringAttr>().getValue()));
}
// Process the operands and results of the operation. For all values up to
@@ -208,10 +208,10 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList,
TypePosition *pos) {
// Check for a constraint on a constant type.
if (pdl::TypeOp typeOp = val.getDefiningOp<pdl::TypeOp>()) {
if (Attribute type = typeOp.typeAttr())
if (Attribute type = typeOp.getConstantTypeAttr())
predList.emplace_back(pos, builder.getTypeConstraint(type));
} else if (pdl::TypesOp typeOp = val.getDefiningOp<pdl::TypesOp>()) {
if (Attribute typeAttr = typeOp.typesAttr())
if (Attribute typeAttr = typeOp.getConstantTypesAttr())
predList.emplace_back(pos, builder.getTypeConstraint(typeAttr));
}
}
@@ -327,7 +327,7 @@ static void getNonTreePredicates(pdl::PatternOp pattern,
std::vector<PositionalPredicate> &predList,
PredicateBuilder &builder,
DenseMap<Value, Position *> &inputs) {
for (Operation &op : pattern.body().getOps()) {
for (Operation &op : pattern.getBodyRegion().getOps()) {
TypeSwitch<Operation *>(&op)
.Case([&](pdl::AttributeOp attrOp) {
getAttributePredicates(attrOp, predList, builder, inputs);
@@ -340,11 +340,13 @@ static void getNonTreePredicates(pdl::PatternOp pattern,
})
.Case([&](pdl::TypeOp typeOp) {
getTypePredicates(
typeOp, [&] { return typeOp.typeAttr(); }, builder, inputs);
typeOp, [&] { return typeOp.getConstantTypeAttr(); }, builder,
inputs);
})
.Case([&](pdl::TypesOp typeOp) {
getTypePredicates(
typeOp, [&] { return typeOp.typesAttr(); }, builder, inputs);
typeOp, [&] { return typeOp.getConstantTypesAttr(); }, builder,
inputs);
});
}
}
@@ -369,8 +371,8 @@ static SmallVector<Value> detectRoots(pdl::PatternOp pattern) {
// First, collect all the operations that are used as operands
// to other operations. These are not roots by default.
DenseSet<Value> used;
for (auto operationOp : pattern.body().getOps<pdl::OperationOp>()) {
for (Value operand : operationOp.operands())
for (auto operationOp : pattern.getBodyRegion().getOps<pdl::OperationOp>()) {
for (Value operand : operationOp.getOperandValues())
TypeSwitch<Operation *>(operand.getDefiningOp())
.Case<pdl::ResultOp, pdl::ResultsOp>(
[&used](auto resultOp) { used.insert(resultOp.parent()); });
@@ -383,7 +385,7 @@ static SmallVector<Value> detectRoots(pdl::PatternOp pattern) {
// Finally, collect all the unused operations.
SmallVector<Value> roots;
for (Value operationOp : pattern.body().getOps<pdl::OperationOp>())
for (Value operationOp : pattern.getBodyRegion().getOps<pdl::OperationOp>())
if (!used.contains(operationOp))
roots.push_back(operationOp);
@@ -451,7 +453,7 @@ static void buildCostGraph(ArrayRef<Value> roots, RootOrderingGraph &graph,
// are expensive to join on.
TypeSwitch<Operation *>(entry.value.getDefiningOp())
.Case<pdl::OperationOp>([&](auto operationOp) {
OperandRange operands = operationOp.operands();
OperandRange operands = operationOp.getOperandValues();
// Special case when we pass all the operands in one range.
// For those, the index is empty.
if (operands.size() == 1 &&
@@ -462,7 +464,8 @@ static void buildCostGraph(ArrayRef<Value> roots, RootOrderingGraph &graph,
}
// Default case: visit all the operands.
for (const auto &p : llvm::enumerate(operationOp.operands()))
for (const auto &p :
llvm::enumerate(operationOp.getOperandValues()))
toVisit.emplace(p.value(), entry.value, p.index(),
entry.depth + 1);
})
@@ -507,7 +510,7 @@ static void buildCostGraph(ArrayRef<Value> roots, RootOrderingGraph &graph,
/// Returns true if the operand at the given index needs to be queried using an
/// operand group, i.e., if it is variadic itself or follows a variadic operand.
static bool useOperandGroup(pdl::OperationOp op, unsigned index) {
OperandRange operands = op.operands();
OperandRange operands = op.getOperandValues();
assert(index < operands.size() && "operand index out of range");
for (unsigned i = 0; i <= index; ++i)
if (operands[i].getType().isa<pdl::RangeType>())
@@ -537,7 +540,7 @@ static void visitUpward(std::vector<PositionalPredicate> &predList,
operandPos = builder.getAllOperands(opPos);
} else if (useOperandGroup(operationOp, *opIndex.index)) {
// We are querying an operand group.
Type type = operationOp.operands()[*opIndex.index].getType();
Type type = operationOp.getOperandValues()[*opIndex.index].getType();
bool variadic = type.isa<pdl::RangeType>();
operandPos = builder.getOperandGroup(opPos, opIndex.index, variadic);
} else {

View File

@@ -74,7 +74,7 @@ static void visit(Operation *op, DenseSet<Operation *> &visited) {
// Traverse the operands / parent.
TypeSwitch<Operation *>(op)
.Case<OperationOp>([&visited](auto operation) {
for (Value operand : operation.operands())
for (Value operand : operation.getOperandValues())
visit(operand.getDefiningOp(), visited);
})
.Case<ResultOp, ResultsOp>([&visited](auto result) {
@@ -111,7 +111,7 @@ LogicalResult ApplyNativeRewriteOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult AttributeOp::verify() {
Value attrType = type();
Value attrType = getValueType();
Optional<Attribute> attrValue = value();
if (!attrValue) {
@@ -189,7 +189,7 @@ static LogicalResult verifyResultTypesAreInferrable(OperationOp op,
if (!replOpUser || use.getOperandNumber() == 0)
return false;
// Make sure the replaced operation was defined before this one.
Operation *replacedOp = replOpUser.operation().getDefiningOp();
Operation *replacedOp = replOpUser.getOpValue().getDefiningOp();
return replacedOp->getBlock() != rewriterBlock ||
replacedOp->isBeforeInBlock(op);
};
@@ -203,7 +203,7 @@ static LogicalResult verifyResultTypesAreInferrable(OperationOp op,
if (resultTypes.empty()) {
// If we don't know the concrete operation, don't attempt any verification.
// We can't make assumptions if we don't know the concrete operation.
Optional<StringRef> rawOpName = op.name();
Optional<StringRef> rawOpName = op.getOpName();
if (!rawOpName)
return success();
Optional<RegisteredOperationName> opName =
@@ -246,10 +246,12 @@ static LogicalResult verifyResultTypesAreInferrable(OperationOp op,
isa<OperandOp, OperandsOp, OperationOp>(user);
};
if (TypeOp typeOp = dyn_cast<TypeOp>(resultTypeOp)) {
if (typeOp.type() || llvm::any_of(typeOp->getUsers(), constrainsInput))
if (typeOp.getConstantType() ||
llvm::any_of(typeOp->getUsers(), constrainsInput))
continue;
} else if (TypesOp typeOp = dyn_cast<TypesOp>(resultTypeOp)) {
if (typeOp.types() || llvm::any_of(typeOp->getUsers(), constrainsInput))
if (typeOp.getConstantTypes() ||
llvm::any_of(typeOp->getUsers(), constrainsInput))
continue;
}
@@ -264,11 +266,11 @@ static LogicalResult verifyResultTypesAreInferrable(OperationOp op,
LogicalResult OperationOp::verify() {
bool isWithinRewrite = isa<RewriteOp>((*this)->getParentOp());
if (isWithinRewrite && !name())
if (isWithinRewrite && !getOpName())
return emitOpError("must have an operation name when nested within "
"a `pdl.rewrite`");
ArrayAttr attributeNames = attributeNamesAttr();
auto attributeValues = attributes();
ArrayAttr attributeNames = getAttributeValueNamesAttr();
auto attributeValues = getAttributeValues();
if (attributeNames.size() != attributeValues.size()) {
return emitOpError()
<< "expected the same number of attribute values and attribute "
@@ -280,7 +282,7 @@ LogicalResult OperationOp::verify() {
// If the operation is within a rewrite body and doesn't have type inference,
// ensure that the result types can be resolved.
if (isWithinRewrite && !mightHaveTypeInference()) {
if (failed(verifyResultTypesAreInferrable(*this, types())))
if (failed(verifyResultTypesAreInferrable(*this, getTypeValues())))
return failure();
}
@@ -288,7 +290,7 @@ LogicalResult OperationOp::verify() {
}
bool OperationOp::hasTypeInference() {
if (Optional<StringRef> rawOpName = name()) {
if (Optional<StringRef> rawOpName = getOpName()) {
OperationName opName(*rawOpName, getContext());
return opName.hasInterface<InferTypeOpInterface>();
}
@@ -296,7 +298,7 @@ bool OperationOp::hasTypeInference() {
}
bool OperationOp::mightHaveTypeInference() {
if (Optional<StringRef> rawOpName = name()) {
if (Optional<StringRef> rawOpName = getOpName()) {
OperationName opName(*rawOpName, getContext());
return opName.mightHaveInterface<InferTypeOpInterface>();
}
@@ -387,7 +389,7 @@ void PatternOp::build(OpBuilder &builder, OperationState &state,
/// Returns the rewrite operation of this pattern.
RewriteOp PatternOp::getRewriter() {
return cast<RewriteOp>(body().front().getTerminator());
return cast<RewriteOp>(getBodyRegion().front().getTerminator());
}
/// The default dialect is `pdl`.
@@ -441,7 +443,7 @@ LogicalResult ResultsOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult RewriteOp::verifyRegions() {
Region &rewriteRegion = body();
Region &rewriteRegion = getBodyRegion();
// Handle the case where the rewrite is external.
if (name()) {
@@ -477,7 +479,7 @@ StringRef RewriteOp::getDefaultDialect() {
//===----------------------------------------------------------------------===//
LogicalResult TypeOp::verify() {
if (!typeAttr())
if (!getConstantTypeAttr())
return verifyHasBindingUse(*this);
return success();
}
@@ -487,7 +489,7 @@ LogicalResult TypeOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult TypesOp::verify() {
if (!typesAttr())
if (!getConstantTypesAttr())
return verifyHasBindingUse(*this);
return success();
}

View File

@@ -203,7 +203,7 @@ static void checkAndNestUnderRewriteOp(OpBuilder &builder, Value rootExpr,
pdl::RewriteOp rewrite =
builder.create<pdl::RewriteOp>(loc, rootExpr, /*name=*/StringAttr(),
/*externalArgs=*/ValueRange());
builder.createBlock(&rewrite.body());
builder.createBlock(&rewrite.getBodyRegion());
}
}

View File

@@ -86,14 +86,14 @@ class AttributeOp:
"""Specialization for PDL attribute op class."""
def __init__(self,
type: Optional[Union[OpView, Operation, Value]] = None,
valueType: Optional[Union[OpView, Operation, Value]] = None,
value: Optional[Attribute] = None,
*,
loc=None,
ip=None):
type = type if type is None else _get_value(type)
valueType = valueType if valueType is None else _get_value(valueType)
result = pdl.AttributeType.get()
super().__init__(result, type=type, value=value, loc=loc, ip=ip)
super().__init__(result, valueType=valueType, value=value, loc=loc, ip=ip)
class EraseOp:
@@ -118,7 +118,7 @@ class OperandOp:
ip=None):
type = type if type is None else _get_value(type)
result = pdl.ValueType.get()
super().__init__(result, type=type, loc=loc, ip=ip)
super().__init__(result, valueType=type, loc=loc, ip=ip)
class OperandsOp:
@@ -131,7 +131,7 @@ class OperandsOp:
ip=None):
types = types if types is None else _get_value(types)
result = pdl.RangeType.get(pdl.ValueType.get())
super().__init__(result, type=types, loc=loc, ip=ip)
super().__init__(result, valueType=types, loc=loc, ip=ip)
class OperationOp:
@@ -147,15 +147,15 @@ class OperationOp:
ip=None):
name = name if name is None else _get_str_attr(name)
args = _get_values(args)
attributeNames = []
attributeValues = []
attrNames = []
attrValues = []
for attrName, attrValue in attributes.items():
attributeNames.append(StringAttr.get(attrName))
attributeValues.append(_get_value(attrValue))
attributeNames = ArrayAttr.get(attributeNames)
attrNames.append(StringAttr.get(attrName))
attrValues.append(_get_value(attrValue))
attrNames = ArrayAttr.get(attrNames)
types = _get_values(types)
result = pdl.OperationType.get()
super().__init__(result, args, attributeValues, attributeNames, types, name=name, loc=loc, ip=ip)
super().__init__(result, args, attrValues, attrNames, types, opName=name, loc=loc, ip=ip)
class PatternOp:
@@ -255,24 +255,26 @@ class TypeOp:
"""Specialization for PDL type op class."""
def __init__(self,
type: Optional[Union[TypeAttr, Type]] = None,
constantType: Optional[Union[TypeAttr, Type]] = None,
*,
loc=None,
ip=None):
type = type if type is None else _get_type_attr(type)
constantType = constantType if constantType is None else _get_type_attr(
constantType)
result = pdl.TypeType.get()
super().__init__(result, type=type, loc=loc, ip=ip)
super().__init__(result, constantType=constantType, loc=loc, ip=ip)
class TypesOp:
"""Specialization for PDL types op class."""
def __init__(self,
types: Sequence[Union[TypeAttr, Type]] = [],
constantTypes: Sequence[Union[TypeAttr, Type]] = [],
*,
loc=None,
ip=None):
types = _get_array_attr([_get_type_attr(ty) for ty in types])
types = None if not types else types
constantTypes = _get_array_attr(
[_get_type_attr(ty) for ty in constantTypes])
constantTypes = None if not constantTypes else constantTypes
result = pdl.RangeType.get(pdl.TypeType.get())
super().__init__(result, types=types, loc=loc, ip=ip)
super().__init__(result, constantTypes=constantTypes, loc=loc, ip=ip)

View File

@@ -121,7 +121,7 @@ pdl.pattern : benefit(1) {
pdl.pattern : benefit(1) {
// expected-error@below {{expected the same number of attribute values and attribute names, got 1 names and 0 values}}
%op = "pdl.operation"() {
attributeNames = ["attr"],
attributeValueNames = ["attr"],
operand_segment_sizes = array<i32: 0, 0, 0>
} : () -> (!pdl.operation)
rewrite %op with "rewriter"