Files
intel-graphics-compiler/IGC/Compiler/Optimizer/OpenCLPasses/SpvSubgroupMMAResolution/SpvSubgroupMMAResolution.cpp
Kwasniewski, Patryk b717c7c181 fix test
igc_opt returns non-zero on failed assert
2025-08-26 15:56:46 +02:00

517 lines
20 KiB
C++

/*========================== begin_copyright_notice ============================
Copyright (C) 2024 Intel Corporation
SPDX-License-Identifier: MIT
============================= end_copyright_notice ===========================*/
#include "SpvSubgroupMMAResolution.hpp"
#include <cmath> // for ceil
#include "common/LLVMWarningsPush.hpp"
#include <llvm/ADT/SmallVector.h>
#include "llvm/ADT/StringExtras.h"
#include <llvm/ADT/StringRef.h>
#include <llvm/ADT/Twine.h>
#include <llvm/IR/Constants.h>
#include <llvm/IR/Instruction.h>
#include <llvm/IR/Type.h>
#include <llvm/Support/raw_ostream.h>
#include "llvmWrapper/IR/Instructions.h"
#include "common/LLVMWarningsPop.hpp"
#include "Compiler/CodeGenPublic.h"
#include "Compiler/IGCPassSupport.h"
using namespace llvm;
using namespace IGC;
char SpvSubgroupMMAResolution::ID = 0;
SpvSubgroupMMAResolution::SupportedTable SpvSubgroupMMAResolution::m_Simd8Table;
SpvSubgroupMMAResolution::SupportedTable SpvSubgroupMMAResolution::m_Simd16Table;
#define PASS_FLAG "igc-spv-subgroup-mma-resolution"
#define PASS_DESC "Lowering of SPIR-V INTEL subgroup_matrix_multiply_accumulate instructions"
#define PASS_CFG_ONLY false
#define PASS_ANALYSIS false
#define DEBUG_TYPE "spv-subgroup-mma-resolution"
IGC_INITIALIZE_PASS_BEGIN(SpvSubgroupMMAResolution, PASS_FLAG, PASS_DESC, PASS_CFG_ONLY, PASS_ANALYSIS)
IGC_INITIALIZE_PASS_DEPENDENCY(CodeGenContextWrapper)
IGC_INITIALIZE_PASS_DEPENDENCY(MetaDataUtilsWrapper)
IGC_INITIALIZE_PASS_END(SpvSubgroupMMAResolution, PASS_FLAG, PASS_DESC, PASS_CFG_ONLY, PASS_ANALYSIS)
SpvSubgroupMMAResolution::SpvSubgroupMMAResolution() : ModulePass(ID) {
initializeSpvSubgroupMMAResolutionPass(*PassRegistry::getPassRegistry());
}
bool SpvSubgroupMMAResolution::runOnModule(Module &M) {
m_BuiltinsToRemove.clear();
m_Module = &M;
m_Changed = false;
m_Ctx = getAnalysis<CodeGenContextWrapper>().getCodeGenContext();
visit(M);
for (auto &F : m_BuiltinsToRemove)
F->eraseFromParent();
return m_Changed;
}
enum {
None = 0,
MatrixASignedComponentsINTEL = 1 << 0,
MatrixBSignedComponentsINTEL = 1 << 1,
MatrixCBFloat16INTEL = 1 << 2,
MatrixResultBFloat16INTEL = 1 << 3,
MatrixAPackedInt8INTEL = 1 << 4,
MatrixBPackedInt8INTEL = 1 << 5,
MatrixAPackedInt4INTEL = 1 << 6,
MatrixBPackedInt4INTEL = 1 << 7,
MatrixATF32INTEL = 1 << 8,
MatrixBTF32INTEL = 1 << 9,
MatrixAPackedFloat16INTEL = 1 << 10,
MatrixBPackedFloat16INTEL = 1 << 11,
MatrixAPackedBFloat16INTEL = 1 << 12,
MatrixBPackedBFloat16INTEL = 1 << 13,
};
static std::string GetHumanReadableOperand(uint32_t operand) {
SmallVector<std::string, 8> operands;
if (operand & MatrixASignedComponentsINTEL)
operands.push_back("MatrixASignedComponentsINTEL");
if (operand & MatrixBSignedComponentsINTEL)
operands.push_back("MatrixBSignedComponentsINTEL");
if (operand & MatrixCBFloat16INTEL)
operands.push_back("MatrixCBFloat16INTEL");
if (operand & MatrixResultBFloat16INTEL)
operands.push_back("MatrixResultBFloat16INTEL");
if (operand & MatrixAPackedInt8INTEL)
operands.push_back("MatrixAPackedInt8INTEL");
if (operand & MatrixBPackedInt8INTEL)
operands.push_back("MatrixBPackedInt8INTEL");
if (operand & MatrixAPackedInt4INTEL)
operands.push_back("MatrixAPackedInt4INTEL");
if (operand & MatrixBPackedInt4INTEL)
operands.push_back("MatrixBPackedInt4INTEL");
if (operand & MatrixATF32INTEL)
operands.push_back("MatrixATF32INTEL");
if (operand & MatrixBTF32INTEL)
operands.push_back("MatrixBTF32INTEL");
if (operand & MatrixAPackedFloat16INTEL)
operands.push_back("MatrixAPackedFloat16INTEL");
if (operand & MatrixBPackedFloat16INTEL)
operands.push_back("MatrixBPackedFloat16INTEL");
if (operand & MatrixAPackedBFloat16INTEL)
operands.push_back("MatrixAPackedBFloat16INTEL");
if (operand & MatrixBPackedBFloat16INTEL)
operands.push_back("MatrixBPackedBFloat16INTEL");
if (operands.empty())
return "None";
return llvm::join(operands, " | ");
}
void SpvSubgroupMMAResolution::populateSimd8Table() {
// 8-bit integer matrix sources (signed and unsigned), 32-bit integer accumulator:
m_Simd8Table[32][ElType::I32][ElType::I32][ElType::I32][MatrixAPackedInt8INTEL | MatrixBPackedInt8INTEL] = "u8_u8_";
m_Simd8Table[32][ElType::I32][ElType::I32][ElType::I32]
[MatrixAPackedInt8INTEL | MatrixBPackedInt8INTEL | MatrixASignedComponentsINTEL] = "s8_u8_";
m_Simd8Table[32][ElType::I32][ElType::I32][ElType::I32]
[MatrixAPackedInt8INTEL | MatrixBPackedInt8INTEL | MatrixBSignedComponentsINTEL] = "u8_s8_";
m_Simd8Table[32][ElType::I32][ElType::I32][ElType::I32][MatrixAPackedInt8INTEL | MatrixBPackedInt8INTEL |
MatrixASignedComponentsINTEL | MatrixBSignedComponentsINTEL] =
"s8_s8_";
// 4-bit integer matrix sources (signed and unsigned), 32-bit integer accumulator:
m_Simd8Table[64][ElType::I32][ElType::I32][ElType::I32][MatrixAPackedInt4INTEL | MatrixBPackedInt4INTEL] = "u4_u4_";
m_Simd8Table[64][ElType::I32][ElType::I32][ElType::I32]
[MatrixAPackedInt4INTEL | MatrixBPackedInt4INTEL | MatrixASignedComponentsINTEL] = "s4_u4_";
m_Simd8Table[64][ElType::I32][ElType::I32][ElType::I32]
[MatrixAPackedInt4INTEL | MatrixBPackedInt4INTEL | MatrixBSignedComponentsINTEL] = "u4_s4_";
m_Simd8Table[64][ElType::I32][ElType::I32][ElType::I32][MatrixAPackedInt4INTEL | MatrixBPackedInt4INTEL |
MatrixASignedComponentsINTEL | MatrixBSignedComponentsINTEL] =
"s4_s4_";
// fp16 matrix sources, fp32 accumulator:
m_Simd8Table[16][ElType::F32][ElType::I32][ElType::I32][MatrixAPackedFloat16INTEL | MatrixBPackedFloat16INTEL] =
"hf_hf_";
// bf16 matrix sources, fp32 accumulator:
m_Simd8Table[16][ElType::F32][ElType::I32][ElType::I32][MatrixAPackedBFloat16INTEL | MatrixBPackedBFloat16INTEL] =
"bf_bf_";
}
void SpvSubgroupMMAResolution::populateSimd16Table() {
// 8-bit integer matrix sources (signed and unsigned), 32-bit integer accumulator:
m_Simd16Table[32][ElType::I32][ElType::I16][ElType::I32][MatrixAPackedInt8INTEL | MatrixBPackedInt8INTEL] = "u8_u8_";
m_Simd16Table[32][ElType::I32][ElType::I16][ElType::I32]
[MatrixAPackedInt8INTEL | MatrixBPackedInt8INTEL | MatrixASignedComponentsINTEL] = "s8_u8_";
m_Simd16Table[32][ElType::I32][ElType::I16][ElType::I32]
[MatrixAPackedInt8INTEL | MatrixBPackedInt8INTEL | MatrixBSignedComponentsINTEL] = "u8_s8_";
m_Simd16Table[32][ElType::I32][ElType::I16][ElType::I32][MatrixAPackedInt8INTEL | MatrixBPackedInt8INTEL |
MatrixASignedComponentsINTEL |
MatrixBSignedComponentsINTEL] = "s8_s8_";
// 4-bit integer matrix sources (signed and unsigned), 32-bit integer accumulator:
m_Simd16Table[64][ElType::I32][ElType::I16][ElType::I32][MatrixAPackedInt4INTEL | MatrixBPackedInt4INTEL] = "u4_u4_";
m_Simd16Table[64][ElType::I32][ElType::I16][ElType::I32]
[MatrixAPackedInt4INTEL | MatrixBPackedInt4INTEL | MatrixASignedComponentsINTEL] = "s4_u4_";
m_Simd16Table[64][ElType::I32][ElType::I16][ElType::I32]
[MatrixAPackedInt4INTEL | MatrixBPackedInt4INTEL | MatrixBSignedComponentsINTEL] = "u4_s4_";
m_Simd16Table[64][ElType::I32][ElType::I16][ElType::I32][MatrixAPackedInt4INTEL | MatrixBPackedInt4INTEL |
MatrixASignedComponentsINTEL |
MatrixBSignedComponentsINTEL] = "s4_s4_";
// fp16 matrix sources, fp32 accumulator:
m_Simd16Table[16][ElType::F32][ElType::I16][ElType::I32][MatrixAPackedFloat16INTEL | MatrixBPackedFloat16INTEL] =
"f_f_hf_hf_";
// bf16 matrix sources, fp32 accumulator:
m_Simd16Table[16][ElType::F32][ElType::I16][ElType::I32][MatrixAPackedBFloat16INTEL | MatrixBPackedBFloat16INTEL] =
"f_f_bf_bf_";
// fp16 matrix sources, fp16 accumulator:
m_Simd16Table[16][ElType::F16][ElType::I16][ElType::I32][MatrixAPackedFloat16INTEL | MatrixBPackedFloat16INTEL] =
"hf_hf_hf_hf_";
// bf16 matrix sources, bf16 accumulator:
m_Simd16Table[16][ElType::I16][ElType::I16][ElType::I32][MatrixResultBFloat16INTEL | MatrixAPackedBFloat16INTEL |
MatrixBPackedBFloat16INTEL | MatrixCBFloat16INTEL] =
"bf_bf_bf_bf_";
// tf32 matrix sources, fp32 accumulator:
m_Simd16Table[8][ElType::F32][ElType::F32][ElType::F32][MatrixATF32INTEL | MatrixBTF32INTEL] = "f_f_tf32_tf32_";
}
void SpvSubgroupMMAResolution::emitError(const Twine &message, const CallInst &CI) {
m_Ctx->EmitError(message.str().c_str(), &CI);
}
SpvSubgroupMMAResolution::ElType SpvSubgroupMMAResolution::getElType(const Type *Ty) const {
if (Ty->isIntegerTy(32))
return I32;
if (Ty->isIntegerTy(16))
return I16;
if (Ty->isFloatTy())
return F32;
if (Ty->isHalfTy())
return F16;
return Unknown;
}
StringRef SpvSubgroupMMAResolution::getElTypeStr(const SpvSubgroupMMAResolution::ElType Ty) const {
switch (Ty) {
case I32:
return "int32_t";
case I16:
return "int16_t";
case F32:
return "float32_t";
case F16:
return "float16_t";
default:
IGC_ASSERT_MESSAGE(0, "unexpected element type");
return "Unknown";
}
}
SpvSubgroupMMAResolution::ElType SpvSubgroupMMAResolution::getValidMatrixType(const Type *Ty) const {
if (Ty->isFloatingPointTy() || Ty->isIntegerTy())
return getElType(Ty);
if (auto *VTy = dyn_cast<FixedVectorType>(Ty))
return getValidMatrixType(VTy->getElementType());
return Unknown;
}
bool SpvSubgroupMMAResolution::validateI32Constant(const Value *V, const Twine &ParamName, const CallInst &CI) {
if (!isa<ConstantInt>(V) || !V->getType()->isIntegerTy(32)) {
emitError(Twine("__spirv_SubgroupMatrixMultiplyAccumulateINTEL: ") + ParamName +
" argument must be a constant scalar 32-bit integer",
CI);
return false;
}
return true;
}
bool SpvSubgroupMMAResolution::validateCType(const Type *ResultTy, const Type *CType, const CallInst &CI) {
if (ResultTy == CType)
return true;
std::string msg;
raw_string_ostream rso(msg);
rso << "__spirv_SubgroupMatrixMultiplyAccumulateINTEL: expected Result type to match type of Matrix C for targeted "
"HW. Result type: ";
ResultTy->print(rso);
rso << ", Matrix C type: ";
CType->print(rso);
emitError(msg, CI);
return false;
}
bool SpvSubgroupMMAResolution::validateElementType(const ElType ElemTy, StringRef ParamName, const CallInst &CI) {
if (ElemTy != Unknown)
return true;
emitError(Twine("__spirv_SubgroupMatrixMultiplyAccumulateINTEL: expected ") + ParamName +
" to be a scalar or vector of int32_t, int16_t, float32_t, or float16_t for targeted HW",
CI);
return false;
}
int SpvSubgroupMMAResolution::getElemCount(const Type *Ty) const {
if (auto *VTy = dyn_cast<FixedVectorType>(Ty))
return VTy->getNumElements();
return 1;
}
bool SpvSubgroupMMAResolution::validateElemCounts(int M, int AElemCount, int BElemCount, uint32_t Operands,
CallInst &CI) {
if (M != 1 && M != 2 && M != 4 && M != 8) {
emitError(
"__spirv_SubgroupMatrixMultiplyAccumulateINTEL: M dimension must be 1, 2, 4 or 8 for targeted HW. Actual: " +
std::to_string(M),
CI);
return false;
}
if (Operands & MatrixATF32INTEL) {
int expected = std::ceil(M / 2.0);
if (AElemCount != expected) {
emitError("__spirv_SubgroupMatrixMultiplyAccumulateINTEL: Matrix A argument must have ceil(M/2) components "
"when MatrixATF32INTEL operand is set for targeted HW. Expected " +
std::to_string(expected) + ". Actual " + std::to_string(M),
CI);
return false;
}
} else if (AElemCount != M) {
emitError("__spirv_SubgroupMatrixMultiplyAccumulateINTEL: Matrix A argument must have size " + std::to_string(M) +
" to match M defined by Result type for targeted HW. Actual: " + std::to_string(AElemCount),
CI);
return false;
}
const int expectedBCount = isDoubleSubgroup(CI) ? 4 : 8;
if (BElemCount != expectedBCount) {
emitError("__spirv_SubgroupMatrixMultiplyAccumulateINTEL: Matrix B argument must have " +
std::to_string(expectedBCount) +
" components for targeted HW. Actual: " + std::to_string(BElemCount),
CI);
return false;
}
return true;
}
// Dimension N is platform specific and is directly correlated to minimum subgroup-size for
// given platform. If DPAS with the same M, N, K dimensions is executed within a subgroup
// twice the size of minimum subgroup-size, each work item must contain half of the data
// compared to the minimum subgroup-size.
bool SpvSubgroupMMAResolution::isDoubleSubgroup(CallInst &CI) {
if (!m_Ctx->platform.hasExecSize16DPAS())
return false;
return IGC::getSIMDSize(getAnalysis<MetaDataUtilsWrapper>().getMetaDataUtils(), CI.getParent()->getParent()) == 32;
}
SpvSubgroupMMAResolution::SupportedTable *SpvSubgroupMMAResolution::getSupportedTable() {
if (m_Ctx->platform.hasExecSize16DPAS()) {
if (m_Simd16Table.empty())
populateSimd16Table();
return &m_Simd16Table;
}
if (m_Simd8Table.empty())
populateSimd8Table();
return &m_Simd8Table;
}
template <typename T>
bool SpvSubgroupMMAResolution::validateKDimInTable(const T KIt, int K, const SupportedTable *table,
const CallInst &CI) {
if (KIt != table->end())
return true;
SmallVector<std::string, 8> validKDims;
for (const auto &it : *table)
validKDims.push_back(std::to_string(it.first));
emitError(Twine("__spirv_SubgroupMatrixMultiplyAccumulateINTEL: expected K Dim = ") + llvm::join(validKDims, " or ") +
" for targeted HW. Actual: " + Twine(K),
CI);
return false;
}
template <typename TableType> std::string SpvSubgroupMMAResolution::getValidTypesStr(const TableType &table) const {
SmallVector<std::string, 8> validTypes;
for (const auto &it : table)
validTypes.push_back(getElTypeStr(it.first).str());
return llvm::join(validTypes, " or ");
}
template <typename T>
bool SpvSubgroupMMAResolution::validateResultElementInTable(const T RIt, int K, ElType ResultElemTy,
const RTable &table, const CallInst &CI) {
if (RIt != table.end())
return true;
emitError(Twine("__spirv_SubgroupMatrixMultiplyAccumulateINTEL: expected Result element type to be ") +
getValidTypesStr(table) + " for K Dim = " + Twine(K) +
" for targeted HW. Actual: " + getElTypeStr(ResultElemTy),
CI);
return false;
}
template <typename T>
bool SpvSubgroupMMAResolution::validateAElementInTable(const T AIt, int K, ElType ResultElemTy, ElType AElemTy,
const ATable &table, const CallInst &CI) {
if (AIt != table.end())
return true;
emitError(Twine("__spirv_SubgroupMatrixMultiplyAccumulateINTEL: expected A element type to be ") +
getValidTypesStr(table) + " for K Dim = " + Twine(K) + ", for Result element type " +
getElTypeStr(ResultElemTy) + ", for targeted HW. Actual: " + getElTypeStr(AElemTy),
CI);
return false;
}
template <typename T>
bool SpvSubgroupMMAResolution::validateBElementInTable(const T BIt, int K, ElType ResultElemTy, ElType AElemTy,
ElType BElemTy, const BTable &table, const CallInst &CI) {
if (BIt != table.end())
return true;
emitError(Twine("__spirv_SubgroupMatrixMultiplyAccumulateINTEL: expected B element type to be ") +
getValidTypesStr(table) + " for K Dim = " + Twine(K) + ", for Result element type " +
getElTypeStr(ResultElemTy) + ", for A element type " + getElTypeStr(AElemTy) +
", for targeted HW. Actual: " + getElTypeStr(BElemTy),
CI);
return false;
}
template <typename T>
bool SpvSubgroupMMAResolution::validateOperands(const T OpIt, int K, ElType ResultElemTy, ElType AElemTy,
ElType BElemTy, uint32_t Operands, const OperandsTable &operandMap,
const CallInst &CI) {
if (OpIt != operandMap.end())
return true;
std::stringstream ss;
ss << "__spirv_SubgroupMatrixMultiplyAccumulateINTEL: expected Operands to be one of these combinations:\n";
for (const auto &it : operandMap)
ss << it.first << ": " << GetHumanReadableOperand(it.first) << "\n";
ss << "for K Dim = " << K << ", for Result element type " << getElTypeStr(ResultElemTy).str();
ss << ", for A element type " << getElTypeStr(AElemTy).str() << ", for B element type " << getElTypeStr(BElemTy).str()
<< ", for targeted HW.\n";
ss << "Actual: " << Operands << ": " << GetHumanReadableOperand(Operands);
emitError(ss.str(), CI);
return false;
}
void SpvSubgroupMMAResolution::visitCallInst(CallInst &CI) {
Function *F = CI.getCalledFunction();
if (!F)
return;
StringRef funcName = F->getName();
if (!funcName.contains("__spirv_SubgroupMatrixMultiplyAccumulateINTEL"))
return;
int numArgs = IGCLLVM::getNumArgOperands(&CI);
if (numArgs != 5) {
emitError("__spirv_SubgroupMatrixMultiplyAccumulateINTEL: invalid number of arguments. Expected 5. Actual " +
std::to_string(numArgs),
CI);
return;
}
// Get arguments
Type *ResultTy = CI.getType();
Value *kDim = CI.getArgOperand(0);
Value *a = CI.getArgOperand(1);
Value *b = CI.getArgOperand(2);
Value *c = CI.getArgOperand(3);
Value *OpVaue = CI.getArgOperand(4);
if (!validateI32Constant(OpVaue, "Operands", CI))
return;
uint32_t Operands = cast<ConstantInt>(OpVaue)->getZExtValue();
if (!validateCType(ResultTy, c->getType(), CI))
return;
ElType ResultElemTy = getValidMatrixType(ResultTy);
ElType AElemTy = getValidMatrixType(a->getType());
ElType BElemTy = getValidMatrixType(b->getType());
if (!validateElementType(ResultElemTy, "Result", CI))
return;
if (!validateElementType(AElemTy, "Matrix A", CI))
return;
if (!validateElementType(BElemTy, "Matrix B", CI))
return;
// The number of components in Result Type defines the M dimension.
// If Result Type is a scalar type, the M dimension is one.
int M = getElemCount(ResultTy);
int AElemCount = getElemCount(a->getType());
int BElemCount = getElemCount(b->getType());
if (!validateElemCounts(M, AElemCount, BElemCount, Operands, CI))
return;
if (!validateI32Constant(kDim, "K Dim", CI))
return;
int K = cast<ConstantInt>(kDim)->getZExtValue();
SupportedTable *table = getSupportedTable();
auto KIt = table->find(K);
if (!validateKDimInTable(KIt, K, table, CI))
return;
auto ResultIt = KIt->second.find(ResultElemTy);
if (!validateResultElementInTable(ResultIt, K, ResultElemTy, KIt->second, CI))
return;
auto AIt = ResultIt->second.find(AElemTy);
if (!validateAElementInTable(AIt, K, ResultElemTy, AElemTy, ResultIt->second, CI))
return;
auto BIt = AIt->second.find(BElemTy);
if (!validateBElementInTable(BIt, K, ResultElemTy, AElemTy, BElemTy, AIt->second, CI))
return;
auto OperandsIt = BIt->second.find(Operands);
if (!validateOperands(OperandsIt, K, ResultElemTy, AElemTy, BElemTy, Operands, BIt->second, CI))
return;
// creating IB built-in
SmallVector<Value *, 3> args({c, a, b});
SmallVector<Type *, 3> argTypes({c->getType(), a->getType(), b->getType()});
FunctionType *FT = FunctionType::get(CI.getType(), argTypes, false);
std::string subgroupSize;
if (isDoubleSubgroup(CI)) {
subgroupSize = "32n16";
M *= 2;
} else {
subgroupSize = m_Ctx->platform.hasExecSize16DPAS() ? "16" : "";
}
std::stringstream newFuncName;
newFuncName << "__builtin_IB_sub_group" << subgroupSize;
newFuncName << "_" << (ResultElemTy == I32 ? "i" : "f");
newFuncName << "dpas_" << OperandsIt->second.str() << "8_" << M;
auto newFunc = m_Module->getOrInsertFunction(newFuncName.str(), FT);
auto newCall = CallInst::Create(newFunc, args, "", &CI);
CI.replaceAllUsesWith(newCall);
CI.eraseFromParent();
m_Changed = true;
if (F->use_empty())
m_BuiltinsToRemove.insert(F);
}