diff --git a/level_zero/core/source/cmdlist/cmdlist_additional_args.cpp b/level_zero/core/source/cmdlist/cmdlist_additional_args.cpp index 18acb7c9d8..7a1b036b9e 100644 --- a/level_zero/core/source/cmdlist/cmdlist_additional_args.cpp +++ b/level_zero/core/source/cmdlist/cmdlist_additional_args.cpp @@ -66,6 +66,12 @@ void CommandList::setAdditionalBlitPropertiesFromMemoryCopyParams(NEO::BlitPrope } ze_result_t CommandList::obtainMemoryCopyParamsFromExtensions(const ze_base_desc_t *desc, CmdListMemoryCopyParams &memoryCopyParams) const { + if (desc) { + PRINT_DEBUG_STRING(NEO::debugManager.flags.PrintDebugMessages.get(), stderr, "Could not recognize provided extension, stype: 0x%x.\n", + desc->stype); + return ZE_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + return ZE_RESULT_SUCCESS; } diff --git a/level_zero/core/source/cmdlist/cmdlist_hw.inl b/level_zero/core/source/cmdlist/cmdlist_hw.inl index 1082c91dce..3c51ef4de0 100644 --- a/level_zero/core/source/cmdlist/cmdlist_hw.inl +++ b/level_zero/core/source/cmdlist/cmdlist_hw.inl @@ -2103,7 +2103,10 @@ ze_result_t CommandListCoreFamily::appendMemoryCopyWithParameters uint32_t numWaitEvents, ze_event_handle_t *phWaitEvents) { CmdListMemoryCopyParams memoryCopyParams{}; - obtainMemoryCopyParamsFromExtensions(static_cast(pNext), memoryCopyParams); + ze_result_t ret = obtainMemoryCopyParamsFromExtensions(static_cast(pNext), memoryCopyParams); + if (ret) { + return ret; + } return appendMemoryCopy(dstptr, srcptr, size, hSignalEvent, numWaitEvents, phWaitEvents, memoryCopyParams); } @@ -2753,7 +2756,10 @@ ze_result_t CommandListCoreFamily::appendMemoryFillWithParameters uint32_t numWaitEvents, ze_event_handle_t *phWaitEvents) { CmdListMemoryCopyParams memoryCopyParams{}; - obtainMemoryCopyParamsFromExtensions(static_cast(pNext), memoryCopyParams); + ze_result_t ret = obtainMemoryCopyParamsFromExtensions(static_cast(pNext), memoryCopyParams); + if (ret) { + return ret; + } return appendMemoryFill(ptr, pattern, patternSize, size, hSignalEvent, numWaitEvents, phWaitEvents, memoryCopyParams); } diff --git a/level_zero/core/test/unit_tests/sources/cmdlist/test_cmdlist_6.cpp b/level_zero/core/test/unit_tests/sources/cmdlist/test_cmdlist_6.cpp index 22778940e2..6d17764bf2 100644 --- a/level_zero/core/test/unit_tests/sources/cmdlist/test_cmdlist_6.cpp +++ b/level_zero/core/test/unit_tests/sources/cmdlist/test_cmdlist_6.cpp @@ -484,15 +484,35 @@ HWTEST_F(CommandListTest, givenUnrecognizedDescriptorWhenObtainLaunchParamsFromE EXPECT_EQ(std::string("Could not recognize provided extension, stype: 0x12.\n"), output); } -HWTEST_F(CommandListTest, WhenObtainMemoryCopyParamsFromExtensionsIsCalledThenSuccessIsReturned) { +HWTEST_F(CommandListTest, givenEmptyExtWhenObtainMemoryCopyParamsFromExtensionsIsCalledThenSuccessIsReturned) { CmdListMemoryCopyParams memoryCopyParams{}; - ze_base_desc_t desc{}; auto commandList = std::make_unique>>(); - commandList->initialize(device, NEO::EngineGroupType::renderCompute, 0u); + commandList->initialize(device, NEO::EngineGroupType::copy, 0u); + + ze_result_t result = commandList->obtainMemoryCopyParamsFromExtensions(nullptr, memoryCopyParams); + EXPECT_EQ(ZE_RESULT_SUCCESS, result); +} + +HWTEST_F(CommandListTest, givenUnrecognizedExtWhenObtainMemoryCopyParamsFromExtensionsIsCalledThenErrorIsReturned) { + CmdListMemoryCopyParams memoryCopyParams{}; + ze_base_desc_t desc{ + .stype = static_cast(10), + }; + + StreamCapture capture; + capture.captureStderr(); + DebugManagerStateRestore restorer{}; + debugManager.flags.PrintDebugMessages.set(true); + + auto commandList = std::make_unique>>(); + commandList->initialize(device, NEO::EngineGroupType::copy, 0u); ze_result_t result = commandList->obtainMemoryCopyParamsFromExtensions(&desc, memoryCopyParams); - EXPECT_EQ(ZE_RESULT_SUCCESS, result); + EXPECT_EQ(ZE_RESULT_ERROR_UNSUPPORTED_FEATURE, result); + + std::string output = capture.getCapturedStderr(); + EXPECT_EQ(std::string("Could not recognize provided extension, stype: 0xa.\n"), output); } HWTEST_F(CommandListTest, givenComputeCommandListAnd2dRegionWhenMemoryCopyRegionInUsmHostAllocationCalledThenBuiltinFlagAndDestinationAllocSystemIsSet) { diff --git a/level_zero/core/test/unit_tests/sources/cmdlist/test_cmdlist_append_memory.cpp b/level_zero/core/test/unit_tests/sources/cmdlist/test_cmdlist_append_memory.cpp index 908545962e..b52be73cc8 100644 --- a/level_zero/core/test/unit_tests/sources/cmdlist/test_cmdlist_append_memory.cpp +++ b/level_zero/core/test/unit_tests/sources/cmdlist/test_cmdlist_append_memory.cpp @@ -1396,6 +1396,18 @@ HWTEST2_F(AppendMemoryCopyTests, givenCopyCommandListImmediateWithDummyBlitWaWhe context->freeMem(buffer); } +HWTEST_F(AppendMemoryCopyTests, givenInvalidExtWhenAppendMemoryCopyWithParametersCalledThenErrorIsReturned) { + MockCommandListCoreFamily cmdList; + cmdList.initialize(device, NEO::EngineGroupType::copy, 0u); + + uint32_t srcBuffer = 1; + uint32_t dstBuffer = 0; + ze_base_desc_t desc{}; + + ze_result_t result = cmdList.appendMemoryCopyWithParameters(&dstBuffer, &srcBuffer, sizeof(srcBuffer), &desc, nullptr, 0, nullptr); + EXPECT_NE(ZE_RESULT_SUCCESS, result); +} + struct StagingBuffersFixture : public AppendMemoryCopyTests { void SetUp() override { debugManager.flags.EnableCopyWithStagingBuffers.set(1); diff --git a/level_zero/core/test/unit_tests/sources/cmdlist/test_cmdlist_fill.cpp b/level_zero/core/test/unit_tests/sources/cmdlist/test_cmdlist_fill.cpp index 46e6fa695d..146e8004a9 100644 --- a/level_zero/core/test/unit_tests/sources/cmdlist/test_cmdlist_fill.cpp +++ b/level_zero/core/test/unit_tests/sources/cmdlist/test_cmdlist_fill.cpp @@ -536,5 +536,17 @@ HWTEST2_F(AppendFillTest, true); } +HWTEST_F(AppendFillTest, givenInvalidExtWhenAppendMemoryFillWithParametersCalledThenErrorIsReturned) { + MockCommandList commandList; + commandList.initialize(device, NEO::EngineGroupType::copy, 0u); + + uint32_t dstBuffer = 0; + uint8_t pattern = 1; + ze_base_desc_t desc{}; + + ze_result_t result = commandList.appendMemoryFillWithParameters(&dstBuffer, &pattern, sizeof(pattern), sizeof(dstBuffer), &desc, nullptr, 0, nullptr); + EXPECT_NE(ZE_RESULT_SUCCESS, result); +} + } // namespace ult } // namespace L0