diff --git a/opencl/source/api/api.cpp b/opencl/source/api/api.cpp index 5ba7e8e369..19773f2066 100644 --- a/opencl/source/api/api.cpp +++ b/opencl/source/api/api.cpp @@ -306,7 +306,9 @@ cl_int CL_API_CALL clCreateSubDevices(cl_device_id inDevice, } for (uint32_t i = 0; i < subDevicesCount; i++) { - outDevices[i] = pInDevice->getDeviceById(i); + auto pClDevice = pInDevice->getDeviceById(i); + pClDevice->retainApi(); + outDevices[i] = pClDevice; } return CL_SUCCESS; diff --git a/opencl/source/device/cl_device.cpp b/opencl/source/device/cl_device.cpp index 0a1989c085..8457b569b9 100644 --- a/opencl/source/device/cl_device.cpp +++ b/opencl/source/device/cl_device.cpp @@ -38,7 +38,10 @@ ClDevice::ClDevice(Device &device, Platform *platform) : device(device), platfor deviceInfo.partitionType[1] = CL_DEVICE_AFFINITY_DOMAIN_NUMA; deviceInfo.partitionType[2] = 0; - subDevices.push_back(std::make_unique(coreSubDevice, platform)); + auto pClSubDevice = std::make_unique(coreSubDevice, platform); + pClSubDevice->incRefInternal(); + pClSubDevice->decRefApi(); + subDevices.push_back(std::move(pClSubDevice)); } } if (device.getDeviceInfo().debuggerActive) { @@ -72,18 +75,20 @@ unsigned int ClDevice::getEnabledClVersion() const { return device.getEnabledClV unsigned int ClDevice::getSupportedClVersion() const { return device.getSupportedClVersion(); } void ClDevice::retainApi() { - if (device.isReleasable()) { - auto pPlatform = castToObject(platformId); - pPlatform->getClDevice(device.getRootDeviceIndex())->incRefInternal(); + auto parentDeviceId = device.getDeviceInfo().parentDevice; + if (parentDeviceId) { + auto pParentClDevice = static_cast(parentDeviceId); + pParentClDevice->incRefInternal(); this->incRefApi(); } }; unique_ptr_if_unused ClDevice::releaseApi() { - if (!device.isReleasable()) { + auto parentDeviceId = device.getDeviceInfo().parentDevice; + if (!parentDeviceId) { return unique_ptr_if_unused(this, false); } - auto pPlatform = castToObject(platformId); - pPlatform->getClDevice(device.getRootDeviceIndex())->decRefInternal(); + auto pParentClDevice = static_cast(parentDeviceId); + pParentClDevice->decRefInternal(); return this->decRefApi(); } diff --git a/opencl/test/unit_test/api/cl_create_sub_devices_tests.inl b/opencl/test/unit_test/api/cl_create_sub_devices_tests.inl index e9b3d29e1f..03359977b5 100644 --- a/opencl/test/unit_test/api/cl_create_sub_devices_tests.inl +++ b/opencl/test/unit_test/api/cl_create_sub_devices_tests.inl @@ -104,6 +104,23 @@ TEST_F(clCreateSubDevicesTests, GivenValidInputWhenCreatingSubDevicesThenSubDevi EXPECT_EQ(outDevices[1], outDevices2[1]); } +TEST_F(clCreateSubDevicesTests, GivenValidInputWhenCreatingSubDevicesThenDeviceApiReferenceCountIsIncreasedEveryTime) { + setup(2); + + EXPECT_EQ(0, device->getDeviceById(0)->getRefApiCount()); + EXPECT_EQ(0, device->getDeviceById(1)->getRefApiCount()); + + auto retVal = clCreateSubDevices(device.get(), properties, outDevicesCount, outDevices, nullptr); + EXPECT_EQ(CL_SUCCESS, retVal); + EXPECT_EQ(1, device->getDeviceById(0)->getRefApiCount()); + EXPECT_EQ(1, device->getDeviceById(1)->getRefApiCount()); + + retVal = clCreateSubDevices(device.get(), properties, outDevicesCount, outDevices, nullptr); + EXPECT_EQ(CL_SUCCESS, retVal); + EXPECT_EQ(2, device->getDeviceById(0)->getRefApiCount()); + EXPECT_EQ(2, device->getDeviceById(1)->getRefApiCount()); +} + struct clCreateSubDevicesDeviceInfoTests : clCreateSubDevicesTests { void setup(int numberOfDevices) { clCreateSubDevicesTests::setup(numberOfDevices); diff --git a/shared/source/device/device.h b/shared/source/device/device.h index d6a29b0908..7fb8481b11 100644 --- a/shared/source/device/device.h +++ b/shared/source/device/device.h @@ -35,8 +35,6 @@ class Device : public ReferenceTrackedObject { return createDeviceInternals(device); } - virtual bool isReleasable() = 0; - bool getDeviceAndHostTimer(uint64_t *deviceTimestamp, uint64_t *hostTimestamp) const; bool getHostTimer(uint64_t *hostTimestamp) const; const HardwareInfo &getHardwareInfo() const; diff --git a/shared/source/device/root_device.cpp b/shared/source/device/root_device.cpp index a40828fa22..4b920c337f 100644 --- a/shared/source/device/root_device.cpp +++ b/shared/source/device/root_device.cpp @@ -76,9 +76,6 @@ bool RootDevice::createDeviceImpl() { } return true; } -bool RootDevice::isReleasable() { - return false; -}; DeviceBitfield RootDevice::getDeviceBitfield() const { DeviceBitfield deviceBitfield{static_cast(maxNBitValue(getNumAvailableDevices()))}; return deviceBitfield; diff --git a/shared/source/device/root_device.h b/shared/source/device/root_device.h index 8de3e7cf3c..b8b62e78f8 100644 --- a/shared/source/device/root_device.h +++ b/shared/source/device/root_device.h @@ -20,7 +20,6 @@ class RootDevice : public Device { uint32_t getNumAvailableDevices() const override; uint32_t getRootDeviceIndex() const override; Device *getDeviceById(uint32_t deviceId) const override; - bool isReleasable() override; uint32_t getNumSubDevices() const; diff --git a/shared/source/device/sub_device.cpp b/shared/source/device/sub_device.cpp index 2eca39b840..33cef34d2e 100644 --- a/shared/source/device/sub_device.cpp +++ b/shared/source/device/sub_device.cpp @@ -15,10 +15,6 @@ SubDevice::SubDevice(ExecutionEnvironment *executionEnvironment, uint32_t subDev : Device(executionEnvironment), subDeviceIndex(subDeviceIndex), rootDevice(rootDevice) { } -bool SubDevice::isReleasable() { - return true; -}; - DeviceBitfield SubDevice::getDeviceBitfield() const { DeviceBitfield deviceBitfield; deviceBitfield.set(subDeviceIndex); diff --git a/shared/source/device/sub_device.h b/shared/source/device/sub_device.h index 36fa2d5586..68fe4482dd 100644 --- a/shared/source/device/sub_device.h +++ b/shared/source/device/sub_device.h @@ -15,7 +15,6 @@ class SubDevice : public Device { constexpr static uint32_t unspecifiedSubDeviceIndex = std::numeric_limits::max(); SubDevice(ExecutionEnvironment *executionEnvironment, uint32_t subDeviceIndex, RootDevice &rootDevice); - bool isReleasable() override; uint32_t getNumAvailableDevices() const override; uint32_t getRootDeviceIndex() const override; Device *getDeviceById(uint32_t deviceId) const override;