diff --git a/runtime/command_queue/command_queue.cpp b/runtime/command_queue/command_queue.cpp index 87ca0aa547..6b7c70aef6 100644 --- a/runtime/command_queue/command_queue.cpp +++ b/runtime/command_queue/command_queue.cpp @@ -332,9 +332,16 @@ LinearStream &CommandQueue::getCS(size_t minRequiredSize) { } cl_int CommandQueue::enqueueAcquireSharedObjects(cl_uint numObjects, const cl_mem *memObjects, cl_uint numEventsInWaitList, const cl_event *eventWaitList, cl_event *oclEvent, cl_uint cmdType) { + if ((memObjects == nullptr && numObjects != 0) || (memObjects != nullptr && numObjects == 0)) { + return CL_INVALID_VALUE; + } for (unsigned int object = 0; object < numObjects; object++) { - auto memObject = castToObjectOrAbort(memObjects[object]); + auto memObject = castToObject(memObjects[object]); + if (memObject == nullptr || memObject->peekSharingHandler() == nullptr) { + return CL_INVALID_MEM_OBJECT; + } + memObject->peekSharingHandler()->acquire(memObject); memObject->acquireCount++; } @@ -351,8 +358,16 @@ cl_int CommandQueue::enqueueAcquireSharedObjects(cl_uint numObjects, const cl_me } cl_int CommandQueue::enqueueReleaseSharedObjects(cl_uint numObjects, const cl_mem *memObjects, cl_uint numEventsInWaitList, const cl_event *eventWaitList, cl_event *oclEvent, cl_uint cmdType) { + if ((memObjects == nullptr && numObjects != 0) || (memObjects != nullptr && numObjects == 0)) { + return CL_INVALID_VALUE; + } + for (unsigned int object = 0; object < numObjects; object++) { - auto memObject = castToObjectOrAbort(memObjects[object]); + auto memObject = castToObject(memObjects[object]); + if (memObject == nullptr || memObject->peekSharingHandler() == nullptr) { + return CL_INVALID_MEM_OBJECT; + } + memObject->peekSharingHandler()->release(memObject); DEBUG_BREAK_IF(memObject->acquireCount <= 0); memObject->acquireCount--; diff --git a/unit_tests/command_queue/command_queue_tests.cpp b/unit_tests/command_queue/command_queue_tests.cpp index d4d83657a9..964dcfa77f 100644 --- a/unit_tests/command_queue/command_queue_tests.cpp +++ b/unit_tests/command_queue/command_queue_tests.cpp @@ -34,6 +34,7 @@ #include "unit_tests/fixtures/context_fixture.h" #include "unit_tests/fixtures/device_fixture.h" #include "unit_tests/fixtures/memory_management_fixture.h" +#include "unit_tests/fixtures/buffer_fixture.h" #include "unit_tests/helpers/debug_manager_state_restore.h" #include "unit_tests/libult/ult_command_stream_receiver.h" #include "unit_tests/mocks/mock_memory_manager.h" @@ -777,3 +778,105 @@ TEST(CommandQueueGetIndirectHeap, whenCheckingForCsrInstructionHeapReservedBlock EXPECT_GE(alignedPatternSize, csr->getInstructionHeapCmdStreamReceiverReservedSize()); EXPECT_EQ(alignedPatternSize, cmdQ.getInstructionHeapReservedBlockSize()); } + +TEST(CommandQueue, givenEnqueueAcquireSharedObjectsWhenNoObjectsThenReturnSuccess) { + MockContext context; + CommandQueue cmdQ(&context, nullptr, 0); + + cl_uint numObjects = 0; + cl_mem *memObjects = nullptr; + + cl_int result = cmdQ.enqueueAcquireSharedObjects(numObjects, memObjects, 0, nullptr, nullptr, 0); + EXPECT_EQ(result, CL_SUCCESS); +} + +TEST(CommandQueue, givenEnqueueAcquireSharedObjectsWhenIncorrectArgumentsThenReturnProperError) { + MockContext context; + CommandQueue cmdQ(&context, nullptr, 0); + + cl_uint numObjects = 1; + cl_mem *memObjects = nullptr; + + cl_int result = cmdQ.enqueueAcquireSharedObjects(numObjects, memObjects, 0, nullptr, nullptr, 0); + EXPECT_EQ(result, CL_INVALID_VALUE); + + numObjects = 0; + memObjects = (cl_mem *)1; + + result = cmdQ.enqueueAcquireSharedObjects(numObjects, memObjects, 0, nullptr, nullptr, 0); + EXPECT_EQ(result, CL_INVALID_VALUE); + + numObjects = 0; + memObjects = (cl_mem *)1; + + result = cmdQ.enqueueAcquireSharedObjects(numObjects, memObjects, 0, nullptr, nullptr, 0); + EXPECT_EQ(result, CL_INVALID_VALUE); + + cl_mem memObject = nullptr; + + numObjects = 1; + memObjects = &memObject; + + result = cmdQ.enqueueAcquireSharedObjects(numObjects, memObjects, 0, nullptr, nullptr, 0); + EXPECT_EQ(result, CL_INVALID_MEM_OBJECT); + + auto buffer = std::unique_ptr(BufferHelper<>::create(&context)); + memObject = buffer.get(); + + numObjects = 1; + memObjects = &memObject; + + result = cmdQ.enqueueAcquireSharedObjects(numObjects, memObjects, 0, nullptr, nullptr, 0); + EXPECT_EQ(result, CL_INVALID_MEM_OBJECT); +} + +TEST(CommandQueue, givenEnqueueReleaseSharedObjectsWhenNoObjectsThenReturnSuccess) { + MockContext context; + CommandQueue cmdQ(&context, nullptr, 0); + + cl_uint numObjects = 0; + cl_mem *memObjects = nullptr; + + cl_int result = cmdQ.enqueueReleaseSharedObjects(numObjects, memObjects, 0, nullptr, nullptr, 0); + EXPECT_EQ(result, CL_SUCCESS); +} + +TEST(CommandQueue, givenEnqueueReleaseSharedObjectsWhenIncorrectArgumentsThenReturnProperError) { + MockContext context; + CommandQueue cmdQ(&context, nullptr, 0); + + cl_uint numObjects = 1; + cl_mem *memObjects = nullptr; + + cl_int result = cmdQ.enqueueReleaseSharedObjects(numObjects, memObjects, 0, nullptr, nullptr, 0); + EXPECT_EQ(result, CL_INVALID_VALUE); + + numObjects = 0; + memObjects = (cl_mem *)1; + + result = cmdQ.enqueueReleaseSharedObjects(numObjects, memObjects, 0, nullptr, nullptr, 0); + EXPECT_EQ(result, CL_INVALID_VALUE); + + numObjects = 0; + memObjects = (cl_mem *)1; + + result = cmdQ.enqueueReleaseSharedObjects(numObjects, memObjects, 0, nullptr, nullptr, 0); + EXPECT_EQ(result, CL_INVALID_VALUE); + + cl_mem memObject = nullptr; + + numObjects = 1; + memObjects = &memObject; + + result = cmdQ.enqueueReleaseSharedObjects(numObjects, memObjects, 0, nullptr, nullptr, 0); + EXPECT_EQ(result, CL_INVALID_MEM_OBJECT); + + auto buffer = std::unique_ptr(BufferHelper<>::create(&context)); + memObject = buffer.get(); + + numObjects = 1; + memObjects = &memObject; + + result = cmdQ.enqueueReleaseSharedObjects(numObjects, memObjects, 0, nullptr, nullptr, 0); + EXPECT_EQ(result, CL_INVALID_MEM_OBJECT); +}