mirror of
https://github.com/intel/llvm.git
synced 2026-01-26 21:53:12 +08:00
[HLSL][SPIRV][DXIL] Implement dot4add_i8packed intrinsic (#113623)
- create a clang built-in in Builtins.td
- link dot4add_i8packed in hlsl_intrinsics.h
- add lowering to spirv backend through expansion of operation as OPSDot
is missing up to SPIRV 1.6 in SPIRVInstructionSelector.cpp
- add lowering to spirv backend using OpSDot in applicable SPIRV version
or if SPV_KHR_integer_dot_product is enabled
- add dot4add_i8packed intrinsic to IntrinsicsDirectX.td and mapping to
DXIL.td op Dot4AddI8Packed
- add tests for HLSL intrinsic lowering to dx/spv intrinsic in
dot4add_i8packed.hlsl
- add tests for sema checks in dot4add_i8packed-errors.hlsl
- add test of spir-v lowering in SPIRV/dot4add_i8packed.ll
- add test to dxil lowering in DirectX/dot4add_i8packed.ll
Resolves #99220
This commit is contained in:
@@ -4792,6 +4792,12 @@ def HLSLDotProduct : LangBuiltin<"HLSL_LANG"> {
|
||||
let Prototype = "void(...)";
|
||||
}
|
||||
|
||||
def HLSLDot4AddI8Packed : LangBuiltin<"HLSL_LANG"> {
|
||||
let Spellings = ["__builtin_hlsl_dot4add_i8packed"];
|
||||
let Attributes = [NoThrow, Const];
|
||||
let Prototype = "int(unsigned int, unsigned int, int)";
|
||||
}
|
||||
|
||||
def HLSLFrac : LangBuiltin<"HLSL_LANG"> {
|
||||
let Spellings = ["__builtin_hlsl_elementwise_frac"];
|
||||
let Attributes = [NoThrow, Const];
|
||||
|
||||
@@ -18855,7 +18855,17 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
|
||||
/*ReturnType=*/T0->getScalarType(),
|
||||
getDotProductIntrinsic(CGM.getHLSLRuntime(), VecTy0->getElementType()),
|
||||
ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.dot");
|
||||
} break;
|
||||
}
|
||||
case Builtin::BI__builtin_hlsl_dot4add_i8packed: {
|
||||
Value *A = EmitScalarExpr(E->getArg(0));
|
||||
Value *B = EmitScalarExpr(E->getArg(1));
|
||||
Value *C = EmitScalarExpr(E->getArg(2));
|
||||
|
||||
Intrinsic::ID ID = CGM.getHLSLRuntime().getDot4AddI8PackedIntrinsic();
|
||||
return Builder.CreateIntrinsic(
|
||||
/*ReturnType=*/C->getType(), ID, ArrayRef<Value *>{A, B, C}, nullptr,
|
||||
"hlsl.dot4add.i8packed");
|
||||
}
|
||||
case Builtin::BI__builtin_hlsl_lerp: {
|
||||
Value *X = EmitScalarExpr(E->getArg(0));
|
||||
Value *Y = EmitScalarExpr(E->getArg(1));
|
||||
|
||||
@@ -89,6 +89,7 @@ public:
|
||||
GENERATE_HLSL_INTRINSIC_FUNCTION(FDot, fdot)
|
||||
GENERATE_HLSL_INTRINSIC_FUNCTION(SDot, sdot)
|
||||
GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
|
||||
GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddI8Packed, dot4add_i8packed)
|
||||
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
|
||||
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_readlane)
|
||||
|
||||
|
||||
@@ -934,6 +934,16 @@ uint64_t dot(uint64_t3, uint64_t3);
|
||||
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
|
||||
uint64_t dot(uint64_t4, uint64_t4);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// dot4add builtins
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// \fn int dot4add_i8packed(uint A, uint B, int C)
|
||||
|
||||
_HLSL_AVAILABILITY(shadermodel, 6.4)
|
||||
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot4add_i8packed)
|
||||
int dot4add_i8packed(unsigned int, unsigned int, int);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// exp builtins
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
17
clang/test/CodeGenHLSL/builtins/dot4add_i8packed.hlsl
Normal file
17
clang/test/CodeGenHLSL/builtins/dot4add_i8packed.hlsl
Normal file
@@ -0,0 +1,17 @@
|
||||
// RUN: %clang_cc1 -finclude-default-header -triple \
|
||||
// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
|
||||
// RUN: FileCheck %s -DTARGET=dx
|
||||
// RUN: %clang_cc1 -finclude-default-header -triple \
|
||||
// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
|
||||
// RUN: FileCheck %s -DTARGET=spv
|
||||
|
||||
// Test basic lowering to runtime function call.
|
||||
|
||||
// CHECK-LABEL: test
|
||||
int test(uint a, uint b, int c) {
|
||||
// CHECK: %[[RET:.*]] = call [[TY:i32]] @llvm.[[TARGET]].dot4add.i8packed([[TY]] %[[#]], [[TY]] %[[#]], [[TY]] %[[#]])
|
||||
// CHECK: ret [[TY]] %[[RET]]
|
||||
return dot4add_i8packed(a, b, c);
|
||||
}
|
||||
|
||||
// CHECK: declare [[TY]] @llvm.[[TARGET]].dot4add.i8packed([[TY]], [[TY]], [[TY]])
|
||||
28
clang/test/SemaHLSL/BuiltIns/dot4add_i8packed-errors.hlsl
Normal file
28
clang/test/SemaHLSL/BuiltIns/dot4add_i8packed-errors.hlsl
Normal file
@@ -0,0 +1,28 @@
|
||||
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
|
||||
|
||||
int test_too_few_arg0() {
|
||||
return __builtin_hlsl_dot4add_i8packed();
|
||||
// expected-error@-1 {{too few arguments to function call, expected 3, have 0}}
|
||||
}
|
||||
|
||||
int test_too_few_arg1(int p0) {
|
||||
return __builtin_hlsl_dot4add_i8packed(p0);
|
||||
// expected-error@-1 {{too few arguments to function call, expected 3, have 1}}
|
||||
}
|
||||
|
||||
int test_too_few_arg2(int p0) {
|
||||
return __builtin_hlsl_dot4add_i8packed(p0, p0);
|
||||
// expected-error@-1 {{too few arguments to function call, expected 3, have 2}}
|
||||
}
|
||||
|
||||
int test_too_many_arg(int p0) {
|
||||
return __builtin_hlsl_dot4add_i8packed(p0, p0, p0, p0);
|
||||
// expected-error@-1 {{too many arguments to function call, expected 3, have 4}}
|
||||
}
|
||||
|
||||
struct S { float f; };
|
||||
|
||||
int test_expr_struct_type_check(S p0, int p1) {
|
||||
return __builtin_hlsl_dot4add_i8packed(p0, p1, p1);
|
||||
// expected-error@-1 {{no viable conversion from 'S' to 'unsigned int'}}
|
||||
}
|
||||
@@ -179,6 +179,8 @@ list of supported SPIR-V extensions, sorted alphabetically by their extension na
|
||||
- Provides additional information to a compiler, similar to the llvm.assume and llvm.expect intrinsics.
|
||||
* - ``SPV_KHR_float_controls``
|
||||
- Provides new execution modes to control floating-point computations by overriding an implementation’s default behavior for rounding modes, denormals, signed zero, and infinities.
|
||||
* - ``SPV_KHR_integer_dot_product``
|
||||
- Adds instructions for dot product operations on integer vectors with optional accumulation. Integer vectors includes 4-component vector of 8 bit integers and 4-component vectors of 8 bit integers packed into 32-bit integers.
|
||||
* - ``SPV_KHR_linkonce_odr``
|
||||
- Allows to use the LinkOnceODR linkage type that lets a function or global variable to be merged with other functions or global variables of the same name when linkage occurs.
|
||||
* - ``SPV_KHR_no_integer_wrap_decoration``
|
||||
|
||||
@@ -69,6 +69,7 @@ def int_dx_udot :
|
||||
DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
|
||||
[llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
|
||||
[IntrNoMem, Commutative] >;
|
||||
def int_dx_dot4add_i8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
|
||||
|
||||
def int_dx_frac : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
|
||||
def int_dx_degrees : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;
|
||||
|
||||
@@ -83,6 +83,7 @@ let TargetPrefix = "spv" in {
|
||||
DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
|
||||
[llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
|
||||
[IntrNoMem, Commutative] >;
|
||||
def int_spv_dot4add_i8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
|
||||
def int_spv_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
|
||||
def int_spv_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>;
|
||||
def int_spv_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;
|
||||
|
||||
@@ -788,6 +788,16 @@ def SplitDouble : DXILOp<102, splitDouble> {
|
||||
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
|
||||
}
|
||||
|
||||
def Dot4AddI8Packed : DXILOp<163, dot4AddPacked> {
|
||||
let Doc = "signed dot product of 4 x i8 vectors packed into i32, with "
|
||||
"accumulate to i32";
|
||||
let LLVMIntrinsic = int_dx_dot4add_i8packed;
|
||||
let arguments = [Int32Ty, Int32Ty, Int32Ty];
|
||||
let result = Int32Ty;
|
||||
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
|
||||
let stages = [Stages<DXIL1_0, [all_stages]>];
|
||||
}
|
||||
|
||||
def AnnotateHandle : DXILOp<216, annotateHandle> {
|
||||
let Doc = "annotate handle with resource properties";
|
||||
let arguments = [HandleTy, ResPropsTy];
|
||||
|
||||
@@ -56,6 +56,8 @@ static const std::map<std::string, SPIRV::Extension::Extension>
|
||||
SPIRV::Extension::Extension::SPV_KHR_expect_assume},
|
||||
{"SPV_KHR_bit_instructions",
|
||||
SPIRV::Extension::Extension::SPV_KHR_bit_instructions},
|
||||
{"SPV_KHR_integer_dot_product",
|
||||
SPIRV::Extension::Extension::SPV_KHR_integer_dot_product},
|
||||
{"SPV_KHR_linkonce_odr",
|
||||
SPIRV::Extension::Extension::SPV_KHR_linkonce_odr},
|
||||
{"SPV_INTEL_inline_assembly",
|
||||
|
||||
@@ -524,6 +524,9 @@ defm OpISubBorrow: BinOpTypedGen<"OpISubBorrow", 150, subc, 0, 1>;
|
||||
def OpUMulExtended: BinOp<"OpUMulExtended", 151>;
|
||||
def OpSMulExtended: BinOp<"OpSMulExtended", 152>;
|
||||
|
||||
def OpSDot: BinOp<"OpSDot", 4450>;
|
||||
def OpUDot: BinOp<"OpUDot", 4451>;
|
||||
|
||||
// 3.42.14 Bit Instructions
|
||||
|
||||
defm OpShiftRightLogical: BinOpTypedGen<"OpShiftRightLogical", 194, srl, 0, 1>;
|
||||
|
||||
@@ -164,6 +164,13 @@ private:
|
||||
bool selectIntegerDot(Register ResVReg, const SPIRVType *ResType,
|
||||
MachineInstr &I) const;
|
||||
|
||||
template <bool Signed>
|
||||
bool selectDot4AddPacked(Register ResVReg, const SPIRVType *ResType,
|
||||
MachineInstr &I) const;
|
||||
template <bool Signed>
|
||||
bool selectDot4AddPackedExpansion(Register ResVReg, const SPIRVType *ResType,
|
||||
MachineInstr &I) const;
|
||||
|
||||
void renderImm32(MachineInstrBuilder &MIB, const MachineInstr &I,
|
||||
int OpIdx) const;
|
||||
void renderFImm64(MachineInstrBuilder &MIB, const MachineInstr &I,
|
||||
@@ -1646,7 +1653,7 @@ bool SPIRVInstructionSelector::selectIntegerDot(Register ResVReg,
|
||||
// Multiply the vectors, then sum the results
|
||||
Register Vec0 = I.getOperand(2).getReg();
|
||||
Register Vec1 = I.getOperand(3).getReg();
|
||||
Register TmpVec = MRI->createVirtualRegister(&SPIRV::IDRegClass);
|
||||
Register TmpVec = MRI->createVirtualRegister(GR.getRegClass(ResType));
|
||||
SPIRVType *VecType = GR.getSPIRVTypeForVReg(Vec0);
|
||||
|
||||
bool Result = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIMulV))
|
||||
@@ -1660,8 +1667,8 @@ bool SPIRVInstructionSelector::selectIntegerDot(Register ResVReg,
|
||||
GR.getScalarOrVectorComponentCount(VecType) > 1 &&
|
||||
"dot product requires a vector of at least 2 components");
|
||||
|
||||
Register Res = MRI->createVirtualRegister(&SPIRV::IDRegClass);
|
||||
Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
|
||||
Register Res = MRI->createVirtualRegister(GR.getRegClass(ResType));
|
||||
Result &= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
|
||||
.addDef(Res)
|
||||
.addUse(GR.getSPIRVTypeID(ResType))
|
||||
.addUse(TmpVec)
|
||||
@@ -1669,9 +1676,9 @@ bool SPIRVInstructionSelector::selectIntegerDot(Register ResVReg,
|
||||
.constrainAllUses(TII, TRI, RBI);
|
||||
|
||||
for (unsigned i = 1; i < GR.getScalarOrVectorComponentCount(VecType); i++) {
|
||||
Register Elt = MRI->createVirtualRegister(&SPIRV::IDRegClass);
|
||||
Register Elt = MRI->createVirtualRegister(GR.getRegClass(ResType));
|
||||
|
||||
Result |=
|
||||
Result &=
|
||||
BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
|
||||
.addDef(Elt)
|
||||
.addUse(GR.getSPIRVTypeID(ResType))
|
||||
@@ -1680,10 +1687,10 @@ bool SPIRVInstructionSelector::selectIntegerDot(Register ResVReg,
|
||||
.constrainAllUses(TII, TRI, RBI);
|
||||
|
||||
Register Sum = i < GR.getScalarOrVectorComponentCount(VecType) - 1
|
||||
? MRI->createVirtualRegister(&SPIRV::IDRegClass)
|
||||
? MRI->createVirtualRegister(GR.getRegClass(ResType))
|
||||
: ResVReg;
|
||||
|
||||
Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIAddS))
|
||||
Result &= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIAddS))
|
||||
.addDef(Sum)
|
||||
.addUse(GR.getSPIRVTypeID(ResType))
|
||||
.addUse(Res)
|
||||
@@ -1695,6 +1702,112 @@ bool SPIRVInstructionSelector::selectIntegerDot(Register ResVReg,
|
||||
return Result;
|
||||
}
|
||||
|
||||
template <bool Signed>
|
||||
bool SPIRVInstructionSelector::selectDot4AddPacked(Register ResVReg,
|
||||
const SPIRVType *ResType,
|
||||
MachineInstr &I) const {
|
||||
assert(I.getNumOperands() == 5);
|
||||
assert(I.getOperand(2).isReg());
|
||||
assert(I.getOperand(3).isReg());
|
||||
assert(I.getOperand(4).isReg());
|
||||
MachineBasicBlock &BB = *I.getParent();
|
||||
|
||||
auto DotOp = Signed ? SPIRV::OpSDot : SPIRV::OpUDot;
|
||||
Register Dot = MRI->createVirtualRegister(GR.getRegClass(ResType));
|
||||
bool Result = BuildMI(BB, I, I.getDebugLoc(), TII.get(DotOp))
|
||||
.addDef(Dot)
|
||||
.addUse(GR.getSPIRVTypeID(ResType))
|
||||
.addUse(I.getOperand(2).getReg())
|
||||
.addUse(I.getOperand(3).getReg())
|
||||
.constrainAllUses(TII, TRI, RBI);
|
||||
|
||||
Result &= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIAddS))
|
||||
.addDef(ResVReg)
|
||||
.addUse(GR.getSPIRVTypeID(ResType))
|
||||
.addUse(Dot)
|
||||
.addUse(I.getOperand(4).getReg())
|
||||
.constrainAllUses(TII, TRI, RBI);
|
||||
|
||||
return Result;
|
||||
}
|
||||
|
||||
// Since pre-1.6 SPIRV has no DotProductInput4x8BitPacked implementation,
|
||||
// extract the elements of the packed inputs, multiply them and add the result
|
||||
// to the accumulator.
|
||||
template <bool Signed>
|
||||
bool SPIRVInstructionSelector::selectDot4AddPackedExpansion(
|
||||
Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const {
|
||||
assert(I.getNumOperands() == 5);
|
||||
assert(I.getOperand(2).isReg());
|
||||
assert(I.getOperand(3).isReg());
|
||||
assert(I.getOperand(4).isReg());
|
||||
MachineBasicBlock &BB = *I.getParent();
|
||||
|
||||
bool Result = false;
|
||||
|
||||
// Acc = C
|
||||
Register Acc = I.getOperand(4).getReg();
|
||||
SPIRVType *EltType = GR.getOrCreateSPIRVIntegerType(8, I, TII);
|
||||
auto ExtractOp =
|
||||
Signed ? SPIRV::OpBitFieldSExtract : SPIRV::OpBitFieldUExtract;
|
||||
|
||||
// Extract the i8 element, multiply and add it to the accumulator
|
||||
for (unsigned i = 0; i < 4; i++) {
|
||||
// A[i]
|
||||
Register AElt = MRI->createVirtualRegister(&SPIRV::IDRegClass);
|
||||
Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp))
|
||||
.addDef(AElt)
|
||||
.addUse(GR.getSPIRVTypeID(ResType))
|
||||
.addUse(I.getOperand(2).getReg())
|
||||
.addUse(GR.getOrCreateConstInt(i * 8, I, EltType, TII))
|
||||
.addUse(GR.getOrCreateConstInt(8, I, EltType, TII))
|
||||
.constrainAllUses(TII, TRI, RBI);
|
||||
|
||||
// B[i]
|
||||
Register BElt = MRI->createVirtualRegister(&SPIRV::IDRegClass);
|
||||
Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp))
|
||||
.addDef(BElt)
|
||||
.addUse(GR.getSPIRVTypeID(ResType))
|
||||
.addUse(I.getOperand(3).getReg())
|
||||
.addUse(GR.getOrCreateConstInt(i * 8, I, EltType, TII))
|
||||
.addUse(GR.getOrCreateConstInt(8, I, EltType, TII))
|
||||
.constrainAllUses(TII, TRI, RBI);
|
||||
|
||||
// A[i] * B[i]
|
||||
Register Mul = MRI->createVirtualRegister(&SPIRV::IDRegClass);
|
||||
Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIMulS))
|
||||
.addDef(Mul)
|
||||
.addUse(GR.getSPIRVTypeID(ResType))
|
||||
.addUse(AElt)
|
||||
.addUse(BElt)
|
||||
.constrainAllUses(TII, TRI, RBI);
|
||||
|
||||
// Discard 24 highest-bits so that stored i32 register is i8 equivalent
|
||||
Register MaskMul = MRI->createVirtualRegister(&SPIRV::IDRegClass);
|
||||
Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp))
|
||||
.addDef(MaskMul)
|
||||
.addUse(GR.getSPIRVTypeID(ResType))
|
||||
.addUse(Mul)
|
||||
.addUse(GR.getOrCreateConstInt(0, I, EltType, TII))
|
||||
.addUse(GR.getOrCreateConstInt(8, I, EltType, TII))
|
||||
.constrainAllUses(TII, TRI, RBI);
|
||||
|
||||
// Acc = Acc + A[i] * B[i]
|
||||
Register Sum =
|
||||
i < 3 ? MRI->createVirtualRegister(&SPIRV::IDRegClass) : ResVReg;
|
||||
Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIAddS))
|
||||
.addDef(Sum)
|
||||
.addUse(GR.getSPIRVTypeID(ResType))
|
||||
.addUse(Acc)
|
||||
.addUse(MaskMul)
|
||||
.constrainAllUses(TII, TRI, RBI);
|
||||
|
||||
Acc = Sum;
|
||||
}
|
||||
|
||||
return Result;
|
||||
}
|
||||
|
||||
/// Transform saturate(x) to clamp(x, 0.0f, 1.0f) as SPIRV
|
||||
/// does not have a saturate builtin.
|
||||
bool SPIRVInstructionSelector::selectSaturate(Register ResVReg,
|
||||
@@ -2528,6 +2641,11 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
|
||||
case Intrinsic::spv_udot:
|
||||
case Intrinsic::spv_sdot:
|
||||
return selectIntegerDot(ResVReg, ResType, I);
|
||||
case Intrinsic::spv_dot4add_i8packed:
|
||||
if (STI.canUseExtension(SPIRV::Extension::SPV_KHR_integer_dot_product) ||
|
||||
STI.isAtLeastSPIRVVer(VersionTuple(1, 6)))
|
||||
return selectDot4AddPacked<true>(ResVReg, ResType, I);
|
||||
return selectDot4AddPackedExpansion<true>(ResVReg, ResType, I);
|
||||
case Intrinsic::spv_all:
|
||||
return selectAll(ResVReg, ResType, I);
|
||||
case Intrinsic::spv_any:
|
||||
|
||||
@@ -626,6 +626,22 @@ void SPIRV::RequirementHandler::removeCapabilityIf(
|
||||
namespace llvm {
|
||||
namespace SPIRV {
|
||||
void RequirementHandler::initAvailableCapabilities(const SPIRVSubtarget &ST) {
|
||||
// Provided by both all supported Vulkan versions and OpenCl.
|
||||
addAvailableCaps({Capability::Shader, Capability::Linkage, Capability::Int8,
|
||||
Capability::Int16});
|
||||
|
||||
if (ST.isAtLeastSPIRVVer(VersionTuple(1, 6)))
|
||||
addAvailableCaps({Capability::DotProduct, Capability::DotProductInputAll,
|
||||
Capability::DotProductInput4x8Bit,
|
||||
Capability::DotProductInput4x8BitPacked});
|
||||
|
||||
// Add capabilities enabled by extensions.
|
||||
for (auto Extension : ST.getAllAvailableExtensions()) {
|
||||
CapabilityList EnabledCapabilities =
|
||||
getCapabilitiesEnabledByExtension(Extension);
|
||||
addAvailableCaps(EnabledCapabilities);
|
||||
}
|
||||
|
||||
if (ST.isOpenCLEnv()) {
|
||||
initAvailableCapabilitiesForOpenCL(ST);
|
||||
return;
|
||||
@@ -643,10 +659,8 @@ void RequirementHandler::initAvailableCapabilitiesForOpenCL(
|
||||
const SPIRVSubtarget &ST) {
|
||||
// Add the min requirements for different OpenCL and SPIR-V versions.
|
||||
addAvailableCaps({Capability::Addresses, Capability::Float16Buffer,
|
||||
Capability::Int16, Capability::Int8, Capability::Kernel,
|
||||
Capability::Linkage, Capability::Vector16,
|
||||
Capability::Groups, Capability::GenericPointer,
|
||||
Capability::Shader});
|
||||
Capability::Kernel, Capability::Vector16,
|
||||
Capability::Groups, Capability::GenericPointer});
|
||||
if (ST.hasOpenCLFullProfile())
|
||||
addAvailableCaps({Capability::Int64, Capability::Int64Atomics});
|
||||
if (ST.hasOpenCLImageSupport()) {
|
||||
@@ -675,25 +689,16 @@ void RequirementHandler::initAvailableCapabilitiesForOpenCL(
|
||||
// TODO: verify if this needs some checks.
|
||||
addAvailableCaps({Capability::Float16, Capability::Float64});
|
||||
|
||||
// Add capabilities enabled by extensions.
|
||||
for (auto Extension : ST.getAllAvailableExtensions()) {
|
||||
CapabilityList EnabledCapabilities =
|
||||
getCapabilitiesEnabledByExtension(Extension);
|
||||
addAvailableCaps(EnabledCapabilities);
|
||||
}
|
||||
|
||||
// TODO: add OpenCL extensions.
|
||||
}
|
||||
|
||||
void RequirementHandler::initAvailableCapabilitiesForVulkan(
|
||||
const SPIRVSubtarget &ST) {
|
||||
addAvailableCaps({Capability::Shader, Capability::Linkage});
|
||||
|
||||
// Core in Vulkan 1.1 and earlier.
|
||||
addAvailableCaps({Capability::Int16, Capability::Int64, Capability::Float16,
|
||||
Capability::Float64, Capability::GroupNonUniform,
|
||||
Capability::Image1D, Capability::SampledBuffer,
|
||||
Capability::ImageBuffer,
|
||||
addAvailableCaps({Capability::Int64, Capability::Float16, Capability::Float64,
|
||||
Capability::GroupNonUniform, Capability::Image1D,
|
||||
Capability::SampledBuffer, Capability::ImageBuffer,
|
||||
Capability::UniformBufferArrayDynamicIndexing,
|
||||
Capability::SampledImageArrayDynamicIndexing,
|
||||
Capability::StorageBufferArrayDynamicIndexing,
|
||||
@@ -1000,6 +1005,32 @@ void addOpAccessChainReqs(const MachineInstr &Instr,
|
||||
}
|
||||
}
|
||||
|
||||
static void AddDotProductRequirements(const MachineInstr &MI,
|
||||
SPIRV::RequirementHandler &Reqs,
|
||||
const SPIRVSubtarget &ST) {
|
||||
if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_integer_dot_product))
|
||||
Reqs.addExtension(SPIRV::Extension::SPV_KHR_integer_dot_product);
|
||||
Reqs.addCapability(SPIRV::Capability::DotProduct);
|
||||
|
||||
const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
|
||||
const MachineInstr *InstrPtr = &MI;
|
||||
assert(MI.getOperand(1).isReg() && "Unexpected operand in dot");
|
||||
|
||||
Register TypeReg = InstrPtr->getOperand(1).getReg();
|
||||
SPIRVType *TypeDef = MRI.getVRegDef(TypeReg);
|
||||
if (TypeDef->getOpcode() == SPIRV::OpTypeInt) {
|
||||
assert(TypeDef->getOperand(1).getImm() == 32);
|
||||
Reqs.addCapability(SPIRV::Capability::DotProductInput4x8BitPacked);
|
||||
} else if (TypeDef->getOpcode() == SPIRV::OpTypeVector) {
|
||||
SPIRVType *ScalarTypeDef = MRI.getVRegDef(TypeDef->getOperand(1).getReg());
|
||||
assert(ScalarTypeDef->getOpcode() == SPIRV::OpTypeInt);
|
||||
auto Capability = ScalarTypeDef->getOperand(1).getImm() == 8
|
||||
? SPIRV::Capability::DotProductInput4x8Bit
|
||||
: SPIRV::Capability::DotProductInputAll;
|
||||
Reqs.addCapability(Capability);
|
||||
}
|
||||
}
|
||||
|
||||
void addInstrRequirements(const MachineInstr &MI,
|
||||
SPIRV::RequirementHandler &Reqs,
|
||||
const SPIRVSubtarget &ST) {
|
||||
@@ -1376,6 +1407,10 @@ void addInstrRequirements(const MachineInstr &MI,
|
||||
Reqs.addCapability(SPIRV::Capability::SplitBarrierINTEL);
|
||||
}
|
||||
break;
|
||||
case SPIRV::OpSDot:
|
||||
case SPIRV::OpUDot:
|
||||
AddDotProductRequirements(MI, Reqs, ST);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -129,9 +129,6 @@ bool SPIRVSubtarget::canDirectlyComparePointers() const {
|
||||
|
||||
void SPIRVSubtarget::initAvailableExtensions() {
|
||||
AvailableExtensions.clear();
|
||||
if (!isOpenCLEnv())
|
||||
return;
|
||||
|
||||
AvailableExtensions.insert(Extensions.begin(), Extensions.end());
|
||||
}
|
||||
|
||||
|
||||
@@ -467,6 +467,10 @@ defm ExpectAssumeKHR : CapabilityOperand<5629, 0, 0, [SPV_KHR_expect_assume], []
|
||||
defm FunctionPointersINTEL : CapabilityOperand<5603, 0, 0, [SPV_INTEL_function_pointers], []>;
|
||||
defm IndirectReferencesINTEL : CapabilityOperand<5604, 0, 0, [SPV_INTEL_function_pointers], []>;
|
||||
defm AsmINTEL : CapabilityOperand<5606, 0, 0, [SPV_INTEL_inline_assembly], []>;
|
||||
defm DotProductInputAll : CapabilityOperand<6016, 0x10600, 0, [SPV_KHR_integer_dot_product], []>;
|
||||
defm DotProductInput4x8Bit : CapabilityOperand<6017, 0x10600, 0, [SPV_KHR_integer_dot_product], [Int8]>;
|
||||
defm DotProductInput4x8BitPacked : CapabilityOperand<6018, 0x10600, 0, [SPV_KHR_integer_dot_product], []>;
|
||||
defm DotProduct : CapabilityOperand<6019, 0x10600, 0, [SPV_KHR_integer_dot_product], []>;
|
||||
defm GroupNonUniformRotateKHR : CapabilityOperand<6026, 0, 0, [SPV_KHR_subgroup_rotate], [GroupNonUniform]>;
|
||||
defm AtomicFloat32AddEXT : CapabilityOperand<6033, 0, 0, [SPV_EXT_shader_atomic_float_add], []>;
|
||||
defm AtomicFloat64AddEXT : CapabilityOperand<6034, 0, 0, [SPV_EXT_shader_atomic_float_add], []>;
|
||||
|
||||
10
llvm/test/CodeGen/DirectX/dot4add_i8packed.ll
Normal file
10
llvm/test/CodeGen/DirectX/dot4add_i8packed.ll
Normal file
@@ -0,0 +1,10 @@
|
||||
; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s
|
||||
|
||||
define void @main(i32 %a, i32 %b, i32 %c) {
|
||||
entry:
|
||||
; CHECK: call i32 @dx.op.dot4AddPacked(i32 163, i32 %a, i32 %b, i32 %c)
|
||||
%0 = call i32 @llvm.dx.dot4add.i8packed(i32 %a, i32 %b, i32 %c)
|
||||
ret void
|
||||
}
|
||||
|
||||
declare i32 @llvm.dx.dot4add.i8packed(i32, i32, i32)
|
||||
65
llvm/test/CodeGen/SPIRV/hlsl-intrinsics/dot4add_i8packed.ll
Normal file
65
llvm/test/CodeGen/SPIRV/hlsl-intrinsics/dot4add_i8packed.ll
Normal file
@@ -0,0 +1,65 @@
|
||||
; RUN: llc -O0 -mtriple=spirv1.5-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-EXP
|
||||
; RUN: llc -O0 -mtriple=spirv1.6-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-DOT
|
||||
; RUN: llc -O0 -mtriple=spirv-unknown-unknown -spirv-ext=+SPV_KHR_integer_dot_product %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-EXT
|
||||
; RUN: %if spirv-tools %{ llc -verify-machineinstrs -O0 -mtriple=spirv1.5-unknown-unknown %s -o - -filetype=obj | spirv-val %}
|
||||
; RUN: %if spirv-tools %{ llc -verify-machineinstrs -O0 -mtriple=spirv1.6-unknown-unknown %s -o - -filetype=obj | spirv-val %}
|
||||
; RUN: %if spirv-tools %{ llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-unknown -spirv-ext=+SPV_KHR_integer_dot_product %s -o - -filetype=obj | spirv-val %}
|
||||
|
||||
; CHECK-DOT: OpCapability DotProduct
|
||||
; CHECK-DOT: OpCapability DotProductInput4x8BitPacked
|
||||
; CHECK-EXT: OpExtension "SPV_KHR_integer_dot_product"
|
||||
|
||||
; CHECK: %[[#int_32:]] = OpTypeInt 32 0
|
||||
; CHECK-EXP-DAG: %[[#int_8:]] = OpTypeInt 8 0
|
||||
; CHECK-EXP-DAG: %[[#zero:]] = OpConstantNull %[[#int_8]]
|
||||
; CHECK-EXP-DAG: %[[#eight:]] = OpConstant %[[#int_8]] 8
|
||||
; CHECK-EXP-DAG: %[[#sixteen:]] = OpConstant %[[#int_8]] 16
|
||||
; CHECK-EXP-DAG: %[[#twentyfour:]] = OpConstant %[[#int_8]] 24
|
||||
|
||||
; CHECK-LABEL: Begin function test_dot
|
||||
define noundef i32 @test_dot(i32 noundef %a, i32 noundef %b, i32 noundef %c) {
|
||||
entry:
|
||||
; CHECK: %[[#A:]] = OpFunctionParameter %[[#int_32]]
|
||||
; CHECK: %[[#B:]] = OpFunctionParameter %[[#int_32]]
|
||||
; CHECK: %[[#C:]] = OpFunctionParameter %[[#int_32]]
|
||||
|
||||
; Test that we use the dot product op when capabilities allow
|
||||
|
||||
; CHECK-DOT: %[[#DOT:]] = OpSDot %[[#int_32]] %[[#A]] %[[#B]]
|
||||
; CHECK-DOT: %[[#RES:]] = OpIAdd %[[#int_32]] %[[#DOT]] %[[#C]]
|
||||
|
||||
; Test expansion is used when spirv dot product capabilities aren't available:
|
||||
|
||||
; First element of the packed vector
|
||||
; CHECK-EXP: %[[#A0:]] = OpBitFieldSExtract %[[#int_32]] %[[#A]] %[[#zero]] %[[#eight]]
|
||||
; CHECK-EXP: %[[#B0:]] = OpBitFieldSExtract %[[#int_32]] %[[#B]] %[[#zero]] %[[#eight]]
|
||||
; CHECK-EXP: %[[#MUL0:]] = OpIMul %[[#int_32]] %[[#A0]] %[[#B0]]
|
||||
; CHECK-EXP: %[[#MASK0:]] = OpBitFieldSExtract %[[#int_32]] %[[#MUL0]] %[[#zero]] %[[#eight]]
|
||||
; CHECK-EXP: %[[#ACC0:]] = OpIAdd %[[#int_32]] %[[#C]] %[[#MASK0]]
|
||||
|
||||
; Second element of the packed vector
|
||||
; CHECK-EXP: %[[#A1:]] = OpBitFieldSExtract %[[#int_32]] %[[#A]] %[[#eight]] %[[#eight]]
|
||||
; CHECK-EXP: %[[#B1:]] = OpBitFieldSExtract %[[#int_32]] %[[#B]] %[[#eight]] %[[#eight]]
|
||||
; CHECK-EXP: %[[#MUL1:]] = OpIMul %[[#int_32]] %[[#A1]] %[[#B1]]
|
||||
; CHECK-EXP: %[[#MASK1:]] = OpBitFieldSExtract %[[#int_32]] %[[#MUL1]] %[[#zero]] %[[#eight]]
|
||||
; CHECK-EXP: %[[#ACC1:]] = OpIAdd %[[#int_32]] %[[#ACC0]] %[[#MASK1]]
|
||||
|
||||
; Third element of the packed vector
|
||||
; CHECK-EXP: %[[#A2:]] = OpBitFieldSExtract %[[#int_32]] %[[#A]] %[[#sixteen]] %[[#eight]]
|
||||
; CHECK-EXP: %[[#B2:]] = OpBitFieldSExtract %[[#int_32]] %[[#B]] %[[#sixteen]] %[[#eight]]
|
||||
; CHECK-EXP: %[[#MUL2:]] = OpIMul %[[#int_32]] %[[#A2]] %[[#B2]]
|
||||
; CHECK-EXP: %[[#MASK2:]] = OpBitFieldSExtract %[[#int_32]] %[[#MUL2]] %[[#zero]] %[[#eight]]
|
||||
; CHECK-EXP: %[[#ACC2:]] = OpIAdd %[[#int_32]] %[[#ACC1]] %[[#MASK2]]
|
||||
|
||||
; Fourth element of the packed vector
|
||||
; CHECK-EXP: %[[#A3:]] = OpBitFieldSExtract %[[#int_32]] %[[#A]] %[[#twentyfour]] %[[#eight]]
|
||||
; CHECK-EXP: %[[#B3:]] = OpBitFieldSExtract %[[#int_32]] %[[#B]] %[[#twentyfour]] %[[#eight]]
|
||||
; CHECK-EXP: %[[#MUL3:]] = OpIMul %[[#int_32]] %[[#A3]] %[[#B3]]
|
||||
; CHECK-EXP: %[[#MASK3:]] = OpBitFieldSExtract %[[#int_32]] %[[#MUL3]] %[[#zero]] %[[#eight]]
|
||||
|
||||
; CHECK-EXP: %[[#RES:]] = OpIAdd %[[#int_32]] %[[#ACC2]] %[[#MASK3]]
|
||||
; CHECK: OpReturnValue %[[#RES]]
|
||||
%spv.dot = call i32 @llvm.spv.dot4add.i8packed(i32 %a, i32 %b, i32 %c)
|
||||
|
||||
ret i32 %spv.dot
|
||||
}
|
||||
Reference in New Issue
Block a user