[mlir] multi-argument binding for top-level transform ops

`applyTransforms` now takes an optional mapping to be associated with
trailing block arguments of the top-level transform op, in addition to
the payload root. This allows for more advanced forms of communication
between C++ code and the transform dialect interpreter, in particular
supplying operations without having to re-match them during
interpretation.

Reviewed By: shabalin

Differential Revision: https://reviews.llvm.org/D142559
This commit is contained in:
Alex Zinenko
2023-01-25 16:53:25 +00:00
parent a51ad873bf
commit b9e40cde3b
13 changed files with 640 additions and 92 deletions

View File

@@ -109,13 +109,19 @@ A program transformation expressed using the Transform dialect can be
programmatically triggered by calling:
```c++
LogicalResult transform::applyTransforms(Operation *payloadRoot,
TransformOpInterface transform,
const TransformOptions &options);
LogicalResult transform::applyTransforms(
Operation *payloadRoot,
ArrayRef<ArrayRef<PointerUnion<Operation *, Attribute>> extraMappings,
TransformOpInterface transform,
const TransformOptions &options);
```
that applies the transformations specified by the top-level `transform` to
payload IR contained in `payloadRoot`.
payload IR contained in `payloadRoot`. The payload root operation will be
associated with the first argument of the entry block of the top-level transform
op. This block may have additional arguments, handles or parameters. They will
be associated with values provided as `extraMappings`. The call will report an
error and return if the wrong number of mappings is provided.
## Dialect Extension Mechanism

View File

@@ -42,6 +42,9 @@ private:
bool expensiveChecksEnabled = true;
};
using Param = Attribute;
using MappedValue = llvm::PointerUnion<Operation *, Param>;
/// Entry point to the Transform dialect infrastructure. Applies the
/// transformation specified by `transform` to payload IR contained in
/// `payloadRoot`. The `transform` operation may contain other operations that
@@ -50,6 +53,7 @@ private:
/// This function internally keeps track of the transformation state.
LogicalResult
applyTransforms(Operation *payloadRoot, TransformOpInterface transform,
ArrayRef<ArrayRef<MappedValue>> extraMapping = {},
const TransformOptions &options = TransformOptions());
/// The state maintained across applications of various ops implementing the
@@ -85,7 +89,7 @@ applyTransforms(Operation *payloadRoot, TransformOpInterface transform,
/// using `mapBlockArguments`.
class TransformState {
public:
using Param = Attribute;
using Param = transform::Param;
private:
/// Mapping between a Value in the transform IR and the corresponding set of
@@ -109,15 +113,23 @@ private:
ParamMapping params;
};
friend LogicalResult applyTransforms(Operation *payloadRoot,
TransformOpInterface transform,
const TransformOptions &options);
friend LogicalResult applyTransforms(Operation *, TransformOpInterface,
ArrayRef<ArrayRef<MappedValue>>,
const TransformOptions &);
public:
/// Returns the op at which the transformation state is rooted. This is
/// typically helpful for transformations that apply globally.
Operation *getTopLevel() const;
/// Returns the number of extra mappings for the top-level operation.
size_t getNumTopLevelMappings() const { return topLevelMappedValues.size(); }
/// Returns the position-th extra mapping for the top-level operation.
ArrayRef<MappedValue> getTopLevelMapping(size_t position) const {
return topLevelMappedValues[position];
}
/// Returns the list of ops that the given transform IR value corresponds to.
/// This is helpful for transformations that apply to a particular handle.
ArrayRef<Operation *> getPayloadOps(Value value) const;
@@ -150,6 +162,8 @@ public:
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
return setPayloadOps(argument, operations);
}
LogicalResult mapBlockArgument(BlockArgument argument,
ArrayRef<MappedValue> values);
// Forward declarations to support limited visibility.
class RegionScope;
@@ -302,6 +316,7 @@ private:
/// which may or may not contain the region with transform ops. Additional
/// options can be provided through the trailing configuration object.
TransformState(Region *region, Operation *payloadRoot,
ArrayRef<ArrayRef<MappedValue>> extraMappings = {},
const TransformOptions &options = TransformOptions());
/// Returns the mappings frame for the reigon in which the value is defined.
@@ -403,6 +418,15 @@ private:
/// The top-level operation that contains all payload IR, typically a module.
Operation *topLevel;
/// Storage for extra mapped values (payload operations or parameters) to be
/// associated with additional entry block arguments of the top-level
/// transform operation. Each entry in `topLevelMappedValues` is a reference
/// to a contiguous block in `topLevelMappedValueStorage`.
// TODO: turn this into a proper named data structure, there are several more
// below.
SmallVector<ArrayRef<MappedValue>> topLevelMappedValues;
SmallVector<MappedValue> topLevelMappedValueStorage;
/// Additional options controlling the transformation state behavior.
TransformOptions options;

View File

@@ -26,6 +26,9 @@ class FailurePropagationModeAttr;
/// A builder function that populates the body of a SequenceOp.
using SequenceBodyBuilderFn = ::llvm::function_ref<void(
::mlir::OpBuilder &, ::mlir::Location, ::mlir::BlockArgument)>;
using SequenceBodyBuilderArgsFn =
::llvm::function_ref<void(::mlir::OpBuilder &, ::mlir::Location,
::mlir::BlockArgument, ::mlir::ValueRange)>;
} // namespace transform
} // namespace mlir

View File

@@ -384,7 +384,8 @@ def SequenceOp : TransformDialectOp<"sequence",
DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
OpAsmOpInterface, PossibleTopLevelTransformOpTrait,
SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">]> {
SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">,
AttrSizedOperandSegments]> {
let summary = "Contains a sequence of other transform ops to apply";
let description = [{
The transformations indicated by the sequence are applied in order of their
@@ -417,12 +418,14 @@ def SequenceOp : TransformDialectOp<"sequence",
}];
let arguments = (ins FailurePropagationMode:$failure_propagation_mode,
Optional<TransformHandleTypeInterface>:$root);
Optional<TransformHandleTypeInterface>:$root,
Variadic<Transform_AnyHandleOrParamType>:$extra_bindings);
let results = (outs Variadic<TransformHandleTypeInterface>:$results);
let regions = (region SizedRegion<1>:$body);
let assemblyFormat =
"($root^ `:` type($root))? (`->` type($results)^)? `failures` `(` "
"custom<SequenceOpOperands>($root, type($root), $extra_bindings, type($extra_bindings))"
" (`->` type($results)^)? `failures` `(` "
"$failure_propagation_mode `)` attr-dict-with-keyword regions";
let builders = [
@@ -432,11 +435,25 @@ def SequenceOp : TransformDialectOp<"sequence",
"::mlir::transform::FailurePropagationMode":$failure_propagation_mode,
"::mlir::Value":$root, "SequenceBodyBuilderFn":$bodyBuilder)>,
// Build a sequence without a root but a certain bbArg type.
// Build a sequence with a root and additional arguments.
OpBuilder<(ins
"::mlir::TypeRange":$resultTypes,
"::mlir::transform::FailurePropagationMode":$failure_propagation_mode,
"::mlir::Type":$bbArgType, "SequenceBodyBuilderFn":$bodyBuilder)>
"::mlir::Value":$root, "::mlir::ValueRange":$extraBindings,
"SequenceBodyBuilderArgsFn":$bodyBuilder)>,
// Build a top-level sequence (no root).
OpBuilder<(ins
"::mlir::TypeRange":$resultTypes,
"::mlir::transform::FailurePropagationMode":$failure_propagation_mode,
"::mlir::Type":$bbArgType, "SequenceBodyBuilderFn":$bodyBuilder)>,
// Build a top-level sequence (no root) with extra arguments.
OpBuilder<(ins
"::mlir::TypeRange":$resultTypes,
"::mlir::transform::FailurePropagationMode":$failure_propagation_mode,
"::mlir::Type":$bbArgType, "::mlir::TypeRange":$extraBindingTypes,
"SequenceBodyBuilderArgsFn":$bodyBuilder)>
];
let extraClassDeclaration = [{

View File

@@ -27,10 +27,20 @@ using namespace mlir;
constexpr const Value transform::TransformState::kTopLevelValue;
transform::TransformState::TransformState(Region *region,
Operation *payloadRoot,
const TransformOptions &options)
transform::TransformState::TransformState(
Region *region, Operation *payloadRoot,
ArrayRef<ArrayRef<MappedValue>> extraMappings,
const TransformOptions &options)
: topLevel(payloadRoot), options(options) {
topLevelMappedValues.reserve(extraMappings.size());
for (ArrayRef<MappedValue> mapping : extraMappings) {
size_t start = topLevelMappedValueStorage.size();
llvm::append_range(topLevelMappedValueStorage, mapping);
topLevelMappedValues.push_back(
ArrayRef<MappedValue>(topLevelMappedValueStorage)
.slice(start, mapping.size()));
}
auto result = mappings.try_emplace(region);
assert(result.second && "the region scope is already present");
(void)result;
@@ -72,6 +82,38 @@ LogicalResult transform::TransformState::getHandlesForPayloadOp(
return success(found);
}
LogicalResult
transform::TransformState::mapBlockArgument(BlockArgument argument,
ArrayRef<MappedValue> values) {
if (argument.getType().isa<TransformHandleTypeInterface>()) {
SmallVector<Operation *> operations;
operations.reserve(values.size());
for (MappedValue value : values) {
if (auto *op = value.dyn_cast<Operation *>()) {
operations.push_back(op);
continue;
}
return emitError(argument.getLoc())
<< "wrong kind of value provided for top-level operation handle";
}
return setPayloadOps(argument, operations);
}
assert(argument.getType().isa<TransformParamTypeInterface>() &&
"unsupported kind of block argument");
SmallVector<Param> parameters;
parameters.reserve(values.size());
for (MappedValue value : values) {
if (auto attr = value.dyn_cast<Attribute>()) {
parameters.push_back(attr);
continue;
}
return emitError(argument.getLoc())
<< "wrong kind of value provided for top-level parameter";
}
return setParams(argument, parameters);
}
LogicalResult
transform::TransformState::setPayloadOps(Value value,
ArrayRef<Operation *> targets) {
@@ -522,12 +564,43 @@ void transform::detail::setApplyToOneResults(
LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments(
TransformState &state, Operation *op, Region &region) {
SmallVector<Operation *> targets;
if (op->getNumOperands() != 0)
SmallVector<SmallVector<MappedValue>> extraMappings;
if (op->getNumOperands() != 0) {
llvm::append_range(targets, state.getPayloadOps(op->getOperand(0)));
else
targets.push_back(state.getTopLevel());
for (Value operand : op->getOperands().drop_front()) {
SmallVector<MappedValue> &mapped = extraMappings.emplace_back();
if (operand.getType().isa<TransformHandleTypeInterface>()) {
llvm::append_range(mapped, state.getPayloadOps(operand));
} else {
assert(operand.getType().isa<TransformParamTypeInterface>() &&
"unsupported kind of transform dialect value");
llvm::append_range(mapped, state.getParams(operand));
}
}
} else {
if (state.getNumTopLevelMappings() !=
region.front().getNumArguments() - 1) {
return emitError(op->getLoc())
<< "operation expects " << region.front().getNumArguments() - 1
<< " extra value bindings, but " << state.getNumTopLevelMappings()
<< " were provided to the interpreter";
}
return state.mapBlockArguments(region.front().getArgument(0), targets);
targets.push_back(state.getTopLevel());
for (unsigned i = 0, e = state.getNumTopLevelMappings(); i < e; ++i)
extraMappings.push_back(llvm::to_vector(state.getTopLevelMapping(i)));
}
if (failed(state.mapBlockArguments(region.front().getArgument(0), targets)))
return failure();
for (BlockArgument argument : region.front().getArguments().drop_front()) {
if (failed(state.mapBlockArgument(
argument, extraMappings[argument.getArgNumber() - 1])))
return failure();
}
return success();
}
LogicalResult
@@ -547,19 +620,42 @@ transform::detail::verifyPossibleTopLevelTransformOpTrait(Operation *op) {
return op->emitOpError() << "expects a single-block region";
Block *body = &bodyRegion->front();
if (body->getNumArguments() != 1 ||
!body->getArgumentTypes()[0].isa<TransformHandleTypeInterface>()) {
if (body->getNumArguments() == 0) {
return op->emitOpError()
<< "expects the entry block to have one argument "
"of type implementing TransformHandleTypeInterface";
<< "expects the entry block to have at least one argument";
}
if (!body->getArgument(0).getType().isa<TransformHandleTypeInterface>()) {
return op->emitOpError()
<< "expects the first entry block argument to be of type "
"implementing TransformHandleTypeInterface";
}
BlockArgument arg = body->getArgument(0);
if (op->getNumOperands() != 0) {
if (arg.getType() != op->getOperand(0).getType()) {
return op->emitOpError()
<< "expects the type of the block argument to match "
"the type of the operand";
}
}
for (BlockArgument arg : body->getArguments().drop_front()) {
if (arg.getType()
.isa<TransformHandleTypeInterface, TransformParamTypeInterface>())
continue;
InFlightDiagnostic diag =
op->emitOpError()
<< "expects trailing entry block arguments to be of type implementing "
"TransformHandleTypeInterface or TransformParamTypeInterface";
diag.attachNote() << "argument #" << arg.getArgNumber() << " does not";
return diag;
}
if (auto *parent =
op->getParentWithTrait<PossibleTopLevelTransformOpTrait>()) {
if (op->getNumOperands() == 0) {
if (op->getNumOperands() != body->getNumArguments()) {
InFlightDiagnostic diag =
op->emitOpError()
<< "expects the root operation to be provided for a nested op";
<< "expects operands to be provided for a nested op";
diag.attachNote(parent->getLoc())
<< "nested in another possible top-level op";
return diag;
@@ -717,9 +813,11 @@ LogicalResult transform::detail::verifyTransformOpInterface(Operation *op) {
// Entry point.
//===----------------------------------------------------------------------===//
LogicalResult transform::applyTransforms(Operation *payloadRoot,
TransformOpInterface transform,
const TransformOptions &options) {
LogicalResult
transform::applyTransforms(Operation *payloadRoot,
TransformOpInterface transform,
ArrayRef<ArrayRef<MappedValue>> extraMapping,
const TransformOptions &options) {
#ifndef NDEBUG
if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
transform->getNumOperands() != 0) {
@@ -730,7 +828,8 @@ LogicalResult transform::applyTransforms(Operation *payloadRoot,
}
#endif // NDEBUG
TransformState state(transform->getParentRegion(), payloadRoot, options);
TransformState state(transform->getParentRegion(), payloadRoot, extraMapping,
options);
return state.applyTransform(transform).checkAndReport();
}

View File

@@ -26,6 +26,16 @@
using namespace mlir;
static ParseResult parseSequenceOpOperands(
OpAsmParser &parser, Optional<OpAsmParser::UnresolvedOperand> &root,
Type &rootType,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
SmallVectorImpl<Type> &extraBindingTypes);
static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op,
Value root, Type rootType,
ValueRange extraBindings,
TypeRange extraBindingTypes);
#define GET_OP_CLASSES
#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
@@ -654,6 +664,76 @@ transform::SequenceOp::apply(transform::TransformResults &results,
return DiagnosedSilenceableFailure::success();
}
static ParseResult parseSequenceOpOperands(
OpAsmParser &parser, Optional<OpAsmParser::UnresolvedOperand> &root,
Type &rootType,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
SmallVectorImpl<Type> &extraBindingTypes) {
OpAsmParser::UnresolvedOperand rootOperand;
OptionalParseResult hasRoot = parser.parseOptionalOperand(rootOperand);
if (!hasRoot.has_value()) {
root = std::nullopt;
return success();
}
if (failed(hasRoot.value()))
return failure();
root = rootOperand;
if (succeeded(parser.parseOptionalComma())) {
if (failed(parser.parseOperandList(extraBindings)))
return failure();
}
if (failed(parser.parseColon()))
return failure();
// The paren is truly optional.
(void)parser.parseOptionalLParen();
if (failed(parser.parseType(rootType))) {
return failure();
}
if (!extraBindings.empty()) {
if (parser.parseComma() || parser.parseTypeList(extraBindingTypes))
return failure();
}
if (extraBindingTypes.size() != extraBindings.size()) {
return parser.emitError(parser.getNameLoc(),
"expected types to be provided for all operands");
}
// The paren is truly optional.
(void)parser.parseOptionalRParen();
return success();
}
static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op,
Value root, Type rootType,
ValueRange extraBindings,
TypeRange extraBindingTypes) {
if (!root)
return;
printer << root;
bool hasExtras = !extraBindings.empty();
if (hasExtras) {
printer << ", ";
printer.printOperands(extraBindings);
}
printer << " : ";
if (hasExtras)
printer << "(";
printer << rootType;
if (hasExtras) {
printer << ", ";
llvm::interleaveComma(extraBindingTypes, printer.getStream());
printer << ")";
}
}
/// Returns `true` if the given op operand may be consuming the handle value in
/// the Transform IR. That is, if it may have a Free effect on it.
static bool isValueUsePotentialConsumer(OpOperand &use) {
@@ -691,22 +771,22 @@ checkDoubleConsume(Value value,
}
LogicalResult transform::SequenceOp::verify() {
assert(getBodyBlock()->getNumArguments() == 1 &&
"the number of arguments must have been verified to be 1 by "
assert(getBodyBlock()->getNumArguments() >= 1 &&
"the number of arguments must have been verified to be more than 1 by "
"PossibleTopLevelTransformOpTrait");
BlockArgument arg = getBodyBlock()->getArgument(0);
if (getRoot()) {
if (arg.getType() != getRoot().getType()) {
return emitOpError() << "expects the type of the block argument to match "
"the type of the operand";
}
if (!getRoot() && !getExtraBindings().empty()) {
return emitOpError()
<< "does not expect extra operands when used as top-level";
}
// Check if the block argument has more than one consuming use.
if (failed(checkDoubleConsume(
arg, [this]() { return (emitOpError() << "block argument #0"); }))) {
return failure();
// Check if a block argument has more than one consuming use.
for (BlockArgument arg : getBodyBlock()->getArguments()) {
if (failed(checkDoubleConsume(arg, [this, arg]() {
return (emitOpError() << "block argument #" << arg.getArgNumber());
}))) {
return failure();
}
}
// Check properties of the nested operations they cannot check themselves.
@@ -740,26 +820,26 @@ LogicalResult transform::SequenceOp::verify() {
return success();
}
/// Appends to `effects` the memory effect instances on `target` with the same
/// resource and effect as the ones the operation `iface` having on `source`.
static void
remapEffects(MemoryEffectOpInterface iface, BlockArgument source, Value target,
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
SmallVector<MemoryEffects::EffectInstance> nestedEffects;
iface.getEffectsOnValue(source, nestedEffects);
for (const auto &effect : nestedEffects)
effects.emplace_back(effect.getEffect(), target, effect.getResource());
}
void transform::SequenceOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
auto *mappingResource = TransformMappingResource::get();
effects.emplace_back(MemoryEffects::Read::get(), getRoot(), mappingResource);
for (Value result : getResults()) {
effects.emplace_back(MemoryEffects::Allocate::get(), result,
mappingResource);
effects.emplace_back(MemoryEffects::Write::get(), result, mappingResource);
}
onlyReadsHandle(getRoot(), effects);
onlyReadsHandle(getExtraBindings(), effects);
producesHandle(getResults(), effects);
if (!getRoot()) {
for (Operation &op : *getBodyBlock()) {
auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
if (!iface) {
// TODO: fill all possible effects; or require ops to actually implement
// the memory effect interface always
assert(false);
}
auto iface = cast<MemoryEffectOpInterface>(&op);
SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects;
iface.getEffects(effects);
}
@@ -769,24 +849,20 @@ void transform::SequenceOp::getEffects(
// Carry over all effects on the argument of the entry block as those on the
// operand, this is the same value just remapped.
for (Operation &op : *getBodyBlock()) {
auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
if (!iface) {
// TODO: fill all possible effects; or require ops to actually implement
// the memory effect interface always
assert(false);
}
auto iface = cast<MemoryEffectOpInterface>(&op);
SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects;
iface.getEffectsOnValue(getBodyBlock()->getArgument(0), nestedEffects);
for (const auto &effect : nestedEffects)
effects.emplace_back(effect.getEffect(), getRoot(), effect.getResource());
remapEffects(iface, getBodyBlock()->getArgument(0), getRoot(), effects);
for (auto [source, target] : llvm::zip(
getBodyBlock()->getArguments().drop_front(), getExtraBindings())) {
remapEffects(iface, source, target, effects);
}
}
}
OperandRange transform::SequenceOp::getSuccessorEntryOperands(
std::optional<unsigned> index) {
assert(index && *index == 0 && "unexpected region index");
if (getOperation()->getNumOperands() == 1)
if (getOperation()->getNumOperands() > 0)
return getOperation()->getOperands();
return OperandRange(getOperation()->operand_end(),
getOperation()->operand_end());
@@ -813,21 +889,51 @@ void transform::SequenceOp::getRegionInvocationBounds(
bounds.emplace_back(1, 1);
}
template <typename FnTy>
static void buildSequenceBody(OpBuilder &builder, OperationState &state,
Type bbArgType, TypeRange extraBindingTypes,
FnTy bodyBuilder) {
SmallVector<Type> types;
types.reserve(1 + extraBindingTypes.size());
types.push_back(bbArgType);
llvm::append_range(types, extraBindingTypes);
OpBuilder::InsertionGuard guard(builder);
Region *region = state.regions.back().get();
Block *bodyBlock = builder.createBlock(region, region->begin(),
extraBindingTypes, {state.location});
// Populate body.
builder.setInsertionPointToStart(bodyBlock);
if constexpr (llvm::function_traits<FnTy>::num_args == 3) {
bodyBuilder(builder, state.location, bodyBlock->getArgument(0));
} else {
bodyBuilder(builder, state.location, bodyBlock->getArgument(0),
bodyBlock->getArguments().drop_front());
}
}
void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
TypeRange resultTypes,
FailurePropagationMode failurePropagationMode,
Value root,
SequenceBodyBuilderFn bodyBuilder) {
build(builder, state, resultTypes, failurePropagationMode, root);
Region *region = state.regions.back().get();
build(builder, state, resultTypes, failurePropagationMode, root,
/*extraBindings=*/ValueRange());
Type bbArgType = root.getType();
OpBuilder::InsertionGuard guard(builder);
Block *bodyBlock = builder.createBlock(
region, region->begin(), TypeRange{bbArgType}, {state.location});
buildSequenceBody(builder, state, bbArgType,
/*extraBindingTypes=*/TypeRange(), bodyBuilder);
}
// Populate body.
builder.setInsertionPointToStart(bodyBlock);
bodyBuilder(builder, state.location, bodyBlock->getArgument(0));
void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
TypeRange resultTypes,
FailurePropagationMode failurePropagationMode,
Value root, ValueRange extraBindings,
SequenceBodyBuilderArgsFn bodyBuilder) {
build(builder, state, resultTypes, failurePropagationMode, root,
extraBindings);
buildSequenceBody(builder, state, root.getType(), extraBindings.getTypes(),
bodyBuilder);
}
void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
@@ -835,15 +941,20 @@ void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
FailurePropagationMode failurePropagationMode,
Type bbArgType,
SequenceBodyBuilderFn bodyBuilder) {
build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value());
Region *region = state.regions.back().get();
OpBuilder::InsertionGuard guard(builder);
Block *bodyBlock = builder.createBlock(
region, region->begin(), TypeRange{bbArgType}, {state.location});
build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
/*extraBindings=*/ValueRange());
buildSequenceBody(builder, state, bbArgType,
/*extraBindingTypes=*/TypeRange(), bodyBuilder);
}
// Populate body.
builder.setInsertionPointToStart(bodyBlock);
bodyBuilder(builder, state.location, bodyBlock->getArgument(0));
void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
TypeRange resultTypes,
FailurePropagationMode failurePropagationMode,
Type bbArgType, TypeRange extraBindingTypes,
SequenceBodyBuilderArgsFn bodyBuilder) {
build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
/*extraBindings=*/ValueRange());
buildSequenceBody(builder, state, bbArgType, extraBindingTypes, bodyBuilder);
}
//===----------------------------------------------------------------------===//

View File

@@ -89,7 +89,9 @@ class ReplicateOp:
class SequenceOp:
def __init__(self, failure_propagation_mode, results: Sequence[Type],
target: Union[Operation, Value, Type]):
target: Union[Operation, Value, Type],
extra_bindings: Optional[Union[Sequence[Value], Sequence[Type],
Operation, OpView]] = None):
root = _get_op_result_or_value(target) if isinstance(
target, (Operation, Value)) else None
root_type = root.type if not isinstance(target, Type) else target
@@ -98,10 +100,25 @@ class SequenceOp:
IntegerType.get_signless(32), failure_propagation_mode._as_int())
else:
failure_propagation_mode = failure_propagation_mode
if extra_bindings is None:
extra_bindings = []
if isinstance(extra_bindings, (Operation, OpView)):
extra_bindings = _get_op_results_or_values(extra_bindings)
extra_binding_types = []
if len(extra_bindings) != 0:
if isinstance(extra_bindings[0], Type):
extra_binding_types = extra_bindings
extra_bindings = []
else:
extra_binding_types = [v.type for v in extra_bindings]
super().__init__(results_=results,
failure_propagation_mode=failure_propagation_mode_attr,
root=root)
self.regions[0].blocks.append(root_type)
root=root,
extra_bindings=extra_bindings)
self.regions[0].blocks.append(*tuple([root_type] + extra_binding_types))
@property
def body(self) -> Block:
@@ -111,6 +128,10 @@ class SequenceOp:
def bodyTarget(self) -> Value:
return self.body.arguments[0]
@property
def bodyExtraArgs(self) -> BlockArgumentList:
return self.body.arguments[1:]
class WithPDLPatternsOp:

View File

@@ -0,0 +1,71 @@
// RUN: mlir-opt %s --pass-pipeline='builtin.module(test-transform-dialect-interpreter{bind-first-extra-to-ops=func.func bind-second-extra-to-ops=func.return})' \
// RUN: --split-input-file --verify-diagnostics
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_op):
transform.test_print_remark_at_operand %arg1, "first extra" : !transform.any_op
transform.test_print_remark_at_operand %arg2, "second extra" : !transform.any_op
}
// expected-remark @below {{first extra}}
func.func @foo() {
// expected-remark @below {{second extra}}
return
}
// expected-remark @below {{first extra}}
func.func @bar(%arg0: i1) {
cf.cond_br %arg0, ^bb1, ^bb2
^bb1:
// expected-remark @below {{second extra}}
return
^bb2:
// expected-remark @below {{second extra}}
return
}
// -----
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.param<i64>):
// expected-error @above {{wrong kind of value provided for top-level parameter}}
}
func.func @foo() {
return
}
// -----
// expected-error @below {{operation expects 1 extra value bindings, but 2 were provided to the interpreter}}
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op):
}
// -----
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_op):
transform.sequence %arg0, %arg1, %arg2 : !transform.any_op, !transform.any_op, !transform.any_op failures(propagate) {
^bb0(%arg3: !transform.any_op, %arg4: !transform.any_op, %arg5: !transform.any_op):
transform.test_print_remark_at_operand %arg4, "first extra" : !transform.any_op
transform.test_print_remark_at_operand %arg5, "second extra" : !transform.any_op
}
}
// expected-remark @below {{first extra}}
func.func @foo() {
// expected-remark @below {{second extra}}
return
}
// expected-remark @below {{first extra}}
func.func @bar(%arg0: i1) {
cf.cond_br %arg0, ^bb1, ^bb2
^bb1:
// expected-remark @below {{second extra}}
return
^bb2:
// expected-remark @below {{second extra}}
return
}

View File

@@ -0,0 +1,24 @@
// RUN: mlir-opt %s --pass-pipeline='builtin.module(test-transform-dialect-interpreter{bind-first-extra-to-params=1,2,3 bind-second-extra-to-params=42,45})' \
// RUN: --split-input-file --verify-diagnostics
transform.sequence failures(propagate) {
^bb0(%arg0: !pdl.operation, %arg1: !transform.param<i64>, %arg2: !transform.param<i64>):
// expected-remark @below {{1 : i64, 2 : i64, 3 : i64}}
transform.test_print_param %arg1 : !transform.param<i64>
// expected-remark @below {{42 : i64, 45 : i64}}
transform.test_print_param %arg2 : !transform.param<i64>
}
// -----
transform.sequence failures(propagate) {
^bb0(%arg0: !pdl.operation, %arg1: !transform.any_op, %arg2: !transform.param<i64>):
// expected-error @above {{wrong kind of value provided for top-level operation handle}}
}
// -----
// expected-error @below {{operation expects 3 extra value bindings, but 2 were provided to the interpreter}}
transform.sequence failures(propagate) {
^bb0(%arg0: !pdl.operation, %arg1: !transform.param<i64>, %arg2: !transform.param<i64>, %arg3: !transform.param<i64>):
}

View File

@@ -1,15 +1,22 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
// expected-error @below {{expects the entry block to have one argument of type implementing TransformHandleTypeInterface}}
// expected-error @below {{expects the entry block to have at least one argument}}
transform.sequence failures(propagate) {
}
// -----
// expected-error @below {{expects the first entry block argument to be of type implementing TransformHandleTypeInterface}}
transform.sequence failures(propagate) {
^bb0(%rag0: i64):
}
// -----
// expected-note @below {{nested in another possible top-level op}}
transform.sequence failures(propagate) {
^bb0(%arg0: !pdl.operation):
// expected-error @below {{expects the root operation to be provided for a nested op}}
// expected-error @below {{expects operands to be provided for a nested op}}
transform.sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
}
@@ -17,6 +24,14 @@ transform.sequence failures(propagate) {
// -----
// expected-error @below {{'transform.sequence' op expects trailing entry block arguments to be of type implementing TransformHandleTypeInterface or TransformParamTypeInterface}}
// expected-note @below {{argument #1 does not}}
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op, %arg1: i64):
}
// -----
// expected-error @below {{expected children ops to implement TransformOpInterface}}
transform.sequence failures(propagate) {
^bb0(%arg0: !pdl.operation):
@@ -46,10 +61,29 @@ transform.sequence failures(propagate) {
// -----
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_op):
// expected-error @below {{expected types to be provided for all operands}}
transform.sequence %arg0, %arg1, %arg2 : (!transform.any_op, !transform.any_op) failures(propagate) {
^bb0(%arg3: !transform.any_op, %arg4: !transform.any_op, %arg5: !transform.any_op):
}
}
// -----
%0 = "test.generate_something"() : () -> !transform.any_op
// expected-error @below {{does not expect extra operands when used as top-level}}
"transform.sequence"(%0) ({
^bb0(%arg0: !transform.any_op):
"transform.yield"() : () -> ()
}) {failure_propagation_mode = 1 : i32, operand_segment_sizes = array<i32: 0, 1>} : (!transform.any_op) -> ()
// -----
// expected-note @below {{nested in another possible top-level op}}
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
// expected-error @below {{expects the root operation to be provided for a nested op}}
// expected-error @below {{expects operands to be provided for a nested op}}
transform.sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
}
@@ -190,7 +224,7 @@ transform.sequence failures(propagate) {
// -----
// expected-error @below {{expects the entry block to have one argument of type implementing TransformHandleTypeInterface}}
// expected-error @below {{expects the entry block to have at least one argument}}
transform.alternatives {
^bb0:
transform.yield

View File

@@ -50,6 +50,33 @@ transform.sequence failures(propagate) {
}
}
// CHECK: transform.sequence failures(propagate)
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_op):
// CHECK: sequence %{{.*}}, %{{.*}}, %{{.*}} : (!transform.any_op, !transform.any_op, !transform.any_op) failures(propagate)
transform.sequence %arg0, %arg1, %arg2 : !transform.any_op, !transform.any_op, !transform.any_op failures(propagate) {
^bb0(%arg3: !transform.any_op, %arg4: !transform.any_op, %arg5: !transform.any_op):
}
}
// CHECK: transform.sequence failures(propagate)
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_op):
// CHECK: sequence %{{.*}}, %{{.*}}, %{{.*}} : (!transform.any_op, !transform.any_op, !transform.any_op) failures(propagate)
transform.sequence %arg0, %arg1, %arg2 : (!transform.any_op, !transform.any_op, !transform.any_op) failures(propagate) {
^bb0(%arg3: !transform.any_op, %arg4: !transform.any_op, %arg5: !transform.any_op):
}
}
// CHECK: transform.sequence failures(propagate)
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_op):
// CHECK: sequence %{{.*}}, %{{.*}}, %{{.*}} : (!transform.any_op, !transform.any_op, !transform.any_op) failures(propagate)
transform.sequence %arg0, %arg1, %arg2 : (!transform.any_op, !transform.any_op, !transform.any_op) failures(propagate) {
^bb0(%arg3: !transform.any_op, %arg4: !transform.any_op, %arg5: !transform.any_op):
}
}
// CHECK: transform.sequence
// CHECK: foreach
transform.sequence failures(propagate) {

View File

@@ -12,6 +12,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
@@ -39,12 +40,72 @@ public:
return "apply transform dialect operations one by one";
}
ArrayRef<transform::MappedValue>
findOperationsByName(Operation *root, StringRef name,
SmallVectorImpl<transform::MappedValue> &storage) {
size_t start = storage.size();
root->walk([&](Operation *op) {
if (op->getName().getStringRef() == name) {
storage.push_back(op);
}
});
return ArrayRef(storage).drop_front(start);
}
ArrayRef<transform::MappedValue>
createParameterMapping(MLIRContext &context, ArrayRef<int> values,
SmallVectorImpl<transform::MappedValue> &storage) {
size_t start = storage.size();
llvm::append_range(storage, llvm::map_range(values, [&](int v) {
Builder b(&context);
return transform::MappedValue(b.getI64IntegerAttr(v));
}));
return ArrayRef(storage).drop_front(start);
}
void runOnOperation() override {
if (!bindFirstExtraToOps.empty() && !bindFirstExtraToParams.empty()) {
emitError(UnknownLoc::get(&getContext()))
<< "cannot bind the first extra top-level argument to both "
"operations and parameters";
return signalPassFailure();
}
if (!bindSecondExtraToOps.empty() && !bindSecondExtraToParams.empty()) {
emitError(UnknownLoc::get(&getContext()))
<< "cannot bind the second extra top-level argument to both "
"operations and parameters";
return signalPassFailure();
}
if ((!bindSecondExtraToOps.empty() || !bindSecondExtraToParams.empty()) &&
bindFirstExtraToOps.empty() && bindFirstExtraToParams.empty()) {
emitError(UnknownLoc::get(&getContext()))
<< "cannot bind the second extra top-level argument without binding "
"the first";
return signalPassFailure();
}
SmallVector<transform::MappedValue> extraMappingStorage;
SmallVector<ArrayRef<transform::MappedValue>> extraMapping;
if (!bindFirstExtraToOps.empty()) {
extraMapping.push_back(findOperationsByName(
getOperation(), bindFirstExtraToOps.getValue(), extraMappingStorage));
} else if (!bindFirstExtraToParams.empty()) {
extraMapping.push_back(createParameterMapping(
getContext(), bindFirstExtraToParams, extraMappingStorage));
}
if (!bindSecondExtraToOps.empty()) {
extraMapping.push_back(findOperationsByName(
getOperation(), bindSecondExtraToOps, extraMappingStorage));
} else if (!bindSecondExtraToParams.empty()) {
extraMapping.push_back(createParameterMapping(
getContext(), bindSecondExtraToParams, extraMappingStorage));
}
ModuleOp module = getOperation();
for (auto op :
module.getBody()->getOps<transform::TransformOpInterface>()) {
if (failed(transform::applyTransforms(
module, op,
module, op, extraMapping,
transform::TransformOptions().enableExpensiveChecks(
enableExpensiveChecks))))
return signalPassFailure();
@@ -55,6 +116,24 @@ public:
*this, "enable-expensive-checks", llvm::cl::init(false),
llvm::cl::desc("perform expensive checks to better report errors in the "
"transform IR")};
Option<std::string> bindFirstExtraToOps{
*this, "bind-first-extra-to-ops",
llvm::cl::desc("bind the first extra argument of the top-level op to "
"payload operations of the given kind")};
ListOption<int> bindFirstExtraToParams{
*this, "bind-first-extra-to-params",
llvm::cl::desc("bind the first extra argument of the top-level op to "
"the given integer parameters")};
Option<std::string> bindSecondExtraToOps{
*this, "bind-second-extra-to-ops",
llvm::cl::desc("bind the second extra argument of the top-level op to "
"payload operations of the given kind")};
ListOption<int> bindSecondExtraToParams{
*this, "bind-second-extra-to-params",
llvm::cl::desc("bind the second extra argument of the top-level op to "
"the given integer parameters")};
};
struct TestTransformDialectEraseSchedulePass

View File

@@ -69,6 +69,38 @@ def testNestedSequenceOp():
# CHECK: }
@run
def testSequenceOpWithExtras():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get(),
[transform.AnyOpType.get(),
transform.OperationType.get("foo.bar")])
with InsertionPoint(sequence.body):
transform.YieldOp()
# CHECK-LABEL: TEST: testSequenceOpWithExtras
# CHECK: transform.sequence failures(propagate)
# CHECK: ^{{.*}}(%{{.*}}: !transform.any_op, %{{.*}}: !transform.any_op, %{{.*}}: !transform.op<"foo.bar">):
@run
def testNestedSequenceOpWithExtras():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get(),
[transform.AnyOpType.get(),
transform.OperationType.get("foo.bar")])
with InsertionPoint(sequence.body):
nested = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
[], sequence.bodyTarget,
sequence.bodyExtraArgs)
with InsertionPoint(nested.body):
transform.YieldOp()
transform.YieldOp()
# CHECK-LABEL: TEST: testNestedSequenceOpWithExtras
# CHECK: transform.sequence failures(propagate)
# CHECK: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op, %[[ARG1:.*]]: !transform.any_op, %[[ARG2:.*]]: !transform.op<"foo.bar">):
# CHECK: sequence %[[ARG0]], %[[ARG1]], %[[ARG2]] : (!transform.any_op, !transform.any_op, !transform.op<"foo.bar">)
@run
def testTransformPDLOps():
withPdl = transform.WithPDLPatternsOp(pdl.OperationType.get())