Add reference on device while creating a context

- Context uses device, it needs to make sure it exists.

Change-Id: I1aeaabc53b6198965dc8f4e6fa37585490774a3f
This commit is contained in:
Mrozek, Michal
2018-06-27 15:53:35 +02:00
committed by sys_ocldev
parent d447c8c951
commit 887df5a90d
4 changed files with 39 additions and 10 deletions

View File

@@ -74,6 +74,9 @@ Context::~Context() {
memoryManager->getDeferredDeleter()->removeClient(); memoryManager->getDeferredDeleter()->removeClient();
} }
gtpinNotifyContextDestroy((cl_context)this); gtpinNotifyContextDestroy((cl_context)this);
for (auto &device : devices) {
device->decRefInternal();
}
} }
DeviceQueue *Context::getDefaultDeviceQueue() { DeviceQueue *Context::getDefaultDeviceQueue() {
@@ -99,7 +102,7 @@ void Context::overrideSpecialQueueAndDecrementRefCount(CommandQueue *commandQueu
}; };
bool Context::createImpl(const cl_context_properties *properties, bool Context::createImpl(const cl_context_properties *properties,
const DeviceVector &devices, const DeviceVector &inputDevices,
void(CL_CALLBACK *funcNotify)(const char *, const void *, size_t, void *), void(CL_CALLBACK *funcNotify)(const char *, const void *, size_t, void *),
void *data, cl_int &errcodeRet) { void *data, cl_int &errcodeRet) {
@@ -160,7 +163,6 @@ bool Context::createImpl(const cl_context_properties *properties,
this->numProperties = numProperties; this->numProperties = numProperties;
this->properties = propertiesNew; this->properties = propertiesNew;
this->devices = devices;
this->setInteropUserSyncEnabled(interopUserSync); this->setInteropUserSyncEnabled(interopUserSync);
if (!sharingBuilder->finalizeProperties(*this, errcodeRet)) { if (!sharingBuilder->finalizeProperties(*this, errcodeRet)) {
@@ -168,6 +170,7 @@ bool Context::createImpl(const cl_context_properties *properties,
} }
this->driverDiagnostics = driverDiagnostics.release(); this->driverDiagnostics = driverDiagnostics.release();
this->devices = inputDevices;
// We currently assume each device uses the same MemoryManager // We currently assume each device uses the same MemoryManager
if (devices.size() > 0) { if (devices.size() > 0) {
@@ -178,6 +181,10 @@ bool Context::createImpl(const cl_context_properties *properties,
} }
} }
for (auto &device : devices) {
device->incRefInternal();
}
auto commandQueue = CommandQueue::create(this, devices[0], nullptr, errcodeRet); auto commandQueue = CommandQueue::create(this, devices[0], nullptr, errcodeRet);
DEBUG_BREAK_IF(commandQueue == nullptr); DEBUG_BREAK_IF(commandQueue == nullptr);
overrideSpecialQueueAndDecrementRefCount(commandQueue); overrideSpecialQueueAndDecrementRefCount(commandQueue);

View File

@@ -58,9 +58,6 @@ struct ContextTest : public PlatformFixture,
using PlatformFixture::SetUp; using PlatformFixture::SetUp;
ContextTest() {
}
void SetUp() override { void SetUp() override {
PlatformFixture::SetUp(); PlatformFixture::SetUp();
@@ -215,6 +212,30 @@ TEST_F(ContextTest, givenDefaultDeviceCmdQueueWithContextWhenBeingCreatedNextDel
EXPECT_EQ(1, context.getRefInternalCount()); EXPECT_EQ(1, context.getRefInternalCount());
} }
TEST_F(ContextTest, givenContextWhenItIsCreatedFromDeviceThenItAddsRefCountToThisDevice) {
auto device = castToObject<Device>(devices[0]);
EXPECT_EQ(2, device->getRefInternalCount());
cl_device_id deviceID = devices[0];
std::unique_ptr<Context> context(Context::create<Context>(0, DeviceVector(&deviceID, 1), nullptr, nullptr, retVal));
EXPECT_EQ(3, device->getRefInternalCount());
context.reset(nullptr);
EXPECT_EQ(2, device->getRefInternalCount());
}
TEST_F(ContextTest, givenContextWhenItIsCreatedFromMultipleDevicesThenItAddsRefCountToThoseDevices) {
auto device = castToObject<Device>(devices[0]);
EXPECT_EQ(2, device->getRefInternalCount());
DeviceVector devicesVector;
devicesVector.push_back(device);
devicesVector.push_back(device);
std::unique_ptr<Context> context(Context::create<Context>(0, devicesVector, nullptr, nullptr, retVal));
EXPECT_EQ(4, device->getRefInternalCount());
context.reset(nullptr);
EXPECT_EQ(2, device->getRefInternalCount());
}
TEST_F(ContextTest, givenSpecialCmdQueueWithContextWhenBeingCreatedNextAutoDeletedThenContextRefCountShouldNeitherBeIncrementedNorNextDecremented) { TEST_F(ContextTest, givenSpecialCmdQueueWithContextWhenBeingCreatedNextAutoDeletedThenContextRefCountShouldNeitherBeIncrementedNorNextDecremented) {
MockContext context((Device *)devices[0], true); MockContext context((Device *)devices[0], true);
EXPECT_EQ(1, context.getRefInternalCount()); EXPECT_EQ(1, context.getRefInternalCount());

View File

@@ -43,6 +43,7 @@ MockContext::MockContext(Device *device, bool noSpecialQueue) {
assert(retVal == CL_SUCCESS); assert(retVal == CL_SUCCESS);
overrideSpecialQueueAndDecrementRefCount(commandQueue); overrideSpecialQueueAndDecrementRefCount(commandQueue);
} }
device->incRefInternal();
} }
MockContext::MockContext( MockContext::MockContext(
@@ -70,13 +71,13 @@ MockContext::~MockContext() {
} }
MockContext::MockContext() { MockContext::MockContext() {
device = std::unique_ptr<Device>(DeviceHelper<>::create()); device = DeviceHelper<>::create();
devices.push_back(device.get()); devices.push_back(device);
memoryManager = device->getMemoryManager(); memoryManager = device->getMemoryManager();
svmAllocsManager = new SVMAllocsManager(memoryManager); svmAllocsManager = new SVMAllocsManager(memoryManager);
cl_int retVal; cl_int retVal;
if (!specialQueue) { if (!specialQueue) {
auto commandQueue = CommandQueue::create(this, device.get(), nullptr, retVal); auto commandQueue = CommandQueue::create(this, device, nullptr, retVal);
assert(retVal == CL_SUCCESS); assert(retVal == CL_SUCCESS);
overrideSpecialQueueAndDecrementRefCount(commandQueue); overrideSpecialQueueAndDecrementRefCount(commandQueue);
} }

View File

@@ -1,5 +1,5 @@
/* /*
* Copyright (c) 2017, Intel Corporation * Copyright (c) 2017 - 2018, Intel Corporation
* *
* Permission is hereby granted, free of charge, to any person obtaining a * Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"), * copy of this software and associated documentation files (the "Software"),
@@ -51,6 +51,6 @@ class MockContext : public Context {
DriverDiagnostics *getDriverDiagnostics() { return this->driverDiagnostics; } DriverDiagnostics *getDriverDiagnostics() { return this->driverDiagnostics; }
private: private:
std::unique_ptr<Device> device; Device *device;
}; };
} // namespace OCLRT } // namespace OCLRT