Add a math.cbrt instruction and lowering to libm.

There's currently no way to get accurate cube roots in the math dialect.
powf(x, 1/3.0) is too inaccurate in some cases.

Reviewed By: akuegel

Differential Revision: https://reviews.llvm.org/D140842
This commit is contained in:
Johannes Reifferscheid
2023-01-02 15:23:12 +01:00
parent 367e618fd6
commit 998a3a3894
4 changed files with 50 additions and 0 deletions

View File

@@ -196,6 +196,28 @@ def Math_Atan2Op : Math_FloatBinaryOp<"atan2">{
let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
// CbrtOp
//===----------------------------------------------------------------------===//
def Math_CbrtOp : Math_FloatUnaryOp<"cbrt"> {
let summary = "cube root of the specified value";
let description = [{
The `cbrt` operation computes the cube root. It takes one operand of
floating point type (i.e., scalar, tensor or vector) and returns one result
of the same type. It has no standard attributes.
Example:
```mlir
// Scalar cube root value.
%a = math.cbrt %b : f64
```
Note: This op is not equivalent to powf(..., 1/3.0).
}];
}
//===----------------------------------------------------------------------===//
// CeilOp
//===----------------------------------------------------------------------===//

View File

@@ -171,6 +171,8 @@ void mlir::populateMathToLibmConversionPatterns(
"atan", benefit);
patterns.add<ScalarOpToLibmCall<math::Atan2Op>>(patterns.getContext(),
"atan2f", "atan2", benefit);
patterns.add<ScalarOpToLibmCall<math::CbrtOp>>(patterns.getContext(), "cbrtf",
"cbrt", benefit);
patterns.add<ScalarOpToLibmCall<math::ErfOp>>(patterns.getContext(), "erff",
"erf", benefit);
patterns.add<ScalarOpToLibmCall<math::ExpM1Op>>(patterns.getContext(),

View File

@@ -8,6 +8,8 @@
// CHECK-DAG: @expm1f(f32) -> f32 attributes {llvm.readnone}
// CHECK-DAG: @atan2(f64, f64) -> f64 attributes {llvm.readnone}
// CHECK-DAG: @atan2f(f32, f32) -> f32 attributes {llvm.readnone}
// CHECK-DAG: @cbrt(f64) -> f64 attributes {llvm.readnone}
// CHECK-DAG: @cbrtf(f32) -> f32 attributes {llvm.readnone}
// CHECK-DAG: @tan(f64) -> f64 attributes {llvm.readnone}
// CHECK-DAG: @tanf(f32) -> f32 attributes {llvm.readnone}
// CHECK-DAG: @tanh(f64) -> f64 attributes {llvm.readnone}
@@ -241,6 +243,18 @@ func.func @trunc_caller(%float: f32, %double: f64) -> (f32, f64) {
return %float_result, %double_result : f32, f64
}
// CHECK-LABEL: func @cbrt_caller
// CHECK-SAME: %[[FLOAT:.*]]: f32
// CHECK-SAME: %[[DOUBLE:.*]]: f64
func.func @cbrt_caller(%float: f32, %double: f64) -> (f32, f64) {
// CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @cbrtf(%[[FLOAT]]) : (f32) -> f32
%float_result = math.cbrt %float : f32
// CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @cbrt(%[[DOUBLE]]) : (f64) -> f64
%double_result = math.cbrt %double : f64
// CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
return %float_result, %double_result : f32, f64
}
// CHECK-LABEL: func @cos_caller
// CHECK-SAME: %[[FLOAT:.*]]: f32
// CHECK-SAME: %[[DOUBLE:.*]]: f64

View File

@@ -26,6 +26,18 @@ func.func @atan2(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
return
}
// CHECK-LABEL: func @cbrt(
// CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>)
func.func @cbrt(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
// CHECK: %{{.*}} = math.cbrt %[[F]] : f32
%0 = math.cbrt %f : f32
// CHECK: %{{.*}} = math.cbrt %[[V]] : vector<4xf32>
%1 = math.cbrt %v : vector<4xf32>
// CHECK: %{{.*}} = math.cbrt %[[T]] : tensor<4x4x?xf32>
%2 = math.cbrt %t : tensor<4x4x?xf32>
return
}
// CHECK-LABEL: func @cos(
// CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>)
func.func @cos(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {