[mlir][vector] Add extra check on distribute types to avoid crashes (#102952)

This PR addresses the issue detailed in
https://github.com/iree-org/iree/issues/17948.

The problem occurs when distributed types are set to NULL, leading to
compilation crashes.

---------

Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
This commit is contained in:
Bangtian Liu
2024-08-14 11:47:38 -04:00
committed by GitHub
parent abc1acf8df
commit b5e47d2e40
2 changed files with 37 additions and 0 deletions

View File

@@ -1689,6 +1689,9 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
}
});
if (llvm::is_contained(distTypes, Type{}))
return failure();
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, escapingValues.getArrayRef(), distTypes,

View File

@@ -620,6 +620,40 @@ func.func @vector_reduction(%laneid: index) -> (f32) {
// -----
// CHECK-PROP-LABEL: func @warp_distribute(
// CHECK-PROP-SAME: %[[ID:[a-zA-Z0-9]+]]
// CHECK-PROP-SAME: %[[SRC:[a-zA-Z0-9]+]]
// CHECK-PROP-SAME: %[[DEST:[a-zA-Z0-9]+]]
// CHECK-PROP: vector.warp_execute_on_lane_0(%[[ID]])[32]
// CHECK-PROP-NEXT: "some_def"() : () -> vector<4096xf32>
// CHECK-PROP-NEXT: %{{.*}} = vector.reduction
// CHECK-PROP: %[[DEF:.*]] = arith.divf %{{.*}}, %{{.*}} : vector<1xf32>
// CHECK-PROP-NOT: vector.warp_execute_on_lane_0
// CHECK-PROP: scf.for
// CHECK-PROP: %{{.*}} = arith.subf %{{.*}}, %[[DEF]] : vector<1xf32>
func.func @warp_distribute(%arg0: index, %src: memref<128xf32>, %dest: memref<128xf32>){
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c128 = arith.constant 128 : index
%f0 = arith.constant 0.000000e+00 : f32
vector.warp_execute_on_lane_0(%arg0)[32]{
%cst_1 = arith.constant dense<2.621440e+05> : vector<1xf32>
%0 = "some_def"() : () -> (vector<4096xf32>)
%1 = vector.reduction <add>, %0, %cst : vector<4096xf32> into f32
%2 = vector.broadcast %1 : f32 to vector<1xf32>
%3 = arith.divf %2, %cst_1 : vector<1xf32>
scf.for %arg1 = %c0 to %c128 step %c1 {
%4 = vector.transfer_read %src[%arg1], %f0 {in_bounds = [true]} : memref<128xf32>, vector<1xf32>
%5 = arith.subf %4, %3 : vector<1xf32>
vector.transfer_write %5, %dest[%arg1] : vector<1xf32>, memref<128xf32>
}
}
return
}
// -----
func.func @vector_reduction(%laneid: index, %m0: memref<4x2x32xf32>, %m1: memref<f32>) {
%c0 = arith.constant 0: index
%f0 = arith.constant 0.0: f32