From 0b4fe9a0df7389975c4ae56bc82f72582513902f Mon Sep 17 00:00:00 2001 From: Szymon Morek Date: Tue, 14 Jan 2025 10:44:54 +0000 Subject: [PATCH] performance: add staging transfers for cl buffers Related-To: NEO-13529 Signed-off-by: Szymon Morek --- opencl/source/api/api.cpp | 30 ++++-- opencl/source/command_queue/command_queue.h | 5 + .../source/command_queue/command_queue_hw.h | 12 ++- .../command_queue/command_queue_staging.cpp | 21 ++++ .../command_queue/enqueue_write_buffer.h | 20 +++- .../enqueue_write_buffer_tests.cpp | 101 ++++++++++++++++++ .../test/unit_test/mocks/mock_command_queue.h | 18 +++- .../utilities/staging_buffer_manager.cpp | 27 +++++ .../source/utilities/staging_buffer_manager.h | 3 +- .../staging_buffer_manager_tests.cpp | 79 ++++++++++++++ 10 files changed, 299 insertions(+), 17 deletions(-) diff --git a/opencl/source/api/api.cpp b/opencl/source/api/api.cpp index 93d8c21bfb..edcf16e9d0 100644 --- a/opencl/source/api/api.cpp +++ b/opencl/source/api/api.cpp @@ -2545,16 +2545,26 @@ cl_int CL_API_CALL clEnqueueWriteBuffer(cl_command_queue commandQueue, return retVal; } - retVal = pCommandQueue->enqueueWriteBuffer( - pBuffer, - blockingWrite, - offset, - cb, - ptr, - nullptr, - numEventsInWaitList, - eventWaitList, - event); + if (pCommandQueue->isValidForStagingTransfer(pBuffer, ptr, numEventsInWaitList > 0)) { + retVal = pCommandQueue->enqueueStagingWriteBuffer( + pBuffer, + blockingWrite, + offset, + cb, + ptr, + event); + } else { + retVal = pCommandQueue->enqueueWriteBuffer( + pBuffer, + blockingWrite, + offset, + cb, + ptr, + nullptr, + numEventsInWaitList, + eventWaitList, + event); + } } DBG_LOG_INPUTS("event", getClFileLogger().getEvents(reinterpret_cast(event), 1u)); diff --git a/opencl/source/command_queue/command_queue.h b/opencl/source/command_queue/command_queue.h index 983b808090..725eb872e4 100644 --- a/opencl/source/command_queue/command_queue.h +++ b/opencl/source/command_queue/command_queue.h @@ -161,6 +161,10 @@ class CommandQueue : public BaseObject<_cl_command_queue> { const void *ptr, GraphicsAllocation *mapAllocation, cl_uint numEventsInWaitList, const cl_event *eventWaitList, cl_event *event) = 0; + virtual cl_int enqueueWriteBufferImpl(Buffer *buffer, cl_bool blockingWrite, size_t offset, size_t cb, + const void *ptr, GraphicsAllocation *mapAllocation, cl_uint numEventsInWaitList, + const cl_event *eventWaitList, cl_event *event, CommandStreamReceiver &csr) = 0; + virtual cl_int enqueueWriteImageImpl(Image *dstImage, cl_bool blockingWrite, const size_t *origin, const size_t *region, size_t inputRowPitch, size_t inputSlicePitch, const void *ptr, GraphicsAllocation *mapAllocation, cl_uint numEventsInWaitList, @@ -403,6 +407,7 @@ class CommandQueue : public BaseObject<_cl_command_queue> { size_t inputRowPitch, size_t inputSlicePitch, const void *ptr, cl_event *event); cl_int enqueueStagingReadImage(Image *dstImage, cl_bool blockingCopy, const size_t *globalOrigin, const size_t *globalRegion, size_t inputRowPitch, size_t inputSlicePitch, const void *ptr, cl_event *event); + cl_int enqueueStagingWriteBuffer(Buffer *buffer, cl_bool blockingCopy, size_t offset, size_t size, const void *ptr, cl_event *event); bool isValidForStagingBufferCopy(Device &device, void *dstPtr, const void *srcPtr, size_t size, bool hasDependencies); bool isValidForStagingTransfer(MemObj *memObj, const void *ptr, bool hasDependencies); diff --git a/opencl/source/command_queue/command_queue_hw.h b/opencl/source/command_queue/command_queue_hw.h index dd9e5c3319..5fad5911a6 100644 --- a/opencl/source/command_queue/command_queue_hw.h +++ b/opencl/source/command_queue/command_queue_hw.h @@ -1,5 +1,5 @@ /* - * Copyright (C) 2018-2024 Intel Corporation + * Copyright (C) 2018-2025 Intel Corporation * * SPDX-License-Identifier: MIT * @@ -291,6 +291,16 @@ class CommandQueueHw : public CommandQueue { const cl_event *eventWaitList, cl_event *event) override; + cl_int enqueueWriteBufferImpl(Buffer *buffer, + cl_bool blockingWrite, + size_t offset, + size_t cb, + const void *ptr, + GraphicsAllocation *mapAllocation, + cl_uint numEventsInWaitList, + const cl_event *eventWaitList, + cl_event *event, CommandStreamReceiver &csr) override; + cl_int enqueueWriteBufferRect(Buffer *buffer, cl_bool blockingWrite, const size_t *bufferOrigin, diff --git a/opencl/source/command_queue/command_queue_staging.cpp b/opencl/source/command_queue/command_queue_staging.cpp index 9508e23d74..66d08320df 100644 --- a/opencl/source/command_queue/command_queue_staging.cpp +++ b/opencl/source/command_queue/command_queue_staging.cpp @@ -94,6 +94,27 @@ cl_int CommandQueue::enqueueStagingReadImage(Image *srcImage, cl_bool blockingCo return postStagingTransferSync(ret, event, profilingEvent, isSingleTransfer, blockingCopy, CL_COMMAND_READ_IMAGE); } +cl_int CommandQueue::enqueueStagingWriteBuffer(Buffer *buffer, cl_bool blockingCopy, size_t offset, size_t size, const void *ptr, cl_event *event) { + CsrSelectionArgs csrSelectionArgs{CL_COMMAND_WRITE_BUFFER, {}, buffer, this->getDevice().getRootDeviceIndex(), &size}; + CommandStreamReceiver &csr = selectCsrForBuiltinOperation(csrSelectionArgs); + cl_event profilingEvent = nullptr; + + bool isSingleTransfer = false; + ChunkTransferBufferFunc chunkWrite = [&](void *stagingBuffer, size_t chunkOffset, size_t chunkSize) -> int32_t { + auto isFirstTransfer = (chunkOffset == offset); + auto isLastTransfer = (offset + size == chunkOffset + chunkSize); + isSingleTransfer = isFirstTransfer && isLastTransfer; + cl_event *outEvent = assignEventForStaging(event, &profilingEvent, isFirstTransfer, isLastTransfer); + + auto ret = this->enqueueWriteBufferImpl(buffer, false, chunkOffset, chunkSize, stagingBuffer, nullptr, 0, nullptr, outEvent, csr); + ret |= this->flush(); + return ret; + }; + auto stagingBufferManager = this->context->getStagingBufferManager(); + auto ret = stagingBufferManager->performBufferTransfer(ptr, offset, size, chunkWrite, &csr, false); + return postStagingTransferSync(ret, event, profilingEvent, isSingleTransfer, blockingCopy, CL_COMMAND_WRITE_BUFFER); +} + /* * If there's single transfer, use user event. * Otherwise, first transfer uses profiling event to obtain queue/submit/start timestamps. diff --git a/opencl/source/command_queue/enqueue_write_buffer.h b/opencl/source/command_queue/enqueue_write_buffer.h index 8203c8cee0..0458a6fe9f 100644 --- a/opencl/source/command_queue/enqueue_write_buffer.h +++ b/opencl/source/command_queue/enqueue_write_buffer.h @@ -1,5 +1,5 @@ /* - * Copyright (C) 2018-2024 Intel Corporation + * Copyright (C) 2018-2025 Intel Corporation * * SPDX-License-Identifier: MIT * @@ -31,6 +31,24 @@ cl_int CommandQueueHw::enqueueWriteBuffer( CsrSelectionArgs csrSelectionArgs{cmdType, {}, buffer, device->getRootDeviceIndex(), &size}; CommandStreamReceiver &csr = selectCsrForBuiltinOperation(csrSelectionArgs); + return enqueueWriteBufferImpl(buffer, blockingWrite, offset, size, ptr, mapAllocation, numEventsInWaitList, eventWaitList, event, csr); +} + +template +cl_int CommandQueueHw::enqueueWriteBufferImpl( + Buffer *buffer, + cl_bool blockingWrite, + size_t offset, + size_t size, + const void *ptr, + GraphicsAllocation *mapAllocation, + cl_uint numEventsInWaitList, + const cl_event *eventWaitList, + cl_event *event, + CommandStreamReceiver &csr) { + const cl_command_type cmdType = CL_COMMAND_WRITE_BUFFER; + + CsrSelectionArgs csrSelectionArgs{cmdType, {}, buffer, device->getRootDeviceIndex(), &size}; auto rootDeviceIndex = getDevice().getRootDeviceIndex(); auto isMemTransferNeeded = buffer->isMemObjZeroCopy() ? buffer->checkIfMemoryTransferIsRequired(offset, 0, ptr, cmdType) : true; diff --git a/opencl/test/unit_test/command_queue/enqueue_write_buffer_tests.cpp b/opencl/test/unit_test/command_queue/enqueue_write_buffer_tests.cpp index a059dfb4ca..beae821812 100644 --- a/opencl/test/unit_test/command_queue/enqueue_write_buffer_tests.cpp +++ b/opencl/test/unit_test/command_queue/enqueue_write_buffer_tests.cpp @@ -630,3 +630,104 @@ HWTEST_F(EnqueueWriteBufferHw, givenHostPtrIsFromMappedBufferWhenWriteBufferIsCa EXPECT_EQ(CL_SUCCESS, retVal); EXPECT_EQ(1u, csr.createAllocationForHostSurfaceCalled); } + +struct WriteBufferStagingBufferTest : public EnqueueWriteBufferHw { + void SetUp() override { + REQUIRE_SVM_OR_SKIP(defaultHwInfo); + EnqueueWriteBufferHw::SetUp(); + } + + void TearDown() override { + if (defaultHwInfo->capabilityTable.ftrSvm == false) { + return; + } + EnqueueWriteBufferHw::TearDown(); + } + constexpr static size_t chunkSize = MemoryConstants::megaByte * 2; + + unsigned char ptr[MemoryConstants::cacheLineSize]; + MockBuffer buffer; + cl_queue_properties props = {}; +}; + +HWTEST_F(WriteBufferStagingBufferTest, whenEnqueueStagingWriteBufferCalledThenReturnSuccess) { + MockCommandQueueHw mockCommandQueueHw(context.get(), device.get(), &props); + auto res = mockCommandQueueHw.enqueueStagingWriteBuffer(&buffer, false, 0, buffer.getSize(), ptr, nullptr); + EXPECT_TRUE(mockCommandQueueHw.flushCalled); + EXPECT_EQ(res, CL_SUCCESS); + EXPECT_EQ(1ul, mockCommandQueueHw.enqueueWriteBufferCounter); + auto &csr = device->getUltCommandStreamReceiver(); + EXPECT_EQ(0u, csr.createAllocationForHostSurfaceCalled); +} + +HWTEST_F(WriteBufferStagingBufferTest, whenEnqueueStagingWriteBufferCalledWithLargeSizeThenSplitTransfer) { + auto hostPtr = new unsigned char[chunkSize * 4]; + MockCommandQueueHw mockCommandQueueHw(context.get(), device.get(), &props); + auto retVal = CL_SUCCESS; + std::unique_ptr buffer = std::unique_ptr(Buffer::create(context.get(), + 0, + chunkSize * 4, + nullptr, + retVal)); + auto res = mockCommandQueueHw.enqueueStagingWriteBuffer(buffer.get(), false, 0, chunkSize * 4, hostPtr, nullptr); + EXPECT_TRUE(mockCommandQueueHw.flushCalled); + EXPECT_EQ(retVal, CL_SUCCESS); + EXPECT_EQ(res, CL_SUCCESS); + EXPECT_EQ(4ul, mockCommandQueueHw.enqueueWriteBufferCounter); + auto &csr = device->getUltCommandStreamReceiver(); + EXPECT_EQ(0u, csr.createAllocationForHostSurfaceCalled); + + delete[] hostPtr; +} + +HWTEST_F(WriteBufferStagingBufferTest, whenEnqueueStagingWriteBufferCalledWithEventThenReturnValidEvent) { + constexpr cl_command_type expectedLastCmd = CL_COMMAND_WRITE_BUFFER; + MockCommandQueueHw mockCommandQueueHw(context.get(), device.get(), &props); + cl_event event; + auto res = mockCommandQueueHw.enqueueStagingWriteBuffer(&buffer, false, 0, MemoryConstants::cacheLineSize, ptr, &event); + EXPECT_EQ(res, CL_SUCCESS); + + auto pEvent = (Event *)event; + EXPECT_EQ(expectedLastCmd, mockCommandQueueHw.lastCommandType); + EXPECT_EQ(expectedLastCmd, pEvent->getCommandType()); + + clReleaseEvent(event); +} + +HWTEST_F(WriteBufferStagingBufferTest, givenOutOfOrderQueueWhenEnqueueStagingWriteBufferCalledWithSingleTransferThenNoBarrierEnqueued) { + constexpr cl_command_type expectedLastCmd = CL_COMMAND_WRITE_BUFFER; + MockCommandQueueHw mockCommandQueueHw(context.get(), device.get(), &props); + mockCommandQueueHw.setOoqEnabled(); + cl_event event; + auto res = mockCommandQueueHw.enqueueStagingWriteBuffer(&buffer, false, 0, MemoryConstants::cacheLineSize, ptr, &event); + EXPECT_EQ(res, CL_SUCCESS); + + auto pEvent = (Event *)event; + EXPECT_EQ(expectedLastCmd, mockCommandQueueHw.lastCommandType); + EXPECT_EQ(expectedLastCmd, pEvent->getCommandType()); + + clReleaseEvent(event); +} + +HWTEST_F(WriteBufferStagingBufferTest, givenCmdQueueWithProfilingWhenEnqueueStagingWriteBufferThenTimestampsSetCorrectly) { + cl_event event; + MockCommandQueueHw mockCommandQueueHw(context.get(), device.get(), &props); + mockCommandQueueHw.setProfilingEnabled(); + auto res = mockCommandQueueHw.enqueueStagingWriteBuffer(&buffer, false, 0, MemoryConstants::cacheLineSize, ptr, &event); + EXPECT_EQ(res, CL_SUCCESS); + + auto pEvent = (Event *)event; + EXPECT_FALSE(pEvent->isCPUProfilingPath()); + EXPECT_TRUE(pEvent->isProfilingEnabled()); + + clReleaseEvent(event); +} + +HWTEST_F(WriteBufferStagingBufferTest, whenEnqueueStagingWriteBufferFailedThenPropagateErrorCode) { + MockCommandQueueHw mockCommandQueueHw(context.get(), device.get(), &props); + mockCommandQueueHw.enqueueWriteBufferCallBase = false; + auto res = mockCommandQueueHw.enqueueStagingWriteBuffer(&buffer, false, 0, MemoryConstants::cacheLineSize, ptr, nullptr); + + EXPECT_EQ(res, CL_INVALID_OPERATION); + EXPECT_EQ(1ul, mockCommandQueueHw.enqueueWriteBufferCounter); +} \ No newline at end of file diff --git a/opencl/test/unit_test/mocks/mock_command_queue.h b/opencl/test/unit_test/mocks/mock_command_queue.h index 677090414b..2d621cb936 100644 --- a/opencl/test/unit_test/mocks/mock_command_queue.h +++ b/opencl/test/unit_test/mocks/mock_command_queue.h @@ -1,5 +1,5 @@ /* - * Copyright (C) 2018-2024 Intel Corporation + * Copyright (C) 2018-2025 Intel Corporation * * SPDX-License-Identifier: MIT * @@ -106,6 +106,12 @@ class MockCommandQueue : public CommandQueue { return writeBufferRetValue; } + cl_int enqueueWriteBufferImpl(Buffer *buffer, cl_bool blockingWrite, size_t offset, size_t cb, + const void *ptr, GraphicsAllocation *mapAllocation, cl_uint numEventsInWaitList, + const cl_event *eventWaitList, cl_event *event, CommandStreamReceiver &csr) override { + return CL_SUCCESS; + } + WaitStatus waitUntilComplete(TaskCountType gpgpuTaskCountToWait, Range copyEnginesToWait, FlushStamp flushStampToWait, bool useQuickKmdSleep, bool cleanTemporaryAllocationList, bool skipWait) override { latestTaskCountWaited = gpgpuTaskCountToWait; @@ -416,11 +422,14 @@ class MockCommandQueueHw : public CommandQueueHw { cpuDataTransferHandlerCalled = true; return BaseClass::cpuDataTransferHandler(transferProperties, eventsRequest, retVal); } - cl_int enqueueWriteBuffer(Buffer *buffer, cl_bool blockingWrite, size_t offset, size_t size, - const void *ptr, GraphicsAllocation *mapAllocation, cl_uint numEventsInWaitList, const cl_event *eventWaitList, cl_event *event) override { + cl_int enqueueWriteBufferImpl(Buffer *buffer, cl_bool blockingWrite, size_t offset, size_t size, const void *ptr, GraphicsAllocation *mapAllocation, + cl_uint numEventsInWaitList, const cl_event *eventWaitList, cl_event *event, CommandStreamReceiver &csr) override { enqueueWriteBufferCounter++; blockingWriteBuffer = blockingWrite == CL_TRUE; - return BaseClass::enqueueWriteBuffer(buffer, blockingWrite, offset, size, ptr, mapAllocation, numEventsInWaitList, eventWaitList, event); + if (enqueueWriteBufferCallBase) { + return BaseClass::enqueueWriteBufferImpl(buffer, blockingWrite, offset, size, ptr, mapAllocation, numEventsInWaitList, eventWaitList, event, csr); + } + return CL_INVALID_OPERATION; } void enqueueHandlerHook(const unsigned int commandType, const MultiDispatchInfo &dispatchInfo) override { @@ -529,6 +538,7 @@ class MockCommandQueueHw : public CommandQueueHw { size_t enqueueReadImageCounter = 0; bool enqueueReadImageCallBase = true; size_t enqueueWriteBufferCounter = 0; + bool enqueueWriteBufferCallBase = true; size_t requestedCmdStreamSize = 0; bool blockingWriteBuffer = false; bool storeMultiDispatchInfo = false; diff --git a/shared/source/utilities/staging_buffer_manager.cpp b/shared/source/utilities/staging_buffer_manager.cpp index 5a84652e3e..b37c758c27 100644 --- a/shared/source/utilities/staging_buffer_manager.cpp +++ b/shared/source/utilities/staging_buffer_manager.cpp @@ -163,6 +163,33 @@ StagingTransferStatus StagingBufferManager::performImageTransfer(const void *ptr return result; } +StagingTransferStatus StagingBufferManager::performBufferTransfer(const void *ptr, size_t globalOffset, size_t globalSize, ChunkTransferBufferFunc &chunkTransferBufferFunc, CommandStreamReceiver *csr, bool isRead) { + StagingQueue stagingQueue; + auto copiesNum = globalSize / chunkSize; + auto remainder = globalSize % chunkSize; + auto chunkOffset = globalOffset; + StagingTransferStatus result{}; + for (auto i = 0u; i < copiesNum; i++) { + auto chunkPtr = ptrOffset(ptr, i * chunkSize); + result = performChunkTransfer(isRead, const_cast(chunkPtr), chunkSize, stagingQueue, csr, chunkTransferBufferFunc, chunkOffset, chunkSize); + if (result.chunkCopyStatus != 0) { + return result; + } + chunkOffset += chunkSize; + } + + if (remainder != 0) { + auto chunkPtr = ptrOffset(ptr, copiesNum * chunkSize); + result = performChunkTransfer(isRead, const_cast(chunkPtr), remainder, stagingQueue, csr, chunkTransferBufferFunc, chunkOffset, remainder); + if (result.chunkCopyStatus != 0) { + return result; + } + } + + result.waitStatus = drainAndReleaseStagingQueue(stagingQueue); + return result; +} + /* * This method is used for read transfers. It waits for oldest transfer to finish * and copies data associated with that transfer to host allocation. diff --git a/shared/source/utilities/staging_buffer_manager.h b/shared/source/utilities/staging_buffer_manager.h index 059464a1de..1301a9b7b8 100644 --- a/shared/source/utilities/staging_buffer_manager.h +++ b/shared/source/utilities/staging_buffer_manager.h @@ -25,7 +25,7 @@ class HeapAllocator; using ChunkCopyFunction = std::function; using ChunkTransferImageFunc = std::function; - +using ChunkTransferBufferFunc = std::function; class StagingBuffer { public: StagingBuffer(void *baseAddress, size_t size); @@ -83,6 +83,7 @@ class StagingBufferManager { StagingTransferStatus performCopy(void *dstPtr, const void *srcPtr, size_t size, ChunkCopyFunction &chunkCopyFunc, CommandStreamReceiver *csr); StagingTransferStatus performImageTransfer(const void *ptr, const size_t *globalOrigin, const size_t *globalRegion, size_t rowPitch, ChunkTransferImageFunc &chunkTransferImageFunc, CommandStreamReceiver *csr, bool isRead); + StagingTransferStatus performBufferTransfer(const void *ptr, size_t globalOffset, size_t globalSize, ChunkTransferBufferFunc &chunkTransferBufferFunc, CommandStreamReceiver *csr, bool isRead); std::pair requestStagingBuffer(size_t &size); void trackChunk(const StagingBufferTracker &tracker); diff --git a/shared/test/unit_test/utilities/staging_buffer_manager_tests.cpp b/shared/test/unit_test/utilities/staging_buffer_manager_tests.cpp index b6df48670b..ef44dab0ad 100644 --- a/shared/test/unit_test/utilities/staging_buffer_manager_tests.cpp +++ b/shared/test/unit_test/utilities/staging_buffer_manager_tests.cpp @@ -132,6 +132,34 @@ class StagingBufferManagerFixture : public DeviceFixture { delete[] imageData; } + void bufferTransferThroughStagingBuffers(size_t copySize, size_t expectedChunks, size_t expectedAllocations, CommandStreamReceiver *csr) { + auto buffer = new unsigned char[copySize]; + auto nonUsmBuffer = new unsigned char[copySize]; + + size_t chunkCounter = 0; + memset(buffer, 0, copySize); + memset(nonUsmBuffer, 0xFF, copySize); + + ChunkTransferBufferFunc chunkCopy = [&](void *stagingBuffer, size_t offset, size_t size) { + chunkCounter++; + memcpy(buffer + offset, stagingBuffer, size); + reinterpret_cast(csr)->taskCount++; + return 0; + }; + auto initialNumOfUsmAllocations = svmAllocsManager->svmAllocs.getNumAllocs(); + auto ret = stagingBufferManager->performBufferTransfer(nonUsmBuffer, 0, copySize, chunkCopy, csr, false); + auto newUsmAllocations = svmAllocsManager->svmAllocs.getNumAllocs() - initialNumOfUsmAllocations; + + EXPECT_EQ(0, ret.chunkCopyStatus); + EXPECT_EQ(WaitStatus::ready, ret.waitStatus); + EXPECT_EQ(0, memcmp(buffer, nonUsmBuffer, copySize)); + EXPECT_EQ(expectedChunks, chunkCounter); + EXPECT_EQ(expectedAllocations, newUsmAllocations); + + delete[] buffer; + delete[] nonUsmBuffer; + } + constexpr static size_t stagingBufferSize = MemoryConstants::megaByte * 2; DebugManagerStateRestore restorer; std::unique_ptr svmAllocsManager; @@ -574,4 +602,55 @@ TEST_F(StagingBufferManagerTest, givenStagingBufferWhenFailedChunkImageWriteWith EXPECT_EQ(WaitStatus::ready, ret.waitStatus); EXPECT_EQ(remainderCounter, chunkCounter); delete[] ptr; +} + +TEST_F(StagingBufferManagerTest, givenStagingBufferWhenPerformBufferTransferThenCopyData) { + constexpr size_t numOfChunkCopies = 8; + constexpr size_t remainder = 1024; + constexpr size_t totalCopySize = stagingBufferSize * numOfChunkCopies + remainder; + bufferTransferThroughStagingBuffers(totalCopySize, numOfChunkCopies + 1, 1, csr); +} + +TEST_F(StagingBufferManagerTest, givenStagingBufferWhenPerformBufferTransferWithoutRemainderThenNoRemainderCalled) { + constexpr size_t numOfChunkCopies = 8; + constexpr size_t totalCopySize = stagingBufferSize * numOfChunkCopies; + bufferTransferThroughStagingBuffers(totalCopySize, numOfChunkCopies, 1, csr); +} + +TEST_F(StagingBufferManagerTest, givenStagingBufferWhenFailedChunkBufferWriteThenEarlyReturnWithFailure) { + size_t expectedChunks = 4; + constexpr int expectedErrorCode = 1; + auto ptr = new unsigned char[stagingBufferSize * expectedChunks]; + + size_t chunkCounter = 0; + ChunkTransferBufferFunc chunkWrite = [&](void *stagingBuffer, size_t offset, size_t size) -> int32_t { + ++chunkCounter; + return expectedErrorCode; + }; + auto ret = stagingBufferManager->performBufferTransfer(ptr, 0, stagingBufferSize * expectedChunks, chunkWrite, csr, false); + EXPECT_EQ(expectedErrorCode, ret.chunkCopyStatus); + EXPECT_EQ(WaitStatus::ready, ret.waitStatus); + EXPECT_EQ(1u, chunkCounter); + delete[] ptr; +} + +TEST_F(StagingBufferManagerTest, givenStagingBufferWhenFailedChunkBufferWriteWithRemainderThenReturnWithFailure) { + size_t expectedChunks = 2; + constexpr int expectedErrorCode = 1; + auto ptr = new unsigned char[stagingBufferSize * expectedChunks + 512]; + + size_t chunkCounter = 0; + size_t remainderCounter = expectedChunks + 1; + ChunkTransferBufferFunc chunkWrite = [&](void *stagingBuffer, size_t offset, size_t size) -> int32_t { + ++chunkCounter; + if (chunkCounter == remainderCounter) { + return expectedErrorCode; + } + return 0; + }; + auto ret = stagingBufferManager->performBufferTransfer(ptr, 0, stagingBufferSize * expectedChunks + 512, chunkWrite, csr, false); + EXPECT_EQ(expectedErrorCode, ret.chunkCopyStatus); + EXPECT_EQ(WaitStatus::ready, ret.waitStatus); + EXPECT_EQ(remainderCounter, chunkCounter); + delete[] ptr; } \ No newline at end of file