Unify the 'constantFold' and 'fold' hooks on an operation into just 'fold'. This new unified fold hook will take constant attributes as operands, and may return an existing 'Value *' or a constant 'Attribute' when folding. This removes the awkward situation where a simple canonicalization like "sub(x,x)->0" had to be written as a canonicalization pattern as opposed to a fold.

--

PiperOrigin-RevId: 248582024
This commit is contained in:
River Riddle
2019-05-16 12:51:45 -07:00
committed by Mehdi Amini
parent 13dbad87f6
commit 1982afb145
21 changed files with 257 additions and 451 deletions

View File

@@ -371,11 +371,6 @@ This boolean field indicate whether canonicalization patterns have been defined
for this operation. If it is `1`, then `::getCanonicalizationPatterns()` should
be defined.
### `hasConstantFolder`
This boolean field indicate whether constant folding rules have been defined
for this operation. If it is `1`, then `::constantFold()` should be defined.
### `hasFolder`
This boolean field indicate whether general folding rules have been defined

View File

@@ -74,11 +74,10 @@ in either way.
Operations can also have custom parser, printer, builder, verifier, constant
folder, or canonicalizer. These require specifying additional C++ methods to
invoke for additional functionality. For example, if an operation is marked to
have a constant folder, the constant folder also needs to be added, e.g.,:
have a folder, the constant folder also needs to be added, e.g.,:
```c++
Attribute SpecificOp::constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) {
OpFoldResult SpecificOp::fold(ArrayRef<Attribute> constOperands) {
if (unable_to_fold)
return {};
....

View File

@@ -83,7 +83,7 @@ public:
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
LogicalResult verify();
Attribute constantFold(ArrayRef<Attribute> operands, MLIRContext *context);
OpFoldResult fold(ArrayRef<Attribute> operands);
static void getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context);

View File

@@ -761,7 +761,7 @@ private:
namespace llvm {
// Attribute hash just like pointers
// Attribute hash just like pointers.
template <> struct DenseMapInfo<mlir::Attribute> {
static mlir::Attribute getEmptyKey() {
auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
@@ -779,6 +779,18 @@ template <> struct DenseMapInfo<mlir::Attribute> {
}
};
/// Allow LLVM to steal the low bits of Attributes.
template <> struct PointerLikeTypeTraits<mlir::Attribute> {
public:
static inline void *getAsVoidPointer(mlir::Attribute attr) {
return const_cast<void *>(attr.getAsOpaquePointer());
}
static inline mlir::Attribute getFromVoidPointer(void *ptr) {
return mlir::Attribute::getFromOpaquePointer(ptr);
}
enum { NumLowBitsAvailable = 3 };
};
} // namespace llvm
#endif

View File

@@ -24,7 +24,7 @@
#ifndef MLIR_MATCHERS_H
#define MLIR_MATCHERS_H
#include "mlir/IR/Operation.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/StandardTypes.h"
#include <type_traits>
@@ -73,9 +73,9 @@ struct constant_op_binder {
if (!op->hasNoSideEffect())
return false;
SmallVector<Attribute, 1> foldedAttr;
if (succeeded(op->constantFold(/*operands=*/llvm::None, foldedAttr))) {
*bind_value = foldedAttr.front();
SmallVector<OpFoldResult, 1> foldedOp;
if (succeeded(op->fold(/*operands=*/llvm::None, foldedOp))) {
*bind_value = foldedOp.front().dyn_cast<Attribute>();
return true;
}
return false;

View File

@@ -925,9 +925,6 @@ class Op<Dialect dialect, string mnemonic, list<OpTrait> props = []> {
// and C++ implementations.
bit hasCanonicalizer = 0;
// Whether this op has a constant folder.
bit hasConstantFolder = 0;
// Whether this op has a folder.
bit hasFolder = 0;

View File

@@ -170,47 +170,28 @@ inline bool operator!=(OpState lhs, OpState rhs) {
return lhs.getOperation() != rhs.getOperation();
}
/// This template defines the constantFoldHook and foldHook as used by
/// AbstractOperation.
/// This class represents a single result from folding an operation.
class OpFoldResult : public llvm::PointerUnion<Attribute, Value *> {
using llvm::PointerUnion<Attribute, Value *>::PointerUnion;
};
/// This template defines the foldHook as used by AbstractOperation.
///
/// The default implementation uses a general constantFold/fold method that can
/// be defined on custom ops which can return multiple results.
/// The default implementation uses a general fold method that can be defined on
/// custom ops which can return multiple results.
template <typename ConcreteType, bool isSingleResult, typename = void>
class FoldingHook {
public:
/// This is an implementation detail of the constant folder hook for
/// AbstractOperation.
static LogicalResult constantFoldHook(Operation *op,
ArrayRef<Attribute> operands,
SmallVectorImpl<Attribute> &results) {
return cast<ConcreteType>(op).constantFold(operands, results,
op->getContext());
}
/// Op implementations can implement this hook. It should attempt to constant
/// fold this operation with the specified constant operand values - the
/// elements in "operands" will correspond directly to the operands of the
/// operation, but may be null if non-constant. If constant folding is
/// successful, this fills in the `results` vector. If not, `results` is
/// unspecified.
///
/// If not overridden, this fallback implementation always fails to fold.
///
LogicalResult constantFold(ArrayRef<Attribute> operands,
SmallVectorImpl<Attribute> &results,
MLIRContext *context) {
return failure();
}
/// This is an implementation detail of the folder hook for AbstractOperation.
static LogicalResult foldHook(Operation *op,
SmallVectorImpl<Value *> &results) {
return cast<ConcreteType>(op).fold(results);
static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
return cast<ConcreteType>(op).fold(operands, results);
}
/// This hook implements a generalized folder for this operation. Operations
/// can implement this to provide simplifications rules that are applied by
/// the FuncBuilder::foldOrCreate API and the canonicalization pass.
/// the Builder::foldOrCreate API and the canonicalization pass.
///
/// This is an intentionally limited interface - implementations of this hook
/// can only perform the following changes to the operation:
@@ -225,23 +206,25 @@ public:
/// instead.
///
/// This allows expression of some simple in-place canonicalizations (e.g.
/// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), but does
/// not allow for canonicalizations that need to introduce new operations, not
/// even constants (e.g. "x-x -> 0" cannot be expressed).
/// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), as well as
/// generalized constant folding.
///
/// If not overridden, this fallback implementation always fails to fold.
///
LogicalResult fold(SmallVectorImpl<Value *> &results) { return failure(); }
LogicalResult fold(ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
return failure();
}
};
/// This template specialization defines the constantFoldHook and foldHook as
/// used by AbstractOperation for single-result operations. This gives the hook
/// a nicer signature that is easier to implement.
/// This template specialization defines the foldHook as used by
/// AbstractOperation for single-result operations. This gives the hook a nicer
/// signature that is easier to implement.
template <typename ConcreteType, bool isSingleResult>
class FoldingHook<ConcreteType, isSingleResult,
typename std::enable_if<isSingleResult>::type> {
public:
/// If the operation returns a single value, then the Op can be implicitly
/// If the operation returns a single value, then the Op can be implicitly
/// converted to an Value*. This yields the value of the only result.
operator Value *() {
return static_cast<ConcreteType *>(this)->getOperation()->getResult(0);
@@ -249,11 +232,9 @@ public:
/// This is an implementation detail of the constant folder hook for
/// AbstractOperation.
static LogicalResult constantFoldHook(Operation *op,
ArrayRef<Attribute> operands,
SmallVectorImpl<Attribute> &results) {
auto result =
cast<ConcreteType>(op).constantFold(operands, op->getContext());
static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
auto result = cast<ConcreteType>(op).fold(operands);
if (!result)
return failure();
@@ -261,33 +242,9 @@ public:
return success();
}
/// Op implementations can implement this hook. It should attempt to constant
/// fold this operation with the specified constant operand values - the
/// elements in "operands" will correspond directly to the operands of the
/// operation, but may be null if non-constant. If constant folding is
/// successful, this returns a non-null attribute, otherwise it returns null
/// on failure.
///
/// If not overridden, this fallback implementation always fails to fold.
///
Attribute constantFold(ArrayRef<Attribute> operands, MLIRContext *context) {
return nullptr;
}
/// This is an implementation detail of the folder hook for AbstractOperation.
static LogicalResult foldHook(Operation *op,
SmallVectorImpl<Value *> &results) {
auto *result = cast<ConcreteType>(op).fold();
if (!result)
return failure();
if (result != op->getResult(0))
results.push_back(result);
return success();
}
/// This hook implements a generalized folder for this operation. Operations
/// can implement this to provide simplifications rules that are applied by
/// the FuncBuilder::foldOrCreate API and the canonicalization pass.
/// the Builder::foldOrCreate API and the canonicalization pass.
///
/// This is an intentionally limited interface - implementations of this hook
/// can only perform the following changes to the operation:
@@ -301,13 +258,12 @@ public:
/// remove the operation and use that result instead.
///
/// This allows expression of some simple in-place canonicalizations (e.g.
/// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), but does
/// not allow for canonicalizations that need to introduce new operations, not
/// even constants (e.g. "x-x -> 0" cannot be expressed).
/// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), as well as
/// generalized constant folding.
///
/// If not overridden, this fallback implementation always fails to fold.
///
Value *fold() { return nullptr; }
OpFoldResult fold(ArrayRef<Attribute> operands) { return {}; }
};
//===----------------------------------------------------------------------===//

View File

@@ -374,16 +374,12 @@ public:
return getTerminatorStatus() == TerminatorStatus::NonTerminator;
}
/// Attempt to constant fold this operation with the specified constant
/// operand values - the elements in "operands" will correspond directly to
/// the operands of the operation, but may be null if non-constant. If
/// constant folding is successful, this fills in the `results` vector. If
/// not, `results` is unspecified.
LogicalResult constantFold(ArrayRef<Attribute> operands,
SmallVectorImpl<Attribute> &results);
/// Attempt to fold this operation using the Op's registered foldHook.
LogicalResult fold(SmallVectorImpl<Value *> &results);
/// Attempt to fold this operation with the specified constant operand values
/// - the elements in "operands" will correspond directly to the operands of
/// the operation, but may be null if non-constant. If folding is successful,
/// this fills in the `results` vector. If not, `results` is unspecified.
LogicalResult fold(ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results);
//===--------------------------------------------------------------------===//
// Operation Walkers

View File

@@ -41,6 +41,7 @@ struct OperationState;
class OpAsmParser;
class OpAsmParserResult;
class OpAsmPrinter;
class OpFoldResult;
class ParseResult;
class Pattern;
class Region;
@@ -95,14 +96,9 @@ public:
/// success if everything is ok.
LogicalResult (&verifyInvariants)(Operation *op);
/// This hook implements a constant folder for this operation. It fills in
/// `results` on success.
LogicalResult (&constantFoldHook)(Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<Attribute> &results);
/// This hook implements a generalized folder for this operation. Operations
/// can implement this to provide simplifications rules that are applied by
/// the FuncBuilder::foldOrCreate API and the canonicalization pass.
/// the Builder::foldOrCreate API and the canonicalization pass.
///
/// This is an intentionally limited interface - implementations of this hook
/// can only perform the following changes to the operation:
@@ -117,10 +113,10 @@ public:
/// instead.
///
/// This allows expression of some simple in-place canonicalizations (e.g.
/// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), but does
/// not allow for canonicalizations that need to introduce new operations, not
/// even constants (e.g. "x-x -> 0" cannot be expressed).
LogicalResult (&foldHook)(Operation *op, SmallVectorImpl<Value *> &results);
/// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), as well as
/// generalized constant folding.
LogicalResult (&foldHook)(Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results);
/// This hook returns any canonicalization pattern rewrites that the operation
/// supports, for use by the canonicalization pass.
@@ -142,8 +138,8 @@ public:
template <typename T> static AbstractOperation get(Dialect &dialect) {
return AbstractOperation(
T::getOperationName(), dialect, T::getOperationProperties(), T::classof,
T::parseAssembly, T::printAssembly, T::verifyInvariants,
T::constantFoldHook, T::foldHook, T::getCanonicalizationPatterns);
T::parseAssembly, T::printAssembly, T::verifyInvariants, T::foldHook,
T::getCanonicalizationPatterns);
}
private:
@@ -153,17 +149,13 @@ private:
ParseResult (&parseAssembly)(OpAsmParser *parser, OperationState *result),
void (&printAssembly)(Operation *op, OpAsmPrinter *p),
LogicalResult (&verifyInvariants)(Operation *op),
LogicalResult (&constantFoldHook)(Operation *op,
ArrayRef<Attribute> operands,
SmallVectorImpl<Attribute> &results),
LogicalResult (&foldHook)(Operation *op,
SmallVectorImpl<Value *> &results),
LogicalResult (&foldHook)(Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results),
void (&getCanonicalizationPatterns)(OwningRewritePatternList &results,
MLIRContext *context))
: name(name), dialect(dialect), classof(classof),
parseAssembly(parseAssembly), printAssembly(printAssembly),
verifyInvariants(verifyInvariants), constantFoldHook(constantFoldHook),
foldHook(foldHook),
verifyInvariants(verifyInvariants), foldHook(foldHook),
getCanonicalizationPatterns(getCanonicalizationPatterns),
opProperties(opProperties) {}

View File

@@ -101,7 +101,7 @@ public:
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
LogicalResult verify();
Attribute constantFold(ArrayRef<Attribute> operands, MLIRContext *context);
OpFoldResult fold(ArrayRef<Attribute> operands);
};
/// The predicate indicates the type of the comparison to perform:
@@ -176,7 +176,7 @@ public:
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
LogicalResult verify();
Attribute constantFold(ArrayRef<Attribute> operands, MLIRContext *context);
OpFoldResult fold(ArrayRef<Attribute> operands);
};
/// The "cond_br" operation represents a conditional branch operation in a
@@ -600,7 +600,7 @@ public:
Value *getTrueValue() { return getOperand(1); }
Value *getFalseValue() { return getOperand(2); }
Value *fold();
OpFoldResult fold(ArrayRef<Attribute> operands);
};
/// The "store" op writes an element to a memref specified by an index list.

View File

@@ -112,13 +112,12 @@ class FloatArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
def AddFOp : FloatArithmeticOp<"addf"> {
let summary = "floating point addition operation";
let hasConstantFolder = 1;
let hasFolder = 1;
}
def AddIOp : IntArithmeticOp<"addi", [Commutative]> {
let summary = "integer addition operation";
let hasFolder = 1;
let hasConstantFolder = 1;
}
def AllocOp : Std_Op<"alloc"> {
@@ -163,7 +162,6 @@ def AllocOp : Std_Op<"alloc"> {
def AndOp : IntArithmeticOp<"and", [Commutative]> {
let summary = "integer binary and";
let hasConstantFolder = 1;
let hasFolder = 1;
}
@@ -288,7 +286,7 @@ def ConstantOp : Std_Op<"constant", [NoSideEffect]> {
Attribute getValue() { return getAttr("value"); }
}];
let hasConstantFolder = 1;
let hasFolder = 1;
}
def DeallocOp : Std_Op<"dealloc"> {
@@ -338,7 +336,7 @@ def DimOp : Std_Op<"dim", [NoSideEffect]> {
}
}];
let hasConstantFolder = 1;
let hasFolder = 1;
}
def DivFOp : FloatArithmeticOp<"divf"> {
@@ -347,12 +345,12 @@ def DivFOp : FloatArithmeticOp<"divf"> {
def DivISOp : IntArithmeticOp<"divis"> {
let summary = "signed integer division operation";
let hasConstantFolder = 1;
let hasFolder = 1;
}
def DivIUOp : IntArithmeticOp<"diviu"> {
let summary = "unsigned integer division operation";
let hasConstantFolder = 1;
let hasFolder = 1;
}
def ExtractElementOp : Std_Op<"extract_element", [NoSideEffect]> {
@@ -389,7 +387,7 @@ def ExtractElementOp : Std_Op<"extract_element", [NoSideEffect]> {
}
}];
let hasConstantFolder = 1;
let hasFolder = 1;
}
def MemRefCastOp : CastOp<"memref_cast"> {
@@ -426,18 +424,16 @@ def MemRefCastOp : CastOp<"memref_cast"> {
def MulFOp : FloatArithmeticOp<"mulf"> {
let summary = "foating point multiplication operation";
let hasConstantFolder = 1;
let hasFolder = 1;
}
def MulIOp : IntArithmeticOp<"muli", [Commutative]> {
let summary = "integer multiplication operation";
let hasConstantFolder = 1;
let hasFolder = 1;
}
def OrOp : IntArithmeticOp<"or", [Commutative]> {
let summary = "integer binary or";
let hasConstantFolder = 1;
let hasFolder = 1;
}
@@ -447,12 +443,12 @@ def RemFOp : FloatArithmeticOp<"remf"> {
def RemISOp : IntArithmeticOp<"remis"> {
let summary = "signed integer division remainder operation";
let hasConstantFolder = 1;
let hasFolder = 1;
}
def RemIUOp : IntArithmeticOp<"remiu"> {
let summary = "unsigned integer division remainder operation";
let hasConstantFolder = 1;
let hasFolder = 1;
}
def ReturnOp : Std_Op<"return", [Terminator]> {
@@ -481,13 +477,12 @@ def ShlISOp : IntArithmeticOp<"shlis"> {
def SubFOp : FloatArithmeticOp<"subf"> {
let summary = "floating point subtraction operation";
let hasConstantFolder = 1;
let hasFolder = 1;
}
def SubIOp : IntArithmeticOp<"subi"> {
let summary = "integer subtraction operation";
let hasConstantFolder = 1;
let hasCanonicalizer = 1;
let hasFolder = 1;
}
def TensorCastOp : CastOp<"tensor_cast"> {
@@ -518,8 +513,6 @@ def TensorCastOp : CastOp<"tensor_cast"> {
def XOrOp : IntArithmeticOp<"xor", [Commutative]> {
let summary = "integer binary xor";
let hasConstantFolder = 1;
let hasCanonicalizer = 1;
let hasFolder = 1;
}

View File

@@ -1,4 +1,4 @@
//===- ConstantFoldUtils.h - Constant Fold Utilities ------------*- C++ -*-===//
//===- FoldUtils.h - Operation Fold Utilities -------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
@@ -15,13 +15,13 @@
// limitations under the License.
// =============================================================================
//
// This header file declares various constant fold utilities. These utilities
// are intended to be used by passes to unify and simply their logic.
// This header file declares various operation folding utilities. These
// utilities are intended to be used by passes to unify and simply their logic.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TRANSFORMS_CONSTANT_UTILS_H
#define MLIR_TRANSFORMS_CONSTANT_UTILS_H
#ifndef MLIR_TRANSFORMS_FOLDUTILS_H
#define MLIR_TRANSFORMS_FOLDUTILS_H
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Types.h"
@@ -32,13 +32,13 @@ namespace mlir {
class Function;
class Operation;
/// A helper class for constant folding operations, and unifying duplicated
/// constants along the way.
/// A helper class for folding operations, and unifying duplicated constants
/// generated along the way.
///
/// To make sure constants' proper dominance of all their uses, constants are
/// To make sure constants properly dominate all their uses, constants are
/// moved to the beginning of the entry block of the function when tracked by
/// this class.
class ConstantFoldHelper {
class FoldHelper {
public:
/// Constructs an instance for managing constants in the given function `f`.
/// Constants tracked by this instance will be moved to the entry block of
@@ -47,32 +47,30 @@ public:
/// This instance does not proactively walk the operations inside `f`;
/// instead, users must invoke the following methods to manually handle each
/// operation of interest.
ConstantFoldHelper(Function *f);
FoldHelper(Function *f);
/// Tries to perform constant folding on the given `op`, including unifying
/// deplicated constants. If successful, calls `preReplaceAction` (if
/// Tries to perform folding on the given `op`, including unifying
/// deduplicated constants. If successful, calls `preReplaceAction` (if
/// provided) by passing in `op`, then replaces `op`'s uses with folded
/// constants, and returns true.
///
/// Note: `op` will *not* be erased to avoid invalidating potential walkers in
/// the caller.
bool
tryToConstantFold(Operation *op,
std::function<void(Operation *)> preReplaceAction = {});
/// results, and returns success. If the op was completely folded it is
/// erased.
LogicalResult
tryToFold(Operation *op,
std::function<void(Operation *)> preReplaceAction = {});
/// Notifies that the given constant `op` should be remove from this
/// ConstantFoldHelper's internal bookkeeping.
/// FoldHelper's internal bookkeeping.
///
/// Note: this method must be called if a constant op is to be deleted
/// externally to this ConstantFoldHelper. `op` must be a constant op.
/// externally to this FoldHelper. `op` must be a constant op.
void notifyRemoval(Operation *op);
private:
/// Tries to deduplicate the given constant and returns true if that can be
/// Tries to deduplicate the given constant and returns success if that can be
/// done. This moves the given constant to the top of the entry block if it
/// is first seen. If there is already an existing constant that is the same,
/// this does *not* erases the given constant.
bool tryToUnify(Operation *op);
LogicalResult tryToUnify(Operation *op);
/// Moves the given constant `op` to entry block to guarantee dominance.
void moveConstantToEntryBlock(Operation *op);
@@ -86,4 +84,4 @@ private:
} // end namespace mlir
#endif // MLIR_TRANSFORMS_CONSTANT_UTILS_H
#endif // MLIR_TRANSFORMS_FOLDUTILS_H

View File

@@ -201,12 +201,11 @@ bool AffineApplyOp::isValidSymbol() {
[](Value *op) { return mlir::isValidSymbol(op); });
}
Attribute AffineApplyOp::constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) {
OpFoldResult AffineApplyOp::fold(ArrayRef<Attribute> operands) {
auto map = getAffineMap();
SmallVector<Attribute, 1> result;
if (failed(map.constantFold(operands, result)))
return Attribute();
return {};
return result[0];
}

View File

@@ -509,40 +509,32 @@ auto Operation::getSuccessorOperands(unsigned index) -> operand_range {
succOperandIndex + getNumSuccessorOperands(index))};
}
/// Attempt to constant fold this operation with the specified constant
/// operand values. If successful, this fills in the results vector. If not,
/// results is unspecified.
LogicalResult Operation::constantFold(ArrayRef<Attribute> operands,
SmallVectorImpl<Attribute> &results) {
if (auto *abstractOp = getAbstractOperation()) {
// If we have a registered operation definition matching this one, use it to
// try to constant fold the operation.
if (succeeded(abstractOp->constantFoldHook(this, operands, results)))
return success();
// Otherwise, fall back on the dialect hook to handle it.
return abstractOp->dialect.constantFoldHook(this, operands, results);
}
// If this operation hasn't been registered or doesn't have abstract
// operation, fall back to a dialect which matches the prefix.
auto opName = getName().getStringRef();
auto dialectPrefix = opName.split('.').first;
if (auto *dialect = getContext()->getRegisteredDialect(dialectPrefix))
return dialect->constantFoldHook(this, operands, results);
return failure();
}
/// Attempt to fold this operation using the Op's registered foldHook.
LogicalResult Operation::fold(SmallVectorImpl<Value *> &results) {
if (auto *abstractOp = getAbstractOperation()) {
// If we have a registered operation definition matching this one, use it to
// try to constant fold the operation.
if (succeeded(abstractOp->foldHook(this, results)))
return success();
LogicalResult Operation::fold(ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
// If we have a registered operation definition matching this one, use it to
// try to constant fold the operation.
auto *abstractOp = getAbstractOperation();
if (abstractOp && succeeded(abstractOp->foldHook(this, operands, results)))
return success();
// Otherwise, fall back on the dialect hook to handle it.
Dialect *dialect;
if (abstractOp) {
dialect = &abstractOp->dialect;
} else {
// If this operation hasn't been registered, lookup the parent dialect.
auto opName = getName().getStringRef();
auto dialectPrefix = opName.split('.').first;
if (!(dialect = getContext()->getRegisteredDialect(dialectPrefix)))
return failure();
}
return failure();
SmallVector<Attribute, 8> constants;
if (failed(dialect->constantFoldHook(this, operands, constants)))
return failure();
results.assign(constants.begin(), constants.end());
return success();
}
/// Emit an error with the op name prefixed, like "'dim' op " which is

View File

@@ -196,8 +196,7 @@ Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
// AddFOp
//===----------------------------------------------------------------------===//
Attribute AddFOp::constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) {
OpFoldResult AddFOp::fold(ArrayRef<Attribute> operands) {
return constFoldBinaryOp<FloatAttr>(
operands, [](APFloat a, APFloat b) { return a + b; });
}
@@ -206,20 +205,15 @@ Attribute AddFOp::constantFold(ArrayRef<Attribute> operands,
// AddIOp
//===----------------------------------------------------------------------===//
Attribute AddIOp::constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) {
OpFoldResult AddIOp::fold(ArrayRef<Attribute> operands) {
/// addi(x, 0) -> x
if (matchPattern(rhs(), m_Zero()))
return lhs();
return constFoldBinaryOp<IntegerAttr>(operands,
[](APInt a, APInt b) { return a + b; });
}
Value *AddIOp::fold() {
/// addi(x, 0) -> x
if (matchPattern(getOperand(1), m_Zero()))
return getOperand(0);
return nullptr;
}
//===----------------------------------------------------------------------===//
// AllocOp
//===----------------------------------------------------------------------===//
@@ -770,8 +764,7 @@ static bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,
}
// Constant folding hook for comparisons.
Attribute CmpIOp::constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) {
OpFoldResult CmpIOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "cmpi takes two arguments");
auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
@@ -780,7 +773,7 @@ Attribute CmpIOp::constantFold(ArrayRef<Attribute> operands,
return {};
auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
return IntegerAttr::get(IntegerType::get(1, context), APInt(1, val));
return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val));
}
//===----------------------------------------------------------------------===//
@@ -967,8 +960,7 @@ static bool applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs,
}
// Constant folding hook for comparisons.
Attribute CmpFOp::constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) {
OpFoldResult CmpFOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "cmpf takes two arguments");
auto lhs = operands.front().dyn_cast_or_null<FloatAttr>();
@@ -980,7 +972,7 @@ Attribute CmpFOp::constantFold(ArrayRef<Attribute> operands,
return {};
auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
return IntegerAttr::get(IntegerType::get(1, context), APInt(1, val));
return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val));
}
//===----------------------------------------------------------------------===//
@@ -1179,8 +1171,7 @@ static LogicalResult verify(ConstantOp &op) {
"requires a result type that aligns with the 'value' attribute");
}
Attribute ConstantOp::constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) {
OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
assert(operands.empty() && "constant has no operands");
return getValue();
}
@@ -1337,8 +1328,7 @@ static LogicalResult verify(DimOp op) {
return success();
}
Attribute DimOp::constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) {
OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
// Constant fold dim when the size along the index referred to is a constant.
auto opType = getOperand()->getType();
int64_t indexSize = -1;
@@ -1348,19 +1338,17 @@ Attribute DimOp::constantFold(ArrayRef<Attribute> operands,
indexSize = memrefType.getShape()[getIndex()];
if (indexSize >= 0)
return IntegerAttr::get(IndexType::get(context), indexSize);
return IntegerAttr::get(IndexType::get(getContext()), indexSize);
return nullptr;
return {};
}
//===----------------------------------------------------------------------===//
// DivISOp
//===----------------------------------------------------------------------===//
Attribute DivISOp::constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) {
OpFoldResult DivISOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "binary operation takes two operands");
(void)context;
auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
@@ -1368,9 +1356,8 @@ Attribute DivISOp::constantFold(ArrayRef<Attribute> operands,
return {};
// Don't fold if it requires division by zero.
if (rhs.getValue().isNullValue()) {
if (rhs.getValue().isNullValue())
return {};
}
// Don't fold if it would overflow.
bool overflow;
@@ -1382,10 +1369,8 @@ Attribute DivISOp::constantFold(ArrayRef<Attribute> operands,
// DivIUOp
//===----------------------------------------------------------------------===//
Attribute DivIUOp::constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) {
OpFoldResult DivIUOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "binary operation takes two operands");
(void)context;
auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
@@ -1675,14 +1660,13 @@ static LogicalResult verify(ExtractElementOp op) {
return success();
}
Attribute ExtractElementOp::constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) {
OpFoldResult ExtractElementOp::fold(ArrayRef<Attribute> operands) {
assert(!operands.empty() && "extract_element takes atleast one operand");
// The aggregate operand must be a known constant.
Attribute aggregate = operands.front();
if (!aggregate)
return Attribute();
return {};
// If this is a splat elements attribute, simply return the value. All of the
// elements of a splat attribute are the same.
@@ -1693,14 +1677,14 @@ Attribute ExtractElementOp::constantFold(ArrayRef<Attribute> operands,
SmallVector<uint64_t, 8> indices;
for (Attribute indice : llvm::drop_begin(operands, 1)) {
if (!indice || !indice.isa<IntegerAttr>())
return Attribute();
return {};
indices.push_back(indice.cast<IntegerAttr>().getInt());
}
// If this is an elements attribute, query the value at the given indices.
if (auto elementsAttr = aggregate.dyn_cast<ElementsAttr>())
return elementsAttr.getValue(indices);
return Attribute();
return {};
}
//===----------------------------------------------------------------------===//
@@ -1801,14 +1785,15 @@ bool MemRefCastOp::areCastCompatible(Type a, Type b) {
return true;
}
Value *MemRefCastOp::fold() { return impl::foldCastOp(*this); }
OpFoldResult MemRefCastOp::fold(ArrayRef<Attribute> operands) {
return impl::foldCastOp(*this);
}
//===----------------------------------------------------------------------===//
// MulFOp
//===----------------------------------------------------------------------===//
Attribute MulFOp::constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) {
OpFoldResult MulFOp::fold(ArrayRef<Attribute> operands) {
return constFoldBinaryOp<FloatAttr>(
operands, [](APFloat a, APFloat b) { return a * b; });
}
@@ -1817,29 +1802,24 @@ Attribute MulFOp::constantFold(ArrayRef<Attribute> operands,
// MulIOp
//===----------------------------------------------------------------------===//
Attribute MulIOp::constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) {
OpFoldResult MulIOp::fold(ArrayRef<Attribute> operands) {
/// muli(x, 0) -> 0
if (matchPattern(rhs(), m_Zero()))
return rhs();
/// muli(x, 1) -> x
if (matchPattern(rhs(), m_One()))
return getOperand(0);
// TODO: Handle the overflow case.
return constFoldBinaryOp<IntegerAttr>(operands,
[](APInt a, APInt b) { return a * b; });
}
Value *MulIOp::fold() {
/// muli(x, 0) -> 0
if (matchPattern(getOperand(1), m_Zero()))
return getOperand(1);
/// muli(x, 1) -> x
if (matchPattern(getOperand(1), m_One()))
return getOperand(0);
return nullptr;
}
//===----------------------------------------------------------------------===//
// RemISOp
//===----------------------------------------------------------------------===//
Attribute RemISOp::constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) {
OpFoldResult RemISOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "remis takes two operands");
auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
@@ -1852,9 +1832,8 @@ Attribute RemISOp::constantFold(ArrayRef<Attribute> operands,
APInt(rhs.getValue().getBitWidth(), 0));
// Don't fold if it requires division by zero.
if (rhs.getValue().isNullValue()) {
if (rhs.getValue().isNullValue())
return {};
}
auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
if (!lhs)
@@ -1867,8 +1846,7 @@ Attribute RemISOp::constantFold(ArrayRef<Attribute> operands,
// RemIUOp
//===----------------------------------------------------------------------===//
Attribute RemIUOp::constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) {
OpFoldResult RemIUOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "remiu takes two operands");
auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
@@ -1881,9 +1859,8 @@ Attribute RemIUOp::constantFold(ArrayRef<Attribute> operands,
APInt(rhs.getValue().getBitWidth(), 0));
// Don't fold if it requires division by zero.
if (rhs.getValue().isNullValue()) {
if (rhs.getValue().isNullValue())
return {};
}
auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
if (!lhs)
@@ -1990,7 +1967,7 @@ LogicalResult SelectOp::verify() {
return success();
}
Value *SelectOp::fold() {
OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
auto *condition = getCondition();
// select true, %0, %1 => %0
@@ -2081,8 +2058,7 @@ void StoreOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
// SubFOp
//===----------------------------------------------------------------------===//
Attribute SubFOp::constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) {
OpFoldResult SubFOp::fold(ArrayRef<Attribute> operands) {
return constFoldBinaryOp<FloatAttr>(
operands, [](APFloat a, APFloat b) { return a - b; });
}
@@ -2091,48 +2067,20 @@ Attribute SubFOp::constantFold(ArrayRef<Attribute> operands,
// SubIOp
//===----------------------------------------------------------------------===//
Attribute SubIOp::constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) {
OpFoldResult SubIOp::fold(ArrayRef<Attribute> operands) {
// subi(x,x) -> 0
if (getOperand(0) == getOperand(1))
return Builder(getContext()).getZeroAttr(getType());
return constFoldBinaryOp<IntegerAttr>(operands,
[](APInt a, APInt b) { return a - b; });
}
namespace {
/// subi(x,x) -> 0
///
struct SimplifyXMinusX : public RewritePattern {
SimplifyXMinusX(MLIRContext *context)
: RewritePattern(SubIOp::getOperationName(), 1, context) {}
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
auto subi = cast<SubIOp>(op);
if (subi.getOperand(0) != subi.getOperand(1))
return matchFailure();
rewriter.replaceOpWithNewOp<ConstantOp>(
op, subi.getType(), rewriter.getZeroAttr(subi.getType()));
return matchSuccess();
}
};
} // end anonymous namespace.
void SubIOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.push_back(llvm::make_unique<SimplifyXMinusX>(context));
}
//===----------------------------------------------------------------------===//
// AndOp
//===----------------------------------------------------------------------===//
Attribute AndOp::constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) {
return constFoldBinaryOp<IntegerAttr>(operands,
[](APInt a, APInt b) { return a & b; });
}
Value *AndOp::fold() {
OpFoldResult AndOp::fold(ArrayRef<Attribute> operands) {
/// and(x, 0) -> 0
if (matchPattern(rhs(), m_Zero()))
return rhs();
@@ -2140,20 +2088,15 @@ Value *AndOp::fold() {
if (lhs() == rhs())
return rhs();
return nullptr;
return constFoldBinaryOp<IntegerAttr>(operands,
[](APInt a, APInt b) { return a & b; });
}
//===----------------------------------------------------------------------===//
// OrOp
//===----------------------------------------------------------------------===//
Attribute OrOp::constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) {
return constFoldBinaryOp<IntegerAttr>(operands,
[](APInt a, APInt b) { return a | b; });
}
Value *OrOp::fold() {
OpFoldResult OrOp::fold(ArrayRef<Attribute> operands) {
/// or(x, 0) -> x
if (matchPattern(rhs(), m_Zero()))
return lhs();
@@ -2161,51 +2104,26 @@ Value *OrOp::fold() {
if (lhs() == rhs())
return rhs();
return nullptr;
return constFoldBinaryOp<IntegerAttr>(operands,
[](APInt a, APInt b) { return a | b; });
}
//===----------------------------------------------------------------------===//
// XOrOp
//===----------------------------------------------------------------------===//
Attribute XOrOp::constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) {
OpFoldResult XOrOp::fold(ArrayRef<Attribute> operands) {
/// xor(x, 0) -> x
if (matchPattern(rhs(), m_Zero()))
return lhs();
/// xor(x,x) -> 0
if (lhs() == rhs())
return Builder(getContext()).getZeroAttr(getType());
return constFoldBinaryOp<IntegerAttr>(operands,
[](APInt a, APInt b) { return a ^ b; });
}
Value *XOrOp::fold() {
/// xor(x, 0) -> x
if (matchPattern(rhs(), m_Zero()))
return lhs();
return nullptr;
}
namespace {
/// xor(x,x) -> 0
///
struct SimplifyXXOrX : public RewritePattern {
SimplifyXXOrX(MLIRContext *context)
: RewritePattern(XOrOp::getOperationName(), 1, context) {}
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
auto xorOp = cast<XOrOp>(op);
if (xorOp.lhs() != xorOp.rhs())
return matchFailure();
rewriter.replaceOpWithNewOp<ConstantOp>(
op, xorOp.getType(), rewriter.getZeroAttr(xorOp.getType()));
return matchSuccess();
}
};
} // end anonymous namespace.
void XOrOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.push_back(llvm::make_unique<SimplifyXXOrX>(context));
}
//===----------------------------------------------------------------------===//
// TensorCastOp
//===----------------------------------------------------------------------===//
@@ -2239,7 +2157,9 @@ bool TensorCastOp::areCastCompatible(Type a, Type b) {
return true;
}
Value *TensorCastOp::fold() { return impl::foldCastOp(*this); }
OpFoldResult TensorCastOp::fold(ArrayRef<Attribute> operands) {
return impl::foldCastOp(*this);
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions

View File

@@ -17,7 +17,7 @@ add_llvm_library(MLIRTransforms
SimplifyAffineStructures.cpp
StripDebugInfo.cpp
TestConstantFold.cpp
Utils/ConstantFoldUtils.cpp
Utils/FoldUtils.cpp
Utils/GreedyPatternRewriteDriver.cpp
Utils/LoopUtils.cpp
Utils/Utils.cpp

View File

@@ -20,7 +20,7 @@
#include "mlir/IR/Function.h"
#include "mlir/Pass/Pass.h"
#include "mlir/StandardOps/Ops.h"
#include "mlir/Transforms/ConstantFoldUtils.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/Utils.h"
@@ -31,26 +31,22 @@ namespace {
struct TestConstantFold : public FunctionPass<TestConstantFold> {
// All constants in the function post folding.
SmallVector<Operation *, 8> existingConstants;
// Operations that were folded and that need to be erased.
std::vector<Operation *> opsToErase;
void foldOperation(Operation *op, ConstantFoldHelper &helper);
void foldOperation(Operation *op, FoldHelper &helper);
void runOnFunction() override;
};
} // end anonymous namespace
void TestConstantFold::foldOperation(Operation *op,
ConstantFoldHelper &helper) {
void TestConstantFold::foldOperation(Operation *op, FoldHelper &helper) {
// Attempt to fold the specified operation, including handling unused or
// duplicated constants.
if (helper.tryToConstantFold(op)) {
opsToErase.push_back(op);
}
if (succeeded(helper.tryToFold(op)))
return;
// If this op is a constant that are used and cannot be de-duplicated,
// remember it for cleanup later.
else if (auto constant = dyn_cast<ConstantOp>(op)) {
if (auto constant = dyn_cast<ConstantOp>(op))
existingConstants.push_back(op);
}
}
// For now, we do a simple top-down pass over a function folding constants. We
@@ -58,10 +54,9 @@ void TestConstantFold::foldOperation(Operation *op,
// branches, or anything else fancy.
void TestConstantFold::runOnFunction() {
existingConstants.clear();
opsToErase.clear();
auto &f = getFunction();
ConstantFoldHelper helper(&f);
FoldHelper helper(&f);
// Collect and fold the operations within the function.
SmallVector<Operation *, 8> ops;
@@ -74,12 +69,6 @@ void TestConstantFold::runOnFunction() {
for (Operation *op : llvm::reverse(ops))
foldOperation(op, helper);
// At this point, these operations are dead, remove them.
for (auto *op : opsToErase) {
assert(op->hasNoSideEffect() && "Constant folded op with side effects?");
op->erase();
}
// By the time we are done, we may have simplified a bunch of code, leaving
// around dead constants. Check for them now and remove them.
for (auto *cst : existingConstants) {

View File

@@ -1,4 +1,4 @@
//===- ConstantFoldUtils.cpp ---- Constant Fold Utilities -----------------===//
//===- FoldUtils.cpp ---- Fold Utilities ----------------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
@@ -15,12 +15,12 @@
// limitations under the License.
// =============================================================================
//
// This file defines various constant fold utilities. These utilities are
// This file defines various operation fold utilities. These utilities are
// intended to be used by passes to unify and simply their logic.
//
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/ConstantFoldUtils.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Matchers.h"
@@ -29,10 +29,11 @@
using namespace mlir;
ConstantFoldHelper::ConstantFoldHelper(Function *f) : function(f) {}
FoldHelper::FoldHelper(Function *f) : function(f) {}
bool ConstantFoldHelper::tryToConstantFold(
Operation *op, std::function<void(Operation *)> preReplaceAction) {
LogicalResult
FoldHelper::tryToFold(Operation *op,
std::function<void(Operation *)> preReplaceAction) {
assert(op->getFunction() == function &&
"cannot constant fold op from another function");
@@ -44,13 +45,15 @@ bool ConstantFoldHelper::tryToConstantFold(
// If this constant is dead, update bookkeeping and signal the caller.
if (constant.use_empty()) {
notifyRemoval(op);
return true;
op->erase();
return success();
}
// Otherwise, try to see if we can de-duplicate it.
return tryToUnify(op);
}
SmallVector<Attribute, 8> operandConstants, resultConstants;
SmallVector<Attribute, 8> operandConstants;
SmallVector<OpFoldResult, 8> results;
// Check to see if any operands to the operation is constant and whether
// the operation knows how to constant fold itself.
@@ -67,8 +70,8 @@ bool ConstantFoldHelper::tryToConstantFold(
}
// Attempt to constant fold the operation.
if (failed(op->constantFold(operandConstants, resultConstants)))
return false;
if (failed(op->fold(operandConstants, results)))
return failure();
// Constant folding succeeded. We will start replacing this op's uses and
// eventually erase this op. Invoke the callback provided by the caller to
@@ -76,21 +79,35 @@ bool ConstantFoldHelper::tryToConstantFold(
if (preReplaceAction)
preReplaceAction(op);
// Check to see if the operation was just updated in place.
if (results.empty())
return success();
assert(results.size() == op->getNumResults());
// Create the result constants and replace the results.
FuncBuilder builder(op);
for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
auto *res = op->getResult(i);
if (res->use_empty()) // Ignore dead uses.
continue;
assert(!results[i].isNull() && "expected valid OpFoldResult");
// Check if the result was an SSA value.
if (auto *repl = results[i].dyn_cast<Value *>()) {
if (repl != res)
res->replaceAllUsesWith(repl);
continue;
}
// If we already have a canonicalized version of this constant, just reuse
// it. Otherwise create a new one.
Attribute attrRepl = results[i].get<Attribute>();
auto &constInst =
uniquedConstants[std::make_pair(resultConstants[i], res->getType())];
uniquedConstants[std::make_pair(attrRepl, res->getType())];
if (!constInst) {
// TODO: Extend to support dialect-specific constant ops.
auto newOp = builder.create<ConstantOp>(op->getLoc(), res->getType(),
resultConstants[i]);
auto newOp =
builder.create<ConstantOp>(op->getLoc(), res->getType(), attrRepl);
// Register to the constant map and also move up to entry block to
// guarantee dominance.
constInst = newOp.getOperation();
@@ -98,17 +115,18 @@ bool ConstantFoldHelper::tryToConstantFold(
}
res->replaceAllUsesWith(constInst->getResult(0));
}
op->erase();
return true;
return success();
}
void ConstantFoldHelper::notifyRemoval(Operation *op) {
void FoldHelper::notifyRemoval(Operation *op) {
assert(op->getFunction() == function &&
"cannot remove constant from another function");
Attribute constValue;
matchPattern(op, m_Constant(&constValue));
assert(constValue);
if (!matchPattern(op, m_Constant(&constValue)))
return;
// This constant is dead. keep uniquedConstants up to date.
auto it = uniquedConstants.find({constValue, op->getResult(0)->getType()});
@@ -116,7 +134,7 @@ void ConstantFoldHelper::notifyRemoval(Operation *op) {
uniquedConstants.erase(it);
}
bool ConstantFoldHelper::tryToUnify(Operation *op) {
LogicalResult FoldHelper::tryToUnify(Operation *op) {
Attribute constValue;
matchPattern(op, m_Constant(&constValue));
assert(constValue);
@@ -127,13 +145,14 @@ bool ConstantFoldHelper::tryToUnify(Operation *op) {
if (constInst) {
// If this constant is already our uniqued one, then leave it alone.
if (constInst == op)
return false;
return failure();
// Otherwise replace this redundant constant with the uniqued one. We know
// this is safe because we move constants to the top of the function when
// they are uniqued, so we know they dominate all uses.
op->getResult(0)->replaceAllUsesWith(constInst->getResult(0));
return true;
op->erase();
return success();
}
// If we have no entry, then we should unique this constant as the
@@ -141,10 +160,10 @@ bool ConstantFoldHelper::tryToUnify(Operation *op) {
// entry block of the function.
constInst = op;
moveConstantToEntryBlock(op);
return false;
return failure();
}
void ConstantFoldHelper::moveConstantToEntryBlock(Operation *op) {
void FoldHelper::moveConstantToEntryBlock(Operation *op) {
// Insert at the very top of the entry block.
auto &entryBB = function->front();
op->moveBefore(&entryBB, entryBB.begin());

View File

@@ -22,7 +22,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/StandardOps/Ops.h"
#include "mlir/Transforms/ConstantFoldUtils.h"
#include "mlir/Transforms/FoldUtils.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
@@ -148,7 +148,7 @@ private:
/// Perform the rewrites.
bool GreedyPatternRewriteDriver::simplifyFunction(int maxIterations) {
Function *fn = builder.getFunction();
ConstantFoldHelper helper(fn);
FoldHelper helper(fn);
bool changed = false;
int i = 0;
@@ -171,67 +171,31 @@ bool GreedyPatternRewriteDriver::simplifyFunction(int maxIterations) {
// If the operation has no side effects, and no users, then it is
// trivially dead - remove it.
if (op->hasNoSideEffect() && op->use_empty()) {
// Be careful to update bookkeeping in ConstantHelper to keep
// consistency if this is a constant op.
if (isa<ConstantOp>(op))
helper.notifyRemoval(op);
// Be careful to update bookkeeping in FoldHelper to keep consistency if
// this is a constant op.
helper.notifyRemoval(op);
op->erase();
continue;
}
// Collects all the operands and result uses of the given `op` into work
// list.
auto collectOperandsAndUses = [this](Operation *op) {
originalOperands.assign(op->operand_begin(), op->operand_end());
auto collectOperandsAndUses = [&](Operation *op) {
// Add the operands to the worklist for visitation.
addToWorklist(op->getOperands());
addToWorklist(originalOperands);
// Add all the users of the result to the worklist so we make sure
// to revisit them.
//
// TODO: Add a result->getUsers() iterator.
for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
for (unsigned i = 0, e = op->getNumResults(); i != e; ++i)
for (auto &operand : op->getResult(i)->getUses())
addToWorklist(operand.getOwner());
}
};
// Try to constant fold this op.
if (helper.tryToConstantFold(op, collectOperandsAndUses)) {
assert(op->hasNoSideEffect() &&
"Constant folded op with side effects?");
op->erase();
changed |= true;
continue;
}
// Otherwise see if we can use the generic folder API to simplify the
// operation.
originalOperands.assign(op->operand_begin(), op->operand_end());
resultValues.clear();
if (succeeded(op->fold(resultValues))) {
// If the result was an in-place simplification (e.g. max(x,x,y) ->
// max(x,y)) then add the original operands to the worklist so we can
// make sure to revisit them.
if (resultValues.empty()) {
// Add the operands back to the worklist as there may be more
// canonicalization opportunities now.
addToWorklist(originalOperands);
} else {
// Otherwise, the operation is simplified away completely.
assert(resultValues.size() == op->getNumResults());
// Notify that we are replacing this operation.
notifyRootReplaced(op);
// Replace the result values and erase the operation.
for (unsigned i = 0, e = resultValues.size(); i != e; ++i) {
auto *res = op->getResult(i);
if (!res->use_empty())
res->replaceAllUsesWith(resultValues[i]);
}
notifyOperationRemoved(op);
op->erase();
}
// Try to fold this op.
if (succeeded(helper.tryToFold(op, collectOperandsAndUses))) {
changed |= true;
continue;
}

View File

@@ -28,7 +28,6 @@ def NS_AOp : NS_Op<"a_op", [NoSideEffect]> {
let verifier = [{ baz }];
let hasCanonicalizer = 1;
let hasConstantFolder = 1;
let hasFolder = 1;
let extraClassDeclaration = [{
@@ -55,8 +54,7 @@ def NS_AOp : NS_Op<"a_op", [NoSideEffect]> {
// CHECK: void print(OpAsmPrinter *p);
// CHECK: LogicalResult verify();
// CHECK: static void getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context);
// CHECK: LogicalResult constantFold(ArrayRef<Attribute> operands, SmallVectorImpl<Attribute> &results, MLIRContext *context);
// CHECK: bool fold(SmallVectorImpl<Value *> &results);
// CHECK: LogicalResult fold(ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results);
// CHECK: // Display a graph for debugging purposes.
// CHECK: void displayGraph();
// CHECK: };

View File

@@ -834,28 +834,15 @@ void OpEmitter::genCanonicalizerDecls() {
void OpEmitter::genFolderDecls() {
bool hasSingleResult = op.getNumResults() == 1;
if (def.getValueAsBit("hasConstantFolder")) {
if (hasSingleResult) {
const char *const params =
"ArrayRef<Attribute> operands, MLIRContext *context";
opClass.newMethod("Attribute", "constantFold", params, OpMethod::MP_None,
/*declOnly=*/true);
} else {
const char *const params =
"ArrayRef<Attribute> operands, SmallVectorImpl<Attribute> &results, "
"MLIRContext *context";
opClass.newMethod("LogicalResult", "constantFold", params,
OpMethod::MP_None, /*declOnly=*/true);
}
}
if (def.getValueAsBit("hasFolder")) {
if (hasSingleResult) {
opClass.newMethod("Value *", "fold", /*params=*/"", OpMethod::MP_None,
const char *const params = "ArrayRef<Attribute> operands";
opClass.newMethod("OpFoldResult", "fold", params, OpMethod::MP_None,
/*declOnly=*/true);
} else {
opClass.newMethod("bool", "fold", "SmallVectorImpl<Value *> &results",
OpMethod::MP_None,
const char *const params = "ArrayRef<Attribute> operands, "
"SmallVectorImpl<OpFoldResult> &results";
opClass.newMethod("LogicalResult", "fold", params, OpMethod::MP_None,
/*declOnly=*/true);
}
}