Simplify Device classes

Signed-off-by: Bartosz Dunajski <bartosz.dunajski@intel.com>
This commit is contained in:
Bartosz Dunajski
2021-04-07 17:00:33 +00:00
committed by Compute-Runtime-Automation
parent a2cbb4f859
commit 9cf7651643
10 changed files with 65 additions and 83 deletions

View File

@ -86,7 +86,7 @@ MetricContextImp::MetricContextImp(Device &deviceInput)
metricGroupDomains(*this) {
auto deviceNeo = deviceInput.getNEODevice();
bool isSubDevice = deviceNeo->getParentDevice() != nullptr;
bool isSubDevice = deviceNeo->isSubDevice();
subDeviceIndex = isSubDevice
? static_cast<NEO::SubDevice *>(deviceNeo)->getSubDeviceIndex()

View File

@ -203,7 +203,7 @@ void MetricsLibrary::getSubDeviceClientOptions(
ClientOptionsData_1_0 &subDeviceIndex,
ClientOptionsData_1_0 &subDeviceCount) {
if (neoDevice.getParentDevice() == nullptr) {
if (!neoDevice.isSubDevice()) {
// Root device.
subDevice.Type = ClientOptionsType::SubDevice;
@ -225,7 +225,7 @@ void MetricsLibrary::getSubDeviceClientOptions(
subDeviceIndex.SubDeviceIndex.Index = static_cast<NEO::SubDevice *>(&neoDevice)->getSubDeviceIndex();
subDeviceCount.Type = ClientOptionsType::SubDeviceCount;
subDeviceCount.SubDeviceCount.Count = neoDevice.getParentDevice()->getNumAvailableDevices();
subDeviceCount.SubDeviceCount.Count = neoDevice.getRootDevice()->getNumAvailableDevices();
}
}

View File

@ -25,7 +25,7 @@ std::unique_ptr<PerformanceCounters> PerformanceCounters::create(Device *device)
auto &hwHelper = HwHelper::get(gen);
UNRECOVERABLE_IF(counter == nullptr);
if (device->getParentDevice() == nullptr) {
if (!device->isSubDevice()) {
// Root device.
counter->subDevice.Enabled = false;
@ -36,7 +36,7 @@ std::unique_ptr<PerformanceCounters> PerformanceCounters::create(Device *device)
// Sub device.
counter->subDevice.Enabled = true;
counter->subDeviceIndex.Index = static_cast<NEO::SubDevice *>(device)->getSubDeviceIndex();
counter->subDeviceCount.Count = device->getParentDevice()->getNumAvailableDevices();
counter->subDeviceCount.Count = device->getRootDevice()->getNumAvailableDevices();
}
// Adapter data.

View File

@ -401,13 +401,13 @@ TEST(DeviceCreation, givenDeviceWhenCheckingEnginesCountThenNumberGreaterThanZer
TEST(DeviceCreation, givenDeviceWhenCheckingParentDeviceThenCorrectValueIsReturned) {
UltDeviceFactory deviceFactory{2, 2};
EXPECT_EQ(nullptr, deviceFactory.rootDevices[0]->getParentDevice());
EXPECT_EQ(deviceFactory.rootDevices[0], deviceFactory.subDevices[0]->getParentDevice());
EXPECT_EQ(deviceFactory.rootDevices[0], deviceFactory.subDevices[1]->getParentDevice());
EXPECT_EQ(deviceFactory.rootDevices[0], deviceFactory.rootDevices[0]->getRootDevice());
EXPECT_EQ(deviceFactory.rootDevices[0], deviceFactory.subDevices[0]->getRootDevice());
EXPECT_EQ(deviceFactory.rootDevices[0], deviceFactory.subDevices[1]->getRootDevice());
EXPECT_EQ(nullptr, deviceFactory.rootDevices[1]->getParentDevice());
EXPECT_EQ(deviceFactory.rootDevices[1], deviceFactory.subDevices[2]->getParentDevice());
EXPECT_EQ(deviceFactory.rootDevices[1], deviceFactory.subDevices[3]->getParentDevice());
EXPECT_EQ(deviceFactory.rootDevices[1], deviceFactory.rootDevices[1]->getRootDevice());
EXPECT_EQ(deviceFactory.rootDevices[1], deviceFactory.subDevices[2]->getRootDevice());
EXPECT_EQ(deviceFactory.rootDevices[1], deviceFactory.subDevices[3]->getRootDevice());
}
TEST(DeviceCreation, givenRootDeviceWithSubDevicesWhenCheckingEngineGroupsThenItHasOneNonEmptyGroup) {

View File

@ -44,6 +44,12 @@ Device::~Device() {
engine.commandStreamReceiver->flushBatchedSubmissions();
}
for (auto subdevice : subdevices) {
if (subdevice) {
delete subdevice;
}
}
syncBufferHandler.reset();
commandStreamReceivers.clear();
executionEnvironment->memoryManager->waitForDeletions();
@ -309,6 +315,26 @@ bool Device::getHostTimer(uint64_t *hostTimestamp) const {
return getOSTime()->getCpuTime(hostTimestamp);
}
uint32_t Device::getNumAvailableDevices() const {
if (subdevices.empty()) {
return 1u;
}
return getNumSubDevices();
}
Device *Device::getDeviceById(uint32_t deviceId) const {
if (subdevices.empty()) {
UNRECOVERABLE_IF(deviceId > 0);
return const_cast<Device *>(this);
}
UNRECOVERABLE_IF(deviceId >= subdevices.size());
return subdevices[deviceId];
}
BindlessHeapsHelper *Device::getBindlessHeapsHelper() const {
return getRootDeviceEnvironment().getBindlessHeapsHelper();
}
GmmClientContext *Device::getGmmClientContext() const {
return getGmmHelper()->getClientContext();
}

View File

@ -24,6 +24,7 @@
namespace NEO {
class OSTime;
class SourceLevelDebugger;
class SubDevice;
class Device : public ReferenceTrackedObject<Device> {
public:
@ -98,11 +99,14 @@ class Device : public ReferenceTrackedObject<Device> {
void allocateSyncBufferHandler();
virtual uint32_t getRootDeviceIndex() const = 0;
virtual uint32_t getNumAvailableDevices() const = 0;
virtual Device *getDeviceById(uint32_t deviceId) const = 0;
virtual Device *getParentDevice() const = 0;
virtual DeviceBitfield getDeviceBitfield() const = 0;
virtual BindlessHeapsHelper *getBindlessHeapsHelper() const = 0;
uint32_t getNumAvailableDevices() const;
virtual Device *getDeviceById(uint32_t deviceId) const;
virtual Device *getRootDevice() const = 0;
DeviceBitfield getDeviceBitfield() const { return deviceBitfield; };
uint32_t getNumSubDevices() const { return numSubDevices; }
virtual bool isSubDevice() const = 0;
BindlessHeapsHelper *getBindlessHeapsHelper() const;
static decltype(&PerformanceCounters::create) createPerformanceCountersFunc;
std::unique_ptr<SyncBufferHandler> syncBufferHandler;
@ -137,11 +141,17 @@ class Device : public ReferenceTrackedObject<Device> {
std::vector<std::unique_ptr<CommandStreamReceiver>> commandStreamReceivers;
std::vector<EngineControl> engines;
std::vector<std::vector<EngineControl>> engineGroups;
std::vector<SubDevice *> subdevices;
PreemptionMode preemptionMode;
ExecutionEnvironment *executionEnvironment = nullptr;
uint32_t defaultEngineIndex = 0;
uint32_t numSubDevices = 0;
std::atomic<uint32_t> selectorCopyEngine{0};
DeviceBitfield deviceBitfield = 1;
uintptr_t specializedDevice = reinterpret_cast<uintptr_t>(nullptr);
};

View File

@ -28,42 +28,14 @@ RootDevice::~RootDevice() {
if (getRootDeviceEnvironment().tagsManager) {
getRootDeviceEnvironment().tagsManager->shutdown();
}
for (auto subdevice : subdevices) {
if (subdevice) {
delete subdevice;
}
}
}
uint32_t RootDevice::getNumSubDevices() const {
return this->numSubDevices;
}
BindlessHeapsHelper *RootDevice::getBindlessHeapsHelper() const {
return this->getRootDeviceEnvironment().getBindlessHeapsHelper();
}
uint32_t RootDevice::getRootDeviceIndex() const {
return rootDeviceIndex;
}
uint32_t RootDevice::getNumAvailableDevices() const {
if (subdevices.empty()) {
return 1u;
}
return getNumSubDevices();
}
Device *RootDevice::getDeviceById(uint32_t deviceId) const {
if (subdevices.empty()) {
return const_cast<RootDevice *>(this);
}
UNRECOVERABLE_IF(deviceId >= subdevices.size());
return subdevices[deviceId];
}
Device *RootDevice::getParentDevice() const {
return nullptr;
Device *RootDevice::getRootDevice() const {
return const_cast<RootDevice *>(this);
}
SubDevice *RootDevice::createSubDevice(uint32_t subDeviceIndex) {
@ -103,10 +75,6 @@ bool RootDevice::createDeviceImpl() {
return true;
}
DeviceBitfield RootDevice::getDeviceBitfield() const {
return deviceBitfield;
}
bool RootDevice::createEngines() {
if (getNumSubDevices() < 2) {
return Device::createEngines();

View File

@ -17,23 +17,17 @@ class RootDevice : public Device {
RootDevice(ExecutionEnvironment *executionEnvironment, uint32_t rootDeviceIndex);
~RootDevice() override;
bool createDeviceImpl() override;
uint32_t getNumAvailableDevices() const override;
uint32_t getRootDeviceIndex() const override;
Device *getDeviceById(uint32_t deviceId) const override;
Device *getParentDevice() const override;
uint32_t getNumSubDevices() const;
BindlessHeapsHelper *getBindlessHeapsHelper() const override;
Device *getRootDevice() const override;
bool isSubDevice() const override { return false; }
protected:
DeviceBitfield getDeviceBitfield() const override;
bool createEngines() override;
void initializeRootCommandStreamReceiver();
MOCKABLE_VIRTUAL SubDevice *createSubDevice(uint32_t subDeviceIndex);
std::vector<SubDevice *> subdevices;
const uint32_t rootDeviceIndex;
DeviceBitfield deviceBitfield = DeviceBitfield{1u};
uint32_t numSubDevices = 0;
};
} // namespace NEO

View File

@ -1,5 +1,5 @@
/*
* Copyright (C) 2019-2020 Intel Corporation
* Copyright (C) 2019-2021 Intel Corporation
*
* SPDX-License-Identifier: MIT
*
@ -13,6 +13,8 @@ namespace NEO {
SubDevice::SubDevice(ExecutionEnvironment *executionEnvironment, uint32_t subDeviceIndex, RootDevice &rootDevice)
: Device(executionEnvironment), subDeviceIndex(subDeviceIndex), rootDevice(rootDevice) {
deviceBitfield = 0;
deviceBitfield.set(subDeviceIndex);
}
void SubDevice::incRefInternal() {
@ -22,14 +24,6 @@ unique_ptr_if_unused<Device> SubDevice::decRefInternal() {
return rootDevice.decRefInternal();
}
DeviceBitfield SubDevice::getDeviceBitfield() const {
DeviceBitfield deviceBitfield;
deviceBitfield.set(subDeviceIndex);
return deviceBitfield;
}
uint32_t SubDevice::getNumAvailableDevices() const {
return 1u;
}
uint32_t SubDevice::getRootDeviceIndex() const {
return this->rootDevice.getRootDeviceIndex();
}
@ -38,17 +32,9 @@ uint32_t SubDevice::getSubDeviceIndex() const {
return subDeviceIndex;
}
Device *SubDevice::getDeviceById(uint32_t deviceId) const {
UNRECOVERABLE_IF(deviceId >= getNumAvailableDevices());
return const_cast<SubDevice *>(this);
}
Device *SubDevice::getParentDevice() const {
Device *SubDevice::getRootDevice() const {
return &rootDevice;
}
BindlessHeapsHelper *SubDevice::getBindlessHeapsHelper() const {
return rootDevice.getBindlessHeapsHelper();
}
uint64_t SubDevice::getGlobalMemorySize(uint32_t deviceBitfield) const {
auto globalMemorySize = Device::getGlobalMemorySize(static_cast<uint32_t>(maxNBitValue(rootDevice.getNumSubDevices())));

View File

@ -1,5 +1,5 @@
/*
* Copyright (C) 2019-2020 Intel Corporation
* Copyright (C) 2019-2021 Intel Corporation
*
* SPDX-License-Identifier: MIT
*
@ -16,16 +16,14 @@ class SubDevice : public Device {
void incRefInternal() override;
unique_ptr_if_unused<Device> decRefInternal() override;
uint32_t getNumAvailableDevices() const override;
uint32_t getRootDeviceIndex() const override;
Device *getDeviceById(uint32_t deviceId) const override;
Device *getParentDevice() const override;
BindlessHeapsHelper *getBindlessHeapsHelper() const override;
Device *getRootDevice() const override;
uint32_t getSubDeviceIndex() const;
bool isSubDevice() const override { return true; }
protected:
DeviceBitfield getDeviceBitfield() const override;
uint64_t getGlobalMemorySize(uint32_t deviceBitfield) const override;
const uint32_t subDeviceIndex;
RootDevice &rootDevice;