[mlir][vector] Add constant folding for fp16 to fp32 bitcast

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D96041
This commit is contained in:
Lei Zhang
2021-02-05 09:12:24 -05:00
parent 9f622b3d5d
commit 8dae90997a
2 changed files with 39 additions and 0 deletions

View File

@@ -25,6 +25,7 @@
#include "mlir/Support/LLVM.h"
#include "mlir/Support/MathExtras.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/bit.h"
#include <numeric>
using namespace mlir;
@@ -2804,6 +2805,30 @@ OpFoldResult BitCastOp::fold(ArrayRef<Attribute> operands) {
if (result().getType() == otherOp.source().getType())
return otherOp.source();
Attribute sourceConstant = operands.front();
if (!sourceConstant)
return {};
Type srcElemType = getSourceVectorType().getElementType();
Type dstElemType = getResultVectorType().getElementType();
if (auto floatPack = sourceConstant.dyn_cast<DenseFPElementsAttr>()) {
if (floatPack.isSplat()) {
auto splat = floatPack.getSplatValue<FloatAttr>();
// Casting fp16 into fp32.
if (srcElemType.isF16() && dstElemType.isF32()) {
uint32_t bits = static_cast<uint32_t>(
splat.getValue().bitcastToAPInt().getZExtValue());
// Duplicate the 16-bit pattern.
bits = (bits << 16) | (bits & 0xffff);
APInt intBits(32, bits);
APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
return DenseElementsAttr::get(getResultVectorType(), floatBits);
}
}
}
return {};
}

View File

@@ -556,6 +556,20 @@ func @bitcast_folding(%I1: vector<4x8xf32>, %I2: vector<2xi32>) -> (vector<4x8xf
return %0, %2 : vector<4x8xf32>, vector<2xi32>
}
// CHECK-LABEL: func @bitcast_f16_to_f32
// bit pattern: 0x00000000
// CHECK: %[[CST0:.+]] = constant dense<0.000000e+00> : vector<4xf32>
// bit pattern: 0x40004000
// CHECK: %[[CST1:.+]] = constant dense<2.00390625> : vector<4xf32>
// CHECK: return %[[CST0]], %[[CST1]]
func @bitcast_f16_to_f32() -> (vector<4xf32>, vector<4xf32>) {
%cst0 = constant dense<0.0> : vector<8xf16> // bit pattern: 0x0000
%cst1 = constant dense<2.0> : vector<8xf16> // bit pattern: 0x4000
%cast0 = vector.bitcast %cst0: vector<8xf16> to vector<4xf32>
%cast1 = vector.bitcast %cst1: vector<8xf16> to vector<4xf32>
return %cast0, %cast1: vector<4xf32>, vector<4xf32>
}
// -----
// CHECK-LABEL: broadcast_folding1