[mlir] Type erase inputs to select statements in shape.broadcast lowering.

This is required or broadcasting with operands of different ranks will lead to
failures as the select op requires both possible outputs and its output type to
be the same.

Differential Revision: https://reviews.llvm.org/D89134
This commit is contained in:
Tres Popp
2020-10-09 16:45:50 +02:00
parent f82346fd73
commit 8178e41dc1
2 changed files with 54 additions and 6 deletions

View File

@@ -99,10 +99,16 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
rewriter.create<SelectOp>(loc, lhsRankULE, lhsRank, rhsRank);
Value greaterRank =
rewriter.create<SelectOp>(loc, lhsRankULE, rhsRank, lhsRank);
auto erasedRankType =
RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
Value rankErasedLhs =
rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.lhs());
Value rankErasedRhs =
rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.rhs());
Value lesserRankOperand =
rewriter.create<SelectOp>(loc, lhsRankULE, op.lhs(), op.rhs());
rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedLhs, rankErasedRhs);
Value greaterRankOperand =
rewriter.create<SelectOp>(loc, lhsRankULE, op.rhs(), op.lhs());
rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedRhs, rankErasedLhs);
// Allocate stack memory for the broadcasted extent tensor.
Type memTy = MemRefType::get({ShapedType::kDynamicSize}, indexTy);

View File

@@ -305,9 +305,9 @@ func @broadcast(%a : tensor<?xindex>, %b : !shape.shape) -> !shape.shape {
// -----
// CHECK-LABEL: @broadcast
// CHECK-LABEL: @broadcast_unknown_extents
// CHECK-SAME: (%[[LHS:.*]]: tensor<?xindex>, %[[RHS:.*]]: tensor<?xindex>)
func @broadcast(%a : tensor<?xindex>, %b : tensor<?xindex>) {
func @broadcast_unknown_extents(%a : tensor<?xindex>, %b : tensor<?xindex>) {
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<?xindex>
@@ -315,8 +315,10 @@ func @broadcast(%a : tensor<?xindex>, %b : tensor<?xindex>) {
// CHECK: %[[LHS_RANK_ULE:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index
// CHECK: %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index
// CHECK: %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index
// CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[LHS]], %[[RHS]] : tensor<?xindex>
// CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[RHS]], %[[LHS]] : tensor<?xindex>
// CHECK: %[[ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor<?xindex> to tensor<?xindex>
// CHECK: %[[ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor<?xindex> to tensor<?xindex>
// CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor<?xindex>
// CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor<?xindex>
// CHECK: %[[MEM:.*]] = alloca(%[[GREATER_RANK]]) : memref<?xindex>
// CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index
// CHECK: scf.for %[[IV:.*]] = %[[C0]] to %[[RANK_DIFF]] step %[[C1]] {
@@ -340,3 +342,43 @@ func @broadcast(%a : tensor<?xindex>, %b : tensor<?xindex>) {
: tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
return
}
// -----
// CHECK-LABEL: @broadcast_known_different_extents
// CHECK-SAME: (%[[LHS:.*]]: tensor<2xindex>, %[[RHS:.*]]: tensor<3xindex>)
func @broadcast_known_different_extents(%a : tensor<2xindex>, %b : tensor<3xindex>) {
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<2xindex>
// CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<3xindex>
// CHECK: %[[LHS_RANK_ULE:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index
// CHECK: %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index
// CHECK: %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index
// CHECK: %[[ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor<2xindex> to tensor<?xindex>
// CHECK: %[[ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor<3xindex> to tensor<?xindex>
// CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor<?xindex>
// CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor<?xindex>
// CHECK: %[[MEM:.*]] = alloca(%[[GREATER_RANK]]) : memref<?xindex>
// CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index
// CHECK: scf.for %[[IV:.*]] = %[[C0]] to %[[RANK_DIFF]] step %[[C1]] {
// CHECK: %[[EXTENT:.*]] = extract_element %[[GREATER_RANK_OPERAND]][%[[IV]]] : tensor<?xindex>
// CHECK: store %[[EXTENT]], %[[MEM]][%[[IV]]] : memref<?xindex>
// CHECK: }
// CHECK: scf.for %[[IV:.*]] = %[[RANK_DIFF]] to %[[GREATER_RANK]] step %[[C1]] {
// CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[GREATER_RANK_OPERAND]][%[[IV]]] : tensor<?xindex>
// CHECK: %[[GREATER_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[GREATER_RANK_OPERAND_EXTENT]], %[[C1]] : index
// CHECK: %[[EXTENT:.*]] = scf.if %[[GREATER_OPERAND_EXTENT_IS_ONE]] -> (index) {
// CHECK: %[[IV_SHIFTED:.*]] = subi %[[IV]], %[[RANK_DIFF]] : index
// CHECK: %[[LESSER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[LESSER_RANK_OPERAND]][%[[IV_SHIFTED]]] : tensor<?xindex>
// CHECK: scf.yield %[[LESSER_RANK_OPERAND_EXTENT]] : index
// CHECK: } else {
// CHECK: scf.yield %[[GREATER_RANK_OPERAND_EXTENT]] : index
// CHECK: }
// CHECK: store %[[EXTENT]], %[[MEM]][%[[IV]]] : memref<?xindex>
// CHECK: }
// CHECK: %[[BROADCASTED:.*]] = tensor_load %[[MEM]] : memref<?xindex>
%0 = shape.broadcast %a, %b
: tensor<2xindex>, tensor<3xindex> -> tensor<?xindex>
return
}