From 087e599a3f4e8ded50798480a3f2d42e7a10b118 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Thu, 14 Mar 2019 14:14:14 -0700 Subject: [PATCH] Rename allocator to identifierAllocator and add an identifierMutex to make identifier uniquing thread safe. This also adds a general purpose 'contextMutex' to protect access to the rest of the miscellaneous parts of the MLIRContext, e.g. diagnostics, dialect registration, etc. This is step 5/5 of making the MLIRContext thread-safe. PiperOrigin-RevId: 238516697 --- mlir/lib/IR/MLIRContext.cpp | 77 ++++++++++++++++++++++++++++++------- 1 file changed, 64 insertions(+), 13 deletions(-) diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index 5610d5878c7d..140dfa6b3eb6 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -477,12 +477,21 @@ public: using FusedLocations = DenseSet; FusedLocations fusedLocs; + //===--------------------------------------------------------------------===// + // Identifier uniquing + //===--------------------------------------------------------------------===// + + // Identifier allocator and mutex for thread safety. + llvm::BumpPtrAllocator identifierAllocator; + llvm::sys::SmartRWMutex identifierMutex; + //===--------------------------------------------------------------------===// // Other //===--------------------------------------------------------------------===// - /// We put immortal objects into this allocator. - llvm::BumpPtrAllocator allocator; + /// A general purpose mutex to lock access to parts of the context that do not + /// have a more specific mutex, e.g. registry operations, diagnostics, etc. + llvm::sys::SmartRWMutex contextMutex; /// This is the handler to use to report diagnostics, or null if not /// registered. @@ -569,7 +578,8 @@ public: sparseElementsAttrs; public: - MLIRContextImpl() : filenames(locationAllocator), identifiers(allocator) {} + MLIRContextImpl() + : filenames(locationAllocator), identifiers(identifierAllocator) {} }; } // end namespace mlir @@ -599,11 +609,15 @@ static ArrayRef copyArrayRefInto(llvm::BumpPtrAllocator &allocator, /// value that indicates the type of the diagnostic (e.g., Warning, Error). void MLIRContext::registerDiagnosticHandler( const DiagnosticHandlerTy &handler) { + // Lock access to the context diagnostic handler. + llvm::sys::SmartScopedWriter contextLock(getImpl().contextMutex); getImpl().diagnosticHandler = handler; } /// Return the current diagnostic handler, or null if none is present. auto MLIRContext::getDiagnosticHandler() const -> DiagnosticHandlerTy { + // Lock access to the context diagnostic handler. + llvm::sys::SmartScopedReader contextLock(getImpl().contextMutex); return getImpl().diagnosticHandler; } @@ -625,6 +639,10 @@ void MLIRContext::emitDiagnostic(Location location, const llvm::Twine &message, return; } + // Lock access to the context so that no other threads emit diagnostics at + // the same time. + llvm::sys::SmartScopedWriter contextLock(getImpl().contextMutex); + // If we had a handler registered, emit the diagnostic using it. auto handler = getImpl().diagnosticHandler; if (handler) @@ -658,6 +676,9 @@ bool MLIRContext::emitError(Location location, /// Return information about all registered IR dialects. std::vector MLIRContext::getRegisteredDialects() const { + // Lock access to the context registry. + llvm::sys::SmartScopedReader registryLock(getImpl().contextMutex); + std::vector result; result.reserve(getImpl().dialects.size()); for (auto &dialect : getImpl().dialects) @@ -668,6 +689,8 @@ std::vector MLIRContext::getRegisteredDialects() const { /// Get a registered IR dialect with the given namespace. If none is found, /// then return nullptr. Dialect *MLIRContext::getRegisteredDialect(StringRef name) const { + // Lock access to the context registry. + llvm::sys::SmartScopedReader registryLock(getImpl().contextMutex); for (auto &dialect : getImpl().dialects) if (name == dialect->getNamespace()) return dialect.get(); @@ -677,22 +700,32 @@ Dialect *MLIRContext::getRegisteredDialect(StringRef name) const { /// Register this dialect object with the specified context. The context /// takes ownership of the heap allocated dialect. void Dialect::registerDialect(MLIRContext *context) { - context->getImpl().dialects.push_back(std::unique_ptr(this)); + auto &impl = context->getImpl(); + + // Lock access to the context registry. + llvm::sys::SmartScopedWriter registryLock(impl.contextMutex); + impl.dialects.push_back(std::unique_ptr(this)); } /// Return information about all registered operations. This isn't very /// efficient, typically you should ask the operations about their properties /// directly. std::vector MLIRContext::getRegisteredOperations() const { - // We just have the operations in a non-deterministic hash table order. Dump - // into a temporary array, then sort it by operation name to get a stable - // ordering. - StringMap ®isteredOps = getImpl().registeredOperations; - std::vector> opsToSort; - opsToSort.reserve(registeredOps.size()); - for (auto &elt : registeredOps) - opsToSort.push_back({elt.first(), &elt.second}); + + { // Lock access to the context registry. + llvm::sys::SmartScopedReader registryLock(getImpl().contextMutex); + + // We just have the operations in a non-deterministic hash table order. Dump + // into a temporary array, then sort it by operation name to get a stable + // ordering. + StringMap ®isteredOps = + getImpl().registeredOperations; + + opsToSort.reserve(registeredOps.size()); + for (auto &elt : registeredOps) + opsToSort.push_back({elt.first(), &elt.second}); + } llvm::array_pod_sort(opsToSort.begin(), opsToSort.end()); @@ -707,8 +740,10 @@ void Dialect::addOperation(AbstractOperation opInfo) { assert((namePrefix.empty() || (opInfo.name.split('.').first == namePrefix)) && "op name doesn't start with dialect prefix"); assert(&opInfo.dialect == this && "Dialect object mismatch"); - auto &impl = context->getImpl(); + + // Lock access to the context registry. + llvm::sys::SmartScopedWriter registryLock(impl.contextMutex); if (!impl.registeredOperations.insert({opInfo.name, opInfo}).second) { llvm::errs() << "error: ops named '" << opInfo.name << "' is already registered.\n"; @@ -719,6 +754,9 @@ void Dialect::addOperation(AbstractOperation opInfo) { /// Register a dialect-specific type with the current context. void Dialect::addType(const TypeID *const typeID) { auto &impl = context->getImpl(); + + // Lock access to the context registry. + llvm::sys::SmartScopedWriter registryLock(impl.contextMutex); if (!impl.registeredTypes.insert({typeID, this}).second) { llvm::errs() << "error: type already registered.\n"; abort(); @@ -730,6 +768,9 @@ void Dialect::addType(const TypeID *const typeID) { const AbstractOperation *AbstractOperation::lookup(StringRef opName, MLIRContext *context) { auto &impl = context->getImpl(); + + // Lock access to the context registry. + llvm::sys::SmartScopedReader registryLock(impl.contextMutex); auto it = impl.registeredOperations.find(opName); if (it != impl.registeredOperations.end()) return &it->second; @@ -747,6 +788,16 @@ Identifier Identifier::get(StringRef str, const MLIRContext *context) { "Cannot create an identifier with a nul character"); auto &impl = context->getImpl(); + + { // Check for an existing identifier in read-only mode. + llvm::sys::SmartScopedReader contextLock(impl.identifierMutex); + auto it = impl.identifiers.find(str); + if (it != impl.identifiers.end()) + return Identifier(it->getKeyData()); + } + + // Aquire a writer-lock so that we can safely create the new instance. + llvm::sys::SmartScopedWriter contextLock(impl.identifierMutex); auto it = impl.identifiers.insert({str, char()}).first; return Identifier(it->getKeyData()); }