mirror of
https://github.com/intel/intel-graphics-compiler.git
synced 2025-10-30 08:18:26 +08:00
975 lines
36 KiB
C++
975 lines
36 KiB
C++
/*========================== begin_copyright_notice ============================
|
|
|
|
Copyright (C) 2018-2021 Intel Corporation
|
|
|
|
SPDX-License-Identifier: MIT
|
|
|
|
============================= end_copyright_notice ===========================*/
|
|
|
|
#include "Compiler/Optimizer/OpenCLPasses/DpasFuncs/DpasFuncsResolution.hpp"
|
|
#include "Compiler/Optimizer/OCLBIUtils.h"
|
|
#include "Compiler/IGCPassSupport.h"
|
|
#include "common/LLVMWarningsPush.hpp"
|
|
#include <llvm/Pass.h>
|
|
#include <llvm/IR/InstVisitor.h>
|
|
#include <llvm/IR/Function.h>
|
|
#include <llvm/IR/Instructions.h>
|
|
#include "llvmWrapper/IR/DerivedTypes.h"
|
|
#include "common/LLVMWarningsPop.hpp"
|
|
#include "Probe/Assertion.h"
|
|
|
|
using namespace llvm;
|
|
using namespace IGC;
|
|
using IGCLLVM::FixedVectorType;
|
|
|
|
namespace {
|
|
// Types for destination and accumulate.
|
|
enum DstAccType {
|
|
DSTACC_UNUSED,
|
|
DSTACC_FLOAT,
|
|
DSTACC_FP16,
|
|
DSTACC_BF16,
|
|
DSTACC_INT32
|
|
};
|
|
|
|
/// @brief DpasFuncsTranslation pass : tranlate dpas builtin (__builtin_IB_*dpas*) into igc intrinsic.
|
|
/// It also may combine several dpas intrinsics into a single one.
|
|
class DpasFuncsResolution : public FunctionPass, public InstVisitor<DpasFuncsResolution> {
|
|
public:
|
|
// Pass identification, replacement for typeid
|
|
static char ID;
|
|
|
|
DpasFuncsResolution();
|
|
|
|
~DpasFuncsResolution() {}
|
|
|
|
/// @brief Provides name of pass
|
|
virtual StringRef getPassName() const override {
|
|
// This string was changed from "DpasFuncsTranslation" due to IP leaks concers.
|
|
return "ArithmeticFuncsTranslation";
|
|
}
|
|
|
|
void getAnalysisUsage(AnalysisUsage &AU) const override {
|
|
AU.addRequired<CodeGenContextWrapper>();
|
|
AU.addRequired<MetaDataUtilsWrapper>();
|
|
}
|
|
|
|
virtual bool runOnFunction(Function &F) override;
|
|
|
|
void visitCallInst(CallInst &CI);
|
|
|
|
private:
|
|
/// Demangle the suffix of dpas. Return true if sucessful; false otherwise.
|
|
/// Suffix's format: [w_][<DstTy>_<AccTy>_][<PA>_<PB>_]<SD>_<RC> (see below)
|
|
bool demangleSuffix(StringRef FN, int StartPos, bool HasDstAcc, bool IsIDpas, int &DstTy, int &AccTy, int &PA,
|
|
int &PB, int &SD, int &RC, bool *IsDpasw);
|
|
|
|
/// Demangle the suffix of BFCvt. Return true if sucessful; false otherwise.
|
|
/// Suffix's format: [<rm>_]<1|2|4|8|16> (see description below)
|
|
bool demangleFCvtSuffix(StringRef FN, int StartPos, int *pRM, int *pVecLen, bool *pIsSat);
|
|
|
|
/// Indicates if the pass changed the processed function
|
|
bool m_changed{};
|
|
|
|
CodeGenContext *m_pCtx = nullptr;
|
|
std::string m_ErrorMsg{};
|
|
|
|
/// XeHP_SDV's simd8 intrinsics
|
|
///
|
|
/// The dpas builtin function's name has the suffix format as
|
|
/// <a's precision>_<b's precision>_<systolicDepth>_<repeatCount>
|
|
/// They are divided into four groups:
|
|
/// 1. Sub group versions (using other simd-lane's data):
|
|
/// 1.1 __builtin_IB_sub_group_idpas[w]_<s|u><2|4|8>_<s|u><2|4|8>_8_<1-8> (acc, a, b)
|
|
/// 1.2 __builtin_IB_sub_group_fdpas[w]_bf_bf_8_<1-8> (acc, a, b)
|
|
/// __builtin_IB_sub_group_fdpas[w]_hf_hf_8_<1-8> (acc, a, b)
|
|
/// 2. Work-item versions (using its own data, not using cross-lane data)
|
|
/// 2.1 __builtin_IB_idpas[w]_<s|u><2|4|8>_<s|u><2|4|8>_8_<1-8> (acc, a, b)
|
|
/// 2.2 __builtin_IB_fdpas[w]_bf_bf_8_<1-8> (acc, a, b)
|
|
/// __builtin_IB_fdpas[w]_hf_hf_8_<1-8> (acc, a, b)
|
|
///
|
|
/// Note that <a|b|c> denotes one of a, b, or c. "1-8" denotes 1, 2, ..., up to 8.
|
|
/// And for dpasw, repeat count = 2|4|8 are supported only for now.
|
|
static const StringRef SG_PREFIX_IDPAS;
|
|
static const StringRef SG_PREFIX_FDPAS;
|
|
static const StringRef WI_PREFIX_IDPAS;
|
|
static const StringRef WI_PREFIX_FDPAS;
|
|
/// The following are intrinsic for PVC simd16 only.
|
|
/// __builtin_IB_sub_group16_idpas<suffix>
|
|
/// <suffix> : _<a's precision>_<b's precision>_<depth>_<rcount>
|
|
/// ie. _<u|s><2|4|8>_<u|s><2|4|8>_8_<1-8>
|
|
/// the same as XeHP_SDV simd8 intrinsic.
|
|
/// __builtin_IB_sub_group16_fdpas<suffux>
|
|
/// <suffix> : _<retty>_<accty>_<aty>_<bty>_<depth>_<rcount>
|
|
/// 1. _<f|x>_<f|x>_<x>_<x>_8_<1-8>
|
|
/// x: <hf | bf>
|
|
/// 2. _f_f_tf32_tf32_8_<1-8>
|
|
///
|
|
static const StringRef SG_PREFIX_IDPAS16;
|
|
static const StringRef SG_PREFIX_FDPAS16;
|
|
static const StringRef SG_PREFIX_IDPAS32N16;
|
|
static const StringRef SG_PREFIX_FDPAS32N16;
|
|
// PVC+: pure hf/bf dpas builtins
|
|
static const StringRef WI_PREFIX_HFDPAS;
|
|
static const StringRef WI_PREFIX_BFDPAS;
|
|
static const StringRef SG_PREFIX_HFDPAS;
|
|
static const StringRef SG_PREFIX_BFDPAS;
|
|
static const StringRef SG_PREFIX_SDPAS16;
|
|
|
|
/// The bf conversion builtin function's name has the format as
|
|
/// __builtin_IB_<srcType>to<dstType>[_<rm>]_<1|2|3|4|8|16>
|
|
/// where
|
|
/// srcType/dstType : bf(as short) or f(float).
|
|
/// Note that 2bf (as int) and 2f are packed cvt from two float to a
|
|
/// pair of bf.
|
|
/// <rm> : rtz/rte/rtp/rtn
|
|
/// If rm is not present, it is default (rte).
|
|
/// <1|2|3|4|8|16> : vector size of its argument. "1" is for scalar.
|
|
///
|
|
/// **Note that [_<rm>] denotes _<rm> is optional.**
|
|
///
|
|
/// Currently, support builtin are:
|
|
/// __builtin_IB_ftobf[_<rm>]_<1|2|3|4|8|16>
|
|
/// __builtin_IB_bftof_<1|2|3|4|8|16> // no RM as it is precise
|
|
/// __builtin_IB_2fto2bf[_<rm>]_<1|2|3|4|8|16>
|
|
bool processCvt(CallInst &CI);
|
|
|
|
/// Naming convertion of Stochastic rounding builtin
|
|
/// __builtin_IB_srnd_ftohf_<1|2|3|4|8|16> (a, r)
|
|
/// __builtin_IB_srnd_hftobf8_<1|2|3|4|8|16>(a, r)
|
|
bool processSrnd(CallInst &CI);
|
|
/// Naming convertion of sdpas builtin
|
|
/// __builtin_IB_sub_group16_sdpas__<retty>_<accty>_<aty>_<bty>_<depth>_<rcount>
|
|
/// retty/accty : f|hf|bf|d
|
|
/// aty/bty : u8|s8|bf|bf8|hf8|hf|tf32
|
|
/// depth : 16
|
|
/// (b is compressed and is the half of the depth).
|
|
/// rcount : 7-8
|
|
/// Only a limited set of combination of retty/accty/aty/bty is allowed.
|
|
bool processSdpas(CallInst &CI);
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
/// StringRef parsing functions' common arguments
|
|
/// StrRef: string to be parsed
|
|
/// StrPos: the starting position of string.
|
|
/// StrRem: the remaining number of chars at StrPos of StrRef.
|
|
///
|
|
/// Each function will parse particular patterns. Once found,
|
|
/// adjust StrPos to point to the next field, and StrRem to
|
|
/// the number of chars remained unparsed.
|
|
///
|
|
/// The suffix patterns will be parsed by a sequence of parsing
|
|
/// functions. If one parsing function fails, the parsing functions
|
|
/// following the failing one in the sequence will definitely fails.
|
|
/// With this, we can just check the status of the last parsing function
|
|
/// to see if the entire sequence of parsing functions fail or not.
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
// Parse type string for destination or accumulate operands
|
|
// Pattern: "_f" | "_hf" | "_bf"
|
|
DstAccType parseDstAccType(StringRef StrRef, size_t &StrPos, size_t &StrRem);
|
|
|
|
// parse wide version of dpas : "[w]"
|
|
// If it is "w", return true; otherwise, return false.
|
|
// (As 'w' is optional, this function never fails.)
|
|
bool parseW(StringRef StrRef, size_t &StrPos, size_t &StrRem);
|
|
|
|
//
|
|
// Find the following patterns:
|
|
//
|
|
// "_bf" | "_hf" | "_<s|u><2|4|8>" | tf32
|
|
//
|
|
// If success, return the type denoted by this string pattern;
|
|
// otherwise, return PrecisionType::PRECISION_UNUSED.
|
|
//
|
|
PrecisionType parsePrecision(StringRef StrRef, size_t &StrPos, size_t &StrRem);
|
|
|
|
// Pattern: '_8'
|
|
// Return depth if valid, return -1 otherwise.
|
|
int parseDepth(StringRef StrRef, size_t &StrPos, size_t &StrRem);
|
|
|
|
// Pattern: '_<1-8>'
|
|
// Return repeat count if valid, return -1 otherwise.
|
|
int parseRCount(StringRef StrRef, size_t &StrPos, size_t &StrRem);
|
|
|
|
};
|
|
} // namespace
|
|
|
|
char DpasFuncsResolution::ID = 0;
|
|
const StringRef DpasFuncsResolution::SG_PREFIX_IDPAS = "__builtin_IB_sub_group_idpas";
|
|
const StringRef DpasFuncsResolution::SG_PREFIX_FDPAS = "__builtin_IB_sub_group_fdpas";
|
|
const StringRef DpasFuncsResolution::WI_PREFIX_IDPAS = "__builtin_IB_idpas";
|
|
const StringRef DpasFuncsResolution::WI_PREFIX_FDPAS = "__builtin_IB_fdpas";
|
|
const StringRef DpasFuncsResolution::SG_PREFIX_IDPAS16 = "__builtin_IB_sub_group16_idpas";
|
|
const StringRef DpasFuncsResolution::SG_PREFIX_FDPAS16 = "__builtin_IB_sub_group16_fdpas";
|
|
const StringRef DpasFuncsResolution::SG_PREFIX_IDPAS32N16 = "__builtin_IB_sub_group32n16_idpas";
|
|
const StringRef DpasFuncsResolution::SG_PREFIX_FDPAS32N16 = "__builtin_IB_sub_group32n16_fdpas";
|
|
// PVC+: pure hf/bf dpas builtins
|
|
const StringRef DpasFuncsResolution::WI_PREFIX_HFDPAS = "__builtin_IB_hfdpas";
|
|
const StringRef DpasFuncsResolution::WI_PREFIX_BFDPAS = "__builtin_IB_bfdpas";
|
|
const StringRef DpasFuncsResolution::SG_PREFIX_HFDPAS = "__builtin_IB_sub_group_hfdpas";
|
|
const StringRef DpasFuncsResolution::SG_PREFIX_BFDPAS = "__builtin_IB_sub_group_bfdpas";
|
|
const StringRef DpasFuncsResolution::SG_PREFIX_SDPAS16 = "__builtin_IB_sub_group16_sdpas";
|
|
|
|
// Register pass to igc-opt
|
|
#define PASS_FLAG \
|
|
"igc-arith-funcs-translation" // This string was changed from "igc-dpas-funcs-translation" due to IP leaks concers.
|
|
#define PASS_DESCRIPTION \
|
|
"Translate arithmetic builtin functions into igc intrinsics" // This string was changed from "Translate dpas builtin
|
|
// functions into igc intrinsics" due to IP leaks
|
|
// concers.
|
|
#define PASS_CFG_ONLY false
|
|
#define PASS_ANALYSIS false
|
|
IGC_INITIALIZE_PASS_BEGIN(DpasFuncsResolution, PASS_FLAG, PASS_DESCRIPTION, PASS_CFG_ONLY, PASS_ANALYSIS)
|
|
IGC_INITIALIZE_PASS_DEPENDENCY(CodeGenContextWrapper)
|
|
IGC_INITIALIZE_PASS_DEPENDENCY(MetaDataUtilsWrapper)
|
|
IGC_INITIALIZE_PASS_END(DpasFuncsResolution, PASS_FLAG, PASS_DESCRIPTION, PASS_CFG_ONLY, PASS_ANALYSIS)
|
|
|
|
DpasFuncsResolution::DpasFuncsResolution(void) : FunctionPass(ID) {
|
|
initializeDpasFuncsResolutionPass(*PassRegistry::getPassRegistry());
|
|
}
|
|
|
|
bool DpasFuncsResolution::runOnFunction(Function &F) {
|
|
m_pCtx = getAnalysis<CodeGenContextWrapper>().getCodeGenContext();
|
|
m_changed = false;
|
|
|
|
visit(F);
|
|
|
|
if (!m_ErrorMsg.empty()) {
|
|
m_pCtx->EmitError(m_ErrorMsg.c_str(), &F);
|
|
m_ErrorMsg.clear();
|
|
}
|
|
return m_changed;
|
|
}
|
|
|
|
void DpasFuncsResolution::visitCallInst(CallInst &CI) {
|
|
// Skip if there is any error
|
|
if (!m_ErrorMsg.empty()) {
|
|
return;
|
|
}
|
|
|
|
if (processSrnd(CI)) {
|
|
return;
|
|
}
|
|
// Handle bf cvt if it is.
|
|
if (processCvt(CI)) {
|
|
return;
|
|
}
|
|
|
|
/// Process DPAS intrinsics
|
|
Function *func = CI.getCalledFunction();
|
|
if (!func)
|
|
return;
|
|
StringRef funcName = func->getName();
|
|
LLVMContext &Ctx = CI.getContext();
|
|
Type *intTy = Type::getInt32Ty(Ctx);
|
|
Type *boolTy = Type::getInt1Ty(Ctx);
|
|
|
|
bool IsDpasw = false;
|
|
bool IsIDpas = false;
|
|
// Dimension N is platform specific and is directly correlated to minimum subgroup-size for
|
|
// given platform. If DPAS with the same M, N, K dimensions is executed within a subgroup
|
|
// twice the size of minimum subgroup-size, each work item must contain half of the data
|
|
// compared to the minimum subgroup-size.
|
|
bool IsDoubleSubgroup = false;
|
|
int DstTy, AccTy, PA, PB, SD, RC;
|
|
GenISAIntrinsic::ID iid = GenISAIntrinsic::no_intrinsic;
|
|
bool doVerify = false;
|
|
#if defined(_DEBUG)
|
|
doVerify = true;
|
|
#endif
|
|
if (m_pCtx->platform.hasExecSize16DPAS()) {
|
|
// PVC
|
|
if (funcName.startswith(DpasFuncsResolution::SG_PREFIX_IDPAS16)) {
|
|
const int SG_PREFIX_LEN = DpasFuncsResolution::SG_PREFIX_IDPAS16.size();
|
|
IsIDpas = true;
|
|
if (!demangleSuffix(funcName, SG_PREFIX_LEN, false, IsIDpas, DstTy, AccTy, PA, PB, SD, RC, nullptr))
|
|
return;
|
|
iid = GenISAIntrinsic::GenISA_sub_group_dpas;
|
|
} else if (funcName.startswith(DpasFuncsResolution::SG_PREFIX_IDPAS32N16)) {
|
|
const int SG_PREFIX_LEN = DpasFuncsResolution::SG_PREFIX_IDPAS32N16.size();
|
|
IsIDpas = true;
|
|
IsDoubleSubgroup = true;
|
|
if (!demangleSuffix(funcName, SG_PREFIX_LEN, false, IsIDpas, DstTy, AccTy, PA, PB, SD, RC, nullptr))
|
|
return;
|
|
iid = GenISAIntrinsic::GenISA_sub_group_dpas;
|
|
} else if (funcName.startswith(DpasFuncsResolution::SG_PREFIX_FDPAS16)) {
|
|
const int SG_PREFIX_LEN = DpasFuncsResolution::SG_PREFIX_FDPAS16.size();
|
|
IsIDpas = false;
|
|
if (!demangleSuffix(funcName, SG_PREFIX_LEN, true, IsIDpas, DstTy, AccTy, PA, PB, SD, RC, nullptr))
|
|
return;
|
|
iid = GenISAIntrinsic::GenISA_sub_group_dpas;
|
|
} else if (funcName.startswith(DpasFuncsResolution::SG_PREFIX_FDPAS32N16)) {
|
|
const int SG_PREFIX_LEN = DpasFuncsResolution::SG_PREFIX_FDPAS32N16.size();
|
|
IsIDpas = false;
|
|
IsDoubleSubgroup = true;
|
|
if (!demangleSuffix(funcName, SG_PREFIX_LEN, true, IsIDpas, DstTy, AccTy, PA, PB, SD, RC, nullptr))
|
|
return;
|
|
iid = GenISAIntrinsic::GenISA_sub_group_dpas;
|
|
}
|
|
else {
|
|
return;
|
|
}
|
|
} else {
|
|
if (funcName.startswith(DpasFuncsResolution::SG_PREFIX_IDPAS)) {
|
|
const int SG_PREFIX_LEN = DpasFuncsResolution::SG_PREFIX_IDPAS.size();
|
|
IsIDpas = true;
|
|
if (!demangleSuffix(funcName, SG_PREFIX_LEN, false, IsIDpas, DstTy, AccTy, PA, PB, SD, RC, &IsDpasw))
|
|
return;
|
|
iid = GenISAIntrinsic::GenISA_sub_group_dpas;
|
|
} else if (funcName.startswith(DpasFuncsResolution::SG_PREFIX_FDPAS)) {
|
|
const int SG_PREFIX_LEN = DpasFuncsResolution::SG_PREFIX_FDPAS.size();
|
|
IsIDpas = false;
|
|
if (!demangleSuffix(funcName, SG_PREFIX_LEN, false, IsIDpas, DstTy, AccTy, PA, PB, SD, RC, &IsDpasw))
|
|
return;
|
|
iid = GenISAIntrinsic::GenISA_sub_group_dpas;
|
|
} else if (funcName.startswith(DpasFuncsResolution::WI_PREFIX_IDPAS)) {
|
|
const int WI_PREFIX_LEN = DpasFuncsResolution::WI_PREFIX_IDPAS.size();
|
|
IsIDpas = true;
|
|
if (!demangleSuffix(funcName, WI_PREFIX_LEN, false, IsIDpas, DstTy, AccTy, PA, PB, SD, RC, &IsDpasw))
|
|
return;
|
|
iid = GenISAIntrinsic::GenISA_dpas;
|
|
} else if (funcName.startswith(DpasFuncsResolution::WI_PREFIX_FDPAS)) {
|
|
const int WI_PREFIX_LEN = DpasFuncsResolution::WI_PREFIX_FDPAS.size();
|
|
IsIDpas = false;
|
|
if (!demangleSuffix(funcName, WI_PREFIX_LEN, false, IsIDpas, DstTy, AccTy, PA, PB, SD, RC, &IsDpasw))
|
|
return;
|
|
iid = GenISAIntrinsic::GenISA_dpas;
|
|
} else if (funcName.startswith(DpasFuncsResolution::SG_PREFIX_HFDPAS) ||
|
|
funcName.startswith(DpasFuncsResolution::SG_PREFIX_BFDPAS)) {
|
|
const int SG_PREFIX_HF_LEN = DpasFuncsResolution::SG_PREFIX_HFDPAS.size();
|
|
IsIDpas = false;
|
|
if (!demangleSuffix(funcName, SG_PREFIX_HF_LEN, false, IsIDpas, DstTy, AccTy, PA, PB, SD, RC, &IsDpasw))
|
|
return;
|
|
iid = GenISAIntrinsic::GenISA_sub_group_dpas;
|
|
} else if (funcName.startswith(DpasFuncsResolution::WI_PREFIX_HFDPAS) ||
|
|
funcName.startswith(DpasFuncsResolution::WI_PREFIX_BFDPAS)) {
|
|
const int WI_PREFIX_HF_LEN = DpasFuncsResolution::WI_PREFIX_HFDPAS.size();
|
|
IsIDpas = false;
|
|
if (!demangleSuffix(funcName, WI_PREFIX_HF_LEN, false, IsIDpas, DstTy, AccTy, PA, PB, SD, RC, &IsDpasw))
|
|
return;
|
|
iid = GenISAIntrinsic::GenISA_dpas;
|
|
} else {
|
|
return;
|
|
}
|
|
}
|
|
|
|
#if defined(_DEBUG) || defined(_INTERNAL)
|
|
// verify that intrinsic is valid
|
|
if (!IsDpasw && !m_pCtx->platform.supportDpasInstruction()) {
|
|
m_ErrorMsg = "Dpas instruction not supported!";
|
|
IGC_ASSERT_MESSAGE(0, "Dpas instruction not supported!");
|
|
return;
|
|
}
|
|
if (IsDpasw && !m_pCtx->platform.supportDpaswInstruction()) {
|
|
m_ErrorMsg = "Dpasw instruction not supported!";
|
|
IGC_ASSERT_MESSAGE(0, "Dpasw instruction not supported!");
|
|
return;
|
|
}
|
|
|
|
if (doVerify) {
|
|
// Additional intrinsic checks
|
|
Value *ACC = CI.getArgOperand(0);
|
|
Value *A = CI.getArgOperand(1);
|
|
Value *B = CI.getArgOperand(2);
|
|
|
|
Type *DTy = CI.getType();
|
|
Type *ACCTy = ACC->getType();
|
|
Type *ATy = A->getType();
|
|
Type *BTy = B->getType();
|
|
int D_nelts = DTy->isVectorTy() ? (int)cast<FixedVectorType>(DTy)->getNumElements() : 1;
|
|
int ACC_nelts = ACCTy->isVectorTy() ? (int)cast<FixedVectorType>(ACCTy)->getNumElements() : 1;
|
|
int A_nelts = ATy->isVectorTy() ? (int)cast<FixedVectorType>(ATy)->getNumElements() : 1;
|
|
int B_nelts = BTy->isVectorTy() ? (int)cast<FixedVectorType>(BTy)->getNumElements() : 1;
|
|
Type *D_BaseTy = DTy->getScalarType();
|
|
Type *ACC_BaseTy = ACCTy->getScalarType();
|
|
Type *A_BaseTy = ATy->getScalarType();
|
|
Type *B_BaseTy = BTy->getScalarType();
|
|
|
|
if (IsDoubleSubgroup) {
|
|
IGC_ASSERT_MESSAGE(RC >= 2, "ICE: repeat count of DPAS for double subgroup-size must be >= 2!");
|
|
D_nelts *= 2;
|
|
ACC_nelts *= 2;
|
|
A_nelts *= 2;
|
|
B_nelts *= 2;
|
|
}
|
|
|
|
if (IsIDpas) {
|
|
uint32_t Abits = getPrecisionInBits((PrecisionType)PA);
|
|
uint32_t Bbits = getPrecisionInBits((PrecisionType)PB);
|
|
bool is_2xint8 = (Abits != 8 && Bbits != 8);
|
|
uint32_t AbitsPerDepth = Abits * (is_2xint8 ? 8 : 4);
|
|
uint32_t BbitsPerDepth = Bbits * (is_2xint8 ? 8 : 4);
|
|
uint32_t B_nDW = (BbitsPerDepth * SD) / 32;
|
|
if (m_pCtx->platform.hasExecSize16DPAS()) {
|
|
// depth is still 8, the subgroup intrinsic will get
|
|
// one-depth data from two work-items.
|
|
AbitsPerDepth = AbitsPerDepth / 2;
|
|
}
|
|
|
|
if (DstTy != DSTACC_INT32 || AccTy != DSTACC_INT32 || D_nelts != RC || ACC_nelts != RC || B_nelts != B_nDW ||
|
|
RC != (IsDpasw ? 2 * A_nelts : A_nelts)) {
|
|
IGC_ASSERT_MESSAGE(0, "ICE: invalid integer dpas instructions!");
|
|
}
|
|
IGC_ASSERT_MESSAGE(A_BaseTy->isIntegerTy(AbitsPerDepth), "ICE: type of dpas[w]'s A wrong!");
|
|
IGC_ASSERT_MESSAGE(B_BaseTy->isIntegerTy(32), "ICE: type of dpas[w]'s B should be int32!");
|
|
IGC_ASSERT_MESSAGE(D_BaseTy->isIntegerTy(32), "ICE: type of dpas[w]'s D should int32!");
|
|
IGC_ASSERT_MESSAGE(ACC_BaseTy->isIntegerTy(32), "ICE: type of dpas[w]'s ACC should int32!");
|
|
}
|
|
else { // fdpas
|
|
bool precOk = (PA == PB);
|
|
IGC_ASSERT_MESSAGE(D_nelts == RC, "ICE: dpas intrinsic has mismatched vector sizes of arguments!");
|
|
IGC_ASSERT_MESSAGE(ACC_nelts == RC, "ICE: dpas intrinsic has mismatched vector sizes of arguments!");
|
|
IGC_ASSERT_MESSAGE(B_nelts == SD, "ICE: dpas intrinsic has mismatched vector sizes of arguments!");
|
|
IGC_ASSERT_MESSAGE(precOk, "ICE: dpas's A and B have illegal type combination!");
|
|
IGC_ASSERT_MESSAGE(B_BaseTy->isIntegerTy(32) || (PB == PrecisionType::TF32 && B_BaseTy->isFloatTy()),
|
|
"ICE: dpas's arg B shall have base type int32 or float!");
|
|
IGC_ASSERT_MESSAGE(
|
|
(RC == (IsDpasw ? 2 * A_nelts : A_nelts) || (PA == PrecisionType::TF32 && (RC == 2 * A_nelts))),
|
|
"ICE: dpas's arg A has wrong element size!");
|
|
|
|
uint32_t AbitsPerDepth = 32;
|
|
if (m_pCtx->platform.hasExecSize16DPAS()) {
|
|
AbitsPerDepth = AbitsPerDepth / 2;
|
|
}
|
|
|
|
IGC_ASSERT_MESSAGE(A_BaseTy->isIntegerTy(AbitsPerDepth) || (PA == PrecisionType::TF32 && A_BaseTy->isFloatTy()),
|
|
"ICE: dpas intrinsic's A has wrong base type!");
|
|
if (PA == PrecisionType::TF32) {
|
|
if (!(DstTy == DSTACC_FLOAT && AccTy == DSTACC_FLOAT)) {
|
|
IGC_ASSERT_MESSAGE(false, "ICE: wrong type of dst/acc for TF32 dpas!");
|
|
}
|
|
}
|
|
|
|
bool typeOK = false;
|
|
if (DstTy == DSTACC_BF16 || AccTy == DSTACC_BF16) {
|
|
typeOK = (typeOK || PA == PrecisionType::BF16);
|
|
IGC_ASSERT_MESSAGE(typeOK, "ICE: wrong type of dpas dst/acc!");
|
|
} else if (DstTy == DSTACC_FP16 || AccTy == DSTACC_FP16) {
|
|
typeOK = (typeOK || PA == PrecisionType::FP16);
|
|
IGC_ASSERT_MESSAGE(typeOK, "ICE: wrong type of dpas dst/acc!");
|
|
}
|
|
}
|
|
}
|
|
|
|
#endif
|
|
|
|
Value *args[8];
|
|
args[0] = CI.getArgOperand(0);
|
|
args[1] = CI.getArgOperand(1);
|
|
|
|
Value *B = CI.getArgOperand(2);
|
|
Type *BTy = B->getType();
|
|
if (FixedVectorType *BVecTy = dyn_cast<FixedVectorType>(BTy); BVecTy && BTy->getScalarType()->isFloatTy()) {
|
|
B = CastInst::Create(Instruction::CastOps::BitCast, B,
|
|
FixedVectorType::get(intTy, (unsigned)BVecTy->getNumElements()), B->getName() + ".cast", &CI);
|
|
}
|
|
args[2] = B;
|
|
|
|
args[3] = ConstantInt::get(intTy, PA);
|
|
args[4] = ConstantInt::get(intTy, PB);
|
|
args[5] = ConstantInt::get(intTy, SD);
|
|
args[6] = ConstantInt::get(intTy, RC);
|
|
args[7] = ConstantInt::get(boolTy, IsDpasw);
|
|
|
|
// ITys: overload types for this intrinsic
|
|
Type *ITys[4] = {func->getReturnType(), args[0]->getType(), args[1]->getType(), args[2]->getType()};
|
|
Function *dpasFunc = GenISAIntrinsic::getDeclaration(func->getParent(), iid, ITys);
|
|
|
|
Instruction *dpasCall = CallInst::Create(dpasFunc, args, VALUE_NAME("dpas"), &CI);
|
|
|
|
updateDebugLoc(&CI, dpasCall);
|
|
|
|
CI.replaceAllUsesWith(dpasCall);
|
|
CI.eraseFromParent();
|
|
|
|
m_changed = true;
|
|
}
|
|
|
|
bool DpasFuncsResolution::processCvt(CallInst &CI) {
|
|
Function *func = CI.getCalledFunction();
|
|
if (!func)
|
|
return false;
|
|
StringRef funcName = func->getName();
|
|
LLVMContext &Ctx = CI.getContext();
|
|
Type *intTy = Type::getInt32Ty(Ctx);
|
|
Type *boolTy = Type::getInt1Ty(Ctx);
|
|
|
|
int FP_RM = ROUND_TO_NEAREST_EVEN; // default
|
|
int VecLen;
|
|
bool isSat;
|
|
GenISAIntrinsic::ID iid;
|
|
Value *args[3];
|
|
uint32_t argslen;
|
|
if (funcName.startswith("__builtin_IB_ftobf_")) {
|
|
if (!demangleFCvtSuffix(funcName, (int)sizeof("__builtin_IB_ftobf_") - 1, &FP_RM, &VecLen, nullptr))
|
|
return false;
|
|
|
|
iid = GenISAIntrinsic::GenISA_ftobf;
|
|
args[0] = CI.getArgOperand(0); // value to be converted
|
|
args[1] = ConstantInt::get(intTy, FP_RM); // rounding mode
|
|
argslen = 2;
|
|
} else if (funcName.startswith("__builtin_IB_bftof_")) {
|
|
// It is a precise conversion, no RM needed!
|
|
// Note that sizeof() includes the ending '\0', so need to do -1!
|
|
if (!demangleFCvtSuffix(funcName, (int)sizeof("__builtin_IB_bftof_") - 1, nullptr, &VecLen, nullptr))
|
|
return false;
|
|
|
|
iid = GenISAIntrinsic::GenISA_bftof;
|
|
args[0] = CI.getArgOperand(0);
|
|
argslen = 1;
|
|
} else if (funcName.startswith("__builtin_IB_2fto2bf_")) {
|
|
if (!demangleFCvtSuffix(funcName, (int)sizeof("__builtin_IB_2fto2bf_") - 1, &FP_RM, &VecLen, nullptr))
|
|
return false;
|
|
|
|
iid = GenISAIntrinsic::GenISA_2fto2bf;
|
|
args[0] = CI.getArgOperand(0); // value to be converted
|
|
args[1] = CI.getArgOperand(1); // value to be converted
|
|
args[2] = ConstantInt::get(intTy, FP_RM); // rounding mode
|
|
argslen = 3;
|
|
} else if (funcName.startswith("__builtin_IB_hftobf8_")) {
|
|
int sz = (int)sizeof("__builtin_IB_hftobf8_");
|
|
if (!demangleFCvtSuffix(funcName, sz - 1, nullptr, &VecLen, &isSat))
|
|
return false;
|
|
|
|
iid = GenISAIntrinsic::GenISA_hftobf8;
|
|
args[0] = CI.getArgOperand(0); // value to be converted
|
|
args[1] = ConstantInt::get(intTy, FP_RM); // rounding mode
|
|
args[2] = ConstantInt::get(boolTy, isSat); // saturation
|
|
argslen = 3;
|
|
} else if (funcName.startswith("__builtin_IB_bf8tohf_")) {
|
|
int sz = (int)sizeof("__builtin_IB_bf8tohf_");
|
|
// It is a precise conversion, no RM needed!
|
|
// Note that sizeof() includes the ending '\0', so need to do -1!
|
|
if (!demangleFCvtSuffix(funcName, sz - 1, nullptr, &VecLen, nullptr))
|
|
return false;
|
|
|
|
iid = GenISAIntrinsic::GenISA_bf8tohf;
|
|
args[0] = CI.getArgOperand(0);
|
|
argslen = 1;
|
|
} else if (funcName.startswith("__builtin_IB_hftohf8_")) {
|
|
int sz = (int)sizeof("__builtin_IB_hftohf8_");
|
|
if (!demangleFCvtSuffix(funcName, sz - 1, nullptr, &VecLen, &isSat))
|
|
return false;
|
|
|
|
iid = GenISAIntrinsic::GenISA_hftohf8;
|
|
args[0] = CI.getArgOperand(0); // value to be converted
|
|
args[1] = ConstantInt::get(intTy, FP_RM); // rounding mode
|
|
args[2] = ConstantInt::get(boolTy, isSat); // saturation
|
|
argslen = 3;
|
|
} else if (funcName.startswith("__builtin_IB_hf8tohf_")) {
|
|
int sz = (int)sizeof("__builtin_IB_hf8tohf_");
|
|
// It is a precise conversion, no RM needed!
|
|
// Note that sizeof() includes the ending '\0', so need to do -1!
|
|
if (!demangleFCvtSuffix(funcName, sz - 1, nullptr, &VecLen, nullptr))
|
|
return false;
|
|
|
|
iid = GenISAIntrinsic::GenISA_hf8tohf;
|
|
args[0] = CI.getArgOperand(0);
|
|
argslen = 1;
|
|
} else if (funcName.startswith("__builtin_IB_ftotf32_")) {
|
|
if (!demangleFCvtSuffix(funcName, (int)sizeof("__builtin_IB_ftotf32_") - 1, nullptr, &VecLen, nullptr))
|
|
return false;
|
|
|
|
iid = GenISAIntrinsic::GenISA_ftotf32;
|
|
args[0] = CI.getArgOperand(0); // value to be converted
|
|
args[1] = ConstantInt::get(intTy, FP_RM); // rounding mode
|
|
argslen = 2;
|
|
} else {
|
|
return false;
|
|
}
|
|
|
|
// Sanity check
|
|
if (!m_pCtx->platform.supportDpasInstruction()) {
|
|
m_ErrorMsg = "bf conversion instruction not supported!";
|
|
IGC_ASSERT_MESSAGE(0, "bf conversion instruction not supported!");
|
|
return true;
|
|
}
|
|
Type *Ty = CI.getType();
|
|
FixedVectorType *VTy = dyn_cast<FixedVectorType>(Ty);
|
|
Type *ETy = VTy ? VTy->getElementType() : Ty;
|
|
Type *Opnd0Ty = CI.getArgOperand(0)->getType();
|
|
FixedVectorType *VOpnd0Ty = dyn_cast<FixedVectorType>(Opnd0Ty);
|
|
Type *EOpnd0Ty = VOpnd0Ty ? VOpnd0Ty->getElementType() : Opnd0Ty;
|
|
uint32_t n = VTy ? (uint32_t)VTy->getNumElements() : 1;
|
|
uint32_t n0 = VOpnd0Ty ? (uint32_t)VOpnd0Ty->getNumElements() : 1;
|
|
switch (iid) {
|
|
case GenISAIntrinsic::GenISA_ftobf:
|
|
case GenISAIntrinsic::GenISA_2fto2bf:
|
|
case GenISAIntrinsic::GenISA_bftof: {
|
|
if ((n != n0 || n != VecLen) ||
|
|
(iid == GenISAIntrinsic::GenISA_ftobf && !(EOpnd0Ty->isFloatTy() && ETy->isIntegerTy(16))) ||
|
|
(iid == GenISAIntrinsic::GenISA_2fto2bf && !(EOpnd0Ty->isFloatTy() && ETy->isIntegerTy(32))) ||
|
|
(iid == GenISAIntrinsic::GenISA_bftof && !(EOpnd0Ty->isIntegerTy(16) && ETy->isFloatTy()))) {
|
|
m_ErrorMsg = "Wrong argument types in bf conversion functions!";
|
|
IGC_ASSERT_MESSAGE(0, "Wrong argument types in bf conversion functions!");
|
|
return true;
|
|
}
|
|
break;
|
|
}
|
|
case GenISAIntrinsic::GenISA_hftobf8:
|
|
case GenISAIntrinsic::GenISA_bf8tohf: {
|
|
if ((n != n0 || n != VecLen) ||
|
|
(iid == GenISAIntrinsic::GenISA_hftobf8 && !(EOpnd0Ty->isHalfTy() && ETy->isIntegerTy(8))) ||
|
|
(iid == GenISAIntrinsic::GenISA_bf8tohf && !(EOpnd0Ty->isIntegerTy(8) && ETy->isHalfTy()))) {
|
|
m_ErrorMsg = "Wrong argument types in bf8 conversion functions!";
|
|
IGC_ASSERT_MESSAGE(0, "Wrong argument types in bf8 conversion functions!");
|
|
return true;
|
|
}
|
|
break;
|
|
}
|
|
case GenISAIntrinsic::GenISA_ftotf32: {
|
|
if ((n != n0 || n != VecLen) ||
|
|
(iid == GenISAIntrinsic::GenISA_ftotf32 && !(EOpnd0Ty->isFloatTy() && ETy->isFloatTy()))) {
|
|
m_ErrorMsg = "Wrong argument types in tf32 conversion functions!";
|
|
IGC_ASSERT_MESSAGE(0, "Wrong argument types in tf32 conversion functions!");
|
|
return true;
|
|
}
|
|
break;
|
|
}
|
|
default:
|
|
break;
|
|
}
|
|
|
|
ArrayRef<Value *> ii_args(args, argslen);
|
|
|
|
// Only need to specify retType and 1st arg's type.
|
|
Type *ITys[2] = {func->getReturnType(), args[0]->getType()};
|
|
Function *cvtFunc = GenISAIntrinsic::getDeclaration(func->getParent(), iid, ITys);
|
|
const char *cvt = "bf_cvt";
|
|
if (iid == GenISAIntrinsic::GenISA_hftobf8 || iid == GenISAIntrinsic::GenISA_bf8tohf) {
|
|
cvt = "bf8_cvt";
|
|
} else if (iid == GenISAIntrinsic::GenISA_hftohf8 || iid == GenISAIntrinsic::GenISA_hf8tohf) {
|
|
cvt = "hf8_cvt";
|
|
} else if (iid == GenISAIntrinsic::GenISA_ftotf32) {
|
|
cvt = "tf32_cvt";
|
|
}
|
|
Instruction *cvtCall = CallInst::Create(cvtFunc, ii_args, cvt, &CI);
|
|
|
|
updateDebugLoc(&CI, cvtCall);
|
|
|
|
CI.replaceAllUsesWith(cvtCall);
|
|
CI.eraseFromParent();
|
|
|
|
m_changed = true;
|
|
return true;
|
|
}
|
|
|
|
bool DpasFuncsResolution::processSrnd(CallInst &CI) {
|
|
Function *func = CI.getCalledFunction();
|
|
if (!func)
|
|
return false;
|
|
|
|
StringRef funcName = func->getName();
|
|
|
|
int VecLen;
|
|
bool isSat = false;
|
|
GenISAIntrinsic::ID iid;
|
|
if (funcName.consume_front("__builtin_IB_srnd_ftohf_")) {
|
|
if (!demangleFCvtSuffix(funcName, 0, nullptr, &VecLen, nullptr))
|
|
return false;
|
|
iid = GenISAIntrinsic::GenISA_srnd_ftohf;
|
|
} else if (funcName.consume_front("__builtin_IB_srnd_hftobf8_")) {
|
|
if (!demangleFCvtSuffix(funcName, 0, nullptr, &VecLen, &isSat))
|
|
return false;
|
|
iid = GenISAIntrinsic::GenISA_srnd_hftobf8;
|
|
}
|
|
else {
|
|
return false;
|
|
}
|
|
|
|
Type *boolTy = Type::getInt1Ty(CI.getContext());
|
|
Value *args[3] = {CI.getArgOperand(0), CI.getArgOperand(1), ConstantInt::get(boolTy, isSat)};
|
|
ArrayRef<Value *> ii_args(args, 3);
|
|
|
|
Type *ITys[4] = {func->getReturnType(), args[0]->getType(), args[1]->getType(), boolTy};
|
|
Function *srndFunc = GenISAIntrinsic::getDeclaration(func->getParent(), iid, ITys);
|
|
Instruction *srndCall = CallInst::Create(srndFunc, ii_args, VALUE_NAME("srnd"), &CI);
|
|
|
|
#if defined(_DEBUG)
|
|
{ // Verify arguments
|
|
Type *Ty = CI.getType();
|
|
FixedVectorType *VTy = dyn_cast<FixedVectorType>(Ty);
|
|
Type *ETy = VTy ? VTy->getElementType() : Ty;
|
|
Type *Opnd0Ty = CI.getArgOperand(0)->getType();
|
|
Type *Opnd1Ty = CI.getArgOperand(1)->getType();
|
|
FixedVectorType *VOpnd1Ty = dyn_cast<FixedVectorType>(Opnd1Ty);
|
|
Type *EOpnd1Ty = VOpnd1Ty ? VOpnd1Ty->getElementType() : Opnd1Ty;
|
|
FixedVectorType *VOpnd0Ty = dyn_cast<FixedVectorType>(Opnd0Ty);
|
|
Type *EOpnd0Ty = VOpnd0Ty ? VOpnd0Ty->getElementType() : Opnd0Ty;
|
|
uint32_t n = VTy ? (uint32_t)VTy->getNumElements() : 1;
|
|
uint32_t n0 = VOpnd0Ty ? (uint32_t)VOpnd0Ty->getNumElements() : 1;
|
|
|
|
bool supported = false;
|
|
supported |= (ETy->isHalfTy() && EOpnd0Ty->isFloatTy() && EOpnd1Ty->isIntegerTy(16));
|
|
supported |= (ETy->isIntegerTy(8) && EOpnd0Ty->isHalfTy() && EOpnd1Ty->isIntegerTy(8));
|
|
supported |= (ETy->isIntegerTy(8) && EOpnd0Ty->isIntegerTy(16) && EOpnd1Ty->isIntegerTy(8));
|
|
|
|
if (n != n0 || n != VecLen || !supported) {
|
|
m_ErrorMsg = "Wrong argument types in srnd builtin!";
|
|
IGC_ASSERT_MESSAGE(0, "Wrong argument types in srnd builtin!");
|
|
return true;
|
|
}
|
|
}
|
|
#endif
|
|
updateDebugLoc(&CI, srndCall);
|
|
CI.replaceAllUsesWith(srndCall);
|
|
CI.eraseFromParent();
|
|
|
|
m_changed = true;
|
|
return true;
|
|
}
|
|
|
|
|
|
//
|
|
// FN pattern:
|
|
// [w]_<dstty>_<accty>_<a's precision>_<b's precision>_<depth>_<rcount>
|
|
// <a's precision>
|
|
// <b's precision>
|
|
// 1. float version: <bf|hf>_
|
|
// 2. integer version: <u|s><2|4|8>_
|
|
// dstty/accty:
|
|
// 1. float version: f
|
|
// 2. integer version: int32
|
|
// If [w] is present, it is dpasw.
|
|
//
|
|
// PVC supports:
|
|
// additional dstty/accty: bf|hf
|
|
// additional precision : tf32
|
|
//
|
|
bool DpasFuncsResolution::demangleSuffix(StringRef FN, int StartPos, bool HasDstAcc, bool IsIDpas, int &DstTy,
|
|
int &AccTy, int &PA, int &PB, int &SD, int &RC, bool *IsDpasw) {
|
|
size_t sz = FN.size();
|
|
size_t rem = sz - StartPos;
|
|
size_t i = StartPos;
|
|
|
|
// Check if it is wide version of dpas
|
|
if (IsDpasw != nullptr) {
|
|
*IsDpasw = parseW(FN, i, rem);
|
|
}
|
|
|
|
if (HasDstAcc) {
|
|
DstTy = parseDstAccType(FN, i, rem);
|
|
AccTy = parseDstAccType(FN, i, rem);
|
|
} else {
|
|
DstTy = IsIDpas ? DstAccType::DSTACC_INT32 : DstAccType::DSTACC_FLOAT;
|
|
AccTy = DstTy;
|
|
}
|
|
|
|
bool supportDeprecated = true;
|
|
if (!IsIDpas && !HasDstAcc && supportDeprecated && rem == 4) {
|
|
// deprecated format _8_<1-8>
|
|
PA = PrecisionType::BF16;
|
|
PB = PA;
|
|
} else {
|
|
// parse precisions
|
|
PA = parsePrecision(FN, i, rem);
|
|
PB = parsePrecision(FN, i, rem);
|
|
}
|
|
|
|
// depth and repeat count
|
|
SD = parseDepth(FN, i, rem);
|
|
RC = parseRCount(FN, i, rem);
|
|
|
|
if (RC == -1) {
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool DpasFuncsResolution::demangleFCvtSuffix(StringRef FN, int StartPos, int *pRM, int *pVecLen, bool *pIsSat) {
|
|
int sz = (int)FN.size();
|
|
int rem = sz - StartPos;
|
|
int RM = ROUND_TO_NEAREST_EVEN;
|
|
int VecLen = 1;
|
|
bool isSat = false;
|
|
|
|
int i = StartPos;
|
|
if (rem >= 5 && pRM != nullptr) {
|
|
// if it is a valid intrinsic, it must be <rm>_<1|2|4|8|16>[_sat]
|
|
// <rm> is rte|rtp|rtn|rtz.
|
|
if (FN[i] != 'r' || FN[i + 1] != 't' || FN[i + 3] != '_') {
|
|
return false;
|
|
}
|
|
switch (FN[i + 2]) {
|
|
default:
|
|
return false;
|
|
case 'e':
|
|
RM = ROUND_TO_NEAREST_EVEN;
|
|
break;
|
|
case 'p':
|
|
RM = ROUND_TO_POSITIVE;
|
|
break;
|
|
case 'n':
|
|
RM = ROUND_TO_NEGATIVE;
|
|
break;
|
|
case 'z':
|
|
RM = ROUND_TO_ZERO;
|
|
break;
|
|
}
|
|
|
|
i += 4;
|
|
rem -= 4;
|
|
}
|
|
|
|
int c = (FN[i] - '0');
|
|
int c1 = (rem >= 2 ? (FN[i + 1] - '0') : 0);
|
|
|
|
// relax vector size to be 1-16 here.
|
|
if (rem >= 2 && c == 1 && c1 >= 0 && c1 <= 6) {
|
|
VecLen = 10 + c1;
|
|
i += 2;
|
|
rem -= 2;
|
|
} else if (rem >= 1 && c >= 0 && c <= 9) {
|
|
VecLen = c;
|
|
i += 1;
|
|
rem -= 1;
|
|
} else {
|
|
// missing veclen
|
|
return false;
|
|
}
|
|
|
|
// saturation
|
|
if (pIsSat) {
|
|
if (rem >= 1 && FN[i] == '_') {
|
|
++i;
|
|
--rem;
|
|
}
|
|
|
|
if (rem == 3 && FN[i] == 's' && FN[i + 1] == 'a' && FN[i + 2] == 't') {
|
|
i += 3;
|
|
rem -= 3;
|
|
isSat = true;
|
|
}
|
|
}
|
|
|
|
if (rem != 0) {
|
|
return false;
|
|
}
|
|
|
|
if (pRM) {
|
|
*pRM = RM;
|
|
}
|
|
*pVecLen = VecLen;
|
|
if (pIsSat) {
|
|
*pIsSat = isSat;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
DstAccType DpasFuncsResolution::parseDstAccType(StringRef StrRef, size_t &StrPos, size_t &StrRem) {
|
|
DstAccType ty = DSTACC_UNUSED;
|
|
if (StrPos != StringRef::npos && StrRem >= 2) {
|
|
char c0 = StrRef[StrPos];
|
|
char c1 = StrRef[StrPos + 1];
|
|
char c2 = StrRem >= 3 ? StrRef[StrPos + 2] : 0;
|
|
if (c0 == '_' && c1 == 'd') { // "_d"
|
|
ty = DSTACC_INT32;
|
|
StrPos += 2;
|
|
StrRem -= 2;
|
|
} else if (c0 == '_' && c1 == 'f') { // "_f"
|
|
ty = DSTACC_FLOAT;
|
|
StrPos += 2;
|
|
StrRem -= 2;
|
|
} else if (c0 == '_' && (c1 == 'b' || c1 == 'h') && c2 == 'f') { // "_bf" or "_hf"
|
|
ty = (c1 == 'b' ? DSTACC_BF16 : DSTACC_FP16);
|
|
StrPos += 3;
|
|
StrRem -= 3;
|
|
}
|
|
}
|
|
if (ty == DSTACC_UNUSED) {
|
|
// Not valid type
|
|
StrPos = StringRef::npos;
|
|
StrRem = 0;
|
|
}
|
|
return ty;
|
|
}
|
|
|
|
bool DpasFuncsResolution::parseW(StringRef StrRef, size_t &StrPos, size_t &StrRem) {
|
|
if (StrPos != StringRef::npos && StrRem >= 1) {
|
|
char c0 = StrRef[StrPos];
|
|
if (c0 == 'w') {
|
|
StrPos += 1;
|
|
StrRem -= 1;
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
PrecisionType DpasFuncsResolution::parsePrecision(StringRef StrRef, size_t &StrPos, size_t &StrRem) {
|
|
PrecisionType ty = PrecisionType::PRECISION_UNUSED;
|
|
if (StrPos != StringRef::npos && StrRem >= 3) {
|
|
char c0 = StrRef[StrPos];
|
|
char c1 = StrRef[StrPos + 1];
|
|
char c2 = StrRef[StrPos + 2];
|
|
char c3 = StrRem >= 4 ? StrRef[StrPos + 3] : 0;
|
|
char c4 = StrRem >= 5 ? StrRef[StrPos + 4] : 0;
|
|
if (c0 == '_' && c1 == 't' && c2 == 'f' && c3 == '3' && c4 == '2') { // "_tf32"
|
|
ty = PrecisionType::TF32;
|
|
StrPos += 5;
|
|
StrRem -= 5;
|
|
} else if (c0 == '_' && c1 == 'b' && c2 == 'f') { // "_bf"
|
|
ty = PrecisionType::BF16;
|
|
StrPos += 3;
|
|
StrRem -= 3;
|
|
} else if (c0 == '_' && c1 == 'h' && c2 == 'f') { // "_hf"
|
|
ty = PrecisionType::FP16;
|
|
StrPos += 3;
|
|
StrRem -= 3;
|
|
} else if (c0 == '_' && c1 == 'u' && (c2 == '2' || c2 == '4' || c2 == '8')) { // "_u<2|4|8>"
|
|
ty = (c2 == '2' ? PrecisionType::U2 : (c2 == '4' ? PrecisionType::U4 : PrecisionType::U8));
|
|
StrPos += 3;
|
|
StrRem -= 3;
|
|
} else if (c0 == '_' && c1 == 's' && (c2 == '2' || c2 == '4' || c2 == '8')) { // "s<2|4|8>_"
|
|
ty = (c2 == '2' ? PrecisionType::S2 : (c2 == '4' ? PrecisionType::S4 : PrecisionType::S8));
|
|
StrPos += 3;
|
|
StrRem -= 3;
|
|
}
|
|
}
|
|
if (ty == PRECISION_UNUSED) {
|
|
// Not a valid precision
|
|
StrPos = StringRef::npos;
|
|
StrRem = 0;
|
|
}
|
|
return ty;
|
|
};
|
|
|
|
int DpasFuncsResolution::parseDepth(StringRef StrRef, size_t &StrPos, size_t &StrRem) {
|
|
if (StrPos != StringRef::npos && StrRem >= 2) {
|
|
char c0 = StrRef[StrPos];
|
|
char c1 = StrRef[StrPos + 1];
|
|
if (c0 == '_' && c1 == '8') {
|
|
StrPos += 2;
|
|
StrRem -= 2;
|
|
return 8;
|
|
}
|
|
if (StrRem >= 3 && c0 == '_' && c1 == '1' && StrRef[StrPos + 2] == '6') {
|
|
StrPos += 3;
|
|
StrRem -= 3;
|
|
return 16;
|
|
}
|
|
}
|
|
StrPos = StringRef::npos;
|
|
StrRem = 0;
|
|
return -1;
|
|
}
|
|
|
|
int DpasFuncsResolution::parseRCount(StringRef StrRef, size_t &StrPos, size_t &StrRem) {
|
|
if (StrPos != StringRef::npos && StrRem >= 2) {
|
|
char c0 = StrRef[StrPos];
|
|
char c1 = StrRef[StrPos + 1];
|
|
int rc = c1 - '0';
|
|
if (c0 == '_' && rc >= 1 && rc <= 8) {
|
|
StrPos += 2;
|
|
StrRem -= 2;
|
|
return rc;
|
|
}
|
|
}
|
|
StrPos = StringRef::npos;
|
|
StrRem = 0;
|
|
return -1;
|
|
}
|
|
|
|
FunctionPass *IGC::createDpasFuncsResolutionPass() { return new DpasFuncsResolution(); }
|