diff --git a/runtime/command_queue/local_id_gen.cpp b/runtime/command_queue/local_id_gen.cpp index fa655cf572..6cc7ddd784 100644 --- a/runtime/command_queue/local_id_gen.cpp +++ b/runtime/command_queue/local_id_gen.cpp @@ -86,36 +86,42 @@ inline void generateLocalIDsWithLayoutForImages(void *b, const std::array xDelta; - auto buffer = reinterpret_cast(b); uint16_t offset = 0u; auto numGrfs = (localWorkgroupSize.at(0) * localWorkgroupSize.at(1) * localWorkgroupSize.at(2) + (simd - 1)) / simd; + uint8_t xMask = simd == 8u ? 0b1 : 0b11; uint16_t x = 0u; uint16_t y = 0u; for (auto grfId = 0; grfId < numGrfs; grfId++) { auto rowX = buffer + offset; auto rowY = buffer + offset + rowWidth; auto rowZ = buffer + offset + 2 * rowWidth; + uint16_t extraX = 0u; + uint16_t extraY = 0u; for (uint8_t i = 0u; i < simd; i++) { - if (i == yDelta * xDelta && earlyGrowX) { - x += xDelta; + if (i > 0) { + extraX++; + if (extraX == xDelta) { + extraX = 0u; + } + if ((i & xMask) == 0) { + extraY++; + if (y + extraY == localWorkgroupSize.at(1)) { + extraY = 0; + x += xDelta; + } + } } if (x == localWorkgroupSize.at(0)) { x = 0u; y += yDelta; - if (y == localWorkgroupSize.at(1)) { + if (y >= localWorkgroupSize.at(1)) { y = 0u; } } - rowX[i] = (x + (i & (xDelta - 1))); - rowY[i] = (y + i / xDelta); - if (rowY[i] >= localWorkgroupSize.at(1)) { - rowY[i] -= localWorkgroupSize.at(1); - } + rowX[i] = x + extraX; + rowY[i] = y + extraY; rowZ[i] = 0u; } x += xDelta; diff --git a/unit_tests/command_queue/local_id_tests.cpp b/unit_tests/command_queue/local_id_tests.cpp index d3c54e893c..56325cc015 100644 --- a/unit_tests/command_queue/local_id_tests.cpp +++ b/unit_tests/command_queue/local_id_tests.cpp @@ -320,15 +320,19 @@ struct LocalIdsLayoutForImagesTest : ::testing::TestWithParam localWorkSize.at(1) && j == 16u) { baseX += xDelta; if (baseX == localWorkSize.at(0)) { baseX = 0; @@ -339,17 +343,20 @@ struct LocalIdsLayoutForImagesTest : ::testing::TestWithParam= localWorkSize.at(1)) { - expectedY -= localWorkSize.at(1); + expectedY -= (localWorkSize.at(1) - baseY); } EXPECT_EQ(buffer[i * 3 * rowWidth + rowWidth + j], expectedY); } // validate Z row for (int j = 0; j < simd; j++) { + if (simd * i + j == totalLocalIds) + break; EXPECT_EQ(buffer[i * 3 * rowWidth + 2 * rowWidth + j], 0u); } }