mirror of
https://github.com/intel/llvm.git
synced 2026-01-25 10:55:58 +08:00
[mlir][Tensor] Avoid dropping attributes for tensor.pad operations during canonicalization.
Reviewed By: hanchung Differential Revision: https://reviews.llvm.org/D146440
This commit is contained in:
@@ -11,6 +11,7 @@
|
||||
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
#include <optional>
|
||||
|
||||
@@ -461,18 +462,10 @@ struct GenerateLoopNest {
|
||||
/// Returns an attribute list that excludes pre-defined attributes.
|
||||
template <typename OpTy>
|
||||
SmallVector<NamedAttribute> getPrunedAttributeList(OpTy op) {
|
||||
llvm::StringSet<> elidedAttrs;
|
||||
elidedAttrs.insert(op.getAttributeNames().begin(),
|
||||
op.getAttributeNames().end());
|
||||
auto elidedAttrs = llvm::to_vector(op.getAttributeNames());
|
||||
if (isa<linalg::LinalgOp>(op.getOperation()))
|
||||
elidedAttrs.insert(LinalgDialect::kMemoizedIndexingMapsAttrName);
|
||||
SmallVector<NamedAttribute> attrs;
|
||||
for (auto attr : op->getAttrs()) {
|
||||
if (elidedAttrs.count(attr.getName()))
|
||||
continue;
|
||||
attrs.push_back(attr);
|
||||
}
|
||||
return attrs;
|
||||
elidedAttrs.push_back(LinalgDialect::kMemoizedIndexingMapsAttrName);
|
||||
return getPrunedAttributeList(op, elidedAttrs);
|
||||
}
|
||||
|
||||
} // namespace linalg
|
||||
|
||||
@@ -1295,13 +1295,13 @@ def Tensor_PadOp : Tensor_Op<"pad", [
|
||||
|
||||
let builders = [
|
||||
// Build a PadOp with mixed static and dynamic entries.
|
||||
OpBuilder<(ins "Value":$source, "ArrayRef<int64_t>":$staticLow,
|
||||
"ArrayRef<int64_t>":$staticHigh, "ValueRange":$low, "ValueRange":$high,
|
||||
CArg<"bool", "false">:$nofold,
|
||||
OpBuilder<(ins "Type":$resultType, "Value":$source,
|
||||
"ArrayRef<int64_t>":$staticLow, "ArrayRef<int64_t>":$staticHigh,
|
||||
"ValueRange":$low, "ValueRange":$high, CArg<"bool", "false">:$nofold,
|
||||
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
|
||||
// Build a PadOp with all dynamic entries.
|
||||
OpBuilder<(ins "Value":$source, "ValueRange":$low, "ValueRange":$high,
|
||||
CArg<"bool", "false">:$nofold,
|
||||
OpBuilder<(ins "Type":$resultType, "Value":$source, "ValueRange":$low,
|
||||
"ValueRange":$high, CArg<"bool", "false">:$nofold,
|
||||
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
|
||||
// Build a PadOp with mixed static and dynamic entries and custom
|
||||
// result type. If the type passed is nullptr, it is inferred.
|
||||
|
||||
@@ -123,6 +123,11 @@ Operation *cloneWithoutRegions(OpBuilder &b, Operation *op,
|
||||
TypeRange newResultTypes,
|
||||
ValueRange newOperands);
|
||||
|
||||
// Get the list of attributes associated with the op, ignoring
|
||||
// those with the provided name.
|
||||
SmallVector<NamedAttribute>
|
||||
getPrunedAttributeList(Operation *op, ArrayRef<StringRef> elidedAttrs);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H
|
||||
|
||||
@@ -2518,26 +2518,27 @@ RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
|
||||
return RankedTensorType::get(inferredShape, sourceType.getElementType());
|
||||
}
|
||||
|
||||
void PadOp::build(OpBuilder &b, OperationState &result, Value source,
|
||||
ArrayRef<int64_t> staticLow, ArrayRef<int64_t> staticHigh,
|
||||
ValueRange low, ValueRange high, bool nofold,
|
||||
ArrayRef<NamedAttribute> attrs) {
|
||||
void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
|
||||
Value source, ArrayRef<int64_t> staticLow,
|
||||
ArrayRef<int64_t> staticHigh, ValueRange low, ValueRange high,
|
||||
bool nofold, ArrayRef<NamedAttribute> attrs) {
|
||||
auto sourceType = source.getType().cast<RankedTensorType>();
|
||||
auto resultType = inferResultType(sourceType, staticLow, staticHigh);
|
||||
if (!resultType)
|
||||
resultType = inferResultType(sourceType, staticLow, staticHigh);
|
||||
build(b, result, resultType, source, low, high,
|
||||
b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
|
||||
nofold ? b.getUnitAttr() : UnitAttr());
|
||||
result.addAttributes(attrs);
|
||||
}
|
||||
|
||||
void PadOp::build(OpBuilder &b, OperationState &result, Value source,
|
||||
ValueRange low, ValueRange high, bool nofold,
|
||||
void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
|
||||
Value source, ValueRange low, ValueRange high, bool nofold,
|
||||
ArrayRef<NamedAttribute> attrs) {
|
||||
auto sourceType = source.getType().cast<RankedTensorType>();
|
||||
unsigned rank = sourceType.getRank();
|
||||
SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamic);
|
||||
build(b, result, source, staticVector, staticVector, low, high, nofold,
|
||||
attrs);
|
||||
build(b, result, resultType, source, staticVector, staticVector, low, high,
|
||||
nofold, attrs);
|
||||
}
|
||||
|
||||
void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
|
||||
@@ -2635,9 +2636,9 @@ struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
|
||||
} else {
|
||||
auto newOp = rewriter.create<PadOp>(
|
||||
padTensorOp->getLoc(), newResultType, padTensorOp.getSource(),
|
||||
padTensorOp.getLow(), padTensorOp.getHigh(),
|
||||
padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
|
||||
padTensorOp.getNofold());
|
||||
padTensorOp.getLow(), padTensorOp.getHigh(), padTensorOp.getNofold(),
|
||||
getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
|
||||
IRMapping mapper;
|
||||
padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
|
||||
|
||||
@@ -2667,9 +2668,10 @@ struct FoldTargetTensorCast : public OpRewritePattern<PadOp> {
|
||||
|
||||
auto replacementOp = rewriter.create<PadOp>(
|
||||
padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
|
||||
padTensorOp.getSource(), padTensorOp.getLow(), padTensorOp.getHigh(),
|
||||
padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
|
||||
padTensorOp.getNofold());
|
||||
padTensorOp.getSource(), padTensorOp.getStaticLow(),
|
||||
padTensorOp.getStaticHigh(), padTensorOp.getLow(),
|
||||
padTensorOp.getHigh(), padTensorOp.getNofold(),
|
||||
getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
|
||||
replacementOp.getRegion().takeBody(padTensorOp.getRegion());
|
||||
|
||||
rewriter.replaceOp(padTensorOp, replacementOp.getResult());
|
||||
@@ -2827,7 +2829,8 @@ struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
|
||||
innerSliceOp.getMixedStrides());
|
||||
auto newPadOp = rewriter.create<PadOp>(
|
||||
padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
|
||||
padOp.getMixedLowPad(), newHighPad, padOp.getNofold());
|
||||
padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
|
||||
getPrunedAttributeList(padOp, PadOp::getAttributeNames()));
|
||||
rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
|
||||
newPadOp.getRegion().begin());
|
||||
rewriter.replaceOp(padOp, newPadOp.getResult());
|
||||
@@ -2916,8 +2919,9 @@ struct FoldStaticPadding : public OpRewritePattern<PadOp> {
|
||||
auto newResultType = RankedTensorType::get(
|
||||
newOutDims, padTensorOp.getType().getElementType());
|
||||
auto newOp = rewriter.create<PadOp>(
|
||||
padTensorOp->getLoc(), newResultType, input, padTensorOp.getLow(),
|
||||
padTensorOp.getHigh(), staticLow, staticHigh, padTensorOp.getNofold());
|
||||
padTensorOp->getLoc(), newResultType, input, staticLow, staticHigh,
|
||||
padTensorOp.getLow(), padTensorOp.getHigh(), padTensorOp.getNofold(),
|
||||
getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
|
||||
|
||||
IRMapping mapper;
|
||||
padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
|
||||
#include "mlir/Dialect/Utils/DialectUtilsEnums.cpp.inc"
|
||||
|
||||
@@ -114,3 +115,16 @@ Operation *mlir::cloneWithoutRegions(OpBuilder &b, Operation *op,
|
||||
state.addRegion();
|
||||
return b.create(state);
|
||||
}
|
||||
|
||||
SmallVector<NamedAttribute>
|
||||
mlir::getPrunedAttributeList(Operation *op, ArrayRef<StringRef> elidedAttrs) {
|
||||
llvm::StringSet elidedAttrsSet;
|
||||
elidedAttrsSet.insert(elidedAttrs.begin(), elidedAttrs.end());
|
||||
SmallVector<NamedAttribute> attrs;
|
||||
for (auto attr : op->getAttrs()) {
|
||||
if (elidedAttrsSet.count(attr.getName()))
|
||||
continue;
|
||||
attrs.push_back(attr);
|
||||
}
|
||||
return attrs;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user