[mlir] use irdl as matcher description in transform (#89779)

Introduce a new Transform dialect extension that uses IRDL op
definitions as matcher descriptors. IRDL allows one to essentially
define additional op constraits to be verified and, unlike PDL, does not
assume rewriting will happen. Leverage IRDL verification capability to
filter out ops that match an IRDL definition without actually
registering the corresponding operation with the system.
This commit is contained in:
Oleksandr "Alex" Zinenko
2024-05-02 15:03:30 +02:00
committed by GitHub
parent 11bda17254
commit 105c992c83
14 changed files with 308 additions and 33 deletions

View File

@@ -30,7 +30,10 @@ class DynamicTypeDefinition;
namespace mlir {
namespace irdl {
class AttributeOp;
class Constraint;
class OperationOp;
class TypeOp;
/// Provides context to the verification of constraints.
/// It contains the assignment of variables to attributes, and the assignment
@@ -246,6 +249,14 @@ private:
std::optional<SmallVector<unsigned>> argumentConstraints;
std::optional<size_t> blockCount;
};
/// Generate an op verifier function from the given IRDL operation definition.
llvm::unique_function<LogicalResult(Operation *) const> createVerifier(
OperationOp operation,
const DenseMap<irdl::TypeOp, std::unique_ptr<DynamicTypeDefinition>>
&typeDefs,
const DenseMap<irdl::AttributeOp, std::unique_ptr<DynamicAttrDefinition>>
&attrDefs);
} // namespace irdl
} // namespace mlir

View File

@@ -1,6 +1,7 @@
add_subdirectory(DebugExtension)
add_subdirectory(Interfaces)
add_subdirectory(IR)
add_subdirectory(IRDLExtension)
add_subdirectory(LoopExtension)
add_subdirectory(PDLExtension)
add_subdirectory(Transforms)

View File

@@ -0,0 +1,6 @@
set(LLVM_TARGET_DEFINITIONS IRDLExtensionOps.td)
mlir_tablegen(IRDLExtensionOps.h.inc -gen-op-decls)
mlir_tablegen(IRDLExtensionOps.cpp.inc -gen-op-defs)
add_public_tablegen_target(MLIRTransformDialectIRDLExtensionOpsIncGen)
add_mlir_doc(IRDLExtensionOps IRDLExtensionOps Dialects/ -gen-op-doc)

View File

@@ -0,0 +1,21 @@
//===- IRDLExtension.h - IRDL extension for Transform dialect ---*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_TRANSFORM_IRDLEXTENSION_IRDLEXTENSION_H
#define MLIR_DIALECT_TRANSFORM_IRDLEXTENSION_IRDLEXTENSION_H
namespace mlir {
class DialectRegistry;
namespace transform {
/// Registers the IRDL extension of the Transform dialect in the given registry.
void registerIRDLExtension(DialectRegistry &dialectRegistry);
} // namespace transform
} // namespace mlir
#endif // MLIR_DIALECT_TRANSFORM_IRDLEXTENSION_IRDLEXTENSION_H

View File

@@ -0,0 +1,20 @@
//===- IRDLExtensionOps.h - IRDL Transform dialect extension ----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_TRANSFORM_IRDLEXTENSION_IRDLEXTENSIONOPS_H
#define MLIR_DIALECT_TRANSFORM_IRDLEXTENSION_IRDLEXTENSIONOPS_H
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/OpDefinition.h"
#define GET_OP_CLASSES
#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtensionOps.h.inc"
#endif // MLIR_DIALECT_TRANSFORM_IRDLEXTENSION_IRDLEXTENSIONOPS_H

View File

@@ -0,0 +1,36 @@
//===- IRDLExtensionOps.td - Transform dialect extension ---*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_TRANSFORM_IRDLEXTENSION_IRDLEXTENSIONOPS
#define MLIR_DIALECT_TRANSFORM_IRDLEXTENSION_IRDLEXTENSIONOPS
include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/SymbolInterfaces.td"
def IRDLCollectMatchingOp : TransformDialectOp<"irdl.collect_matching",
[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
SymbolTable,
NoTerminator]> {
let summary =
"Finds ops that match the IRDL definition without registering them.";
let arguments = (ins TransformHandleTypeInterface:$root);
let regions = (region SizedRegion<1>:$body);
let results = (outs TransformHandleTypeInterface:$matched);
let assemblyFormat =
"`in` $root `:` functional-type(operands, results) attr-dict-with-keyword "
"regions";
let hasVerifier = 1;
}
#endif // MLIR_DIALECT_TRANSFORM_IRDLEXTENSION_IRDLEXTENSIONOPS

View File

@@ -6,6 +6,9 @@
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_TRANSFORM_PDLEXTENSION_PDLEXTENSION_H
#define MLIR_DIALECT_TRANSFORM_PDLEXTENSION_PDLEXTENSION_H
namespace mlir {
class DialectRegistry;
@@ -14,3 +17,5 @@ namespace transform {
void registerPDLExtension(DialectRegistry &dialectRegistry);
} // namespace transform
} // namespace mlir
#endif // MLIR_DIALECT_TRANSFORM_PDLEXTENSION_PDLEXTENSION_H

View File

@@ -35,6 +35,7 @@
#include "mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.h"
#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
#include "mlir/Dialect/Transform/DebugExtension/DebugExtension.h"
#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h"
#include "mlir/Dialect/Transform/LoopExtension/LoopExtension.h"
#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
@@ -77,6 +78,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
sparse_tensor::registerTransformDialectExtension(registry);
tensor::registerTransformDialectExtension(registry);
transform::registerDebugExtension(registry);
transform::registerIRDLExtension(registry);
transform::registerLoopExtension(registry);
transform::registerPDLExtension(registry);
vector::registerTransformDialectExtension(registry);

View File

@@ -270,26 +270,30 @@ static LogicalResult irdlRegionVerifier(
return success();
}
/// Define and load an operation represented by a `irdl.operation`
/// operation.
static WalkResult loadOperation(
OperationOp op, ExtensibleDialect *dialect,
DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types,
DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> &attrs) {
llvm::unique_function<LogicalResult(Operation *) const>
mlir::irdl::createVerifier(
OperationOp op,
const DenseMap<irdl::TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types,
const DenseMap<irdl::AttributeOp, std::unique_ptr<DynamicAttrDefinition>>
&attrs) {
// Resolve SSA values to verifier constraint slots
SmallVector<Value> constrToValue;
SmallVector<Value> regionToValue;
for (Operation &op : op->getRegion(0).getOps()) {
if (isa<VerifyConstraintInterface>(op)) {
if (op.getNumResults() != 1)
return op.emitError()
<< "IRDL constraint operations must have exactly one result";
if (op.getNumResults() != 1) {
op.emitError()
<< "IRDL constraint operations must have exactly one result";
return nullptr;
}
constrToValue.push_back(op.getResult(0));
}
if (isa<VerifyRegionInterface>(op)) {
if (op.getNumResults() != 1)
return op.emitError()
<< "IRDL constraint operations must have exactly one result";
if (op.getNumResults() != 1) {
op.emitError()
<< "IRDL constraint operations must have exactly one result";
return nullptr;
}
regionToValue.push_back(op.getResult(0));
}
}
@@ -302,7 +306,7 @@ static WalkResult loadOperation(
std::unique_ptr<Constraint> verifier =
op.getVerifier(constrToValue, types, attrs);
if (!verifier)
return WalkResult::interrupt();
return nullptr;
constraints.push_back(std::move(verifier));
}
@@ -358,7 +362,7 @@ static WalkResult loadOperation(
}
// Gather which constraint slots correspond to attributes constraints
DenseMap<StringAttr, size_t> attributesContraints;
DenseMap<StringAttr, size_t> attributeConstraints;
auto attributesOp = op.getOp<AttributesOp>();
if (attributesOp.has_value()) {
const Operation::operand_range values = attributesOp->getAttributeValues();
@@ -367,13 +371,40 @@ static WalkResult loadOperation(
for (const auto &[name, value] : llvm::zip(names, values)) {
for (auto [i, constr] : enumerate(constrToValue)) {
if (constr == value) {
attributesContraints[cast<StringAttr>(name)] = i;
attributeConstraints[cast<StringAttr>(name)] = i;
break;
}
}
}
}
return
[constraints{std::move(constraints)},
regionConstraints{std::move(regionConstraints)},
operandConstraints{std::move(operandConstraints)},
operandVariadicity{std::move(operandVariadicity)},
resultConstraints{std::move(resultConstraints)},
resultVariadicity{std::move(resultVariadicity)},
attributeConstraints{std::move(attributeConstraints)}](Operation *op) {
ConstraintVerifier verifier(constraints);
const LogicalResult opVerifierResult = irdlOpVerifier(
op, verifier, operandConstraints, operandVariadicity,
resultConstraints, resultVariadicity, attributeConstraints);
const LogicalResult opRegionVerifierResult =
irdlRegionVerifier(op, verifier, regionConstraints);
return LogicalResult::success(opVerifierResult.succeeded() &&
opRegionVerifierResult.succeeded());
};
}
/// Define and load an operation represented by a `irdl.operation`
/// operation.
static WalkResult loadOperation(
OperationOp op, ExtensibleDialect *dialect,
const DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types,
const DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>>
&attrs) {
// IRDL does not support defining custom parsers or printers.
auto parser = [](OpAsmParser &parser, OperationState &result) {
return failure();
@@ -382,25 +413,11 @@ static WalkResult loadOperation(
printer.printGenericOp(op);
};
auto verifier =
[constraints{std::move(constraints)},
regionConstraints{std::move(regionConstraints)},
operandConstraints{std::move(operandConstraints)},
operandVariadicity{std::move(operandVariadicity)},
resultConstraints{std::move(resultConstraints)},
resultVariadicity{std::move(resultVariadicity)},
attributesContraints{std::move(attributesContraints)}](Operation *op) {
ConstraintVerifier verifier(constraints);
const LogicalResult opVerifierResult = irdlOpVerifier(
op, verifier, operandConstraints, operandVariadicity,
resultConstraints, resultVariadicity, attributesContraints);
const LogicalResult opRegionVerifierResult =
irdlRegionVerifier(op, verifier, regionConstraints);
return LogicalResult::success(opVerifierResult.succeeded() &&
opRegionVerifierResult.succeeded());
};
auto verifier = createVerifier(op, types, attrs);
if (!verifier)
return WalkResult::interrupt();
// IRDL supports only checking number of blocks and argument contraints
// IRDL supports only checking number of blocks and argument constraints
// It is done in the main verifier to reuse `ConstraintVerifier` context
auto regionVerifier = [](Operation *op) { return LogicalResult::success(); };

View File

@@ -1,6 +1,7 @@
add_subdirectory(DebugExtension)
add_subdirectory(Interfaces)
add_subdirectory(IR)
add_subdirectory(IRDLExtension)
add_subdirectory(LoopExtension)
add_subdirectory(PDLExtension)
add_subdirectory(Transforms)

View File

@@ -0,0 +1,12 @@
add_mlir_dialect_library(MLIRTransformDialectIRDLExtension
IRDLExtension.cpp
IRDLExtensionOps.cpp
DEPENDS
MLIRTransformDialectIRDLExtensionOpsIncGen
LINK_LIBS PUBLIC
MLIRIR
MLIRTransformDialect
MLIRIRDL
)

View File

@@ -0,0 +1,34 @@
//===- IRDLExtension.cpp - IRDL extension for the Transform dialect -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h"
#include "mlir/Dialect/IRDL/IR/IRDL.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtensionOps.h"
#include "mlir/IR/DialectRegistry.h"
using namespace mlir;
namespace {
class IRDLExtension
: public transform::TransformDialectExtension<IRDLExtension> {
public:
void init() {
registerTransformOps<
#define GET_OP_LIST
#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtensionOps.cpp.inc"
>();
declareDependentDialect<irdl::IRDLDialect>();
}
};
} // namespace
void mlir::transform::registerIRDLExtension(DialectRegistry &dialectRegistry) {
dialectRegistry.addExtensions<IRDLExtension>();
}

View File

@@ -0,0 +1,84 @@
//===- IRDLExtensionOps.cpp - IRDL extension for the Transform dialect ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtensionOps.h"
#include "mlir/Dialect/IRDL/IR/IRDL.h"
#include "mlir/Dialect/IRDL/IRDLVerifiers.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/ExtensibleDialect.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "llvm/ADT/STLExtras.h"
using namespace mlir;
#define GET_OP_CLASSES
#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtensionOps.cpp.inc"
namespace mlir::transform {
DiagnosedSilenceableFailure
IRDLCollectMatchingOp::apply(TransformRewriter &rewriter,
TransformResults &results, TransformState &state) {
auto dialect = cast<irdl::DialectOp>(getBody().front().front());
Block &body = dialect.getBody().front();
irdl::OperationOp operation = *body.getOps<irdl::OperationOp>().begin();
auto verifier = irdl::createVerifier(
operation,
DenseMap<irdl::TypeOp, std::unique_ptr<DynamicTypeDefinition>>(),
DenseMap<irdl::AttributeOp, std::unique_ptr<DynamicAttrDefinition>>());
auto handlerID = getContext()->getDiagEngine().registerHandler(
[](Diagnostic &) { return success(); });
SmallVector<Operation *> matched;
for (Operation *payload : state.getPayloadOps(getRoot())) {
payload->walk([&](Operation *target) {
if (succeeded(verifier(target))) {
matched.push_back(target);
}
});
}
getContext()->getDiagEngine().eraseHandler(handlerID);
results.set(cast<OpResult>(getMatched()), matched);
return DiagnosedSilenceableFailure::success();
}
void IRDLCollectMatchingOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(getRoot(), effects);
producesHandle(getMatched(), effects);
onlyReadsPayload(effects);
}
LogicalResult IRDLCollectMatchingOp::verify() {
Block &bodyBlock = getBody().front();
if (!llvm::hasSingleElement(bodyBlock))
return emitOpError() << "expects a single operation in the body";
auto dialect = dyn_cast<irdl::DialectOp>(bodyBlock.front());
if (!dialect) {
return emitOpError() << "expects the body operation to be "
<< irdl::DialectOp::getOperationName();
}
// TODO: relax this by taking a symbol name of the operation to match, note
// that symbol name is also the name of the operation and we may want to
// divert from that to have constraints on-the-fly using IRDL.
auto irdlOperations = dialect.getOps<irdl::OperationOp>();
if (!llvm::hasSingleElement(irdlOperations))
return emitOpError() << "expects IRDL to contain exactly one operation";
if (!dialect.getOps<irdl::TypeOp>().empty() ||
!dialect.getOps<irdl::AttributeOp>().empty()) {
return emitOpError() << "IRDL types and attributes are not yet supported";
}
return success();
}
} // namespace mlir::transform

View File

@@ -0,0 +1,25 @@
// RUN: mlir-opt --transform-interpreter --split-input-file --verify-diagnostics
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
%0 = transform.irdl.collect_matching in %arg0 : (!transform.any_op) -> (!transform.any_op){
^bb0(%arg1: !transform.any_op):
irdl.dialect @test {
irdl.operation @whatever {
%0 = irdl.is i32
%1 = irdl.is i64
%2 = irdl.any_of(%0, %1)
irdl.results(%2)
}
}
}
transform.debug.emit_remark_at %0, "matched" : !transform.any_op
transform.yield
}
// expected-remark @below {{matched}}
"test.whatever"() : () -> i32
"test.whatever"() : () -> f32
// expected-remark @below {{matched}}
"test.whatever"() : () -> i64
}