From 887df5a90dcfade3cfaed83f97c8820f5d04944d Mon Sep 17 00:00:00 2001 From: "Mrozek, Michal" Date: Wed, 27 Jun 2018 15:53:35 +0200 Subject: [PATCH] Add reference on device while creating a context - Context uses device, it needs to make sure it exists. Change-Id: I1aeaabc53b6198965dc8f4e6fa37585490774a3f --- runtime/context/context.cpp | 11 +++++++++-- unit_tests/context/context_tests.cpp | 27 ++++++++++++++++++++++++--- unit_tests/mocks/mock_context.cpp | 7 ++++--- unit_tests/mocks/mock_context.h | 4 ++-- 4 files changed, 39 insertions(+), 10 deletions(-) diff --git a/runtime/context/context.cpp b/runtime/context/context.cpp index 6318467873..d39d10a6a4 100644 --- a/runtime/context/context.cpp +++ b/runtime/context/context.cpp @@ -74,6 +74,9 @@ Context::~Context() { memoryManager->getDeferredDeleter()->removeClient(); } gtpinNotifyContextDestroy((cl_context)this); + for (auto &device : devices) { + device->decRefInternal(); + } } DeviceQueue *Context::getDefaultDeviceQueue() { @@ -99,7 +102,7 @@ void Context::overrideSpecialQueueAndDecrementRefCount(CommandQueue *commandQueu }; bool Context::createImpl(const cl_context_properties *properties, - const DeviceVector &devices, + const DeviceVector &inputDevices, void(CL_CALLBACK *funcNotify)(const char *, const void *, size_t, void *), void *data, cl_int &errcodeRet) { @@ -160,7 +163,6 @@ bool Context::createImpl(const cl_context_properties *properties, this->numProperties = numProperties; this->properties = propertiesNew; - this->devices = devices; this->setInteropUserSyncEnabled(interopUserSync); if (!sharingBuilder->finalizeProperties(*this, errcodeRet)) { @@ -168,6 +170,7 @@ bool Context::createImpl(const cl_context_properties *properties, } this->driverDiagnostics = driverDiagnostics.release(); + this->devices = inputDevices; // We currently assume each device uses the same MemoryManager if (devices.size() > 0) { @@ -178,6 +181,10 @@ bool Context::createImpl(const cl_context_properties *properties, } } + for (auto &device : devices) { + device->incRefInternal(); + } + auto commandQueue = CommandQueue::create(this, devices[0], nullptr, errcodeRet); DEBUG_BREAK_IF(commandQueue == nullptr); overrideSpecialQueueAndDecrementRefCount(commandQueue); diff --git a/unit_tests/context/context_tests.cpp b/unit_tests/context/context_tests.cpp index d1780c64c7..d2bf99ab58 100644 --- a/unit_tests/context/context_tests.cpp +++ b/unit_tests/context/context_tests.cpp @@ -58,9 +58,6 @@ struct ContextTest : public PlatformFixture, using PlatformFixture::SetUp; - ContextTest() { - } - void SetUp() override { PlatformFixture::SetUp(); @@ -215,6 +212,30 @@ TEST_F(ContextTest, givenDefaultDeviceCmdQueueWithContextWhenBeingCreatedNextDel EXPECT_EQ(1, context.getRefInternalCount()); } +TEST_F(ContextTest, givenContextWhenItIsCreatedFromDeviceThenItAddsRefCountToThisDevice) { + auto device = castToObject(devices[0]); + EXPECT_EQ(2, device->getRefInternalCount()); + cl_device_id deviceID = devices[0]; + std::unique_ptr context(Context::create(0, DeviceVector(&deviceID, 1), nullptr, nullptr, retVal)); + EXPECT_EQ(3, device->getRefInternalCount()); + context.reset(nullptr); + EXPECT_EQ(2, device->getRefInternalCount()); +} + +TEST_F(ContextTest, givenContextWhenItIsCreatedFromMultipleDevicesThenItAddsRefCountToThoseDevices) { + auto device = castToObject(devices[0]); + EXPECT_EQ(2, device->getRefInternalCount()); + + DeviceVector devicesVector; + devicesVector.push_back(device); + devicesVector.push_back(device); + + std::unique_ptr context(Context::create(0, devicesVector, nullptr, nullptr, retVal)); + EXPECT_EQ(4, device->getRefInternalCount()); + context.reset(nullptr); + EXPECT_EQ(2, device->getRefInternalCount()); +} + TEST_F(ContextTest, givenSpecialCmdQueueWithContextWhenBeingCreatedNextAutoDeletedThenContextRefCountShouldNeitherBeIncrementedNorNextDecremented) { MockContext context((Device *)devices[0], true); EXPECT_EQ(1, context.getRefInternalCount()); diff --git a/unit_tests/mocks/mock_context.cpp b/unit_tests/mocks/mock_context.cpp index 0733a39794..4e600195ea 100644 --- a/unit_tests/mocks/mock_context.cpp +++ b/unit_tests/mocks/mock_context.cpp @@ -43,6 +43,7 @@ MockContext::MockContext(Device *device, bool noSpecialQueue) { assert(retVal == CL_SUCCESS); overrideSpecialQueueAndDecrementRefCount(commandQueue); } + device->incRefInternal(); } MockContext::MockContext( @@ -70,13 +71,13 @@ MockContext::~MockContext() { } MockContext::MockContext() { - device = std::unique_ptr(DeviceHelper<>::create()); - devices.push_back(device.get()); + device = DeviceHelper<>::create(); + devices.push_back(device); memoryManager = device->getMemoryManager(); svmAllocsManager = new SVMAllocsManager(memoryManager); cl_int retVal; if (!specialQueue) { - auto commandQueue = CommandQueue::create(this, device.get(), nullptr, retVal); + auto commandQueue = CommandQueue::create(this, device, nullptr, retVal); assert(retVal == CL_SUCCESS); overrideSpecialQueueAndDecrementRefCount(commandQueue); } diff --git a/unit_tests/mocks/mock_context.h b/unit_tests/mocks/mock_context.h index 8134cff340..0017fc1ba8 100644 --- a/unit_tests/mocks/mock_context.h +++ b/unit_tests/mocks/mock_context.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017, Intel Corporation + * Copyright (c) 2017 - 2018, Intel Corporation * * Permission is hereby granted, free of charge, to any person obtaining a * copy of this software and associated documentation files (the "Software"), @@ -51,6 +51,6 @@ class MockContext : public Context { DriverDiagnostics *getDriverDiagnostics() { return this->driverDiagnostics; } private: - std::unique_ptr device; + Device *device; }; } // namespace OCLRT