mirror of
https://github.com/intel/llvm.git
synced 2026-02-02 18:18:09 +08:00
[mlir] Type erase inputs to select statements in shape.broadcast lowering.
This is required or broadcasting with operands of different ranks will lead to failures as the select op requires both possible outputs and its output type to be the same. Differential Revision: https://reviews.llvm.org/D89134
This commit is contained in:
@@ -99,10 +99,16 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
|
||||
rewriter.create<SelectOp>(loc, lhsRankULE, lhsRank, rhsRank);
|
||||
Value greaterRank =
|
||||
rewriter.create<SelectOp>(loc, lhsRankULE, rhsRank, lhsRank);
|
||||
auto erasedRankType =
|
||||
RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
|
||||
Value rankErasedLhs =
|
||||
rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.lhs());
|
||||
Value rankErasedRhs =
|
||||
rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.rhs());
|
||||
Value lesserRankOperand =
|
||||
rewriter.create<SelectOp>(loc, lhsRankULE, op.lhs(), op.rhs());
|
||||
rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedLhs, rankErasedRhs);
|
||||
Value greaterRankOperand =
|
||||
rewriter.create<SelectOp>(loc, lhsRankULE, op.rhs(), op.lhs());
|
||||
rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedRhs, rankErasedLhs);
|
||||
|
||||
// Allocate stack memory for the broadcasted extent tensor.
|
||||
Type memTy = MemRefType::get({ShapedType::kDynamicSize}, indexTy);
|
||||
|
||||
@@ -305,9 +305,9 @@ func @broadcast(%a : tensor<?xindex>, %b : !shape.shape) -> !shape.shape {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @broadcast
|
||||
// CHECK-LABEL: @broadcast_unknown_extents
|
||||
// CHECK-SAME: (%[[LHS:.*]]: tensor<?xindex>, %[[RHS:.*]]: tensor<?xindex>)
|
||||
func @broadcast(%a : tensor<?xindex>, %b : tensor<?xindex>) {
|
||||
func @broadcast_unknown_extents(%a : tensor<?xindex>, %b : tensor<?xindex>) {
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<?xindex>
|
||||
@@ -315,8 +315,10 @@ func @broadcast(%a : tensor<?xindex>, %b : tensor<?xindex>) {
|
||||
// CHECK: %[[LHS_RANK_ULE:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||
// CHECK: %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||
// CHECK: %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index
|
||||
// CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[LHS]], %[[RHS]] : tensor<?xindex>
|
||||
// CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[RHS]], %[[LHS]] : tensor<?xindex>
|
||||
// CHECK: %[[ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor<?xindex> to tensor<?xindex>
|
||||
// CHECK: %[[ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor<?xindex> to tensor<?xindex>
|
||||
// CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor<?xindex>
|
||||
// CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor<?xindex>
|
||||
// CHECK: %[[MEM:.*]] = alloca(%[[GREATER_RANK]]) : memref<?xindex>
|
||||
// CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index
|
||||
// CHECK: scf.for %[[IV:.*]] = %[[C0]] to %[[RANK_DIFF]] step %[[C1]] {
|
||||
@@ -340,3 +342,43 @@ func @broadcast(%a : tensor<?xindex>, %b : tensor<?xindex>) {
|
||||
: tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @broadcast_known_different_extents
|
||||
// CHECK-SAME: (%[[LHS:.*]]: tensor<2xindex>, %[[RHS:.*]]: tensor<3xindex>)
|
||||
func @broadcast_known_different_extents(%a : tensor<2xindex>, %b : tensor<3xindex>) {
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<2xindex>
|
||||
// CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<3xindex>
|
||||
// CHECK: %[[LHS_RANK_ULE:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||
// CHECK: %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||
// CHECK: %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index
|
||||
// CHECK: %[[ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor<2xindex> to tensor<?xindex>
|
||||
// CHECK: %[[ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor<3xindex> to tensor<?xindex>
|
||||
// CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor<?xindex>
|
||||
// CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor<?xindex>
|
||||
// CHECK: %[[MEM:.*]] = alloca(%[[GREATER_RANK]]) : memref<?xindex>
|
||||
// CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index
|
||||
// CHECK: scf.for %[[IV:.*]] = %[[C0]] to %[[RANK_DIFF]] step %[[C1]] {
|
||||
// CHECK: %[[EXTENT:.*]] = extract_element %[[GREATER_RANK_OPERAND]][%[[IV]]] : tensor<?xindex>
|
||||
// CHECK: store %[[EXTENT]], %[[MEM]][%[[IV]]] : memref<?xindex>
|
||||
// CHECK: }
|
||||
// CHECK: scf.for %[[IV:.*]] = %[[RANK_DIFF]] to %[[GREATER_RANK]] step %[[C1]] {
|
||||
// CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[GREATER_RANK_OPERAND]][%[[IV]]] : tensor<?xindex>
|
||||
// CHECK: %[[GREATER_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[GREATER_RANK_OPERAND_EXTENT]], %[[C1]] : index
|
||||
// CHECK: %[[EXTENT:.*]] = scf.if %[[GREATER_OPERAND_EXTENT_IS_ONE]] -> (index) {
|
||||
// CHECK: %[[IV_SHIFTED:.*]] = subi %[[IV]], %[[RANK_DIFF]] : index
|
||||
// CHECK: %[[LESSER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[LESSER_RANK_OPERAND]][%[[IV_SHIFTED]]] : tensor<?xindex>
|
||||
// CHECK: scf.yield %[[LESSER_RANK_OPERAND_EXTENT]] : index
|
||||
// CHECK: } else {
|
||||
// CHECK: scf.yield %[[GREATER_RANK_OPERAND_EXTENT]] : index
|
||||
// CHECK: }
|
||||
// CHECK: store %[[EXTENT]], %[[MEM]][%[[IV]]] : memref<?xindex>
|
||||
// CHECK: }
|
||||
// CHECK: %[[BROADCASTED:.*]] = tensor_load %[[MEM]] : memref<?xindex>
|
||||
%0 = shape.broadcast %a, %b
|
||||
: tensor<2xindex>, tensor<3xindex> -> tensor<?xindex>
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user