mirror of
https://github.com/intel/llvm.git
synced 2026-01-15 12:25:46 +08:00
Unify the 'constantFold' and 'fold' hooks on an operation into just 'fold'. This new unified fold hook will take constant attributes as operands, and may return an existing 'Value *' or a constant 'Attribute' when folding. This removes the awkward situation where a simple canonicalization like "sub(x,x)->0" had to be written as a canonicalization pattern as opposed to a fold.
-- PiperOrigin-RevId: 248582024
This commit is contained in:
committed by
Mehdi Amini
parent
13dbad87f6
commit
1982afb145
@@ -371,11 +371,6 @@ This boolean field indicate whether canonicalization patterns have been defined
|
||||
for this operation. If it is `1`, then `::getCanonicalizationPatterns()` should
|
||||
be defined.
|
||||
|
||||
### `hasConstantFolder`
|
||||
|
||||
This boolean field indicate whether constant folding rules have been defined
|
||||
for this operation. If it is `1`, then `::constantFold()` should be defined.
|
||||
|
||||
### `hasFolder`
|
||||
|
||||
This boolean field indicate whether general folding rules have been defined
|
||||
|
||||
@@ -74,11 +74,10 @@ in either way.
|
||||
Operations can also have custom parser, printer, builder, verifier, constant
|
||||
folder, or canonicalizer. These require specifying additional C++ methods to
|
||||
invoke for additional functionality. For example, if an operation is marked to
|
||||
have a constant folder, the constant folder also needs to be added, e.g.,:
|
||||
have a folder, the constant folder also needs to be added, e.g.,:
|
||||
|
||||
```c++
|
||||
Attribute SpecificOp::constantFold(ArrayRef<Attribute> operands,
|
||||
MLIRContext *context) {
|
||||
OpFoldResult SpecificOp::fold(ArrayRef<Attribute> constOperands) {
|
||||
if (unable_to_fold)
|
||||
return {};
|
||||
....
|
||||
|
||||
@@ -83,7 +83,7 @@ public:
|
||||
static ParseResult parse(OpAsmParser *parser, OperationState *result);
|
||||
void print(OpAsmPrinter *p);
|
||||
LogicalResult verify();
|
||||
Attribute constantFold(ArrayRef<Attribute> operands, MLIRContext *context);
|
||||
OpFoldResult fold(ArrayRef<Attribute> operands);
|
||||
|
||||
static void getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context);
|
||||
|
||||
@@ -761,7 +761,7 @@ private:
|
||||
|
||||
namespace llvm {
|
||||
|
||||
// Attribute hash just like pointers
|
||||
// Attribute hash just like pointers.
|
||||
template <> struct DenseMapInfo<mlir::Attribute> {
|
||||
static mlir::Attribute getEmptyKey() {
|
||||
auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
|
||||
@@ -779,6 +779,18 @@ template <> struct DenseMapInfo<mlir::Attribute> {
|
||||
}
|
||||
};
|
||||
|
||||
/// Allow LLVM to steal the low bits of Attributes.
|
||||
template <> struct PointerLikeTypeTraits<mlir::Attribute> {
|
||||
public:
|
||||
static inline void *getAsVoidPointer(mlir::Attribute attr) {
|
||||
return const_cast<void *>(attr.getAsOpaquePointer());
|
||||
}
|
||||
static inline mlir::Attribute getFromVoidPointer(void *ptr) {
|
||||
return mlir::Attribute::getFromOpaquePointer(ptr);
|
||||
}
|
||||
enum { NumLowBitsAvailable = 3 };
|
||||
};
|
||||
|
||||
} // namespace llvm
|
||||
|
||||
#endif
|
||||
|
||||
@@ -24,7 +24,7 @@
|
||||
#ifndef MLIR_MATCHERS_H
|
||||
#define MLIR_MATCHERS_H
|
||||
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include <type_traits>
|
||||
|
||||
@@ -73,9 +73,9 @@ struct constant_op_binder {
|
||||
if (!op->hasNoSideEffect())
|
||||
return false;
|
||||
|
||||
SmallVector<Attribute, 1> foldedAttr;
|
||||
if (succeeded(op->constantFold(/*operands=*/llvm::None, foldedAttr))) {
|
||||
*bind_value = foldedAttr.front();
|
||||
SmallVector<OpFoldResult, 1> foldedOp;
|
||||
if (succeeded(op->fold(/*operands=*/llvm::None, foldedOp))) {
|
||||
*bind_value = foldedOp.front().dyn_cast<Attribute>();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
||||
@@ -925,9 +925,6 @@ class Op<Dialect dialect, string mnemonic, list<OpTrait> props = []> {
|
||||
// and C++ implementations.
|
||||
bit hasCanonicalizer = 0;
|
||||
|
||||
// Whether this op has a constant folder.
|
||||
bit hasConstantFolder = 0;
|
||||
|
||||
// Whether this op has a folder.
|
||||
bit hasFolder = 0;
|
||||
|
||||
|
||||
@@ -170,47 +170,28 @@ inline bool operator!=(OpState lhs, OpState rhs) {
|
||||
return lhs.getOperation() != rhs.getOperation();
|
||||
}
|
||||
|
||||
/// This template defines the constantFoldHook and foldHook as used by
|
||||
/// AbstractOperation.
|
||||
/// This class represents a single result from folding an operation.
|
||||
class OpFoldResult : public llvm::PointerUnion<Attribute, Value *> {
|
||||
using llvm::PointerUnion<Attribute, Value *>::PointerUnion;
|
||||
};
|
||||
|
||||
/// This template defines the foldHook as used by AbstractOperation.
|
||||
///
|
||||
/// The default implementation uses a general constantFold/fold method that can
|
||||
/// be defined on custom ops which can return multiple results.
|
||||
/// The default implementation uses a general fold method that can be defined on
|
||||
/// custom ops which can return multiple results.
|
||||
template <typename ConcreteType, bool isSingleResult, typename = void>
|
||||
class FoldingHook {
|
||||
public:
|
||||
/// This is an implementation detail of the constant folder hook for
|
||||
/// AbstractOperation.
|
||||
static LogicalResult constantFoldHook(Operation *op,
|
||||
ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<Attribute> &results) {
|
||||
return cast<ConcreteType>(op).constantFold(operands, results,
|
||||
op->getContext());
|
||||
}
|
||||
|
||||
/// Op implementations can implement this hook. It should attempt to constant
|
||||
/// fold this operation with the specified constant operand values - the
|
||||
/// elements in "operands" will correspond directly to the operands of the
|
||||
/// operation, but may be null if non-constant. If constant folding is
|
||||
/// successful, this fills in the `results` vector. If not, `results` is
|
||||
/// unspecified.
|
||||
///
|
||||
/// If not overridden, this fallback implementation always fails to fold.
|
||||
///
|
||||
LogicalResult constantFold(ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<Attribute> &results,
|
||||
MLIRContext *context) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
/// This is an implementation detail of the folder hook for AbstractOperation.
|
||||
static LogicalResult foldHook(Operation *op,
|
||||
SmallVectorImpl<Value *> &results) {
|
||||
return cast<ConcreteType>(op).fold(results);
|
||||
static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<OpFoldResult> &results) {
|
||||
return cast<ConcreteType>(op).fold(operands, results);
|
||||
}
|
||||
|
||||
/// This hook implements a generalized folder for this operation. Operations
|
||||
/// can implement this to provide simplifications rules that are applied by
|
||||
/// the FuncBuilder::foldOrCreate API and the canonicalization pass.
|
||||
/// the Builder::foldOrCreate API and the canonicalization pass.
|
||||
///
|
||||
/// This is an intentionally limited interface - implementations of this hook
|
||||
/// can only perform the following changes to the operation:
|
||||
@@ -225,23 +206,25 @@ public:
|
||||
/// instead.
|
||||
///
|
||||
/// This allows expression of some simple in-place canonicalizations (e.g.
|
||||
/// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), but does
|
||||
/// not allow for canonicalizations that need to introduce new operations, not
|
||||
/// even constants (e.g. "x-x -> 0" cannot be expressed).
|
||||
/// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), as well as
|
||||
/// generalized constant folding.
|
||||
///
|
||||
/// If not overridden, this fallback implementation always fails to fold.
|
||||
///
|
||||
LogicalResult fold(SmallVectorImpl<Value *> &results) { return failure(); }
|
||||
LogicalResult fold(ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<OpFoldResult> &results) {
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
/// This template specialization defines the constantFoldHook and foldHook as
|
||||
/// used by AbstractOperation for single-result operations. This gives the hook
|
||||
/// a nicer signature that is easier to implement.
|
||||
/// This template specialization defines the foldHook as used by
|
||||
/// AbstractOperation for single-result operations. This gives the hook a nicer
|
||||
/// signature that is easier to implement.
|
||||
template <typename ConcreteType, bool isSingleResult>
|
||||
class FoldingHook<ConcreteType, isSingleResult,
|
||||
typename std::enable_if<isSingleResult>::type> {
|
||||
public:
|
||||
/// If the operation returns a single value, then the Op can be implicitly
|
||||
/// If the operation returns a single value, then the Op can be implicitly
|
||||
/// converted to an Value*. This yields the value of the only result.
|
||||
operator Value *() {
|
||||
return static_cast<ConcreteType *>(this)->getOperation()->getResult(0);
|
||||
@@ -249,11 +232,9 @@ public:
|
||||
|
||||
/// This is an implementation detail of the constant folder hook for
|
||||
/// AbstractOperation.
|
||||
static LogicalResult constantFoldHook(Operation *op,
|
||||
ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<Attribute> &results) {
|
||||
auto result =
|
||||
cast<ConcreteType>(op).constantFold(operands, op->getContext());
|
||||
static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<OpFoldResult> &results) {
|
||||
auto result = cast<ConcreteType>(op).fold(operands);
|
||||
if (!result)
|
||||
return failure();
|
||||
|
||||
@@ -261,33 +242,9 @@ public:
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Op implementations can implement this hook. It should attempt to constant
|
||||
/// fold this operation with the specified constant operand values - the
|
||||
/// elements in "operands" will correspond directly to the operands of the
|
||||
/// operation, but may be null if non-constant. If constant folding is
|
||||
/// successful, this returns a non-null attribute, otherwise it returns null
|
||||
/// on failure.
|
||||
///
|
||||
/// If not overridden, this fallback implementation always fails to fold.
|
||||
///
|
||||
Attribute constantFold(ArrayRef<Attribute> operands, MLIRContext *context) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/// This is an implementation detail of the folder hook for AbstractOperation.
|
||||
static LogicalResult foldHook(Operation *op,
|
||||
SmallVectorImpl<Value *> &results) {
|
||||
auto *result = cast<ConcreteType>(op).fold();
|
||||
if (!result)
|
||||
return failure();
|
||||
if (result != op->getResult(0))
|
||||
results.push_back(result);
|
||||
return success();
|
||||
}
|
||||
|
||||
/// This hook implements a generalized folder for this operation. Operations
|
||||
/// can implement this to provide simplifications rules that are applied by
|
||||
/// the FuncBuilder::foldOrCreate API and the canonicalization pass.
|
||||
/// the Builder::foldOrCreate API and the canonicalization pass.
|
||||
///
|
||||
/// This is an intentionally limited interface - implementations of this hook
|
||||
/// can only perform the following changes to the operation:
|
||||
@@ -301,13 +258,12 @@ public:
|
||||
/// remove the operation and use that result instead.
|
||||
///
|
||||
/// This allows expression of some simple in-place canonicalizations (e.g.
|
||||
/// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), but does
|
||||
/// not allow for canonicalizations that need to introduce new operations, not
|
||||
/// even constants (e.g. "x-x -> 0" cannot be expressed).
|
||||
/// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), as well as
|
||||
/// generalized constant folding.
|
||||
///
|
||||
/// If not overridden, this fallback implementation always fails to fold.
|
||||
///
|
||||
Value *fold() { return nullptr; }
|
||||
OpFoldResult fold(ArrayRef<Attribute> operands) { return {}; }
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -374,16 +374,12 @@ public:
|
||||
return getTerminatorStatus() == TerminatorStatus::NonTerminator;
|
||||
}
|
||||
|
||||
/// Attempt to constant fold this operation with the specified constant
|
||||
/// operand values - the elements in "operands" will correspond directly to
|
||||
/// the operands of the operation, but may be null if non-constant. If
|
||||
/// constant folding is successful, this fills in the `results` vector. If
|
||||
/// not, `results` is unspecified.
|
||||
LogicalResult constantFold(ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<Attribute> &results);
|
||||
|
||||
/// Attempt to fold this operation using the Op's registered foldHook.
|
||||
LogicalResult fold(SmallVectorImpl<Value *> &results);
|
||||
/// Attempt to fold this operation with the specified constant operand values
|
||||
/// - the elements in "operands" will correspond directly to the operands of
|
||||
/// the operation, but may be null if non-constant. If folding is successful,
|
||||
/// this fills in the `results` vector. If not, `results` is unspecified.
|
||||
LogicalResult fold(ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<OpFoldResult> &results);
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Operation Walkers
|
||||
|
||||
@@ -41,6 +41,7 @@ struct OperationState;
|
||||
class OpAsmParser;
|
||||
class OpAsmParserResult;
|
||||
class OpAsmPrinter;
|
||||
class OpFoldResult;
|
||||
class ParseResult;
|
||||
class Pattern;
|
||||
class Region;
|
||||
@@ -95,14 +96,9 @@ public:
|
||||
/// success if everything is ok.
|
||||
LogicalResult (&verifyInvariants)(Operation *op);
|
||||
|
||||
/// This hook implements a constant folder for this operation. It fills in
|
||||
/// `results` on success.
|
||||
LogicalResult (&constantFoldHook)(Operation *op, ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<Attribute> &results);
|
||||
|
||||
/// This hook implements a generalized folder for this operation. Operations
|
||||
/// can implement this to provide simplifications rules that are applied by
|
||||
/// the FuncBuilder::foldOrCreate API and the canonicalization pass.
|
||||
/// the Builder::foldOrCreate API and the canonicalization pass.
|
||||
///
|
||||
/// This is an intentionally limited interface - implementations of this hook
|
||||
/// can only perform the following changes to the operation:
|
||||
@@ -117,10 +113,10 @@ public:
|
||||
/// instead.
|
||||
///
|
||||
/// This allows expression of some simple in-place canonicalizations (e.g.
|
||||
/// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), but does
|
||||
/// not allow for canonicalizations that need to introduce new operations, not
|
||||
/// even constants (e.g. "x-x -> 0" cannot be expressed).
|
||||
LogicalResult (&foldHook)(Operation *op, SmallVectorImpl<Value *> &results);
|
||||
/// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), as well as
|
||||
/// generalized constant folding.
|
||||
LogicalResult (&foldHook)(Operation *op, ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<OpFoldResult> &results);
|
||||
|
||||
/// This hook returns any canonicalization pattern rewrites that the operation
|
||||
/// supports, for use by the canonicalization pass.
|
||||
@@ -142,8 +138,8 @@ public:
|
||||
template <typename T> static AbstractOperation get(Dialect &dialect) {
|
||||
return AbstractOperation(
|
||||
T::getOperationName(), dialect, T::getOperationProperties(), T::classof,
|
||||
T::parseAssembly, T::printAssembly, T::verifyInvariants,
|
||||
T::constantFoldHook, T::foldHook, T::getCanonicalizationPatterns);
|
||||
T::parseAssembly, T::printAssembly, T::verifyInvariants, T::foldHook,
|
||||
T::getCanonicalizationPatterns);
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -153,17 +149,13 @@ private:
|
||||
ParseResult (&parseAssembly)(OpAsmParser *parser, OperationState *result),
|
||||
void (&printAssembly)(Operation *op, OpAsmPrinter *p),
|
||||
LogicalResult (&verifyInvariants)(Operation *op),
|
||||
LogicalResult (&constantFoldHook)(Operation *op,
|
||||
ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<Attribute> &results),
|
||||
LogicalResult (&foldHook)(Operation *op,
|
||||
SmallVectorImpl<Value *> &results),
|
||||
LogicalResult (&foldHook)(Operation *op, ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<OpFoldResult> &results),
|
||||
void (&getCanonicalizationPatterns)(OwningRewritePatternList &results,
|
||||
MLIRContext *context))
|
||||
: name(name), dialect(dialect), classof(classof),
|
||||
parseAssembly(parseAssembly), printAssembly(printAssembly),
|
||||
verifyInvariants(verifyInvariants), constantFoldHook(constantFoldHook),
|
||||
foldHook(foldHook),
|
||||
verifyInvariants(verifyInvariants), foldHook(foldHook),
|
||||
getCanonicalizationPatterns(getCanonicalizationPatterns),
|
||||
opProperties(opProperties) {}
|
||||
|
||||
|
||||
@@ -101,7 +101,7 @@ public:
|
||||
static ParseResult parse(OpAsmParser *parser, OperationState *result);
|
||||
void print(OpAsmPrinter *p);
|
||||
LogicalResult verify();
|
||||
Attribute constantFold(ArrayRef<Attribute> operands, MLIRContext *context);
|
||||
OpFoldResult fold(ArrayRef<Attribute> operands);
|
||||
};
|
||||
|
||||
/// The predicate indicates the type of the comparison to perform:
|
||||
@@ -176,7 +176,7 @@ public:
|
||||
static ParseResult parse(OpAsmParser *parser, OperationState *result);
|
||||
void print(OpAsmPrinter *p);
|
||||
LogicalResult verify();
|
||||
Attribute constantFold(ArrayRef<Attribute> operands, MLIRContext *context);
|
||||
OpFoldResult fold(ArrayRef<Attribute> operands);
|
||||
};
|
||||
|
||||
/// The "cond_br" operation represents a conditional branch operation in a
|
||||
@@ -600,7 +600,7 @@ public:
|
||||
Value *getTrueValue() { return getOperand(1); }
|
||||
Value *getFalseValue() { return getOperand(2); }
|
||||
|
||||
Value *fold();
|
||||
OpFoldResult fold(ArrayRef<Attribute> operands);
|
||||
};
|
||||
|
||||
/// The "store" op writes an element to a memref specified by an index list.
|
||||
|
||||
@@ -112,13 +112,12 @@ class FloatArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
|
||||
|
||||
def AddFOp : FloatArithmeticOp<"addf"> {
|
||||
let summary = "floating point addition operation";
|
||||
let hasConstantFolder = 1;
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def AddIOp : IntArithmeticOp<"addi", [Commutative]> {
|
||||
let summary = "integer addition operation";
|
||||
let hasFolder = 1;
|
||||
let hasConstantFolder = 1;
|
||||
}
|
||||
|
||||
def AllocOp : Std_Op<"alloc"> {
|
||||
@@ -163,7 +162,6 @@ def AllocOp : Std_Op<"alloc"> {
|
||||
|
||||
def AndOp : IntArithmeticOp<"and", [Commutative]> {
|
||||
let summary = "integer binary and";
|
||||
let hasConstantFolder = 1;
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
@@ -288,7 +286,7 @@ def ConstantOp : Std_Op<"constant", [NoSideEffect]> {
|
||||
Attribute getValue() { return getAttr("value"); }
|
||||
}];
|
||||
|
||||
let hasConstantFolder = 1;
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def DeallocOp : Std_Op<"dealloc"> {
|
||||
@@ -338,7 +336,7 @@ def DimOp : Std_Op<"dim", [NoSideEffect]> {
|
||||
}
|
||||
}];
|
||||
|
||||
let hasConstantFolder = 1;
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def DivFOp : FloatArithmeticOp<"divf"> {
|
||||
@@ -347,12 +345,12 @@ def DivFOp : FloatArithmeticOp<"divf"> {
|
||||
|
||||
def DivISOp : IntArithmeticOp<"divis"> {
|
||||
let summary = "signed integer division operation";
|
||||
let hasConstantFolder = 1;
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def DivIUOp : IntArithmeticOp<"diviu"> {
|
||||
let summary = "unsigned integer division operation";
|
||||
let hasConstantFolder = 1;
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def ExtractElementOp : Std_Op<"extract_element", [NoSideEffect]> {
|
||||
@@ -389,7 +387,7 @@ def ExtractElementOp : Std_Op<"extract_element", [NoSideEffect]> {
|
||||
}
|
||||
}];
|
||||
|
||||
let hasConstantFolder = 1;
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def MemRefCastOp : CastOp<"memref_cast"> {
|
||||
@@ -426,18 +424,16 @@ def MemRefCastOp : CastOp<"memref_cast"> {
|
||||
|
||||
def MulFOp : FloatArithmeticOp<"mulf"> {
|
||||
let summary = "foating point multiplication operation";
|
||||
let hasConstantFolder = 1;
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def MulIOp : IntArithmeticOp<"muli", [Commutative]> {
|
||||
let summary = "integer multiplication operation";
|
||||
let hasConstantFolder = 1;
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def OrOp : IntArithmeticOp<"or", [Commutative]> {
|
||||
let summary = "integer binary or";
|
||||
let hasConstantFolder = 1;
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
@@ -447,12 +443,12 @@ def RemFOp : FloatArithmeticOp<"remf"> {
|
||||
|
||||
def RemISOp : IntArithmeticOp<"remis"> {
|
||||
let summary = "signed integer division remainder operation";
|
||||
let hasConstantFolder = 1;
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def RemIUOp : IntArithmeticOp<"remiu"> {
|
||||
let summary = "unsigned integer division remainder operation";
|
||||
let hasConstantFolder = 1;
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def ReturnOp : Std_Op<"return", [Terminator]> {
|
||||
@@ -481,13 +477,12 @@ def ShlISOp : IntArithmeticOp<"shlis"> {
|
||||
|
||||
def SubFOp : FloatArithmeticOp<"subf"> {
|
||||
let summary = "floating point subtraction operation";
|
||||
let hasConstantFolder = 1;
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def SubIOp : IntArithmeticOp<"subi"> {
|
||||
let summary = "integer subtraction operation";
|
||||
let hasConstantFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TensorCastOp : CastOp<"tensor_cast"> {
|
||||
@@ -518,8 +513,6 @@ def TensorCastOp : CastOp<"tensor_cast"> {
|
||||
|
||||
def XOrOp : IntArithmeticOp<"xor", [Commutative]> {
|
||||
let summary = "integer binary xor";
|
||||
let hasConstantFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//===- ConstantFoldUtils.h - Constant Fold Utilities ------------*- C++ -*-===//
|
||||
//===- FoldUtils.h - Operation Fold Utilities -------------------*- C++ -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
@@ -15,13 +15,13 @@
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This header file declares various constant fold utilities. These utilities
|
||||
// are intended to be used by passes to unify and simply their logic.
|
||||
// This header file declares various operation folding utilities. These
|
||||
// utilities are intended to be used by passes to unify and simply their logic.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_TRANSFORMS_CONSTANT_UTILS_H
|
||||
#define MLIR_TRANSFORMS_CONSTANT_UTILS_H
|
||||
#ifndef MLIR_TRANSFORMS_FOLDUTILS_H
|
||||
#define MLIR_TRANSFORMS_FOLDUTILS_H
|
||||
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
@@ -32,13 +32,13 @@ namespace mlir {
|
||||
class Function;
|
||||
class Operation;
|
||||
|
||||
/// A helper class for constant folding operations, and unifying duplicated
|
||||
/// constants along the way.
|
||||
/// A helper class for folding operations, and unifying duplicated constants
|
||||
/// generated along the way.
|
||||
///
|
||||
/// To make sure constants' proper dominance of all their uses, constants are
|
||||
/// To make sure constants properly dominate all their uses, constants are
|
||||
/// moved to the beginning of the entry block of the function when tracked by
|
||||
/// this class.
|
||||
class ConstantFoldHelper {
|
||||
class FoldHelper {
|
||||
public:
|
||||
/// Constructs an instance for managing constants in the given function `f`.
|
||||
/// Constants tracked by this instance will be moved to the entry block of
|
||||
@@ -47,32 +47,30 @@ public:
|
||||
/// This instance does not proactively walk the operations inside `f`;
|
||||
/// instead, users must invoke the following methods to manually handle each
|
||||
/// operation of interest.
|
||||
ConstantFoldHelper(Function *f);
|
||||
FoldHelper(Function *f);
|
||||
|
||||
/// Tries to perform constant folding on the given `op`, including unifying
|
||||
/// deplicated constants. If successful, calls `preReplaceAction` (if
|
||||
/// Tries to perform folding on the given `op`, including unifying
|
||||
/// deduplicated constants. If successful, calls `preReplaceAction` (if
|
||||
/// provided) by passing in `op`, then replaces `op`'s uses with folded
|
||||
/// constants, and returns true.
|
||||
///
|
||||
/// Note: `op` will *not* be erased to avoid invalidating potential walkers in
|
||||
/// the caller.
|
||||
bool
|
||||
tryToConstantFold(Operation *op,
|
||||
std::function<void(Operation *)> preReplaceAction = {});
|
||||
/// results, and returns success. If the op was completely folded it is
|
||||
/// erased.
|
||||
LogicalResult
|
||||
tryToFold(Operation *op,
|
||||
std::function<void(Operation *)> preReplaceAction = {});
|
||||
|
||||
/// Notifies that the given constant `op` should be remove from this
|
||||
/// ConstantFoldHelper's internal bookkeeping.
|
||||
/// FoldHelper's internal bookkeeping.
|
||||
///
|
||||
/// Note: this method must be called if a constant op is to be deleted
|
||||
/// externally to this ConstantFoldHelper. `op` must be a constant op.
|
||||
/// externally to this FoldHelper. `op` must be a constant op.
|
||||
void notifyRemoval(Operation *op);
|
||||
|
||||
private:
|
||||
/// Tries to deduplicate the given constant and returns true if that can be
|
||||
/// Tries to deduplicate the given constant and returns success if that can be
|
||||
/// done. This moves the given constant to the top of the entry block if it
|
||||
/// is first seen. If there is already an existing constant that is the same,
|
||||
/// this does *not* erases the given constant.
|
||||
bool tryToUnify(Operation *op);
|
||||
LogicalResult tryToUnify(Operation *op);
|
||||
|
||||
/// Moves the given constant `op` to entry block to guarantee dominance.
|
||||
void moveConstantToEntryBlock(Operation *op);
|
||||
@@ -86,4 +84,4 @@ private:
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_TRANSFORMS_CONSTANT_UTILS_H
|
||||
#endif // MLIR_TRANSFORMS_FOLDUTILS_H
|
||||
@@ -201,12 +201,11 @@ bool AffineApplyOp::isValidSymbol() {
|
||||
[](Value *op) { return mlir::isValidSymbol(op); });
|
||||
}
|
||||
|
||||
Attribute AffineApplyOp::constantFold(ArrayRef<Attribute> operands,
|
||||
MLIRContext *context) {
|
||||
OpFoldResult AffineApplyOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto map = getAffineMap();
|
||||
SmallVector<Attribute, 1> result;
|
||||
if (failed(map.constantFold(operands, result)))
|
||||
return Attribute();
|
||||
return {};
|
||||
return result[0];
|
||||
}
|
||||
|
||||
|
||||
@@ -509,40 +509,32 @@ auto Operation::getSuccessorOperands(unsigned index) -> operand_range {
|
||||
succOperandIndex + getNumSuccessorOperands(index))};
|
||||
}
|
||||
|
||||
/// Attempt to constant fold this operation with the specified constant
|
||||
/// operand values. If successful, this fills in the results vector. If not,
|
||||
/// results is unspecified.
|
||||
LogicalResult Operation::constantFold(ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<Attribute> &results) {
|
||||
if (auto *abstractOp = getAbstractOperation()) {
|
||||
// If we have a registered operation definition matching this one, use it to
|
||||
// try to constant fold the operation.
|
||||
if (succeeded(abstractOp->constantFoldHook(this, operands, results)))
|
||||
return success();
|
||||
|
||||
// Otherwise, fall back on the dialect hook to handle it.
|
||||
return abstractOp->dialect.constantFoldHook(this, operands, results);
|
||||
}
|
||||
|
||||
// If this operation hasn't been registered or doesn't have abstract
|
||||
// operation, fall back to a dialect which matches the prefix.
|
||||
auto opName = getName().getStringRef();
|
||||
auto dialectPrefix = opName.split('.').first;
|
||||
if (auto *dialect = getContext()->getRegisteredDialect(dialectPrefix))
|
||||
return dialect->constantFoldHook(this, operands, results);
|
||||
|
||||
return failure();
|
||||
}
|
||||
|
||||
/// Attempt to fold this operation using the Op's registered foldHook.
|
||||
LogicalResult Operation::fold(SmallVectorImpl<Value *> &results) {
|
||||
if (auto *abstractOp = getAbstractOperation()) {
|
||||
// If we have a registered operation definition matching this one, use it to
|
||||
// try to constant fold the operation.
|
||||
if (succeeded(abstractOp->foldHook(this, results)))
|
||||
return success();
|
||||
LogicalResult Operation::fold(ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<OpFoldResult> &results) {
|
||||
// If we have a registered operation definition matching this one, use it to
|
||||
// try to constant fold the operation.
|
||||
auto *abstractOp = getAbstractOperation();
|
||||
if (abstractOp && succeeded(abstractOp->foldHook(this, operands, results)))
|
||||
return success();
|
||||
|
||||
// Otherwise, fall back on the dialect hook to handle it.
|
||||
Dialect *dialect;
|
||||
if (abstractOp) {
|
||||
dialect = &abstractOp->dialect;
|
||||
} else {
|
||||
// If this operation hasn't been registered, lookup the parent dialect.
|
||||
auto opName = getName().getStringRef();
|
||||
auto dialectPrefix = opName.split('.').first;
|
||||
if (!(dialect = getContext()->getRegisteredDialect(dialectPrefix)))
|
||||
return failure();
|
||||
}
|
||||
return failure();
|
||||
|
||||
SmallVector<Attribute, 8> constants;
|
||||
if (failed(dialect->constantFoldHook(this, operands, constants)))
|
||||
return failure();
|
||||
results.assign(constants.begin(), constants.end());
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Emit an error with the op name prefixed, like "'dim' op " which is
|
||||
|
||||
@@ -196,8 +196,7 @@ Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
|
||||
// AddFOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Attribute AddFOp::constantFold(ArrayRef<Attribute> operands,
|
||||
MLIRContext *context) {
|
||||
OpFoldResult AddFOp::fold(ArrayRef<Attribute> operands) {
|
||||
return constFoldBinaryOp<FloatAttr>(
|
||||
operands, [](APFloat a, APFloat b) { return a + b; });
|
||||
}
|
||||
@@ -206,20 +205,15 @@ Attribute AddFOp::constantFold(ArrayRef<Attribute> operands,
|
||||
// AddIOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Attribute AddIOp::constantFold(ArrayRef<Attribute> operands,
|
||||
MLIRContext *context) {
|
||||
OpFoldResult AddIOp::fold(ArrayRef<Attribute> operands) {
|
||||
/// addi(x, 0) -> x
|
||||
if (matchPattern(rhs(), m_Zero()))
|
||||
return lhs();
|
||||
|
||||
return constFoldBinaryOp<IntegerAttr>(operands,
|
||||
[](APInt a, APInt b) { return a + b; });
|
||||
}
|
||||
|
||||
Value *AddIOp::fold() {
|
||||
/// addi(x, 0) -> x
|
||||
if (matchPattern(getOperand(1), m_Zero()))
|
||||
return getOperand(0);
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AllocOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -770,8 +764,7 @@ static bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,
|
||||
}
|
||||
|
||||
// Constant folding hook for comparisons.
|
||||
Attribute CmpIOp::constantFold(ArrayRef<Attribute> operands,
|
||||
MLIRContext *context) {
|
||||
OpFoldResult CmpIOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 2 && "cmpi takes two arguments");
|
||||
|
||||
auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
|
||||
@@ -780,7 +773,7 @@ Attribute CmpIOp::constantFold(ArrayRef<Attribute> operands,
|
||||
return {};
|
||||
|
||||
auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
|
||||
return IntegerAttr::get(IntegerType::get(1, context), APInt(1, val));
|
||||
return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -967,8 +960,7 @@ static bool applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs,
|
||||
}
|
||||
|
||||
// Constant folding hook for comparisons.
|
||||
Attribute CmpFOp::constantFold(ArrayRef<Attribute> operands,
|
||||
MLIRContext *context) {
|
||||
OpFoldResult CmpFOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 2 && "cmpf takes two arguments");
|
||||
|
||||
auto lhs = operands.front().dyn_cast_or_null<FloatAttr>();
|
||||
@@ -980,7 +972,7 @@ Attribute CmpFOp::constantFold(ArrayRef<Attribute> operands,
|
||||
return {};
|
||||
|
||||
auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
|
||||
return IntegerAttr::get(IntegerType::get(1, context), APInt(1, val));
|
||||
return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -1179,8 +1171,7 @@ static LogicalResult verify(ConstantOp &op) {
|
||||
"requires a result type that aligns with the 'value' attribute");
|
||||
}
|
||||
|
||||
Attribute ConstantOp::constantFold(ArrayRef<Attribute> operands,
|
||||
MLIRContext *context) {
|
||||
OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.empty() && "constant has no operands");
|
||||
return getValue();
|
||||
}
|
||||
@@ -1337,8 +1328,7 @@ static LogicalResult verify(DimOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
Attribute DimOp::constantFold(ArrayRef<Attribute> operands,
|
||||
MLIRContext *context) {
|
||||
OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
|
||||
// Constant fold dim when the size along the index referred to is a constant.
|
||||
auto opType = getOperand()->getType();
|
||||
int64_t indexSize = -1;
|
||||
@@ -1348,19 +1338,17 @@ Attribute DimOp::constantFold(ArrayRef<Attribute> operands,
|
||||
indexSize = memrefType.getShape()[getIndex()];
|
||||
|
||||
if (indexSize >= 0)
|
||||
return IntegerAttr::get(IndexType::get(context), indexSize);
|
||||
return IntegerAttr::get(IndexType::get(getContext()), indexSize);
|
||||
|
||||
return nullptr;
|
||||
return {};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DivISOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Attribute DivISOp::constantFold(ArrayRef<Attribute> operands,
|
||||
MLIRContext *context) {
|
||||
OpFoldResult DivISOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 2 && "binary operation takes two operands");
|
||||
(void)context;
|
||||
|
||||
auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
|
||||
auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
|
||||
@@ -1368,9 +1356,8 @@ Attribute DivISOp::constantFold(ArrayRef<Attribute> operands,
|
||||
return {};
|
||||
|
||||
// Don't fold if it requires division by zero.
|
||||
if (rhs.getValue().isNullValue()) {
|
||||
if (rhs.getValue().isNullValue())
|
||||
return {};
|
||||
}
|
||||
|
||||
// Don't fold if it would overflow.
|
||||
bool overflow;
|
||||
@@ -1382,10 +1369,8 @@ Attribute DivISOp::constantFold(ArrayRef<Attribute> operands,
|
||||
// DivIUOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Attribute DivIUOp::constantFold(ArrayRef<Attribute> operands,
|
||||
MLIRContext *context) {
|
||||
OpFoldResult DivIUOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 2 && "binary operation takes two operands");
|
||||
(void)context;
|
||||
|
||||
auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
|
||||
auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
|
||||
@@ -1675,14 +1660,13 @@ static LogicalResult verify(ExtractElementOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
Attribute ExtractElementOp::constantFold(ArrayRef<Attribute> operands,
|
||||
MLIRContext *context) {
|
||||
OpFoldResult ExtractElementOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(!operands.empty() && "extract_element takes atleast one operand");
|
||||
|
||||
// The aggregate operand must be a known constant.
|
||||
Attribute aggregate = operands.front();
|
||||
if (!aggregate)
|
||||
return Attribute();
|
||||
return {};
|
||||
|
||||
// If this is a splat elements attribute, simply return the value. All of the
|
||||
// elements of a splat attribute are the same.
|
||||
@@ -1693,14 +1677,14 @@ Attribute ExtractElementOp::constantFold(ArrayRef<Attribute> operands,
|
||||
SmallVector<uint64_t, 8> indices;
|
||||
for (Attribute indice : llvm::drop_begin(operands, 1)) {
|
||||
if (!indice || !indice.isa<IntegerAttr>())
|
||||
return Attribute();
|
||||
return {};
|
||||
indices.push_back(indice.cast<IntegerAttr>().getInt());
|
||||
}
|
||||
|
||||
// If this is an elements attribute, query the value at the given indices.
|
||||
if (auto elementsAttr = aggregate.dyn_cast<ElementsAttr>())
|
||||
return elementsAttr.getValue(indices);
|
||||
return Attribute();
|
||||
return {};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -1801,14 +1785,15 @@ bool MemRefCastOp::areCastCompatible(Type a, Type b) {
|
||||
return true;
|
||||
}
|
||||
|
||||
Value *MemRefCastOp::fold() { return impl::foldCastOp(*this); }
|
||||
OpFoldResult MemRefCastOp::fold(ArrayRef<Attribute> operands) {
|
||||
return impl::foldCastOp(*this);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MulFOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Attribute MulFOp::constantFold(ArrayRef<Attribute> operands,
|
||||
MLIRContext *context) {
|
||||
OpFoldResult MulFOp::fold(ArrayRef<Attribute> operands) {
|
||||
return constFoldBinaryOp<FloatAttr>(
|
||||
operands, [](APFloat a, APFloat b) { return a * b; });
|
||||
}
|
||||
@@ -1817,29 +1802,24 @@ Attribute MulFOp::constantFold(ArrayRef<Attribute> operands,
|
||||
// MulIOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Attribute MulIOp::constantFold(ArrayRef<Attribute> operands,
|
||||
MLIRContext *context) {
|
||||
OpFoldResult MulIOp::fold(ArrayRef<Attribute> operands) {
|
||||
/// muli(x, 0) -> 0
|
||||
if (matchPattern(rhs(), m_Zero()))
|
||||
return rhs();
|
||||
/// muli(x, 1) -> x
|
||||
if (matchPattern(rhs(), m_One()))
|
||||
return getOperand(0);
|
||||
|
||||
// TODO: Handle the overflow case.
|
||||
return constFoldBinaryOp<IntegerAttr>(operands,
|
||||
[](APInt a, APInt b) { return a * b; });
|
||||
}
|
||||
|
||||
Value *MulIOp::fold() {
|
||||
/// muli(x, 0) -> 0
|
||||
if (matchPattern(getOperand(1), m_Zero()))
|
||||
return getOperand(1);
|
||||
/// muli(x, 1) -> x
|
||||
if (matchPattern(getOperand(1), m_One()))
|
||||
return getOperand(0);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RemISOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Attribute RemISOp::constantFold(ArrayRef<Attribute> operands,
|
||||
MLIRContext *context) {
|
||||
OpFoldResult RemISOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 2 && "remis takes two operands");
|
||||
|
||||
auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
|
||||
@@ -1852,9 +1832,8 @@ Attribute RemISOp::constantFold(ArrayRef<Attribute> operands,
|
||||
APInt(rhs.getValue().getBitWidth(), 0));
|
||||
|
||||
// Don't fold if it requires division by zero.
|
||||
if (rhs.getValue().isNullValue()) {
|
||||
if (rhs.getValue().isNullValue())
|
||||
return {};
|
||||
}
|
||||
|
||||
auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
|
||||
if (!lhs)
|
||||
@@ -1867,8 +1846,7 @@ Attribute RemISOp::constantFold(ArrayRef<Attribute> operands,
|
||||
// RemIUOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Attribute RemIUOp::constantFold(ArrayRef<Attribute> operands,
|
||||
MLIRContext *context) {
|
||||
OpFoldResult RemIUOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 2 && "remiu takes two operands");
|
||||
|
||||
auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
|
||||
@@ -1881,9 +1859,8 @@ Attribute RemIUOp::constantFold(ArrayRef<Attribute> operands,
|
||||
APInt(rhs.getValue().getBitWidth(), 0));
|
||||
|
||||
// Don't fold if it requires division by zero.
|
||||
if (rhs.getValue().isNullValue()) {
|
||||
if (rhs.getValue().isNullValue())
|
||||
return {};
|
||||
}
|
||||
|
||||
auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
|
||||
if (!lhs)
|
||||
@@ -1990,7 +1967,7 @@ LogicalResult SelectOp::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
Value *SelectOp::fold() {
|
||||
OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto *condition = getCondition();
|
||||
|
||||
// select true, %0, %1 => %0
|
||||
@@ -2081,8 +2058,7 @@ void StoreOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
// SubFOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Attribute SubFOp::constantFold(ArrayRef<Attribute> operands,
|
||||
MLIRContext *context) {
|
||||
OpFoldResult SubFOp::fold(ArrayRef<Attribute> operands) {
|
||||
return constFoldBinaryOp<FloatAttr>(
|
||||
operands, [](APFloat a, APFloat b) { return a - b; });
|
||||
}
|
||||
@@ -2091,48 +2067,20 @@ Attribute SubFOp::constantFold(ArrayRef<Attribute> operands,
|
||||
// SubIOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Attribute SubIOp::constantFold(ArrayRef<Attribute> operands,
|
||||
MLIRContext *context) {
|
||||
OpFoldResult SubIOp::fold(ArrayRef<Attribute> operands) {
|
||||
// subi(x,x) -> 0
|
||||
if (getOperand(0) == getOperand(1))
|
||||
return Builder(getContext()).getZeroAttr(getType());
|
||||
|
||||
return constFoldBinaryOp<IntegerAttr>(operands,
|
||||
[](APInt a, APInt b) { return a - b; });
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// subi(x,x) -> 0
|
||||
///
|
||||
struct SimplifyXMinusX : public RewritePattern {
|
||||
SimplifyXMinusX(MLIRContext *context)
|
||||
: RewritePattern(SubIOp::getOperationName(), 1, context) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto subi = cast<SubIOp>(op);
|
||||
if (subi.getOperand(0) != subi.getOperand(1))
|
||||
return matchFailure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<ConstantOp>(
|
||||
op, subi.getType(), rewriter.getZeroAttr(subi.getType()));
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace.
|
||||
|
||||
void SubIOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.push_back(llvm::make_unique<SimplifyXMinusX>(context));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AndOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Attribute AndOp::constantFold(ArrayRef<Attribute> operands,
|
||||
MLIRContext *context) {
|
||||
return constFoldBinaryOp<IntegerAttr>(operands,
|
||||
[](APInt a, APInt b) { return a & b; });
|
||||
}
|
||||
|
||||
Value *AndOp::fold() {
|
||||
OpFoldResult AndOp::fold(ArrayRef<Attribute> operands) {
|
||||
/// and(x, 0) -> 0
|
||||
if (matchPattern(rhs(), m_Zero()))
|
||||
return rhs();
|
||||
@@ -2140,20 +2088,15 @@ Value *AndOp::fold() {
|
||||
if (lhs() == rhs())
|
||||
return rhs();
|
||||
|
||||
return nullptr;
|
||||
return constFoldBinaryOp<IntegerAttr>(operands,
|
||||
[](APInt a, APInt b) { return a & b; });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// OrOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Attribute OrOp::constantFold(ArrayRef<Attribute> operands,
|
||||
MLIRContext *context) {
|
||||
return constFoldBinaryOp<IntegerAttr>(operands,
|
||||
[](APInt a, APInt b) { return a | b; });
|
||||
}
|
||||
|
||||
Value *OrOp::fold() {
|
||||
OpFoldResult OrOp::fold(ArrayRef<Attribute> operands) {
|
||||
/// or(x, 0) -> x
|
||||
if (matchPattern(rhs(), m_Zero()))
|
||||
return lhs();
|
||||
@@ -2161,51 +2104,26 @@ Value *OrOp::fold() {
|
||||
if (lhs() == rhs())
|
||||
return rhs();
|
||||
|
||||
return nullptr;
|
||||
return constFoldBinaryOp<IntegerAttr>(operands,
|
||||
[](APInt a, APInt b) { return a | b; });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XOrOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Attribute XOrOp::constantFold(ArrayRef<Attribute> operands,
|
||||
MLIRContext *context) {
|
||||
OpFoldResult XOrOp::fold(ArrayRef<Attribute> operands) {
|
||||
/// xor(x, 0) -> x
|
||||
if (matchPattern(rhs(), m_Zero()))
|
||||
return lhs();
|
||||
/// xor(x,x) -> 0
|
||||
if (lhs() == rhs())
|
||||
return Builder(getContext()).getZeroAttr(getType());
|
||||
|
||||
return constFoldBinaryOp<IntegerAttr>(operands,
|
||||
[](APInt a, APInt b) { return a ^ b; });
|
||||
}
|
||||
|
||||
Value *XOrOp::fold() {
|
||||
/// xor(x, 0) -> x
|
||||
if (matchPattern(rhs(), m_Zero()))
|
||||
return lhs();
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// xor(x,x) -> 0
|
||||
///
|
||||
struct SimplifyXXOrX : public RewritePattern {
|
||||
SimplifyXXOrX(MLIRContext *context)
|
||||
: RewritePattern(XOrOp::getOperationName(), 1, context) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto xorOp = cast<XOrOp>(op);
|
||||
if (xorOp.lhs() != xorOp.rhs())
|
||||
return matchFailure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<ConstantOp>(
|
||||
op, xorOp.getType(), rewriter.getZeroAttr(xorOp.getType()));
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace.
|
||||
|
||||
void XOrOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.push_back(llvm::make_unique<SimplifyXXOrX>(context));
|
||||
}
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorCastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -2239,7 +2157,9 @@ bool TensorCastOp::areCastCompatible(Type a, Type b) {
|
||||
return true;
|
||||
}
|
||||
|
||||
Value *TensorCastOp::fold() { return impl::foldCastOp(*this); }
|
||||
OpFoldResult TensorCastOp::fold(ArrayRef<Attribute> operands) {
|
||||
return impl::foldCastOp(*this);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TableGen'd op method definitions
|
||||
|
||||
@@ -17,7 +17,7 @@ add_llvm_library(MLIRTransforms
|
||||
SimplifyAffineStructures.cpp
|
||||
StripDebugInfo.cpp
|
||||
TestConstantFold.cpp
|
||||
Utils/ConstantFoldUtils.cpp
|
||||
Utils/FoldUtils.cpp
|
||||
Utils/GreedyPatternRewriteDriver.cpp
|
||||
Utils/LoopUtils.cpp
|
||||
Utils/Utils.cpp
|
||||
|
||||
@@ -20,7 +20,7 @@
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/StandardOps/Ops.h"
|
||||
#include "mlir/Transforms/ConstantFoldUtils.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "mlir/Transforms/Utils.h"
|
||||
|
||||
@@ -31,26 +31,22 @@ namespace {
|
||||
struct TestConstantFold : public FunctionPass<TestConstantFold> {
|
||||
// All constants in the function post folding.
|
||||
SmallVector<Operation *, 8> existingConstants;
|
||||
// Operations that were folded and that need to be erased.
|
||||
std::vector<Operation *> opsToErase;
|
||||
|
||||
void foldOperation(Operation *op, ConstantFoldHelper &helper);
|
||||
void foldOperation(Operation *op, FoldHelper &helper);
|
||||
void runOnFunction() override;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
void TestConstantFold::foldOperation(Operation *op,
|
||||
ConstantFoldHelper &helper) {
|
||||
void TestConstantFold::foldOperation(Operation *op, FoldHelper &helper) {
|
||||
// Attempt to fold the specified operation, including handling unused or
|
||||
// duplicated constants.
|
||||
if (helper.tryToConstantFold(op)) {
|
||||
opsToErase.push_back(op);
|
||||
}
|
||||
if (succeeded(helper.tryToFold(op)))
|
||||
return;
|
||||
|
||||
// If this op is a constant that are used and cannot be de-duplicated,
|
||||
// remember it for cleanup later.
|
||||
else if (auto constant = dyn_cast<ConstantOp>(op)) {
|
||||
if (auto constant = dyn_cast<ConstantOp>(op))
|
||||
existingConstants.push_back(op);
|
||||
}
|
||||
}
|
||||
|
||||
// For now, we do a simple top-down pass over a function folding constants. We
|
||||
@@ -58,10 +54,9 @@ void TestConstantFold::foldOperation(Operation *op,
|
||||
// branches, or anything else fancy.
|
||||
void TestConstantFold::runOnFunction() {
|
||||
existingConstants.clear();
|
||||
opsToErase.clear();
|
||||
|
||||
auto &f = getFunction();
|
||||
ConstantFoldHelper helper(&f);
|
||||
FoldHelper helper(&f);
|
||||
|
||||
// Collect and fold the operations within the function.
|
||||
SmallVector<Operation *, 8> ops;
|
||||
@@ -74,12 +69,6 @@ void TestConstantFold::runOnFunction() {
|
||||
for (Operation *op : llvm::reverse(ops))
|
||||
foldOperation(op, helper);
|
||||
|
||||
// At this point, these operations are dead, remove them.
|
||||
for (auto *op : opsToErase) {
|
||||
assert(op->hasNoSideEffect() && "Constant folded op with side effects?");
|
||||
op->erase();
|
||||
}
|
||||
|
||||
// By the time we are done, we may have simplified a bunch of code, leaving
|
||||
// around dead constants. Check for them now and remove them.
|
||||
for (auto *cst : existingConstants) {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//===- ConstantFoldUtils.cpp ---- Constant Fold Utilities -----------------===//
|
||||
//===- FoldUtils.cpp ---- Fold Utilities ----------------------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
@@ -15,12 +15,12 @@
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file defines various constant fold utilities. These utilities are
|
||||
// This file defines various operation fold utilities. These utilities are
|
||||
// intended to be used by passes to unify and simply their logic.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Transforms/ConstantFoldUtils.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
@@ -29,10 +29,11 @@
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
ConstantFoldHelper::ConstantFoldHelper(Function *f) : function(f) {}
|
||||
FoldHelper::FoldHelper(Function *f) : function(f) {}
|
||||
|
||||
bool ConstantFoldHelper::tryToConstantFold(
|
||||
Operation *op, std::function<void(Operation *)> preReplaceAction) {
|
||||
LogicalResult
|
||||
FoldHelper::tryToFold(Operation *op,
|
||||
std::function<void(Operation *)> preReplaceAction) {
|
||||
assert(op->getFunction() == function &&
|
||||
"cannot constant fold op from another function");
|
||||
|
||||
@@ -44,13 +45,15 @@ bool ConstantFoldHelper::tryToConstantFold(
|
||||
// If this constant is dead, update bookkeeping and signal the caller.
|
||||
if (constant.use_empty()) {
|
||||
notifyRemoval(op);
|
||||
return true;
|
||||
op->erase();
|
||||
return success();
|
||||
}
|
||||
// Otherwise, try to see if we can de-duplicate it.
|
||||
return tryToUnify(op);
|
||||
}
|
||||
|
||||
SmallVector<Attribute, 8> operandConstants, resultConstants;
|
||||
SmallVector<Attribute, 8> operandConstants;
|
||||
SmallVector<OpFoldResult, 8> results;
|
||||
|
||||
// Check to see if any operands to the operation is constant and whether
|
||||
// the operation knows how to constant fold itself.
|
||||
@@ -67,8 +70,8 @@ bool ConstantFoldHelper::tryToConstantFold(
|
||||
}
|
||||
|
||||
// Attempt to constant fold the operation.
|
||||
if (failed(op->constantFold(operandConstants, resultConstants)))
|
||||
return false;
|
||||
if (failed(op->fold(operandConstants, results)))
|
||||
return failure();
|
||||
|
||||
// Constant folding succeeded. We will start replacing this op's uses and
|
||||
// eventually erase this op. Invoke the callback provided by the caller to
|
||||
@@ -76,21 +79,35 @@ bool ConstantFoldHelper::tryToConstantFold(
|
||||
if (preReplaceAction)
|
||||
preReplaceAction(op);
|
||||
|
||||
// Check to see if the operation was just updated in place.
|
||||
if (results.empty())
|
||||
return success();
|
||||
assert(results.size() == op->getNumResults());
|
||||
|
||||
// Create the result constants and replace the results.
|
||||
FuncBuilder builder(op);
|
||||
for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
|
||||
auto *res = op->getResult(i);
|
||||
if (res->use_empty()) // Ignore dead uses.
|
||||
continue;
|
||||
assert(!results[i].isNull() && "expected valid OpFoldResult");
|
||||
|
||||
// Check if the result was an SSA value.
|
||||
if (auto *repl = results[i].dyn_cast<Value *>()) {
|
||||
if (repl != res)
|
||||
res->replaceAllUsesWith(repl);
|
||||
continue;
|
||||
}
|
||||
|
||||
// If we already have a canonicalized version of this constant, just reuse
|
||||
// it. Otherwise create a new one.
|
||||
Attribute attrRepl = results[i].get<Attribute>();
|
||||
auto &constInst =
|
||||
uniquedConstants[std::make_pair(resultConstants[i], res->getType())];
|
||||
uniquedConstants[std::make_pair(attrRepl, res->getType())];
|
||||
if (!constInst) {
|
||||
// TODO: Extend to support dialect-specific constant ops.
|
||||
auto newOp = builder.create<ConstantOp>(op->getLoc(), res->getType(),
|
||||
resultConstants[i]);
|
||||
auto newOp =
|
||||
builder.create<ConstantOp>(op->getLoc(), res->getType(), attrRepl);
|
||||
// Register to the constant map and also move up to entry block to
|
||||
// guarantee dominance.
|
||||
constInst = newOp.getOperation();
|
||||
@@ -98,17 +115,18 @@ bool ConstantFoldHelper::tryToConstantFold(
|
||||
}
|
||||
res->replaceAllUsesWith(constInst->getResult(0));
|
||||
}
|
||||
op->erase();
|
||||
|
||||
return true;
|
||||
return success();
|
||||
}
|
||||
|
||||
void ConstantFoldHelper::notifyRemoval(Operation *op) {
|
||||
void FoldHelper::notifyRemoval(Operation *op) {
|
||||
assert(op->getFunction() == function &&
|
||||
"cannot remove constant from another function");
|
||||
|
||||
Attribute constValue;
|
||||
matchPattern(op, m_Constant(&constValue));
|
||||
assert(constValue);
|
||||
if (!matchPattern(op, m_Constant(&constValue)))
|
||||
return;
|
||||
|
||||
// This constant is dead. keep uniquedConstants up to date.
|
||||
auto it = uniquedConstants.find({constValue, op->getResult(0)->getType()});
|
||||
@@ -116,7 +134,7 @@ void ConstantFoldHelper::notifyRemoval(Operation *op) {
|
||||
uniquedConstants.erase(it);
|
||||
}
|
||||
|
||||
bool ConstantFoldHelper::tryToUnify(Operation *op) {
|
||||
LogicalResult FoldHelper::tryToUnify(Operation *op) {
|
||||
Attribute constValue;
|
||||
matchPattern(op, m_Constant(&constValue));
|
||||
assert(constValue);
|
||||
@@ -127,13 +145,14 @@ bool ConstantFoldHelper::tryToUnify(Operation *op) {
|
||||
if (constInst) {
|
||||
// If this constant is already our uniqued one, then leave it alone.
|
||||
if (constInst == op)
|
||||
return false;
|
||||
return failure();
|
||||
|
||||
// Otherwise replace this redundant constant with the uniqued one. We know
|
||||
// this is safe because we move constants to the top of the function when
|
||||
// they are uniqued, so we know they dominate all uses.
|
||||
op->getResult(0)->replaceAllUsesWith(constInst->getResult(0));
|
||||
return true;
|
||||
op->erase();
|
||||
return success();
|
||||
}
|
||||
|
||||
// If we have no entry, then we should unique this constant as the
|
||||
@@ -141,10 +160,10 @@ bool ConstantFoldHelper::tryToUnify(Operation *op) {
|
||||
// entry block of the function.
|
||||
constInst = op;
|
||||
moveConstantToEntryBlock(op);
|
||||
return false;
|
||||
return failure();
|
||||
}
|
||||
|
||||
void ConstantFoldHelper::moveConstantToEntryBlock(Operation *op) {
|
||||
void FoldHelper::moveConstantToEntryBlock(Operation *op) {
|
||||
// Insert at the very top of the entry block.
|
||||
auto &entryBB = function->front();
|
||||
op->moveBefore(&entryBB, entryBB.begin());
|
||||
@@ -22,7 +22,7 @@
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/StandardOps/Ops.h"
|
||||
#include "mlir/Transforms/ConstantFoldUtils.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
@@ -148,7 +148,7 @@ private:
|
||||
/// Perform the rewrites.
|
||||
bool GreedyPatternRewriteDriver::simplifyFunction(int maxIterations) {
|
||||
Function *fn = builder.getFunction();
|
||||
ConstantFoldHelper helper(fn);
|
||||
FoldHelper helper(fn);
|
||||
|
||||
bool changed = false;
|
||||
int i = 0;
|
||||
@@ -171,67 +171,31 @@ bool GreedyPatternRewriteDriver::simplifyFunction(int maxIterations) {
|
||||
// If the operation has no side effects, and no users, then it is
|
||||
// trivially dead - remove it.
|
||||
if (op->hasNoSideEffect() && op->use_empty()) {
|
||||
// Be careful to update bookkeeping in ConstantHelper to keep
|
||||
// consistency if this is a constant op.
|
||||
if (isa<ConstantOp>(op))
|
||||
helper.notifyRemoval(op);
|
||||
// Be careful to update bookkeeping in FoldHelper to keep consistency if
|
||||
// this is a constant op.
|
||||
helper.notifyRemoval(op);
|
||||
op->erase();
|
||||
continue;
|
||||
}
|
||||
|
||||
// Collects all the operands and result uses of the given `op` into work
|
||||
// list.
|
||||
auto collectOperandsAndUses = [this](Operation *op) {
|
||||
originalOperands.assign(op->operand_begin(), op->operand_end());
|
||||
auto collectOperandsAndUses = [&](Operation *op) {
|
||||
// Add the operands to the worklist for visitation.
|
||||
addToWorklist(op->getOperands());
|
||||
addToWorklist(originalOperands);
|
||||
|
||||
// Add all the users of the result to the worklist so we make sure
|
||||
// to revisit them.
|
||||
//
|
||||
// TODO: Add a result->getUsers() iterator.
|
||||
for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
|
||||
for (unsigned i = 0, e = op->getNumResults(); i != e; ++i)
|
||||
for (auto &operand : op->getResult(i)->getUses())
|
||||
addToWorklist(operand.getOwner());
|
||||
}
|
||||
};
|
||||
|
||||
// Try to constant fold this op.
|
||||
if (helper.tryToConstantFold(op, collectOperandsAndUses)) {
|
||||
assert(op->hasNoSideEffect() &&
|
||||
"Constant folded op with side effects?");
|
||||
op->erase();
|
||||
changed |= true;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Otherwise see if we can use the generic folder API to simplify the
|
||||
// operation.
|
||||
originalOperands.assign(op->operand_begin(), op->operand_end());
|
||||
resultValues.clear();
|
||||
if (succeeded(op->fold(resultValues))) {
|
||||
// If the result was an in-place simplification (e.g. max(x,x,y) ->
|
||||
// max(x,y)) then add the original operands to the worklist so we can
|
||||
// make sure to revisit them.
|
||||
if (resultValues.empty()) {
|
||||
// Add the operands back to the worklist as there may be more
|
||||
// canonicalization opportunities now.
|
||||
addToWorklist(originalOperands);
|
||||
} else {
|
||||
// Otherwise, the operation is simplified away completely.
|
||||
assert(resultValues.size() == op->getNumResults());
|
||||
|
||||
// Notify that we are replacing this operation.
|
||||
notifyRootReplaced(op);
|
||||
|
||||
// Replace the result values and erase the operation.
|
||||
for (unsigned i = 0, e = resultValues.size(); i != e; ++i) {
|
||||
auto *res = op->getResult(i);
|
||||
if (!res->use_empty())
|
||||
res->replaceAllUsesWith(resultValues[i]);
|
||||
}
|
||||
|
||||
notifyOperationRemoved(op);
|
||||
op->erase();
|
||||
}
|
||||
// Try to fold this op.
|
||||
if (succeeded(helper.tryToFold(op, collectOperandsAndUses))) {
|
||||
changed |= true;
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -28,7 +28,6 @@ def NS_AOp : NS_Op<"a_op", [NoSideEffect]> {
|
||||
let verifier = [{ baz }];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
let hasConstantFolder = 1;
|
||||
let hasFolder = 1;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
@@ -55,8 +54,7 @@ def NS_AOp : NS_Op<"a_op", [NoSideEffect]> {
|
||||
// CHECK: void print(OpAsmPrinter *p);
|
||||
// CHECK: LogicalResult verify();
|
||||
// CHECK: static void getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context);
|
||||
// CHECK: LogicalResult constantFold(ArrayRef<Attribute> operands, SmallVectorImpl<Attribute> &results, MLIRContext *context);
|
||||
// CHECK: bool fold(SmallVectorImpl<Value *> &results);
|
||||
// CHECK: LogicalResult fold(ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results);
|
||||
// CHECK: // Display a graph for debugging purposes.
|
||||
// CHECK: void displayGraph();
|
||||
// CHECK: };
|
||||
|
||||
@@ -834,28 +834,15 @@ void OpEmitter::genCanonicalizerDecls() {
|
||||
void OpEmitter::genFolderDecls() {
|
||||
bool hasSingleResult = op.getNumResults() == 1;
|
||||
|
||||
if (def.getValueAsBit("hasConstantFolder")) {
|
||||
if (hasSingleResult) {
|
||||
const char *const params =
|
||||
"ArrayRef<Attribute> operands, MLIRContext *context";
|
||||
opClass.newMethod("Attribute", "constantFold", params, OpMethod::MP_None,
|
||||
/*declOnly=*/true);
|
||||
} else {
|
||||
const char *const params =
|
||||
"ArrayRef<Attribute> operands, SmallVectorImpl<Attribute> &results, "
|
||||
"MLIRContext *context";
|
||||
opClass.newMethod("LogicalResult", "constantFold", params,
|
||||
OpMethod::MP_None, /*declOnly=*/true);
|
||||
}
|
||||
}
|
||||
|
||||
if (def.getValueAsBit("hasFolder")) {
|
||||
if (hasSingleResult) {
|
||||
opClass.newMethod("Value *", "fold", /*params=*/"", OpMethod::MP_None,
|
||||
const char *const params = "ArrayRef<Attribute> operands";
|
||||
opClass.newMethod("OpFoldResult", "fold", params, OpMethod::MP_None,
|
||||
/*declOnly=*/true);
|
||||
} else {
|
||||
opClass.newMethod("bool", "fold", "SmallVectorImpl<Value *> &results",
|
||||
OpMethod::MP_None,
|
||||
const char *const params = "ArrayRef<Attribute> operands, "
|
||||
"SmallVectorImpl<OpFoldResult> &results";
|
||||
opClass.newMethod("LogicalResult", "fold", params, OpMethod::MP_None,
|
||||
/*declOnly=*/true);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user