mirror of
https://github.com/intel/llvm.git
synced 2026-01-23 16:06:39 +08:00
[mlir] [VectorOps] Improve vector.constant_mask lowering
Use direct vector constants for the 1-D case. This approach scales much better than generating elaborate insertion operations that are eventually folded into a constant. We could of course generalize the 1-D case to higher ranks, but this simplification already helps in scaling some microbenchmarks that would formerly crash on the intermediate IR length. Reviewed By: reidtatge Differential Revision: https://reviews.llvm.org/D82144
This commit is contained in:
@@ -120,6 +120,7 @@ public:
|
||||
IntegerAttr getUI32IntegerAttr(uint32_t value);
|
||||
|
||||
/// Vector-typed DenseIntElementsAttr getters. `values` must not be empty.
|
||||
DenseIntElementsAttr getBoolVectorAttr(ArrayRef<bool> values);
|
||||
DenseIntElementsAttr getI32VectorAttr(ArrayRef<int32_t> values);
|
||||
DenseIntElementsAttr getI64VectorAttr(ArrayRef<int64_t> values);
|
||||
|
||||
|
||||
@@ -1311,7 +1311,8 @@ public:
|
||||
/// %4 = vector.insert %l, %z[0]
|
||||
/// ..
|
||||
/// %x = vector.insert %l, %..[a-1]
|
||||
/// which will be folded at LLVM IR level.
|
||||
/// until a one-dimensional vector is reached. All these operations
|
||||
/// will be folded at LLVM IR level.
|
||||
class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
|
||||
public:
|
||||
using OpRewritePattern<vector::ConstantMaskOp>::OpRewritePattern;
|
||||
@@ -1325,20 +1326,22 @@ public:
|
||||
int64_t rank = dimSizes.size();
|
||||
int64_t trueDim = dimSizes[0].cast<IntegerAttr>().getInt();
|
||||
|
||||
Value trueVal;
|
||||
if (rank == 1) {
|
||||
trueVal = rewriter.create<ConstantOp>(
|
||||
loc, eltType, rewriter.getIntegerAttr(eltType, 1));
|
||||
} else {
|
||||
VectorType lowType =
|
||||
VectorType::get(dstType.getShape().drop_front(), eltType);
|
||||
SmallVector<int64_t, 4> newDimSizes;
|
||||
for (int64_t r = 1; r < rank; r++)
|
||||
newDimSizes.push_back(dimSizes[r].cast<IntegerAttr>().getInt());
|
||||
trueVal = rewriter.create<vector::ConstantMaskOp>(
|
||||
loc, lowType, rewriter.getI64ArrayAttr(newDimSizes));
|
||||
SmallVector<bool, 4> values(dstType.getDimSize(0));
|
||||
for (int64_t d = 0; d < trueDim; d++)
|
||||
values[d] = true;
|
||||
rewriter.replaceOpWithNewOp<ConstantOp>(
|
||||
op, dstType, rewriter.getBoolVectorAttr(values));
|
||||
return success();
|
||||
}
|
||||
|
||||
VectorType lowType =
|
||||
VectorType::get(dstType.getShape().drop_front(), eltType);
|
||||
SmallVector<int64_t, 4> newDimSizes;
|
||||
for (int64_t r = 1; r < rank; r++)
|
||||
newDimSizes.push_back(dimSizes[r].cast<IntegerAttr>().getInt());
|
||||
Value trueVal = rewriter.create<vector::ConstantMaskOp>(
|
||||
loc, lowType, rewriter.getI64ArrayAttr(newDimSizes));
|
||||
Value result = rewriter.create<ConstantOp>(loc, dstType,
|
||||
rewriter.getZeroAttr(dstType));
|
||||
for (int64_t d = 0; d < trueDim; d++) {
|
||||
|
||||
@@ -104,6 +104,12 @@ IntegerAttr Builder::getI64IntegerAttr(int64_t value) {
|
||||
return IntegerAttr::get(getIntegerType(64), APInt(64, value));
|
||||
}
|
||||
|
||||
DenseIntElementsAttr Builder::getBoolVectorAttr(ArrayRef<bool> values) {
|
||||
return DenseIntElementsAttr::get(
|
||||
VectorType::get(static_cast<int64_t>(values.size()), getI1Type()),
|
||||
values);
|
||||
}
|
||||
|
||||
DenseIntElementsAttr Builder::getI32VectorAttr(ArrayRef<int32_t> values) {
|
||||
return DenseIntElementsAttr::get(
|
||||
VectorType::get(static_cast<int64_t>(values.size()), getIntegerType(32)),
|
||||
|
||||
@@ -944,17 +944,26 @@ func @genbool_1d() -> vector<8xi1> {
|
||||
return %0 : vector<8xi1>
|
||||
}
|
||||
// CHECK-LABEL: func @genbool_1d
|
||||
// CHECK: %[[T0:.*]] = llvm.mlir.constant(true) : !llvm.i1
|
||||
// CHECK: %[[T1:.*]] = llvm.mlir.constant(dense<false> : vector<8xi1>) : !llvm<"<8 x i1>">
|
||||
// CHECK: %[[T2:.*]] = llvm.mlir.constant(0 : i64) : !llvm.i64
|
||||
// CHECK: %[[T3:.*]] = llvm.insertelement %[[T0]], %[[T1]][%[[T2]] : !llvm.i64] : !llvm<"<8 x i1>">
|
||||
// CHECK: %[[T4:.*]] = llvm.mlir.constant(1 : i64) : !llvm.i64
|
||||
// CHECK: %[[T5:.*]] = llvm.insertelement %[[T0]], %[[T3]][%[[T4]] : !llvm.i64] : !llvm<"<8 x i1>">
|
||||
// CHECK: %[[T6:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64
|
||||
// CHECK: %[[T7:.*]] = llvm.insertelement %[[T0]], %[[T5]][%[[T6]] : !llvm.i64] : !llvm<"<8 x i1>">
|
||||
// CHECK: %[[T8:.*]] = llvm.mlir.constant(3 : i64) : !llvm.i64
|
||||
// CHECK: %[[T9:.*]] = llvm.insertelement %[[T0]], %[[T7]][%[[T8]] : !llvm.i64] : !llvm<"<8 x i1>">
|
||||
// CHECK: llvm.return %9 : !llvm<"<8 x i1>">
|
||||
// CHECK: %[[C1:.*]] = llvm.mlir.constant(dense<[true, true, true, true, false, false, false, false]> : vector<8xi1>) : !llvm<"<8 x i1>">
|
||||
// CHECK: llvm.return %[[C1]] : !llvm<"<8 x i1>">
|
||||
|
||||
func @genbool_2d() -> vector<4x4xi1> {
|
||||
%v = vector.constant_mask [2, 2] : vector<4x4xi1>
|
||||
return %v: vector<4x4xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @genbool_2d
|
||||
// CHECK: %[[C1:.*]] = llvm.mlir.constant(dense<[true, true, false, false]> : vector<4xi1>) : !llvm<"<4 x i1>">
|
||||
// CHECK: %[[C2:.*]] = llvm.mlir.constant(dense<false> : vector<4x4xi1>) : !llvm<"[4 x <4 x i1>]">
|
||||
// CHECK: %[[T0:.*]] = llvm.insertvalue %[[C1]], %[[C2]][0] : !llvm<"[4 x <4 x i1>]">
|
||||
// CHECK: %[[T1:.*]] = llvm.insertvalue %[[C1]], %[[T0]][1] : !llvm<"[4 x <4 x i1>]">
|
||||
// CHECK: llvm.return %[[T1]] : !llvm<"[4 x <4 x i1>]">
|
||||
|
||||
func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
|
||||
%0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 }
|
||||
: vector<16xf32> -> vector<16xf32>
|
||||
return %0 : vector<16xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @flat_transpose
|
||||
// CHECK-SAME: %[[A:.*]]: !llvm<"<16 x float>">
|
||||
@@ -962,8 +971,3 @@ func @genbool_1d() -> vector<8xi1> {
|
||||
// CHECK-SAME: {columns = 4 : i32, rows = 4 : i32} :
|
||||
// CHECK-SAME: !llvm<"<16 x float>"> into !llvm<"<16 x float>">
|
||||
// CHECK: llvm.return %[[T]] : !llvm<"<16 x float>">
|
||||
func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
|
||||
%0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 }
|
||||
: vector<16xf32> -> vector<16xf32>
|
||||
return %0 : vector<16xf32>
|
||||
}
|
||||
|
||||
@@ -676,13 +676,8 @@ func @broadcast_stretch_in_middle(%arg0: vector<4x1x2xf32>) -> vector<4x3x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @genbool_1d
|
||||
// CHECK: %[[TT:.*]] = constant true
|
||||
// CHECK: %[[C1:.*]] = constant dense<false> : vector<8xi1>
|
||||
// CHECK: %[[T0:.*]] = vector.insert %[[TT]], %[[C1]] [0] : i1 into vector<8xi1>
|
||||
// CHECK: %[[T1:.*]] = vector.insert %[[TT]], %[[T0]] [1] : i1 into vector<8xi1>
|
||||
// CHECK: %[[T2:.*]] = vector.insert %[[TT]], %[[T1]] [2] : i1 into vector<8xi1>
|
||||
// CHECK: %[[T3:.*]] = vector.insert %[[TT]], %[[T2]] [3] : i1 into vector<8xi1>
|
||||
// CHECK: return %[[T3]] : vector<8xi1>
|
||||
// CHECK: %[[T0:.*]] = constant dense<[true, true, true, true, false, false, false, false]> : vector<8xi1>
|
||||
// CHECK: return %[[T0]] : vector<8xi1>
|
||||
|
||||
func @genbool_1d() -> vector<8xi1> {
|
||||
%0 = vector.constant_mask [4] : vector<8xi1>
|
||||
@@ -690,14 +685,11 @@ func @genbool_1d() -> vector<8xi1> {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @genbool_2d
|
||||
// CHECK: %[[TT:.*]] = constant true
|
||||
// CHECK: %[[C1:.*]] = constant dense<false> : vector<4xi1>
|
||||
// CHECK: %[[C1:.*]] = constant dense<[true, true, false, false]> : vector<4xi1>
|
||||
// CHECK: %[[C2:.*]] = constant dense<false> : vector<4x4xi1>
|
||||
// CHECK: %[[T0:.*]] = vector.insert %[[TT]], %[[C1]] [0] : i1 into vector<4xi1>
|
||||
// CHECK: %[[T1:.*]] = vector.insert %[[TT]], %[[T0]] [1] : i1 into vector<4xi1>
|
||||
// CHECK: %[[T2:.*]] = vector.insert %[[T1]], %[[C2]] [0] : vector<4xi1> into vector<4x4xi1>
|
||||
// CHECK: %[[T3:.*]] = vector.insert %[[T1]], %[[T2]] [1] : vector<4xi1> into vector<4x4xi1>
|
||||
// CHECK: return %[[T3]] : vector<4x4xi1>
|
||||
// CHECK: %[[T0:.*]] = vector.insert %[[C1]], %[[C2]] [0] : vector<4xi1> into vector<4x4xi1>
|
||||
// CHECK: %[[T1:.*]] = vector.insert %[[C1]], %[[T0]] [1] : vector<4xi1> into vector<4x4xi1>
|
||||
// CHECK: return %[[T1]] : vector<4x4xi1>
|
||||
|
||||
func @genbool_2d() -> vector<4x4xi1> {
|
||||
%v = vector.constant_mask [2, 2] : vector<4x4xi1>
|
||||
@@ -705,16 +697,12 @@ func @genbool_2d() -> vector<4x4xi1> {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @genbool_3d
|
||||
// CHECK: %[[TT:.*]] = constant true
|
||||
// CHECK: %[[C1:.*]] = constant dense<false> : vector<4xi1>
|
||||
// CHECK: %[[C1:.*]] = constant dense<[true, true, true, false]> : vector<4xi1>
|
||||
// CHECK: %[[C2:.*]] = constant dense<false> : vector<3x4xi1>
|
||||
// CHECK: %[[C3:.*]] = constant dense<false> : vector<2x3x4xi1>
|
||||
// CHECK: %[[T0:.*]] = vector.insert %[[TT]], %[[C1]] [0] : i1 into vector<4xi1>
|
||||
// CHECK: %[[T1:.*]] = vector.insert %[[TT]], %[[T0]] [1] : i1 into vector<4xi1>
|
||||
// CHECK: %[[T2:.*]] = vector.insert %[[TT]], %[[T1]] [2] : i1 into vector<4xi1>
|
||||
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C2]] [0] : vector<4xi1> into vector<3x4xi1>
|
||||
// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[C3]] [0] : vector<3x4xi1> into vector<2x3x4xi1>
|
||||
// CHECK: return %[[T4]] : vector<2x3x4xi1>
|
||||
// CHECK: %[[T0:.*]] = vector.insert %[[C1]], %[[C2]] [0] : vector<4xi1> into vector<3x4xi1>
|
||||
// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C3]] [0] : vector<3x4xi1> into vector<2x3x4xi1>
|
||||
// CHECK: return %[[T1]] : vector<2x3x4xi1>
|
||||
|
||||
func @genbool_3d() -> vector<2x3x4xi1> {
|
||||
%v = vector.constant_mask [1, 1, 3] : vector<2x3x4xi1>
|
||||
|
||||
Reference in New Issue
Block a user