[mlir][tensor][NFC] Code cleanup around shape inference support for tensor.concat op (#140616)

Addresses some code review on
https://github.com/llvm/llvm-project/pull/140168 that came in after
merge.
This commit is contained in:
Aaron St George
2025-05-19 18:54:13 -07:00
committed by GitHub
parent df0358f36b
commit a0a55df385

View File

@@ -800,23 +800,22 @@ struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
LogicalResult matchAndRewrite(ConcatOp concatOp,
PatternRewriter &rewriter) const override {
auto operandTensorTypes =
llvm::map_range(concatOp->getOperandTypes(), [](Type type) {
return llvm::cast<RankedTensorType>(type);
});
int64_t dim = concatOp.getDim();
ArrayRef<int64_t> inferredResultShape =
ConcatOp::inferResultType(dim, concatOp->getOperandTypes()).getShape();
RankedTensorType inferredResultType =
ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
// Find operands for which a more static shape can be inferred.
LogicalResult matched = failure();
for (auto [operandIdx, operandType] : llvm::enumerate(operandTensorTypes)) {
// Inferred operand shapes are identical in every dimension except the
// concatenation dimension.
SmallVector<int64_t> inferredOperandShape(inferredResultType.getShape());
for (auto [operandIdx, operandType] :
llvm::enumerate(concatOp->getOperandTypes())) {
// Compute inferred type for operand.
SmallVector<int64_t> inferredOperandShape(inferredResultShape);
inferredOperandShape[dim] = operandType.getDimSize(dim);
inferredOperandShape[dim] =
cast<RankedTensorType>(operandType).getDimSize(dim);
auto inferredOperandType = RankedTensorType::get(
inferredOperandShape, operandType.getElementType());
inferredOperandShape, inferredResultType.getElementType());
// Check if inferred type is more static.
if (!preservesStaticInformation(inferredOperandType, operandType)) {