Fix printf for type BYTE and SHORT

Generated instructions writing to printf buffer require destination
address to be DWORD aligned. Because of that values of type BYTE (1B)
and SHORT (2B) need to be written as 4B value.
This change adds support for this. When trying to read value of type
BYTE or SHORT four bytes are actually read to be aligned with compiler
implementation.

Signed-off-by: Krystian Chmielewski <krystian.chmielewski@intel.com>
This commit is contained in:
Krystian Chmielewski
2022-09-06 11:17:04 +00:00
committed by Compute-Runtime-Automation
parent 3b211c9f8d
commit 2e9574c656
3 changed files with 52 additions and 43 deletions

View File

@ -10,26 +10,20 @@
#include "zello_common.h" #include "zello_common.h"
#include "zello_compile.h" #include "zello_compile.h"
#include <fstream>
#include <iomanip>
#include <iostream> #include <iostream>
#include <numeric>
const char *source = R"===( const char *source = R"===(
__kernel void test_printf(__global char *dst, __global char *src){ __kernel void printf_kernel(char byteValue, short shortValue, int intValue, long longValue){
uint gid = get_global_id(0); printf("byte = %hhd\nshort = %hd\nint = %d\nlong = %ld", byteValue, shortValue, intValue, longValue);
printf("global_id = %d\n", gid);
} }
)==="; )===";
void testPrintfKernel(ze_context_handle_t &context, ze_device_handle_t &device) { void runPrintfKernel(ze_context_handle_t &context, ze_device_handle_t &device) {
ze_module_handle_t module;
ze_kernel_handle_t kernel;
ze_command_queue_handle_t cmdQueue; ze_command_queue_handle_t cmdQueue;
ze_command_list_handle_t cmdList; ze_command_list_handle_t cmdList;
ze_group_count_t dispatchTraits;
ze_command_queue_desc_t cmdQueueDesc = {ZE_STRUCTURE_TYPE_COMMAND_QUEUE_DESC}; ze_command_queue_desc_t cmdQueueDesc = {ZE_STRUCTURE_TYPE_COMMAND_QUEUE_DESC};
cmdQueueDesc.ordinal = 0; cmdQueueDesc.ordinal = 0;
cmdQueueDesc.index = 0; cmdQueueDesc.index = 0;
cmdQueueDesc.mode = ZE_COMMAND_QUEUE_MODE_ASYNCHRONOUS; cmdQueueDesc.mode = ZE_COMMAND_QUEUE_MODE_ASYNCHRONOUS;
@ -44,6 +38,7 @@ void testPrintfKernel(ze_context_handle_t &context, ze_device_handle_t &device)
} }
SUCCESS_OR_TERMINATE((0 == spirV.size())); SUCCESS_OR_TERMINATE((0 == spirV.size()));
ze_module_handle_t module;
ze_module_desc_t moduleDesc = {ZE_STRUCTURE_TYPE_MODULE_DESC}; ze_module_desc_t moduleDesc = {ZE_STRUCTURE_TYPE_MODULE_DESC};
moduleDesc.format = ZE_MODULE_FORMAT_IL_SPIRV; moduleDesc.format = ZE_MODULE_FORMAT_IL_SPIRV;
moduleDesc.pInputModule = spirV.data(); moduleDesc.pInputModule = spirV.data();
@ -52,41 +47,35 @@ void testPrintfKernel(ze_context_handle_t &context, ze_device_handle_t &device)
SUCCESS_OR_TERMINATE(zeModuleCreate(context, device, &moduleDesc, &module, nullptr)); SUCCESS_OR_TERMINATE(zeModuleCreate(context, device, &moduleDesc, &module, nullptr));
ze_kernel_handle_t kernel;
ze_kernel_desc_t kernelDesc = {ZE_STRUCTURE_TYPE_KERNEL_DESC}; ze_kernel_desc_t kernelDesc = {ZE_STRUCTURE_TYPE_KERNEL_DESC};
kernelDesc.pKernelName = "test_printf"; kernelDesc.pKernelName = "printf_kernel";
SUCCESS_OR_TERMINATE(zeKernelCreate(module, &kernelDesc, &kernel)); SUCCESS_OR_TERMINATE(zeKernelCreate(module, &kernelDesc, &kernel));
uint32_t groupSizeX = 1; [[maybe_unused]] int8_t byteValue = std::numeric_limits<int8_t>::max();
uint32_t groupSizeY = 1; [[maybe_unused]] int16_t shortValue = std::numeric_limits<int16_t>::max();
uint32_t groupSizeZ = 1; [[maybe_unused]] int32_t intValue = std::numeric_limits<int32_t>::max();
uint32_t globalSizeX = 64; [[maybe_unused]] int64_t longValue = std::numeric_limits<int64_t>::max();
SUCCESS_OR_TERMINATE(zeKernelSuggestGroupSize(kernel, globalSizeX, 1, 1, &groupSizeX, SUCCESS_OR_TERMINATE(zeKernelSetArgumentValue(kernel, 0, sizeof(byteValue), &byteValue));
&groupSizeY, &groupSizeZ)); SUCCESS_OR_TERMINATE(zeKernelSetArgumentValue(kernel, 1, sizeof(shortValue), &shortValue));
SUCCESS_OR_TERMINATE(zeKernelSetArgumentValue(kernel, 2, sizeof(intValue), &intValue));
SUCCESS_OR_TERMINATE(zeKernelSetArgumentValue(kernel, 3, sizeof(longValue), &longValue));
SUCCESS_OR_TERMINATE(zeKernelSetGroupSize(kernel, groupSizeX, groupSizeY, groupSizeZ)); SUCCESS_OR_TERMINATE(zeKernelSetGroupSize(kernel, 1U, 1U, 1U));
dispatchTraits.groupCountX = globalSizeX / groupSizeX;
dispatchTraits.groupCountY = 1;
dispatchTraits.groupCountZ = 1;
if (verbose) {
std::cout << "Number of groups : (" << dispatchTraits.groupCountX << ", "
<< dispatchTraits.groupCountY << ", " << dispatchTraits.groupCountZ << ")"
<< std::endl;
}
SUCCESS_OR_TERMINATE(zeKernelSetArgumentValue(kernel, 0, sizeof(size_t), nullptr));
SUCCESS_OR_TERMINATE(zeKernelSetArgumentValue(kernel, 1, sizeof(size_t), nullptr));
ze_group_count_t dispatchTraits;
dispatchTraits.groupCountX = 1u;
dispatchTraits.groupCountY = 1u;
dispatchTraits.groupCountZ = 1u;
SUCCESS_OR_TERMINATE(zeCommandListAppendLaunchKernel(cmdList, kernel, &dispatchTraits, nullptr, 0, nullptr)); SUCCESS_OR_TERMINATE(zeCommandListAppendLaunchKernel(cmdList, kernel, &dispatchTraits, nullptr, 0, nullptr));
SUCCESS_OR_TERMINATE(zeCommandListClose(cmdList)); SUCCESS_OR_TERMINATE(zeCommandListClose(cmdList));
SUCCESS_OR_TERMINATE(zeCommandQueueExecuteCommandLists(cmdQueue, 1, &cmdList, nullptr)); SUCCESS_OR_TERMINATE(zeCommandQueueExecuteCommandLists(cmdQueue, 1, &cmdList, nullptr));
SUCCESS_OR_TERMINATE(zeCommandQueueSynchronize(cmdQueue, std::numeric_limits<uint64_t>::max())); SUCCESS_OR_TERMINATE(zeCommandQueueSynchronize(cmdQueue, std::numeric_limits<uint64_t>::max()));
SUCCESS_OR_TERMINATE(zeKernelDestroy(kernel)); SUCCESS_OR_TERMINATE(zeKernelDestroy(kernel));
SUCCESS_OR_TERMINATE(zeModuleDestroy(module)); SUCCESS_OR_TERMINATE(zeModuleDestroy(module));
SUCCESS_OR_TERMINATE(zeCommandListDestroy(cmdList)); SUCCESS_OR_TERMINATE(zeCommandListDestroy(cmdList));
SUCCESS_OR_TERMINATE(zeCommandQueueDestroy(cmdQueue)); SUCCESS_OR_TERMINATE(zeCommandQueueDestroy(cmdQueue));
} }
@ -102,7 +91,7 @@ int main(int argc, char *argv[]) {
SUCCESS_OR_TERMINATE(zeDeviceGetProperties(device, &deviceProperties)); SUCCESS_OR_TERMINATE(zeDeviceGetProperties(device, &deviceProperties));
printDeviceProperties(deviceProperties); printDeviceProperties(deviceProperties);
testPrintfKernel(context, device); runPrintfKernel(context, device);
SUCCESS_OR_TERMINATE(zeContextDestroy(context)); SUCCESS_OR_TERMINATE(zeContextDestroy(context));

View File

@ -33,8 +33,6 @@ class PrintFormatterTest : public testing::Test {
uint8_t buffer; uint8_t buffer;
MockGraphicsAllocation *data; MockGraphicsAllocation *data;
MockKernel *kernel;
std::unique_ptr<MockProgram> program;
std::unique_ptr<MockKernelInfo> kernelInfo; std::unique_ptr<MockKernelInfo> kernelInfo;
ClDevice *device; ClDevice *device;
@ -50,9 +48,6 @@ class PrintFormatterTest : public testing::Test {
data = new MockGraphicsAllocation(underlyingBuffer, maxPrintfOutputLength); data = new MockGraphicsAllocation(underlyingBuffer, maxPrintfOutputLength);
kernelInfo = std::make_unique<MockKernelInfo>(); kernelInfo = std::make_unique<MockKernelInfo>();
device = new MockClDevice{MockDevice::createWithNewExecutionEnvironment<MockDevice>(nullptr)};
program = std::make_unique<MockProgram>(toClDeviceVector(*device));
kernel = new MockKernel(program.get(), *kernelInfo, *device);
printFormatter = std::unique_ptr<PrintFormatter>(new PrintFormatter(static_cast<uint8_t *>(data->getUnderlyingBuffer()), printfBufferSize, is32bit, &kernelInfo->kernelDescriptor.kernelMetadata.printfStringsMap)); printFormatter = std::unique_ptr<PrintFormatter>(new PrintFormatter(static_cast<uint8_t *>(data->getUnderlyingBuffer()), printfBufferSize, is32bit, &kernelInfo->kernelDescriptor.kernelMetadata.printfStringsMap));
@ -64,8 +59,6 @@ class PrintFormatterTest : public testing::Test {
void TearDown() override { void TearDown() override {
delete data; delete data;
delete kernel;
delete device;
} }
enum class PRINTF_DATA_TYPE : int { enum class PRINTF_DATA_TYPE : int {
@ -86,6 +79,7 @@ class PrintFormatterTest : public testing::Test {
VECTOR_DOUBLE VECTOR_DOUBLE
}; };
PRINTF_DATA_TYPE getPrintfDataType(char value) { return PRINTF_DATA_TYPE::BYTE; };
PRINTF_DATA_TYPE getPrintfDataType(int8_t value) { return PRINTF_DATA_TYPE::BYTE; }; PRINTF_DATA_TYPE getPrintfDataType(int8_t value) { return PRINTF_DATA_TYPE::BYTE; };
PRINTF_DATA_TYPE getPrintfDataType(uint8_t value) { return PRINTF_DATA_TYPE::BYTE; }; PRINTF_DATA_TYPE getPrintfDataType(uint8_t value) { return PRINTF_DATA_TYPE::BYTE; };
PRINTF_DATA_TYPE getPrintfDataType(int16_t value) { return PRINTF_DATA_TYPE::SHORT; }; PRINTF_DATA_TYPE getPrintfDataType(int16_t value) { return PRINTF_DATA_TYPE::SHORT; };
@ -100,9 +94,15 @@ class PrintFormatterTest : public testing::Test {
template <class T> template <class T>
void injectValue(T value) { void injectValue(T value) {
storeData(getPrintfDataType(value)); auto dataType = getPrintfDataType(value);
storeData(dataType);
if (dataType == PRINTF_DATA_TYPE::BYTE ||
dataType == PRINTF_DATA_TYPE::SHORT) {
storeData(static_cast<int>(value));
} else {
storeData(value); storeData(value);
} }
}
void injectStringValue(int value) { void injectStringValue(int value) {
storeData(PRINTF_DATA_TYPE::STRING); storeData(PRINTF_DATA_TYPE::STRING);
@ -906,6 +906,24 @@ TEST_F(PrintFormatterTest, GivenNoStringMapAndBufferWithFormatStringAnd2StringsT
EXPECT_STREQ(expectedOutput, output); EXPECT_STREQ(expectedOutput, output);
} }
TEST_F(PrintFormatterTest, GivenTypeSmallerThan4BThenItIsReadAs4BValue) {
printFormatter.reset(new PrintFormatter(static_cast<uint8_t *>(data->getUnderlyingBuffer()), printfBufferSize, true));
const char *formatString = "%c %hd %d";
storeData(formatString);
char byteValue = 'a';
injectValue(byteValue);
short shortValue = 123;
injectValue(shortValue);
int intValue = 456;
injectValue(intValue);
const char *expectedOutput = "a 123 456";
char output[maxPrintfOutputLength];
printFormatter->printKernelOutput([&output](char *str) { strncpy_s(output, maxPrintfOutputLength, str, maxPrintfOutputLength - 1); });
EXPECT_STREQ(expectedOutput, output);
}
TEST(printToSTDOUTTest, GivenStringWhenPrintingToStdoutThenOutputOccurs) { TEST(printToSTDOUTTest, GivenStringWhenPrintingToStdoutThenOutputOccurs) {
testing::internal::CaptureStdout(); testing::internal::CaptureStdout();
printToSTDOUT("test"); printToSTDOUT("test");

View File

@ -82,8 +82,10 @@ class PrintFormatter {
template <class T> template <class T>
size_t typedPrintToken(char *output, size_t size, const char *formatString) { size_t typedPrintToken(char *output, size_t size, const char *formatString) {
T value = {0}; T value{0};
read(&value); read(&value);
constexpr auto offsetToBeDwordAligned = static_cast<uint32_t>(std::max(int64_t(sizeof(int) - sizeof(T)), int64_t(0)));
currentOffset += offsetToBeDwordAligned;
return simpleSprintf(output, size, formatString, value); return simpleSprintf(output, size, formatString, value);
} }