mirror of
https://github.com/intel/llvm.git
synced 2026-01-27 06:06:34 +08:00
[mlir][tensor][NFC] Code cleanup around shape inference support for tensor.concat op (#140616)
Addresses some code review on https://github.com/llvm/llvm-project/pull/140168 that came in after merge.
This commit is contained in:
@@ -800,23 +800,22 @@ struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
|
||||
|
||||
LogicalResult matchAndRewrite(ConcatOp concatOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto operandTensorTypes =
|
||||
llvm::map_range(concatOp->getOperandTypes(), [](Type type) {
|
||||
return llvm::cast<RankedTensorType>(type);
|
||||
});
|
||||
|
||||
int64_t dim = concatOp.getDim();
|
||||
ArrayRef<int64_t> inferredResultShape =
|
||||
ConcatOp::inferResultType(dim, concatOp->getOperandTypes()).getShape();
|
||||
RankedTensorType inferredResultType =
|
||||
ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
|
||||
|
||||
// Find operands for which a more static shape can be inferred.
|
||||
LogicalResult matched = failure();
|
||||
for (auto [operandIdx, operandType] : llvm::enumerate(operandTensorTypes)) {
|
||||
// Inferred operand shapes are identical in every dimension except the
|
||||
// concatenation dimension.
|
||||
SmallVector<int64_t> inferredOperandShape(inferredResultType.getShape());
|
||||
for (auto [operandIdx, operandType] :
|
||||
llvm::enumerate(concatOp->getOperandTypes())) {
|
||||
// Compute inferred type for operand.
|
||||
SmallVector<int64_t> inferredOperandShape(inferredResultShape);
|
||||
inferredOperandShape[dim] = operandType.getDimSize(dim);
|
||||
inferredOperandShape[dim] =
|
||||
cast<RankedTensorType>(operandType).getDimSize(dim);
|
||||
auto inferredOperandType = RankedTensorType::get(
|
||||
inferredOperandShape, operandType.getElementType());
|
||||
inferredOperandShape, inferredResultType.getElementType());
|
||||
|
||||
// Check if inferred type is more static.
|
||||
if (!preservesStaticInformation(inferredOperandType, operandType)) {
|
||||
|
||||
Reference in New Issue
Block a user