diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 6c32476d8656..9a0d5d7e1696 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -800,23 +800,22 @@ struct InferConcatOperandTypes : public OpRewritePattern { LogicalResult matchAndRewrite(ConcatOp concatOp, PatternRewriter &rewriter) const override { - auto operandTensorTypes = - llvm::map_range(concatOp->getOperandTypes(), [](Type type) { - return llvm::cast(type); - }); - int64_t dim = concatOp.getDim(); - ArrayRef inferredResultShape = - ConcatOp::inferResultType(dim, concatOp->getOperandTypes()).getShape(); + RankedTensorType inferredResultType = + ConcatOp::inferResultType(dim, concatOp->getOperandTypes()); // Find operands for which a more static shape can be inferred. LogicalResult matched = failure(); - for (auto [operandIdx, operandType] : llvm::enumerate(operandTensorTypes)) { + // Inferred operand shapes are identical in every dimension except the + // concatenation dimension. + SmallVector inferredOperandShape(inferredResultType.getShape()); + for (auto [operandIdx, operandType] : + llvm::enumerate(concatOp->getOperandTypes())) { // Compute inferred type for operand. - SmallVector inferredOperandShape(inferredResultShape); - inferredOperandShape[dim] = operandType.getDimSize(dim); + inferredOperandShape[dim] = + cast(operandType).getDimSize(dim); auto inferredOperandType = RankedTensorType::get( - inferredOperandShape, operandType.getElementType()); + inferredOperandShape, inferredResultType.getElementType()); // Check if inferred type is more static. if (!preservesStaticInformation(inferredOperandType, operandType)) {