mirror of
https://github.com/intel/llvm.git
synced 2026-01-25 01:07:04 +08:00
[mlir] Flip PDL to use Both accessors
This allows for incrementally updating the old API usages without needing to update everything at once. PDL will be left on Both for a little bit and then flipped to prefixed when all APIs have been updated. Differential Revision: https://reviews.llvm.org/D134387
This commit is contained in:
@@ -70,9 +70,8 @@ def PDL_Dialect : Dialect {
|
||||
void registerTypes();
|
||||
}];
|
||||
|
||||
// FIXME: Prefixed accessors overlap with builtin Operation members. Flip
|
||||
// once resolved.
|
||||
let emitAccessorPrefix = kEmitAccessorPrefix_Raw;
|
||||
// FIXME: Flip to prefixed.
|
||||
let emitAccessorPrefix = kEmitAccessorPrefix_Both;
|
||||
}
|
||||
|
||||
#endif // MLIR_DIALECT_PDL_IR_PDLDIALECT
|
||||
|
||||
@@ -122,10 +122,10 @@ def PDL_AttributeOp : PDL_Op<"attribute"> {
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins Optional<PDL_Type>:$type,
|
||||
let arguments = (ins Optional<PDL_Type>:$valueType,
|
||||
OptionalAttr<AnyAttr>:$value);
|
||||
let results = (outs PDL_Attribute:$attr);
|
||||
let assemblyFormat = "(`:` $type^)? (`=` $value^)? attr-dict-with-keyword";
|
||||
let assemblyFormat = "(`:` $valueType^)? (`=` $value^)? attr-dict-with-keyword";
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins CArg<"Value", "Value()">:$type), [{
|
||||
@@ -156,8 +156,8 @@ def PDL_EraseOp : PDL_Op<"erase", [HasParent<"pdl::RewriteOp">]> {
|
||||
pdl.erase %root
|
||||
```
|
||||
}];
|
||||
let arguments = (ins PDL_Operation:$operation);
|
||||
let assemblyFormat = "$operation attr-dict";
|
||||
let arguments = (ins PDL_Operation:$opValue);
|
||||
let assemblyFormat = "$opValue attr-dict";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -187,9 +187,9 @@ def PDL_OperandOp
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins Optional<PDL_Type>:$type);
|
||||
let results = (outs PDL_Value:$val);
|
||||
let assemblyFormat = "(`:` $type^)? attr-dict";
|
||||
let arguments = (ins Optional<PDL_Type>:$valueType);
|
||||
let results = (outs PDL_Value:$value);
|
||||
let assemblyFormat = "(`:` $valueType^)? attr-dict";
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins), [{
|
||||
@@ -226,9 +226,9 @@ def PDL_OperandsOp
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins Optional<PDL_RangeOf<PDL_Type>>:$type);
|
||||
let results = (outs PDL_RangeOf<PDL_Value>:$val);
|
||||
let assemblyFormat = "(`:` $type^)? attr-dict";
|
||||
let arguments = (ins Optional<PDL_RangeOf<PDL_Type>>:$valueType);
|
||||
let results = (outs PDL_RangeOf<PDL_Value>:$value);
|
||||
let assemblyFormat = "(`:` $valueType^)? attr-dict";
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins), [{
|
||||
@@ -341,16 +341,16 @@ def PDL_OperationOp : PDL_Op<"operation", [AttrSizedOperandSegments]> {
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins OptionalAttr<StrAttr>:$name,
|
||||
Variadic<PDL_InstOrRangeOf<PDL_Value>>:$operands,
|
||||
Variadic<PDL_Attribute>:$attributes,
|
||||
StrArrayAttr:$attributeNames,
|
||||
Variadic<PDL_InstOrRangeOf<PDL_Type>>:$types);
|
||||
let arguments = (ins OptionalAttr<StrAttr>:$opName,
|
||||
Variadic<PDL_InstOrRangeOf<PDL_Value>>:$operandValues,
|
||||
Variadic<PDL_Attribute>:$attributeValues,
|
||||
StrArrayAttr:$attributeValueNames,
|
||||
Variadic<PDL_InstOrRangeOf<PDL_Type>>:$typeValues);
|
||||
let results = (outs PDL_Operation:$op);
|
||||
let assemblyFormat = [{
|
||||
($name^)? (`(` $operands^ `:` type($operands) `)`)?
|
||||
custom<OperationOpAttributes>($attributes, $attributeNames)
|
||||
(`->` `(` $types^ `:` type($types) `)`)? attr-dict
|
||||
($opName^)? (`(` $operandValues^ `:` type($operandValues) `)`)?
|
||||
custom<OperationOpAttributes>($attributeValues, $attributeValueNames)
|
||||
(`->` `(` $typeValues^ `:` type($typeValues) `)`)? attr-dict
|
||||
}];
|
||||
|
||||
let builders = [
|
||||
@@ -413,9 +413,9 @@ def PDL_PatternOp : PDL_Op<"pattern", [
|
||||
|
||||
let arguments = (ins ConfinedAttr<I16Attr, [IntNonNegative]>:$benefit,
|
||||
OptionalAttr<SymbolNameAttr>:$sym_name);
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
let regions = (region SizedRegion<1>:$bodyRegion);
|
||||
let assemblyFormat = [{
|
||||
($sym_name^)? `:` `benefit` `(` $benefit `)` attr-dict-with-keyword $body
|
||||
($sym_name^)? `:` `benefit` `(` $benefit `)` attr-dict-with-keyword $bodyRegion
|
||||
}];
|
||||
|
||||
let builders = [
|
||||
@@ -467,11 +467,11 @@ def PDL_ReplaceOp : PDL_Op<"replace", [
|
||||
pdl.replace %root with %otherOp
|
||||
```
|
||||
}];
|
||||
let arguments = (ins PDL_Operation:$operation,
|
||||
let arguments = (ins PDL_Operation:$opValue,
|
||||
Optional<PDL_Operation>:$replOperation,
|
||||
Variadic<PDL_InstOrRangeOf<PDL_Value>>:$replValues);
|
||||
let assemblyFormat = [{
|
||||
$operation `with` (`(` $replValues^ `:` type($replValues) `)`)?
|
||||
$opValue `with` (`(` $replValues^ `:` type($replValues) `)`)?
|
||||
($replOperation^)? attr-dict
|
||||
}];
|
||||
let hasVerifier = 1;
|
||||
@@ -603,10 +603,10 @@ def PDL_RewriteOp : PDL_Op<"rewrite", [
|
||||
let arguments = (ins Optional<PDL_Operation>:$root,
|
||||
OptionalAttr<StrAttr>:$name,
|
||||
Variadic<PDL_AnyType>:$externalArgs);
|
||||
let regions = (region AnyRegion:$body);
|
||||
let regions = (region AnyRegion:$bodyRegion);
|
||||
let assemblyFormat = [{
|
||||
($root^)? (`with` $name^ (`(` $externalArgs^ `:` type($externalArgs) `)`)?)?
|
||||
($body^)?
|
||||
($bodyRegion^)?
|
||||
attr-dict-with-keyword
|
||||
}];
|
||||
let hasRegionVerifier = 1;
|
||||
@@ -635,9 +635,9 @@ def PDL_TypeOp : PDL_Op<"type"> {
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins OptionalAttr<TypeAttr>:$type);
|
||||
let arguments = (ins OptionalAttr<TypeAttr>:$constantType);
|
||||
let results = (outs PDL_Type:$result);
|
||||
let assemblyFormat = "attr-dict (`:` $type^)?";
|
||||
let assemblyFormat = "attr-dict (`:` $constantType^)?";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
@@ -664,9 +664,9 @@ def PDL_TypesOp : PDL_Op<"types"> {
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins OptionalAttr<TypeArrayAttr>:$types);
|
||||
let arguments = (ins OptionalAttr<TypeArrayAttr>:$constantTypes);
|
||||
let results = (outs PDL_RangeOf<PDL_Type>:$result);
|
||||
let assemblyFormat = "attr-dict (`:` $types^)?";
|
||||
let assemblyFormat = "attr-dict (`:` $constantTypes^)?";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
|
||||
@@ -575,8 +575,9 @@ void PatternLowering::generate(SuccessNode *successNode, Block *¤tBlock) {
|
||||
|
||||
// Collect the set of operations generated by the rewriter.
|
||||
SmallVector<StringRef, 4> generatedOps;
|
||||
for (auto op : pattern.getRewriter().body().getOps<pdl::OperationOp>())
|
||||
generatedOps.push_back(*op.name());
|
||||
for (auto op :
|
||||
pattern.getRewriter().getBodyRegion().getOps<pdl::OperationOp>())
|
||||
generatedOps.push_back(*op.getOpName());
|
||||
ArrayAttr generatedOpsAttr;
|
||||
if (!generatedOps.empty())
|
||||
generatedOpsAttr = builder.getStrArrayAttr(generatedOps);
|
||||
@@ -584,7 +585,7 @@ void PatternLowering::generate(SuccessNode *successNode, Block *¤tBlock) {
|
||||
// Grab the root kind if present.
|
||||
StringAttr rootKindAttr;
|
||||
if (pdl::OperationOp rootOp = root.getDefiningOp<pdl::OperationOp>())
|
||||
if (Optional<StringRef> rootKind = rootOp.name())
|
||||
if (Optional<StringRef> rootKind = rootOp.getOpName())
|
||||
rootKindAttr = builder.getStringAttr(*rootKind);
|
||||
|
||||
builder.setInsertionPointToEnd(currentBlock);
|
||||
@@ -620,12 +621,12 @@ SymbolRefAttr PatternLowering::generateRewriter(
|
||||
attrOp.getLoc(), value);
|
||||
}
|
||||
} else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) {
|
||||
if (TypeAttr type = typeOp.typeAttr()) {
|
||||
if (TypeAttr type = typeOp.getConstantTypeAttr()) {
|
||||
return newValue = builder.create<pdl_interp::CreateTypeOp>(
|
||||
typeOp.getLoc(), type);
|
||||
}
|
||||
} else if (pdl::TypesOp typeOp = dyn_cast<pdl::TypesOp>(oldOp)) {
|
||||
if (ArrayAttr type = typeOp.typesAttr()) {
|
||||
if (ArrayAttr type = typeOp.getConstantTypesAttr()) {
|
||||
return newValue = builder.create<pdl_interp::CreateTypesOp>(
|
||||
typeOp.getLoc(), typeOp.getType(), type);
|
||||
}
|
||||
@@ -699,18 +700,18 @@ void PatternLowering::generateRewriter(
|
||||
pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues,
|
||||
function_ref<Value(Value)> mapRewriteValue) {
|
||||
builder.create<pdl_interp::EraseOp>(eraseOp.getLoc(),
|
||||
mapRewriteValue(eraseOp.operation()));
|
||||
mapRewriteValue(eraseOp.getOpValue()));
|
||||
}
|
||||
|
||||
void PatternLowering::generateRewriter(
|
||||
pdl::OperationOp operationOp, DenseMap<Value, Value> &rewriteValues,
|
||||
function_ref<Value(Value)> mapRewriteValue) {
|
||||
SmallVector<Value, 4> operands;
|
||||
for (Value operand : operationOp.operands())
|
||||
for (Value operand : operationOp.getOperandValues())
|
||||
operands.push_back(mapRewriteValue(operand));
|
||||
|
||||
SmallVector<Value, 4> attributes;
|
||||
for (Value attr : operationOp.attributes())
|
||||
for (Value attr : operationOp.getAttributeValues())
|
||||
attributes.push_back(mapRewriteValue(attr));
|
||||
|
||||
bool hasInferredResultTypes = false;
|
||||
@@ -721,14 +722,14 @@ void PatternLowering::generateRewriter(
|
||||
// Create the new operation.
|
||||
Location loc = operationOp.getLoc();
|
||||
Value createdOp = builder.create<pdl_interp::CreateOperationOp>(
|
||||
loc, *operationOp.name(), types, hasInferredResultTypes, operands,
|
||||
attributes, operationOp.attributeNames());
|
||||
loc, *operationOp.getOpName(), types, hasInferredResultTypes, operands,
|
||||
attributes, operationOp.getAttributeValueNames());
|
||||
rewriteValues[operationOp.op()] = createdOp;
|
||||
|
||||
// Generate accesses for any results that have their types constrained.
|
||||
// Handle the case where there is a single range representing all of the
|
||||
// result types.
|
||||
OperandRange resultTys = operationOp.types();
|
||||
OperandRange resultTys = operationOp.getTypeValues();
|
||||
if (resultTys.size() == 1 && resultTys[0].getType().isa<pdl::RangeType>()) {
|
||||
Value &type = rewriteValues[resultTys[0]];
|
||||
if (!type) {
|
||||
@@ -772,8 +773,8 @@ void PatternLowering::generateRewriter(
|
||||
// user facing.
|
||||
if (Value replOp = replaceOp.replOperation()) {
|
||||
// Don't use replace if we know the replaced operation has no results.
|
||||
auto opOp = replaceOp.operation().getDefiningOp<pdl::OperationOp>();
|
||||
if (!opOp || !opOp.types().empty()) {
|
||||
auto opOp = replaceOp.getOpValue().getDefiningOp<pdl::OperationOp>();
|
||||
if (!opOp || !opOp.getTypeValues().empty()) {
|
||||
replOperands.push_back(builder.create<pdl_interp::GetResultsOp>(
|
||||
replOp.getLoc(), mapRewriteValue(replOp)));
|
||||
}
|
||||
@@ -784,13 +785,14 @@ void PatternLowering::generateRewriter(
|
||||
|
||||
// If there are no replacement values, just create an erase instead.
|
||||
if (replOperands.empty()) {
|
||||
builder.create<pdl_interp::EraseOp>(replaceOp.getLoc(),
|
||||
mapRewriteValue(replaceOp.operation()));
|
||||
builder.create<pdl_interp::EraseOp>(
|
||||
replaceOp.getLoc(), mapRewriteValue(replaceOp.getOpValue()));
|
||||
return;
|
||||
}
|
||||
|
||||
builder.create<pdl_interp::ReplaceOp>(
|
||||
replaceOp.getLoc(), mapRewriteValue(replaceOp.operation()), replOperands);
|
||||
builder.create<pdl_interp::ReplaceOp>(replaceOp.getLoc(),
|
||||
mapRewriteValue(replaceOp.getOpValue()),
|
||||
replOperands);
|
||||
}
|
||||
|
||||
void PatternLowering::generateRewriter(
|
||||
@@ -814,7 +816,7 @@ void PatternLowering::generateRewriter(
|
||||
function_ref<Value(Value)> mapRewriteValue) {
|
||||
// If the type isn't constant, the users (e.g. OperationOp) will resolve this
|
||||
// type.
|
||||
if (TypeAttr typeAttr = typeOp.typeAttr()) {
|
||||
if (TypeAttr typeAttr = typeOp.getConstantTypeAttr()) {
|
||||
rewriteValues[typeOp] =
|
||||
builder.create<pdl_interp::CreateTypeOp>(typeOp.getLoc(), typeAttr);
|
||||
}
|
||||
@@ -825,7 +827,7 @@ void PatternLowering::generateRewriter(
|
||||
function_ref<Value(Value)> mapRewriteValue) {
|
||||
// If the type isn't constant, the users (e.g. OperationOp) will resolve this
|
||||
// type.
|
||||
if (ArrayAttr typeAttr = typeOp.typesAttr()) {
|
||||
if (ArrayAttr typeAttr = typeOp.getConstantTypesAttr()) {
|
||||
rewriteValues[typeOp] = builder.create<pdl_interp::CreateTypesOp>(
|
||||
typeOp.getLoc(), typeOp.getType(), typeAttr);
|
||||
}
|
||||
@@ -840,7 +842,7 @@ void PatternLowering::generateOperationResultTypeRewriter(
|
||||
// Try to handle resolution for each of the result types individually. This is
|
||||
// preferred over type inferrence because it will allow for us to use existing
|
||||
// types directly, as opposed to trying to rebuild the type list.
|
||||
OperandRange resultTypeValues = op.types();
|
||||
OperandRange resultTypeValues = op.getTypeValues();
|
||||
auto tryResolveResultTypes = [&] {
|
||||
types.reserve(resultTypeValues.size());
|
||||
for (const auto &it : llvm::enumerate(resultTypeValues)) {
|
||||
@@ -886,7 +888,7 @@ void PatternLowering::generateOperationResultTypeRewriter(
|
||||
// rewrites only have single block regions, so if the op isn't in the
|
||||
// rewriter block (i.e. the current block of the operation) we already know
|
||||
// it dominates (i.e. it's in the matcher).
|
||||
Value replOpVal = replOpUser.operation();
|
||||
Value replOpVal = replOpUser.getOpValue();
|
||||
Operation *replacedOp = replOpVal.getDefiningOp();
|
||||
if (replacedOp->getBlock() == rewriterBlock &&
|
||||
!replacedOp->isBeforeInBlock(op))
|
||||
|
||||
@@ -53,7 +53,7 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList,
|
||||
predList.emplace_back(pos, builder.getIsNotNull());
|
||||
|
||||
// If the attribute has a type or value, add a constraint.
|
||||
if (Value type = attr.type())
|
||||
if (Value type = attr.getValueType())
|
||||
getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
|
||||
else if (Attribute value = attr.valueAttr())
|
||||
predList.emplace_back(pos, builder.getAttributeConstraint(value));
|
||||
@@ -76,7 +76,7 @@ static void getOperandTreePredicates(std::vector<PositionalPredicate> &predList,
|
||||
cast<OperandGroupPosition>(pos)->getOperandGroupNumber())
|
||||
predList.emplace_back(pos, builder.getIsNotNull());
|
||||
|
||||
if (Value type = op.type())
|
||||
if (Value type = op.getValueType())
|
||||
getTreePredicates(predList, type, builder, inputs,
|
||||
builder.getType(pos));
|
||||
})
|
||||
@@ -120,12 +120,12 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList,
|
||||
predList.emplace_back(pos, builder.getIsNotNull());
|
||||
|
||||
// Check that this is the correct root operation.
|
||||
if (Optional<StringRef> opName = op.name())
|
||||
if (Optional<StringRef> opName = op.getOpName())
|
||||
predList.emplace_back(pos, builder.getOperationName(*opName));
|
||||
|
||||
// Check that the operation has the proper number of operands. If there are
|
||||
// any variable length operands, we check a minimum instead of an exact count.
|
||||
OperandRange operands = op.operands();
|
||||
OperandRange operands = op.getOperandValues();
|
||||
unsigned minOperands = getNumNonRangeValues(operands);
|
||||
if (minOperands != operands.size()) {
|
||||
if (minOperands)
|
||||
@@ -136,7 +136,7 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList,
|
||||
|
||||
// Check that the operation has the proper number of results. If there are
|
||||
// any variable length results, we check a minimum instead of an exact count.
|
||||
OperandRange types = op.types();
|
||||
OperandRange types = op.getTypeValues();
|
||||
unsigned minResults = getNumNonRangeValues(types);
|
||||
if (minResults == types.size())
|
||||
predList.emplace_back(pos, builder.getResultCount(types.size()));
|
||||
@@ -144,11 +144,11 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList,
|
||||
predList.emplace_back(pos, builder.getResultCountAtLeast(minResults));
|
||||
|
||||
// Recurse into any attributes, operands, or results.
|
||||
for (auto it : llvm::zip(op.attributeNames(), op.attributes())) {
|
||||
for (auto [attrName, attr] :
|
||||
llvm::zip(op.getAttributeValueNames(), op.getAttributeValues())) {
|
||||
getTreePredicates(
|
||||
predList, std::get<1>(it), builder, inputs,
|
||||
builder.getAttribute(opPos,
|
||||
std::get<0>(it).cast<StringAttr>().getValue()));
|
||||
predList, attr, builder, inputs,
|
||||
builder.getAttribute(opPos, attrName.cast<StringAttr>().getValue()));
|
||||
}
|
||||
|
||||
// Process the operands and results of the operation. For all values up to
|
||||
@@ -208,10 +208,10 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList,
|
||||
TypePosition *pos) {
|
||||
// Check for a constraint on a constant type.
|
||||
if (pdl::TypeOp typeOp = val.getDefiningOp<pdl::TypeOp>()) {
|
||||
if (Attribute type = typeOp.typeAttr())
|
||||
if (Attribute type = typeOp.getConstantTypeAttr())
|
||||
predList.emplace_back(pos, builder.getTypeConstraint(type));
|
||||
} else if (pdl::TypesOp typeOp = val.getDefiningOp<pdl::TypesOp>()) {
|
||||
if (Attribute typeAttr = typeOp.typesAttr())
|
||||
if (Attribute typeAttr = typeOp.getConstantTypesAttr())
|
||||
predList.emplace_back(pos, builder.getTypeConstraint(typeAttr));
|
||||
}
|
||||
}
|
||||
@@ -327,7 +327,7 @@ static void getNonTreePredicates(pdl::PatternOp pattern,
|
||||
std::vector<PositionalPredicate> &predList,
|
||||
PredicateBuilder &builder,
|
||||
DenseMap<Value, Position *> &inputs) {
|
||||
for (Operation &op : pattern.body().getOps()) {
|
||||
for (Operation &op : pattern.getBodyRegion().getOps()) {
|
||||
TypeSwitch<Operation *>(&op)
|
||||
.Case([&](pdl::AttributeOp attrOp) {
|
||||
getAttributePredicates(attrOp, predList, builder, inputs);
|
||||
@@ -340,11 +340,13 @@ static void getNonTreePredicates(pdl::PatternOp pattern,
|
||||
})
|
||||
.Case([&](pdl::TypeOp typeOp) {
|
||||
getTypePredicates(
|
||||
typeOp, [&] { return typeOp.typeAttr(); }, builder, inputs);
|
||||
typeOp, [&] { return typeOp.getConstantTypeAttr(); }, builder,
|
||||
inputs);
|
||||
})
|
||||
.Case([&](pdl::TypesOp typeOp) {
|
||||
getTypePredicates(
|
||||
typeOp, [&] { return typeOp.typesAttr(); }, builder, inputs);
|
||||
typeOp, [&] { return typeOp.getConstantTypesAttr(); }, builder,
|
||||
inputs);
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -369,8 +371,8 @@ static SmallVector<Value> detectRoots(pdl::PatternOp pattern) {
|
||||
// First, collect all the operations that are used as operands
|
||||
// to other operations. These are not roots by default.
|
||||
DenseSet<Value> used;
|
||||
for (auto operationOp : pattern.body().getOps<pdl::OperationOp>()) {
|
||||
for (Value operand : operationOp.operands())
|
||||
for (auto operationOp : pattern.getBodyRegion().getOps<pdl::OperationOp>()) {
|
||||
for (Value operand : operationOp.getOperandValues())
|
||||
TypeSwitch<Operation *>(operand.getDefiningOp())
|
||||
.Case<pdl::ResultOp, pdl::ResultsOp>(
|
||||
[&used](auto resultOp) { used.insert(resultOp.parent()); });
|
||||
@@ -383,7 +385,7 @@ static SmallVector<Value> detectRoots(pdl::PatternOp pattern) {
|
||||
|
||||
// Finally, collect all the unused operations.
|
||||
SmallVector<Value> roots;
|
||||
for (Value operationOp : pattern.body().getOps<pdl::OperationOp>())
|
||||
for (Value operationOp : pattern.getBodyRegion().getOps<pdl::OperationOp>())
|
||||
if (!used.contains(operationOp))
|
||||
roots.push_back(operationOp);
|
||||
|
||||
@@ -451,7 +453,7 @@ static void buildCostGraph(ArrayRef<Value> roots, RootOrderingGraph &graph,
|
||||
// are expensive to join on.
|
||||
TypeSwitch<Operation *>(entry.value.getDefiningOp())
|
||||
.Case<pdl::OperationOp>([&](auto operationOp) {
|
||||
OperandRange operands = operationOp.operands();
|
||||
OperandRange operands = operationOp.getOperandValues();
|
||||
// Special case when we pass all the operands in one range.
|
||||
// For those, the index is empty.
|
||||
if (operands.size() == 1 &&
|
||||
@@ -462,7 +464,8 @@ static void buildCostGraph(ArrayRef<Value> roots, RootOrderingGraph &graph,
|
||||
}
|
||||
|
||||
// Default case: visit all the operands.
|
||||
for (const auto &p : llvm::enumerate(operationOp.operands()))
|
||||
for (const auto &p :
|
||||
llvm::enumerate(operationOp.getOperandValues()))
|
||||
toVisit.emplace(p.value(), entry.value, p.index(),
|
||||
entry.depth + 1);
|
||||
})
|
||||
@@ -507,7 +510,7 @@ static void buildCostGraph(ArrayRef<Value> roots, RootOrderingGraph &graph,
|
||||
/// Returns true if the operand at the given index needs to be queried using an
|
||||
/// operand group, i.e., if it is variadic itself or follows a variadic operand.
|
||||
static bool useOperandGroup(pdl::OperationOp op, unsigned index) {
|
||||
OperandRange operands = op.operands();
|
||||
OperandRange operands = op.getOperandValues();
|
||||
assert(index < operands.size() && "operand index out of range");
|
||||
for (unsigned i = 0; i <= index; ++i)
|
||||
if (operands[i].getType().isa<pdl::RangeType>())
|
||||
@@ -537,7 +540,7 @@ static void visitUpward(std::vector<PositionalPredicate> &predList,
|
||||
operandPos = builder.getAllOperands(opPos);
|
||||
} else if (useOperandGroup(operationOp, *opIndex.index)) {
|
||||
// We are querying an operand group.
|
||||
Type type = operationOp.operands()[*opIndex.index].getType();
|
||||
Type type = operationOp.getOperandValues()[*opIndex.index].getType();
|
||||
bool variadic = type.isa<pdl::RangeType>();
|
||||
operandPos = builder.getOperandGroup(opPos, opIndex.index, variadic);
|
||||
} else {
|
||||
|
||||
@@ -74,7 +74,7 @@ static void visit(Operation *op, DenseSet<Operation *> &visited) {
|
||||
// Traverse the operands / parent.
|
||||
TypeSwitch<Operation *>(op)
|
||||
.Case<OperationOp>([&visited](auto operation) {
|
||||
for (Value operand : operation.operands())
|
||||
for (Value operand : operation.getOperandValues())
|
||||
visit(operand.getDefiningOp(), visited);
|
||||
})
|
||||
.Case<ResultOp, ResultsOp>([&visited](auto result) {
|
||||
@@ -111,7 +111,7 @@ LogicalResult ApplyNativeRewriteOp::verify() {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult AttributeOp::verify() {
|
||||
Value attrType = type();
|
||||
Value attrType = getValueType();
|
||||
Optional<Attribute> attrValue = value();
|
||||
|
||||
if (!attrValue) {
|
||||
@@ -189,7 +189,7 @@ static LogicalResult verifyResultTypesAreInferrable(OperationOp op,
|
||||
if (!replOpUser || use.getOperandNumber() == 0)
|
||||
return false;
|
||||
// Make sure the replaced operation was defined before this one.
|
||||
Operation *replacedOp = replOpUser.operation().getDefiningOp();
|
||||
Operation *replacedOp = replOpUser.getOpValue().getDefiningOp();
|
||||
return replacedOp->getBlock() != rewriterBlock ||
|
||||
replacedOp->isBeforeInBlock(op);
|
||||
};
|
||||
@@ -203,7 +203,7 @@ static LogicalResult verifyResultTypesAreInferrable(OperationOp op,
|
||||
if (resultTypes.empty()) {
|
||||
// If we don't know the concrete operation, don't attempt any verification.
|
||||
// We can't make assumptions if we don't know the concrete operation.
|
||||
Optional<StringRef> rawOpName = op.name();
|
||||
Optional<StringRef> rawOpName = op.getOpName();
|
||||
if (!rawOpName)
|
||||
return success();
|
||||
Optional<RegisteredOperationName> opName =
|
||||
@@ -246,10 +246,12 @@ static LogicalResult verifyResultTypesAreInferrable(OperationOp op,
|
||||
isa<OperandOp, OperandsOp, OperationOp>(user);
|
||||
};
|
||||
if (TypeOp typeOp = dyn_cast<TypeOp>(resultTypeOp)) {
|
||||
if (typeOp.type() || llvm::any_of(typeOp->getUsers(), constrainsInput))
|
||||
if (typeOp.getConstantType() ||
|
||||
llvm::any_of(typeOp->getUsers(), constrainsInput))
|
||||
continue;
|
||||
} else if (TypesOp typeOp = dyn_cast<TypesOp>(resultTypeOp)) {
|
||||
if (typeOp.types() || llvm::any_of(typeOp->getUsers(), constrainsInput))
|
||||
if (typeOp.getConstantTypes() ||
|
||||
llvm::any_of(typeOp->getUsers(), constrainsInput))
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -264,11 +266,11 @@ static LogicalResult verifyResultTypesAreInferrable(OperationOp op,
|
||||
|
||||
LogicalResult OperationOp::verify() {
|
||||
bool isWithinRewrite = isa<RewriteOp>((*this)->getParentOp());
|
||||
if (isWithinRewrite && !name())
|
||||
if (isWithinRewrite && !getOpName())
|
||||
return emitOpError("must have an operation name when nested within "
|
||||
"a `pdl.rewrite`");
|
||||
ArrayAttr attributeNames = attributeNamesAttr();
|
||||
auto attributeValues = attributes();
|
||||
ArrayAttr attributeNames = getAttributeValueNamesAttr();
|
||||
auto attributeValues = getAttributeValues();
|
||||
if (attributeNames.size() != attributeValues.size()) {
|
||||
return emitOpError()
|
||||
<< "expected the same number of attribute values and attribute "
|
||||
@@ -280,7 +282,7 @@ LogicalResult OperationOp::verify() {
|
||||
// If the operation is within a rewrite body and doesn't have type inference,
|
||||
// ensure that the result types can be resolved.
|
||||
if (isWithinRewrite && !mightHaveTypeInference()) {
|
||||
if (failed(verifyResultTypesAreInferrable(*this, types())))
|
||||
if (failed(verifyResultTypesAreInferrable(*this, getTypeValues())))
|
||||
return failure();
|
||||
}
|
||||
|
||||
@@ -288,7 +290,7 @@ LogicalResult OperationOp::verify() {
|
||||
}
|
||||
|
||||
bool OperationOp::hasTypeInference() {
|
||||
if (Optional<StringRef> rawOpName = name()) {
|
||||
if (Optional<StringRef> rawOpName = getOpName()) {
|
||||
OperationName opName(*rawOpName, getContext());
|
||||
return opName.hasInterface<InferTypeOpInterface>();
|
||||
}
|
||||
@@ -296,7 +298,7 @@ bool OperationOp::hasTypeInference() {
|
||||
}
|
||||
|
||||
bool OperationOp::mightHaveTypeInference() {
|
||||
if (Optional<StringRef> rawOpName = name()) {
|
||||
if (Optional<StringRef> rawOpName = getOpName()) {
|
||||
OperationName opName(*rawOpName, getContext());
|
||||
return opName.mightHaveInterface<InferTypeOpInterface>();
|
||||
}
|
||||
@@ -387,7 +389,7 @@ void PatternOp::build(OpBuilder &builder, OperationState &state,
|
||||
|
||||
/// Returns the rewrite operation of this pattern.
|
||||
RewriteOp PatternOp::getRewriter() {
|
||||
return cast<RewriteOp>(body().front().getTerminator());
|
||||
return cast<RewriteOp>(getBodyRegion().front().getTerminator());
|
||||
}
|
||||
|
||||
/// The default dialect is `pdl`.
|
||||
@@ -441,7 +443,7 @@ LogicalResult ResultsOp::verify() {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult RewriteOp::verifyRegions() {
|
||||
Region &rewriteRegion = body();
|
||||
Region &rewriteRegion = getBodyRegion();
|
||||
|
||||
// Handle the case where the rewrite is external.
|
||||
if (name()) {
|
||||
@@ -477,7 +479,7 @@ StringRef RewriteOp::getDefaultDialect() {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult TypeOp::verify() {
|
||||
if (!typeAttr())
|
||||
if (!getConstantTypeAttr())
|
||||
return verifyHasBindingUse(*this);
|
||||
return success();
|
||||
}
|
||||
@@ -487,7 +489,7 @@ LogicalResult TypeOp::verify() {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult TypesOp::verify() {
|
||||
if (!typesAttr())
|
||||
if (!getConstantTypesAttr())
|
||||
return verifyHasBindingUse(*this);
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -203,7 +203,7 @@ static void checkAndNestUnderRewriteOp(OpBuilder &builder, Value rootExpr,
|
||||
pdl::RewriteOp rewrite =
|
||||
builder.create<pdl::RewriteOp>(loc, rootExpr, /*name=*/StringAttr(),
|
||||
/*externalArgs=*/ValueRange());
|
||||
builder.createBlock(&rewrite.body());
|
||||
builder.createBlock(&rewrite.getBodyRegion());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -86,14 +86,14 @@ class AttributeOp:
|
||||
"""Specialization for PDL attribute op class."""
|
||||
|
||||
def __init__(self,
|
||||
type: Optional[Union[OpView, Operation, Value]] = None,
|
||||
valueType: Optional[Union[OpView, Operation, Value]] = None,
|
||||
value: Optional[Attribute] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None):
|
||||
type = type if type is None else _get_value(type)
|
||||
valueType = valueType if valueType is None else _get_value(valueType)
|
||||
result = pdl.AttributeType.get()
|
||||
super().__init__(result, type=type, value=value, loc=loc, ip=ip)
|
||||
super().__init__(result, valueType=valueType, value=value, loc=loc, ip=ip)
|
||||
|
||||
|
||||
class EraseOp:
|
||||
@@ -118,7 +118,7 @@ class OperandOp:
|
||||
ip=None):
|
||||
type = type if type is None else _get_value(type)
|
||||
result = pdl.ValueType.get()
|
||||
super().__init__(result, type=type, loc=loc, ip=ip)
|
||||
super().__init__(result, valueType=type, loc=loc, ip=ip)
|
||||
|
||||
|
||||
class OperandsOp:
|
||||
@@ -131,7 +131,7 @@ class OperandsOp:
|
||||
ip=None):
|
||||
types = types if types is None else _get_value(types)
|
||||
result = pdl.RangeType.get(pdl.ValueType.get())
|
||||
super().__init__(result, type=types, loc=loc, ip=ip)
|
||||
super().__init__(result, valueType=types, loc=loc, ip=ip)
|
||||
|
||||
|
||||
class OperationOp:
|
||||
@@ -147,15 +147,15 @@ class OperationOp:
|
||||
ip=None):
|
||||
name = name if name is None else _get_str_attr(name)
|
||||
args = _get_values(args)
|
||||
attributeNames = []
|
||||
attributeValues = []
|
||||
attrNames = []
|
||||
attrValues = []
|
||||
for attrName, attrValue in attributes.items():
|
||||
attributeNames.append(StringAttr.get(attrName))
|
||||
attributeValues.append(_get_value(attrValue))
|
||||
attributeNames = ArrayAttr.get(attributeNames)
|
||||
attrNames.append(StringAttr.get(attrName))
|
||||
attrValues.append(_get_value(attrValue))
|
||||
attrNames = ArrayAttr.get(attrNames)
|
||||
types = _get_values(types)
|
||||
result = pdl.OperationType.get()
|
||||
super().__init__(result, args, attributeValues, attributeNames, types, name=name, loc=loc, ip=ip)
|
||||
super().__init__(result, args, attrValues, attrNames, types, opName=name, loc=loc, ip=ip)
|
||||
|
||||
|
||||
class PatternOp:
|
||||
@@ -255,24 +255,26 @@ class TypeOp:
|
||||
"""Specialization for PDL type op class."""
|
||||
|
||||
def __init__(self,
|
||||
type: Optional[Union[TypeAttr, Type]] = None,
|
||||
constantType: Optional[Union[TypeAttr, Type]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None):
|
||||
type = type if type is None else _get_type_attr(type)
|
||||
constantType = constantType if constantType is None else _get_type_attr(
|
||||
constantType)
|
||||
result = pdl.TypeType.get()
|
||||
super().__init__(result, type=type, loc=loc, ip=ip)
|
||||
super().__init__(result, constantType=constantType, loc=loc, ip=ip)
|
||||
|
||||
|
||||
class TypesOp:
|
||||
"""Specialization for PDL types op class."""
|
||||
|
||||
def __init__(self,
|
||||
types: Sequence[Union[TypeAttr, Type]] = [],
|
||||
constantTypes: Sequence[Union[TypeAttr, Type]] = [],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None):
|
||||
types = _get_array_attr([_get_type_attr(ty) for ty in types])
|
||||
types = None if not types else types
|
||||
constantTypes = _get_array_attr(
|
||||
[_get_type_attr(ty) for ty in constantTypes])
|
||||
constantTypes = None if not constantTypes else constantTypes
|
||||
result = pdl.RangeType.get(pdl.TypeType.get())
|
||||
super().__init__(result, types=types, loc=loc, ip=ip)
|
||||
super().__init__(result, constantTypes=constantTypes, loc=loc, ip=ip)
|
||||
|
||||
@@ -121,7 +121,7 @@ pdl.pattern : benefit(1) {
|
||||
pdl.pattern : benefit(1) {
|
||||
// expected-error@below {{expected the same number of attribute values and attribute names, got 1 names and 0 values}}
|
||||
%op = "pdl.operation"() {
|
||||
attributeNames = ["attr"],
|
||||
attributeValueNames = ["attr"],
|
||||
operand_segment_sizes = array<i32: 0, 0, 0>
|
||||
} : () -> (!pdl.operation)
|
||||
rewrite %op with "rewriter"
|
||||
|
||||
Reference in New Issue
Block a user