mirror of
https://github.com/intel/llvm.git
synced 2026-01-25 01:07:04 +08:00
Move the parser extensions for aliases currently on Dialect to a new OpAsmDialectInterface.
This will allow for adding more hooks for controlling parser behavior without bloating Dialect in the common case. This cl also adds iteration support to the DialectInterfaceCollection. PiperOrigin-RevId: 264627846
This commit is contained in:
committed by
A. Unique TensorFlower
parent
8d18fdf2d3
commit
7e1af594d2
@@ -133,22 +133,6 @@ public:
|
||||
llvm_unreachable("dialect has no registered type printing hook");
|
||||
}
|
||||
|
||||
/// Registered hooks for getting identifier aliases for symbols. The
|
||||
/// identifier is used in place of the symbol when printing textual IR.
|
||||
///
|
||||
/// Hook for defining Attribute kind aliases. This will generate an alias for
|
||||
/// all attributes of the given kind in the form : <alias>[0-9]+. These
|
||||
/// aliases must not contain `.`.
|
||||
virtual void getAttributeKindAliases(
|
||||
SmallVectorImpl<std::pair<unsigned, StringRef>> &aliases) {}
|
||||
/// Hook for defining Attribute aliases. These aliases must not contain `.` or
|
||||
/// end with a numeric digit([0-9]+).
|
||||
virtual void getAttributeAliases(
|
||||
SmallVectorImpl<std::pair<Attribute, StringRef>> &aliases) {}
|
||||
/// Hook for defining Type aliases.
|
||||
virtual void
|
||||
getTypeAliases(SmallVectorImpl<std::pair<Type, StringRef>> &aliases) {}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Verification Hooks
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
@@ -99,6 +99,7 @@ class DialectInterfaceCollectionBase {
|
||||
|
||||
/// A set of registered dialect interface instances.
|
||||
using InterfaceSetT = DenseSet<const DialectInterface *, InterfaceKeyInfo>;
|
||||
using InterfaceVectorT = std::vector<const DialectInterface *>;
|
||||
|
||||
public:
|
||||
DialectInterfaceCollectionBase(MLIRContext *ctx, ClassID *interfaceKind);
|
||||
@@ -115,9 +116,40 @@ protected:
|
||||
return it == interfaces.end() ? nullptr : *it;
|
||||
}
|
||||
|
||||
/// An iterator class that iterates the held interface objects of the given
|
||||
/// derived interface type.
|
||||
template <typename InterfaceT>
|
||||
class iterator : public llvm::mapped_iterator<
|
||||
InterfaceVectorT::const_iterator,
|
||||
const InterfaceT &(*)(const DialectInterface *)> {
|
||||
static const InterfaceT &remapIt(const DialectInterface *interface) {
|
||||
return *static_cast<const InterfaceT *>(interface);
|
||||
}
|
||||
|
||||
iterator(InterfaceVectorT::const_iterator it)
|
||||
: llvm::mapped_iterator<
|
||||
InterfaceVectorT::const_iterator,
|
||||
const InterfaceT &(*)(const DialectInterface *)>(it, &remapIt) {}
|
||||
|
||||
/// Allow access to the constructor.
|
||||
friend DialectInterfaceCollectionBase;
|
||||
};
|
||||
|
||||
/// Iterator access to the held interfaces.
|
||||
template <typename InterfaceT> iterator<InterfaceT> interface_begin() const {
|
||||
return iterator<InterfaceT>(orderedInterfaces.begin());
|
||||
}
|
||||
template <typename InterfaceT> iterator<InterfaceT> interface_end() const {
|
||||
return iterator<InterfaceT>(orderedInterfaces.end());
|
||||
}
|
||||
|
||||
private:
|
||||
/// A set of registered dialect interface instances.
|
||||
InterfaceSetT interfaces;
|
||||
/// An ordered list of the registered interface instances, necessary for
|
||||
/// deterministic iteration.
|
||||
// NOTE: SetVector does not provide find access, so it can't be used here.
|
||||
InterfaceVectorT orderedInterfaces;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
@@ -141,6 +173,16 @@ public:
|
||||
return static_cast<const InterfaceType *>(
|
||||
detail::DialectInterfaceCollectionBase::getInterfaceFor(obj));
|
||||
}
|
||||
|
||||
/// Iterator access to the held interfaces.
|
||||
using iterator =
|
||||
detail::DialectInterfaceCollectionBase::iterator<InterfaceType>;
|
||||
iterator begin() const { return interface_begin<InterfaceType>(); }
|
||||
iterator end() const { return interface_end<InterfaceType>(); }
|
||||
|
||||
private:
|
||||
using detail::DialectInterfaceCollectionBase::interface_begin;
|
||||
using detail::DialectInterfaceCollectionBase::interface_end;
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
@@ -22,6 +22,7 @@
|
||||
#ifndef MLIR_IR_OPIMPLEMENTATION_H
|
||||
#define MLIR_IR_OPIMPLEMENTATION_H
|
||||
|
||||
#include "mlir/IR/DialectInterface.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "llvm/ADT/Twine.h"
|
||||
#include "llvm/Support/SMLoc.h"
|
||||
@@ -528,6 +529,32 @@ private:
|
||||
Delimiter delimiter);
|
||||
};
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Dialect OpAsm interface.
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
class OpAsmDialectInterface
|
||||
: public DialectInterface::Base<OpAsmDialectInterface> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
/// Hooks for getting identifier aliases for symbols. The identifier is used
|
||||
/// in place of the symbol when printing textual IR.
|
||||
///
|
||||
/// Hook for defining Attribute kind aliases. This will generate an alias for
|
||||
/// all attributes of the given kind in the form : <alias>[0-9]+. These
|
||||
/// aliases must not contain `.`.
|
||||
virtual void getAttributeKindAliases(
|
||||
SmallVectorImpl<std::pair<unsigned, StringRef>> &aliases) const {}
|
||||
/// Hook for defining Attribute aliases. These aliases must not contain `.` or
|
||||
/// end with a numeric digit([0-9]+).
|
||||
virtual void getAttributeAliases(
|
||||
SmallVectorImpl<std::pair<Attribute, StringRef>> &aliases) const {}
|
||||
/// Hook for defining Type aliases.
|
||||
virtual void
|
||||
getTypeAliases(SmallVectorImpl<std::pair<Type, StringRef>> &aliases) const {}
|
||||
};
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif
|
||||
|
||||
@@ -88,7 +88,8 @@ public:
|
||||
/// This is the current context if it is knowable, otherwise this is null.
|
||||
MLIRContext *const context;
|
||||
|
||||
explicit ModuleState(MLIRContext *context) : context(context) {}
|
||||
explicit ModuleState(MLIRContext *context)
|
||||
: context(context), interfaces(context) {}
|
||||
|
||||
// Initializes module state, populating affine map state.
|
||||
void initialize(Operation *op);
|
||||
@@ -185,6 +186,9 @@ private:
|
||||
|
||||
/// A mapping between a type and a given alias.
|
||||
DenseMap<Type, StringRef> typeToAlias;
|
||||
|
||||
/// Collection of OpAsm interfaces implemented in the context.
|
||||
DialectInterfaceCollection<OpAsmDialectInterface> interfaces;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
@@ -251,9 +255,6 @@ void ModuleState::initializeSymbolAliases() {
|
||||
// isn't used twice.
|
||||
llvm::StringSet<> usedAliases;
|
||||
|
||||
// Get the currently registered dialects.
|
||||
auto dialects = context->getRegisteredDialects();
|
||||
|
||||
// Collect the set of aliases from each dialect.
|
||||
SmallVector<std::pair<unsigned, StringRef>, 8> attributeKindAliases;
|
||||
SmallVector<std::pair<Attribute, StringRef>, 8> attributeAliases;
|
||||
@@ -263,10 +264,10 @@ void ModuleState::initializeSymbolAliases() {
|
||||
attributeKindAliases.emplace_back(StandardAttributes::AffineMap, "map");
|
||||
attributeKindAliases.emplace_back(StandardAttributes::IntegerSet, "set");
|
||||
|
||||
for (auto *dialect : dialects) {
|
||||
dialect->getAttributeKindAliases(attributeKindAliases);
|
||||
dialect->getAttributeAliases(attributeAliases);
|
||||
dialect->getTypeAliases(typeAliases);
|
||||
for (auto &interface : interfaces) {
|
||||
interface.getAttributeKindAliases(attributeKindAliases);
|
||||
interface.getAttributeAliases(attributeAliases);
|
||||
interface.getTypeAliases(typeAliases);
|
||||
}
|
||||
|
||||
// Setup the attribute kind aliases.
|
||||
@@ -1635,7 +1636,7 @@ void ModulePrinter::print(ModuleOp module) {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void Attribute::print(raw_ostream &os) const {
|
||||
ModuleState state(/*no context is known*/ nullptr);
|
||||
ModuleState state(getContext());
|
||||
ModulePrinter(os, state).printAttribute(*this);
|
||||
}
|
||||
|
||||
@@ -1685,7 +1686,7 @@ void AffineMap::print(raw_ostream &os) const {
|
||||
}
|
||||
|
||||
void IntegerSet::print(raw_ostream &os) const {
|
||||
ModuleState state(/*no context is known*/ nullptr);
|
||||
ModuleState state(getContext());
|
||||
ModulePrinter(os, state).printIntegerSet(*this);
|
||||
}
|
||||
|
||||
|
||||
@@ -135,9 +135,12 @@ DialectInterface::~DialectInterface() {}
|
||||
|
||||
DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
|
||||
MLIRContext *ctx, ClassID *interfaceKind) {
|
||||
for (auto *dialect : ctx->getRegisteredDialects())
|
||||
if (auto *interface = dialect->getRegisteredInterface(interfaceKind))
|
||||
for (auto *dialect : ctx->getRegisteredDialects()) {
|
||||
if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) {
|
||||
interfaces.insert(interface);
|
||||
orderedInterfaces.push_back(interface);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
DialectInterfaceCollectionBase::~DialectInterfaceCollectionBase() {}
|
||||
|
||||
Reference in New Issue
Block a user