From a702628843befd93e7ac074c299254ff90b9de63 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Fri, 3 Feb 2023 14:00:33 +0000 Subject: [PATCH] [mlir] add support for transform dialect value handles Introduce support for the third kind of values in the transform dialect: value handles. Similarly to operation handles, value handles are pointing to a set of values in the payload IR. This enables transformation to be targeted at specific values, such as individual results of a multi-result payload operation without indirecting through the producing op or block arguments that previously could not be easily addressed. This is expected to support a broad class of memory-oriented transformations such as selective bufferization, buffer assignment, and memory transfer management. Value handles are functionally similar to operation handles and require similar implementation logic. The most important change concerns the handle invalidation mechanism where operation and value handles can affect each other. This patch includes two cleanups that make it easier to introduce value handles: - `RaggedArray` structure that encapsulates the SmallVector of ArrayRef backed by flat SmallVector logic, frequently used in the transform interfaces implementation; - rewrite the tests that associated payload handles with an integer value `reinterpret_cast`ed as a pointer, which were a frequent source of confusion and crashes when adding more debugging facilities that can inspect the payload. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D143385 --- mlir/docs/Dialects/Transform.md | 179 +++-- .../Transform/IR/TransformInterfaces.h | 201 ++++-- .../Transform/IR/TransformInterfaces.td | 22 +- .../Dialect/Transform/IR/TransformTypes.td | 9 + .../Transforms/TransformInterpreterPassBase.h | 2 +- .../Dialect/Transform/Utils/RaggedArray.h | 92 +++ .../Dialect/Transform/IR/TransformDialect.cpp | 14 +- .../Transform/IR/TransformInterfaces.cpp | 662 ++++++++++++++---- .../Dialect/Transform/IR/TransformTypes.cpp | 10 + .../TransformInterpreterPassBase.cpp | 2 +- .../Dialect/Linalg/transform-op-match.mlir | 8 +- .../Transform/check-use-after-free.mlir | 20 +- .../Dialect/Transform/expensive-checks.mlir | 232 +++++- .../Transform/multi-arg-top-level-ops.mlir | 11 + .../Transform/multi-arg-top-level-values.mlir | 45 ++ mlir/test/Dialect/Transform/ops-invalid.mlir | 32 +- .../Transform/test-dialect-injection.mlir | 8 +- .../Dialect/Transform/test-interpreter.mlir | 115 ++- .../Transform/transform-state-extension.mlir | 12 + .../TestTransformDialectExtension.cpp | 111 ++- .../TestTransformDialectExtension.td | 92 ++- .../TestTransformDialectInterpreter.cpp | 115 +-- 22 files changed, 1606 insertions(+), 388 deletions(-) create mode 100644 mlir/include/mlir/Dialect/Transform/Utils/RaggedArray.h create mode 100644 mlir/test/Dialect/Transform/multi-arg-top-level-values.mlir diff --git a/mlir/docs/Dialects/Transform.md b/mlir/docs/Dialects/Transform.md index e33604c57778..123d66136d39 100644 --- a/mlir/docs/Dialects/Transform.md +++ b/mlir/docs/Dialects/Transform.md @@ -25,20 +25,19 @@ of the IR using a different portion of the IR. It refers to the IR being transformed as payload IR, and to the IR guiding the transformation as transform IR. -The main use case for this dialect is orchestrating fine-grain -transformations on individual operations or sets thereof. For example, it -may involve finding loop-like operations with specific properties (e.g., -large size) in the payload IR, applying loop tiling to those and only those -operations, and then applying loop unrolling to the inner loops produced -by the previous transformations. As such, it is not intended as a -replacement for the pass infrastructure, nor for the pattern rewriting -infrastructure. In the most common case, the transform IR will be processed -and applied to the payload IR by a pass. Transformations expressed by the -transform dialect may be implemented using the pattern infrastructure or any -other relevant MLIR component. +The main use case for this dialect is orchestrating fine-grain transformations +on individual IR objects (operations or values) or sets thereof. For example, it +may involve finding loop-like operations with specific properties (e.g., large +size) in the payload IR, applying loop tiling to those and only those +operations, and then applying loop unrolling to the inner loops produced by the +previous transformations. As such, it is not intended as a replacement for the +pass infrastructure, nor for the pattern rewriting infrastructure. In the most +common case, the transform IR will be processed and applied to the payload IR by +a pass. Transformations expressed by the transform dialect may be implemented +using the pattern infrastructure or any other relevant MLIR component. The following IR gives a rough idea of what the operations in this dialect -may look like: +may look like without using actually existing operations: ```mlir %0 = transform.loop.find { size > 42 } : !transform.interface @@ -46,57 +45,70 @@ may look like: %2:2 = transform.loop.tile %0 tile_sizes(1, 4, %1) : (!transform.interface) -> (!transform.op, !transform.op) +%3 = transform.get_op_result [0] %2#0 : !transform.any_value +transform.assign_to_fast_memory %3 transform.loop.unroll %1#1 : !transform.op ``` -The values used in the Transform dialect may correspond to either: +The values used in the Transform dialect may correspond to: * sets of operations in the payload IR; + * sets of values in the payload IR; + * sets of parameters (attributes) known at the execution time of the transform dialect. -The former kind of values is also referred to as *handles*. In the example -above, `%0` corresponds to the set of loops found in the payload IR that -satisfy the condition, and `%2` correspond to groups of outer and inner -loops, respectively, produced by the tiling transformation, whereas `%1` -corresponds to a list of tile sizes selected for each of the operations -that `%0` corresponds to. +The former two kinds of values are also referred to as operation and value +*handles*, respectively. In the example above, `%0` corresponds to the set of +loops found in the payload IR that satisfy the condition, and `%2` correspond to +groups of outer and inner loops, respectively, produced by the tiling +transformation. `%3` corresponds to a set of values that are produced by the +outer loops after tiling. `%1` corresponds to a list of tile sizes selected for +each of the operations that `%0` corresponds to. -A transform handle such as `%0` may be associated with multiple payload +An operation handle such as `%0` may be associated with multiple payload operations. This is conceptually a set of operations and no assumptions should be made about the order of ops unless specified otherwise by the operation. -Operations may take as operands and produce an arbitrary combination of values -representing handles and parameters. Most Transform IR ops support operand -values that are mapped to multiple operations. They usually apply the respective -transformation for every mapped op ("batched execution"). Deviations from this -convention are described in the documentation of Transform IR ops. +Similarly, a value handle such as `%3` may be associated with a set of payload +IR values. Transform dialect operations may take as operands and produce an +arbitrary combination of values representing handles and parameters. Most +Transform IR ops support operand values that are mapped to multiple payload +objects. They usually apply the respective transformation for every mapped +object ("batched execution"). Deviations from this convention are described in +the documentation of Transform IR ops. -The transform IR values have transform IR types, which implement either -[TransformHandleTypeInterface](Transform.md#transformhandletypeinterface-transformhandletypeinterface) -or -[TransformParamTypeInterface](Transform.md##transformparamtypeinterface-transformparamtypeinterface). -The former interface verifiers properties of payload IR operations associated -with the value that are known to the transform dialect, for example, all -associated payload operations implement a "TileableOp" interface, or have a -specific "loop" kind. Similarly, the latter interface verifies properties of -attributes associated with the parameter value. These properties are used to -statically indicate pre- and post-conditions of a transformation connected to a -Transform dialect operation. The conditions are verified when attributes or -payload IR operations are first associated with a transform handle. By -convention, Transform dialect operations are expected to indicate narrow -preconditions for their operands by enforcing operand type constraints in the -their definitions and verifiers. On the contrary, operations are expected to -have few constraints on their results. Specific instances of a transform -operation can then be created with a more restricted result type than the -constraint in the operation (e.g., the "find" operation only constrains the -result type to be a transform IR type while its concrete instance can have a -type with stricter constraints such as implementing the "tilable" interface). -The verification will then happen at transform execution time. This approach -allows one to capture payload IR operation properties in the transform IR -without resorting to excessive use of type casts or coupling dialect extensions -between themselves. It is a trade-off between verbosity/complexity and static -hardening, which can be revised in the future. +The transform IR values have transform IR types, which should implement exactly one of: + + * [TransformHandleTypeInterface](Transform.md#transformhandletypeinterface-transformhandletypeinterface), + + * [TransformValueHandleTypeInterface](Transform.md#transformvaluehandletypeinterface-transformvaluehandletypeinterface), + + * [TransformParamTypeInterface](Transform.md##transformparamtypeinterface-transformparamtypeinterface). + +The goal of these type interfaces, beyond providing a common base for accepted +types, is to verify the properties of the associated objects. For example, a +handle type interface implementation may check whether all associated payload IR +operations implement the "TileableOp" interface or have a specific "loop" kind. +Similarly, a value handle type interface implementation may check if the +associated payload IR values are block arguments or have a specific type, or a +parameter type interface may check whether the associated attributes contain +non-negative integer values. These properties are used to statically indicate + pre- and post-conditions of a transformation connected to a Transform dialect +operation. The conditions are verified when payload objects operations are first +associated with a transform handle. By convention, Transform dialect operations +are expected to indicate narrow preconditions for their operands by enforcing +operand type constraints in the their definitions and verifiers. On the +contrary, operations are expected to have few constraints on their results. +Specific instances of a transform operation can then be created with a more +restricted result type than the constraint in the operation (e.g., the "find" +operation only constrains the result type to be a transform IR type while its +concrete instance can have a type with stricter constraints such as implementing +the "tilable" interface). The verification will then happen at transform +execution time. This approach allows one to capture payload IR operation +properties in the transform IR without resorting to excessive use of type casts +or coupling dialect extensions between themselves. It is a trade-off between +verbosity/complexity and static hardening, which can be revised in the future. Overall, Transform IR ops are expected to be contained in a single top-level op. Such top-level ops specify how to apply the transformations described @@ -111,7 +123,7 @@ programmatically triggered by calling: ```c++ LogicalResult transform::applyTransforms( Operation *payloadRoot, - ArrayRef> extraMappings, + const RaggedArray &extraMappings, TransformOpInterface transform, const TransformOptions &options); ``` @@ -163,7 +175,7 @@ Similarly to operations, additional types can be injected into the dialect using the same extension mechanism. The types must: * Implement exactly one of `TransformHandleTypeInterface`, - `TransformParamTypeInterface`. + `TransformValueHandleTypeInterface`, `TransformParamTypeInterface`. ## Side Effects @@ -255,18 +267,57 @@ operation lists. ## Handle Invalidation -The execution model of the transform dialect allows a payload IR operation -to be associated with _multiple_ handles as well as nested payload IR -operations to be associated with different handles. A transform IR operation -that consumes a handle automatically _invalidates_ all the other handles -associated with the same payload IR operations, or with any of their -descendants, as the consumed handle. Note that the _entire_ handle is -invalidated, even if some of the payload IR operations associated with it -or their ancestors were not associated with the consumed handle. Any use of -the invalidated handle results in undefined behavior since the payload IR -operations associated with it are likely to have been mutated or erased. The -mere fact of the handle being invalidated does _not_ trigger undefined -behavior, only its appearance as an operand does. +The execution model of the transform dialect allows a payload IR operation to be +associated with _multiple_ handles as well as nested payload IR operations to be +associated with different handles. Similarly, a payload IR value may be +associated with multiple transform IR value handles. When a transform IR +operation consumes a handle, it usually indicates that the corresponding payload +IR object was destroyed and should no longer be referenced. Transform IR handles +that _may_ be pointing to an erased payload IR object are _invalidated_. The +mere presence of an invalidated handle in the transform IR is not a problem, but +_using_ it results in undefined behavior. Invalidated handles can be thought of +as dangling pointers. Note that the _entire_ handle is invalidated, even if some +of the payload IR objects associated with it remain live. + +The following handle invalidation rules apply. + + * When an operation handle is consumed, are invalidated: + + - operation handles associated with one of the payload operations that the + consumed handle is associated with; + + - operation handles associated with one of the operations _nested_ in the + payload operations described above; + + - value handles associated with any result of any operation described above; + + - value handles associated with any argument of a block contained in a + region attached to any operation described above. + + * When a value handle is consumed, are invalidated: + + - operation handles associated with payload operations that produce as + result any value associated with the consumed handle (when the associated + is an operation result); + + - operation handles associated with payload operations _nested_ in the + payload operations described above; + + - operation handles associated with payload operations (recursively) + _contained_ in the block that defines as argument any value associated + with the consumed handle (when the associated value is a block argument); + note that the adjacent blocks are not affected; + + - value handles associated with any result of any operation described above, + including all results of the operation defining as result the value + associated with the consumed handle; + + - value handles associated with any argument of a block contained in a + region attached to any operation described above. + +More intuitively, consuming a handle invalidates any handle that may be pointing +to an object defined or contained in the payload IR subtree rooted at the +closest operation or block. The Transform dialect infrastructure has the capability of checking whether the transform IR op operand is invalidated before applying the diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h index bae5d4568a55..dc9612c9679b 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -11,8 +11,8 @@ #include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/Dialect/Transform/Utils/DiagnosedSilenceableFailure.h" +#include "mlir/Dialect/Transform/Utils/RaggedArray.h" #include "mlir/IR/OpDefinition.h" - #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Support/LogicalResult.h" @@ -45,7 +45,7 @@ private: }; using Param = Attribute; -using MappedValue = llvm::PointerUnion; +using MappedValue = llvm::PointerUnion; /// Entry point to the Transform dialect infrastructure. Applies the /// transformation specified by `transform` to payload IR contained in @@ -55,7 +55,7 @@ using MappedValue = llvm::PointerUnion; /// This function internally keeps track of the transformation state. LogicalResult applyTransforms(Operation *payloadRoot, TransformOpInterface transform, - ArrayRef> extraMapping = {}, + const RaggedArray &extraMapping = {}, const TransformOptions &options = TransformOptions()); /// The state maintained across applications of various ops implementing the @@ -107,16 +107,22 @@ private: /// parameters. using ParamMapping = DenseMap>; + /// Mapping between a Value in the transform IR and the corrsponding list of + /// values in the payload IR. Also works for reverse mappings. + using ValueMapping = DenseMap>; + /// The bidirectional mappings between transform IR values and payload IR /// operations, and the mapping between transform IR values and parameters. struct Mappings { TransformOpMapping direct; TransformOpReverseMapping reverse; ParamMapping params; + ValueMapping values; + ValueMapping reverseValues; }; friend LogicalResult applyTransforms(Operation *, TransformOpInterface, - ArrayRef>, + const RaggedArray &, const TransformOptions &); public: @@ -140,11 +146,21 @@ public: /// corresponds to. ArrayRef getParams(Value value) const; + /// Returns the list of payload IR values that the given transform IR value + /// corresponds to. + ArrayRef getPayloadValues(Value handleValue) const; + /// Populates `handles` with all handles pointing to the given Payload IR op. /// Returns success if such handles exist, failure otherwise. LogicalResult getHandlesForPayloadOp(Operation *op, SmallVectorImpl &handles) const; + /// Populates `handles` with all handles pointing to the given payload IR + /// value. Returns success if such handles exist, failure otherwise. + LogicalResult + getHandlesForPayloadValue(Value payloadValue, + SmallVectorImpl &handles) const; + /// Applies the transformation specified by the given transform op and updates /// the state accordingly. DiagnosedSilenceableFailure applyTransform(TransformOpInterface transform); @@ -319,10 +335,10 @@ private: /// which may or may not contain the region with transform ops. Additional /// options can be provided through the trailing configuration object. TransformState(Region *region, Operation *payloadRoot, - ArrayRef> extraMappings = {}, + const RaggedArray &extraMappings = {}, const TransformOptions &options = TransformOptions()); - /// Returns the mappings frame for the reigon in which the value is defined. + /// Returns the mappings frame for the region in which the value is defined. const Mappings &getMapping(Value value) const { return const_cast(this)->getMapping(value); } @@ -344,10 +360,6 @@ private: return it->second; } - /// Removes the mapping between the given payload IR operation and the given - /// transform IR value. - void dropReverseMapping(Mappings &mappings, Operation *op, Value value); - /// Sets the payload IR ops associated with the given transform IR value /// (handle). A payload op may be associated multiple handles as long as /// at most one of them gets consumed by further transformations. @@ -367,40 +379,111 @@ private: /// by side effects. Practically, a transformation consuming a handle means /// that the associated payload operation may no longer exist. /// + /// Similarly, operation handles may be invalidate and should not be used + /// after a transform that consumed a value handle pointing to a payload value + /// defined by the operation as either block argument or op result. For + /// example, in the following sequence, the last transform operation rewrites + /// the callee to not return a specified result: + /// + /// %0 = transform.find_call "myfunc" + /// %1 = transform.find_results_of_calling "myfunc" + /// transform.drop_call_result_from_signature %1[0] + /// + /// which requires the call operations to be recreated. Therefore, the handle + /// %0 becomes associated with a dangling pointer and should not be used. + /// /// Returns failure if the payload does not satisfy the conditions associated /// with the type of the handle value. The value is expected to have a type /// implementing TransformHandleTypeInterface. LogicalResult setPayloadOps(Value value, ArrayRef targets); + /// Sets the payload IR values association with the given transform IR value + /// (handle). A payload value may be associated with multiple handles as long + /// as at most one of them is consumed by further transformations. For + /// example, a hypothetical "get results of calls to function with the given + /// name" transform may be performed twice in a row producing handles pointing + /// to the same values: + /// + /// %0 = transform.find_results_of_calling "myfunc" + /// %1 = transform.find_results_of_calling "myfunc" + /// + /// which is valid by itself. However, calling a hypothetical "erase value + /// producer" transform on both handles: + /// + /// transform.erase_value_produce %0 + /// transform.erase_value_produce %1 + /// + /// is invalid provided the transformation "consumes" the handle as expressed + /// by side effects (which themselves reflect the semantics of the transform + /// erasing the producer and making the handle dangling). Practically, a + /// transformation consuming a handle means the associated payload value may + /// no longer exist. + /// + /// Similarly, value handles are invalidated and should not be used after a + /// transform that consumed an operation handle pointing to the payload IR + /// operation defining the values associated the value handle, as either block + /// arguments or op results, or any ancestor operation. For example, + /// + /// %0 = transform.find_call "myfunc" + /// %1 = transform.find_results_of_calling "myfunc" + /// transform.rewrite_and_rename %0 { new_name = "func" } + /// + /// makes %1 unusable after the last transformation if it consumes %0. When an + /// operation handle is consumed, it usually indicates that the operation was + /// destroyed or heavily modified, meaning that the values it defines may no + /// longer exist. + /// + /// Returns failure if the payload values do not satisfy the conditions + /// associated with the type of the handle value. The value is expected to + /// have a type implementing TransformValueHandleTypeInterface. + LogicalResult setPayloadValues(Value handle, ValueRange payloadValues); + /// Sets the parameters associated with the given transform IR value. Returns /// failure if the parameters do not satisfy the conditions associated with /// the type of the value. The value is expected to have a type implementing /// TransformParamTypeInterface. LogicalResult setParams(Value value, ArrayRef params); - /// Forgets the payload IR ops associated with the given transform IR value. - void removePayloadOps(Value value); + /// Forgets the payload IR ops associated with the given transform IR value, + /// as well as any association between value handles and the results of said + /// payload IR op. + void forgetMapping(Value opHandle, ValueRange origOpFlatResults); + + void forgetValueMapping(Value valueHandle, + ArrayRef payloadOperations); /// Updates the payload IR ops associated with the given transform IR value. /// The callback function is called once per associated operation and is /// expected to return the modified operation or nullptr. In the latter case, /// the corresponding operation is no longer associated with the transform IR - /// value. + /// value. Value handles associated with the results of the operation are + /// also updated to be associated with the results of the new operation. For + /// this reason, the new operation must have the same number of results. /// /// Returns failure if the payload does not satisfy the conditions associated /// with the type of the handle value. - LogicalResult - updatePayloadOps(Value value, - function_ref callback); + LogicalResult replacePayloadOp(Operation *op, Operation *replacement); /// If the operand is a handle consumed by the operation, i.e. has the "free" /// memory effect associated with it, identifies other handles that are /// pointing to payload IR operations nested in the operations pointed to by /// the consumed handle. Marks all such handles as invalidated to trigger - /// errors if they are used. - void recordHandleInvalidation(OpOperand &handle); - void recordHandleInvalidationOne(OpOperand &handle, Operation *payloadOp, - Value otherHandle); + /// errors if they are used. If `throughValue` is passed, record the fact that + /// an op handle was invalidated because a value handle associated with + /// results of the payload op or its block arguments was invalidated. + void recordOpHandleInvalidation(OpOperand &consumingHandle, + ArrayRef potentialAncestors, + Value throughValue = nullptr); + void recordOpHandleInvalidationOne(OpOperand &handle, + ArrayRef potentialAncestors, + Operation *payloadOp, Value otherHandle, + Value throughValue = nullptr); + + void recordValueHandleInvalidationByOpHandleOne( + OpOperand &opHandle, ArrayRef potentialAncestors, + Value payloadValue, Value valueHandle); + + void recordValueHandleInvalidation(OpOperand &valueHandle); /// Checks that the operation does not use invalidated handles as operands. /// Reports errors and returns failure if it does. Otherwise, invalidates the @@ -421,14 +504,10 @@ private: /// The top-level operation that contains all payload IR, typically a module. Operation *topLevel; - /// Storage for extra mapped values (payload operations or parameters) to be + /// Extra mapped values (payload operations, values or parameters) to be /// associated with additional entry block arguments of the top-level - /// transform operation. Each entry in `topLevelMappedValues` is a reference - /// to a contiguous block in `topLevelMappedValueStorage`. - // TODO: turn this into a proper named data structure, there are several more - // below. - SmallVector> topLevelMappedValues; - SmallVector topLevelMappedValueStorage; + /// transform operation. + RaggedArray topLevelMappedValues; /// Additional options controlling the transformation state behavior. TransformOptions options; @@ -455,16 +534,23 @@ class TransformResults { public: /// Indicates that the result of the transform IR op at the given position /// corresponds to the given list of payload IR ops. Each result must be set - /// by the transformation exactly once. The value must have a type - /// implementing TransformHandleTypeInterface. + /// by the transformation exactly once in case of transformation succeeding. + /// The value must have a type implementing TransformHandleTypeInterface. void set(OpResult value, ArrayRef ops); /// Indicates that the result of the transform IR op at the given position /// corresponds to the given list of parameters. Each result must be set by - /// the transformation exactly once. The value must have a type implementing - /// TransformParamTypeInterface. + /// the transformation exactly once in case of transformation succeeding. The + /// value must have a type implementing TransformParamTypeInterface. void setParams(OpResult value, ArrayRef params); + /// Indicates that the result of the transform IR op at the given position + /// corresponds to the given range of payload IR values. Each result must be + /// set by the transformation exactly once in case of transformation + /// succeeding. The value must have a type implementing + /// TransformValueHandleTypeInterface. + void setValues(OpResult handle, ValueRange values); + private: /// Creates an instance of TransformResults that expects mappings for /// `numSegments` values, which may be associated with payload operations or @@ -481,34 +567,34 @@ private: /// be associated with parameters. ArrayRef getParams(unsigned resultNumber) const; + /// Gets the list of payload IR values associated with the result identified + /// by its number in the list of operation results. The result must have been + /// set to be associated with payload IR values. + ArrayRef getValues(unsigned resultNumber) const; + /// Returns `true` if the result identified by its number in the list of - /// operation results is associated with a list of parameters, `false` if it - /// is associated with the list of payload IR operations. + /// operation results is associated with a list of parameters, `false` + /// otherwise. bool isParam(unsigned resultNumber) const; + /// Returns `true` if the result identified by its number in the list of + /// operation results is associated with a list of payload IR value, `false` + /// otherwise. + bool isValue(unsigned resultNumber) const; + /// Returns `true` if the result identified by its number in the list of /// operation results is associated with something. bool isSet(unsigned resultNumber) const; - /// Storage for pointers to payload IR ops that are associated with results of - /// a transform IR op. `segments` contains as many entries as the transform IR - /// op has results, even if some of them are not associated with payload IR - /// operations. Each entry is a reference to a contiguous segment in the - /// `operations` list that contains the pointers to operations. This allows - /// for operations to be stored contiguously without nested vectors and for - /// different segments to be set in any order. - SmallVector, 2> segments; - SmallVector operations; + /// Pointers to payload IR ops that are associated with results of a transform + /// IR op. + RaggedArray operations; - /// Storage for parameters that are associated with results of the transform - /// IR op. `paramSegments` contains as many entries as the transform IR op has - /// results, even if some of them are not associated with parameters. Each - /// entry is a reference to a contiguous segment in the `params` list that - /// contains the actual parameters. This allows for parameters to be stored - /// contiguously without nested vectors and for different segments to be set - /// in any order. - SmallVector, 2> paramSegments; - SmallVector params; + /// Parameters that are associated with results of the transform IR op. + RaggedArray params; + + /// Payload IR values that are associated with results of a transform IR op. + RaggedArray values; }; TransformState::RegionScope TransformState::make_region_scope(Region ®ion) { @@ -625,14 +711,14 @@ public: /// Side effect resource corresponding to the mapping between Transform IR /// values and Payload IR operations. An Allocate effect from this resource /// means creating a new mapping entry, it is always accompanied by a Write -/// effet. A Read effect from this resource means accessing the mapping. A Free +/// effect. A Read effect from this resource means accessing the mapping. A Free /// effect on this resource indicates the removal of the mapping entry, /// typically after a transformation that modifies the Payload IR operations /// associated with one of the Transform IR operation's operands. It is always /// accompanied by a Read effect. Read-after-Free and double-Free are not /// allowed (they would be problematic with "regular" memory effects too) as /// they indicate an attempt to access Payload IR operations that have been -/// modified, potentially erased, by the previous tranfsormations. +/// modified, potentially erased, by the previous transformations. // TODO: consider custom effects if these are not enabling generic passes such // as CSE/DCE to work. struct TransformMappingResource @@ -769,7 +855,7 @@ namespace transform { /// A single result of applying a transform op with `ApplyEachOpTrait` to a /// single payload operation. -using ApplyToEachResult = llvm::PointerUnion; +using ApplyToEachResult = MappedValue; /// A list of results of applying a transform op with `ApplyEachOpTrait` to a /// single payload operation, co-indexed with the results of the transform op. @@ -793,6 +879,9 @@ public: if constexpr (std::is_convertible_v) { results.push_back(static_cast(element)); + } else if constexpr (std::is_convertible_v) { + results.push_back(element.template get()); } else { results.push_back(static_cast(element)); } @@ -800,8 +889,12 @@ public: } /// Appends an element to the list. + // Using ApplyToEachResult that can be implicitly constructed from a Value but + // not from a concrete Op that is implicitly convertible to a Value to avoid + // ambiguity. void push_back(Operation *op) { results.push_back(op); } void push_back(Attribute attr) { results.push_back(attr); } + void push_back(ApplyToEachResult r) { results.push_back(r); } /// Reserves space for `size` elements in the list. void reserve(unsigned size) { results.reserve(size); } diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td index 22c0c94b2760..f4439616b947 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td @@ -137,10 +137,10 @@ def TransformHandleTypeInterface : TransformTypeInterfaceBase<"TransformHandleTypeInterface", "::mlir::Operation *"> { let description = [{ - Types that can be used for the Transform dialect handle values. Such types - define the properties of Payload IR operations associated with the handle. - A user of such a handle can assume that these properties have been verified - for any Payload IR operation associated with it. + Types that can be used for the Transform dialect operation handle values. + Such types define the properties of Payload IR operations associated with + the handle. A user of such a handle can assume that these properties have + been verified for any Payload IR operation associated with it. }]; } @@ -155,9 +155,21 @@ def TransformParamTypeInterface }]; } +def TransformValueHandleTypeInterface + : TransformTypeInterfaceBase<"TransformValueHandleTypeInterface", + "::mlir::Value"> { + let description = [{ + Types that can be used for the Transform dialect handle values pointing to + Payload IR values. Such types define the properties of Payload IR values + associated with the handle. Users of such a handle can assume that these + properties have been verified for any Payload IR value associated with it. + }]; +} + def Transform_AnyHandleOrParamType : Type, + TransformHandleTypeInterface.predicate, + TransformValueHandleTypeInterface.predicate]>, "any transform handle or parameter">; def FunctionalStyleTransformOpTrait diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td b/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td index ebaf576451d7..9eece0fb3dcf 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td @@ -52,6 +52,15 @@ def Transform_ParamType : TypeDef]> { + let description = [{ + Transform IR value that can be associated with a list of Payload IR values. + }]; + let mnemonic = "any_value"; + let assemblyFormat = ""; +} + class Transform_ConcreteOpType : Type()" diff --git a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h index 1b1ad6a74ecd..0a60b4c644f2 100644 --- a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h +++ b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h @@ -40,7 +40,7 @@ interpreterBaseInitializeImpl(MLIRContext *context, StringRef transformFileName, LogicalResult interpreterBaseRunOnOperationImpl( Operation *target, StringRef passName, const std::shared_ptr> &sharedTransformModule, - ArrayRef> extraMappings, + const RaggedArray &extraMappings, const TransformOptions &options, const Pass::Option &transformFileName, const Pass::Option &debugPayloadRootTag, diff --git a/mlir/include/mlir/Dialect/Transform/Utils/RaggedArray.h b/mlir/include/mlir/Dialect/Transform/Utils/RaggedArray.h new file mode 100644 index 000000000000..c3c38d61bef3 --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/Utils/RaggedArray.h @@ -0,0 +1,92 @@ +//===- RaggedArray.h - 2D array with different inner lengths ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" + +namespace mlir { +/// A 2D array where each row may have different length. Elements of each row +/// are stored contiguously, but rows don't have a fixed order in the storage. +template +class RaggedArray { +public: + /// Returns the number of rows in the 2D array. + size_t size() const { return slices.size(); } + + /// Returns true if the are no rows in the 2D array. Note that an array with a + /// non-zero number of empty rows is *NOT* empty. + bool empty() const { return slices.empty(); } + + /// Accesses `pos`-th row. + ArrayRef operator[](size_t pos) const { return at(pos); } + ArrayRef at(size_t pos) const { return slices[pos]; } + MutableArrayRef operator[](size_t pos) { return at(pos); } + MutableArrayRef at(size_t pos) { return slices[pos]; } + + /// Iterator over rows. + auto begin() { return slices.begin(); } + auto begin() const { return slices.begin(); } + auto end() { return slices.end(); } + auto end() const { return slices.end(); } + + /// Reserve space to store `size` rows with `nestedSize` elements each. + void reserve(size_t size, size_t nestedSize = 0) { + slices.reserve(size); + storage.reserve(size * nestedSize); + } + + /// Appends the given range of elements as a new row to the 2D array. May + /// invalidate the end iterator. + template + void push_back(Range &&elements) { + slices.push_back(appendToStorage(std::forward(elements))); + } + + /// Replaces the `pos`-th row in the 2D array with the given range of + /// elements. Invalidates iterators and references to `pos`-th and all + /// succeeding rows. + template + void replace(size_t pos, Range &&elements) { + auto from = slices[pos].data(); + if (from != nullptr) { + auto to = std::next(from, slices[pos].size()); + auto newFrom = storage.erase(from, to); + // Update the array refs after the underlying storage was shifted. + for (size_t i = pos + 1, e = size(); i < e; ++i) { + slices[i] = MutableArrayRef(newFrom, slices[i].size()); + std::advance(newFrom, slices[i].size()); + } + } + slices[pos] = appendToStorage(std::forward(elements)); + } + + /// Appends `num` empty rows to the array. + void appendEmptyRows(size_t num) { slices.resize(slices.size() + num); } + +private: + /// Appends the given elements to the storage and returns an ArrayRef pointing + /// to them in the storage. + template + MutableArrayRef appendToStorage(Range &&elements) { + size_t start = storage.size(); + llvm::append_range(storage, std::forward(elements)); + return MutableArrayRef(storage).drop_front(start); + } + + /// Outer elements of the ragged array. Each entry is a reference to a + /// contiguous segment in the `storage` list that contains the actual + /// elements. This allows for elements to be stored contiguously without + /// nested vectors and for different segments to be set or replaced in any + /// order. + SmallVector> slices; + + /// Dense storage for ragged array elements. + SmallVector storage; +}; +} // namespace mlir diff --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp index fadc9ce6ed0a..1f61ecd97162 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp @@ -38,12 +38,14 @@ void transform::detail::checkImplementsTransformOpInterface( void transform::detail::checkImplementsTransformHandleTypeInterface( TypeID typeID, MLIRContext *context) { const auto &abstractType = AbstractType::lookup(typeID, context); - assert( - (abstractType.hasInterface( - TransformHandleTypeInterface::getInterfaceID()) || - abstractType.hasInterface( - TransformParamTypeInterface::getInterfaceID())) && - "expected Transform dialect type to implement one of the two interfaces"); + assert((abstractType.hasInterface( + TransformHandleTypeInterface::getInterfaceID()) || + abstractType.hasInterface( + TransformParamTypeInterface::getInterfaceID()) || + abstractType.hasInterface( + TransformValueHandleTypeInterface::getInterfaceID())) && + "expected Transform dialect type to implement one of the three " + "interfaces"); } #endif // NDEBUG diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp index e14fca25af77..6b0f59da9101 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Operation.h" +#include "mlir/Support/LogicalResult.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/Support/Debug.h" @@ -29,17 +30,12 @@ constexpr const Value transform::TransformState::kTopLevelValue; transform::TransformState::TransformState( Region *region, Operation *payloadRoot, - ArrayRef> extraMappings, + const RaggedArray &extraMappings, const TransformOptions &options) : topLevel(payloadRoot), options(options) { topLevelMappedValues.reserve(extraMappings.size()); - for (ArrayRef mapping : extraMappings) { - size_t start = topLevelMappedValueStorage.size(); - llvm::append_range(topLevelMappedValueStorage, mapping); - topLevelMappedValues.push_back( - ArrayRef(topLevelMappedValueStorage) - .slice(start, mapping.size())); - } + for (ArrayRef mapping : extraMappings) + topLevelMappedValues.push_back(mapping); auto result = mappings.try_emplace(region); assert(result.second && "the region scope is already present"); @@ -55,16 +51,26 @@ ArrayRef transform::TransformState::getPayloadOps(Value value) const { const TransformOpMapping &operationMapping = getMapping(value).direct; auto iter = operationMapping.find(value); - assert(iter != operationMapping.end() && - "cannot find mapping for payload handle (param handle provided?)"); + assert( + iter != operationMapping.end() && + "cannot find mapping for payload handle (param/value handle provided?)"); return iter->getSecond(); } ArrayRef transform::TransformState::getParams(Value value) const { const ParamMapping &mapping = getMapping(value).params; auto iter = mapping.find(value); - assert(iter != mapping.end() && - "cannot find mapping for param handle (payload handle provided?)"); + assert(iter != mapping.end() && "cannot find mapping for param handle " + "(operation/value handle provided?)"); + return iter->getSecond(); +} + +ArrayRef +transform::TransformState::getPayloadValues(Value handleValue) const { + const ValueMapping &mapping = getMapping(handleValue).values; + auto iter = mapping.find(handleValue); + assert(iter != mapping.end() && "cannot find mapping for value handle " + "(param/operation handle provided?)"); return iter->getSecond(); } @@ -82,6 +88,20 @@ LogicalResult transform::TransformState::getHandlesForPayloadOp( return success(found); } +LogicalResult transform::TransformState::getHandlesForPayloadValue( + Value payloadValue, SmallVectorImpl &handles) const { + bool found = false; + for (const Mappings &mapping : llvm::make_second_range(mappings)) { + auto iterator = mapping.reverseValues.find(payloadValue); + if (iterator != mapping.reverseValues.end()) { + llvm::append_range(handles, iterator->getSecond()); + found = true; + } + } + + return success(found); +} + LogicalResult transform::TransformState::mapBlockArgument(BlockArgument argument, ArrayRef values) { @@ -99,6 +119,20 @@ transform::TransformState::mapBlockArgument(BlockArgument argument, return setPayloadOps(argument, operations); } + if (argument.getType().isa()) { + SmallVector payloadValues; + payloadValues.reserve(values.size()); + for (MappedValue value : values) { + if (auto v = value.dyn_cast()) { + payloadValues.push_back(v); + continue; + } + return emitError(argument.getLoc()) + << "wrong kind of value provided for the top-level value handle"; + } + return setPayloadValues(argument, payloadValues); + } + assert(argument.getType().isa() && "unsupported kind of block argument"); SmallVector parameters; @@ -119,8 +153,8 @@ transform::TransformState::setPayloadOps(Value value, ArrayRef targets) { assert(value != kTopLevelValue && "attempting to reset the transformation root"); - assert(!value.getType().isa() && - "cannot associate payload ops with a value of parameter type"); + assert(value.getType().isa() && + "wrong handle type"); for (Operation *target : targets) { if (target) @@ -150,6 +184,41 @@ transform::TransformState::setPayloadOps(Value value, return success(); } +LogicalResult +transform::TransformState::setPayloadValues(Value handle, + ValueRange payloadValues) { + assert(handle != nullptr && "attempting to set params for a null value"); + assert(handle.getType().isa() && + "wrong handle type"); + + for (Value payload : payloadValues) { + if (payload) + continue; + return emitError(handle.getLoc()) << "attempting to assign a null payload " + "value to this transform handle"; + } + + auto iface = handle.getType().cast(); + SmallVector payloadValueVector = llvm::to_vector(payloadValues); + DiagnosedSilenceableFailure result = + iface.checkPayload(handle.getLoc(), payloadValueVector); + if (failed(result.checkAndReport())) + return failure(); + + Mappings &mappings = getMapping(handle); + bool inserted = + mappings.values.insert({handle, std::move(payloadValueVector)}).second; + assert( + inserted && + "value handle is already associated with another list of payload values"); + (void)inserted; + + for (Value payload : payloadValues) + mappings.reverseValues[payload].push_back(handle); + + return success(); +} + LogicalResult transform::TransformState::setParams(Value value, ArrayRef params) { assert(value != nullptr && "attempting to set params for a null value"); @@ -177,54 +246,146 @@ LogicalResult transform::TransformState::setParams(Value value, return success(); } -void transform::TransformState::dropReverseMapping(Mappings &mappings, - Operation *op, Value value) { - auto it = mappings.reverse.find(op); - if (it == mappings.reverse.end()) +template +void dropMappingEntry(Mapping &mapping, Key key, Mapped mapped) { + auto it = mapping.find(key); + if (it == mapping.end()) return; - llvm::erase_value(it->getSecond(), value); + llvm::erase_value(it->getSecond(), mapped); if (it->getSecond().empty()) - mappings.reverse.erase(it); + mapping.erase(it); } -void transform::TransformState::removePayloadOps(Value value) { - Mappings &mappings = getMapping(value); - for (Operation *op : mappings.direct[value]) - dropReverseMapping(mappings, op, value); - mappings.direct.erase(value); +void transform::TransformState::forgetMapping(Value opHandle, + ValueRange origOpFlatResults) { + Mappings &mappings = getMapping(opHandle); + for (Operation *op : mappings.direct[opHandle]) + dropMappingEntry(mappings.reverse, op, opHandle); + mappings.direct.erase(opHandle); + + for (Value opResult : origOpFlatResults) { + SmallVector resultHandles; + (void)getHandlesForPayloadValue(opResult, resultHandles); + for (Value resultHandle : resultHandles) { + Mappings &localMappings = getMapping(resultHandle); + dropMappingEntry(localMappings.values, resultHandle, opResult); + dropMappingEntry(localMappings.reverseValues, opResult, resultHandle); + } + } } -LogicalResult transform::TransformState::updatePayloadOps( - Value value, function_ref callback) { - Mappings &mappings = getMapping(value); - auto it = mappings.direct.find(value); - assert(it != mappings.direct.end() && "unknown handle"); - SmallVector &association = it->getSecond(); - SmallVector updated; - updated.reserve(association.size()); +void transform::TransformState::forgetValueMapping( + Value valueHandle, ArrayRef payloadOperations) { + Mappings &mappings = getMapping(valueHandle); + for (Value payloadValue : mappings.reverseValues[valueHandle]) + dropMappingEntry(mappings.reverseValues, payloadValue, valueHandle); + mappings.values.erase(valueHandle); - for (Operation *op : association) { - dropReverseMapping(mappings, op, value); - if (Operation *updatedOp = callback(op)) { - updated.push_back(updatedOp); - mappings.reverse[updatedOp].push_back(value); + for (Operation *payloadOp : payloadOperations) { + SmallVector opHandles; + (void)getHandlesForPayloadOp(payloadOp, opHandles); + for (Value opHandle : opHandles) { + Mappings &localMappings = getMapping(opHandle); + dropMappingEntry(localMappings.direct, opHandle, payloadOp); + dropMappingEntry(localMappings.reverse, payloadOp, opHandle); + } + } +} + +LogicalResult +transform::TransformState::replacePayloadOp(Operation *op, + Operation *replacement) { + // Drop the mapping between the op and all handles that point to it. Don't + // care if there are on such handles. + SmallVector opHandles; + (void)getHandlesForPayloadOp(op, opHandles); + for (Value handle : opHandles) { + Mappings &mappings = getMapping(handle); + dropMappingEntry(mappings.reverse, op, handle); + } + + // Drop the mapping between the op results and all value handles that point to + // them. Don't care if there are no such handles. + RaggedArray resultValueHandles; + for (Value opResult : op->getResults()) { + SmallVector valueHandles; + (void)getHandlesForPayloadValue(opResult, valueHandles); + for (Value handle : valueHandles) { + Mappings &localMappings = getMapping(handle); + dropMappingEntry(localMappings.reverseValues, opResult, handle); + } + resultValueHandles.push_back(std::move(valueHandles)); + } + + // TODO: consider invalidating the handles to nested objects here. + + // If replacing with null, that is erasing the mapping, drop the mapping + // between the handles and the IR objects and return. + if (!replacement) { + for (Value handle : opHandles) { + Mappings &mappings = getMapping(handle); + dropMappingEntry(mappings.direct, handle, op); + } + for (Value opResult : op->getResults()) { + SmallVector valueHandles; + (void)getHandlesForPayloadValue(opResult, valueHandles); + for (Value handle : valueHandles) { + Mappings &localMappings = getMapping(handle); + dropMappingEntry(localMappings.values, handle, opResult); + } + } + return success(); + } + + // Otherwise, replace the pointed-to object of all handles while preserving + // their relative order. + if (op->getNumResults() != replacement->getNumResults()) { + return emitError(op->getLoc()) + << "cannot replace an op with another op producing a different " + "number of results while tracking handles"; + } + + // Replace the mapped operation if present. + for (Value handle : opHandles) { + Mappings &mappings = getMapping(handle); + auto it = mappings.direct.find(handle); + if (it == mappings.direct.end()) + continue; + + SmallVector &association = it->getSecond(); + // Note that an operation may be associated with the handle more than once. + for (Operation *&mapped : association) { + if (mapped == op) + mapped = replacement; + } + mappings.reverse[replacement].push_back(handle); + } + + // Replace the mapped results of the operation. + for (auto [origResult, replacementResult, handleList] : llvm::zip( + op->getResults(), replacement->getResults(), resultValueHandles)) { + for (Value resultHandle : handleList) { + Mappings &mappings = getMapping(resultHandle); + auto it = mappings.values.find(resultHandle); + if (it == mappings.values.end()) + continue; + + SmallVector &association = it->getSecond(); + for (Value &mapped : association) { + if (mapped == origResult) + mapped = replacementResult; + } + mappings.reverseValues[replacementResult].push_back(resultHandle); } } - auto iface = value.getType().cast(); - DiagnosedSilenceableFailure result = - iface.checkPayload(value.getLoc(), updated); - if (failed(result.checkAndReport())) - return failure(); - - it->second = updated; return success(); } -void transform::TransformState::recordHandleInvalidationOne( - OpOperand &handle, Operation *payloadOp, Value otherHandle) { - ArrayRef potentialAncestors = getPayloadOps(handle.get()); +void transform::TransformState::recordOpHandleInvalidationOne( + OpOperand &consumingHandle, ArrayRef potentialAncestors, + Operation *payloadOp, Value otherHandle, Value throughValue) { // If the op is associated with invalidated handle, skip the check as it // may be reading invalid IR. if (invalidatedHandles.count(otherHandle)) @@ -240,10 +401,13 @@ void transform::TransformState::recordHandleInvalidationOne( // deleted before the lambda gets called. Location ancestorLoc = ancestor->getLoc(); Location opLoc = payloadOp->getLoc(); - Operation *owner = handle.getOwner(); - unsigned operandNo = handle.getOperandNumber(); + Operation *owner = consumingHandle.getOwner(); + unsigned operandNo = consumingHandle.getOperandNumber(); + std::optional throughValueLoc = + throughValue ? std::make_optional(throughValue.getLoc()) : std::nullopt; invalidatedHandles[otherHandle] = [ancestorLoc, opLoc, owner, operandNo, - otherHandle](Location currentLoc) { + otherHandle, + throughValueLoc](Location currentLoc) { InFlightDiagnostic diag = emitError(currentLoc) << "op uses a handle invalidated by a " "previously executed transform op"; @@ -251,19 +415,144 @@ void transform::TransformState::recordHandleInvalidationOne( diag.attachNote(owner->getLoc()) << "invalidated by this transform op that consumes its operand #" << operandNo - << " and invalidates handles to payload ops nested in payload " - "ops associated with the consumed handle"; + << " and invalidates all handles to payload IR entities associated " + "with this operand and entities nested in them"; diag.attachNote(ancestorLoc) << "ancestor payload op"; diag.attachNote(opLoc) << "nested payload op"; + if (throughValueLoc) { + diag.attachNote(*throughValueLoc) + << "consumed handle points to this payload value"; + } }; } } -void transform::TransformState::recordHandleInvalidation(OpOperand &handle) { - for (const Mappings &mapping : llvm::make_second_range(mappings)) - for (const auto &[payloadOp, otherHandles] : mapping.reverse) +void transform::TransformState::recordValueHandleInvalidationByOpHandleOne( + OpOperand &consumingHandle, ArrayRef potentialAncestors, + Value payloadValue, Value valueHandle) { + // If the op is associated with invalidated handle, skip the check as it + // may be reading invalid IR. + if (invalidatedHandles.count(valueHandle)) + return; + + for (Operation *ancestor : potentialAncestors) { + Operation *definingOp; + std::optional resultNo = std::nullopt; + unsigned argumentNo, blockNo, regionNo; + if (auto opResult = payloadValue.dyn_cast()) { + definingOp = opResult.getOwner(); + resultNo = opResult.getResultNumber(); + } else { + auto arg = payloadValue.cast(); + definingOp = arg.getParentBlock()->getParentOp(); + argumentNo = arg.getArgNumber(); + blockNo = std::distance(arg.getOwner()->getParent()->begin(), + arg.getOwner()->getIterator()); + regionNo = arg.getOwner()->getParent()->getRegionNumber(); + } + assert(definingOp && "expected the value to be defined by an op as result " + "or block argument"); + if (!ancestor->isAncestor(definingOp)) + continue; + + Operation *owner = consumingHandle.getOwner(); + unsigned operandNo = consumingHandle.getOperandNumber(); + Location ancestorLoc = ancestor->getLoc(); + Location opLoc = definingOp->getLoc(); + Location valueLoc = payloadValue.getLoc(); + invalidatedHandles[valueHandle] = + [valueHandle, owner, operandNo, resultNo, argumentNo, blockNo, regionNo, + ancestorLoc, opLoc, valueLoc](Location currentLoc) { + InFlightDiagnostic diag = emitError(currentLoc) + << "op uses a handle invalidated by a " + "previously executed transform op"; + diag.attachNote(valueHandle.getLoc()) << "invalidated handle"; + diag.attachNote(owner->getLoc()) + << "invalidated by this transform op that consumes its operand #" + << operandNo + << " and invalidates all handles to payload IR entities " + "associated with this operand and entities nested in them"; + diag.attachNote(ancestorLoc) + << "ancestor op associated with the consumed handle"; + if (resultNo) { + diag.attachNote(opLoc) + << "op defining the value as result #" << *resultNo; + } else { + diag.attachNote(opLoc) + << "op defining the value as block argument #" << argumentNo + << " of block #" << blockNo << " in region #" << regionNo; + } + diag.attachNote(valueLoc) << "payload value"; + }; + } +} + +void transform::TransformState::recordOpHandleInvalidation( + OpOperand &handle, ArrayRef potentialAncestors, + Value throughValue) { + // Iterate over the mapping and invalidate aliasing handles. This is quite + // expensive and only necessary for error reporting in case of transform + // dialect misuse with dangling handles. Iteration over the handles is based + // on the assumption that the number of handles is significantly less than the + // number of IR objects (operations and values). Alternatively, we could walk + // the IR nested in each payload op associated with the given handle and look + // for handles associated with each operation and value. + for (const Mappings &mapping : llvm::make_second_range(mappings)) { + // Go over all op handle mappings and mark as invalidated any handle + // pointing to any of the payload ops associated with the given handle or + // any op nested in them. + for (const auto &[payloadOp, otherHandles] : mapping.reverse) { for (Value otherHandle : otherHandles) - recordHandleInvalidationOne(handle, payloadOp, otherHandle); + recordOpHandleInvalidationOne(handle, potentialAncestors, payloadOp, + otherHandle, throughValue); + } + // Go over all value handle mappings and mark as invalidated any handle + // pointing to any result of the payload op associated with the given handle + // or any op nested in them. Similarly invalidate handles to argument of + // blocks belonging to any region of any payload op associated with the + // given handle or any op nested in them. + for (const auto &[payloadValue, valueHandles] : mapping.reverseValues) { + for (Value valueHandle : valueHandles) + recordValueHandleInvalidationByOpHandleOne(handle, potentialAncestors, + payloadValue, valueHandle); + } + } +} + +void transform::TransformState::recordValueHandleInvalidation( + OpOperand &valueHandle) { + // Invalidate other handles to the same value. + for (Value payloadValue : getPayloadValues(valueHandle.get())) { + SmallVector otherValueHandles; + (void)getHandlesForPayloadValue(payloadValue, otherValueHandles); + for (Value otherHandle : otherValueHandles) { + Operation *owner = valueHandle.getOwner(); + unsigned operandNo = valueHandle.getOperandNumber(); + Location valueLoc = payloadValue.getLoc(); + invalidatedHandles[otherHandle] = [otherHandle, owner, operandNo, + valueLoc](Location currentLoc) { + InFlightDiagnostic diag = emitError(currentLoc) + << "op uses a handle invalidated by a " + "previously executed transform op"; + diag.attachNote(otherHandle.getLoc()) << "invalidated handle"; + diag.attachNote(owner->getLoc()) + << "invalidated by this transform op that consumes its operand #" + << operandNo + << " and invalidates handles to the same values as associated with " + "it"; + diag.attachNote(valueLoc) << "payload value"; + }; + } + + if (auto opResult = payloadValue.dyn_cast()) { + Operation *payloadOp = opResult.getOwner(); + recordOpHandleInvalidation(valueHandle, payloadOp, payloadValue); + } else { + auto arg = payloadValue.dyn_cast(); + for (Operation &payloadOp : *arg.getOwner()) + recordOpHandleInvalidation(valueHandle, &payloadOp, payloadValue); + } + } } LogicalResult transform::TransformState::checkAndRecordHandleInvalidation( @@ -287,13 +576,44 @@ LogicalResult transform::TransformState::checkAndRecordHandleInvalidation( return isa(effect.getEffect()) && effect.getValue() == target.get(); }; - if (llvm::any_of(effects, consumesTarget)) - recordHandleInvalidation(target); + if (llvm::any_of(effects, consumesTarget)) { + if (target.get().getType().isa()) { + ArrayRef payloadOps = getPayloadOps(target.get()); + recordOpHandleInvalidation(target, payloadOps); + } else if (target.get() + .getType() + .isa()) { + recordValueHandleInvalidation(target); + } + } } return success(); } +template +DiagnosedSilenceableFailure +checkRepeatedConsumptionInOperand(ArrayRef payload, + transform::TransformOpInterface transform, + unsigned operandNumber) { + DenseSet seen; + for (T p : payload) { + if (!seen.insert(p).second) { + DiagnosedSilenceableFailure diag = + transform.emitSilenceableError() + << "a handle passed as operand #" << operandNumber + << " and consumed by this operation points to a payload " + "entity more than once"; + if constexpr (std::is_pointer_v) + diag.attachNote(p->getLoc()) << "repeated target op"; + else + diag.attachNote(p.getLoc()) << "repeated target value"; + return diag; + } + } + return DiagnosedSilenceableFailure::success(); +} + DiagnosedSilenceableFailure transform::TransformState::applyTransform(TransformOpInterface transform) { LLVM_DEBUG(DBGS() << "applying: " << transform << "\n"); @@ -313,25 +633,82 @@ transform::TransformState::applyTransform(TransformOpInterface transform) { if (!isHandleConsumed(operand.get(), transform)) continue; - DenseSet seen; - for (Operation *op : getPayloadOps(operand.get())) { - if (!seen.insert(op).second) { - DiagnosedSilenceableFailure diag = - transform.emitSilenceableError() - << "a handle passed as operand #" << operand.getOperandNumber() - << " and consumed by this operation points to a payload " - "operation more than once"; - diag.attachNote(op->getLoc()) << "repeated target op"; - return diag; - } + Type operandType = operand.get().getType(); + if (operandType.isa()) { + DiagnosedSilenceableFailure check = + checkRepeatedConsumptionInOperand( + getPayloadOps(operand.get()), transform, + operand.getOperandNumber()); + if (!check.succeeded()) + return check; + } else if (operandType.isa()) { + DiagnosedSilenceableFailure check = + checkRepeatedConsumptionInOperand( + getPayloadValues(operand.get()), transform, + operand.getOperandNumber()); + if (!check.succeeded()) + return check; } } } - transform::TransformResults results(transform->getNumResults()); + // Find which operands are consumed. + DenseSet consumedOperands; + auto memEffectInterface = + cast(transform.getOperation()); + SmallVector effects; + for (OpOperand &target : transform->getOpOperands()) { + effects.clear(); + memEffectInterface.getEffectsOnValue(target.get(), effects); + if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) { + return isa( + effect.getResource()) && + isa(effect.getEffect()); + })) { + consumedOperands.insert(target.getOperandNumber()); + } + } + + // Remember the results of the payload ops associated with the consumed + // op handles or the ops defining the value handles so we can drop the + // association with them later. This must happen here because the + // transformation may destroy or mutate them so we cannot traverse the payload + // IR after that. + SmallVector origOpFlatResults; + SmallVector origAssociatedOps; + for (unsigned index : consumedOperands) { + Value operand = transform->getOperand(index); + if (operand.getType().isa()) { + for (Operation *payloadOp : getPayloadOps(operand)) + llvm::append_range(origOpFlatResults, payloadOp->getResults()); + continue; + } + if (operand.getType().isa()) { + for (Value payloadValue : getPayloadValues(operand)) { + if (payloadValue.isa()) { + origAssociatedOps.push_back(payloadValue.getDefiningOp()); + continue; + } + llvm::append_range( + origAssociatedOps, + llvm::map_range(*payloadValue.cast().getOwner(), + [](Operation &op) { return &op; })); + } + continue; + } + DiagnosedDefiniteFailure diag = + emitDefiniteFailure(transform->getLoc()) + << "unexpectedly consumed a value that is not a handle as operand #" + << index; + diag.attachNote(operand.getLoc()) + << "value defined here with type " << operand.getType(); + return diag; + } + // Compute the result but do not short-circuit the silenceable failure case as // we still want the handles to propagate properly so the "suppress" mode can // proceed on a best effort basis. + transform::TransformResults results(transform->getNumResults()); DiagnosedSilenceableFailure result(transform.apply(results, *this)); if (result.isDefiniteFailure()) return result; @@ -352,18 +729,12 @@ transform::TransformState::applyTransform(TransformOpInterface transform) { // Remove the mapping for the operand if it is consumed by the operation. This // allows us to catch use-after-free with assertions later on. - auto memEffectInterface = - cast(transform.getOperation()); - SmallVector effects; - for (OpOperand &target : transform->getOpOperands()) { - effects.clear(); - memEffectInterface.getEffectsOnValue(target.get(), effects); - if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) { - return isa( - effect.getResource()) && - isa(effect.getEffect()); - })) { - removePayloadOps(target.get()); + for (unsigned index : consumedOperands) { + Value operand = transform->getOperand(index); + if (operand.getType().isa()) { + forgetMapping(operand, origOpFlatResults); + } else if (operand.getType().isa()) { + forgetValueMapping(operand, origAssociatedOps); } } @@ -378,6 +749,13 @@ transform::TransformState::applyTransform(TransformOpInterface transform) { setParams(result, results.getParams(result.getResultNumber())))) { return DiagnosedSilenceableFailure::definiteFailure(); } + } else if (result.getType().isa()) { + assert(results.isValue(result.getResultNumber()) && + "expected values for value-type-result"); + if (failed(setPayloadValues( + result, results.getValues(result.getResultNumber())))) { + return DiagnosedSilenceableFailure::definiteFailure(); + } } else { assert(!results.isParam(result.getResultNumber()) && "expected payload ops for the non-parameter typed result"); @@ -409,15 +787,9 @@ transform::TransformState::Extension::replacePayloadOp(Operation *op, if (failed(state.getHandlesForPayloadOp(op, handles))) return failure(); - for (Value handle : handles) { - LogicalResult result = - state.updatePayloadOps(handle, [&](Operation *current) { - return current == op ? replacement : current; - }); - if (failed(result)) - return failure(); - } - return success(); + // TODO: we may need to invalidate handles to operations and values nested in + // the operation being replaced. + return state.replacePayloadOp(op, replacement); } //===----------------------------------------------------------------------===// @@ -425,63 +797,95 @@ transform::TransformState::Extension::replacePayloadOp(Operation *op, //===----------------------------------------------------------------------===// transform::TransformResults::TransformResults(unsigned numSegments) { - segments.resize(numSegments, - ArrayRef(nullptr, static_cast(0))); - paramSegments.resize(numSegments, ArrayRef( - nullptr, static_cast(0))); + operations.appendEmptyRows(numSegments); + params.appendEmptyRows(numSegments); + values.appendEmptyRows(numSegments); } void transform::TransformResults::set(OpResult value, ArrayRef ops) { int64_t position = value.getResultNumber(); - assert(position < static_cast(segments.size()) && + assert(position < static_cast(operations.size()) && "setting results for a non-existent handle"); - assert(segments[position].data() == nullptr && "results already set"); - int64_t start = operations.size(); - llvm::append_range(operations, ops); - segments[position] = ArrayRef(operations).drop_front(start); + assert(operations[position].data() == nullptr && "results already set"); + assert(params[position].data() == nullptr && + "another kind of results already set"); + assert(values[position].data() == nullptr && + "another kind of results already set"); + operations.replace(position, ops); } void transform::TransformResults::setParams( OpResult value, ArrayRef params) { int64_t position = value.getResultNumber(); - assert(position < static_cast(paramSegments.size()) && + assert(position < static_cast(this->params.size()) && "setting params for a non-existent handle"); - assert(paramSegments[position].data() == nullptr && "params already set"); - size_t start = this->params.size(); - llvm::append_range(this->params, params); - paramSegments[position] = ArrayRef(this->params).drop_front(start); + assert(this->params[position].data() == nullptr && "params already set"); + assert(operations[position].data() == nullptr && + "another kind of results already set"); + assert(values[position].data() == nullptr && + "another kind of results already set"); + this->params.replace(position, params); +} + +void transform::TransformResults::setValues(OpResult handle, + ValueRange values) { + int64_t position = handle.getResultNumber(); + assert(position < static_cast(values.size()) && + "setting values for a non-existent handle"); + assert(this->values[position].data() == nullptr && "values already set"); + assert(operations[position].data() == nullptr && + "another kind of results already set"); + assert(params[position].data() == nullptr && + "another kind of results already set"); + this->values.replace(position, values); } ArrayRef transform::TransformResults::get(unsigned resultNumber) const { - assert(resultNumber < segments.size() && + assert(resultNumber < operations.size() && "querying results for a non-existent handle"); - assert(segments[resultNumber].data() != nullptr && - "querying unset results (param expected?)"); - return segments[resultNumber]; + assert(operations[resultNumber].data() != nullptr && + "querying unset results (values or params expected?)"); + return operations[resultNumber]; } ArrayRef transform::TransformResults::getParams(unsigned resultNumber) const { - assert(resultNumber < paramSegments.size() && + assert(resultNumber < params.size() && "querying params for a non-existent handle"); - assert(paramSegments[resultNumber].data() != nullptr && - "querying unset params (payload ops expected?)"); - return paramSegments[resultNumber]; + assert(params[resultNumber].data() != nullptr && + "querying unset params (ops or values expected?)"); + return params[resultNumber]; +} + +ArrayRef +transform::TransformResults::getValues(unsigned resultNumber) const { + assert(resultNumber < params.size() && + "querying params for a non-existent handle"); + assert(values[resultNumber].data() != nullptr && + "querying unset values (ops or params expected?)"); + return values[resultNumber]; } bool transform::TransformResults::isParam(unsigned resultNumber) const { - assert(resultNumber < paramSegments.size() && + assert(resultNumber < params.size() && "querying association for a non-existent handle"); - return paramSegments[resultNumber].data() != nullptr; + return params[resultNumber].data() != nullptr; +} + +bool transform::TransformResults::isValue(unsigned resultNumber) const { + assert(resultNumber < values.size() && + "querying association for a non-existent handle"); + return values[resultNumber].data() != nullptr; } bool transform::TransformResults::isSet(unsigned resultNumber) const { - assert(resultNumber < paramSegments.size() && + assert(resultNumber < params.size() && "querying association for a non-existent handle"); - return paramSegments[resultNumber].data() != nullptr || - segments[resultNumber].data() != nullptr; + return params[resultNumber].data() != nullptr || + operations[resultNumber].data() != nullptr || + values[resultNumber].data() != nullptr; } //===----------------------------------------------------------------------===// @@ -547,6 +951,12 @@ void transform::detail::setApplyToOneResults( return oneResult[r.getResultNumber()].get(); })); transformResults.setParams(r, params); + } else if (r.getType().isa()) { + auto values = llvm::to_vector( + llvm::map_range(results, [r](const ApplyToEachResultList &oneResult) { + return oneResult[r.getResultNumber()].get(); + })); + transformResults.setValues(r, values); } else { auto payloads = llvm::to_vector( llvm::map_range(results, [r](const ApplyToEachResultList &oneResult) { @@ -571,6 +981,8 @@ LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments( SmallVector &mapped = extraMappings.emplace_back(); if (operand.getType().isa()) { llvm::append_range(mapped, state.getPayloadOps(operand)); + } else if (operand.getType().isa()) { + llvm::append_range(mapped, state.getPayloadValues(operand)); } else { assert(operand.getType().isa() && "unsupported kind of transform dialect value"); @@ -639,13 +1051,15 @@ transform::detail::verifyPossibleTopLevelTransformOpTrait(Operation *op) { } for (BlockArgument arg : body->getArguments().drop_front()) { if (arg.getType() - .isa()) + .isa()) continue; InFlightDiagnostic diag = op->emitOpError() << "expects trailing entry block arguments to be of type implementing " - "TransformHandleTypeInterface or TransformParamTypeInterface"; + "TransformHandleTypeInterface, TransformValueHandleTypeInterface or " + "TransformParamTypeInterface"; diag.attachNote() << "argument #" << arg.getArgNumber() << " does not"; return diag; } @@ -675,7 +1089,9 @@ void transform::detail::getParamProducerTransformOpTraitEffects( bool hasPayloadOperands = false; for (Value operand : op->getOperands()) { onlyReadsHandle(operand, effects); - if (operand.getType().isa()) + if (operand.getType() + .isa()) hasPayloadOperands = true; } if (hasPayloadOperands) @@ -841,7 +1257,7 @@ LogicalResult transform::detail::verifyTransformOpInterface(Operation *op) { LogicalResult transform::applyTransforms(Operation *payloadRoot, TransformOpInterface transform, - ArrayRef> extraMapping, + const RaggedArray &extraMapping, const TransformOptions &options) { #ifndef NDEBUG if (!transform->hasTrait() || diff --git a/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp b/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp index 4e4fb2da91df..579e1065e952 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp @@ -99,3 +99,13 @@ transform::ParamType::checkPayload(Location loc, } return DiagnosedSilenceableFailure::success(); } + +//===----------------------------------------------------------------------===// +// transform::AnyValueType +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::AnyValueType::checkPayload(Location loc, + ArrayRef payload) const { + return DiagnosedSilenceableFailure::success(); +} diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp index a7456e3bcc3a..3d6ee21478e5 100644 --- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp +++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp @@ -279,7 +279,7 @@ static void performOptionalDebugActions( LogicalResult transform::detail::interpreterBaseRunOnOperationImpl( Operation *target, StringRef passName, const std::shared_ptr> &sharedTransformModule, - ArrayRef> extraMappings, + const RaggedArray &extraMappings, const TransformOptions &options, const Pass::Option &transformFileName, const Pass::Option &debugPayloadRootTag, diff --git a/mlir/test/Dialect/Linalg/transform-op-match.mlir b/mlir/test/Dialect/Linalg/transform-op-match.mlir index e41ffec038ed..cbc62a6a078e 100644 --- a/mlir/test/Dialect/Linalg/transform-op-match.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-match.mlir @@ -13,11 +13,11 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): %match_name = transform.structured.match ops{["arith.constant"]} in %arg1 : (!pdl.operation) -> !pdl.operation transform.test_print_remark_at_operand %match_name, "matched op name" : !pdl.operation - transform.test_consume_operand %match_name + transform.test_consume_operand %match_name : !pdl.operation %match_attr = transform.structured.match ops{["arith.constant"]} attributes{my_attr} in %arg1 : (!pdl.operation) -> !pdl.operation transform.test_print_remark_at_operand %match_attr, "matched attr name" : !pdl.operation - transform.test_consume_operand %match_attr + transform.test_consume_operand %match_attr : !pdl.operation } // ----- @@ -34,7 +34,7 @@ transform.sequence failures(propagate) { %match_name = transform.structured.match ops{["arith.constant"]} filter_result_type = f32 in %arg1 : (!pdl.operation) -> !pdl.operation transform.test_print_remark_at_operand %match_name, "matched op name" : !pdl.operation - transform.test_consume_operand %match_name + transform.test_consume_operand %match_name : !pdl.operation } // ----- @@ -65,7 +65,7 @@ transform.sequence failures(propagate) { #linalg.iterator_type]} in %arg1 : (!pdl.operation) -> !pdl.operation transform.test_print_remark_at_operand %match_attr, "matched complex attr" : !pdl.operation - transform.test_consume_operand %match_attr + transform.test_consume_operand %match_attr : !pdl.operation %no_match = transform.structured.match attributes{iterator_types = [ diff --git a/mlir/test/Dialect/Transform/check-use-after-free.mlir b/mlir/test/Dialect/Transform/check-use-after-free.mlir index 80761639edd1..e2c0f0703250 100644 --- a/mlir/test/Dialect/Transform/check-use-after-free.mlir +++ b/mlir/test/Dialect/Transform/check-use-after-free.mlir @@ -2,7 +2,7 @@ func.func @use_after_free_branching_control_flow() { // expected-note @below {{allocated here}} - %0 = transform.test_produce_param_or_forward_operand 42 + %0 = transform.test_produce_self_handle_or_forward_operand transform.test_transform_op_with_regions { "transform.test_branching_transform_op_terminator"() : () -> () }, @@ -11,7 +11,7 @@ func.func @use_after_free_branching_control_flow() { "transform.test_branching_transform_op_terminator"()[^bb1, ^bb2] : () -> () ^bb1: // expected-note @below {{freed here}} - transform.test_consume_operand_if_matches_param_or_fail %0[42] + transform.test_consume_operand_of_op_kind_or_fail %0, "transform.test_produce_self_handle_or_forward_operand" "transform.test_branching_transform_op_terminator"()[^bb3] : () -> () ^bb2: "transform.test_branching_transform_op_terminator"()[^bb3] : () -> () @@ -29,7 +29,7 @@ func.func @use_after_free_branching_control_flow() { func.func @use_after_free_in_nested_op() { // expected-note @below {{allocated here}} - %0 = transform.test_produce_param_or_forward_operand 42 + %0 = transform.test_produce_self_handle_or_forward_operand // expected-note @below {{freed here}} transform.test_transform_op_with_regions { "transform.test_branching_transform_op_terminator"() : () -> () @@ -38,7 +38,7 @@ func.func @use_after_free_in_nested_op() { ^bb0: "transform.test_branching_transform_op_terminator"()[^bb1, ^bb2] : () -> () ^bb1: - transform.test_consume_operand_if_matches_param_or_fail %0[42] + transform.test_consume_operand_of_op_kind_or_fail %0, "transform.test_produce_self_handle_or_forward_operand" "transform.test_branching_transform_op_terminator"()[^bb3] : () -> () ^bb2: "transform.test_branching_transform_op_terminator"()[^bb3] : () -> () @@ -74,7 +74,7 @@ func.func @use_after_free_recursive_side_effects() { // expected-note @below {{freed here}} transform.sequence %0 : !pdl.operation failures(propagate) attributes { ord = 4 } { ^bb4(%arg4: !pdl.operation): - test_consume_operand_if_matches_param_or_fail %0[42] + test_consume_operand_of_op_kind_or_fail %0, "transform.sequence" } // expected-warning @below {{operand #0 may be used after free}} transform.sequence %0 : !pdl.operation failures(propagate) attributes { ord = 5 } { @@ -102,7 +102,7 @@ func.func @use_after_free() { } // expected-note @below {{freed here}} - test_consume_operand_if_matches_param_or_fail %0[42] + test_consume_operand_of_op_kind_or_fail %0, "transform.sequence" // expected-warning @below {{operand #0 may be used after free}} transform.sequence %0 : !pdl.operation failures(propagate) attributes { ord = 5 } { ^bb3(%arg3: !pdl.operation): @@ -118,7 +118,7 @@ func.func @use_after_free() { // be reported as use-after-free. func.func @use_after_free_self_cycle() { // expected-note @below {{allocated here}} - %0 = transform.test_produce_param_or_forward_operand 42 + %0 = transform.test_produce_self_handle_or_forward_operand transform.test_transform_op_with_regions { "transform.test_branching_transform_op_terminator"() : () -> () }, @@ -132,7 +132,7 @@ func.func @use_after_free_self_cycle() { } // expected-warning @below {{operand #0 may be used after free}} // expected-note @below {{freed here}} - transform.test_consume_operand_if_matches_param_or_fail %0[42] + transform.test_consume_operand_of_op_kind_or_fail %0, "transform.test_produce_self_handle_or_forward_operand" "transform.test_branching_transform_op_terminator"()[^bb1, ^bb2] : () -> () ^bb2: "transform.test_branching_transform_op_terminator"() : () -> () @@ -147,7 +147,7 @@ func.func @use_after_free_self_cycle() { // use-after-free. func.func @use_after_free_cycle() { // expected-note @below {{allocated here}} - %0 = transform.test_produce_param_or_forward_operand 42 + %0 = transform.test_produce_self_handle_or_forward_operand transform.test_transform_op_with_regions { "transform.test_branching_transform_op_terminator"() : () -> () }, @@ -157,7 +157,7 @@ func.func @use_after_free_cycle() { ^bb1: // expected-warning @below {{operand #0 may be used after free}} // expected-note @below {{freed here}} - transform.test_consume_operand_if_matches_param_or_fail %0[42] + transform.test_consume_operand_of_op_kind_or_fail %0, "transform.test_produce_self_handle_or_forward_operand" "transform.test_branching_transform_op_terminator"()[^bb2, ^bb3] : () -> () ^bb2: "transform.test_branching_transform_op_terminator"()[^bb1] : () -> () diff --git a/mlir/test/Dialect/Transform/expensive-checks.mlir b/mlir/test/Dialect/Transform/expensive-checks.mlir index abc09d799ee2..de5355443ede 100644 --- a/mlir/test/Dialect/Transform/expensive-checks.mlir +++ b/mlir/test/Dialect/Transform/expensive-checks.mlir @@ -21,7 +21,7 @@ transform.with_pdl_patterns { %0 = pdl_match @return in %arg1 : (!pdl.operation) -> !pdl.operation %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation // expected-note @below {{invalidated by this transform op that consumes its operand #0}} - test_consume_operand %1 + test_consume_operand %1 : !pdl.operation // expected-error @below {{op uses a handle invalidated by a previously executed transform op}} test_print_remark_at_operand %0, "remark" : !pdl.operation } @@ -55,8 +55,8 @@ transform.with_pdl_patterns { %0 = pdl_match @func in %arg1 : (!pdl.operation) -> !pdl.operation %1 = pdl_match @return in %arg1 : (!pdl.operation) -> !pdl.operation %2 = replicate num(%0) %1 : !pdl.operation, !pdl.operation - // expected-error @below {{a handle passed as operand #0 and consumed by this operation points to a payload operation more than once}} - test_consume_operand %2 + // expected-error @below {{a handle passed as operand #0 and consumed by this operation points to a payload entity more than once}} + test_consume_operand %2 : !pdl.operation test_print_remark_at_operand %0, "remark" : !pdl.operation } } @@ -74,9 +74,9 @@ module { // expected-note @below {{handle to invalidated ops}} %2 = transform.test_copy_payload %0 // expected-note @below {{invalidated by this transform op that consumes its operand #0}} - transform.test_consume_operand %1 + transform.test_consume_operand %1 : !pdl.operation // expected-error @below {{op uses a handle invalidated by a previously executed transform op}} - transform.test_consume_operand %2 + transform.test_consume_operand %2 : !pdl.operation } } @@ -95,8 +95,8 @@ module { // to overlapping sets of payload IR ops. // // expected-error @below {{op uses a handle invalidated by a previously executed transform op}} - // expected-note @below {{invalidated by this transform op that consumes its operand #0 and invalidates handles}} - transform.test_consume_operand %1, %2 + // expected-note @below {{invalidated by this transform op that consumes its operand #0 and invalidates all handles to payload IR entities}} + transform.test_consume_operand %1, %2 : !pdl.operation } } @@ -113,3 +113,221 @@ module { transform.merge_handles %1, %2 { deduplicate } : !pdl.operation } } +// ----- + +// expected-note @below {{payload value}} +%0 = "test.match_anchor"() : () -> (i32) + +transform.sequence failures(propagate) { +^bb1(%arg0: !transform.any_op): + %2 = transform.structured.match ops{["test.match_anchor"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %3 = test_produce_value_handle_to_result %2, 0 : (!transform.any_op) -> !transform.any_value + // expected-note @below {{invalidated handle}} + %4 = test_produce_value_handle_to_result %2, 0 : (!transform.any_op) -> !transform.any_value + // expected-note @below {{invalidated by this transform op that consumes its operand #0 and invalidates handles to the same values as associated with it}} + test_consume_operand %3 : !transform.any_value + // expected-error @below {{op uses a handle invalidated by a previously executed transform op}} + test_consume_operand %4 : !transform.any_value +} + +// ----- + +// expected-note @below {{ancestor op associated with the consumed handle}} +// expected-note @below {{payload value}} +// expected-note @below {{op defining the value as result #0}} +%0 = "test.match_anchor"() : () -> (i32) + +transform.sequence failures(propagate) { +^bb1(%arg0: !transform.any_op): + %2 = transform.structured.match ops{["test.match_anchor"]} in %arg0 : (!transform.any_op) -> !transform.any_op + // expected-note @below {{invalidated handle}} + %3 = test_produce_value_handle_to_result %2, 0 : (!transform.any_op) -> !transform.any_value + // expected-note @below {{invalidated by this transform op that consumes its operand #0 and invalidates all handles to payload IR entities associated with this operand and entities nested in them}} + test_consume_operand %2 : !transform.any_op + // expected-error @below {{op uses a handle invalidated by a previously executed transform op}} + test_consume_operand %3 : !transform.any_value +} + +// ----- + +// expected-note @below {{ancestor op associated with the consumed handle}} +"test.match_anchor_1"() ({ +^bb0: + // expected-note @below {{op defining the value as result #0}} + // expected-note @below {{payload value}} + %0 = "test.match_anchor_2"() : () -> (i32) + "test.region_terminator"() : () -> () +}) : () -> () + +transform.sequence failures(propagate) { +^bb1(%arg0: !transform.any_op): + %1 = transform.structured.match ops{["test.match_anchor_1"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.match ops{["test.match_anchor_2"]} in %arg0 : (!transform.any_op) -> !transform.any_op + // expected-note @below {{invalidated handle}} + %3 = test_produce_value_handle_to_result %2, 0 : (!transform.any_op) -> !transform.any_value + // expected-note @below {{invalidated by this transform op that consumes its operand #0 and invalidates all handles to payload IR entities associated with this operand and entities nested in them}} + test_consume_operand %1 : !transform.any_op + // expected-error @below {{op uses a handle invalidated by a previously executed transform op}} + test_consume_operand %3 : !transform.any_value +} + +// ----- + +// expected-note @below {{ancestor op associated with the consumed handle}} +// expected-note @below {{op defining the value as block argument #0 of block #0 in region #0}} +"test.match_anchor_1"() ({ +// expected-note @below {{payload value}} +^bb0(%arg0: i32): + %0 = "test.match_anchor_2"() : () -> (i32) + "test.region_terminator"() : () -> () +}) : () -> () + +transform.sequence failures(propagate) { +^bb1(%arg0: !transform.any_op): + %1 = transform.structured.match ops{["test.match_anchor_1"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.match ops{["test.match_anchor_2"]} in %arg0 : (!transform.any_op) -> !transform.any_op + // expected-note @below {{invalidated handle}} + %3 = test_produce_value_handle_to_argument_of_parent_block %2, 0 : (!transform.any_op) -> !transform.any_value + // expected-note @below {{invalidated by this transform op that consumes its operand #0 and invalidates all handles to payload IR entities associated with this operand and entities nested in them}} + test_consume_operand %1 : !transform.any_op + // expected-error @below {{op uses a handle invalidated by a previously executed transform op}} + test_consume_operand %3 : !transform.any_value +} + +// ----- + +// expected-note @below {{ancestor op associated with the consumed handle}} +"test.match_anchor_1"() ({ +^bb: + // expected-note @below {{op defining the value as block argument #0 of block #0 in region #0}} + "test.op_with_regions"() ({ + // expected-note @below {{payload value}} + ^bb0(%arg0: i32): + %0 = "test.match_anchor_2"() : () -> (i32) + "test.region_terminator"() : () -> () + }): () -> () +}) : () -> () + +transform.sequence failures(propagate) { +^bb1(%arg0: !transform.any_op): + %1 = transform.structured.match ops{["test.match_anchor_1"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.match ops{["test.match_anchor_2"]} in %arg0 : (!transform.any_op) -> !transform.any_op + // expected-note @below {{invalidated handle}} + %3 = test_produce_value_handle_to_argument_of_parent_block %2, 0 : (!transform.any_op) -> !transform.any_value + // expected-note @below {{invalidated by this transform op that consumes its operand #0 and invalidates all handles to payload IR entities associated with this operand and entities nested in them}} + test_consume_operand %1 : !transform.any_op + // expected-error @below {{op uses a handle invalidated by a previously executed transform op}} + test_consume_operand %3 : !transform.any_value +} + +// ----- + +// expected-note @below {{ancestor payload op}} +// expected-note @below {{nested payload op}} +// expected-note @below {{consumed handle points to this payload value}} +%0 = "test.match_anchor"() : () -> (i32) + +transform.sequence failures(propagate) { +^bb1(%arg0: !transform.any_op): + // expected-note @below {{handle to invalidated ops}} + %2 = transform.structured.match ops{["test.match_anchor"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %3 = test_produce_value_handle_to_result %2, 0 : (!transform.any_op) -> !transform.any_value + // expected-note @below {{invalidated by this transform op that consumes its operand #0 and invalidates all handles to payload IR entities associated with this operand and entities nested in them}} + test_consume_operand %3 : !transform.any_value + // expected-error @below {{op uses a handle invalidated by a previously executed transform op}} + test_consume_operand %2 : !transform.any_op +} + +// ----- + +// expected-note @below {{ancestor payload op}} +// expected-note @below {{consumed handle points to this payload value}} +%0 = "test.match_anchor_1"() ({ +^bb0: + // expected-note @below {{nested payload op}} + "test.match_anchor_2"() : () -> () + "test.region_terminator"() : () -> () +}) : () -> (i32) + +transform.sequence failures(propagate) { +^bb1(%arg0: !transform.any_op): + %1 = transform.structured.match ops{["test.match_anchor_1"]} in %arg0 : (!transform.any_op) -> !transform.any_op + // expected-note @below {{handle to invalidated ops}} + %2 = transform.structured.match ops{["test.match_anchor_2"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %3 = test_produce_value_handle_to_result %1, 0 : (!transform.any_op) -> !transform.any_value + // expected-note @below {{invalidated by this transform op that consumes its operand #0 and invalidates all handles to payload IR entities associated with this operand and entities nested in them}} + test_consume_operand %3 : !transform.any_value + // expected-error @below {{op uses a handle invalidated by a previously executed transform op}} + test_consume_operand %2 : !transform.any_op +} + + +// ----- + +"test.match_anchor_1"() ({ +// expected-note @below {{consumed handle points to this payload value}} +^bb0(%arg0: f32): + // expected-note @below {{ancestor payload op}} + // expected-note @below {{nested payload op}} + "test.match_anchor_2"() : () -> () + "test.region_terminator"() : () -> () +}) : () -> () + +transform.sequence failures(propagate) { +^bb1(%arg0: !transform.any_op): + // expected-note @below {{handle to invalidated ops}} + %2 = transform.structured.match ops{["test.match_anchor_2"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %3 = test_produce_value_handle_to_argument_of_parent_block %2, 0 : (!transform.any_op) -> !transform.any_value + // expected-note @below {{invalidated by this transform op that consumes its operand #0 and invalidates all handles to payload IR entities associated with this operand and entities nested in them}} + test_consume_operand %3 : !transform.any_value + // expected-error @below {{op uses a handle invalidated by a previously executed transform op}} + test_consume_operand %2 : !transform.any_op +} + +// ----- + +"test.op_with_regions"() ({ +// expected-note @below {{consumed handle points to this payload value}} +^bb(%arg0: i32): + // expected-note @below {{ancestor payload op}} + "test.op_with_regions"() ({ + ^bb0: + // expected-note @below {{nested payload op}} + "test.match_anchor_2"() : () -> () + "test.region_terminator"() : () -> () + }): () -> () + "test.match_anchor_1"() : () -> () +}) : () -> () + +transform.sequence failures(propagate) { +^bb1(%arg0: !transform.any_op): + %1 = transform.structured.match ops{["test.match_anchor_1"]} in %arg0 : (!transform.any_op) -> !transform.any_op + // expected-note @below {{handle to invalidated ops}} + %2 = transform.structured.match ops{["test.match_anchor_2"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %3 = test_produce_value_handle_to_argument_of_parent_block %1, 0 : (!transform.any_op) -> !transform.any_value + // expected-note @below {{invalidated by this transform op that consumes its operand #0 and invalidates all handles to payload IR entities associated with this operand and entities nested in them}} + test_consume_operand %3 : !transform.any_value + // expected-error @below {{op uses a handle invalidated by a previously executed transform op}} + test_consume_operand %2 : !transform.any_op +} + +// ----- + +// Removing a block argument does not invalidate handles to operations in another block. +// Not expecting an error here. + +"test.op_with_regions"() ({ +^bb1(%arg0: i32): + "test.match_anchor_1"() : () -> () +^bb2: + "test.match_anchor_2"() : () -> () +}) : () -> () + +transform.sequence failures(propagate) { +^bb1(%arg0: !transform.any_op): + %1 = transform.structured.match ops{["test.match_anchor_1"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.match ops{["test.match_anchor_2"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %3 = test_produce_value_handle_to_argument_of_parent_block %1, 0 : (!transform.any_op) -> !transform.any_value + test_consume_operand %3 : !transform.any_value + test_consume_operand %2 : !transform.any_op +} diff --git a/mlir/test/Dialect/Transform/multi-arg-top-level-ops.mlir b/mlir/test/Dialect/Transform/multi-arg-top-level-ops.mlir index 447c6b4be925..51478366df3a 100644 --- a/mlir/test/Dialect/Transform/multi-arg-top-level-ops.mlir +++ b/mlir/test/Dialect/Transform/multi-arg-top-level-ops.mlir @@ -37,6 +37,17 @@ func.func @foo() { // ----- +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_value): + // expected-error @above {{wrong kind of value provided for the top-level value handle}} +} + +func.func @foo() { + return +} + +// ----- + // expected-error @below {{operation expects 1 extra value bindings, but 2 were provided to the interpreter}} transform.sequence failures(propagate) { ^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op): diff --git a/mlir/test/Dialect/Transform/multi-arg-top-level-values.mlir b/mlir/test/Dialect/Transform/multi-arg-top-level-values.mlir new file mode 100644 index 000000000000..431b0c524cb1 --- /dev/null +++ b/mlir/test/Dialect/Transform/multi-arg-top-level-values.mlir @@ -0,0 +1,45 @@ +// RUN: mlir-opt %s --pass-pipeline='builtin.module(test-transform-dialect-interpreter{bind-first-extra-to-results-of-ops=test.some_returning_op bind-second-extra-to-results-of-ops=test.some_other_returning_op})' \ +// RUN: --split-input-file --verify-diagnostics + +// Note that diagnostic checker will merge two diagnostics with the same message +// at the same location, so only check the remark once. +// +// expected-remark @below {{first extra}} +// expected-note @below {{value handle points to an op result #0}} +// expected-note @below {{value handle points to an op result #1}} +%0:2 = "test.some_returning_op"() : () -> (i32, i64) + +// expected-remark @below {{first extra}} +// expected-note @below {{value handle points to an op result #0}} +%1 = "test.some_returning_op"() : () -> index + +// Note that diagnostic checker will merge two diagnostics with the same message +// at the same location, so only check the remark once. +// +// expected-remark @below {{second extra}} +// expected-note @below {{value handle points to an op result #0}} +// expected-note @below {{value handle points to an op result #1}} +%2:2 = "test.some_other_returning_op"() : () -> (f32, f64) + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op, %arg1: !transform.any_value, %arg2: !transform.any_value): + test_print_remark_at_operand_value %arg1, "first extra" : !transform.any_value + test_print_remark_at_operand_value %arg2, "second extra" : !transform.any_value +} + +// ----- + +%0:2 = "test.some_returning_op"() : () -> (i32, i64) +%1 = "test.some_returning_op"() : () -> index + +transform.sequence failures(propagate) { +// expected-error @below {{wrong kind of value provided for top-level operation handle}} +^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_value): +} + +// ----- + +// expected-error @below {{operation expects 1 extra value bindings, but 2 were provided to the interpreter}} +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op, %arg1: !transform.any_value): +} diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir index d2142db6386c..4abaa233a5d1 100644 --- a/mlir/test/Dialect/Transform/ops-invalid.mlir +++ b/mlir/test/Dialect/Transform/ops-invalid.mlir @@ -24,7 +24,7 @@ transform.sequence failures(propagate) { // ----- -// expected-error @below {{'transform.sequence' op expects trailing entry block arguments to be of type implementing TransformHandleTypeInterface or TransformParamTypeInterface}} +// expected-error @below {{'transform.sequence' op expects trailing entry block arguments to be of type implementing TransformHandleTypeInterface, TransformValueHandleTypeInterface or TransformParamTypeInterface}} // expected-note @below {{argument #1 does not}} transform.sequence failures(propagate) { ^bb0(%arg0: !transform.any_op, %arg1: i64): @@ -166,11 +166,11 @@ transform.sequence failures(propagate) { transform.sequence failures(propagate) { ^bb0(%arg0: !pdl.operation): // expected-error @below {{result #0 has more than one potential consumer}} - %0 = test_produce_param_or_forward_operand 42 + %0 = test_produce_self_handle_or_forward_operand // expected-note @below {{used here as operand #0}} - test_consume_operand_if_matches_param_or_fail %0[42] + test_consume_operand_of_op_kind_or_fail %0, "transform.test_produce_self_handle_or_forward_operand" // expected-note @below {{used here as operand #0}} - test_consume_operand_if_matches_param_or_fail %0[42] + test_consume_operand_of_op_kind_or_fail %0, "transform.test_produce_self_handle_or_forward_operand" } // ----- @@ -178,13 +178,13 @@ transform.sequence failures(propagate) { transform.sequence failures(propagate) { ^bb0(%arg0: !pdl.operation): // expected-error @below {{result #0 has more than one potential consumer}} - %0 = test_produce_param_or_forward_operand 42 + %0 = test_produce_self_handle_or_forward_operand // expected-note @below {{used here as operand #0}} - test_consume_operand_if_matches_param_or_fail %0[42] + test_consume_operand_of_op_kind_or_fail %0, "transform.test_produce_self_handle_or_forward_operand" // expected-note @below {{used here as operand #0}} transform.sequence %0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): - test_consume_operand_if_matches_param_or_fail %arg1[42] + test_consume_operand_of_op_kind_or_fail %arg1, "transform.test_produce_self_handle_or_forward_operand" } } @@ -193,13 +193,13 @@ transform.sequence failures(propagate) { transform.sequence failures(propagate) { ^bb0(%arg0: !pdl.operation): // expected-error @below {{result #0 has more than one potential consumer}} - %0 = test_produce_param_or_forward_operand 42 + %0 = test_produce_self_handle_or_forward_operand // expected-note @below {{used here as operand #0}} - test_consume_operand_if_matches_param_or_fail %0[42] + test_consume_operand_of_op_kind_or_fail %0, "transform.test_produce_self_handle_or_forward_operand" transform.sequence %0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): // expected-note @below {{used here as operand #0}} - test_consume_operand_if_matches_param_or_fail %0[42] + test_consume_operand_of_op_kind_or_fail %0, "transform.test_produce_self_handle_or_forward_operand" } } @@ -208,15 +208,15 @@ transform.sequence failures(propagate) { transform.sequence failures(propagate) { ^bb0(%arg0: !pdl.operation): // expected-error @below {{result #0 has more than one potential consumer}} - %0 = test_produce_param_or_forward_operand 42 + %0 = test_produce_self_handle_or_forward_operand // expected-note @below {{used here as operand #0}} - test_consume_operand_if_matches_param_or_fail %0[42] + test_consume_operand_of_op_kind_or_fail %0, "transform.test_produce_self_handle_or_forward_operand" // expected-note @below {{used here as operand #0}} transform.sequence %0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): transform.sequence %arg1 : !pdl.operation failures(propagate) { ^bb2(%arg2: !pdl.operation): - test_consume_operand_if_matches_param_or_fail %arg2[42] + test_consume_operand_of_op_kind_or_fail %arg2, "transform.test_produce_self_handle_or_forward_operand" } } } @@ -257,14 +257,14 @@ transform.alternatives { transform.sequence failures(propagate) { ^bb0(%arg0: !pdl.operation): // expected-error @below {{result #0 has more than one potential consumer}} - %0 = test_produce_param_or_forward_operand 42 + %0 = test_produce_self_handle_or_forward_operand // expected-note @below {{used here as operand #0}} transform.foreach %0 : !pdl.operation { ^bb1(%arg1: !pdl.operation): - transform.test_consume_operand %arg1 + transform.test_consume_operand %arg1 : !pdl.operation } // expected-note @below {{used here as operand #0}} - transform.test_consume_operand %0 + transform.test_consume_operand %0 : !pdl.operation } // ----- diff --git a/mlir/test/Dialect/Transform/test-dialect-injection.mlir b/mlir/test/Dialect/Transform/test-dialect-injection.mlir index f1c276254485..5bbda8e1fe99 100644 --- a/mlir/test/Dialect/Transform/test-dialect-injection.mlir +++ b/mlir/test/Dialect/Transform/test-dialect-injection.mlir @@ -6,11 +6,11 @@ // CHECK: transform.test_transform_op transform.test_transform_op -// CHECK: = transform.test_produce_param_or_forward_operand 42 {foo = "bar"} -%0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" } +// CHECK: = transform.test_produce_self_handle_or_forward_operand {foo = "bar"} +%0 = transform.test_produce_self_handle_or_forward_operand { foo = "bar" } -// CHECK: transform.test_consume_operand_if_matches_param_or_fail %{{.*}}[42] -transform.test_consume_operand_if_matches_param_or_fail %0[42] +// CHECK: transform.test_consume_operand_of_op_kind_or_fail %{{.*}}, +transform.test_consume_operand_of_op_kind_or_fail %0, "transform.test_produce_self_handle_or_forward_operand" // Ensure that the extension type is roundtripped correctly. // CHECK: transform.cast %{{.*}} : !pdl.operation to !transform.test_dialect_op diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir index f470606c78cf..e8bc530f6a54 100644 --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -10,18 +10,18 @@ transform.sequence failures(propagate) { transform.sequence failures(propagate) { ^bb0(%arg0: !transform.any_op): - %0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" } + %0 = transform.test_produce_self_handle_or_forward_operand { foo = "bar" } // expected-remark @below {{succeeded}} - transform.test_consume_operand_if_matches_param_or_fail %0[42] + transform.test_consume_operand_of_op_kind_or_fail %0, "transform.test_produce_self_handle_or_forward_operand" } // ----- transform.sequence failures(propagate) { ^bb0(%arg0: !transform.any_op): - %0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" } - // expected-error @below {{expected the operand to be associated with 21 got 42}} - transform.test_consume_operand_if_matches_param_or_fail %0[21] + %0 = transform.test_produce_self_handle_or_forward_operand { foo = "bar" } + // expected-error @below {{expected the operand to be associated a payload op of kind transform.sequence got transform.test_produce_self_handle_or_forward_operand}} + transform.test_consume_operand_of_op_kind_or_fail %0, "transform.sequence" } // ----- @@ -31,10 +31,10 @@ transform.sequence failures(propagate) { // to detect double-consumption. transform.sequence failures(propagate) { ^bb0(%arg0: !transform.any_op): - %0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" } + %0 = transform.test_produce_self_handle_or_forward_operand { foo = "bar" } %1 = transform.test_copy_payload %0 // expected-remark @below {{succeeded}} - transform.test_consume_operand_if_matches_param_or_fail %0[42] + transform.test_consume_operand_of_op_kind_or_fail %0, "transform.test_produce_self_handle_or_forward_operand" } // ----- @@ -60,11 +60,11 @@ transform.sequence failures(propagate) { transform.sequence failures(propagate) { ^bb0(%arg0: !pdl.operation): - %0 = test_produce_param_or_forward_operand 42 + %0 = test_produce_self_handle_or_forward_operand sequence %0 : !pdl.operation failures(propagate) { ^bb0(%arg1: !pdl.operation): // expected-remark @below {{succeeded}} - test_consume_operand_if_matches_param_or_fail %arg1[42] + test_consume_operand_of_op_kind_or_fail %arg1, "transform.test_produce_self_handle_or_forward_operand" } } @@ -74,11 +74,11 @@ transform.sequence failures(propagate) { ^bb0(%arg0: !pdl.operation): %0 = sequence %arg0 : !pdl.operation -> !pdl.operation failures(propagate) { ^bb0(%arg1: !pdl.operation): - %1 = test_produce_param_or_forward_operand 42 + %1 = test_produce_self_handle_or_forward_operand yield %1 : !pdl.operation } // expected-remark @below {{succeeded}} - test_consume_operand_if_matches_param_or_fail %0[42] + test_consume_operand_of_op_kind_or_fail %0, "transform.test_produce_self_handle_or_forward_operand" } // ----- @@ -163,15 +163,15 @@ transform.with_pdl_patterns { %0 = pdl_match @match_func in %arg1 : (!pdl.operation) -> !pdl.operation transform.alternatives %0 : !pdl.operation { ^bb2(%arg2: !pdl.operation): - %1 = transform.test_produce_param_or_forward_operand 42 + %1 = transform.test_produce_self_handle_or_forward_operand // This operation fails, which triggers the next alternative without // reporting the error. - transform.test_consume_operand_if_matches_param_or_fail %1[43] + transform.test_consume_operand_of_op_kind_or_fail %1, "transform.sequence" }, { ^bb2(%arg2: !pdl.operation): - %1 = transform.test_produce_param_or_forward_operand 42 + %1 = transform.test_produce_self_handle_or_forward_operand // expected-remark @below {{succeeded}} - transform.test_consume_operand_if_matches_param_or_fail %1[42] + transform.test_consume_operand_of_op_kind_or_fail %1, "transform.test_produce_self_handle_or_forward_operand" } } } @@ -315,17 +315,18 @@ transform.with_pdl_patterns { %3 = transform.pdl_match @match_call in %arg2 : (!pdl.operation) -> !pdl.operation // expected-remark @below {{applying}} transform.test_emit_remark_and_erase_operand %3, "applying" {fail_after_erase} - %4 = transform.test_produce_param_or_forward_operand 43 + %4 = transform.test_produce_self_handle_or_forward_operand %3 transform.yield %4 : !pdl.operation }, { ^bb2(%arg2: !pdl.operation): - %4 = transform.test_produce_param_or_forward_operand 42 + %4 = transform.test_produce_self_handle_or_forward_operand transform.yield %4 : !pdl.operation } // The first alternative failed, so the returned value is taken from the - // second alternative. + // second alternative, associated test_produce_self_handle_or_forward_operand rather + // than pdl_match. // expected-remark @below {{succeeded}} - transform.test_consume_operand_if_matches_param_or_fail %2[42] + transform.test_consume_operand_of_op_kind_or_fail %2, "transform.test_produce_self_handle_or_forward_operand" } } @@ -349,12 +350,12 @@ module { // expected-error @below {{scope must not contain the transforms being applied}} transform.alternatives %arg1 : !pdl.operation { ^bb2(%arg2: !pdl.operation): - %0 = transform.test_produce_param_or_forward_operand 42 - transform.test_consume_operand_if_matches_param_or_fail %0[43] + %0 = transform.test_produce_self_handle_or_forward_operand + transform.test_consume_operand_of_op_kind_or_fail %0, "transform.sequence" }, { ^bb2(%arg2: !pdl.operation): - %0 = transform.test_produce_param_or_forward_operand 42 - transform.test_consume_operand_if_matches_param_or_fail %0[42] + %0 = transform.test_produce_self_handle_or_forward_operand + transform.test_consume_operand_of_op_kind_or_fail %0, "transform.test_produce_self_handle_or_forward_operand" } } } @@ -1094,6 +1095,14 @@ transform.sequence failures(propagate) { // ----- +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + // expected-error @below {{attempting to assign a null payload value to this transform handle}} + %0 = transform.test_produce_null_value : !transform.any_value +} + +// ----- + // expected-error @below {{could not find a nested top-level transform op}} // expected-note @below {{use the 'transform-file-name' option to provide transform as external file}} module { @@ -1106,7 +1115,65 @@ transform.sequence failures(propagate) { ^bb0(%arg0: !transform.any_op): } -// expected-error @below {{ore than one top-level transform op}} +// expected-error @below {{more than one top-level transform op}} transform.sequence failures(propagate) { ^bb0(%arg0: !transform.any_op): } + +// ----- + +transform.sequence failures(propagate) { +// expected-remark @below {{value handle}} +// expected-note @below {{value handle points to a block argument #0 in block #0 in region #0}} +^bb1(%arg0: !transform.any_op): + %0 = test_produce_value_handle_to_self_operand %arg0 : (!transform.any_op) -> !transform.any_value + test_print_remark_at_operand_value %0, "value handle" : !transform.any_value +} + +// ----- + +// expected-remark @below {{result handle}} +// expected-note @below {{value handle points to an op result #1}} +%0:2 = "test.get_two_results"() : () -> (i32, i32) +// expected-remark @below {{result handle}} +// expected-note @below {{value handle points to an op result #1}} +%1:3 = "test.get_three_results"() : () -> (i32, i32, f32) + +transform.sequence failures(propagate) { +^bb1(%arg0: !transform.any_op): + %2 = transform.structured.match ops{["test.get_two_results", "test.get_three_results"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %3 = test_produce_value_handle_to_result %2, 1 : (!transform.any_op) -> !transform.any_value + test_print_remark_at_operand_value %3, "result handle" : !transform.any_value +} + +// ----- + +"test.op_with_regions"() ({ +^bb0: + "test.regon_terminator"() : () -> () +}, { +^bb1: + "test.regon_terminator"() : () -> () +// expected-remark @below {{block argument handle}} +// expected-note @below {{value handle points to a block argument #2 in block #1 in region #1}} +^bb2(%arg0: i32, %arg1: f64, %arg3: index): + "test.match_anchor"() : () -> () + "test.regon_terminator"() : () -> () +}) : () -> () + +transform.sequence failures(propagate) { +^bb1(%arg0: !transform.any_op): + %2 = transform.structured.match ops{["test.match_anchor"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %3 = test_produce_value_handle_to_argument_of_parent_block %2, 2 : (!transform.any_op) -> !transform.any_value + test_print_remark_at_operand_value %3, "block argument handle" : !transform.any_value +} + +// ----- + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + // expected-note @below {{value defined here with type '!transform.test_dialect_param'}} + %0 = test_produce_param_with_number_of_test_ops %arg0 : !transform.any_op + // expected-error @below {{unexpectedly consumed a value that is not a handle as operand #0}} + test_consume_operand %0 : !transform.test_dialect_param +} diff --git a/mlir/test/Dialect/Transform/transform-state-extension.mlir b/mlir/test/Dialect/Transform/transform-state-extension.mlir index 1f29684b35e3..054ee077496c 100644 --- a/mlir/test/Dialect/Transform/transform-state-extension.mlir +++ b/mlir/test/Dialect/Transform/transform-state-extension.mlir @@ -45,6 +45,18 @@ module { } } +// ----- + +// expected-error @below {{cannot replace an op with another op producing a different number of results while tracking handles}} +module { + transform.sequence failures(propagate) { + ^bb0(%arg0: !pdl.operation): + test_add_test_extension "A" + %dummy = test_remap_operand_to_self %arg0 : !transform.any_op + } +} + + // ----- module { diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp index 0bd3031df913..e32cb3ec891a 100644 --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -106,29 +106,73 @@ public: } // namespace DiagnosedSilenceableFailure -mlir::test::TestProduceParamOrForwardOperandOp::apply( +mlir::test::TestProduceSelfHandleOrForwardOperandOp::apply( transform::TransformResults &results, transform::TransformState &state) { if (getOperation()->getNumOperands() != 0) { results.set(getResult().cast(), getOperation()->getOperand(0).getDefiningOp()); } else { - results.set(getResult().cast(), - reinterpret_cast(*getParameter())); + results.set(getResult().cast(), getOperation()); } return DiagnosedSilenceableFailure::success(); } -void mlir::test::TestProduceParamOrForwardOperandOp::getEffects( +void mlir::test::TestProduceSelfHandleOrForwardOperandOp::getEffects( SmallVectorImpl &effects) { if (getOperand()) transform::onlyReadsHandle(getOperand(), effects); transform::producesHandle(getRes(), effects); } -LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::verify() { - if (getParameter().has_value() ^ (getNumOperands() != 1)) - return emitOpError() << "expects either a parameter or an operand"; - return success(); +DiagnosedSilenceableFailure +mlir::test::TestProduceValueHandleToSelfOperand::apply( + transform::TransformResults &results, transform::TransformState &state) { + results.setValues(getOut().cast(), getIn()); + return DiagnosedSilenceableFailure::success(); +} + +void mlir::test::TestProduceValueHandleToSelfOperand::getEffects( + SmallVectorImpl &effects) { + transform::onlyReadsHandle(getIn(), effects); + transform::producesHandle(getOut(), effects); + transform::onlyReadsPayload(effects); +} + +DiagnosedSilenceableFailure +mlir::test::TestProduceValueHandleToResult::applyToOne( + Operation *target, transform::ApplyToEachResultList &results, + transform::TransformState &state) { + if (target->getNumResults() <= getNumber()) + return emitSilenceableError() << "payload has no result #" << getNumber(); + results.push_back(target->getResult(getNumber())); + return DiagnosedSilenceableFailure::success(); +} + +void mlir::test::TestProduceValueHandleToResult::getEffects( + SmallVectorImpl &effects) { + transform::onlyReadsHandle(getIn(), effects); + transform::producesHandle(getOut(), effects); + transform::onlyReadsPayload(effects); +} + +DiagnosedSilenceableFailure +mlir::test::TestProduceValueHandleToArgumentOfParentBlock::applyToOne( + Operation *target, transform::ApplyToEachResultList &results, + transform::TransformState &state) { + if (!target->getBlock()) + return emitSilenceableError() << "payload has no parent block"; + if (target->getBlock()->getNumArguments() <= getNumber()) + return emitSilenceableError() + << "parent of the payload has no argument #" << getNumber(); + results.push_back(target->getBlock()->getArgument(getNumber())); + return DiagnosedSilenceableFailure::success(); +} + +void mlir::test::TestProduceValueHandleToArgumentOfParentBlock::getEffects( + SmallVectorImpl &effects) { + transform::onlyReadsHandle(getIn(), effects); + transform::producesHandle(getOut(), effects); + transform::onlyReadsPayload(effects); } DiagnosedSilenceableFailure @@ -145,23 +189,21 @@ void mlir::test::TestConsumeOperand::getEffects( transform::modifiesPayload(effects); } -DiagnosedSilenceableFailure -mlir::test::TestConsumeOperandIfMatchesParamOrFail::apply( +DiagnosedSilenceableFailure mlir::test::TestConsumeOperandOfOpKindOrFail::apply( transform::TransformResults &results, transform::TransformState &state) { ArrayRef payload = state.getPayloadOps(getOperand()); assert(payload.size() == 1 && "expected a single target op"); - auto value = reinterpret_cast(payload[0]); - if (static_cast(value) != getParameter()) { + if (payload[0]->getName().getStringRef() != getOpKind()) { return emitSilenceableError() - << "op expected the operand to be associated with " << getParameter() - << " got " << value; + << "op expected the operand to be associated a payload op of kind " + << getOpKind() << " got " << payload[0]->getName().getStringRef(); } emitRemark() << "succeeded"; return DiagnosedSilenceableFailure::success(); } -void mlir::test::TestConsumeOperandIfMatchesParamOrFail::getEffects( +void mlir::test::TestConsumeOperandOfOpKindOrFail::getEffects( SmallVectorImpl &effects) { transform::consumesHandle(getOperand(), effects); transform::modifiesPayload(effects); @@ -182,6 +224,32 @@ void mlir::test::TestPrintRemarkAtOperandOp::getEffects( transform::onlyReadsPayload(effects); } +DiagnosedSilenceableFailure mlir::test::TestPrintRemarkAtOperandValue::apply( + transform::TransformResults &results, transform::TransformState &state) { + ArrayRef values = state.getPayloadValues(getIn()); + for (Value value : values) { + std::string note; + llvm::raw_string_ostream os(note); + if (auto arg = value.dyn_cast()) { + os << "a block argument #" << arg.getArgNumber() << " in block #" + << std::distance(arg.getOwner()->getParent()->begin(), + arg.getOwner()->getIterator()) + << " in region #" << arg.getOwner()->getParent()->getRegionNumber(); + } else { + os << "an op result #" << value.cast().getResultNumber(); + } + InFlightDiagnostic diag = ::emitRemark(value.getLoc()) << getMessage(); + diag.attachNote() << "value handle points to " << os.str(); + } + return DiagnosedSilenceableFailure::success(); +} + +void mlir::test::TestPrintRemarkAtOperandValue::getEffects( + SmallVectorImpl &effects) { + transform::onlyReadsHandle(getIn(), effects); + transform::onlyReadsPayload(effects); +} + DiagnosedSilenceableFailure mlir::test::TestAddTestExtensionOp::apply(transform::TransformResults &results, transform::TransformState &state) { @@ -235,6 +303,7 @@ DiagnosedSilenceableFailure mlir::test::TestRemapOperandPayloadToSelfOp::apply( void mlir::test::TestRemapOperandPayloadToSelfOp::getEffects( SmallVectorImpl &effects) { transform::onlyReadsHandle(getOperand(), effects); + transform::producesHandle(getOut(), effects); transform::onlyReadsPayload(effects); } @@ -528,6 +597,18 @@ mlir::test::TestProduceNullParamOp::apply(transform::TransformResults &results, return DiagnosedSilenceableFailure::success(); } +void mlir::test::TestProduceNullValueOp::getEffects( + SmallVectorImpl &effects) { + transform::producesHandle(getOut(), effects); +} + +DiagnosedSilenceableFailure +mlir::test::TestProduceNullValueOp::apply(transform::TransformResults &results, + transform::TransformState &state) { + results.setValues(getOut().cast(), Value()); + return DiagnosedSilenceableFailure::success(); +} + void mlir::test::TestRequiredMemoryEffectsOp::getEffects( SmallVectorImpl &effects) { if (getHasOperandEffect()) diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td index cc67c2ac2afd..4c9b3d58ffcb 100644 --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -39,37 +39,79 @@ def TestTransformTestDialectParamType let assemblyFormat = ""; } -def TestProduceParamOrForwardOperandOp - : Op, DeclareOpInterfaceMethods]> { - let arguments = (ins - Optional:$operand, - OptionalAttr:$parameter); + let arguments = (ins Optional:$operand); let results = (outs PDL_Operation:$res); - let assemblyFormat = "(`from` $operand^)? ($parameter^)? attr-dict"; + let assemblyFormat = "($operand^)? attr-dict"; let cppNamespace = "::mlir::test"; - let hasVerifier = 1; +} + +def TestProduceValueHandleToSelfOperand + : Op, + DeclareOpInterfaceMethods]> { + let arguments = (ins TransformHandleTypeInterface:$in); + let results = (outs TransformValueHandleTypeInterface:$out); + let assemblyFormat = "$in attr-dict `:` functional-type(operands, results)"; + let cppNamespace = "::mlir::test"; + +} + +def TestProduceValueHandleToResult + : Op]> { + let arguments = (ins TransformHandleTypeInterface:$in, + I64Attr:$number); + let results = (outs TransformValueHandleTypeInterface:$out); + let assemblyFormat = "$in `,` $number attr-dict `:` functional-type(operands, results)"; + let cppNamespace = "::mlir::test"; + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +def TestProduceValueHandleToArgumentOfParentBlock + : Op]> { + let arguments = (ins TransformHandleTypeInterface:$in, + I64Attr:$number); + let results = (outs TransformValueHandleTypeInterface:$out); + let assemblyFormat = "$in `,` $number attr-dict `:` functional-type(operands, results)"; + let cppNamespace = "::mlir::test"; + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; } def TestConsumeOperand : Op, DeclareOpInterfaceMethods]> { let arguments = (ins - PDL_Operation:$operand, + Transform_AnyHandleOrParamType:$operand, Optional:$second_operand); - let assemblyFormat = "$operand (`,` $second_operand^)? attr-dict"; + let assemblyFormat = "$operand (`,` $second_operand^)? attr-dict `:` type($operand)"; let cppNamespace = "::mlir::test"; } -def TestConsumeOperandIfMatchesParamOrFail - : Op, DeclareOpInterfaceMethods]> { let arguments = (ins PDL_Operation:$operand, - I64Attr:$parameter); - let assemblyFormat = "$operand `[` $parameter `]` attr-dict"; + StrAttr:$op_kind); + let assemblyFormat = "$operand `,` $op_kind attr-dict"; let cppNamespace = "::mlir::test"; } @@ -85,6 +127,16 @@ def TestPrintRemarkAtOperandOp let cppNamespace = "::mlir::test"; } +def TestPrintRemarkAtOperandValue + : Op, + DeclareOpInterfaceMethods]> { + let arguments = (ins TransformValueHandleTypeInterface:$in, + StrAttr:$message); + let assemblyFormat = "$in `,` $message attr-dict `:` type($in)"; + let cppNamespace = "::mlir::test"; +} + def TestAddTestExtensionOp : Op, @@ -107,8 +159,9 @@ def TestRemapOperandPayloadToSelfOp : Op, DeclareOpInterfaceMethods]> { - let arguments = (ins PDL_Operation:$operand); - let assemblyFormat = "$operand attr-dict"; + let arguments = (ins PDL_Operation:$operand); + let results = (outs Optional:$out); + let assemblyFormat = "$operand attr-dict (`:` type($out)^)?"; let cppNamespace = "::mlir::test"; } @@ -349,6 +402,15 @@ def TestProduceNullParamOp let cppNamespace = "::mlir::test"; } +def TestProduceNullValueOp + : Op, + DeclareOpInterfaceMethods]> { + let results = (outs TransformValueHandleTypeInterface:$out); + let assemblyFormat = "attr-dict `:` type($out)"; + let cppNamespace = "::mlir::test"; +} + def TestRequiredMemoryEffectsOp : Op, diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp index f9823e146f50..7beae91bb800 100644 --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp @@ -46,65 +46,93 @@ public: return "apply transform dialect operations one by one"; } - ArrayRef - findOperationsByName(Operation *root, StringRef name, - SmallVectorImpl &storage) { - size_t start = storage.size(); + void findOperationsByName(Operation *root, StringRef name, + SmallVectorImpl &operations) { root->walk([&](Operation *op) { if (op->getName().getStringRef() == name) { - storage.push_back(op); + operations.push_back(op); } }); - return ArrayRef(storage).drop_front(start); } - ArrayRef - createParameterMapping(MLIRContext &context, ArrayRef values, - SmallVectorImpl &storage) { - size_t start = storage.size(); - llvm::append_range(storage, llvm::map_range(values, [&](int v) { - Builder b(&context); - return transform::MappedValue(b.getI64IntegerAttr(v)); - })); - return ArrayRef(storage).drop_front(start); + void createParameterMapping(MLIRContext &context, ArrayRef values, + RaggedArray &result) { + SmallVector storage = + llvm::to_vector(llvm::map_range(values, [&](int v) { + Builder b(&context); + return transform::MappedValue(b.getI64IntegerAttr(v)); + })); + result.push_back(std::move(storage)); + } + + void + createOpResultMapping(Operation *root, StringRef name, + RaggedArray &extraMapping) { + SmallVector operations; + findOperationsByName(root, name, operations); + SmallVector results; + for (Operation *op : operations) + llvm::append_range(results, op->getResults()); + extraMapping.push_back(results); + } + + unsigned numberOfSetOptions(const Option &ops, + const ListOption ¶ms, + const Option &values) { + unsigned numSetValues = 0; + numSetValues += !ops.empty(); + numSetValues += !params.empty(); + numSetValues += !values.empty(); + return numSetValues; } void runOnOperation() override { - if (!bindFirstExtraToOps.empty() && !bindFirstExtraToParams.empty()) { - emitError(UnknownLoc::get(&getContext())) - << "cannot bind the first extra top-level argument to both " - "operations and parameters"; + unsigned firstSetOptions = + numberOfSetOptions(bindFirstExtraToOps, bindFirstExtraToParams, + bindFirstExtraToResultsOfOps); + unsigned secondSetOptions = + numberOfSetOptions(bindSecondExtraToOps, bindSecondExtraToParams, + bindSecondExtraToResultsOfOps); + auto loc = UnknownLoc::get(&getContext()); + if (firstSetOptions > 1) { + emitError(loc) << "cannot bind the first extra top-level argument to " + "multiple entities"; return signalPassFailure(); } - if (!bindSecondExtraToOps.empty() && !bindSecondExtraToParams.empty()) { - emitError(UnknownLoc::get(&getContext())) - << "cannot bind the second extra top-level argument to both " - "operations and parameters"; + if (secondSetOptions > 1) { + emitError(loc) << "cannot bind the second extra top-level argument to " + "multiple entities"; return signalPassFailure(); } - if ((!bindSecondExtraToOps.empty() || !bindSecondExtraToParams.empty()) && - bindFirstExtraToOps.empty() && bindFirstExtraToParams.empty()) { - emitError(UnknownLoc::get(&getContext())) - << "cannot bind the second extra top-level argument without binding " - "the first"; - return signalPassFailure(); + if (firstSetOptions == 0 && secondSetOptions != 0) { + emitError(loc) << "cannot bind the second extra top-level argument " + "without bindings the first"; } - SmallVector extraMappingStorage; - SmallVector> extraMapping; + RaggedArray extraMapping; if (!bindFirstExtraToOps.empty()) { - extraMapping.push_back(findOperationsByName( - getOperation(), bindFirstExtraToOps.getValue(), extraMappingStorage)); + SmallVector operations; + findOperationsByName(getOperation(), bindFirstExtraToOps.getValue(), + operations); + extraMapping.push_back(operations); } else if (!bindFirstExtraToParams.empty()) { - extraMapping.push_back(createParameterMapping( - getContext(), bindFirstExtraToParams, extraMappingStorage)); + createParameterMapping(getContext(), bindFirstExtraToParams, + extraMapping); + } else if (!bindFirstExtraToResultsOfOps.empty()) { + createOpResultMapping(getOperation(), bindFirstExtraToResultsOfOps, + extraMapping); } + if (!bindSecondExtraToOps.empty()) { - extraMapping.push_back(findOperationsByName( - getOperation(), bindSecondExtraToOps, extraMappingStorage)); + SmallVector operations; + findOperationsByName(getOperation(), bindSecondExtraToOps, operations); + extraMapping.push_back(operations); } else if (!bindSecondExtraToParams.empty()) { - extraMapping.push_back(createParameterMapping( - getContext(), bindSecondExtraToParams, extraMappingStorage)); + createParameterMapping(getContext(), bindSecondExtraToParams, + extraMapping); + } else if (!bindSecondExtraToResultsOfOps.empty()) { + createOpResultMapping(getOperation(), bindSecondExtraToResultsOfOps, + extraMapping); } options = options.enableExpensiveChecks(enableExpensiveChecks); @@ -128,6 +156,10 @@ public: *this, "bind-first-extra-to-params", llvm::cl::desc("bind the first extra argument of the top-level op to " "the given integer parameters")}; + Option bindFirstExtraToResultsOfOps{ + *this, "bind-first-extra-to-results-of-ops", + llvm::cl::desc("bind the first extra argument of the top-level op to " + "results of payload operations of the given kind")}; Option bindSecondExtraToOps{ *this, "bind-second-extra-to-ops", @@ -137,6 +169,11 @@ public: *this, "bind-second-extra-to-params", llvm::cl::desc("bind the second extra argument of the top-level op to " "the given integer parameters")}; + Option bindSecondExtraToResultsOfOps{ + *this, "bind-second-extra-to-results-of-ops", + llvm::cl::desc("bind the second extra argument of the top-level op to " + "results of payload operations of the given kind")}; + Option transformFileName{ *this, "transform-file-name", llvm::cl::init(""), llvm::cl::desc(