mirror of
https://github.com/intel/llvm.git
synced 2026-01-31 07:27:33 +08:00
[mlir] multi-argument binding for top-level transform ops
`applyTransforms` now takes an optional mapping to be associated with trailing block arguments of the top-level transform op, in addition to the payload root. This allows for more advanced forms of communication between C++ code and the transform dialect interpreter, in particular supplying operations without having to re-match them during interpretation. Reviewed By: shabalin Differential Revision: https://reviews.llvm.org/D142559
This commit is contained in:
@@ -109,13 +109,19 @@ A program transformation expressed using the Transform dialect can be
|
||||
programmatically triggered by calling:
|
||||
|
||||
```c++
|
||||
LogicalResult transform::applyTransforms(Operation *payloadRoot,
|
||||
TransformOpInterface transform,
|
||||
const TransformOptions &options);
|
||||
LogicalResult transform::applyTransforms(
|
||||
Operation *payloadRoot,
|
||||
ArrayRef<ArrayRef<PointerUnion<Operation *, Attribute>> extraMappings,
|
||||
TransformOpInterface transform,
|
||||
const TransformOptions &options);
|
||||
```
|
||||
|
||||
that applies the transformations specified by the top-level `transform` to
|
||||
payload IR contained in `payloadRoot`.
|
||||
payload IR contained in `payloadRoot`. The payload root operation will be
|
||||
associated with the first argument of the entry block of the top-level transform
|
||||
op. This block may have additional arguments, handles or parameters. They will
|
||||
be associated with values provided as `extraMappings`. The call will report an
|
||||
error and return if the wrong number of mappings is provided.
|
||||
|
||||
## Dialect Extension Mechanism
|
||||
|
||||
|
||||
@@ -42,6 +42,9 @@ private:
|
||||
bool expensiveChecksEnabled = true;
|
||||
};
|
||||
|
||||
using Param = Attribute;
|
||||
using MappedValue = llvm::PointerUnion<Operation *, Param>;
|
||||
|
||||
/// Entry point to the Transform dialect infrastructure. Applies the
|
||||
/// transformation specified by `transform` to payload IR contained in
|
||||
/// `payloadRoot`. The `transform` operation may contain other operations that
|
||||
@@ -50,6 +53,7 @@ private:
|
||||
/// This function internally keeps track of the transformation state.
|
||||
LogicalResult
|
||||
applyTransforms(Operation *payloadRoot, TransformOpInterface transform,
|
||||
ArrayRef<ArrayRef<MappedValue>> extraMapping = {},
|
||||
const TransformOptions &options = TransformOptions());
|
||||
|
||||
/// The state maintained across applications of various ops implementing the
|
||||
@@ -85,7 +89,7 @@ applyTransforms(Operation *payloadRoot, TransformOpInterface transform,
|
||||
/// using `mapBlockArguments`.
|
||||
class TransformState {
|
||||
public:
|
||||
using Param = Attribute;
|
||||
using Param = transform::Param;
|
||||
|
||||
private:
|
||||
/// Mapping between a Value in the transform IR and the corresponding set of
|
||||
@@ -109,15 +113,23 @@ private:
|
||||
ParamMapping params;
|
||||
};
|
||||
|
||||
friend LogicalResult applyTransforms(Operation *payloadRoot,
|
||||
TransformOpInterface transform,
|
||||
const TransformOptions &options);
|
||||
friend LogicalResult applyTransforms(Operation *, TransformOpInterface,
|
||||
ArrayRef<ArrayRef<MappedValue>>,
|
||||
const TransformOptions &);
|
||||
|
||||
public:
|
||||
/// Returns the op at which the transformation state is rooted. This is
|
||||
/// typically helpful for transformations that apply globally.
|
||||
Operation *getTopLevel() const;
|
||||
|
||||
/// Returns the number of extra mappings for the top-level operation.
|
||||
size_t getNumTopLevelMappings() const { return topLevelMappedValues.size(); }
|
||||
|
||||
/// Returns the position-th extra mapping for the top-level operation.
|
||||
ArrayRef<MappedValue> getTopLevelMapping(size_t position) const {
|
||||
return topLevelMappedValues[position];
|
||||
}
|
||||
|
||||
/// Returns the list of ops that the given transform IR value corresponds to.
|
||||
/// This is helpful for transformations that apply to a particular handle.
|
||||
ArrayRef<Operation *> getPayloadOps(Value value) const;
|
||||
@@ -150,6 +162,8 @@ public:
|
||||
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
|
||||
return setPayloadOps(argument, operations);
|
||||
}
|
||||
LogicalResult mapBlockArgument(BlockArgument argument,
|
||||
ArrayRef<MappedValue> values);
|
||||
|
||||
// Forward declarations to support limited visibility.
|
||||
class RegionScope;
|
||||
@@ -302,6 +316,7 @@ private:
|
||||
/// which may or may not contain the region with transform ops. Additional
|
||||
/// options can be provided through the trailing configuration object.
|
||||
TransformState(Region *region, Operation *payloadRoot,
|
||||
ArrayRef<ArrayRef<MappedValue>> extraMappings = {},
|
||||
const TransformOptions &options = TransformOptions());
|
||||
|
||||
/// Returns the mappings frame for the reigon in which the value is defined.
|
||||
@@ -403,6 +418,15 @@ private:
|
||||
/// The top-level operation that contains all payload IR, typically a module.
|
||||
Operation *topLevel;
|
||||
|
||||
/// Storage for extra mapped values (payload operations or parameters) to be
|
||||
/// associated with additional entry block arguments of the top-level
|
||||
/// transform operation. Each entry in `topLevelMappedValues` is a reference
|
||||
/// to a contiguous block in `topLevelMappedValueStorage`.
|
||||
// TODO: turn this into a proper named data structure, there are several more
|
||||
// below.
|
||||
SmallVector<ArrayRef<MappedValue>> topLevelMappedValues;
|
||||
SmallVector<MappedValue> topLevelMappedValueStorage;
|
||||
|
||||
/// Additional options controlling the transformation state behavior.
|
||||
TransformOptions options;
|
||||
|
||||
|
||||
@@ -26,6 +26,9 @@ class FailurePropagationModeAttr;
|
||||
/// A builder function that populates the body of a SequenceOp.
|
||||
using SequenceBodyBuilderFn = ::llvm::function_ref<void(
|
||||
::mlir::OpBuilder &, ::mlir::Location, ::mlir::BlockArgument)>;
|
||||
using SequenceBodyBuilderArgsFn =
|
||||
::llvm::function_ref<void(::mlir::OpBuilder &, ::mlir::Location,
|
||||
::mlir::BlockArgument, ::mlir::ValueRange)>;
|
||||
} // namespace transform
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@@ -384,7 +384,8 @@ def SequenceOp : TransformDialectOp<"sequence",
|
||||
DeclareOpInterfaceMethods<TransformOpInterface>,
|
||||
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
|
||||
OpAsmOpInterface, PossibleTopLevelTransformOpTrait,
|
||||
SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">]> {
|
||||
SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">,
|
||||
AttrSizedOperandSegments]> {
|
||||
let summary = "Contains a sequence of other transform ops to apply";
|
||||
let description = [{
|
||||
The transformations indicated by the sequence are applied in order of their
|
||||
@@ -417,12 +418,14 @@ def SequenceOp : TransformDialectOp<"sequence",
|
||||
}];
|
||||
|
||||
let arguments = (ins FailurePropagationMode:$failure_propagation_mode,
|
||||
Optional<TransformHandleTypeInterface>:$root);
|
||||
Optional<TransformHandleTypeInterface>:$root,
|
||||
Variadic<Transform_AnyHandleOrParamType>:$extra_bindings);
|
||||
let results = (outs Variadic<TransformHandleTypeInterface>:$results);
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
|
||||
let assemblyFormat =
|
||||
"($root^ `:` type($root))? (`->` type($results)^)? `failures` `(` "
|
||||
"custom<SequenceOpOperands>($root, type($root), $extra_bindings, type($extra_bindings))"
|
||||
" (`->` type($results)^)? `failures` `(` "
|
||||
"$failure_propagation_mode `)` attr-dict-with-keyword regions";
|
||||
|
||||
let builders = [
|
||||
@@ -432,11 +435,25 @@ def SequenceOp : TransformDialectOp<"sequence",
|
||||
"::mlir::transform::FailurePropagationMode":$failure_propagation_mode,
|
||||
"::mlir::Value":$root, "SequenceBodyBuilderFn":$bodyBuilder)>,
|
||||
|
||||
// Build a sequence without a root but a certain bbArg type.
|
||||
// Build a sequence with a root and additional arguments.
|
||||
OpBuilder<(ins
|
||||
"::mlir::TypeRange":$resultTypes,
|
||||
"::mlir::transform::FailurePropagationMode":$failure_propagation_mode,
|
||||
"::mlir::Type":$bbArgType, "SequenceBodyBuilderFn":$bodyBuilder)>
|
||||
"::mlir::Value":$root, "::mlir::ValueRange":$extraBindings,
|
||||
"SequenceBodyBuilderArgsFn":$bodyBuilder)>,
|
||||
|
||||
// Build a top-level sequence (no root).
|
||||
OpBuilder<(ins
|
||||
"::mlir::TypeRange":$resultTypes,
|
||||
"::mlir::transform::FailurePropagationMode":$failure_propagation_mode,
|
||||
"::mlir::Type":$bbArgType, "SequenceBodyBuilderFn":$bodyBuilder)>,
|
||||
|
||||
// Build a top-level sequence (no root) with extra arguments.
|
||||
OpBuilder<(ins
|
||||
"::mlir::TypeRange":$resultTypes,
|
||||
"::mlir::transform::FailurePropagationMode":$failure_propagation_mode,
|
||||
"::mlir::Type":$bbArgType, "::mlir::TypeRange":$extraBindingTypes,
|
||||
"SequenceBodyBuilderArgsFn":$bodyBuilder)>
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
|
||||
@@ -27,10 +27,20 @@ using namespace mlir;
|
||||
|
||||
constexpr const Value transform::TransformState::kTopLevelValue;
|
||||
|
||||
transform::TransformState::TransformState(Region *region,
|
||||
Operation *payloadRoot,
|
||||
const TransformOptions &options)
|
||||
transform::TransformState::TransformState(
|
||||
Region *region, Operation *payloadRoot,
|
||||
ArrayRef<ArrayRef<MappedValue>> extraMappings,
|
||||
const TransformOptions &options)
|
||||
: topLevel(payloadRoot), options(options) {
|
||||
topLevelMappedValues.reserve(extraMappings.size());
|
||||
for (ArrayRef<MappedValue> mapping : extraMappings) {
|
||||
size_t start = topLevelMappedValueStorage.size();
|
||||
llvm::append_range(topLevelMappedValueStorage, mapping);
|
||||
topLevelMappedValues.push_back(
|
||||
ArrayRef<MappedValue>(topLevelMappedValueStorage)
|
||||
.slice(start, mapping.size()));
|
||||
}
|
||||
|
||||
auto result = mappings.try_emplace(region);
|
||||
assert(result.second && "the region scope is already present");
|
||||
(void)result;
|
||||
@@ -72,6 +82,38 @@ LogicalResult transform::TransformState::getHandlesForPayloadOp(
|
||||
return success(found);
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
transform::TransformState::mapBlockArgument(BlockArgument argument,
|
||||
ArrayRef<MappedValue> values) {
|
||||
if (argument.getType().isa<TransformHandleTypeInterface>()) {
|
||||
SmallVector<Operation *> operations;
|
||||
operations.reserve(values.size());
|
||||
for (MappedValue value : values) {
|
||||
if (auto *op = value.dyn_cast<Operation *>()) {
|
||||
operations.push_back(op);
|
||||
continue;
|
||||
}
|
||||
return emitError(argument.getLoc())
|
||||
<< "wrong kind of value provided for top-level operation handle";
|
||||
}
|
||||
return setPayloadOps(argument, operations);
|
||||
}
|
||||
|
||||
assert(argument.getType().isa<TransformParamTypeInterface>() &&
|
||||
"unsupported kind of block argument");
|
||||
SmallVector<Param> parameters;
|
||||
parameters.reserve(values.size());
|
||||
for (MappedValue value : values) {
|
||||
if (auto attr = value.dyn_cast<Attribute>()) {
|
||||
parameters.push_back(attr);
|
||||
continue;
|
||||
}
|
||||
return emitError(argument.getLoc())
|
||||
<< "wrong kind of value provided for top-level parameter";
|
||||
}
|
||||
return setParams(argument, parameters);
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
transform::TransformState::setPayloadOps(Value value,
|
||||
ArrayRef<Operation *> targets) {
|
||||
@@ -522,12 +564,43 @@ void transform::detail::setApplyToOneResults(
|
||||
LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments(
|
||||
TransformState &state, Operation *op, Region ®ion) {
|
||||
SmallVector<Operation *> targets;
|
||||
if (op->getNumOperands() != 0)
|
||||
SmallVector<SmallVector<MappedValue>> extraMappings;
|
||||
if (op->getNumOperands() != 0) {
|
||||
llvm::append_range(targets, state.getPayloadOps(op->getOperand(0)));
|
||||
else
|
||||
targets.push_back(state.getTopLevel());
|
||||
for (Value operand : op->getOperands().drop_front()) {
|
||||
SmallVector<MappedValue> &mapped = extraMappings.emplace_back();
|
||||
if (operand.getType().isa<TransformHandleTypeInterface>()) {
|
||||
llvm::append_range(mapped, state.getPayloadOps(operand));
|
||||
} else {
|
||||
assert(operand.getType().isa<TransformParamTypeInterface>() &&
|
||||
"unsupported kind of transform dialect value");
|
||||
llvm::append_range(mapped, state.getParams(operand));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (state.getNumTopLevelMappings() !=
|
||||
region.front().getNumArguments() - 1) {
|
||||
return emitError(op->getLoc())
|
||||
<< "operation expects " << region.front().getNumArguments() - 1
|
||||
<< " extra value bindings, but " << state.getNumTopLevelMappings()
|
||||
<< " were provided to the interpreter";
|
||||
}
|
||||
|
||||
return state.mapBlockArguments(region.front().getArgument(0), targets);
|
||||
targets.push_back(state.getTopLevel());
|
||||
for (unsigned i = 0, e = state.getNumTopLevelMappings(); i < e; ++i)
|
||||
extraMappings.push_back(llvm::to_vector(state.getTopLevelMapping(i)));
|
||||
}
|
||||
|
||||
if (failed(state.mapBlockArguments(region.front().getArgument(0), targets)))
|
||||
return failure();
|
||||
|
||||
for (BlockArgument argument : region.front().getArguments().drop_front()) {
|
||||
if (failed(state.mapBlockArgument(
|
||||
argument, extraMappings[argument.getArgNumber() - 1])))
|
||||
return failure();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
@@ -547,19 +620,42 @@ transform::detail::verifyPossibleTopLevelTransformOpTrait(Operation *op) {
|
||||
return op->emitOpError() << "expects a single-block region";
|
||||
|
||||
Block *body = &bodyRegion->front();
|
||||
if (body->getNumArguments() != 1 ||
|
||||
!body->getArgumentTypes()[0].isa<TransformHandleTypeInterface>()) {
|
||||
if (body->getNumArguments() == 0) {
|
||||
return op->emitOpError()
|
||||
<< "expects the entry block to have one argument "
|
||||
"of type implementing TransformHandleTypeInterface";
|
||||
<< "expects the entry block to have at least one argument";
|
||||
}
|
||||
if (!body->getArgument(0).getType().isa<TransformHandleTypeInterface>()) {
|
||||
return op->emitOpError()
|
||||
<< "expects the first entry block argument to be of type "
|
||||
"implementing TransformHandleTypeInterface";
|
||||
}
|
||||
BlockArgument arg = body->getArgument(0);
|
||||
if (op->getNumOperands() != 0) {
|
||||
if (arg.getType() != op->getOperand(0).getType()) {
|
||||
return op->emitOpError()
|
||||
<< "expects the type of the block argument to match "
|
||||
"the type of the operand";
|
||||
}
|
||||
}
|
||||
for (BlockArgument arg : body->getArguments().drop_front()) {
|
||||
if (arg.getType()
|
||||
.isa<TransformHandleTypeInterface, TransformParamTypeInterface>())
|
||||
continue;
|
||||
|
||||
InFlightDiagnostic diag =
|
||||
op->emitOpError()
|
||||
<< "expects trailing entry block arguments to be of type implementing "
|
||||
"TransformHandleTypeInterface or TransformParamTypeInterface";
|
||||
diag.attachNote() << "argument #" << arg.getArgNumber() << " does not";
|
||||
return diag;
|
||||
}
|
||||
|
||||
if (auto *parent =
|
||||
op->getParentWithTrait<PossibleTopLevelTransformOpTrait>()) {
|
||||
if (op->getNumOperands() == 0) {
|
||||
if (op->getNumOperands() != body->getNumArguments()) {
|
||||
InFlightDiagnostic diag =
|
||||
op->emitOpError()
|
||||
<< "expects the root operation to be provided for a nested op";
|
||||
<< "expects operands to be provided for a nested op";
|
||||
diag.attachNote(parent->getLoc())
|
||||
<< "nested in another possible top-level op";
|
||||
return diag;
|
||||
@@ -717,9 +813,11 @@ LogicalResult transform::detail::verifyTransformOpInterface(Operation *op) {
|
||||
// Entry point.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult transform::applyTransforms(Operation *payloadRoot,
|
||||
TransformOpInterface transform,
|
||||
const TransformOptions &options) {
|
||||
LogicalResult
|
||||
transform::applyTransforms(Operation *payloadRoot,
|
||||
TransformOpInterface transform,
|
||||
ArrayRef<ArrayRef<MappedValue>> extraMapping,
|
||||
const TransformOptions &options) {
|
||||
#ifndef NDEBUG
|
||||
if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
|
||||
transform->getNumOperands() != 0) {
|
||||
@@ -730,7 +828,8 @@ LogicalResult transform::applyTransforms(Operation *payloadRoot,
|
||||
}
|
||||
#endif // NDEBUG
|
||||
|
||||
TransformState state(transform->getParentRegion(), payloadRoot, options);
|
||||
TransformState state(transform->getParentRegion(), payloadRoot, extraMapping,
|
||||
options);
|
||||
return state.applyTransform(transform).checkAndReport();
|
||||
}
|
||||
|
||||
|
||||
@@ -26,6 +26,16 @@
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
static ParseResult parseSequenceOpOperands(
|
||||
OpAsmParser &parser, Optional<OpAsmParser::UnresolvedOperand> &root,
|
||||
Type &rootType,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
|
||||
SmallVectorImpl<Type> &extraBindingTypes);
|
||||
static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op,
|
||||
Value root, Type rootType,
|
||||
ValueRange extraBindings,
|
||||
TypeRange extraBindingTypes);
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
|
||||
|
||||
@@ -654,6 +664,76 @@ transform::SequenceOp::apply(transform::TransformResults &results,
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
static ParseResult parseSequenceOpOperands(
|
||||
OpAsmParser &parser, Optional<OpAsmParser::UnresolvedOperand> &root,
|
||||
Type &rootType,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
|
||||
SmallVectorImpl<Type> &extraBindingTypes) {
|
||||
OpAsmParser::UnresolvedOperand rootOperand;
|
||||
OptionalParseResult hasRoot = parser.parseOptionalOperand(rootOperand);
|
||||
if (!hasRoot.has_value()) {
|
||||
root = std::nullopt;
|
||||
return success();
|
||||
}
|
||||
if (failed(hasRoot.value()))
|
||||
return failure();
|
||||
root = rootOperand;
|
||||
|
||||
if (succeeded(parser.parseOptionalComma())) {
|
||||
if (failed(parser.parseOperandList(extraBindings)))
|
||||
return failure();
|
||||
}
|
||||
if (failed(parser.parseColon()))
|
||||
return failure();
|
||||
|
||||
// The paren is truly optional.
|
||||
(void)parser.parseOptionalLParen();
|
||||
|
||||
if (failed(parser.parseType(rootType))) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (!extraBindings.empty()) {
|
||||
if (parser.parseComma() || parser.parseTypeList(extraBindingTypes))
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (extraBindingTypes.size() != extraBindings.size()) {
|
||||
return parser.emitError(parser.getNameLoc(),
|
||||
"expected types to be provided for all operands");
|
||||
}
|
||||
|
||||
// The paren is truly optional.
|
||||
(void)parser.parseOptionalRParen();
|
||||
return success();
|
||||
}
|
||||
|
||||
static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op,
|
||||
Value root, Type rootType,
|
||||
ValueRange extraBindings,
|
||||
TypeRange extraBindingTypes) {
|
||||
if (!root)
|
||||
return;
|
||||
|
||||
printer << root;
|
||||
bool hasExtras = !extraBindings.empty();
|
||||
if (hasExtras) {
|
||||
printer << ", ";
|
||||
printer.printOperands(extraBindings);
|
||||
}
|
||||
|
||||
printer << " : ";
|
||||
if (hasExtras)
|
||||
printer << "(";
|
||||
|
||||
printer << rootType;
|
||||
if (hasExtras) {
|
||||
printer << ", ";
|
||||
llvm::interleaveComma(extraBindingTypes, printer.getStream());
|
||||
printer << ")";
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns `true` if the given op operand may be consuming the handle value in
|
||||
/// the Transform IR. That is, if it may have a Free effect on it.
|
||||
static bool isValueUsePotentialConsumer(OpOperand &use) {
|
||||
@@ -691,22 +771,22 @@ checkDoubleConsume(Value value,
|
||||
}
|
||||
|
||||
LogicalResult transform::SequenceOp::verify() {
|
||||
assert(getBodyBlock()->getNumArguments() == 1 &&
|
||||
"the number of arguments must have been verified to be 1 by "
|
||||
assert(getBodyBlock()->getNumArguments() >= 1 &&
|
||||
"the number of arguments must have been verified to be more than 1 by "
|
||||
"PossibleTopLevelTransformOpTrait");
|
||||
|
||||
BlockArgument arg = getBodyBlock()->getArgument(0);
|
||||
if (getRoot()) {
|
||||
if (arg.getType() != getRoot().getType()) {
|
||||
return emitOpError() << "expects the type of the block argument to match "
|
||||
"the type of the operand";
|
||||
}
|
||||
if (!getRoot() && !getExtraBindings().empty()) {
|
||||
return emitOpError()
|
||||
<< "does not expect extra operands when used as top-level";
|
||||
}
|
||||
|
||||
// Check if the block argument has more than one consuming use.
|
||||
if (failed(checkDoubleConsume(
|
||||
arg, [this]() { return (emitOpError() << "block argument #0"); }))) {
|
||||
return failure();
|
||||
// Check if a block argument has more than one consuming use.
|
||||
for (BlockArgument arg : getBodyBlock()->getArguments()) {
|
||||
if (failed(checkDoubleConsume(arg, [this, arg]() {
|
||||
return (emitOpError() << "block argument #" << arg.getArgNumber());
|
||||
}))) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
// Check properties of the nested operations they cannot check themselves.
|
||||
@@ -740,26 +820,26 @@ LogicalResult transform::SequenceOp::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Appends to `effects` the memory effect instances on `target` with the same
|
||||
/// resource and effect as the ones the operation `iface` having on `source`.
|
||||
static void
|
||||
remapEffects(MemoryEffectOpInterface iface, BlockArgument source, Value target,
|
||||
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
||||
SmallVector<MemoryEffects::EffectInstance> nestedEffects;
|
||||
iface.getEffectsOnValue(source, nestedEffects);
|
||||
for (const auto &effect : nestedEffects)
|
||||
effects.emplace_back(effect.getEffect(), target, effect.getResource());
|
||||
}
|
||||
|
||||
void transform::SequenceOp::getEffects(
|
||||
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
||||
auto *mappingResource = TransformMappingResource::get();
|
||||
effects.emplace_back(MemoryEffects::Read::get(), getRoot(), mappingResource);
|
||||
|
||||
for (Value result : getResults()) {
|
||||
effects.emplace_back(MemoryEffects::Allocate::get(), result,
|
||||
mappingResource);
|
||||
effects.emplace_back(MemoryEffects::Write::get(), result, mappingResource);
|
||||
}
|
||||
onlyReadsHandle(getRoot(), effects);
|
||||
onlyReadsHandle(getExtraBindings(), effects);
|
||||
producesHandle(getResults(), effects);
|
||||
|
||||
if (!getRoot()) {
|
||||
for (Operation &op : *getBodyBlock()) {
|
||||
auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
|
||||
if (!iface) {
|
||||
// TODO: fill all possible effects; or require ops to actually implement
|
||||
// the memory effect interface always
|
||||
assert(false);
|
||||
}
|
||||
|
||||
auto iface = cast<MemoryEffectOpInterface>(&op);
|
||||
SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects;
|
||||
iface.getEffects(effects);
|
||||
}
|
||||
@@ -769,24 +849,20 @@ void transform::SequenceOp::getEffects(
|
||||
// Carry over all effects on the argument of the entry block as those on the
|
||||
// operand, this is the same value just remapped.
|
||||
for (Operation &op : *getBodyBlock()) {
|
||||
auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
|
||||
if (!iface) {
|
||||
// TODO: fill all possible effects; or require ops to actually implement
|
||||
// the memory effect interface always
|
||||
assert(false);
|
||||
}
|
||||
auto iface = cast<MemoryEffectOpInterface>(&op);
|
||||
|
||||
SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects;
|
||||
iface.getEffectsOnValue(getBodyBlock()->getArgument(0), nestedEffects);
|
||||
for (const auto &effect : nestedEffects)
|
||||
effects.emplace_back(effect.getEffect(), getRoot(), effect.getResource());
|
||||
remapEffects(iface, getBodyBlock()->getArgument(0), getRoot(), effects);
|
||||
for (auto [source, target] : llvm::zip(
|
||||
getBodyBlock()->getArguments().drop_front(), getExtraBindings())) {
|
||||
remapEffects(iface, source, target, effects);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
OperandRange transform::SequenceOp::getSuccessorEntryOperands(
|
||||
std::optional<unsigned> index) {
|
||||
assert(index && *index == 0 && "unexpected region index");
|
||||
if (getOperation()->getNumOperands() == 1)
|
||||
if (getOperation()->getNumOperands() > 0)
|
||||
return getOperation()->getOperands();
|
||||
return OperandRange(getOperation()->operand_end(),
|
||||
getOperation()->operand_end());
|
||||
@@ -813,21 +889,51 @@ void transform::SequenceOp::getRegionInvocationBounds(
|
||||
bounds.emplace_back(1, 1);
|
||||
}
|
||||
|
||||
template <typename FnTy>
|
||||
static void buildSequenceBody(OpBuilder &builder, OperationState &state,
|
||||
Type bbArgType, TypeRange extraBindingTypes,
|
||||
FnTy bodyBuilder) {
|
||||
SmallVector<Type> types;
|
||||
types.reserve(1 + extraBindingTypes.size());
|
||||
types.push_back(bbArgType);
|
||||
llvm::append_range(types, extraBindingTypes);
|
||||
|
||||
OpBuilder::InsertionGuard guard(builder);
|
||||
Region *region = state.regions.back().get();
|
||||
Block *bodyBlock = builder.createBlock(region, region->begin(),
|
||||
extraBindingTypes, {state.location});
|
||||
|
||||
// Populate body.
|
||||
builder.setInsertionPointToStart(bodyBlock);
|
||||
if constexpr (llvm::function_traits<FnTy>::num_args == 3) {
|
||||
bodyBuilder(builder, state.location, bodyBlock->getArgument(0));
|
||||
} else {
|
||||
bodyBuilder(builder, state.location, bodyBlock->getArgument(0),
|
||||
bodyBlock->getArguments().drop_front());
|
||||
}
|
||||
}
|
||||
|
||||
void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
|
||||
TypeRange resultTypes,
|
||||
FailurePropagationMode failurePropagationMode,
|
||||
Value root,
|
||||
SequenceBodyBuilderFn bodyBuilder) {
|
||||
build(builder, state, resultTypes, failurePropagationMode, root);
|
||||
Region *region = state.regions.back().get();
|
||||
build(builder, state, resultTypes, failurePropagationMode, root,
|
||||
/*extraBindings=*/ValueRange());
|
||||
Type bbArgType = root.getType();
|
||||
OpBuilder::InsertionGuard guard(builder);
|
||||
Block *bodyBlock = builder.createBlock(
|
||||
region, region->begin(), TypeRange{bbArgType}, {state.location});
|
||||
buildSequenceBody(builder, state, bbArgType,
|
||||
/*extraBindingTypes=*/TypeRange(), bodyBuilder);
|
||||
}
|
||||
|
||||
// Populate body.
|
||||
builder.setInsertionPointToStart(bodyBlock);
|
||||
bodyBuilder(builder, state.location, bodyBlock->getArgument(0));
|
||||
void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
|
||||
TypeRange resultTypes,
|
||||
FailurePropagationMode failurePropagationMode,
|
||||
Value root, ValueRange extraBindings,
|
||||
SequenceBodyBuilderArgsFn bodyBuilder) {
|
||||
build(builder, state, resultTypes, failurePropagationMode, root,
|
||||
extraBindings);
|
||||
buildSequenceBody(builder, state, root.getType(), extraBindings.getTypes(),
|
||||
bodyBuilder);
|
||||
}
|
||||
|
||||
void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
|
||||
@@ -835,15 +941,20 @@ void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
|
||||
FailurePropagationMode failurePropagationMode,
|
||||
Type bbArgType,
|
||||
SequenceBodyBuilderFn bodyBuilder) {
|
||||
build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value());
|
||||
Region *region = state.regions.back().get();
|
||||
OpBuilder::InsertionGuard guard(builder);
|
||||
Block *bodyBlock = builder.createBlock(
|
||||
region, region->begin(), TypeRange{bbArgType}, {state.location});
|
||||
build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
|
||||
/*extraBindings=*/ValueRange());
|
||||
buildSequenceBody(builder, state, bbArgType,
|
||||
/*extraBindingTypes=*/TypeRange(), bodyBuilder);
|
||||
}
|
||||
|
||||
// Populate body.
|
||||
builder.setInsertionPointToStart(bodyBlock);
|
||||
bodyBuilder(builder, state.location, bodyBlock->getArgument(0));
|
||||
void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
|
||||
TypeRange resultTypes,
|
||||
FailurePropagationMode failurePropagationMode,
|
||||
Type bbArgType, TypeRange extraBindingTypes,
|
||||
SequenceBodyBuilderArgsFn bodyBuilder) {
|
||||
build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
|
||||
/*extraBindings=*/ValueRange());
|
||||
buildSequenceBody(builder, state, bbArgType, extraBindingTypes, bodyBuilder);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -89,7 +89,9 @@ class ReplicateOp:
|
||||
class SequenceOp:
|
||||
|
||||
def __init__(self, failure_propagation_mode, results: Sequence[Type],
|
||||
target: Union[Operation, Value, Type]):
|
||||
target: Union[Operation, Value, Type],
|
||||
extra_bindings: Optional[Union[Sequence[Value], Sequence[Type],
|
||||
Operation, OpView]] = None):
|
||||
root = _get_op_result_or_value(target) if isinstance(
|
||||
target, (Operation, Value)) else None
|
||||
root_type = root.type if not isinstance(target, Type) else target
|
||||
@@ -98,10 +100,25 @@ class SequenceOp:
|
||||
IntegerType.get_signless(32), failure_propagation_mode._as_int())
|
||||
else:
|
||||
failure_propagation_mode = failure_propagation_mode
|
||||
|
||||
if extra_bindings is None:
|
||||
extra_bindings = []
|
||||
if isinstance(extra_bindings, (Operation, OpView)):
|
||||
extra_bindings = _get_op_results_or_values(extra_bindings)
|
||||
|
||||
extra_binding_types = []
|
||||
if len(extra_bindings) != 0:
|
||||
if isinstance(extra_bindings[0], Type):
|
||||
extra_binding_types = extra_bindings
|
||||
extra_bindings = []
|
||||
else:
|
||||
extra_binding_types = [v.type for v in extra_bindings]
|
||||
|
||||
super().__init__(results_=results,
|
||||
failure_propagation_mode=failure_propagation_mode_attr,
|
||||
root=root)
|
||||
self.regions[0].blocks.append(root_type)
|
||||
root=root,
|
||||
extra_bindings=extra_bindings)
|
||||
self.regions[0].blocks.append(*tuple([root_type] + extra_binding_types))
|
||||
|
||||
@property
|
||||
def body(self) -> Block:
|
||||
@@ -111,6 +128,10 @@ class SequenceOp:
|
||||
def bodyTarget(self) -> Value:
|
||||
return self.body.arguments[0]
|
||||
|
||||
@property
|
||||
def bodyExtraArgs(self) -> BlockArgumentList:
|
||||
return self.body.arguments[1:]
|
||||
|
||||
|
||||
class WithPDLPatternsOp:
|
||||
|
||||
|
||||
71
mlir/test/Dialect/Transform/multi-arg-top-level-ops.mlir
Normal file
71
mlir/test/Dialect/Transform/multi-arg-top-level-ops.mlir
Normal file
@@ -0,0 +1,71 @@
|
||||
// RUN: mlir-opt %s --pass-pipeline='builtin.module(test-transform-dialect-interpreter{bind-first-extra-to-ops=func.func bind-second-extra-to-ops=func.return})' \
|
||||
// RUN: --split-input-file --verify-diagnostics
|
||||
|
||||
transform.sequence failures(propagate) {
|
||||
^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_op):
|
||||
transform.test_print_remark_at_operand %arg1, "first extra" : !transform.any_op
|
||||
transform.test_print_remark_at_operand %arg2, "second extra" : !transform.any_op
|
||||
}
|
||||
|
||||
// expected-remark @below {{first extra}}
|
||||
func.func @foo() {
|
||||
// expected-remark @below {{second extra}}
|
||||
return
|
||||
}
|
||||
|
||||
// expected-remark @below {{first extra}}
|
||||
func.func @bar(%arg0: i1) {
|
||||
cf.cond_br %arg0, ^bb1, ^bb2
|
||||
^bb1:
|
||||
// expected-remark @below {{second extra}}
|
||||
return
|
||||
^bb2:
|
||||
// expected-remark @below {{second extra}}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
transform.sequence failures(propagate) {
|
||||
^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.param<i64>):
|
||||
// expected-error @above {{wrong kind of value provided for top-level parameter}}
|
||||
}
|
||||
|
||||
func.func @foo() {
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @below {{operation expects 1 extra value bindings, but 2 were provided to the interpreter}}
|
||||
transform.sequence failures(propagate) {
|
||||
^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op):
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
transform.sequence failures(propagate) {
|
||||
^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_op):
|
||||
transform.sequence %arg0, %arg1, %arg2 : !transform.any_op, !transform.any_op, !transform.any_op failures(propagate) {
|
||||
^bb0(%arg3: !transform.any_op, %arg4: !transform.any_op, %arg5: !transform.any_op):
|
||||
transform.test_print_remark_at_operand %arg4, "first extra" : !transform.any_op
|
||||
transform.test_print_remark_at_operand %arg5, "second extra" : !transform.any_op
|
||||
}
|
||||
}
|
||||
|
||||
// expected-remark @below {{first extra}}
|
||||
func.func @foo() {
|
||||
// expected-remark @below {{second extra}}
|
||||
return
|
||||
}
|
||||
|
||||
// expected-remark @below {{first extra}}
|
||||
func.func @bar(%arg0: i1) {
|
||||
cf.cond_br %arg0, ^bb1, ^bb2
|
||||
^bb1:
|
||||
// expected-remark @below {{second extra}}
|
||||
return
|
||||
^bb2:
|
||||
// expected-remark @below {{second extra}}
|
||||
return
|
||||
}
|
||||
24
mlir/test/Dialect/Transform/multi-arg-top-level-params.mlir
Normal file
24
mlir/test/Dialect/Transform/multi-arg-top-level-params.mlir
Normal file
@@ -0,0 +1,24 @@
|
||||
// RUN: mlir-opt %s --pass-pipeline='builtin.module(test-transform-dialect-interpreter{bind-first-extra-to-params=1,2,3 bind-second-extra-to-params=42,45})' \
|
||||
// RUN: --split-input-file --verify-diagnostics
|
||||
|
||||
transform.sequence failures(propagate) {
|
||||
^bb0(%arg0: !pdl.operation, %arg1: !transform.param<i64>, %arg2: !transform.param<i64>):
|
||||
// expected-remark @below {{1 : i64, 2 : i64, 3 : i64}}
|
||||
transform.test_print_param %arg1 : !transform.param<i64>
|
||||
// expected-remark @below {{42 : i64, 45 : i64}}
|
||||
transform.test_print_param %arg2 : !transform.param<i64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
transform.sequence failures(propagate) {
|
||||
^bb0(%arg0: !pdl.operation, %arg1: !transform.any_op, %arg2: !transform.param<i64>):
|
||||
// expected-error @above {{wrong kind of value provided for top-level operation handle}}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @below {{operation expects 3 extra value bindings, but 2 were provided to the interpreter}}
|
||||
transform.sequence failures(propagate) {
|
||||
^bb0(%arg0: !pdl.operation, %arg1: !transform.param<i64>, %arg2: !transform.param<i64>, %arg3: !transform.param<i64>):
|
||||
}
|
||||
@@ -1,15 +1,22 @@
|
||||
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
|
||||
|
||||
// expected-error @below {{expects the entry block to have one argument of type implementing TransformHandleTypeInterface}}
|
||||
// expected-error @below {{expects the entry block to have at least one argument}}
|
||||
transform.sequence failures(propagate) {
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @below {{expects the first entry block argument to be of type implementing TransformHandleTypeInterface}}
|
||||
transform.sequence failures(propagate) {
|
||||
^bb0(%rag0: i64):
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-note @below {{nested in another possible top-level op}}
|
||||
transform.sequence failures(propagate) {
|
||||
^bb0(%arg0: !pdl.operation):
|
||||
// expected-error @below {{expects the root operation to be provided for a nested op}}
|
||||
// expected-error @below {{expects operands to be provided for a nested op}}
|
||||
transform.sequence failures(propagate) {
|
||||
^bb1(%arg1: !pdl.operation):
|
||||
}
|
||||
@@ -17,6 +24,14 @@ transform.sequence failures(propagate) {
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @below {{'transform.sequence' op expects trailing entry block arguments to be of type implementing TransformHandleTypeInterface or TransformParamTypeInterface}}
|
||||
// expected-note @below {{argument #1 does not}}
|
||||
transform.sequence failures(propagate) {
|
||||
^bb0(%arg0: !transform.any_op, %arg1: i64):
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @below {{expected children ops to implement TransformOpInterface}}
|
||||
transform.sequence failures(propagate) {
|
||||
^bb0(%arg0: !pdl.operation):
|
||||
@@ -46,10 +61,29 @@ transform.sequence failures(propagate) {
|
||||
|
||||
// -----
|
||||
|
||||
transform.sequence failures(propagate) {
|
||||
^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_op):
|
||||
// expected-error @below {{expected types to be provided for all operands}}
|
||||
transform.sequence %arg0, %arg1, %arg2 : (!transform.any_op, !transform.any_op) failures(propagate) {
|
||||
^bb0(%arg3: !transform.any_op, %arg4: !transform.any_op, %arg5: !transform.any_op):
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
%0 = "test.generate_something"() : () -> !transform.any_op
|
||||
// expected-error @below {{does not expect extra operands when used as top-level}}
|
||||
"transform.sequence"(%0) ({
|
||||
^bb0(%arg0: !transform.any_op):
|
||||
"transform.yield"() : () -> ()
|
||||
}) {failure_propagation_mode = 1 : i32, operand_segment_sizes = array<i32: 0, 1>} : (!transform.any_op) -> ()
|
||||
|
||||
// -----
|
||||
|
||||
// expected-note @below {{nested in another possible top-level op}}
|
||||
transform.with_pdl_patterns {
|
||||
^bb0(%arg0: !pdl.operation):
|
||||
// expected-error @below {{expects the root operation to be provided for a nested op}}
|
||||
// expected-error @below {{expects operands to be provided for a nested op}}
|
||||
transform.sequence failures(propagate) {
|
||||
^bb1(%arg1: !pdl.operation):
|
||||
}
|
||||
@@ -190,7 +224,7 @@ transform.sequence failures(propagate) {
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @below {{expects the entry block to have one argument of type implementing TransformHandleTypeInterface}}
|
||||
// expected-error @below {{expects the entry block to have at least one argument}}
|
||||
transform.alternatives {
|
||||
^bb0:
|
||||
transform.yield
|
||||
|
||||
@@ -50,6 +50,33 @@ transform.sequence failures(propagate) {
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK: transform.sequence failures(propagate)
|
||||
transform.sequence failures(propagate) {
|
||||
^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_op):
|
||||
// CHECK: sequence %{{.*}}, %{{.*}}, %{{.*}} : (!transform.any_op, !transform.any_op, !transform.any_op) failures(propagate)
|
||||
transform.sequence %arg0, %arg1, %arg2 : !transform.any_op, !transform.any_op, !transform.any_op failures(propagate) {
|
||||
^bb0(%arg3: !transform.any_op, %arg4: !transform.any_op, %arg5: !transform.any_op):
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK: transform.sequence failures(propagate)
|
||||
transform.sequence failures(propagate) {
|
||||
^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_op):
|
||||
// CHECK: sequence %{{.*}}, %{{.*}}, %{{.*}} : (!transform.any_op, !transform.any_op, !transform.any_op) failures(propagate)
|
||||
transform.sequence %arg0, %arg1, %arg2 : (!transform.any_op, !transform.any_op, !transform.any_op) failures(propagate) {
|
||||
^bb0(%arg3: !transform.any_op, %arg4: !transform.any_op, %arg5: !transform.any_op):
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK: transform.sequence failures(propagate)
|
||||
transform.sequence failures(propagate) {
|
||||
^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_op):
|
||||
// CHECK: sequence %{{.*}}, %{{.*}}, %{{.*}} : (!transform.any_op, !transform.any_op, !transform.any_op) failures(propagate)
|
||||
transform.sequence %arg0, %arg1, %arg2 : (!transform.any_op, !transform.any_op, !transform.any_op) failures(propagate) {
|
||||
^bb0(%arg3: !transform.any_op, %arg4: !transform.any_op, %arg5: !transform.any_op):
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK: transform.sequence
|
||||
// CHECK: foreach
|
||||
transform.sequence failures(propagate) {
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
@@ -39,12 +40,72 @@ public:
|
||||
return "apply transform dialect operations one by one";
|
||||
}
|
||||
|
||||
ArrayRef<transform::MappedValue>
|
||||
findOperationsByName(Operation *root, StringRef name,
|
||||
SmallVectorImpl<transform::MappedValue> &storage) {
|
||||
size_t start = storage.size();
|
||||
root->walk([&](Operation *op) {
|
||||
if (op->getName().getStringRef() == name) {
|
||||
storage.push_back(op);
|
||||
}
|
||||
});
|
||||
return ArrayRef(storage).drop_front(start);
|
||||
}
|
||||
|
||||
ArrayRef<transform::MappedValue>
|
||||
createParameterMapping(MLIRContext &context, ArrayRef<int> values,
|
||||
SmallVectorImpl<transform::MappedValue> &storage) {
|
||||
size_t start = storage.size();
|
||||
llvm::append_range(storage, llvm::map_range(values, [&](int v) {
|
||||
Builder b(&context);
|
||||
return transform::MappedValue(b.getI64IntegerAttr(v));
|
||||
}));
|
||||
return ArrayRef(storage).drop_front(start);
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
if (!bindFirstExtraToOps.empty() && !bindFirstExtraToParams.empty()) {
|
||||
emitError(UnknownLoc::get(&getContext()))
|
||||
<< "cannot bind the first extra top-level argument to both "
|
||||
"operations and parameters";
|
||||
return signalPassFailure();
|
||||
}
|
||||
if (!bindSecondExtraToOps.empty() && !bindSecondExtraToParams.empty()) {
|
||||
emitError(UnknownLoc::get(&getContext()))
|
||||
<< "cannot bind the second extra top-level argument to both "
|
||||
"operations and parameters";
|
||||
return signalPassFailure();
|
||||
}
|
||||
if ((!bindSecondExtraToOps.empty() || !bindSecondExtraToParams.empty()) &&
|
||||
bindFirstExtraToOps.empty() && bindFirstExtraToParams.empty()) {
|
||||
emitError(UnknownLoc::get(&getContext()))
|
||||
<< "cannot bind the second extra top-level argument without binding "
|
||||
"the first";
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
SmallVector<transform::MappedValue> extraMappingStorage;
|
||||
SmallVector<ArrayRef<transform::MappedValue>> extraMapping;
|
||||
if (!bindFirstExtraToOps.empty()) {
|
||||
extraMapping.push_back(findOperationsByName(
|
||||
getOperation(), bindFirstExtraToOps.getValue(), extraMappingStorage));
|
||||
} else if (!bindFirstExtraToParams.empty()) {
|
||||
extraMapping.push_back(createParameterMapping(
|
||||
getContext(), bindFirstExtraToParams, extraMappingStorage));
|
||||
}
|
||||
if (!bindSecondExtraToOps.empty()) {
|
||||
extraMapping.push_back(findOperationsByName(
|
||||
getOperation(), bindSecondExtraToOps, extraMappingStorage));
|
||||
} else if (!bindSecondExtraToParams.empty()) {
|
||||
extraMapping.push_back(createParameterMapping(
|
||||
getContext(), bindSecondExtraToParams, extraMappingStorage));
|
||||
}
|
||||
|
||||
ModuleOp module = getOperation();
|
||||
for (auto op :
|
||||
module.getBody()->getOps<transform::TransformOpInterface>()) {
|
||||
if (failed(transform::applyTransforms(
|
||||
module, op,
|
||||
module, op, extraMapping,
|
||||
transform::TransformOptions().enableExpensiveChecks(
|
||||
enableExpensiveChecks))))
|
||||
return signalPassFailure();
|
||||
@@ -55,6 +116,24 @@ public:
|
||||
*this, "enable-expensive-checks", llvm::cl::init(false),
|
||||
llvm::cl::desc("perform expensive checks to better report errors in the "
|
||||
"transform IR")};
|
||||
|
||||
Option<std::string> bindFirstExtraToOps{
|
||||
*this, "bind-first-extra-to-ops",
|
||||
llvm::cl::desc("bind the first extra argument of the top-level op to "
|
||||
"payload operations of the given kind")};
|
||||
ListOption<int> bindFirstExtraToParams{
|
||||
*this, "bind-first-extra-to-params",
|
||||
llvm::cl::desc("bind the first extra argument of the top-level op to "
|
||||
"the given integer parameters")};
|
||||
|
||||
Option<std::string> bindSecondExtraToOps{
|
||||
*this, "bind-second-extra-to-ops",
|
||||
llvm::cl::desc("bind the second extra argument of the top-level op to "
|
||||
"payload operations of the given kind")};
|
||||
ListOption<int> bindSecondExtraToParams{
|
||||
*this, "bind-second-extra-to-params",
|
||||
llvm::cl::desc("bind the second extra argument of the top-level op to "
|
||||
"the given integer parameters")};
|
||||
};
|
||||
|
||||
struct TestTransformDialectEraseSchedulePass
|
||||
|
||||
@@ -69,6 +69,38 @@ def testNestedSequenceOp():
|
||||
# CHECK: }
|
||||
|
||||
|
||||
@run
|
||||
def testSequenceOpWithExtras():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get(),
|
||||
[transform.AnyOpType.get(),
|
||||
transform.OperationType.get("foo.bar")])
|
||||
with InsertionPoint(sequence.body):
|
||||
transform.YieldOp()
|
||||
# CHECK-LABEL: TEST: testSequenceOpWithExtras
|
||||
# CHECK: transform.sequence failures(propagate)
|
||||
# CHECK: ^{{.*}}(%{{.*}}: !transform.any_op, %{{.*}}: !transform.any_op, %{{.*}}: !transform.op<"foo.bar">):
|
||||
|
||||
|
||||
@run
|
||||
def testNestedSequenceOpWithExtras():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get(),
|
||||
[transform.AnyOpType.get(),
|
||||
transform.OperationType.get("foo.bar")])
|
||||
with InsertionPoint(sequence.body):
|
||||
nested = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
|
||||
[], sequence.bodyTarget,
|
||||
sequence.bodyExtraArgs)
|
||||
with InsertionPoint(nested.body):
|
||||
transform.YieldOp()
|
||||
transform.YieldOp()
|
||||
# CHECK-LABEL: TEST: testNestedSequenceOpWithExtras
|
||||
# CHECK: transform.sequence failures(propagate)
|
||||
# CHECK: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op, %[[ARG1:.*]]: !transform.any_op, %[[ARG2:.*]]: !transform.op<"foo.bar">):
|
||||
# CHECK: sequence %[[ARG0]], %[[ARG1]], %[[ARG2]] : (!transform.any_op, !transform.any_op, !transform.op<"foo.bar">)
|
||||
|
||||
|
||||
@run
|
||||
def testTransformPDLOps():
|
||||
withPdl = transform.WithPDLPatternsOp(pdl.OperationType.get())
|
||||
|
||||
Reference in New Issue
Block a user