mirror of
https://github.com/intel/llvm.git
synced 2026-01-25 01:07:04 +08:00
[mlir][Vector] Support 0-D vectors in BitCastOp
The implementation only allows to bit-cast between two 0-D vectors. We could probably support casting from/to vectors like `vector<1xf32>`, but I wasn't convinced that this would be important and it would require breaking the invariant that `BitCastOp` works only on vectors with equal rank. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D114854
This commit is contained in:
committed by
Nicolas Vasilache
parent
8e2b373396
commit
1423e8bf5d
@@ -675,7 +675,7 @@ def Vector_InsertElementOp :
|
||||
position and inserts the source into the destination at the proper position.
|
||||
|
||||
Note that this instruction resembles vector.insert, but is restricted to 0-D
|
||||
and 1-D vectors and relaxed to dynamic indices.
|
||||
and 1-D vectors and relaxed to dynamic indices.
|
||||
|
||||
It is meant to be closer to LLVM's version:
|
||||
https://llvm.org/docs/LangRef.html#insertelement-instruction
|
||||
@@ -2025,13 +2025,14 @@ def Vector_ShapeCastOp :
|
||||
|
||||
def Vector_BitCastOp :
|
||||
Vector_Op<"bitcast", [NoSideEffect, AllRanksMatch<["source", "result"]>]>,
|
||||
Arguments<(ins AnyVector:$source)>,
|
||||
Results<(outs AnyVector:$result)>{
|
||||
Arguments<(ins AnyVectorOfAnyRank:$source)>,
|
||||
Results<(outs AnyVectorOfAnyRank:$result)>{
|
||||
let summary = "bitcast casts between vectors";
|
||||
let description = [{
|
||||
The bitcast operation casts between vectors of the same rank, the minor 1-D
|
||||
vector size is casted to a vector with a different element type but same
|
||||
bitwidth.
|
||||
bitwidth. In case of 0-D vectors, the bitwidth of element types must be
|
||||
equal.
|
||||
|
||||
Example:
|
||||
|
||||
@@ -2044,6 +2045,9 @@ def Vector_BitCastOp :
|
||||
|
||||
// Example casting to an element type of the same size.
|
||||
%5 = vector.bitcast %4 : vector<5x1x4x3xf32> to vector<5x1x4x3xi32>
|
||||
|
||||
// Example casting of 0-D vectors.
|
||||
%7 = vector.bitcast %6 : vector<f32> to vector<i32>
|
||||
```
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
|
||||
@@ -121,9 +121,9 @@ public:
|
||||
LogicalResult
|
||||
matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Only 1-D vectors can be lowered to LLVM.
|
||||
VectorType resultTy = bitCastOp.getType();
|
||||
if (resultTy.getRank() != 1)
|
||||
// Only 0-D and 1-D vectors can be lowered to LLVM.
|
||||
VectorType resultTy = bitCastOp.getResultVectorType();
|
||||
if (resultTy.getRank() > 1)
|
||||
return failure();
|
||||
Type newResultTy = typeConverter->convertType(resultTy);
|
||||
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy,
|
||||
|
||||
@@ -3702,12 +3702,20 @@ static LogicalResult verify(BitCastOp op) {
|
||||
}
|
||||
|
||||
DataLayout dataLayout = DataLayout::closest(op);
|
||||
if (dataLayout.getTypeSizeInBits(sourceVectorType.getElementType()) *
|
||||
sourceVectorType.getShape().back() !=
|
||||
dataLayout.getTypeSizeInBits(resultVectorType.getElementType()) *
|
||||
resultVectorType.getShape().back())
|
||||
auto sourceElementBits =
|
||||
dataLayout.getTypeSizeInBits(sourceVectorType.getElementType());
|
||||
auto resultElementBits =
|
||||
dataLayout.getTypeSizeInBits(resultVectorType.getElementType());
|
||||
|
||||
if (sourceVectorType.getRank() == 0) {
|
||||
if (sourceElementBits != resultElementBits)
|
||||
return op.emitOpError("source/result bitwidth of the 0-D vector element "
|
||||
"types must be equal");
|
||||
} else if (sourceElementBits * sourceVectorType.getShape().back() !=
|
||||
resultElementBits * resultVectorType.getShape().back()) {
|
||||
return op.emitOpError(
|
||||
"source/result bitwidth of the minor 1-D vectors must be equal");
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -1,6 +1,20 @@
|
||||
// RUN: mlir-opt %s -convert-vector-to-llvm -split-input-file | FileCheck %s
|
||||
|
||||
|
||||
func @bitcast_f32_to_i32_vector_0d(%input: vector<f32>) -> vector<i32> {
|
||||
%0 = vector.bitcast %input : vector<f32> to vector<i32>
|
||||
return %0 : vector<i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @bitcast_f32_to_i32_vector_0d
|
||||
// CHECK-SAME: %[[input:.*]]: vector<f32>
|
||||
// CHECK: %[[vec_f32_1d:.*]] = builtin.unrealized_conversion_cast %[[input]] : vector<f32> to vector<1xf32>
|
||||
// CHECK: %[[vec_i32_1d:.*]] = llvm.bitcast %[[vec_f32_1d]] : vector<1xf32> to vector<1xi32>
|
||||
// CHECK: %[[vec_i32_0d:.*]] = builtin.unrealized_conversion_cast %[[vec_i32_1d]] : vector<1xi32> to vector<i32>
|
||||
// CHECK: return %[[vec_i32_0d]] : vector<i32>
|
||||
|
||||
// -----
|
||||
|
||||
func @bitcast_f32_to_i32_vector(%input: vector<16xf32>) -> vector<16xi32> {
|
||||
%0 = vector.bitcast %input : vector<16xf32> to vector<16xi32>
|
||||
return %0 : vector<16xi32>
|
||||
|
||||
@@ -1014,6 +1014,20 @@ func @bitcast_not_vector(%arg0 : vector<5x1x3x2xf32>) {
|
||||
|
||||
// -----
|
||||
|
||||
func @bitcast_rank_mismatch_to_0d(%arg0 : vector<1xf32>) {
|
||||
// expected-error@+1 {{op failed to verify that all of {source, result} have same rank}}
|
||||
%0 = vector.bitcast %arg0 : vector<1xf32> to vector<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @bitcast_rank_mismatch_from_0d(%arg0 : vector<f32>) {
|
||||
// expected-error@+1 {{op failed to verify that all of {source, result} have same rank}}
|
||||
%0 = vector.bitcast %arg0 : vector<f32> to vector<1xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @bitcast_rank_mismatch(%arg0 : vector<5x1x3x2xf32>) {
|
||||
// expected-error@+1 {{op failed to verify that all of {source, result} have same rank}}
|
||||
%0 = vector.bitcast %arg0 : vector<5x1x3x2xf32> to vector<5x3x2xf32>
|
||||
|
||||
@@ -432,8 +432,9 @@ func @shape_cast(%arg0 : vector<5x1x3x2xf32>,
|
||||
func @bitcast(%arg0 : vector<5x1x3x2xf32>,
|
||||
%arg1 : vector<8x1xi32>,
|
||||
%arg2 : vector<16x1x8xi8>,
|
||||
%arg3 : vector<8x2x1xindex>)
|
||||
-> (vector<5x1x3x4xf16>, vector<5x1x3x8xi8>, vector<8x4xi8>, vector<8x1xf32>, vector<16x1x2xi32>, vector<16x1x4xi16>, vector<16x1x1xindex>, vector<8x2x2xf32>) {
|
||||
%arg3 : vector<8x2x1xindex>,
|
||||
%arg4 : vector<f32>)
|
||||
-> (vector<5x1x3x4xf16>, vector<5x1x3x8xi8>, vector<8x4xi8>, vector<8x1xf32>, vector<16x1x2xi32>, vector<16x1x4xi16>, vector<16x1x1xindex>, vector<8x2x2xf32>, vector<i32>) {
|
||||
|
||||
// CHECK: vector.bitcast %{{.*}} : vector<5x1x3x2xf32> to vector<5x1x3x4xf16>
|
||||
%0 = vector.bitcast %arg0 : vector<5x1x3x2xf32> to vector<5x1x3x4xf16>
|
||||
@@ -459,7 +460,10 @@ func @bitcast(%arg0 : vector<5x1x3x2xf32>,
|
||||
// CHECK-NEXT: vector.bitcast %{{.*}} : vector<8x2x1xindex> to vector<8x2x2xf32>
|
||||
%7 = vector.bitcast %arg3 : vector<8x2x1xindex> to vector<8x2x2xf32>
|
||||
|
||||
return %0, %1, %2, %3, %4, %5, %6, %7 : vector<5x1x3x4xf16>, vector<5x1x3x8xi8>, vector<8x4xi8>, vector<8x1xf32>, vector<16x1x2xi32>, vector<16x1x4xi16>, vector<16x1x1xindex>, vector<8x2x2xf32>
|
||||
// CHECK: vector.bitcast %{{.*}} : vector<f32> to vector<i32>
|
||||
%8 = vector.bitcast %arg4 : vector<f32> to vector<i32>
|
||||
|
||||
return %0, %1, %2, %3, %4, %5, %6, %7, %8 : vector<5x1x3x4xf16>, vector<5x1x3x8xi8>, vector<8x4xi8>, vector<8x1xf32>, vector<16x1x2xi32>, vector<16x1x4xi16>, vector<16x1x1xindex>, vector<8x2x2xf32>, vector<i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @vector_fma
|
||||
|
||||
@@ -55,6 +55,19 @@ func @broadcast_0d(%a: f32) {
|
||||
return
|
||||
}
|
||||
|
||||
func @bitcast_0d() {
|
||||
%0 = arith.constant 42 : i32
|
||||
%1 = arith.constant dense<0> : vector<i32>
|
||||
%2 = vector.insertelement %0, %1[] : vector<i32>
|
||||
%3 = vector.bitcast %2 : vector<i32> to vector<f32>
|
||||
%4 = vector.extractelement %3[] : vector<f32>
|
||||
%5 = arith.bitcast %4 : f32 to i32
|
||||
// CHECK: 42
|
||||
vector.print %5: i32
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
func @entry() {
|
||||
%0 = arith.constant 42.0 : f32
|
||||
%1 = arith.constant dense<0.0> : vector<f32>
|
||||
@@ -68,5 +81,7 @@ func @entry() {
|
||||
call @splat_0d(%4) : (f32) -> ()
|
||||
call @broadcast_0d(%4) : (f32) -> ()
|
||||
|
||||
call @bitcast_0d() : () -> ()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user