[mlir][Tensor] Avoid dropping attributes for tensor.pad operations during canonicalization.

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D146440
This commit is contained in:
Mahesh Ravishankar
2023-03-20 20:56:41 +00:00
parent fb1b9945be
commit c21e88cc02
5 changed files with 49 additions and 33 deletions

View File

@@ -11,6 +11,7 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "llvm/ADT/StringSet.h"
#include <optional>
@@ -461,18 +462,10 @@ struct GenerateLoopNest {
/// Returns an attribute list that excludes pre-defined attributes.
template <typename OpTy>
SmallVector<NamedAttribute> getPrunedAttributeList(OpTy op) {
llvm::StringSet<> elidedAttrs;
elidedAttrs.insert(op.getAttributeNames().begin(),
op.getAttributeNames().end());
auto elidedAttrs = llvm::to_vector(op.getAttributeNames());
if (isa<linalg::LinalgOp>(op.getOperation()))
elidedAttrs.insert(LinalgDialect::kMemoizedIndexingMapsAttrName);
SmallVector<NamedAttribute> attrs;
for (auto attr : op->getAttrs()) {
if (elidedAttrs.count(attr.getName()))
continue;
attrs.push_back(attr);
}
return attrs;
elidedAttrs.push_back(LinalgDialect::kMemoizedIndexingMapsAttrName);
return getPrunedAttributeList(op, elidedAttrs);
}
} // namespace linalg

View File

@@ -1295,13 +1295,13 @@ def Tensor_PadOp : Tensor_Op<"pad", [
let builders = [
// Build a PadOp with mixed static and dynamic entries.
OpBuilder<(ins "Value":$source, "ArrayRef<int64_t>":$staticLow,
"ArrayRef<int64_t>":$staticHigh, "ValueRange":$low, "ValueRange":$high,
CArg<"bool", "false">:$nofold,
OpBuilder<(ins "Type":$resultType, "Value":$source,
"ArrayRef<int64_t>":$staticLow, "ArrayRef<int64_t>":$staticHigh,
"ValueRange":$low, "ValueRange":$high, CArg<"bool", "false">:$nofold,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
// Build a PadOp with all dynamic entries.
OpBuilder<(ins "Value":$source, "ValueRange":$low, "ValueRange":$high,
CArg<"bool", "false">:$nofold,
OpBuilder<(ins "Type":$resultType, "Value":$source, "ValueRange":$low,
"ValueRange":$high, CArg<"bool", "false">:$nofold,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
// Build a PadOp with mixed static and dynamic entries and custom
// result type. If the type passed is nullptr, it is inferred.

View File

@@ -123,6 +123,11 @@ Operation *cloneWithoutRegions(OpBuilder &b, Operation *op,
TypeRange newResultTypes,
ValueRange newOperands);
// Get the list of attributes associated with the op, ignoring
// those with the provided name.
SmallVector<NamedAttribute>
getPrunedAttributeList(Operation *op, ArrayRef<StringRef> elidedAttrs);
} // namespace mlir
#endif // MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H

View File

@@ -2518,26 +2518,27 @@ RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
return RankedTensorType::get(inferredShape, sourceType.getElementType());
}
void PadOp::build(OpBuilder &b, OperationState &result, Value source,
ArrayRef<int64_t> staticLow, ArrayRef<int64_t> staticHigh,
ValueRange low, ValueRange high, bool nofold,
ArrayRef<NamedAttribute> attrs) {
void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
Value source, ArrayRef<int64_t> staticLow,
ArrayRef<int64_t> staticHigh, ValueRange low, ValueRange high,
bool nofold, ArrayRef<NamedAttribute> attrs) {
auto sourceType = source.getType().cast<RankedTensorType>();
auto resultType = inferResultType(sourceType, staticLow, staticHigh);
if (!resultType)
resultType = inferResultType(sourceType, staticLow, staticHigh);
build(b, result, resultType, source, low, high,
b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
nofold ? b.getUnitAttr() : UnitAttr());
result.addAttributes(attrs);
}
void PadOp::build(OpBuilder &b, OperationState &result, Value source,
ValueRange low, ValueRange high, bool nofold,
void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
Value source, ValueRange low, ValueRange high, bool nofold,
ArrayRef<NamedAttribute> attrs) {
auto sourceType = source.getType().cast<RankedTensorType>();
unsigned rank = sourceType.getRank();
SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamic);
build(b, result, source, staticVector, staticVector, low, high, nofold,
attrs);
build(b, result, resultType, source, staticVector, staticVector, low, high,
nofold, attrs);
}
void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
@@ -2635,9 +2636,9 @@ struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
} else {
auto newOp = rewriter.create<PadOp>(
padTensorOp->getLoc(), newResultType, padTensorOp.getSource(),
padTensorOp.getLow(), padTensorOp.getHigh(),
padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
padTensorOp.getNofold());
padTensorOp.getLow(), padTensorOp.getHigh(), padTensorOp.getNofold(),
getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
IRMapping mapper;
padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
@@ -2667,9 +2668,10 @@ struct FoldTargetTensorCast : public OpRewritePattern<PadOp> {
auto replacementOp = rewriter.create<PadOp>(
padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
padTensorOp.getSource(), padTensorOp.getLow(), padTensorOp.getHigh(),
padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
padTensorOp.getNofold());
padTensorOp.getSource(), padTensorOp.getStaticLow(),
padTensorOp.getStaticHigh(), padTensorOp.getLow(),
padTensorOp.getHigh(), padTensorOp.getNofold(),
getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
replacementOp.getRegion().takeBody(padTensorOp.getRegion());
rewriter.replaceOp(padTensorOp, replacementOp.getResult());
@@ -2827,7 +2829,8 @@ struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
innerSliceOp.getMixedStrides());
auto newPadOp = rewriter.create<PadOp>(
padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
padOp.getMixedLowPad(), newHighPad, padOp.getNofold());
padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
getPrunedAttributeList(padOp, PadOp::getAttributeNames()));
rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
newPadOp.getRegion().begin());
rewriter.replaceOp(padOp, newPadOp.getResult());
@@ -2916,8 +2919,9 @@ struct FoldStaticPadding : public OpRewritePattern<PadOp> {
auto newResultType = RankedTensorType::get(
newOutDims, padTensorOp.getType().getElementType());
auto newOp = rewriter.create<PadOp>(
padTensorOp->getLoc(), newResultType, input, padTensorOp.getLow(),
padTensorOp.getHigh(), staticLow, staticHigh, padTensorOp.getNofold());
padTensorOp->getLoc(), newResultType, input, staticLow, staticHigh,
padTensorOp.getLow(), padTensorOp.getHigh(), padTensorOp.getNofold(),
getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
IRMapping mapper;
padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);

View File

@@ -11,6 +11,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/IRMapping.h"
#include "llvm/ADT/StringSet.h"
#include "mlir/Dialect/Utils/DialectUtilsEnums.cpp.inc"
@@ -114,3 +115,16 @@ Operation *mlir::cloneWithoutRegions(OpBuilder &b, Operation *op,
state.addRegion();
return b.create(state);
}
SmallVector<NamedAttribute>
mlir::getPrunedAttributeList(Operation *op, ArrayRef<StringRef> elidedAttrs) {
llvm::StringSet elidedAttrsSet;
elidedAttrsSet.insert(elidedAttrs.begin(), elidedAttrs.end());
SmallVector<NamedAttribute> attrs;
for (auto attr : op->getAttrs()) {
if (elidedAttrsSet.count(attr.getName()))
continue;
attrs.push_back(attr);
}
return attrs;
}