[mlir] Use double format when parsing bfloat16 hexadecimal values

Summary: bfloat16 doesn't have a valid APFloat format, so we have to use double semantics when storing it. This change makes sure that hexadecimal values can be round-tripped properly given this fact.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D72667
This commit is contained in:
River Riddle
2020-01-14 13:47:21 -08:00
parent a3490e3e3d
commit 1bd14ce392
3 changed files with 34 additions and 14 deletions

View File

@@ -1709,8 +1709,14 @@ Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
/// Construct a float attribute bitwise equivalent to the integer literal.
static FloatAttr buildHexadecimalFloatLiteral(Parser *p, FloatType type,
uint64_t value) {
int width = type.getIntOrFloatBitWidth();
APInt apInt(width, value);
// FIXME: bfloat is currently stored as a double internally because it doesn't
// have valid APFloat semantics.
if (type.isF64() || type.isBF16()) {
APFloat apFloat(type.getFloatSemantics(), APInt(/*numBits=*/64, value));
return p->builder.getFloatAttr(type, apFloat);
}
APInt apInt(type.getWidth(), value);
if (apInt != value) {
p->emitError("hexadecimal float constant out of range for type");
return nullptr;
@@ -1741,11 +1747,6 @@ Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
}
if (auto floatType = type.dyn_cast<FloatType>()) {
// TODO(zinenko): Update once hex format for bfloat16 is supported.
if (type.isBF16())
return emitError(loc,
"hexadecimal float literal not supported for bfloat16"),
nullptr;
if (isNegative)
return emitError(
loc,

View File

@@ -1120,13 +1120,6 @@ func @invalid_region_dominance() {
// -----
func @hexadecimal_bf16() {
// expected-error @+1 {{hexadecimal float literal not supported for bfloat16}}
"foo"() {value = 0xffff : bf16} : () -> ()
}
// -----
func @hexadecimal_float_leading_minus() {
// expected-error @+1 {{hexadecimal float literal should not have a leading minus}}
"foo"() {value = -0x7fff : f16} : () -> ()

View File

@@ -1030,6 +1030,32 @@ func @f64_special_values() {
return
}
// FIXME: bfloat16 currently uses f64 as a storage format. This test should be
// changed when that gets fixed.
// CHECK-LABEL: @bfloat16_special_values
func @bfloat16_special_values() {
// bfloat16 signaling NaNs.
// CHECK: constant 0x7FF0000000000001 : bf16
%0 = constant 0x7FF0000000000001 : bf16
// CHECK: constant 0x7FF8000000000000 : bf16
%1 = constant 0x7FF8000000000000 : bf16
// bfloat16 quiet NaNs.
// CHECK: constant 0x7FF0000001000000 : bf16
%2 = constant 0x7FF0000001000000 : bf16
// CHECK: constant 0xFFF0000001000000 : bf16
%3 = constant 0xFFF0000001000000 : bf16
// bfloat16 positive infinity.
// CHECK: constant 0x7FF0000000000000 : bf16
%4 = constant 0x7FF0000000000000 : bf16
// bfloat16 negative infinity.
// CHECK: constant 0xFFF0000000000000 : bf16
%5 = constant 0xFFF0000000000000 : bf16
return
}
// We want to print floats in exponential notation with 6 significant digits,
// but it may lead to precision loss when parsing back, in which case we print
// the decimal form instead.