[MLIR][Math] Add erf to math dialect

Add math.erf lowering to libm call.
Add math.erf polynomial approximation.

Reviewed By: silvas, ezhulenev

Differential Revision: https://reviews.llvm.org/D112200
This commit is contained in:
Boian Petkantchin
2021-10-25 18:15:13 +00:00
committed by Sean Silva
parent b283d55c90
commit f1b922188e
9 changed files with 395 additions and 2 deletions

View File

@@ -285,6 +285,39 @@ def Math_SinOp : Math_FloatUnaryOp<"sin"> {
}];
}
//===----------------------------------------------------------------------===//
// ErfOp
//===----------------------------------------------------------------------===//
def Math_ErfOp : Math_FloatUnaryOp<"erf"> {
let summary = "error function of the specified value";
let description = [{
Syntax:
```
operation ::= ssa-id `=` `math.erf` ssa-use `:` type
```
The `erf` operation computes the error function. It takes one operand
and returns one result of the same type. This type may be a float scalar
type, a vector whose element type is float, or a tensor of floats. It has
no standard attributes.
Example:
```mlir
// Scalar error function value.
%a = math.erf %b : f64
// SIMD vector element-wise error function value.
%f = math.erf %g : vector<4xf32>
// Tensor element-wise error function value.
%x = math.erf %y : tensor<4x?xf8>
```
}];
}
//===----------------------------------------------------------------------===//
// ExpOp

View File

@@ -0,0 +1,29 @@
//===- Approximation.h - Math dialect -----------------------------*- C++-*-==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_MATH_TRANSFORMATIONS_APPROXIMATION_H_
#define MLIR_DIALECT_MATH_TRANSFORMATIONS_APPROXIMATION_H_
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/IR/PatternMatch.h"
namespace mlir {
namespace math {
struct ErfPolynomialApproximation : public OpRewritePattern<math::ErfOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(math::ErfOp op,
PatternRewriter &rewriter) const final;
};
} // namespace math
} // namespace mlir
#endif // MLIR_DIALECT_MATH_TRANSFORMATIONS_APPROXIMATION_H_

View File

@@ -116,6 +116,8 @@ void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns,
VecOpToScalarOp<math::TanhOp>>(patterns.getContext(), benefit);
patterns.add<ScalarOpToLibmCall<math::Atan2Op>>(patterns.getContext(),
"atan2f", "atan2", benefit);
patterns.add<ScalarOpToLibmCall<math::ErfOp>>(patterns.getContext(), "erff",
"erf", benefit);
patterns.add<ScalarOpToLibmCall<math::ExpM1Op>>(patterns.getContext(),
"expm1f", "expm1", benefit);
patterns.add<ScalarOpToLibmCall<math::TanhOp>>(patterns.getContext(), "tanhf",

View File

@@ -13,6 +13,7 @@
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Math/Transforms/Approximation.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
@@ -21,9 +22,12 @@
#include "mlir/Transforms/Bufferize.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/ArrayRef.h"
#include <climits>
#include <cstddef>
using namespace mlir;
using namespace mlir::math;
using namespace mlir::vector;
using TypePredicate = llvm::function_ref<bool(Type)>;
@@ -183,6 +187,24 @@ static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) {
return exp2ValueF32;
}
namespace {
Value makePolynomialCalculation(ImplicitLocOpBuilder &builder,
llvm::ArrayRef<Value> coeffs, Value x) {
auto width = vectorWidth(x.getType(), isF32);
if (coeffs.size() == 0) {
return broadcast(builder, f32Cst(builder, 0.0f), *width);
} else if (coeffs.size() == 1) {
return coeffs[0];
}
Value res = builder.create<math::FmaOp>(x, coeffs[coeffs.size() - 1],
coeffs[coeffs.size() - 2]);
for (auto i = ptrdiff_t(coeffs.size()) - 3; i >= 0; --i) {
res = builder.create<math::FmaOp>(x, res, coeffs[i]);
}
return res;
}
} // namespace
//----------------------------------------------------------------------------//
// TanhOp approximation.
//----------------------------------------------------------------------------//
@@ -465,6 +487,122 @@ Log1pApproximation::matchAndRewrite(math::Log1pOp op,
return success();
}
//----------------------------------------------------------------------------//
// Erf approximation.
//----------------------------------------------------------------------------//
// Approximates erf(x) with
// a - P(x)/Q(x)
// where P and Q are polynomials of degree 4.
// Different coefficients are chosen based on the value of x.
// The approximation error is ~2.5e-07.
// Boost's minimax tool that utilizes the Remez method was used to find the
// coefficients.
LogicalResult
ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
PatternRewriter &rewriter) const {
auto width = vectorWidth(op.operand().getType(), isF32);
if (!width.hasValue())
return rewriter.notifyMatchFailure(op, "unsupported operand type");
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
return broadcast(builder, value, *width);
};
const int intervalsCount = 3;
const int polyDegree = 4;
Value zero = bcast(f32Cst(builder, 0));
Value one = bcast(f32Cst(builder, 1));
Value pp[intervalsCount][polyDegree + 1];
pp[0][0] = bcast(f32Cst(builder, +0.00000000000000000e+00));
pp[0][1] = bcast(f32Cst(builder, +1.12837916222975858e+00));
pp[0][2] = bcast(f32Cst(builder, -5.23018562988006470e-01));
pp[0][3] = bcast(f32Cst(builder, +2.09741709609267072e-01));
pp[0][4] = bcast(f32Cst(builder, +2.58146801602987875e-02));
pp[1][0] = bcast(f32Cst(builder, +0.00000000000000000e+00));
pp[1][1] = bcast(f32Cst(builder, +1.12750687816789140e+00));
pp[1][2] = bcast(f32Cst(builder, -3.64721408487825775e-01));
pp[1][3] = bcast(f32Cst(builder, +1.18407396425136952e-01));
pp[1][4] = bcast(f32Cst(builder, +3.70645533056476558e-02));
pp[2][0] = bcast(f32Cst(builder, -3.30093071049483172e-03));
pp[2][1] = bcast(f32Cst(builder, +3.51961938357697011e-03));
pp[2][2] = bcast(f32Cst(builder, -1.41373622814988039e-03));
pp[2][3] = bcast(f32Cst(builder, +2.53447094961941348e-04));
pp[2][4] = bcast(f32Cst(builder, -1.71048029455037401e-05));
Value qq[intervalsCount][polyDegree + 1];
qq[0][0] = bcast(f32Cst(builder, +1.000000000000000000e+00));
qq[0][1] = bcast(f32Cst(builder, -4.635138185962547255e-01));
qq[0][2] = bcast(f32Cst(builder, +5.192301327279782447e-01));
qq[0][3] = bcast(f32Cst(builder, -1.318089722204810087e-01));
qq[0][4] = bcast(f32Cst(builder, +7.397964654672315005e-02));
qq[1][0] = bcast(f32Cst(builder, +1.00000000000000000e+00));
qq[1][1] = bcast(f32Cst(builder, -3.27607011824493086e-01));
qq[1][2] = bcast(f32Cst(builder, +4.48369090658821977e-01));
qq[1][3] = bcast(f32Cst(builder, -8.83462621207857930e-02));
qq[1][4] = bcast(f32Cst(builder, +5.72442770283176093e-02));
qq[2][0] = bcast(f32Cst(builder, +1.00000000000000000e+00));
qq[2][1] = bcast(f32Cst(builder, -2.06069165953913769e+00));
qq[2][2] = bcast(f32Cst(builder, +1.62705939945477759e+00));
qq[2][3] = bcast(f32Cst(builder, -5.83389859211130017e-01));
qq[2][4] = bcast(f32Cst(builder, +8.21908939856640930e-02));
Value offsets[intervalsCount];
offsets[0] = bcast(f32Cst(builder, 0));
offsets[1] = bcast(f32Cst(builder, 0));
offsets[2] = bcast(f32Cst(builder, 1));
Value bounds[intervalsCount];
bounds[0] = bcast(f32Cst(builder, 0.8));
bounds[1] = bcast(f32Cst(builder, 2));
bounds[2] = bcast(f32Cst(builder, 3.75));
Value isNegativeArg = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT,
op.operand(), zero);
Value negArg = builder.create<arith::NegFOp>(op.operand());
Value x = builder.create<SelectOp>(isNegativeArg, negArg, op.operand());
Value offset = offsets[0];
Value p[polyDegree + 1];
Value q[polyDegree + 1];
for (int i = 0; i <= polyDegree; ++i) {
p[i] = pp[0][i];
q[i] = qq[0][i];
}
// TODO: maybe use vector stacking to reduce the number of selects.
Value isLessThanBound[intervalsCount];
for (int j = 0; j < intervalsCount - 1; ++j) {
isLessThanBound[j] =
builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, bounds[j]);
for (int i = 0; i <= polyDegree; ++i) {
p[i] = builder.create<SelectOp>(isLessThanBound[j], p[i], pp[j + 1][i]);
q[i] = builder.create<SelectOp>(isLessThanBound[j], q[i], qq[j + 1][i]);
}
offset =
builder.create<SelectOp>(isLessThanBound[j], offset, offsets[j + 1]);
}
isLessThanBound[intervalsCount - 1] = builder.create<arith::CmpFOp>(
arith::CmpFPredicate::ULT, x, bounds[intervalsCount - 1]);
Value pPoly = makePolynomialCalculation(builder, p, x);
Value qPoly = makePolynomialCalculation(builder, q, x);
Value rationalPoly = builder.create<arith::DivFOp>(pPoly, qPoly);
Value formula = builder.create<arith::AddFOp>(offset, rationalPoly);
formula = builder.create<SelectOp>(isLessThanBound[intervalsCount - 1],
formula, one);
// erf is odd function: erf(x) = -erf(-x).
Value negFormula = builder.create<arith::NegFOp>(formula);
Value res = builder.create<SelectOp>(isNegativeArg, negFormula, formula);
rewriter.replaceOp(op, res);
return success();
}
//----------------------------------------------------------------------------//
// Exp approximation.
//----------------------------------------------------------------------------//
@@ -848,8 +986,8 @@ void mlir::populateMathPolynomialApproximationPatterns(
RewritePatternSet &patterns,
const MathPolynomialApproximationOptions &options) {
patterns.add<TanhApproximation, LogApproximation, Log2Approximation,
Log1pApproximation, ExpApproximation, ExpM1Approximation,
SinAndCosApproximation<true, math::SinOp>,
Log1pApproximation, ErfPolynomialApproximation, ExpApproximation,
ExpM1Approximation, SinAndCosApproximation<true, math::SinOp>,
SinAndCosApproximation<false, math::CosOp>>(
patterns.getContext());
if (options.enableAvx2)

View File

@@ -1,5 +1,7 @@
// RUN: mlir-opt %s -convert-math-to-libm -canonicalize | FileCheck %s
// CHECK-DAG: @erf(f64) -> f64
// CHECK-DAG: @erff(f32) -> f32
// CHECK-DAG: @expm1(f64) -> f64
// CHECK-DAG: @expm1f(f32) -> f32
// CHECK-DAG: @atan2(f64, f64) -> f64
@@ -32,6 +34,18 @@ func @atan2_caller(%float: f32, %double: f64) -> (f32, f64) {
return %float_result, %double_result : f32, f64
}
// CHECK-LABEL: func @erf_caller
// CHECK-SAME: %[[FLOAT:.*]]: f32
// CHECK-SAME: %[[DOUBLE:.*]]: f64
func @erf_caller(%float: f32, %double: f64) -> (f32, f64) {
// CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @erff(%[[FLOAT]]) : (f32) -> f32
%float_result = math.erf %float : f32
// CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @erf(%[[DOUBLE]]) : (f64) -> f64
%double_result = math.erf %double : f64
// CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
return %float_result, %double_result : f32, f64
}
// CHECK-LABEL: func @expm1_caller
// CHECK-SAME: %[[FLOAT:.*]]: f32
// CHECK-SAME: %[[DOUBLE:.*]]: f64

View File

@@ -50,6 +50,18 @@ func @sin(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
return
}
// CHECK-LABEL: func @erf(
// CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>)
func @erf(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
// CHECK: %{{.*}} = math.erf %[[F]] : f32
%0 = math.erf %f : f32
// CHECK: %{{.*}} = math.erf %[[V]] : vector<4xf32>
%1 = math.erf %v : vector<4xf32>
// CHECK: %{{.*}} = math.erf %[[T]] : tensor<4x4x?xf32>
%2 = math.erf %t : tensor<4x4x?xf32>
return
}
// CHECK-LABEL: func @exp(
// CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>)
func @exp(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {

View File

@@ -5,6 +5,95 @@
// Check that all math functions lowered to approximations built from
// standard operations (add, mul, fma, shift, etc...).
// CHECK-LABEL: func @erf_scalar(
// CHECK-SAME: %[[val_arg0:.*]]: f32) -> f32 {
// CHECK-DAG: %[[val_cst:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[val_cst_0:.*]] = arith.constant 1.000000e+00 : f32
// CHECK-DAG: %[[val_cst_1:.*]] = arith.constant 1.12837911 : f32
// CHECK-DAG: %[[val_cst_2:.*]] = arith.constant -0.523018539 : f32
// CHECK-DAG: %[[val_cst_3:.*]] = arith.constant 0.209741712 : f32
// CHECK-DAG: %[[val_cst_4:.*]] = arith.constant 0.0258146804 : f32
// CHECK-DAG: %[[val_cst_5:.*]] = arith.constant 1.12750685 : f32
// CHECK-DAG: %[[val_cst_6:.*]] = arith.constant -0.364721417 : f32
// CHECK-DAG: %[[val_cst_7:.*]] = arith.constant 0.118407398 : f32
// CHECK-DAG: %[[val_cst_8:.*]] = arith.constant 0.0370645523 : f32
// CHECK-DAG: %[[val_cst_9:.*]] = arith.constant -0.00330093061 : f32
// CHECK-DAG: %[[val_cst_10:.*]] = arith.constant 0.00351961935 : f32
// CHECK-DAG: %[[val_cst_11:.*]] = arith.constant -0.00141373626 : f32
// CHECK-DAG: %[[val_cst_12:.*]] = arith.constant 2.53447099E-4 : f32
// CHECK-DAG: %[[val_cst_13:.*]] = arith.constant -1.71048032E-5 : f32
// CHECK-DAG: %[[val_cst_14:.*]] = arith.constant -0.463513821 : f32
// CHECK-DAG: %[[val_cst_15:.*]] = arith.constant 0.519230127 : f32
// CHECK-DAG: %[[val_cst_16:.*]] = arith.constant -0.131808966 : f32
// CHECK-DAG: %[[val_cst_17:.*]] = arith.constant 0.0739796459 : f32
// CHECK-DAG: %[[val_cst_18:.*]] = arith.constant -3.276070e-01 : f32
// CHECK-DAG: %[[val_cst_19:.*]] = arith.constant 0.448369086 : f32
// CHECK-DAG: %[[val_cst_20:.*]] = arith.constant -0.0883462652 : f32
// CHECK-DAG: %[[val_cst_21:.*]] = arith.constant 0.0572442785 : f32
// CHECK-DAG: %[[val_cst_22:.*]] = arith.constant -2.0606916 : f32
// CHECK-DAG: %[[val_cst_23:.*]] = arith.constant 1.62705934 : f32
// CHECK-DAG: %[[val_cst_24:.*]] = arith.constant -0.583389878 : f32
// CHECK-DAG: %[[val_cst_25:.*]] = arith.constant 0.0821908935 : f32
// CHECK-DAG: %[[val_cst_26:.*]] = arith.constant 8.000000e-01 : f32
// CHECK-DAG: %[[val_cst_27:.*]] = arith.constant 2.000000e+00 : f32
// CHECK-DAG: %[[val_cst_28:.*]] = arith.constant 3.750000e+00 : f32
// CHECK: %[[val_0:.*]] = arith.cmpf olt, %[[val_arg0]], %[[val_cst]] : f32
// CHECK: %[[val_1:.*]] = arith.negf %[[val_arg0]] : f32
// CHECK: %[[val_2:.*]] = select %[[val_0]], %[[val_1]], %[[val_arg0]] : f32
// CHECK: %[[val_3:.*]] = arith.cmpf olt, %[[val_2]], %[[val_cst_26]] : f32
// CHECK: %[[val_4:.*]] = select %[[val_3]], %[[val_cst_1]], %[[val_cst_5]] : f32
// CHECK: %[[val_5:.*]] = select %[[val_3]], %[[val_cst_14]], %[[val_cst_18]] : f32
// CHECK: %[[val_6:.*]] = select %[[val_3]], %[[val_cst_2]], %[[val_cst_6]] : f32
// CHECK: %[[val_7:.*]] = select %[[val_3]], %[[val_cst_15]], %[[val_cst_19]] : f32
// CHECK: %[[val_8:.*]] = select %[[val_3]], %[[val_cst_3]], %[[val_cst_7]] : f32
// CHECK: %[[val_9:.*]] = select %[[val_3]], %[[val_cst_16]], %[[val_cst_20]] : f32
// CHECK: %[[val_10:.*]] = select %[[val_3]], %[[val_cst_4]], %[[val_cst_8]] : f32
// CHECK: %[[val_11:.*]] = select %[[val_3]], %[[val_cst_17]], %[[val_cst_21]] : f32
// CHECK: %[[val_12:.*]] = arith.cmpf olt, %[[val_2]], %[[val_cst_27]] : f32
// CHECK: %[[val_13:.*]] = select %[[val_12]], %[[val_cst]], %[[val_cst_9]] : f32
// CHECK: %[[val_14:.*]] = select %[[val_12]], %[[val_4]], %[[val_cst_10]] : f32
// CHECK: %[[val_15:.*]] = select %[[val_12]], %[[val_5]], %[[val_cst_22]] : f32
// CHECK: %[[val_16:.*]] = select %[[val_12]], %[[val_6]], %[[val_cst_11]] : f32
// CHECK: %[[val_17:.*]] = select %[[val_12]], %[[val_7]], %[[val_cst_23]] : f32
// CHECK: %[[val_18:.*]] = select %[[val_12]], %[[val_8]], %[[val_cst_12]] : f32
// CHECK: %[[val_19:.*]] = select %[[val_12]], %[[val_9]], %[[val_cst_24]] : f32
// CHECK: %[[val_20:.*]] = select %[[val_12]], %[[val_10]], %[[val_cst_13]] : f32
// CHECK: %[[val_21:.*]] = select %[[val_12]], %[[val_11]], %[[val_cst_25]] : f32
// CHECK: %[[val_22:.*]] = select %[[val_12]], %[[val_cst]], %[[val_cst_0]] : f32
// CHECK: %[[val_23:.*]] = arith.cmpf ult, %[[val_2]], %[[val_cst_28]] : f32
// CHECK: %[[val_24:.*]] = math.fma %[[val_2]], %[[val_20]], %[[val_18]] : f32
// CHECK: %[[val_25:.*]] = math.fma %[[val_2]], %[[val_24]], %[[val_16]] : f32
// CHECK: %[[val_26:.*]] = math.fma %[[val_2]], %[[val_25]], %[[val_14]] : f32
// CHECK: %[[val_27:.*]] = math.fma %[[val_2]], %[[val_26]], %[[val_13]] : f32
// CHECK: %[[val_28:.*]] = math.fma %[[val_2]], %[[val_21]], %[[val_19]] : f32
// CHECK: %[[val_29:.*]] = math.fma %[[val_2]], %[[val_28]], %[[val_17]] : f32
// CHECK: %[[val_30:.*]] = math.fma %[[val_2]], %[[val_29]], %[[val_15]] : f32
// CHECK: %[[val_31:.*]] = math.fma %[[val_2]], %[[val_30]], %[[val_cst_0]] : f32
// CHECK: %[[val_32:.*]] = arith.divf %[[val_27]], %[[val_31]] : f32
// CHECK: %[[val_33:.*]] = arith.addf %[[val_22]], %[[val_32]] : f32
// CHECK: %[[val_34:.*]] = select %[[val_23]], %[[val_33]], %[[val_cst_0]] : f32
// CHECK: %[[val_35:.*]] = arith.negf %[[val_34]] : f32
// CHECK: %[[val_36:.*]] = select %[[val_0]], %[[val_35]], %[[val_34]] : f32
// CHECK: return %[[val_36]] : f32
// CHECK: }
func @erf_scalar(%arg0: f32) -> f32 {
%0 = math.erf %arg0 : f32
return %0 : f32
}
// CHECK-LABEL: func @erf_vector(
// CHECK-SAME: %[[arg0:.*]]: vector<8xf32>) -> vector<8xf32> {
// CHECK: %[[zero:.*]] = arith.constant dense<0.000000e+00> : vector<8xf32>
// CHECK-NOT: erf
// CHECK-COUNT-20: select
// CHECK: %[[res:.*]] = select
// CHECK: return %[[res]] : vector<8xf32>
// CHECK: }
func @erf_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
%0 = math.erf %arg0 : vector<8xf32>
return %0 : vector<8xf32>
}
// CHECK-LABEL: func @exp_scalar(
// CHECK-SAME: %[[VAL_0:.*]]: f32) -> f32 {
// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0.693147182 : f32

View File

@@ -152,6 +152,78 @@ func @log1p() {
return
}
// -------------------------------------------------------------------------- //
// Erf.
// -------------------------------------------------------------------------- //
func @erf() {
// CHECK: -0.000274406
%val1 = arith.constant -2.431864e-4 : f32
%erfVal1 = math.erf %val1 : f32
vector.print %erfVal1 : f32
// CHECK: 0.742095
%val2 = arith.constant 0.79999 : f32
%erfVal2 = math.erf %val2 : f32
vector.print %erfVal2 : f32
// CHECK: 0.742101
%val3 = arith.constant 0.8 : f32
%erfVal3 = math.erf %val3 : f32
vector.print %erfVal3 : f32
// CHECK: 0.995322
%val4 = arith.constant 1.99999 : f32
%erfVal4 = math.erf %val4 : f32
vector.print %erfVal4 : f32
// CHECK: 0.995322
%val5 = arith.constant 2.0 : f32
%erfVal5 = math.erf %val5 : f32
vector.print %erfVal5 : f32
// CHECK: 1
%val6 = arith.constant 3.74999 : f32
%erfVal6 = math.erf %val6 : f32
vector.print %erfVal6 : f32
// CHECK: 1
%val7 = arith.constant 3.75 : f32
%erfVal7 = math.erf %val7 : f32
vector.print %erfVal7 : f32
// CHECK: -1
%negativeInf = arith.constant 0xff800000 : f32
%erfNegativeInf = math.erf %negativeInf : f32
vector.print %erfNegativeInf : f32
// CHECK: -1, -1, -0.913759, -0.731446
%vecVals1 = arith.constant dense<[-3.4028235e+38, -4.54318, -1.2130899, -7.8234202e-01]> : vector<4xf32>
%erfVecVals1 = math.erf %vecVals1 : vector<4xf32>
vector.print %erfVecVals1 : vector<4xf32>
// CHECK: -1.3264e-38, 0, 1.3264e-38, 0.121319
%vecVals2 = arith.constant dense<[-1.1754944e-38, 0.0, 1.1754944e-38, 1.0793410e-01]> : vector<4xf32>
%erfVecVals2 = math.erf %vecVals2 : vector<4xf32>
vector.print %erfVecVals2 : vector<4xf32>
// CHECK: 0.919477, 0.999069, 1, 1
%vecVals3 = arith.constant dense<[1.23578, 2.34093, 3.82342, 3.4028235e+38]> : vector<4xf32>
%erfVecVals3 = math.erf %vecVals3 : vector<4xf32>
vector.print %erfVecVals3 : vector<4xf32>
// CHECK: 1
%inf = arith.constant 0x7f800000 : f32
%erfInf = math.erf %inf : f32
vector.print %erfInf : f32
// CHECK: nan
%nan = arith.constant 0x7fc00000 : f32
%erfNan = math.erf %nan : f32
vector.print %erfNan : f32
return
}
// -------------------------------------------------------------------------- //
// Exp.
// -------------------------------------------------------------------------- //
@@ -305,6 +377,7 @@ func @main() {
call @log(): () -> ()
call @log2(): () -> ()
call @log1p(): () -> ()
call @erf(): () -> ()
call @exp(): () -> ()
call @expm1(): () -> ()
call @sin(): () -> ()

View File

@@ -43,6 +43,9 @@ syn keyword mlirOps memref_shape_cast mulf muli negf powf prefetch rsqrt sitofp
syn keyword mlirOps splat store select sqrt subf subi subview tanh
syn keyword mlirOps view
" Math ops.
syn match mlirOps /\<math\.erf\>/
" Affine ops.
syn match mlirOps /\<affine\.apply\>/
syn match mlirOps /\<affine\.dma_start\>/