Update getKernelInfo method

add root device index parameter to return proper kernel info

Related-To: NEO-5001
Signed-off-by: Mateusz Jablonski <mateusz.jablonski@intel.com>
This commit is contained in:
Mateusz Jablonski
2020-12-07 14:41:52 +00:00
committed by Compute-Runtime-Automation
parent 864f069b8f
commit c8d1e082dd
69 changed files with 381 additions and 337 deletions

View File

@@ -27,15 +27,16 @@ class VmeBuiltinDispatchInfoBuilder : public BuiltinDispatchInfoBuilder {
populate(builtinOp,
mediaKernelsBuildOptions,
kernelName, vmeKernel);
widthArgNum = vmeKernel->getKernelInfo().getArgNumByName("width");
heightArgNum = vmeKernel->getKernelInfo().getArgNumByName("height");
strideArgNum = vmeKernel->getKernelInfo().getArgNumByName("stride");
acceleratorArgNum = vmeKernel->getKernelInfo().getArgNumByName("accelerator");
srcImgArgNum = vmeKernel->getKernelInfo().getArgNumByName("srcImg");
refImgArgNum = vmeKernel->getKernelInfo().getArgNumByName("refImg");
motionVectorBufferArgNum = vmeKernel->getKernelInfo().getArgNumByName("motion_vector_buffer");
predictionMotionVectorBufferArgNum = vmeKernel->getKernelInfo().getArgNumByName("prediction_motion_vector_buffer");
residualsArgNum = vmeKernel->getKernelInfo().getArgNumByName("residuals");
auto rootDeviceIndex = clDevice.getRootDeviceIndex();
widthArgNum = vmeKernel->getKernelInfo(rootDeviceIndex).getArgNumByName("width");
heightArgNum = vmeKernel->getKernelInfo(rootDeviceIndex).getArgNumByName("height");
strideArgNum = vmeKernel->getKernelInfo(rootDeviceIndex).getArgNumByName("stride");
acceleratorArgNum = vmeKernel->getKernelInfo(rootDeviceIndex).getArgNumByName("accelerator");
srcImgArgNum = vmeKernel->getKernelInfo(rootDeviceIndex).getArgNumByName("srcImg");
refImgArgNum = vmeKernel->getKernelInfo(rootDeviceIndex).getArgNumByName("refImg");
motionVectorBufferArgNum = vmeKernel->getKernelInfo(rootDeviceIndex).getArgNumByName("motion_vector_buffer");
predictionMotionVectorBufferArgNum = vmeKernel->getKernelInfo(rootDeviceIndex).getArgNumByName("prediction_motion_vector_buffer");
residualsArgNum = vmeKernel->getKernelInfo(rootDeviceIndex).getArgNumByName("residuals");
}
void getBlkTraits(const Vec3<size_t> &inGws, size_t &gwWidthInBlk, size_t &gwHeightInBlk) const {
@@ -51,6 +52,8 @@ class VmeBuiltinDispatchInfoBuilder : public BuiltinDispatchInfoBuilder {
return false;
}
auto rootDeviceIndex = clDevice.getRootDeviceIndex();
size_t gwWidthInBlk = 0;
size_t gwHeightInBlk = 0;
getBlkTraits(inGws, gwWidthInBlk, gwHeightInBlk);
@@ -59,7 +62,7 @@ class VmeBuiltinDispatchInfoBuilder : public BuiltinDispatchInfoBuilder {
cl_int width = (cl_int)gwWidthInBlk;
cl_int stride = height;
size_t numThreadsX = gwWidthInBlk;
const size_t simdWidth = vmeKernel->getKernelInfo().getMaxSimdSize();
const size_t simdWidth = vmeKernel->getKernelInfo(rootDeviceIndex).getMaxSimdSize();
stride = static_cast<cl_int>(Math::divideAndRoundUp(height * width, numThreadsX));
// update implicit args
@@ -69,7 +72,7 @@ class VmeBuiltinDispatchInfoBuilder : public BuiltinDispatchInfoBuilder {
// Update global work size to force macro-block to HW thread execution model
Vec3<size_t> gws = {numThreadsX * simdWidth, 1, 1};
Vec3<size_t> lws = {vmeKernel->getKernelInfo().kernelDescriptor.kernelAttributes.requiredWorkgroupSize[0], 1, 1};
Vec3<size_t> lws = {vmeKernel->getKernelInfo(rootDeviceIndex).kernelDescriptor.kernelAttributes.requiredWorkgroupSize[0], 1, 1};
DispatchInfoBuilder<SplitDispatch::Dim::d2D, SplitDispatch::SplitMode::NoSplit> builder(clDevice);
builder.setDispatchGeometry(gws, lws, inOffset, gws, lws);
@@ -164,9 +167,10 @@ class VmeBuiltinDispatchInfoBuilder : public BuiltinDispatchInfoBuilder {
template <typename RetType>
RetType getKernelArgByValValue(uint32_t argNum) const {
auto &kai = vmeKernel->getKernelInfo().kernelArgInfo[argNum];
DEBUG_BREAK_IF(kai.kernelArgPatchInfoVector.size() != 1);
const KernelArgPatchInfo &patchInfo = kai.kernelArgPatchInfoVector[0];
auto rootDeviceIndex = clDevice.getRootDeviceIndex();
auto &kernelArgInfo = vmeKernel->getKernelInfo(rootDeviceIndex).kernelArgInfo[argNum];
DEBUG_BREAK_IF(kernelArgInfo.kernelArgPatchInfoVector.size() != 1);
const KernelArgPatchInfo &patchInfo = kernelArgInfo.kernelArgPatchInfoVector[0];
DEBUG_BREAK_IF(sizeof(RetType) > patchInfo.size);
return *(RetType *)(vmeKernel->getCrossThreadData(clDevice.getRootDeviceIndex()) + patchInfo.crossthreadOffset);
}
@@ -255,18 +259,19 @@ class AdvancedVmeBuiltinDispatchInfoBuilder : public VmeBuiltinDispatchInfoBuild
const char *kernelName)
: VmeBuiltinDispatchInfoBuilder(kernelsLib, device, builtinOp,
kernelName) {
flagsArgNum = this->vmeKernel->getKernelInfo().getArgNumByName("flags");
intraSrcImgArgNum = this->vmeKernel->getKernelInfo().getArgNumByName("intraSrcImg");
skipBlockTypeArgNum = this->vmeKernel->getKernelInfo().getArgNumByName("skip_block_type");
searchCostPenaltyArgNum = this->vmeKernel->getKernelInfo().getArgNumByName("search_cost_penalty");
searchCostPrecisionArgNum = this->vmeKernel->getKernelInfo().getArgNumByName("search_cost_precision");
bidirWeightArgNum = this->vmeKernel->getKernelInfo().getArgNumByName("bidir_weight");
predictorsBufferArgNum = this->vmeKernel->getKernelInfo().getArgNumByName("predictors_buffer");
countMotionVectorBufferArgNum = this->vmeKernel->getKernelInfo().getArgNumByName("count_motion_vector_buffer");
skipMotionVectorBufferArgNum = this->vmeKernel->getKernelInfo().getArgNumByName("skip_motion_vector_buffer");
intraSearchPredictorModesArgNum = this->vmeKernel->getKernelInfo().getArgNumByName("intra_search_predictor_modes");
skipResidualsArgNum = this->vmeKernel->getKernelInfo().getArgNumByName("skip_residuals");
intraResidualsArgNum = this->vmeKernel->getKernelInfo().getArgNumByName("intra_residuals");
auto rootDeviceIndex = clDevice.getRootDeviceIndex();
flagsArgNum = this->vmeKernel->getKernelInfo(rootDeviceIndex).getArgNumByName("flags");
intraSrcImgArgNum = this->vmeKernel->getKernelInfo(rootDeviceIndex).getArgNumByName("intraSrcImg");
skipBlockTypeArgNum = this->vmeKernel->getKernelInfo(rootDeviceIndex).getArgNumByName("skip_block_type");
searchCostPenaltyArgNum = this->vmeKernel->getKernelInfo(rootDeviceIndex).getArgNumByName("search_cost_penalty");
searchCostPrecisionArgNum = this->vmeKernel->getKernelInfo(rootDeviceIndex).getArgNumByName("search_cost_precision");
bidirWeightArgNum = this->vmeKernel->getKernelInfo(rootDeviceIndex).getArgNumByName("bidir_weight");
predictorsBufferArgNum = this->vmeKernel->getKernelInfo(rootDeviceIndex).getArgNumByName("predictors_buffer");
countMotionVectorBufferArgNum = this->vmeKernel->getKernelInfo(rootDeviceIndex).getArgNumByName("count_motion_vector_buffer");
skipMotionVectorBufferArgNum = this->vmeKernel->getKernelInfo(rootDeviceIndex).getArgNumByName("skip_motion_vector_buffer");
intraSearchPredictorModesArgNum = this->vmeKernel->getKernelInfo(rootDeviceIndex).getArgNumByName("intra_search_predictor_modes");
skipResidualsArgNum = this->vmeKernel->getKernelInfo(rootDeviceIndex).getArgNumByName("skip_residuals");
intraResidualsArgNum = this->vmeKernel->getKernelInfo(rootDeviceIndex).getArgNumByName("intra_residuals");
}
bool setExplicitArg(uint32_t argIndex, size_t argSize, const void *argVal, cl_int &err) const override {