diff --git a/opencl/source/api/api.cpp b/opencl/source/api/api.cpp index 89b4096f7f..a42504d6a8 100644 --- a/opencl/source/api/api.cpp +++ b/opencl/source/api/api.cpp @@ -4056,7 +4056,6 @@ void *CL_API_CALL clSVMAlloc(cl_context context, TRACING_EXIT(clSVMAlloc, &pAlloc); return pAlloc; } - auto pDevice = pContext->getDevice(0); if (flags == 0) { flags = CL_MEM_READ_WRITE; @@ -4077,6 +4076,7 @@ void *CL_API_CALL clSVMAlloc(cl_context context, return pAlloc; } + auto pDevice = pContext->getDevice(0); if ((size == 0) || (size > pDevice->getDeviceInfo().maxMemAllocSize)) { TRACING_EXIT(clSVMAlloc, &pAlloc); return pAlloc; @@ -4112,10 +4112,23 @@ void CL_API_CALL clSVMFree(cl_context context, TRACING_ENTER(clSVMFree, &context, &svmPointer); DBG_LOG_INPUTS("context", context, "svmPointer", svmPointer); + Context *pContext = nullptr; - if (validateObject(WithCastToInternal(context, &pContext)) == CL_SUCCESS) { - pContext->getSVMAllocsManager()->freeSVMAlloc(svmPointer); + cl_int retVal = validateObjects( + WithCastToInternal(context, &pContext)); + + if (retVal != CL_SUCCESS) { + TRACING_EXIT(clSVMFree, nullptr); + return; } + + auto pClDevice = pContext->getDevice(0); + if (!pClDevice->getHardwareInfo().capabilityTable.ftrSvm) { + TRACING_EXIT(clSVMFree, nullptr); + return; + } + + pContext->getSVMAllocsManager()->freeSVMAlloc(svmPointer); TRACING_EXIT(clSVMFree, nullptr); } diff --git a/opencl/test/unit_test/api/cl_svm_free_tests.inl b/opencl/test/unit_test/api/cl_svm_free_tests.inl index c01061530f..1030716875 100644 --- a/opencl/test/unit_test/api/cl_svm_free_tests.inl +++ b/opencl/test/unit_test/api/cl_svm_free_tests.inl @@ -19,4 +19,19 @@ TEST_F(clSVMFreeTests, GivenNullPtrWhenFreeingSvmThenNoAction) { nullptr // void *svm_pointer ); } + +TEST_F(clSVMFreeTests, GivenContextWithDeviceNotSupportingSvmWhenFreeingSvmThenNoAction) { + HardwareInfo hwInfo = *platformDevices[0]; + hwInfo.capabilityTable.ftrSvm = false; + auto clDevice = std::make_unique(MockDevice::createWithNewExecutionEnvironment(&hwInfo)); + + cl_device_id deviceId = clDevice.get(); + auto context = clUniquePtr(Context::create(nullptr, ClDeviceVector(&deviceId, 1), nullptr, nullptr, retVal)); + EXPECT_EQ(retVal, CL_SUCCESS); + + clSVMFree( + context.get(), + reinterpret_cast(0x1234)); +} + } // namespace ULT