mirror of
https://github.com/intel/llvm.git
synced 2026-01-16 21:55:39 +08:00
[ROCDL] added math instructions to the ROCDL dialect (#169672)
Exposed llvm amdgcn math intrinsic calls through ROCDL
This commit is contained in:
committed by
GitHub
parent
c6e23ab807
commit
7305ed7e15
@@ -1913,6 +1913,33 @@ def ROCDL_FMed3Op : ROCDL_IntrOp<"fmed3", [0], [], [Pure, AllTypesMatch<["res",
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Math operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class ROCDL_Math_IntrOp<string mnemonic, list<Trait> traits = [Pure]> :
|
||||
ROCDL_IntrOp<mnemonic, [0], [], traits, 1>,
|
||||
Arguments<(ins LLVM_AnyFloat:$arg)> {
|
||||
let results = (outs LLVM_AnyFloat:$res);
|
||||
let description = [{
|
||||
Note: In the general case, prefer the conventional `arith`, `math`, or `llvm` ops over this.
|
||||
Use this ROCDL-specific operation only when you fully understand its implication and
|
||||
when it is strictly necessary. This op is usually chosen when a small loss in precision is
|
||||
acceptable in exchange for higher execution speed.
|
||||
}];
|
||||
let assemblyFormat =
|
||||
"$arg qualified(type($arg)) attr-dict `->` qualified(type($res))";
|
||||
}
|
||||
|
||||
def ROCDLTanh : ROCDL_Math_IntrOp<"tanh">;
|
||||
def ROCDLSin : ROCDL_Math_IntrOp<"sin">;
|
||||
def ROCDLCos : ROCDL_Math_IntrOp<"cos">;
|
||||
def ROCDLRcp : ROCDL_Math_IntrOp<"rcp">;
|
||||
def ROCDLExp : ROCDL_Math_IntrOp<"exp">;
|
||||
def ROCDLExp2 : ROCDL_Math_IntrOp<"exp2">;
|
||||
def ROCDLLog : ROCDL_Math_IntrOp<"log">;
|
||||
def ROCDLSqrt : ROCDL_Math_IntrOp<"sqrt">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ROCDL target attribute.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -49,6 +49,59 @@ func.func @rocdl.fmed3.vector(%a: vector<4xf16>, %b: vector<4xf16>, %c: vector<4
|
||||
llvm.return %0 : vector<4xf16>
|
||||
}
|
||||
|
||||
func.func @rocdl.math.ops(%a: f32, %b: f16, %c: bf16) {
|
||||
// CHECK-LABEL: rocdl.math.ops
|
||||
// CHECK: %{{.*}} = rocdl.tanh %{{.*}} f32 -> f32
|
||||
// CHECK: %{{.*}} = rocdl.tanh %{{.*}} f16 -> f16
|
||||
// CHECK: %{{.*}} = rocdl.tanh %{{.*}} bf16 -> bf16
|
||||
%tanh0 = rocdl.tanh %a f32 -> f32
|
||||
%tanh1 = rocdl.tanh %b f16 -> f16
|
||||
%tanh2 = rocdl.tanh %c bf16 -> bf16
|
||||
|
||||
// CHECK: %{{.*}} = rocdl.sin %{{.*}} f32 -> f32
|
||||
// CHECK: %{{.*}} = rocdl.sin %{{.*}} f16 -> f16
|
||||
// CHECK: %{{.*}} = rocdl.sin %{{.*}} bf16 -> bf16
|
||||
%sin0 = rocdl.sin %a f32 -> f32
|
||||
%sin1 = rocdl.sin %b f16 -> f16
|
||||
%sin2 = rocdl.sin %c bf16 -> bf16
|
||||
|
||||
// CHECK: %{{.*}} = rocdl.cos %{{.*}} f32 -> f32
|
||||
// CHECK: %{{.*}} = rocdl.cos %{{.*}} f16 -> f16
|
||||
// CHECK: %{{.*}} = rocdl.cos %{{.*}} bf16 -> bf16
|
||||
%cos0 = rocdl.cos %a f32 -> f32
|
||||
%cos1 = rocdl.cos %b f16 -> f16
|
||||
%cos2 = rocdl.cos %c bf16 -> bf16
|
||||
|
||||
// CHECK: %{{.*}} = rocdl.rcp %{{.*}} f32 -> f32
|
||||
// CHECK: %{{.*}} = rocdl.rcp %{{.*}} f16 -> f16
|
||||
// CHECK: %{{.*}} = rocdl.rcp %{{.*}} bf16 -> bf16
|
||||
%rcp0 = rocdl.rcp %a f32 -> f32
|
||||
%rcp1 = rocdl.rcp %b f16 -> f16
|
||||
%rcp2 = rocdl.rcp %c bf16 -> bf16
|
||||
|
||||
// CHECK: %{{.*}} = rocdl.exp2 %{{.*}} f32 -> f32
|
||||
// CHECK: %{{.*}} = rocdl.exp2 %{{.*}} f16 -> f16
|
||||
// CHECK: %{{.*}} = rocdl.exp2 %{{.*}} bf16 -> bf16
|
||||
%exp2_0 = rocdl.exp2 %a f32 -> f32
|
||||
%exp2_1 = rocdl.exp2 %b f16 -> f16
|
||||
%exp2_2 = rocdl.exp2 %c bf16 -> bf16
|
||||
|
||||
// CHECK: %{{.*}} = rocdl.log %{{.*}} f32 -> f32
|
||||
// CHECK: %{{.*}} = rocdl.log %{{.*}} f16 -> f16
|
||||
// CHECK: %{{.*}} = rocdl.log %{{.*}} bf16 -> bf16
|
||||
%log0 = rocdl.log %a f32 -> f32
|
||||
%log1 = rocdl.log %b f16 -> f16
|
||||
%log2 = rocdl.log %c bf16 -> bf16
|
||||
|
||||
// CHECK: %{{.*}} = rocdl.sqrt %{{.*}} f32 -> f32
|
||||
// CHECK: %{{.*}} = rocdl.sqrt %{{.*}} f16 -> f16
|
||||
// CHECK: %{{.*}} = rocdl.sqrt %{{.*}} bf16 -> bf16
|
||||
%sqrt0 = rocdl.sqrt %a f32 -> f32
|
||||
%sqrt1 = rocdl.sqrt %b f16 -> f16
|
||||
%sqrt2 = rocdl.sqrt %c bf16 -> bf16
|
||||
llvm.return
|
||||
}
|
||||
|
||||
func.func @rocdl.barrier() {
|
||||
// CHECK: rocdl.barrier
|
||||
rocdl.barrier
|
||||
|
||||
@@ -61,6 +61,59 @@ llvm.func @kernel_func_workgroups()
|
||||
llvm.return
|
||||
}
|
||||
|
||||
llvm.func @kernel_math_ops(%a: f32, %b: f16, %c: bf16) {
|
||||
// CHECK-LABEL: kernel_math_ops
|
||||
// CHECK: call float @llvm.amdgcn.tanh.f32(float %{{.*}})
|
||||
// CHECK: call half @llvm.amdgcn.tanh.f16(half %{{.*}})
|
||||
// CHECK: call bfloat @llvm.amdgcn.tanh.bf16(bfloat %{{.*}})
|
||||
%tanh0 = rocdl.tanh %a f32 -> f32
|
||||
%tanh1 = rocdl.tanh %b f16 -> f16
|
||||
%tanh2 = rocdl.tanh %c bf16 -> bf16
|
||||
|
||||
// CHECK: call float @llvm.amdgcn.sin.f32(float %{{.*}})
|
||||
// CHECK: call half @llvm.amdgcn.sin.f16(half %{{.*}})
|
||||
// CHECK: call bfloat @llvm.amdgcn.sin.bf16(bfloat %{{.*}})
|
||||
%sin0 = rocdl.sin %a f32 -> f32
|
||||
%sin1 = rocdl.sin %b f16 -> f16
|
||||
%sin2 = rocdl.sin %c bf16 -> bf16
|
||||
|
||||
// CHECK: call float @llvm.amdgcn.cos.f32(float %{{.*}})
|
||||
// CHECK: call half @llvm.amdgcn.cos.f16(half %{{.*}})
|
||||
// CHECK: call bfloat @llvm.amdgcn.cos.bf16(bfloat %{{.*}})
|
||||
%cos0 = rocdl.cos %a f32 -> f32
|
||||
%cos1 = rocdl.cos %b f16 -> f16
|
||||
%cos2 = rocdl.cos %c bf16 -> bf16
|
||||
|
||||
// CHECK: call float @llvm.amdgcn.rcp.f32(float %{{.*}})
|
||||
// CHECK: call half @llvm.amdgcn.rcp.f16(half %{{.*}})
|
||||
// CHECK: call bfloat @llvm.amdgcn.rcp.bf16(bfloat %{{.*}})
|
||||
%rcp0 = rocdl.rcp %a f32 -> f32
|
||||
%rcp1 = rocdl.rcp %b f16 -> f16
|
||||
%rcp2 = rocdl.rcp %c bf16 -> bf16
|
||||
|
||||
// CHECK: call float @llvm.amdgcn.exp2.f32(float %{{.*}})
|
||||
// CHECK: call half @llvm.amdgcn.exp2.f16(half %{{.*}})
|
||||
// CHECK: call bfloat @llvm.amdgcn.exp2.bf16(bfloat %{{.*}})
|
||||
%exp2_0 = rocdl.exp2 %a f32 -> f32
|
||||
%exp2_1 = rocdl.exp2 %b f16 -> f16
|
||||
%exp2_2 = rocdl.exp2 %c bf16 -> bf16
|
||||
|
||||
// CHECK: call float @llvm.amdgcn.log.f32(float %{{.*}})
|
||||
// CHECK: call half @llvm.amdgcn.log.f16(half %{{.*}})
|
||||
// CHECK: call bfloat @llvm.amdgcn.log.bf16(bfloat %{{.*}})
|
||||
%log0 = rocdl.log %a f32 -> f32
|
||||
%log1 = rocdl.log %b f16 -> f16
|
||||
%log2 = rocdl.log %c bf16 -> bf16
|
||||
|
||||
// CHECK: call float @llvm.amdgcn.sqrt.f32(float %{{.*}})
|
||||
// CHECK: call half @llvm.amdgcn.sqrt.f16(half %{{.*}})
|
||||
// CHECK: call bfloat @llvm.amdgcn.sqrt.bf16(bfloat %{{.*}})
|
||||
%sqrt0 = rocdl.sqrt %a f32 -> f32
|
||||
%sqrt1 = rocdl.sqrt %b f16 -> f16
|
||||
%sqrt2 = rocdl.sqrt %c bf16 -> bf16
|
||||
llvm.return
|
||||
}
|
||||
|
||||
llvm.func @known_block_sizes()
|
||||
attributes {rocdl.kernel,
|
||||
rocdl.flat_work_group_size = "128,128",
|
||||
|
||||
Reference in New Issue
Block a user