mirror of
https://github.com/intel/llvm.git
synced 2026-01-31 07:04:56 +08:00
[mlir][vector] Add constant folding for fp16 to fp32 bitcast
Reviewed By: ThomasRaoux Differential Revision: https://reviews.llvm.org/D96041
This commit is contained in:
@@ -25,6 +25,7 @@
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Support/MathExtras.h"
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
#include "llvm/ADT/bit.h"
|
||||
#include <numeric>
|
||||
|
||||
using namespace mlir;
|
||||
@@ -2804,6 +2805,30 @@ OpFoldResult BitCastOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (result().getType() == otherOp.source().getType())
|
||||
return otherOp.source();
|
||||
|
||||
Attribute sourceConstant = operands.front();
|
||||
if (!sourceConstant)
|
||||
return {};
|
||||
|
||||
Type srcElemType = getSourceVectorType().getElementType();
|
||||
Type dstElemType = getResultVectorType().getElementType();
|
||||
|
||||
if (auto floatPack = sourceConstant.dyn_cast<DenseFPElementsAttr>()) {
|
||||
if (floatPack.isSplat()) {
|
||||
auto splat = floatPack.getSplatValue<FloatAttr>();
|
||||
|
||||
// Casting fp16 into fp32.
|
||||
if (srcElemType.isF16() && dstElemType.isF32()) {
|
||||
uint32_t bits = static_cast<uint32_t>(
|
||||
splat.getValue().bitcastToAPInt().getZExtValue());
|
||||
// Duplicate the 16-bit pattern.
|
||||
bits = (bits << 16) | (bits & 0xffff);
|
||||
APInt intBits(32, bits);
|
||||
APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
|
||||
return DenseElementsAttr::get(getResultVectorType(), floatBits);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
|
||||
@@ -556,6 +556,20 @@ func @bitcast_folding(%I1: vector<4x8xf32>, %I2: vector<2xi32>) -> (vector<4x8xf
|
||||
return %0, %2 : vector<4x8xf32>, vector<2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @bitcast_f16_to_f32
|
||||
// bit pattern: 0x00000000
|
||||
// CHECK: %[[CST0:.+]] = constant dense<0.000000e+00> : vector<4xf32>
|
||||
// bit pattern: 0x40004000
|
||||
// CHECK: %[[CST1:.+]] = constant dense<2.00390625> : vector<4xf32>
|
||||
// CHECK: return %[[CST0]], %[[CST1]]
|
||||
func @bitcast_f16_to_f32() -> (vector<4xf32>, vector<4xf32>) {
|
||||
%cst0 = constant dense<0.0> : vector<8xf16> // bit pattern: 0x0000
|
||||
%cst1 = constant dense<2.0> : vector<8xf16> // bit pattern: 0x4000
|
||||
%cast0 = vector.bitcast %cst0: vector<8xf16> to vector<4xf32>
|
||||
%cast1 = vector.bitcast %cst1: vector<8xf16> to vector<4xf32>
|
||||
return %cast0, %cast1: vector<4xf32>, vector<4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: broadcast_folding1
|
||||
|
||||
Reference in New Issue
Block a user