[HLSL][DXIL] implement sqrt intrinsic (#86560)

completes #86187
- fix hlsl_intrinsic to cover the correct cases
- move to using `__builtin_elementwise_sqrt`
- add lowering of `Intrinsic::sqrt` to dxilop 24.
This commit is contained in:
Farzon Lotfi
2024-03-25 18:02:30 -04:00
committed by GitHub
parent 060df78cdb
commit 4cea2d049f
5 changed files with 104 additions and 34 deletions

View File

@@ -1366,14 +1366,26 @@ float4 sin(float4);
/// \param Val The input value.
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_sqrtf16)
half sqrt(half In);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_sqrt)
half sqrt(half);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_sqrt)
half2 sqrt(half2);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_sqrt)
half3 sqrt(half3);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_sqrt)
half4 sqrt(half4);
_HLSL_BUILTIN_ALIAS(__builtin_sqrtf)
float sqrt(float In);
_HLSL_BUILTIN_ALIAS(__builtin_sqrt)
double sqrt(double In);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_sqrt)
float sqrt(float);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_sqrt)
float2 sqrt(float2);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_sqrt)
float3 sqrt(float3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_sqrt)
float4 sqrt(float4);
//===----------------------------------------------------------------------===//
// trunc builtins

View File

@@ -1,29 +1,53 @@
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
// RUN: dxil-pc-shadermodel6.2-library %s -fnative-half-type \
// RUN: -emit-llvm -disable-llvm-passes -o - | FileCheck %s
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \
// RUN: -emit-llvm -disable-llvm-passes -o - | FileCheck %s \
// RUN: --check-prefixes=CHECK,NATIVE_HALF
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
// RUN: dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
// RUN: -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF
using hlsl::sqrt;
// NATIVE_HALF: define noundef half @
// NATIVE_HALF: %{{.*}} = call half @llvm.sqrt.f16(
// NATIVE_HALF: ret half %{{.*}}
// NO_HALF: define noundef float @"?test_sqrt_half@@YA$halff@$halff@@Z"(
// NO_HALF: %{{.*}} = call float @llvm.sqrt.f32(
// NO_HALF: ret float %{{.*}}
half test_sqrt_half(half p0) { return sqrt(p0); }
// NATIVE_HALF: define noundef <2 x half> @
// NATIVE_HALF: %{{.*}} = call <2 x half> @llvm.sqrt.v2f16
// NATIVE_HALF: ret <2 x half> %{{.*}}
// NO_HALF: define noundef <2 x float> @
// NO_HALF: %{{.*}} = call <2 x float> @llvm.sqrt.v2f32(
// NO_HALF: ret <2 x float> %{{.*}}
half2 test_sqrt_half2(half2 p0) { return sqrt(p0); }
// NATIVE_HALF: define noundef <3 x half> @
// NATIVE_HALF: %{{.*}} = call <3 x half> @llvm.sqrt.v3f16
// NATIVE_HALF: ret <3 x half> %{{.*}}
// NO_HALF: define noundef <3 x float> @
// NO_HALF: %{{.*}} = call <3 x float> @llvm.sqrt.v3f32(
// NO_HALF: ret <3 x float> %{{.*}}
half3 test_sqrt_half3(half3 p0) { return sqrt(p0); }
// NATIVE_HALF: define noundef <4 x half> @
// NATIVE_HALF: %{{.*}} = call <4 x half> @llvm.sqrt.v4f16
// NATIVE_HALF: ret <4 x half> %{{.*}}
// NO_HALF: define noundef <4 x float> @
// NO_HALF: %{{.*}} = call <4 x float> @llvm.sqrt.v4f32(
// NO_HALF: ret <4 x float> %{{.*}}
half4 test_sqrt_half4(half4 p0) { return sqrt(p0); }
double sqrt_d(double x)
{
return sqrt(x);
}
// CHECK: define noundef double @"?sqrt_d@@YANN@Z"(
// CHECK: call double @llvm.sqrt.f64(double %0)
float sqrt_f(float x)
{
return sqrt(x);
}
// CHECK: define noundef float @"?sqrt_f@@YAMM@Z"(
// CHECK: call float @llvm.sqrt.f32(float %0)
half sqrt_h(half x)
{
return sqrt(x);
}
// CHECK: define noundef half @"?sqrt_h@@YA$f16@$f16@@Z"(
// CHECK: call half @llvm.sqrt.f16(half %0)
// CHECK: define noundef float @
// CHECK: %{{.*}} = call float @llvm.sqrt.f32(
// CHECK: ret float %{{.*}}
float test_sqrt_float(float p0) { return sqrt(p0); }
// CHECK: define noundef <2 x float> @
// CHECK: %{{.*}} = call <2 x float> @llvm.sqrt.v2f32
// CHECK: ret <2 x float> %{{.*}}
float2 test_sqrt_float2(float2 p0) { return sqrt(p0); }
// CHECK: define noundef <3 x float> @
// CHECK: %{{.*}} = call <3 x float> @llvm.sqrt.v3f32
// CHECK: ret <3 x float> %{{.*}}
float3 test_sqrt_float3(float3 p0) { return sqrt(p0); }
// CHECK: define noundef <4 x float> @
// CHECK: %{{.*}} = call <4 x float> @llvm.sqrt.v4f32
// CHECK: ret <4 x float> %{{.*}}
float4 test_sqrt_float4(float4 p0) { return sqrt(p0); }

View File

@@ -274,6 +274,10 @@ def Frac : DXILOpMapping<22, unary, int_dx_frac,
"Returns a fraction from 0 to 1 that represents the "
"decimal part of the input.",
[llvm_halforfloat_ty, LLVMMatchType<0>]>;
def Sqrt : DXILOpMapping<24, unary, int_sqrt,
"Returns the square root of the specified floating-point"
"value, per component.",
[llvm_halforfloat_ty, LLVMMatchType<0>]>;
def RSqrt : DXILOpMapping<25, unary, int_dx_rsqrt,
"Returns the reciprocal of the square root of the specified value."
"rsqrt(x) = 1 / sqrt(x).",

View File

@@ -0,0 +1,20 @@
; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
; Make sure dxil operation function calls for sqrt are generated for float and half.
define noundef float @sqrt_float(float noundef %a) #0 {
entry:
; CHECK:call float @dx.op.unary.f32(i32 24, float %{{.*}})
%elt.sqrt = call float @llvm.sqrt.f32(float %a)
ret float %elt.sqrt
}
define noundef half @sqrt_half(half noundef %a) #0 {
entry:
; CHECK:call half @dx.op.unary.f16(i32 24, half %{{.*}})
%elt.sqrt = call half @llvm.sqrt.f16(half %a)
ret half %elt.sqrt
}
declare half @llvm.sqrt.f16(half)
declare float @llvm.sqrt.f32(float)

View File

@@ -0,0 +1,10 @@
; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
; DXIL operation sqrt does not support double overload type
; CHECK: LLVM ERROR: Invalid Overload Type
define noundef double @sqrt_double(double noundef %a) {
entry:
%elt.sqrt = call double @llvm.sqrt.f64(double %a)
ret double %elt.sqrt
}