Move the parser extensions for aliases currently on Dialect to a new OpAsmDialectInterface.

This will allow for adding more hooks for controlling parser behavior without bloating Dialect in the common case. This cl also adds iteration support to the DialectInterfaceCollection.

PiperOrigin-RevId: 264627846
This commit is contained in:
River Riddle
2019-08-21 09:41:37 -07:00
committed by A. Unique TensorFlower
parent 8d18fdf2d3
commit 7e1af594d2
5 changed files with 85 additions and 28 deletions

View File

@@ -133,22 +133,6 @@ public:
llvm_unreachable("dialect has no registered type printing hook");
}
/// Registered hooks for getting identifier aliases for symbols. The
/// identifier is used in place of the symbol when printing textual IR.
///
/// Hook for defining Attribute kind aliases. This will generate an alias for
/// all attributes of the given kind in the form : <alias>[0-9]+. These
/// aliases must not contain `.`.
virtual void getAttributeKindAliases(
SmallVectorImpl<std::pair<unsigned, StringRef>> &aliases) {}
/// Hook for defining Attribute aliases. These aliases must not contain `.` or
/// end with a numeric digit([0-9]+).
virtual void getAttributeAliases(
SmallVectorImpl<std::pair<Attribute, StringRef>> &aliases) {}
/// Hook for defining Type aliases.
virtual void
getTypeAliases(SmallVectorImpl<std::pair<Type, StringRef>> &aliases) {}
//===--------------------------------------------------------------------===//
// Verification Hooks
//===--------------------------------------------------------------------===//

View File

@@ -99,6 +99,7 @@ class DialectInterfaceCollectionBase {
/// A set of registered dialect interface instances.
using InterfaceSetT = DenseSet<const DialectInterface *, InterfaceKeyInfo>;
using InterfaceVectorT = std::vector<const DialectInterface *>;
public:
DialectInterfaceCollectionBase(MLIRContext *ctx, ClassID *interfaceKind);
@@ -115,9 +116,40 @@ protected:
return it == interfaces.end() ? nullptr : *it;
}
/// An iterator class that iterates the held interface objects of the given
/// derived interface type.
template <typename InterfaceT>
class iterator : public llvm::mapped_iterator<
InterfaceVectorT::const_iterator,
const InterfaceT &(*)(const DialectInterface *)> {
static const InterfaceT &remapIt(const DialectInterface *interface) {
return *static_cast<const InterfaceT *>(interface);
}
iterator(InterfaceVectorT::const_iterator it)
: llvm::mapped_iterator<
InterfaceVectorT::const_iterator,
const InterfaceT &(*)(const DialectInterface *)>(it, &remapIt) {}
/// Allow access to the constructor.
friend DialectInterfaceCollectionBase;
};
/// Iterator access to the held interfaces.
template <typename InterfaceT> iterator<InterfaceT> interface_begin() const {
return iterator<InterfaceT>(orderedInterfaces.begin());
}
template <typename InterfaceT> iterator<InterfaceT> interface_end() const {
return iterator<InterfaceT>(orderedInterfaces.end());
}
private:
/// A set of registered dialect interface instances.
InterfaceSetT interfaces;
/// An ordered list of the registered interface instances, necessary for
/// deterministic iteration.
// NOTE: SetVector does not provide find access, so it can't be used here.
InterfaceVectorT orderedInterfaces;
};
} // namespace detail
@@ -141,6 +173,16 @@ public:
return static_cast<const InterfaceType *>(
detail::DialectInterfaceCollectionBase::getInterfaceFor(obj));
}
/// Iterator access to the held interfaces.
using iterator =
detail::DialectInterfaceCollectionBase::iterator<InterfaceType>;
iterator begin() const { return interface_begin<InterfaceType>(); }
iterator end() const { return interface_end<InterfaceType>(); }
private:
using detail::DialectInterfaceCollectionBase::interface_begin;
using detail::DialectInterfaceCollectionBase::interface_end;
};
} // namespace mlir

View File

@@ -22,6 +22,7 @@
#ifndef MLIR_IR_OPIMPLEMENTATION_H
#define MLIR_IR_OPIMPLEMENTATION_H
#include "mlir/IR/DialectInterface.h"
#include "mlir/IR/OpDefinition.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/SMLoc.h"
@@ -528,6 +529,32 @@ private:
Delimiter delimiter);
};
//===--------------------------------------------------------------------===//
// Dialect OpAsm interface.
//===--------------------------------------------------------------------===//
class OpAsmDialectInterface
: public DialectInterface::Base<OpAsmDialectInterface> {
public:
using Base::Base;
/// Hooks for getting identifier aliases for symbols. The identifier is used
/// in place of the symbol when printing textual IR.
///
/// Hook for defining Attribute kind aliases. This will generate an alias for
/// all attributes of the given kind in the form : <alias>[0-9]+. These
/// aliases must not contain `.`.
virtual void getAttributeKindAliases(
SmallVectorImpl<std::pair<unsigned, StringRef>> &aliases) const {}
/// Hook for defining Attribute aliases. These aliases must not contain `.` or
/// end with a numeric digit([0-9]+).
virtual void getAttributeAliases(
SmallVectorImpl<std::pair<Attribute, StringRef>> &aliases) const {}
/// Hook for defining Type aliases.
virtual void
getTypeAliases(SmallVectorImpl<std::pair<Type, StringRef>> &aliases) const {}
};
} // end namespace mlir
#endif

View File

@@ -88,7 +88,8 @@ public:
/// This is the current context if it is knowable, otherwise this is null.
MLIRContext *const context;
explicit ModuleState(MLIRContext *context) : context(context) {}
explicit ModuleState(MLIRContext *context)
: context(context), interfaces(context) {}
// Initializes module state, populating affine map state.
void initialize(Operation *op);
@@ -185,6 +186,9 @@ private:
/// A mapping between a type and a given alias.
DenseMap<Type, StringRef> typeToAlias;
/// Collection of OpAsm interfaces implemented in the context.
DialectInterfaceCollection<OpAsmDialectInterface> interfaces;
};
} // end anonymous namespace
@@ -251,9 +255,6 @@ void ModuleState::initializeSymbolAliases() {
// isn't used twice.
llvm::StringSet<> usedAliases;
// Get the currently registered dialects.
auto dialects = context->getRegisteredDialects();
// Collect the set of aliases from each dialect.
SmallVector<std::pair<unsigned, StringRef>, 8> attributeKindAliases;
SmallVector<std::pair<Attribute, StringRef>, 8> attributeAliases;
@@ -263,10 +264,10 @@ void ModuleState::initializeSymbolAliases() {
attributeKindAliases.emplace_back(StandardAttributes::AffineMap, "map");
attributeKindAliases.emplace_back(StandardAttributes::IntegerSet, "set");
for (auto *dialect : dialects) {
dialect->getAttributeKindAliases(attributeKindAliases);
dialect->getAttributeAliases(attributeAliases);
dialect->getTypeAliases(typeAliases);
for (auto &interface : interfaces) {
interface.getAttributeKindAliases(attributeKindAliases);
interface.getAttributeAliases(attributeAliases);
interface.getTypeAliases(typeAliases);
}
// Setup the attribute kind aliases.
@@ -1635,7 +1636,7 @@ void ModulePrinter::print(ModuleOp module) {
//===----------------------------------------------------------------------===//
void Attribute::print(raw_ostream &os) const {
ModuleState state(/*no context is known*/ nullptr);
ModuleState state(getContext());
ModulePrinter(os, state).printAttribute(*this);
}
@@ -1685,7 +1686,7 @@ void AffineMap::print(raw_ostream &os) const {
}
void IntegerSet::print(raw_ostream &os) const {
ModuleState state(/*no context is known*/ nullptr);
ModuleState state(getContext());
ModulePrinter(os, state).printIntegerSet(*this);
}

View File

@@ -135,9 +135,12 @@ DialectInterface::~DialectInterface() {}
DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
MLIRContext *ctx, ClassID *interfaceKind) {
for (auto *dialect : ctx->getRegisteredDialects())
if (auto *interface = dialect->getRegisteredInterface(interfaceKind))
for (auto *dialect : ctx->getRegisteredDialects()) {
if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) {
interfaces.insert(interface);
orderedInterfaces.push_back(interface);
}
}
}
DialectInterfaceCollectionBase::~DialectInterfaceCollectionBase() {}