|
|
|
|
@@ -97,7 +97,7 @@ bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
|
|
|
|
|
// Can fold if the source of cast has at least as much static information as
|
|
|
|
|
// its results.
|
|
|
|
|
return preservesStaticInformation(castOp.getType(),
|
|
|
|
|
castOp.source().getType());
|
|
|
|
|
castOp.getSource().getType());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Determines whether the tensor::CastOp casts to a more static version of the
|
|
|
|
|
@@ -123,7 +123,7 @@ bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
|
|
|
|
|
bool mlir::tensor::canFoldIntoProducerOp(CastOp castOp) {
|
|
|
|
|
if (!castOp)
|
|
|
|
|
return false;
|
|
|
|
|
return preservesStaticInformation(castOp.source().getType(),
|
|
|
|
|
return preservesStaticInformation(castOp.getSource().getType(),
|
|
|
|
|
castOp.getType());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -250,13 +250,15 @@ struct TensorCastExtractSlice : public OpRewritePattern<CastOp> {
|
|
|
|
|
tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
|
|
|
|
|
|
|
|
|
|
if (!extractOperand || !canFoldIntoProducerOp(tensorCast) ||
|
|
|
|
|
tensorCast.getType().getShape() ==
|
|
|
|
|
tensorCast.source().getType().cast<RankedTensorType>().getShape())
|
|
|
|
|
tensorCast.getType().getShape() == tensorCast.getSource()
|
|
|
|
|
.getType()
|
|
|
|
|
.cast<RankedTensorType>()
|
|
|
|
|
.getShape())
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
|
|
|
|
|
auto dimMask = computeRankReductionMask(
|
|
|
|
|
extractFromI64ArrayAttr(extractOperand.static_sizes()),
|
|
|
|
|
extractFromI64ArrayAttr(extractOperand.getStaticSizes()),
|
|
|
|
|
extractOperand.getType().getShape());
|
|
|
|
|
size_t dimIndex = 0;
|
|
|
|
|
for (size_t i = 0, e = sizes.size(); i < e; i++) {
|
|
|
|
|
@@ -270,7 +272,7 @@ struct TensorCastExtractSlice : public OpRewritePattern<CastOp> {
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<ExtractSliceOp>(
|
|
|
|
|
tensorCast, tensorCast.getType().cast<RankedTensorType>(),
|
|
|
|
|
extractOperand.source(), extractOperand.getMixedOffsets(), sizes,
|
|
|
|
|
extractOperand.getSource(), extractOperand.getMixedOffsets(), sizes,
|
|
|
|
|
extractOperand.getMixedStrides());
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
@@ -295,7 +297,7 @@ void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Optional<int64_t> DimOp::getConstantIndex() {
|
|
|
|
|
if (auto constantOp = index().getDefiningOp<arith::ConstantOp>())
|
|
|
|
|
if (auto constantOp = getIndex().getDefiningOp<arith::ConstantOp>())
|
|
|
|
|
return constantOp.getValue().cast<IntegerAttr>().getInt();
|
|
|
|
|
return {};
|
|
|
|
|
}
|
|
|
|
|
@@ -307,7 +309,7 @@ LogicalResult DimOp::verify() {
|
|
|
|
|
return success();
|
|
|
|
|
|
|
|
|
|
// Check that constant index is not knowingly out of range.
|
|
|
|
|
auto type = source().getType();
|
|
|
|
|
auto type = getSource().getType();
|
|
|
|
|
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
|
|
|
|
if (*index >= tensorType.getRank())
|
|
|
|
|
return emitOpError("index is out of range");
|
|
|
|
|
@@ -326,7 +328,7 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
|
return {};
|
|
|
|
|
|
|
|
|
|
// Folding for unranked types (UnrankedTensorType) is not supported.
|
|
|
|
|
auto tensorType = source().getType().dyn_cast<RankedTensorType>();
|
|
|
|
|
auto tensorType = getSource().getType().dyn_cast<RankedTensorType>();
|
|
|
|
|
if (!tensorType)
|
|
|
|
|
return {};
|
|
|
|
|
|
|
|
|
|
@@ -336,7 +338,7 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
|
return builder.getIndexAttr(tensorType.getShape()[index.getInt()]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Operation *definingOp = source().getDefiningOp();
|
|
|
|
|
Operation *definingOp = getSource().getDefiningOp();
|
|
|
|
|
|
|
|
|
|
// Fold dim to the operand of tensor.generate.
|
|
|
|
|
if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
|
|
|
|
|
@@ -347,7 +349,7 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
|
assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
|
|
|
|
|
|
|
|
|
|
// Find the operand of the fromElements that corresponds to this index.
|
|
|
|
|
auto dynExtents = fromElements.dynamicExtents().begin();
|
|
|
|
|
auto dynExtents = fromElements.getDynamicExtents().begin();
|
|
|
|
|
for (auto dim : resultType.getShape().take_front(index.getInt()))
|
|
|
|
|
if (ShapedType::isDynamic(dim))
|
|
|
|
|
dynExtents++;
|
|
|
|
|
@@ -381,11 +383,11 @@ struct DimOfCastOp : public OpRewritePattern<DimOp> {
|
|
|
|
|
|
|
|
|
|
LogicalResult matchAndRewrite(DimOp dimOp,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
auto castOp = dimOp.source().getDefiningOp<CastOp>();
|
|
|
|
|
auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
|
|
|
|
|
if (!castOp)
|
|
|
|
|
return failure();
|
|
|
|
|
Value newSource = castOp.getOperand();
|
|
|
|
|
rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.index());
|
|
|
|
|
rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.getIndex());
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
@@ -402,8 +404,8 @@ void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
|
|
|
|
|
|
|
|
LogicalResult ExtractOp::verify() {
|
|
|
|
|
// Verify the # indices match if we have a ranked type.
|
|
|
|
|
if (auto tensorType = tensor().getType().dyn_cast<RankedTensorType>())
|
|
|
|
|
if (tensorType.getRank() != static_cast<int64_t>(indices().size()))
|
|
|
|
|
if (auto tensorType = getTensor().getType().dyn_cast<RankedTensorType>())
|
|
|
|
|
if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
|
|
|
|
|
return emitOpError("incorrect number of indices for extract_element");
|
|
|
|
|
|
|
|
|
|
return success();
|
|
|
|
|
@@ -425,7 +427,7 @@ OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Fold extract(from_elements(...)).
|
|
|
|
|
if (auto fromElementsOp = tensor().getDefiningOp<FromElementsOp>()) {
|
|
|
|
|
if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
|
|
|
|
|
auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
|
|
|
|
|
auto rank = tensorType.getRank();
|
|
|
|
|
assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
|
|
|
|
|
@@ -439,10 +441,10 @@ OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
|
}
|
|
|
|
|
// Prevent out of bounds accesses. This can happen in invalid code that will
|
|
|
|
|
// never execute.
|
|
|
|
|
if (static_cast<int>(fromElementsOp.elements().size()) <= flatIndex ||
|
|
|
|
|
if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
|
|
|
|
|
flatIndex < 0)
|
|
|
|
|
return {};
|
|
|
|
|
return fromElementsOp.elements()[flatIndex];
|
|
|
|
|
return fromElementsOp.getElements()[flatIndex];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// If this is an elements attribute, query the value at the given indices.
|
|
|
|
|
@@ -503,14 +505,14 @@ struct ExtractElementFromIndexCast
|
|
|
|
|
LogicalResult matchAndRewrite(tensor::ExtractOp extract,
|
|
|
|
|
PatternRewriter &rewriter) const final {
|
|
|
|
|
Location loc = extract.getLoc();
|
|
|
|
|
auto indexCast = extract.tensor().getDefiningOp<arith::IndexCastOp>();
|
|
|
|
|
auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>();
|
|
|
|
|
if (!indexCast)
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
Type elementTy = getElementTypeOrSelf(indexCast.getIn());
|
|
|
|
|
|
|
|
|
|
auto newExtract = rewriter.create<tensor::ExtractOp>(
|
|
|
|
|
loc, elementTy, indexCast.getIn(), extract.indices());
|
|
|
|
|
loc, elementTy, indexCast.getIn(), extract.getIndices());
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(extract, extract.getType(),
|
|
|
|
|
newExtract);
|
|
|
|
|
@@ -532,8 +534,8 @@ void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
|
|
|
|
|
|
|
|
LogicalResult InsertOp::verify() {
|
|
|
|
|
// Verify the # indices match if we have a ranked type.
|
|
|
|
|
if (auto destType = dest().getType().dyn_cast<RankedTensorType>())
|
|
|
|
|
if (destType.getRank() != static_cast<int64_t>(indices().size()))
|
|
|
|
|
if (auto destType = getDest().getType().dyn_cast<RankedTensorType>())
|
|
|
|
|
if (destType.getRank() != static_cast<int64_t>(getIndices().size()))
|
|
|
|
|
return emitOpError("incorrect number of indices");
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
@@ -581,16 +583,16 @@ LogicalResult GenerateOp::verify() {
|
|
|
|
|
LogicalResult GenerateOp::verifyRegions() {
|
|
|
|
|
RankedTensorType resultTy = getType().cast<RankedTensorType>();
|
|
|
|
|
// Ensure that region arguments span the index space.
|
|
|
|
|
if (!llvm::all_of(body().getArgumentTypes(),
|
|
|
|
|
if (!llvm::all_of(getBody().getArgumentTypes(),
|
|
|
|
|
[](Type ty) { return ty.isIndex(); }))
|
|
|
|
|
return emitError("all body arguments must be index");
|
|
|
|
|
if (body().getNumArguments() != resultTy.getRank())
|
|
|
|
|
if (getBody().getNumArguments() != resultTy.getRank())
|
|
|
|
|
return emitError("must have one body argument per input dimension");
|
|
|
|
|
|
|
|
|
|
// Ensure that the region yields an element of the right type.
|
|
|
|
|
auto yieldOp = cast<YieldOp>(body().getBlocks().front().getTerminator());
|
|
|
|
|
auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
|
|
|
|
|
|
|
|
|
|
if (yieldOp.value().getType() != resultTy.getElementType())
|
|
|
|
|
if (yieldOp.getValue().getType() != resultTy.getElementType())
|
|
|
|
|
return emitOpError(
|
|
|
|
|
"body must be terminated with a `yield` operation of the tensor "
|
|
|
|
|
"element type");
|
|
|
|
|
@@ -634,7 +636,7 @@ struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
|
|
|
|
|
|
|
|
|
|
SmallVector<Value, 4> newOperands;
|
|
|
|
|
SmallVector<int64_t, 4> newShape;
|
|
|
|
|
auto operandsIt = tensorFromElements.dynamicExtents().begin();
|
|
|
|
|
auto operandsIt = tensorFromElements.getDynamicExtents().begin();
|
|
|
|
|
|
|
|
|
|
for (int64_t dim : resultType.getShape()) {
|
|
|
|
|
if (!ShapedType::isDynamic(dim)) {
|
|
|
|
|
@@ -651,15 +653,15 @@ struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
|
|
|
|
|
operandsIt++;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (newOperands.size() == tensorFromElements.dynamicExtents().size())
|
|
|
|
|
if (newOperands.size() == tensorFromElements.getDynamicExtents().size())
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
auto loc = tensorFromElements.getLoc();
|
|
|
|
|
auto newOp = rewriter.create<GenerateOp>(
|
|
|
|
|
loc, RankedTensorType::get(newShape, resultType.getElementType()),
|
|
|
|
|
newOperands);
|
|
|
|
|
rewriter.inlineRegionBefore(tensorFromElements.body(), newOp.body(),
|
|
|
|
|
newOp.body().begin());
|
|
|
|
|
rewriter.inlineRegionBefore(tensorFromElements.getBody(), newOp.getBody(),
|
|
|
|
|
newOp.getBody().begin());
|
|
|
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(tensorFromElements, resultType,
|
|
|
|
|
newOp);
|
|
|
|
|
return success();
|
|
|
|
|
@@ -682,19 +684,19 @@ struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
|
|
|
|
|
|
|
|
|
|
LogicalResult matchAndRewrite(tensor::ExtractOp extract,
|
|
|
|
|
PatternRewriter &rewriter) const final {
|
|
|
|
|
auto tensorFromElements = extract.tensor().getDefiningOp<GenerateOp>();
|
|
|
|
|
auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>();
|
|
|
|
|
if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
BlockAndValueMapping mapping;
|
|
|
|
|
Block *body = &tensorFromElements.getBody().front();
|
|
|
|
|
mapping.map(body->getArguments(), extract.indices());
|
|
|
|
|
mapping.map(body->getArguments(), extract.getIndices());
|
|
|
|
|
for (auto &op : body->without_terminator())
|
|
|
|
|
rewriter.clone(op, mapping);
|
|
|
|
|
|
|
|
|
|
auto yield = cast<YieldOp>(body->getTerminator());
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.value()));
|
|
|
|
|
rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.getValue()));
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
@@ -712,12 +714,12 @@ struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
|
|
|
|
|
|
|
|
|
|
LogicalResult matchAndRewrite(tensor::ExtractOp extract,
|
|
|
|
|
PatternRewriter &rewriter) const final {
|
|
|
|
|
auto tensorCast = extract.tensor().getDefiningOp<tensor::CastOp>();
|
|
|
|
|
auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
|
|
|
|
|
if (!tensorCast)
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(extract, tensorCast.source(),
|
|
|
|
|
extract.indices());
|
|
|
|
|
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
|
|
|
|
|
extract, tensorCast.getSource(), extract.getIndices());
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
@@ -756,14 +758,15 @@ static int64_t getNumElements(ShapedType type) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
LogicalResult ReshapeOp::verify() {
|
|
|
|
|
TensorType operandType = source().getType().cast<TensorType>();
|
|
|
|
|
TensorType resultType = result().getType().cast<TensorType>();
|
|
|
|
|
TensorType operandType = getSource().getType().cast<TensorType>();
|
|
|
|
|
TensorType resultType = getResult().getType().cast<TensorType>();
|
|
|
|
|
|
|
|
|
|
if (operandType.getElementType() != resultType.getElementType())
|
|
|
|
|
return emitOpError("element types of source and destination tensor "
|
|
|
|
|
"types should be the same");
|
|
|
|
|
|
|
|
|
|
int64_t shapeSize = shape().getType().cast<RankedTensorType>().getDimSize(0);
|
|
|
|
|
int64_t shapeSize =
|
|
|
|
|
getShape().getType().cast<RankedTensorType>().getDimSize(0);
|
|
|
|
|
auto resultRankedType = resultType.dyn_cast<RankedTensorType>();
|
|
|
|
|
auto operandRankedType = operandType.dyn_cast<RankedTensorType>();
|
|
|
|
|
|
|
|
|
|
@@ -891,7 +894,7 @@ struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
|
|
|
|
|
LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
DenseElementsAttr attr;
|
|
|
|
|
if (!matchPattern(reshapeOp.src(), m_Constant(&attr)))
|
|
|
|
|
if (!matchPattern(reshapeOp.getSrc(), m_Constant(&attr)))
|
|
|
|
|
return failure();
|
|
|
|
|
if (!attr || !attr.isSplat())
|
|
|
|
|
return failure();
|
|
|
|
|
@@ -910,7 +913,7 @@ struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
|
|
|
|
|
LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
auto fromElements =
|
|
|
|
|
reshapeOp.src().template getDefiningOp<FromElementsOp>();
|
|
|
|
|
reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
|
|
|
|
|
if (!fromElements)
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
@@ -920,7 +923,7 @@ struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<FromElementsOp>(reshapeOp, reshapeOp.getType(),
|
|
|
|
|
fromElements.elements());
|
|
|
|
|
fromElements.getElements());
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
@@ -1208,7 +1211,7 @@ public:
|
|
|
|
|
}))
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
auto castOp = sliceOp.source().getDefiningOp<tensor::CastOp>();
|
|
|
|
|
auto castOp = sliceOp.getSource().getDefiningOp<tensor::CastOp>();
|
|
|
|
|
if (!castOp)
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
@@ -1221,9 +1224,9 @@ public:
|
|
|
|
|
sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(),
|
|
|
|
|
sliceOp.getMixedStrides());
|
|
|
|
|
Value newSlice = rewriter.create<ExtractSliceOp>(
|
|
|
|
|
sliceOp.getLoc(), resultType, castOp.source(), sliceOp.offsets(),
|
|
|
|
|
sliceOp.sizes(), sliceOp.strides(), sliceOp.static_offsets(),
|
|
|
|
|
sliceOp.static_sizes(), sliceOp.static_strides());
|
|
|
|
|
sliceOp.getLoc(), resultType, castOp.getSource(), sliceOp.getOffsets(),
|
|
|
|
|
sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
|
|
|
|
|
sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
|
|
|
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(sliceOp, sliceOp.getType(),
|
|
|
|
|
newSlice);
|
|
|
|
|
return success();
|
|
|
|
|
@@ -1277,7 +1280,7 @@ public:
|
|
|
|
|
LogicalResult matchAndRewrite(ExtractSliceOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
DenseElementsAttr attr;
|
|
|
|
|
if (!matchPattern(op.source(), m_Constant(&attr)))
|
|
|
|
|
if (!matchPattern(op.getSource(), m_Constant(&attr)))
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
// A constant splat is handled by fold().
|
|
|
|
|
@@ -1285,8 +1288,8 @@ public:
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
// Dynamic result shape is not supported.
|
|
|
|
|
auto sourceType = op.source().getType().cast<ShapedType>();
|
|
|
|
|
auto resultType = op.result().getType().cast<ShapedType>();
|
|
|
|
|
auto sourceType = op.getSource().getType().cast<ShapedType>();
|
|
|
|
|
auto resultType = op.getResult().getType().cast<ShapedType>();
|
|
|
|
|
if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
@@ -1299,13 +1302,13 @@ public:
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
// Check if there are any dynamic parts, which are not supported.
|
|
|
|
|
auto offsets = extractFromI64ArrayAttr(op.static_offsets());
|
|
|
|
|
auto offsets = extractFromI64ArrayAttr(op.getStaticOffsets());
|
|
|
|
|
if (llvm::is_contained(offsets, ShapedType::kDynamicStrideOrOffset))
|
|
|
|
|
return failure();
|
|
|
|
|
auto sizes = extractFromI64ArrayAttr(op.static_sizes());
|
|
|
|
|
auto sizes = extractFromI64ArrayAttr(op.getStaticSizes());
|
|
|
|
|
if (llvm::is_contained(sizes, ShapedType::kDynamicSize))
|
|
|
|
|
return failure();
|
|
|
|
|
auto strides = extractFromI64ArrayAttr(op.static_strides());
|
|
|
|
|
auto strides = extractFromI64ArrayAttr(op.getStaticStrides());
|
|
|
|
|
if (llvm::is_contained(strides, ShapedType::kDynamicStrideOrOffset))
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
@@ -1414,25 +1417,25 @@ foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
|
|
|
|
|
// TODO: This only checks the immediate producer; extend to go up the
|
|
|
|
|
// insert/extract chain if the slices are disjoint.
|
|
|
|
|
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
|
|
|
|
|
auto insertOp = extractOp.source().getDefiningOp<InsertSliceOp>();
|
|
|
|
|
auto insertOp = extractOp.getSource().getDefiningOp<InsertSliceOp>();
|
|
|
|
|
|
|
|
|
|
auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
|
|
|
|
|
if (insertOp && insertOp.source().getType() == extractOp.getType() &&
|
|
|
|
|
if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
|
|
|
|
|
insertOp.isSameAs(extractOp, isSame))
|
|
|
|
|
return insertOp.source();
|
|
|
|
|
return insertOp.getSource();
|
|
|
|
|
|
|
|
|
|
return {};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
OpFoldResult ExtractSliceOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
|
if (auto splat = operands[0].dyn_cast_or_null<SplatElementsAttr>()) {
|
|
|
|
|
auto resultType = result().getType().cast<ShapedType>();
|
|
|
|
|
auto resultType = getResult().getType().cast<ShapedType>();
|
|
|
|
|
if (resultType.hasStaticShape())
|
|
|
|
|
return splat.resizeSplat(resultType);
|
|
|
|
|
}
|
|
|
|
|
if (getSourceType() == getType() &&
|
|
|
|
|
succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
|
|
|
|
|
return this->source();
|
|
|
|
|
return this->getSource();
|
|
|
|
|
if (Value slice = foldExtractAfterInsertSlice(*this))
|
|
|
|
|
return slice;
|
|
|
|
|
|
|
|
|
|
@@ -1518,8 +1521,8 @@ verifyInsertSliceOp(ShapedType srcType, ShapedType dstType,
|
|
|
|
|
LogicalResult InsertSliceOp::verify() {
|
|
|
|
|
ShapedType expectedType;
|
|
|
|
|
auto result =
|
|
|
|
|
verifyInsertSliceOp(getSourceType(), getType(), static_offsets(),
|
|
|
|
|
static_sizes(), static_strides(), &expectedType);
|
|
|
|
|
verifyInsertSliceOp(getSourceType(), getType(), getStaticOffsets(),
|
|
|
|
|
getStaticSizes(), getStaticStrides(), &expectedType);
|
|
|
|
|
return produceSliceErrorMsg(result, *this, expectedType);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -1539,15 +1542,15 @@ LogicalResult InsertSliceOp::verify() {
|
|
|
|
|
/// %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1]
|
|
|
|
|
/// ```
|
|
|
|
|
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) {
|
|
|
|
|
auto prevInsertOp = insertOp.dest().getDefiningOp<InsertSliceOp>();
|
|
|
|
|
auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
|
|
|
|
|
|
|
|
|
|
auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
|
|
|
|
|
if (!prevInsertOp ||
|
|
|
|
|
prevInsertOp.source().getType() != insertOp.source().getType() ||
|
|
|
|
|
prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
|
|
|
|
|
!prevInsertOp.isSameAs(insertOp, isSame))
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
insertOp.destMutable().assign(prevInsertOp.dest());
|
|
|
|
|
insertOp.getDestMutable().assign(prevInsertOp.getDest());
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -1555,7 +1558,7 @@ OpFoldResult InsertSliceOp::fold(ArrayRef<Attribute>) {
|
|
|
|
|
if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
|
|
|
|
|
getSourceType() == getType() &&
|
|
|
|
|
succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
|
|
|
|
|
return this->source();
|
|
|
|
|
return this->getSource();
|
|
|
|
|
if (succeeded(foldInsertAfterInsertSlice(*this)))
|
|
|
|
|
return getResult();
|
|
|
|
|
return OpFoldResult();
|
|
|
|
|
@@ -1566,7 +1569,7 @@ LogicalResult InsertSliceOp::reifyResultShapes(
|
|
|
|
|
reifiedReturnShapes.resize(1, SmallVector<Value>(getType().getRank()));
|
|
|
|
|
for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
|
|
|
|
|
reifiedReturnShapes[0][dim] =
|
|
|
|
|
builder.createOrFold<tensor::DimOp>(getLoc(), dest(), dim);
|
|
|
|
|
builder.createOrFold<tensor::DimOp>(getLoc(), getDest(), dim);
|
|
|
|
|
}
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
@@ -1600,13 +1603,13 @@ public:
|
|
|
|
|
auto sourceType = ExtractSliceOp::inferRankReducedResultType(
|
|
|
|
|
insertSliceOp.getSourceType().getRank(), insertSliceOp.getType(),
|
|
|
|
|
mixedOffsets, mixedSizes, mixedStrides);
|
|
|
|
|
Value toInsert = insertSliceOp.source();
|
|
|
|
|
Value toInsert = insertSliceOp.getSource();
|
|
|
|
|
if (sourceType != insertSliceOp.getSourceType())
|
|
|
|
|
toInsert = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
|
|
|
|
|
sourceType, toInsert);
|
|
|
|
|
rewriter.replaceOpWithNewOp<InsertSliceOp>(
|
|
|
|
|
insertSliceOp, toInsert, insertSliceOp.dest(), mixedOffsets, mixedSizes,
|
|
|
|
|
mixedStrides);
|
|
|
|
|
insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
|
|
|
|
|
mixedSizes, mixedStrides);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
@@ -1643,22 +1646,23 @@ struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertSliceOp> {
|
|
|
|
|
auto castOp = v.getDefiningOp<tensor::CastOp>();
|
|
|
|
|
if (!castOp || !canFoldIntoConsumerOp(castOp))
|
|
|
|
|
return llvm::None;
|
|
|
|
|
return castOp.source();
|
|
|
|
|
return castOp.getSource();
|
|
|
|
|
};
|
|
|
|
|
Optional<Value> sourceCastSource =
|
|
|
|
|
getSourceOfCastOp(insertSliceOp.source());
|
|
|
|
|
Optional<Value> destCastSource = getSourceOfCastOp(insertSliceOp.dest());
|
|
|
|
|
getSourceOfCastOp(insertSliceOp.getSource());
|
|
|
|
|
Optional<Value> destCastSource = getSourceOfCastOp(insertSliceOp.getDest());
|
|
|
|
|
if (!sourceCastSource && !destCastSource)
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
auto src = (sourceCastSource ? *sourceCastSource : insertSliceOp.source());
|
|
|
|
|
auto dst = (destCastSource ? *destCastSource : insertSliceOp.dest());
|
|
|
|
|
auto src =
|
|
|
|
|
(sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
|
|
|
|
|
auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
|
|
|
|
|
|
|
|
|
|
auto srcType = src.getType().cast<ShapedType>();
|
|
|
|
|
auto dstType = dst.getType().cast<ShapedType>();
|
|
|
|
|
if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.static_offsets(),
|
|
|
|
|
insertSliceOp.static_sizes(),
|
|
|
|
|
insertSliceOp.static_strides()) !=
|
|
|
|
|
if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
|
|
|
|
|
insertSliceOp.getStaticSizes(),
|
|
|
|
|
insertSliceOp.getStaticStrides()) !=
|
|
|
|
|
SliceVerificationResult::Success)
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
@@ -1724,9 +1728,9 @@ struct InsertSliceOpSourceCastInserter final
|
|
|
|
|
// 3) Cast-compatible with srcType.
|
|
|
|
|
// Insert the cast.
|
|
|
|
|
Value cast = rewriter.create<tensor::CastOp>(
|
|
|
|
|
insertSliceOp.getLoc(), newSrcType, insertSliceOp.source());
|
|
|
|
|
insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource());
|
|
|
|
|
rewriter.replaceOpWithNewOp<InsertSliceOp>(
|
|
|
|
|
insertSliceOp, cast, insertSliceOp.dest(),
|
|
|
|
|
insertSliceOp, cast, insertSliceOp.getDest(),
|
|
|
|
|
insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
|
|
|
|
|
insertSliceOp.getMixedStrides());
|
|
|
|
|
return success();
|
|
|
|
|
@@ -1781,11 +1785,11 @@ ParseResult parseInferType(OpAsmParser &parser,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
LogicalResult PadOp::verify() {
|
|
|
|
|
auto sourceType = source().getType().cast<RankedTensorType>();
|
|
|
|
|
auto resultType = result().getType().cast<RankedTensorType>();
|
|
|
|
|
auto expectedType =
|
|
|
|
|
PadOp::inferResultType(sourceType, extractFromI64ArrayAttr(static_low()),
|
|
|
|
|
extractFromI64ArrayAttr(static_high()));
|
|
|
|
|
auto sourceType = getSource().getType().cast<RankedTensorType>();
|
|
|
|
|
auto resultType = getResult().getType().cast<RankedTensorType>();
|
|
|
|
|
auto expectedType = PadOp::inferResultType(
|
|
|
|
|
sourceType, extractFromI64ArrayAttr(getStaticLow()),
|
|
|
|
|
extractFromI64ArrayAttr(getStaticHigh()));
|
|
|
|
|
for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
|
|
|
|
|
if (resultType.getDimSize(i) == expectedType.getDimSize(i))
|
|
|
|
|
continue;
|
|
|
|
|
@@ -1801,7 +1805,7 @@ LogicalResult PadOp::verify() {
|
|
|
|
|
|
|
|
|
|
LogicalResult PadOp::verifyRegions() {
|
|
|
|
|
auto ®ion = getRegion();
|
|
|
|
|
unsigned rank = result().getType().cast<RankedTensorType>().getRank();
|
|
|
|
|
unsigned rank = getResult().getType().cast<RankedTensorType>().getRank();
|
|
|
|
|
Block &block = region.front();
|
|
|
|
|
if (block.getNumArguments() != rank)
|
|
|
|
|
return emitError("expected the block to have ") << rank << " arguments";
|
|
|
|
|
@@ -1815,7 +1819,7 @@ LogicalResult PadOp::verifyRegions() {
|
|
|
|
|
|
|
|
|
|
// Ensure that the region yields an element of the right type.
|
|
|
|
|
auto yieldOp = llvm::cast<YieldOp>(block.getTerminator());
|
|
|
|
|
if (yieldOp.value().getType() !=
|
|
|
|
|
if (yieldOp.getValue().getType() !=
|
|
|
|
|
getType().cast<ShapedType>().getElementType())
|
|
|
|
|
return emitOpError("expected yield type to match shape element type");
|
|
|
|
|
|
|
|
|
|
@@ -1919,10 +1923,11 @@ struct FoldStaticZeroPadding : public OpRewritePattern<PadOp> {
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
|
|
|
|
|
return failure();
|
|
|
|
|
if (padTensorOp.nofold())
|
|
|
|
|
if (padTensorOp.getNofold())
|
|
|
|
|
return failure();
|
|
|
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
|
|
|
|
padTensorOp, padTensorOp.result().getType(), padTensorOp.source());
|
|
|
|
|
padTensorOp, padTensorOp.getResult().getType(),
|
|
|
|
|
padTensorOp.getSource());
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
@@ -1933,25 +1938,26 @@ struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
|
|
|
|
|
|
|
|
|
|
LogicalResult matchAndRewrite(PadOp padTensorOp,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
auto castOp = padTensorOp.source().getDefiningOp<tensor::CastOp>();
|
|
|
|
|
auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
|
|
|
|
|
if (!tensor::canFoldIntoConsumerOp(castOp))
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
auto newResultType = PadOp::inferResultType(
|
|
|
|
|
castOp.source().getType().cast<RankedTensorType>(),
|
|
|
|
|
extractFromI64ArrayAttr(padTensorOp.static_low()),
|
|
|
|
|
extractFromI64ArrayAttr(padTensorOp.static_high()),
|
|
|
|
|
castOp.getSource().getType().cast<RankedTensorType>(),
|
|
|
|
|
extractFromI64ArrayAttr(padTensorOp.getStaticLow()),
|
|
|
|
|
extractFromI64ArrayAttr(padTensorOp.getStaticHigh()),
|
|
|
|
|
padTensorOp.getResultType().getShape());
|
|
|
|
|
|
|
|
|
|
if (newResultType == padTensorOp.getResultType()) {
|
|
|
|
|
rewriter.updateRootInPlace(padTensorOp, [&]() {
|
|
|
|
|
padTensorOp.sourceMutable().assign(castOp.source());
|
|
|
|
|
padTensorOp.getSourceMutable().assign(castOp.getSource());
|
|
|
|
|
});
|
|
|
|
|
} else {
|
|
|
|
|
auto newOp = rewriter.create<PadOp>(
|
|
|
|
|
padTensorOp->getLoc(), newResultType, padTensorOp.source(),
|
|
|
|
|
padTensorOp.low(), padTensorOp.high(), padTensorOp.static_low(),
|
|
|
|
|
padTensorOp.static_high(), padTensorOp.nofold());
|
|
|
|
|
padTensorOp->getLoc(), newResultType, padTensorOp.getSource(),
|
|
|
|
|
padTensorOp.getLow(), padTensorOp.getHigh(),
|
|
|
|
|
padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
|
|
|
|
|
padTensorOp.getNofold());
|
|
|
|
|
BlockAndValueMapping mapper;
|
|
|
|
|
padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
|
|
|
|
|
|
|
|
|
|
@@ -1969,25 +1975,25 @@ struct FoldTargetTensorCast : public OpRewritePattern<PadOp> {
|
|
|
|
|
|
|
|
|
|
LogicalResult matchAndRewrite(PadOp padTensorOp,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
if (!padTensorOp.result().hasOneUse())
|
|
|
|
|
if (!padTensorOp.getResult().hasOneUse())
|
|
|
|
|
return failure();
|
|
|
|
|
auto tensorCastOp =
|
|
|
|
|
dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
|
|
|
|
|
if (!tensorCastOp)
|
|
|
|
|
return failure();
|
|
|
|
|
if (!tensor::preservesStaticInformation(padTensorOp.result().getType(),
|
|
|
|
|
tensorCastOp.dest().getType()))
|
|
|
|
|
if (!tensor::preservesStaticInformation(padTensorOp.getResult().getType(),
|
|
|
|
|
tensorCastOp.getDest().getType()))
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
auto replacementOp = rewriter.create<PadOp>(
|
|
|
|
|
padTensorOp.getLoc(), tensorCastOp.dest().getType(),
|
|
|
|
|
padTensorOp.source(), padTensorOp.low(), padTensorOp.high(),
|
|
|
|
|
padTensorOp.static_low(), padTensorOp.static_high(),
|
|
|
|
|
padTensorOp.nofold());
|
|
|
|
|
replacementOp.region().takeBody(padTensorOp.region());
|
|
|
|
|
padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
|
|
|
|
|
padTensorOp.getSource(), padTensorOp.getLow(), padTensorOp.getHigh(),
|
|
|
|
|
padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
|
|
|
|
|
padTensorOp.getNofold());
|
|
|
|
|
replacementOp.getRegion().takeBody(padTensorOp.getRegion());
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOp(padTensorOp, replacementOp.result());
|
|
|
|
|
rewriter.replaceOp(tensorCastOp, replacementOp.result());
|
|
|
|
|
rewriter.replaceOp(padTensorOp, replacementOp.getResult());
|
|
|
|
|
rewriter.replaceOp(tensorCastOp, replacementOp.getResult());
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
@@ -2031,13 +2037,13 @@ struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
|
|
|
|
|
|
|
|
|
|
LogicalResult matchAndRewrite(PadOp padOp,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
auto innerSliceOp = padOp.source().getDefiningOp<ExtractSliceOp>();
|
|
|
|
|
auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
|
|
|
|
|
if (!innerSliceOp)
|
|
|
|
|
return failure();
|
|
|
|
|
auto outerPadOp = innerSliceOp.source().getDefiningOp<PadOp>();
|
|
|
|
|
if (!outerPadOp || outerPadOp.nofold())
|
|
|
|
|
auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
|
|
|
|
|
if (!outerPadOp || outerPadOp.getNofold())
|
|
|
|
|
return failure();
|
|
|
|
|
auto outerSliceOp = outerPadOp.source().getDefiningOp<ExtractSliceOp>();
|
|
|
|
|
auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
|
|
|
|
|
if (!outerSliceOp)
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
@@ -2136,11 +2142,11 @@ struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
|
|
|
|
|
// Create a new tensor::ExtractSliceOp, tensor::PadOp pair that performs the
|
|
|
|
|
// two paddings in one step.
|
|
|
|
|
auto newSliceOp = rewriter.create<ExtractSliceOp>(
|
|
|
|
|
padOp.getLoc(), outerSliceOp.source(), newOffsets, newSizes,
|
|
|
|
|
padOp.getLoc(), outerSliceOp.getSource(), newOffsets, newSizes,
|
|
|
|
|
innerSliceOp.getMixedStrides());
|
|
|
|
|
auto newPadOp = rewriter.create<PadOp>(
|
|
|
|
|
padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
|
|
|
|
|
padOp.getMixedLowPad(), newHighPad, padOp.nofold());
|
|
|
|
|
padOp.getMixedLowPad(), newHighPad, padOp.getNofold());
|
|
|
|
|
rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
|
|
|
|
|
newPadOp.getRegion().begin());
|
|
|
|
|
rewriter.replaceOp(padOp, newPadOp.getResult());
|
|
|
|
|
@@ -2169,7 +2175,7 @@ Value PadOp::getConstantPaddingValue() {
|
|
|
|
|
auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
|
|
|
|
|
if (!yieldOp)
|
|
|
|
|
return {};
|
|
|
|
|
Value padValue = yieldOp.value();
|
|
|
|
|
Value padValue = yieldOp.getValue();
|
|
|
|
|
// Check if yield value is a constant.
|
|
|
|
|
if (matchPattern(padValue, m_Constant()))
|
|
|
|
|
return padValue;
|
|
|
|
|
@@ -2182,8 +2188,8 @@ Value PadOp::getConstantPaddingValue() {
|
|
|
|
|
|
|
|
|
|
OpFoldResult PadOp::fold(ArrayRef<Attribute>) {
|
|
|
|
|
if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
|
|
|
|
|
!nofold())
|
|
|
|
|
return source();
|
|
|
|
|
!getNofold())
|
|
|
|
|
return getSource();
|
|
|
|
|
return {};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|