diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index 7ae27407a952..f50e3464867b 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -121,6 +121,70 @@ def Tensor_CastOp : Tensor_Op<"cast", [ let hasCanonicalizer = 1; } +//===----------------------------------------------------------------------===// +// ConcatOp +//===----------------------------------------------------------------------===// + +def Tensor_ConcatOp : Tensor_Op<"concat", + [Pure, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { + let summary = "tensor concatenation operation"; + let description = [{ + The "concat" operation constructs a tensor out of a variadic list of input + tensors, concatenated along a static dimension number. All inputs and the + result type must share the same rank. + + `dim` specifies the dimension along which to concatenate. The size of the + concatenated dimension in the result must be equal to the sum of the sizes + of the inputs along that dimension. All other dimensions in both the inputs + and result must be the same size. + + Example: + + ```mlir + %0 = tensor.concat dim(0) %0, %1, %2 : + (tensor<3x6xf32>, tensor<3x6xf32>, tensor<1x6xf32) -> tensor<7x6xf32> + + // Dynamic + dynamic -> static + %0 = tensor.concat dim(1) %0, %1, %2 : + (tensor<3x?xf32>, tensor<3x2xf32>, tensor<3x?xf32) -> tensor<3x10xf32> + ``` + }]; + let arguments = (ins I64Attr:$dim, + Variadic:$inputs); + let results = (outs AnyRankedTensor:$result); + let assemblyFormat = [{ + `dim` `(` $dim `)` $inputs attr-dict + `:` functional-type(operands, results) + }]; + + let builders = [ + // Builder with an inferred result type. + OpBuilder<(ins "int64_t":$dim, "ValueRange":$inputs)>, + ]; + + let extraClassDeclaration = [{ + // Helper to infer the concatenated result type for the given list of input + // types, being concatenated along `dim`. Because concatenation can specify + // more static information than can automatically be inferred, + // InferTypeOpInterface is not used. + static RankedTensorType inferResultType(int64_t dim, TypeRange inputTypes); + + RankedTensorType getResultType() { + return ::llvm::cast(getResult().getType()); + } + + int64_t getRank() { + return ::llvm::cast(getResult().getType()).getRank(); + } + }]; + + let hasCanonicalizer = 1; + let hasFolder = 1; + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // DimOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td index 66c6021418b4..8556d9570fd1 100644 --- a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td +++ b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td @@ -15,6 +15,18 @@ include "mlir/Dialect/Transform/IR/TransformTypes.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/OpBase.td" +def ApplyDecomposeTensorConcatPatternsOp : Op]> { + let description = [{ + Indicates that tensor.concat ops should be decomposed into a chain of + tensor.insert_slice operations inserting into a materialized destination. + }]; + + let assemblyFormat = "attr-dict"; +} + + def ApplyDropRedundantInsertSliceRankExpansionPatternsOp : Op]> { diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h index 705b30e7ded4..44b8377bd6aa 100644 --- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h @@ -67,6 +67,13 @@ void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns); void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns, bool foldSingleUseOnly = false); +/// Populates `patterns` with patterns that decompose `tensor.concat` into +/// `tensor.empty` of a tensor of the concatenated size, followed by a chain +/// of `tensor.insert_slice` operations on the inputs. This is intended to be +/// used as a fallback tensor -> tensor lowering that decomposes concat such +/// that it can be bufferized into a sequence of copies. +void populateDecomposeTensorConcatPatterns(RewritePatternSet &patterns); + /// Populates `patterns` with patterns that fold operations like `tensor.pad` /// and `tensor.extract_slice` into `tensor.pack` and `tensor.unpack` operations /// respectively. diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h index c2fbaea726ab..502ab93ddbfa 100644 --- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h @@ -151,6 +151,39 @@ LogicalResult foldDynamicIndexList(SmallVectorImpl &ofrs, std::optional constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step); +/// Idiomatic saturated operations on values like offsets, sizes, and strides. +struct SaturatedInteger { + static SaturatedInteger wrap(int64_t v) { + return (ShapedType::isDynamic(v)) ? SaturatedInteger{true, 0} + : SaturatedInteger{false, v}; + } + int64_t asInteger() { return saturated ? ShapedType::kDynamic : v; } + FailureOr desaturate(SaturatedInteger other) { + if (saturated && !other.saturated) + return other; + if (!saturated && !other.saturated && v != other.v) + return failure(); + return *this; + } + bool operator==(SaturatedInteger other) { + return (saturated && other.saturated) || + (!saturated && !other.saturated && v == other.v); + } + bool operator!=(SaturatedInteger other) { return !(*this == other); } + SaturatedInteger operator+(SaturatedInteger other) { + if (saturated || other.saturated) + return SaturatedInteger{true, 0}; + return SaturatedInteger{false, other.v + v}; + } + SaturatedInteger operator*(SaturatedInteger other) { + if (saturated || other.saturated) + return SaturatedInteger{true, 0}; + return SaturatedInteger{false, other.v * v}; + } + bool saturated = true; + int64_t v = 0; +}; + } // namespace mlir #endif // MLIR_DIALECT_UTILS_STATICVALUEUTILS_H diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index a2fc954ad07f..dce96cca016f 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -26,43 +26,6 @@ using namespace mlir; using namespace mlir::memref; -namespace { -/// Idiomatic saturated operations on offsets, sizes and strides. -namespace saturated_arith { -struct Wrapper { - static Wrapper stride(int64_t v) { - return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v}; - } - static Wrapper offset(int64_t v) { - return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v}; - } - static Wrapper size(int64_t v) { - return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v}; - } - int64_t asOffset() { return saturated ? ShapedType::kDynamic : v; } - int64_t asSize() { return saturated ? ShapedType::kDynamic : v; } - int64_t asStride() { return saturated ? ShapedType::kDynamic : v; } - bool operator==(Wrapper other) { - return (saturated && other.saturated) || - (!saturated && !other.saturated && v == other.v); - } - bool operator!=(Wrapper other) { return !(*this == other); } - Wrapper operator+(Wrapper other) { - if (saturated || other.saturated) - return Wrapper{true, 0}; - return Wrapper{false, other.v + v}; - } - Wrapper operator*(Wrapper other) { - if (saturated || other.saturated) - return Wrapper{true, 0}; - return Wrapper{false, other.v * v}; - } - bool saturated; - int64_t v; -}; -} // namespace saturated_arith -} // namespace - /// Materialize a single constant operation from a given attribute value with /// the desired resultant type. Operation *MemRefDialect::materializeConstant(OpBuilder &builder, @@ -2208,11 +2171,11 @@ computeExpandedLayoutMap(MemRefType srcType, ArrayRef resultShape, ReassociationIndices reassoc = std::get<0>(it); int64_t currentStrideToExpand = std::get<1>(it); for (unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) { - using saturated_arith::Wrapper; reverseResultStrides.push_back(currentStrideToExpand); - currentStrideToExpand = (Wrapper::stride(currentStrideToExpand) * - Wrapper::size(resultShape[shapeIndex--])) - .asStride(); + currentStrideToExpand = + (SaturatedInteger::wrap(currentStrideToExpand) * + SaturatedInteger::wrap(resultShape[shapeIndex--])) + .asInteger(); } } auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides)); @@ -2332,10 +2295,9 @@ computeCollapsedLayoutMap(MemRefType srcType, unsigned resultStrideIndex = resultStrides.size() - 1; for (const ReassociationIndices &reassoc : llvm::reverse(reassociation)) { auto trailingReassocs = ArrayRef(reassoc).drop_front(); - using saturated_arith::Wrapper; - auto stride = Wrapper::stride(resultStrides[resultStrideIndex--]); + auto stride = SaturatedInteger::wrap(resultStrides[resultStrideIndex--]); for (int64_t idx : llvm::reverse(trailingReassocs)) { - stride = stride * Wrapper::size(srcShape[idx]); + stride = stride * SaturatedInteger::wrap(srcShape[idx]); // Both source and result stride must have the same static value. In that // case, we can be sure, that the dimensions are collapsible (because they @@ -2345,7 +2307,7 @@ computeCollapsedLayoutMap(MemRefType srcType, // ops where obviously non-contiguous dims are collapsed, but accept ops // where we cannot be sure statically. Such ops may fail at runtime. See // the op documentation for details. - auto srcStride = Wrapper::stride(srcStrides[idx - 1]); + auto srcStride = SaturatedInteger::wrap(srcStrides[idx - 1]); if (strict && (stride.saturated || srcStride.saturated)) return failure(); @@ -2371,11 +2333,11 @@ MemRefType CollapseShapeOp::computeCollapsedType( SmallVector resultShape; resultShape.reserve(reassociation.size()); for (const ReassociationIndices &group : reassociation) { - using saturated_arith::Wrapper; - auto groupSize = Wrapper::size(1); + auto groupSize = SaturatedInteger::wrap(1); for (int64_t srcDim : group) - groupSize = groupSize * Wrapper::size(srcType.getDimSize(srcDim)); - resultShape.push_back(groupSize.asSize()); + groupSize = + groupSize * SaturatedInteger::wrap(srcType.getDimSize(srcDim)); + resultShape.push_back(groupSize.asInteger()); } if (srcType.getLayout().isIdentity()) { @@ -2586,11 +2548,10 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType, int64_t targetOffset = sourceOffset; for (auto it : llvm::zip(staticOffsets, sourceStrides)) { auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it); - using saturated_arith::Wrapper; - targetOffset = - (Wrapper::offset(targetOffset) + - Wrapper::offset(staticOffset) * Wrapper::stride(targetStride)) - .asOffset(); + targetOffset = (SaturatedInteger::wrap(targetOffset) + + SaturatedInteger::wrap(staticOffset) * + SaturatedInteger::wrap(targetStride)) + .asInteger(); } // Compute target stride whose value is: @@ -2599,10 +2560,9 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType, targetStrides.reserve(staticOffsets.size()); for (auto it : llvm::zip(sourceStrides, staticStrides)) { auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it); - using saturated_arith::Wrapper; - targetStrides.push_back( - (Wrapper::stride(sourceStride) * Wrapper::stride(staticStride)) - .asStride()); + targetStrides.push_back((SaturatedInteger::wrap(sourceStride) * + SaturatedInteger::wrap(staticStride)) + .asInteger()); } // The type is now known. diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index cd9b82d2c553..02146e8257b3 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -472,6 +472,192 @@ void CastOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add(context); } +//===----------------------------------------------------------------------===// +// ConcatOp +//===----------------------------------------------------------------------===// + +RankedTensorType ConcatOp::inferResultType(int64_t dim, TypeRange inputTypes) { + assert(!inputTypes.empty() && "cannot concatenate 0 tensors"); + auto tensorTypes = + llvm::to_vector<4>(llvm::map_range(inputTypes, [](Type type) { + return llvm::cast(type); + })); + int64_t concatRank = tensorTypes[0].getRank(); + + // The concatenation dim must be in the range [0, rank). + assert(dim >= 0 && dim < concatRank && "Invalid concatenation dim"); + + SmallVector sizes(concatRank); + for (int64_t i = 0, e = concatRank; i < e; ++i) { + if (i == dim) + continue; + SaturatedInteger size; + for (auto tensorType : tensorTypes) + size = *size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i))); + sizes[i] = size.asInteger(); + } + auto concatSize = SaturatedInteger::wrap(0); + for (auto tensorType : tensorTypes) + concatSize = + concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim)); + sizes[dim] = concatSize.asInteger(); + return RankedTensorType::get(sizes, tensorTypes[0].getElementType()); +} + +void ConcatOp::build(OpBuilder &builder, OperationState &result, int64_t dim, + ValueRange inputs) { + FailureOr resultType = + inferResultType(dim, inputs.getTypes()); + assert(succeeded(resultType) && "failed to infer concatenation result type"); + build(builder, result, *resultType, dim, inputs); +} + +LogicalResult ConcatOp::verify() { + if (getInputs().size() < 1) + return emitOpError("requires at least one input"); + + SmallVector inputTypes; + for (auto input : getInputs()) + inputTypes.push_back(cast(input.getType())); + + RankedTensorType resultType = getResultType(); + int64_t resultRank = getRank(); + if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) { + return type.getRank() != resultRank; + })) + return emitOpError("rank of concatenated inputs must match result rank"); + + Type resultElementType = resultType.getElementType(); + if (llvm::any_of(inputTypes, [&](RankedTensorType type) { + return type.getElementType() != resultElementType; + })) + return emitOpError("inputs and result element type must match"); + + int64_t dim = getDim(); + if (dim >= resultRank) + return emitOpError("concatenation dim must be less than the tensor rank"); + + SmallVector sizes(resultRank); + for (int64_t i = 0, e = resultRank; i < e; ++i) { + if (i == dim) + continue; + SaturatedInteger size; + for (auto tensorType : inputTypes) { + FailureOr maybeSize = + size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i))); + if (failed(maybeSize)) + return emitOpError("static concatenation size mismatch along ") + << "non-concatenated dimension " << i; + size = *maybeSize; + } + sizes[i] = size.asInteger(); + } + auto concatSize = SaturatedInteger::wrap(0); + for (auto tensorType : inputTypes) + concatSize = + concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim)); + sizes[dim] = concatSize.asInteger(); + auto inferredResultType = + RankedTensorType::get(sizes, inputTypes[0].getElementType()); + + for (auto [inferredSize, actualSize] : + llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) { + bool hasDynamic = ShapedType::isDynamic(inferredSize) || + ShapedType::isDynamic(actualSize); + if (!hasDynamic && inferredSize != actualSize) + return emitOpError("result type ") + << resultType << "does not match inferred shape " + << inferredResultType << " static sizes"; + } + + return success(); +} + +LogicalResult +ConcatOp::reifyResultShapes(OpBuilder &builder, + ReifiedRankedShapedTypeDims &reifiedReturnShapes) { + ValueRange inputs = getInputs(); + int64_t dim = getDim(); + RankedTensorType inferredResultType = inferResultType(dim, inputs.getTypes()); + + Value init = inputs[0]; + int64_t rank = getType().getRank(); + + reifiedReturnShapes.resize(1, SmallVector(rank)); + + // Pre-populate the result sizes with as much static information as possible + // from the given result type, as well as the inferred result type, otherwise + // use the dim sizes from the first input. + for (int64_t i = 0; i < rank; ++i) { + if (i == dim) + continue; + if (!getType().isDynamicDim(i)) { + reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i)); + } else if (!inferredResultType.isDynamicDim(i)) { + reifiedReturnShapes[0][i] = + builder.getIndexAttr(inferredResultType.getDimSize(i)); + } else { + reifiedReturnShapes[0][i] = + builder.create(init.getLoc(), init, i).getResult(); + } + } + + // Take the sum of the input sizes along the concatenated dim. + AffineExpr sum = builder.getAffineDimExpr(0); + SmallVector sizes = { + builder.create(init.getLoc(), init, 0).getResult()}; + for (auto [idx, input] : llvm::enumerate(inputs.drop_front())) { + sum = sum + builder.getAffineDimExpr(idx + 1); + sizes.push_back( + builder.createOrFold(input.getLoc(), input, dim)); + } + reifiedReturnShapes[0][dim] = + affine::makeComposedFoldedAffineApply(builder, getLoc(), sum, sizes); + + // ReifyRankedShapedTypeOpInterface requires that reifyResultShapes + // returns a Value for dynamic dimensions. + for (int64_t i = 0; i < rank; ++i) { + if (getType().isDynamicDim(i)) { + reifiedReturnShapes[0][i] = getValueOrCreateConstantIndexOp( + builder, getLoc(), reifiedReturnShapes[0][i]); + } + } + return success(); +} + +void ConcatOp::getAsmResultNames( + function_ref setNameFn) { + setNameFn(getResult(), "concat"); +} + +OpFoldResult ConcatOp::fold(FoldAdaptor) { + ValueRange inputs = getInputs(); + if (inputs.size() == 1 && inputs[0].getType() == getResultType()) + return inputs[0]; + return {}; +} + +namespace { +/// Fold a concat op with a single input to a cast. +struct SingleInputConcatOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ConcatOp concatOp, + PatternRewriter &rewriter) const override { + if (concatOp.getInputs().size() != 1) + return failure(); + rewriter.replaceOpWithNewOp(concatOp, concatOp.getResultType(), + concatOp.getInputs()[0]); + return success(); + } +}; +} // namespace + +void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + //===----------------------------------------------------------------------===// // DimOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp index 3cec91389392..ed2742387047 100644 --- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp +++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp @@ -83,6 +83,11 @@ void tensor::registerFindPayloadReplacementOpInterfaceExternalModels( // Apply...PatternsOp //===----------------------------------------------------------------------===// +void transform::ApplyDecomposeTensorConcatPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + tensor::populateDecomposeTensorConcatPatterns(patterns); +} + void transform::ApplyDropRedundantInsertSliceRankExpansionPatternsOp:: populatePatterns(RewritePatternSet &patterns) { tensor::populateDropRedundantInsertSliceRankExpansionPatterns(patterns); diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt index c5fd4e65bbf7..d233ab7a0e89 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRTensorTransforms BufferizableOpInterfaceImpl.cpp Bufferize.cpp + ConcatOpPatterns.cpp EmptyOpPatterns.cpp ExtractSliceFromReshapeUtils.cpp FoldIntoPackAndUnpackPatterns.cpp @@ -23,6 +24,7 @@ add_mlir_dialect_library(MLIRTensorTransforms MLIRAffineTransforms MLIRAffineUtils MLIRArithDialect + MLIRArithUtils MLIRBufferizationDialect MLIRBufferizationTransforms MLIRIR diff --git a/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp new file mode 100644 index 000000000000..2108fc591055 --- /dev/null +++ b/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp @@ -0,0 +1,93 @@ +//===- ConcatOpPatterns.cpp - Patterns related to tensor.concat lowering --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/IR/PatternMatch.h" + +using namespace mlir; +using namespace mlir::tensor; + +namespace { + +/// Decompose `tensor.concat` into `tensor.empty` and a chain of slice inserts. +/// +/// %concat = tensor.concat dim(1) %0, %1 : +/// (tensor<2x3xf32>, tensor<2x4xf32>) -> tensor<2x7xf32> +/// +/// Becomes +/// +/// %empty = tensor.empty() : tensor<2x7xf32> +/// %insert0 = tensor.insert_slice %0 into %empty[0, 0][2, 3][1, 1] +/// %concat = tensor.insert_slice %1 into %insert0[0, 3][2, 4][1, 1] +struct DecomposeTensorConcatOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ConcatOp concatOp, + PatternRewriter &rewriter) const override { + Location loc = concatOp.getLoc(); + FailureOr dest = + tensor::getOrCreateDestination(rewriter, loc, concatOp->getResult(0)); + if (failed(dest)) + return failure(); + + auto empty = dest->getDefiningOp(); + if (!empty) + return failure(); + + int64_t dim = concatOp.getDim(); + Value dimValue = rewriter.createOrFold( + loc, rewriter.getIndexAttr(dim)); + + int64_t rank = concatOp.getResultType().getRank(); + SmallVector strides(rank, rewriter.getIndexAttr(1)); + SmallVector offsets(rank, rewriter.getIndexAttr(0)); + + // Compute the partial sums for the slice offsets. + AffineExpr sum = rewriter.getAffineDimExpr(0); + SmallVector partialSums = {sum}; + SmallVector offsetStrides = {rewriter.getIndexAttr(0)}; + for (auto [idx, input] : + llvm::enumerate(concatOp.getInputs().drop_back())) { + sum = sum + rewriter.getAffineDimExpr(idx + 1); + partialSums.push_back(sum); + offsetStrides.push_back( + rewriter.createOrFold(loc, input, dimValue)); + } + auto partialSumMap = AffineMap::get(concatOp.getInputs().size(), 0, + partialSums, rewriter.getContext()); + SmallVector dimOffsets = + affine::makeComposedFoldedMultiResultAffineApply( + rewriter, loc, partialSumMap, offsetStrides); + + // Construct the chain of insert_slice ops into the destination. + Value result = *dest; + for (auto [input, offset] : + llvm::zip_equal(concatOp.getInputs(), dimOffsets)) { + SmallVector sizes = + tensor::getMixedSizes(rewriter, loc, input); + offsets[dim] = offset; + result = rewriter.createOrFold( + loc, input, result, offsets, sizes, strides); + } + + rewriter.replaceOpWithNewOp( + concatOp, concatOp.getResultType(), result); + return success(); + } +}; + +} // namespace + +void mlir::tensor::populateDecomposeTensorConcatPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index 580c1db60702..84c44a09aa3d 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -87,6 +87,18 @@ func.func @tensor.cast_chain_invalid(%input: tensor<4x8xi32>) -> tensor<8x4xi32> // ----- +// CHECK-LABEL: fold_concat +// CHECK-SAME: %[[ARG0:.*]]: tensor<1x2x?xi32> +func.func @fold_concat(%arg0: tensor<1x2x?xi32>) -> (tensor<1x2x3xi32>, tensor<1x2x?xi32>) { + %0 = tensor.concat dim(2) %arg0 : (tensor<1x2x?xi32>) -> tensor<1x2x3xi32> + // CHECK-NEXT: %[[CAST:.*]] = tensor.cast %[[ARG0]] : tensor<1x2x?xi32> to tensor<1x2x3xi32> + %1 = tensor.concat dim(2) %arg0 : (tensor<1x2x?xi32>) -> tensor<1x2x?xi32> + // CHECK-NEXT: return %[[CAST]], %[[ARG0]] : tensor<1x2x3xi32>, tensor<1x2x?xi32> + return %0, %1 : tensor<1x2x3xi32>, tensor<1x2x?xi32> +} + +// ----- + // CHECK-LABEL: func @fold_extract func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex) { %const_0 = arith.constant 0 : index diff --git a/mlir/test/Dialect/Tensor/decompose-concat.mlir b/mlir/test/Dialect/Tensor/decompose-concat.mlir new file mode 100644 index 000000000000..5712c77a743d --- /dev/null +++ b/mlir/test/Dialect/Tensor/decompose-concat.mlir @@ -0,0 +1,57 @@ +// RUN: mlir-opt -split-input-file -transform-interpreter -cse %s | FileCheck %s + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) { + transform.apply_patterns to %func_op { + transform.apply_patterns.tensor.decompose_concat + } : !transform.op<"func.func"> + transform.yield + } +} + +func.func @decompose_dynamic_concat(%arg0 : tensor<8x4xf32>, %arg1 : tensor) -> tensor { + %0 = tensor.concat dim(1) %arg0, %arg1 : (tensor<8x4xf32>, tensor) -> tensor + return %0 : tensor +} +// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)> +// CHECK-LABEL: func @decompose_dynamic_concat( +// CHECK-SAME: %[[ARG0:.+]]: tensor<8x4xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor + +// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor +// CHECK: %[[CONCAT_SIZE:.+]] = affine.apply #[[$MAP]]()[%[[C8]], %[[DIM]]] +// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[C8]], %[[CONCAT_SIZE]]) : tensor +// CHECK: %[[SLICE0:.+]] = tensor.insert_slice %[[ARG0]] into %[[EMPTY]][0, 0] [8, 4] [1, 1] : tensor<8x4xf32> into tensor +// CHECK: %[[OFFSET:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor +// CHECK: %[[CONCAT:.+]] = tensor.insert_slice %[[ARG1]] into %[[SLICE0]][0, 4] [%[[OFFSET]], %[[DIM]]] [1, 1] : tensor into tensor +// CHECK: return %[[CONCAT]] : tensor + +// ----- + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) { + transform.apply_patterns to %func_op { + transform.apply_patterns.tensor.decompose_concat + } : !transform.op<"func.func"> + transform.yield + } +} + +func.func @decompose_1d_concat(%arg0 : tensor<1xf32>, + %arg1 : tensor<2xf32>, + %arg2 : tensor<3xf32>, + %arg3: tensor<4xf32>) -> tensor<10xf32> { + %0 = tensor.concat dim(0) %arg0, %arg1, %arg2, %arg3 + : (tensor<1xf32>, tensor<2xf32>, tensor<3xf32>, tensor<4xf32>) -> tensor<10xf32> + return %0 : tensor<10xf32> +} +// CHECK-LABEL: func @decompose_1d_concat +// CHECK: tensor.empty() : tensor<10xf32> +// CHECK: tensor.insert_slice %{{.*}}[0] [1] [1] : tensor<1xf32> into tensor<10xf32> +// CHECK: tensor.insert_slice %{{.*}}[1] [2] [1] : tensor<2xf32> into tensor<10xf32> +// CHECK: tensor.insert_slice %{{.*}}[3] [3] [1] : tensor<3xf32> into tensor<10xf32> +// CHECK: %[[CONCAT:.+]] = tensor.insert_slice %{{.*}}[6] [4] [1] : tensor<4xf32> into tensor<10xf32> +// CHECK: return %[[CONCAT]] : tensor<10xf32> diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir index 389e7e675c0e..9b6c2327879c 100644 --- a/mlir/test/Dialect/Tensor/invalid.mlir +++ b/mlir/test/Dialect/Tensor/invalid.mlir @@ -16,6 +16,54 @@ func.func @tensor.cast_mismatching_constants(%arg0: tensor<1xf32>) { // ----- +func.func @concat_empty() { + // expected-error@+1 {{requires at least one input}} + %0 = tensor.concat dim(0) : () -> tensor<1x2x3xf32> + return +} + +// ----- + +func.func @concat_rank_mismatch(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) { + // expected-error@+1 {{rank of concatenated inputs must match result rank}} + %0 = tensor.concat dim(0) %arg0, %arg1 : (tensor<1xf32>, tensor<1xf32>) -> tensor<2x1xf32> + return +} + +// ----- + +func.func @concat_dim_out_of_range(%arg0: tensor<3xf32>) { + // expected-error@+1 {{concatenation dim must be less than the tensor rank}} + %0 = tensor.concat dim(1) %arg0 : (tensor<3xf32>) -> tensor<3xf32> + return +} + +// ----- + +func.func @concat_element_type_mismatch(%arg0: tensor<3xf32>, %arg1: tensor<3xi32>) { + // expected-error@+1 {{inputs and result element type must match}} + %0 = tensor.concat dim(0) %arg0, %arg1 : (tensor<3xf32>, tensor<3xi32>) -> tensor<3xf32> + return +} + +// ----- + +func.func @concat_incompatible_input_types(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xf32>) { + // expected-error@+1 {{static concatenation size mismatch along non-concatenated dimension 1}} + %0 = tensor.concat dim(0) %arg0, %arg1 : (tensor<3x4xf32>, tensor<4x5xf32>) -> tensor<7x5xf32> + return +} + +// ----- + +func.func @concat_static_shape_mismatch(%arg0: tensor<3xf32>) { + // expected-error@+1 {{result type 'tensor<7xf32>'does not match inferred shape 'tensor<6xf32>' static sizes}} + %0 = tensor.concat dim(0) %arg0, %arg0 : (tensor<3xf32>, tensor<3xf32>) -> tensor<7xf32> + return +} + +// ----- + func.func @extract_too_many_indices(%arg0: tensor) { // expected-error@+1 {{incorrect number of indices for extract_element}} %0 = tensor.extract %arg0[] : tensor diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir index 71a0489b23f5..2282da38803a 100644 --- a/mlir/test/Dialect/Tensor/ops.mlir +++ b/mlir/test/Dialect/Tensor/ops.mlir @@ -15,6 +15,23 @@ func.func @cast(%arg0: tensor<*xf32>, %arg1 : tensor<4x4xf32>, %arg2: tensor, %arg1 : tensor<4x4x3xf32>, %arg2: tensor) { + // CHECK: tensor.concat dim(0) %{{.*}} : (tensor<4x7x3xf32>) -> tensor<4x7x3xf32> + %0 = tensor.concat dim(0) %arg0 : (tensor<4x7x3xf32>) -> tensor<4x7x3xf32> + // CHECK: tensor.concat dim(1) %{{.*}} : (tensor<4x7x3xf32>, tensor<4x4x3xf32>) -> tensor<4x11x3xf32> + %1 = tensor.concat dim(1) %arg0, %arg1 : (tensor<4x7x3xf32>, tensor<4x4x3xf32>) -> tensor<4x11x3xf32> + // CHECK: tensor.concat dim(2) %{{.*}} : (tensor<4x7x3xf32>, tensor) -> tensor + %2 = tensor.concat dim(2) %arg0, %arg2 : (tensor<4x7x3xf32>, tensor) -> tensor + // CHECK: tensor.concat dim(1) %{{.*}} : (tensor, tensor) -> tensor + %3 = tensor.concat dim(1) %arg2, %arg2 : (tensor, tensor) -> tensor + // CHECK: tensor.concat dim(1) %{{.*}} : (tensor, tensor<4x4x3xf32>, tensor<4x7x3xf32>) -> tensor<4x?x3xf32> + %4 = tensor.concat dim(1) %arg2, %arg1, %arg0 : (tensor, tensor<4x4x3xf32>, tensor<4x7x3xf32>) -> tensor<4x?x3xf32> + return +} + +// ----- + // CHECK-LABEL: func @empty( // CHECK-SAME: %[[sz:.*]]: index func.func @empty(%sz: index) -> tensor<5x?x6xf32> {