diff --git a/opencl/source/api/api.cpp b/opencl/source/api/api.cpp index 21c335fd15..4770e89c16 100644 --- a/opencl/source/api/api.cpp +++ b/opencl/source/api/api.cpp @@ -1650,7 +1650,7 @@ cl_program CL_API_CALL clLinkProgram(cl_context context, ErrorCodeHelper err(errcodeRet, CL_SUCCESS); Context *pContext = nullptr; - Program *pProgram = nullptr; + cl_program clProgram = nullptr; retVal = validateObjects(withCastToInternal(context, &pContext), Program::isValidCallback(funcNotify, userData)); @@ -1661,17 +1661,16 @@ cl_program CL_API_CALL clLinkProgram(cl_context context, } if (CL_SUCCESS == retVal) { - - pProgram = new Program(pContext, false, *deviceVectorPtr); + clProgram = new Program(pContext, false, *deviceVectorPtr); + auto pProgram = castToObject(clProgram); retVal = pProgram->link(*deviceVectorPtr, options, numInputPrograms, inputPrograms); pProgram->invokeCallback(funcNotify, userData); } err.set(retVal); - - TRACING_EXIT(ClLinkProgram, (cl_program *)&pProgram); - return pProgram; + TRACING_EXIT(ClLinkProgram, &clProgram); + return clProgram; } cl_int CL_API_CALL clUnloadPlatformCompiler(cl_platform_id platform) { diff --git a/opencl/test/unit_test/api/cl_intel_tracing_tests.inl b/opencl/test/unit_test/api/cl_intel_tracing_tests.inl index 9daf0c294e..03021bd2d5 100644 --- a/opencl/test/unit_test/api/cl_intel_tracing_tests.inl +++ b/opencl/test/unit_test/api/cl_intel_tracing_tests.inl @@ -999,4 +999,68 @@ TEST_F(IntelClCreateContextFromTypeTracingTest, givenCreateContextFromTypeCallTr clReleaseContext(context); } +struct IntelClLinkProgramTracingTest : public IntelTracingTest, PlatformFixture { + public: + void SetUp() override { + PlatformFixture::setUp(); + IntelTracingTest::setUp(); + + status = clCreateTracingHandleINTEL(devices[0], callback, this, &handle); + ASSERT_NE(nullptr, handle); + ASSERT_EQ(CL_SUCCESS, status); + + status = clSetTracingPointINTEL(handle, CL_FUNCTION_clLinkProgram, CL_TRUE); + ASSERT_EQ(CL_SUCCESS, status); + + status = clEnableTracingINTEL(handle); + ASSERT_EQ(CL_SUCCESS, status); + + pProgram->irBinary = makeCopy(ir, sizeof(ir)); + pProgram->irBinarySize = sizeof(ir); + } + void TearDown() override { + status = clDisableTracingINTEL(handle); + ASSERT_EQ(CL_SUCCESS, status); + + status = clDestroyTracingHandleINTEL(handle); + ASSERT_EQ(CL_SUCCESS, status); + IntelTracingTest::tearDown(); + PlatformFixture::tearDown(); + } + + protected: + void call() { + cl_program inputPrograms = pProgram; + programReturned = clLinkProgram(pContext, 1, &testedClDevice, nullptr, 1, &inputPrograms, nullptr, nullptr, &retVal); + ASSERT_EQ(CL_SUCCESS, retVal); + } + + void vcallback(cl_function_id fid, cl_callback_data *callbackData, void *userData) override { + ASSERT_EQ(CL_FUNCTION_clLinkProgram, fid); + if (callbackData->site == CL_CALLBACK_SITE_ENTER) { + ++enterCount; + } else if (callbackData->site == CL_CALLBACK_SITE_EXIT) { + obtainedProgramCallback = *reinterpret_cast(callbackData->functionReturnValue); + ++exitCount; + } + } + + uint8_t ir[10] = {'m', 'o', 'c', 'k', 'i', 'r', 'd', 'a', 't', 'a'}; + cl_program programReturned = nullptr; + cl_program obtainedProgramCallback = nullptr; + uint16_t enterCount = 0; + uint16_t exitCount = 0; +}; + +TEST_F(IntelClLinkProgramTracingTest, givenLinkProgramCallTracingWhenInvokingCallbackThenPointersFromCallAndCallbackPointToTheSameAddress) { + call(); + EXPECT_EQ(1u, enterCount); + EXPECT_EQ(1u, exitCount); + + EXPECT_NE(nullptr, programReturned); + EXPECT_NE(nullptr, obtainedProgramCallback); + EXPECT_EQ(programReturned, obtainedProgramCallback); + clReleaseProgram(programReturned); +} + } // namespace ULT