compute-runtime/runtime/helpers/base_object.h

258 lines
6.1 KiB
C++

/*
* Copyright (C) 2017-2019 Intel Corporation
*
* SPDX-License-Identifier: MIT
*
*/
#pragma once
#include "runtime/api/dispatch.h"
#include "runtime/helpers/abort.h"
#include "runtime/helpers/debug_helpers.h"
#include "runtime/utilities/reference_tracked_object.h"
#include "CL/cl.h"
#include <atomic>
#include <condition_variable>
#include <iostream>
#include <mutex>
#include <thread>
namespace NEO {
#if defined(__clang__)
#define NO_SANITIZE __attribute__((no_sanitize("undefined")))
#else
#define NO_SANITIZE
#endif
template <typename Type>
struct OpenCLObjectMapper {
};
template <typename T>
using DerivedType_t = typename OpenCLObjectMapper<T>::DerivedType;
template <typename DerivedType>
NO_SANITIZE inline DerivedType *castToObject(typename DerivedType::BaseType *object) {
if (object == nullptr) {
return nullptr;
}
auto derivedObject = static_cast<DerivedType *>(object);
if ((derivedObject->getMagic() & DerivedType::maskMagic) == DerivedType::objectMagic) {
DEBUG_BREAK_IF(derivedObject->dispatch.icdDispatch != &icdGlobalDispatchTable);
return derivedObject;
}
return nullptr;
}
template <typename DerivedType>
inline DerivedType *castToObjectOrAbort(typename DerivedType::BaseType *object) {
auto derivedObject = castToObject<DerivedType>(object);
if (derivedObject == nullptr) {
abortExecution();
} else {
return derivedObject;
}
}
template <typename DerivedType>
inline const DerivedType *castToObject(const typename DerivedType::BaseType *object) {
return castToObject<DerivedType>(const_cast<typename DerivedType::BaseType *>(object));
}
template <typename DerivedType>
inline DerivedType *castToObject(const void *object) {
cl_mem clMem = const_cast<cl_mem>(static_cast<const _cl_mem *>(object));
return castToObject<DerivedType>(clMem);
}
extern std::thread::id invalidThreadID;
class ConditionVariableWithCounter {
public:
ConditionVariableWithCounter() {
waitersCount = 0;
}
template <typename... Args>
void wait(Args &&... args) {
++waitersCount;
cond.wait(std::forward<Args>(args)...);
--waitersCount;
}
void notify_one() { // NOLINT
cond.notify_one();
}
uint32_t peekNumWaiters() {
return waitersCount.load();
}
private:
std::atomic_uint waitersCount;
std::condition_variable cond;
};
template <typename T>
class TakeOwnershipWrapper {
public:
TakeOwnershipWrapper(T &obj)
: obj(obj) {
lock();
}
TakeOwnershipWrapper(T &obj, bool lockImmediately)
: obj(obj) {
if (lockImmediately) {
lock();
}
}
~TakeOwnershipWrapper() {
unlock();
}
void unlock() {
if (locked) {
obj.releaseOwnership();
locked = false;
}
}
void lock() {
if (!locked) {
obj.takeOwnership();
locked = true;
}
}
private:
T &obj;
bool locked = false;
};
// This class should act as a base class for all CL objects. It will handle the
// MT safe and reference things for every CL object.
template <typename B>
class BaseObject : public B, public ReferenceTrackedObject<DerivedType_t<B>> {
public:
typedef BaseObject<B> ThisType;
typedef B BaseType;
typedef DerivedType_t<B> DerivedType;
const static cl_ulong maskMagic = 0xFFFFFFFFFFFFFFFFLL;
const static cl_ulong deadMagic = 0xFFFFFFFFFFFFFFFFLL;
BaseObject(const BaseObject &) = delete;
BaseObject &operator=(const BaseObject &) = delete;
protected:
cl_long magic;
mutable std::mutex mtx;
mutable ConditionVariableWithCounter cond;
mutable std::thread::id owner;
mutable uint32_t recursiveOwnageCounter = 0;
BaseObject()
: magic(DerivedType::objectMagic) {
this->incRefApi();
}
~BaseObject() override {
magic = deadMagic;
}
bool isValid() const {
return (magic & DerivedType::maskMagic) == DerivedType::objectMagic;
}
void convertToInternalObject() {
this->incRefInternal();
this->decRefApi();
}
public:
NO_SANITIZE
cl_ulong getMagic() const {
return this->magic;
}
virtual void retain() {
DEBUG_BREAK_IF(!isValid());
this->incRefApi();
}
virtual unique_ptr_if_unused<DerivedType> release() {
DEBUG_BREAK_IF(!isValid());
return this->decRefApi();
}
cl_int getReference() const {
DEBUG_BREAK_IF(!isValid());
return this->getRefApiCount();
}
MOCKABLE_VIRTUAL void takeOwnership() const {
DEBUG_BREAK_IF(!isValid());
std::unique_lock<std::mutex> theLock(mtx);
std::thread::id self = std::this_thread::get_id();
if (owner == invalidThreadID) {
owner = self;
return;
}
if (owner == self) {
++recursiveOwnageCounter;
return;
}
cond.wait(theLock, [&] { return owner == invalidThreadID; });
owner = self;
recursiveOwnageCounter = 0;
}
MOCKABLE_VIRTUAL void releaseOwnership() const {
DEBUG_BREAK_IF(!isValid());
std::unique_lock<std::mutex> theLock(mtx);
if (hasOwnership() == false) {
DEBUG_BREAK_IF(true);
return;
}
if (recursiveOwnageCounter > 0) {
--recursiveOwnageCounter;
return;
}
owner = invalidThreadID;
cond.notify_one();
}
// checks whether current thread owns object mutex
bool hasOwnership() const {
DEBUG_BREAK_IF(!isValid());
return (owner == std::this_thread::get_id());
}
ConditionVariableWithCounter &getCond() {
return this->cond;
}
// Custom allocators for memory tracking CL objects
static void *operator new(size_t sz);
static void *operator new(size_t sz, const std::nothrow_t &) noexcept;
static void operator delete(void *ptr, size_t allocationSize);
static void operator delete(void *ptr, const std::nothrow_t &)noexcept;
};
// Method called by global factory enabler
template <typename Type>
void populateFactoryTable();
} // namespace NEO