mirror of
https://github.com/intel/llvm.git
synced 2026-01-31 07:27:33 +08:00
[mlir][linalg][python] Add exp and log to the OpDSL.
Introduce the exp and log function in OpDSL. Add the soft plus operator to test the emitted IR in Python and C++. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D105420
This commit is contained in:
@@ -7,6 +7,7 @@ from typing import Dict, Sequence
|
||||
from mlir.ir import *
|
||||
from mlir.dialects import linalg
|
||||
from mlir.dialects import std
|
||||
from mlir.dialects import math
|
||||
# TODO: resolve name collision for Linalg functionality that is injected inside
|
||||
# the _mlir.dialects.linalg directly via pybind.
|
||||
from _mlir.dialects.linalg import fill_builtin_region
|
||||
@@ -293,6 +294,16 @@ class _BodyBuilder:
|
||||
return std.AddIOp(lhs.type, lhs, rhs).result
|
||||
raise NotImplementedError("Unsupported 'add' operand: {lhs}")
|
||||
|
||||
def _eval_exp(self, x: Value) -> Value:
|
||||
if _is_floating_point_type(x.type):
|
||||
return math.ExpOp(x.type, x).result
|
||||
raise NotImplementedError("Unsupported 'exp' operand: {x}")
|
||||
|
||||
def _eval_log(self, x: Value) -> Value:
|
||||
if _is_floating_point_type(x.type):
|
||||
return math.LogOp(x.type, x).result
|
||||
raise NotImplementedError("Unsupported 'log' operand: {x}")
|
||||
|
||||
def _eval_sub(self, lhs: Value, rhs: Value) -> Value:
|
||||
if _is_floating_point_type(lhs.type):
|
||||
return std.SubFOp(lhs.type, lhs, rhs).result
|
||||
|
||||
@@ -209,3 +209,16 @@ def fill_rng_2d(
|
||||
offset = cast(F64, const(2147483647))
|
||||
scaling = (max - min) * inv_range
|
||||
O[D.m, D.n] = cast(T, (offset + cast(F64, rand2)) * scaling + min)
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
def soft_plus_2d(
|
||||
I=TensorDef(T, S.M, S.N), O=TensorDef(U, S.M, S.N, output=True)):
|
||||
"""Implements the soft plus operator.
|
||||
|
||||
Numeric casting is performed on the input operand, promoting it to the same
|
||||
data type as the accumulator/output.
|
||||
"""
|
||||
domain(D.m, D.n)
|
||||
O[D.m, D.n] = \
|
||||
PrimFn.log(cast(U, const(1.0)) + PrimFn.exp(cast(U, I[D.m, D.n])))
|
||||
|
||||
Reference in New Issue
Block a user