Revert "Expose callbacks for encoding of types/attributes"

This reverts commit b299ec1666.

The authorship informations were incorrect.
This commit is contained in:
Mehdi Amini
2023-07-28 16:44:25 -07:00
parent b08d358e8a
commit b86a13211f
20 changed files with 152 additions and 950 deletions

View File

@@ -24,17 +24,6 @@
#include "llvm/ADT/Twine.h"
namespace mlir {
//===--------------------------------------------------------------------===//
// Dialect Version Interface.
//===--------------------------------------------------------------------===//
/// This class is used to represent the version of a dialect, for the purpose
/// of polymorphic destruction.
class DialectVersion {
public:
virtual ~DialectVersion() = default;
};
//===----------------------------------------------------------------------===//
// DialectBytecodeReader
//===----------------------------------------------------------------------===//
@@ -49,14 +38,7 @@ public:
virtual ~DialectBytecodeReader() = default;
/// Emit an error to the reader.
virtual InFlightDiagnostic emitError(const Twine &msg = {}) const = 0;
/// Retrieve the dialect version by name if available.
virtual FailureOr<const DialectVersion *>
getDialectVersion(StringRef dialectName) const = 0;
/// Retrieve the context associated to the reader.
virtual MLIRContext *getContext() const = 0;
virtual InFlightDiagnostic emitError(const Twine &msg = {}) = 0;
/// Return the bytecode version being read.
virtual uint64_t getBytecodeVersion() const = 0;
@@ -402,6 +384,17 @@ public:
virtual int64_t getBytecodeVersion() const = 0;
};
//===--------------------------------------------------------------------===//
// Dialect Version Interface.
//===--------------------------------------------------------------------===//
/// This class is used to represent the version of a dialect, for the purpose
/// of polymorphic destruction.
class DialectVersion {
public:
virtual ~DialectVersion() = default;
};
//===----------------------------------------------------------------------===//
// BytecodeDialectInterface
//===----------------------------------------------------------------------===//
@@ -416,23 +409,47 @@ public:
//===--------------------------------------------------------------------===//
/// Read an attribute belonging to this dialect from the given reader. This
/// method should return null in the case of failure. Optionally, the dialect
/// version can be accessed through the reader.
/// method should return null in the case of failure.
virtual Attribute readAttribute(DialectBytecodeReader &reader) const {
reader.emitError() << "dialect " << getDialect()->getNamespace()
<< " does not support reading attributes from bytecode";
return Attribute();
}
/// Read a versioned attribute encoding belonging to this dialect from the
/// given reader. This method should return null in the case of failure, and
/// falls back to the non-versioned reader in case the dialect implements
/// versioning but it does not support versioned custom encodings for the
/// attributes.
virtual Attribute readAttribute(DialectBytecodeReader &reader,
const DialectVersion &version) const {
reader.emitError()
<< "dialect " << getDialect()->getNamespace()
<< " does not support reading versioned attributes from bytecode";
return Attribute();
}
/// Read a type belonging to this dialect from the given reader. This method
/// should return null in the case of failure. Optionally, the dialect version
/// can be accessed thorugh the reader.
/// should return null in the case of failure.
virtual Type readType(DialectBytecodeReader &reader) const {
reader.emitError() << "dialect " << getDialect()->getNamespace()
<< " does not support reading types from bytecode";
return Type();
}
/// Read a versioned type encoding belonging to this dialect from the given
/// reader. This method should return null in the case of failure, and
/// falls back to the non-versioned reader in case the dialect implements
/// versioning but it does not support versioned custom encodings for the
/// types.
virtual Type readType(DialectBytecodeReader &reader,
const DialectVersion &version) const {
reader.emitError()
<< "dialect " << getDialect()->getNamespace()
<< " does not support reading versioned types from bytecode";
return Type();
}
//===--------------------------------------------------------------------===//
// Writing
//===--------------------------------------------------------------------===//

View File

@@ -25,6 +25,7 @@ class SourceMgr;
} // namespace llvm
namespace mlir {
/// The BytecodeReader allows to load MLIR bytecode files, while keeping the
/// state explicitly available in order to support lazy loading.
/// The `finalize` method must be called before destruction.

View File

@@ -1,120 +0,0 @@
//===- BytecodeReader.h - MLIR Bytecode Reader ------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This header defines interfaces to read MLIR bytecode files/streams.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_BYTECODE_BYTECODEREADERCONFIG_H
#define MLIR_BYTECODE_BYTECODEREADERCONFIG_H
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
namespace mlir {
class Attribute;
class DialectBytecodeReader;
class Type;
/// A class to interact with the attributes and types parser when parsing MLIR
/// bytecode.
template <class T>
class AttrTypeBytecodeReader {
public:
AttrTypeBytecodeReader() = default;
virtual ~AttrTypeBytecodeReader() = default;
virtual LogicalResult read(DialectBytecodeReader &reader,
StringRef dialectName, T &entry) = 0;
/// Return an Attribute/Type printer implemented via the given callable, whose
/// form should match that of the `parse` function above.
template <typename CallableT,
std::enable_if_t<
std::is_convertible_v<
CallableT, std::function<LogicalResult(
DialectBytecodeReader &, StringRef, T &)>>,
bool> = true>
static std::unique_ptr<AttrTypeBytecodeReader<T>>
fromCallable(CallableT &&readFn) {
struct Processor : public AttrTypeBytecodeReader<T> {
Processor(CallableT &&readFn)
: AttrTypeBytecodeReader(), readFn(std::move(readFn)) {}
LogicalResult read(DialectBytecodeReader &reader, StringRef dialectName,
T &entry) override {
return readFn(reader, dialectName, entry);
}
std::decay_t<CallableT> readFn;
};
return std::make_unique<Processor>(std::forward<CallableT>(readFn));
}
};
//===----------------------------------------------------------------------===//
// BytecodeReaderConfig
//===----------------------------------------------------------------------===//
/// A class containing bytecode-specific configurations of the `ParserConfig`.
class BytecodeReaderConfig {
public:
BytecodeReaderConfig() = default;
/// Returns the callbacks available to the parser.
ArrayRef<std::unique_ptr<AttrTypeBytecodeReader<Attribute>>>
getAttributeCallbacks() const {
return attributeBytecodeParsers;
}
ArrayRef<std::unique_ptr<AttrTypeBytecodeReader<Type>>>
getTypeCallbacks() const {
return typeBytecodeParsers;
}
/// Attach a custom bytecode parser callback to the configuration for parsing
/// of custom type/attributes encodings.
void attachAttributeCallback(
std::unique_ptr<AttrTypeBytecodeReader<Attribute>> parser) {
attributeBytecodeParsers.emplace_back(std::move(parser));
}
void
attachTypeCallback(std::unique_ptr<AttrTypeBytecodeReader<Type>> parser) {
typeBytecodeParsers.emplace_back(std::move(parser));
}
/// Attach a custom bytecode parser callback to the configuration for parsing
/// of custom type/attributes encodings.
template <typename CallableT>
std::enable_if_t<std::is_convertible_v<
CallableT, std::function<LogicalResult(DialectBytecodeReader &, StringRef,
Attribute &)>>>
attachAttributeCallback(CallableT &&parserFn) {
attachAttributeCallback(AttrTypeBytecodeReader<Attribute>::fromCallable(
std::forward<CallableT>(parserFn)));
}
template <typename CallableT>
std::enable_if_t<std::is_convertible_v<
CallableT,
std::function<LogicalResult(DialectBytecodeReader &, StringRef, Type &)>>>
attachTypeCallback(CallableT &&parserFn) {
attachTypeCallback(AttrTypeBytecodeReader<Type>::fromCallable(
std::forward<CallableT>(parserFn)));
}
private:
llvm::SmallVector<std::unique_ptr<AttrTypeBytecodeReader<Attribute>>>
attributeBytecodeParsers;
llvm::SmallVector<std::unique_ptr<AttrTypeBytecodeReader<Type>>>
typeBytecodeParsers;
};
} // namespace mlir
#endif // MLIR_BYTECODE_BYTECODEREADERCONFIG_H

View File

@@ -17,55 +17,6 @@
namespace mlir {
class Operation;
class DialectBytecodeWriter;
/// A class to interact with the attributes and types printer when emitting MLIR
/// bytecode.
template <class T>
class AttrTypeBytecodeWriter {
public:
AttrTypeBytecodeWriter() = default;
virtual ~AttrTypeBytecodeWriter() = default;
/// Callback writer API used in IRNumbering, where groups are created and
/// type/attribute components are numbered. At this stage, writer is expected
/// to be a `NumberingDialectWriter`.
virtual LogicalResult write(T entry, std::optional<StringRef> &name,
DialectBytecodeWriter &writer) = 0;
/// Callback writer API used in BytecodeWriter, where groups are created and
/// type/attribute components are numbered. Here, DialectBytecodeWriter is
/// expected to be an actual writer. The optional stringref specified by
/// the user is ignored, since the group was already specified when numbering
/// the IR.
LogicalResult write(T entry, DialectBytecodeWriter &writer) {
std::optional<StringRef> dummy;
return write(entry, dummy, writer);
}
/// Return an Attribute/Type printer implemented via the given callable, whose
/// form should match that of the `write` function above.
template <typename CallableT,
std::enable_if_t<std::is_convertible_v<
CallableT, std::function<LogicalResult(
T, std::optional<StringRef> &,
DialectBytecodeWriter &)>>,
bool> = true>
static std::unique_ptr<AttrTypeBytecodeWriter<T>>
fromCallable(CallableT &&writeFn) {
struct Processor : public AttrTypeBytecodeWriter<T> {
Processor(CallableT &&writeFn)
: AttrTypeBytecodeWriter(), writeFn(std::move(writeFn)) {}
LogicalResult write(T entry, std::optional<StringRef> &name,
DialectBytecodeWriter &writer) override {
return writeFn(entry, name, writer);
}
std::decay_t<CallableT> writeFn;
};
return std::make_unique<Processor>(std::forward<CallableT>(writeFn));
}
};
/// This class contains the configuration used for the bytecode writer. It
/// controls various aspects of bytecode generation, and contains all of the
@@ -97,43 +48,6 @@ public:
/// Get the set desired bytecode version to emit.
int64_t getDesiredBytecodeVersion() const;
//===--------------------------------------------------------------------===//
// Types and Attributes encoding
//===--------------------------------------------------------------------===//
/// Retrieve the callbacks.
ArrayRef<std::unique_ptr<AttrTypeBytecodeWriter<Attribute>>>
getAttributeWriterCallbacks() const;
ArrayRef<std::unique_ptr<AttrTypeBytecodeWriter<Type>>>
getTypeWriterCallbacks() const;
/// Attach a custom bytecode printer callback to the configuration for the
/// emission of custom type/attributes encodings.
void attachAttributeCallback(
std::unique_ptr<AttrTypeBytecodeWriter<Attribute>> callback);
void
attachTypeCallback(std::unique_ptr<AttrTypeBytecodeWriter<Type>> callback);
/// Attach a custom bytecode printer callback to the configuration for the
/// emission of custom type/attributes encodings.
template <typename CallableT>
std::enable_if_t<std::is_convertible_v<
CallableT,
std::function<LogicalResult(Attribute, std::optional<StringRef> &,
DialectBytecodeWriter &)>>>
attachAttributeCallback(CallableT &&emitFn) {
attachAttributeCallback(AttrTypeBytecodeWriter<Attribute>::fromCallable(
std::forward<CallableT>(emitFn)));
}
template <typename CallableT>
std::enable_if_t<std::is_convertible_v<
CallableT, std::function<LogicalResult(Type, std::optional<StringRef> &,
DialectBytecodeWriter &)>>>
attachTypeCallback(CallableT &&emitFn) {
attachTypeCallback(AttrTypeBytecodeWriter<Type>::fromCallable(
std::forward<CallableT>(emitFn)));
}
//===--------------------------------------------------------------------===//
// Resources
//===--------------------------------------------------------------------===//

View File

@@ -14,7 +14,6 @@
#ifndef MLIR_IR_ASMSTATE_H_
#define MLIR_IR_ASMSTATE_H_
#include "mlir/Bytecode/BytecodeReaderConfig.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/MapVector.h"
@@ -476,11 +475,6 @@ public:
/// Returns if the parser should verify the IR after parsing.
bool shouldVerifyAfterParse() const { return verifyAfterParse; }
/// Returns the parsing configurations associated to the bytecode read.
BytecodeReaderConfig &getBytecodeReaderConfig() const {
return const_cast<BytecodeReaderConfig &>(bytecodeReaderConfig);
}
/// Return the resource parser registered to the given name, or nullptr if no
/// parser with `name` is registered.
AsmResourceParser *getResourceParser(StringRef name) const {
@@ -515,7 +509,6 @@ private:
bool verifyAfterParse;
DenseMap<StringRef, std::unique_ptr<AsmResourceParser>> resourceParsers;
FallbackAsmResourceMap *fallbackResourceMap;
BytecodeReaderConfig bytecodeReaderConfig;
};
//===----------------------------------------------------------------------===//

View File

@@ -451,7 +451,7 @@ struct BytecodeDialect {
/// Returns failure if the dialect couldn't be loaded *and* the provided
/// context does not allow unregistered dialects. The provided reader is used
/// for error emission if necessary.
LogicalResult load(const DialectReader &reader, MLIRContext *ctx);
LogicalResult load(DialectReader &reader, MLIRContext *ctx);
/// Return the loaded dialect, or nullptr if the dialect is unknown. This can
/// only be called after `load`.
@@ -505,11 +505,10 @@ struct BytecodeOperationName {
/// Parse a single dialect group encoded in the byte stream.
static LogicalResult parseDialectGrouping(
EncodingReader &reader,
MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
EncodingReader &reader, MutableArrayRef<BytecodeDialect> dialects,
function_ref<LogicalResult(BytecodeDialect *)> entryCallback) {
// Parse the dialect and the number of entries in the group.
std::unique_ptr<BytecodeDialect> *dialect;
BytecodeDialect *dialect;
if (failed(parseEntry(reader, dialects, dialect, "dialect")))
return failure();
uint64_t numEntries;
@@ -517,7 +516,7 @@ static LogicalResult parseDialectGrouping(
return failure();
for (uint64_t i = 0; i < numEntries; ++i)
if (failed(entryCallback(dialect->get())))
if (failed(entryCallback(dialect)))
return failure();
return success();
}
@@ -533,7 +532,7 @@ public:
/// Initialize the resource section reader with the given section data.
LogicalResult
initialize(Location fileLoc, const ParserConfig &config,
MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
MutableArrayRef<BytecodeDialect> dialects,
StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData,
ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader,
const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef);
@@ -683,7 +682,7 @@ parseResourceGroup(Location fileLoc, bool allowEmpty,
LogicalResult ResourceSectionReader::initialize(
Location fileLoc, const ParserConfig &config,
MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
MutableArrayRef<BytecodeDialect> dialects,
StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData,
ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader,
const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) {
@@ -732,19 +731,19 @@ LogicalResult ResourceSectionReader::initialize(
// Read the dialect resources from the bytecode.
MLIRContext *ctx = fileLoc->getContext();
while (!offsetReader.empty()) {
std::unique_ptr<BytecodeDialect> *dialect;
BytecodeDialect *dialect;
if (failed(parseEntry(offsetReader, dialects, dialect, "dialect")) ||
failed((*dialect)->load(dialectReader, ctx)))
failed(dialect->load(dialectReader, ctx)))
return failure();
Dialect *loadedDialect = (*dialect)->getLoadedDialect();
Dialect *loadedDialect = dialect->getLoadedDialect();
if (!loadedDialect) {
return resourceReader.emitError()
<< "dialect '" << (*dialect)->name << "' is unknown";
<< "dialect '" << dialect->name << "' is unknown";
}
const auto *handler = dyn_cast<OpAsmDialectInterface>(loadedDialect);
if (!handler) {
return resourceReader.emitError()
<< "unexpected resources for dialect '" << (*dialect)->name << "'";
<< "unexpected resources for dialect '" << dialect->name << "'";
}
// Ensure that each resource is declared before being processed.
@@ -754,7 +753,7 @@ LogicalResult ResourceSectionReader::initialize(
if (failed(handle)) {
return resourceReader.emitError()
<< "unknown 'resource' key '" << key << "' for dialect '"
<< (*dialect)->name << "'";
<< dialect->name << "'";
}
dialectResourceHandleRenamingMap[key] = handler->getResourceKey(*handle);
dialectResources.push_back(*handle);
@@ -797,19 +796,15 @@ class AttrTypeReader {
public:
AttrTypeReader(StringSectionReader &stringReader,
ResourceSectionReader &resourceReader,
const llvm::StringMap<BytecodeDialect *> &dialectsMap,
uint64_t &bytecodeVersion, Location fileLoc,
const ParserConfig &config)
ResourceSectionReader &resourceReader, Location fileLoc,
uint64_t &bytecodeVersion)
: stringReader(stringReader), resourceReader(resourceReader),
dialectsMap(dialectsMap), fileLoc(fileLoc),
bytecodeVersion(bytecodeVersion), parserConfig(config) {}
fileLoc(fileLoc), bytecodeVersion(bytecodeVersion) {}
/// Initialize the attribute and type information within the reader.
LogicalResult
initialize(MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
ArrayRef<uint8_t> sectionData,
ArrayRef<uint8_t> offsetSectionData);
LogicalResult initialize(MutableArrayRef<BytecodeDialect> dialects,
ArrayRef<uint8_t> sectionData,
ArrayRef<uint8_t> offsetSectionData);
/// Resolve the attribute or type at the given index. Returns nullptr on
/// failure.
@@ -883,10 +878,6 @@ private:
/// parsing custom encoded attribute/type entries.
ResourceSectionReader &resourceReader;
/// The map of the loaded dialects used to retrieve dialect information, such
/// as the dialect version.
const llvm::StringMap<BytecodeDialect *> &dialectsMap;
/// The set of attribute and type entries.
SmallVector<AttrEntry> attributes;
SmallVector<TypeEntry> types;
@@ -896,48 +887,27 @@ private:
/// Current bytecode version being used.
uint64_t &bytecodeVersion;
/// Reference to the parser configuration.
const ParserConfig &parserConfig;
};
class DialectReader : public DialectBytecodeReader {
public:
DialectReader(AttrTypeReader &attrTypeReader,
StringSectionReader &stringReader,
ResourceSectionReader &resourceReader,
const llvm::StringMap<BytecodeDialect *> &dialectsMap,
EncodingReader &reader, uint64_t &bytecodeVersion)
ResourceSectionReader &resourceReader, EncodingReader &reader,
uint64_t &bytecodeVersion)
: attrTypeReader(attrTypeReader), stringReader(stringReader),
resourceReader(resourceReader), dialectsMap(dialectsMap),
reader(reader), bytecodeVersion(bytecodeVersion) {}
resourceReader(resourceReader), reader(reader),
bytecodeVersion(bytecodeVersion) {}
InFlightDiagnostic emitError(const Twine &msg) const override {
InFlightDiagnostic emitError(const Twine &msg) override {
return reader.emitError(msg);
}
FailureOr<const DialectVersion *>
getDialectVersion(StringRef dialectName) const override {
// First check if the dialect is available in the map.
auto dialectEntry = dialectsMap.find(dialectName);
if (dialectEntry == dialectsMap.end())
return failure();
// If the dialect was found, try to load it. This will trigger reading the
// bytecode version from the version buffer if it wasn't already processed.
// Return failure if either of those two actions could not be completed.
if (failed(dialectEntry->getValue()->load(*this, getLoc().getContext())) ||
dialectEntry->getValue()->loadedVersion.get() == nullptr)
return failure();
return dialectEntry->getValue()->loadedVersion.get();
}
MLIRContext *getContext() const override { return getLoc().getContext(); }
uint64_t getBytecodeVersion() const override { return bytecodeVersion; }
DialectReader withEncodingReader(EncodingReader &encReader) const {
DialectReader withEncodingReader(EncodingReader &encReader) {
return DialectReader(attrTypeReader, stringReader, resourceReader,
dialectsMap, encReader, bytecodeVersion);
encReader, bytecodeVersion);
}
Location getLoc() const { return reader.getLoc(); }
@@ -1040,7 +1010,6 @@ private:
AttrTypeReader &attrTypeReader;
StringSectionReader &stringReader;
ResourceSectionReader &resourceReader;
const llvm::StringMap<BytecodeDialect *> &dialectsMap;
EncodingReader &reader;
uint64_t &bytecodeVersion;
};
@@ -1127,9 +1096,10 @@ private:
};
} // namespace
LogicalResult AttrTypeReader::initialize(
MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
ArrayRef<uint8_t> sectionData, ArrayRef<uint8_t> offsetSectionData) {
LogicalResult
AttrTypeReader::initialize(MutableArrayRef<BytecodeDialect> dialects,
ArrayRef<uint8_t> sectionData,
ArrayRef<uint8_t> offsetSectionData) {
EncodingReader offsetReader(offsetSectionData, fileLoc);
// Parse the number of attribute and type entries.
@@ -1181,7 +1151,6 @@ LogicalResult AttrTypeReader::initialize(
return offsetReader.emitError(
"unexpected trailing data in the Attribute/Type offset section");
}
return success();
}
@@ -1247,54 +1216,32 @@ template <typename T>
LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
EncodingReader &reader,
StringRef entryType) {
DialectReader dialectReader(*this, stringReader, resourceReader, dialectsMap,
reader, bytecodeVersion);
DialectReader dialectReader(*this, stringReader, resourceReader, reader,
bytecodeVersion);
if (failed(entry.dialect->load(dialectReader, fileLoc.getContext())))
return failure();
if constexpr (std::is_same_v<T, Type>) {
// Try parsing with callbacks first if available.
for (const auto &callback :
parserConfig.getBytecodeReaderConfig().getTypeCallbacks()) {
if (failed(
callback->read(dialectReader, entry.dialect->name, entry.entry)))
return failure();
// Early return if parsing was successful.
if (!!entry.entry)
return success();
// Reset the reader if we failed to parse, so we can fall through the
// other parsing functions.
reader = EncodingReader(entry.data, reader.getLoc());
}
} else {
// Try parsing with callbacks first if available.
for (const auto &callback :
parserConfig.getBytecodeReaderConfig().getAttributeCallbacks()) {
if (failed(
callback->read(dialectReader, entry.dialect->name, entry.entry)))
return failure();
// Early return if parsing was successful.
if (!!entry.entry)
return success();
// Reset the reader if we failed to parse, so we can fall through the
// other parsing functions.
reader = EncodingReader(entry.data, reader.getLoc());
}
}
// Ensure that the dialect implements the bytecode interface.
if (!entry.dialect->interface) {
return reader.emitError("dialect '", entry.dialect->name,
"' does not implement the bytecode interface");
}
if constexpr (std::is_same_v<T, Type>)
entry.entry = entry.dialect->interface->readType(dialectReader);
else
entry.entry = entry.dialect->interface->readAttribute(dialectReader);
// Ask the dialect to parse the entry. If the dialect is versioned, parse
// using the versioned encoding readers.
if (entry.dialect->loadedVersion.get()) {
if constexpr (std::is_same_v<T, Type>)
entry.entry = entry.dialect->interface->readType(
dialectReader, *entry.dialect->loadedVersion);
else
entry.entry = entry.dialect->interface->readAttribute(
dialectReader, *entry.dialect->loadedVersion);
} else {
if constexpr (std::is_same_v<T, Type>)
entry.entry = entry.dialect->interface->readType(dialectReader);
else
entry.entry = entry.dialect->interface->readAttribute(dialectReader);
}
return success(!!entry.entry);
}
@@ -1315,8 +1262,7 @@ public:
llvm::MemoryBufferRef buffer,
const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef)
: config(config), fileLoc(fileLoc), lazyLoading(lazyLoading),
attrTypeReader(stringReader, resourceReader, dialectsMap, version,
fileLoc, config),
attrTypeReader(stringReader, resourceReader, fileLoc, version),
// Use the builtin unrealized conversion cast operation to represent
// forward references to values that aren't yet defined.
forwardRefOpState(UnknownLoc::get(config.getContext()),
@@ -1582,8 +1528,7 @@ private:
StringRef producer;
/// The table of IR units referenced within the bytecode file.
SmallVector<std::unique_ptr<BytecodeDialect>> dialects;
llvm::StringMap<BytecodeDialect *> dialectsMap;
SmallVector<BytecodeDialect> dialects;
SmallVector<BytecodeOperationName> opNames;
/// The reader used to process resources within the bytecode.
@@ -1730,8 +1675,7 @@ LogicalResult BytecodeReader::Impl::parseVersion(EncodingReader &reader) {
//===----------------------------------------------------------------------===//
// Dialect Section
LogicalResult BytecodeDialect::load(const DialectReader &reader,
MLIRContext *ctx) {
LogicalResult BytecodeDialect::load(DialectReader &reader, MLIRContext *ctx) {
if (dialect)
return success();
Dialect *loadedDialect = ctx->getOrLoadDialect(name);
@@ -1775,15 +1719,13 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
// Parse each of the dialects.
for (uint64_t i = 0; i < numDialects; ++i) {
dialects[i] = std::make_unique<BytecodeDialect>();
/// Before version kDialectVersioning, there wasn't any versioning available
/// for dialects, and the entryIdx represent the string itself.
if (version < bytecode::kDialectVersioning) {
if (failed(stringReader.parseString(sectionReader, dialects[i]->name)))
if (failed(stringReader.parseString(sectionReader, dialects[i].name)))
return failure();
continue;
}
// Parse ID representing dialect and version.
uint64_t dialectNameIdx;
bool versionAvailable;
@@ -1791,19 +1733,18 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
versionAvailable)))
return failure();
if (failed(stringReader.parseStringAtIndex(sectionReader, dialectNameIdx,
dialects[i]->name)))
dialects[i].name)))
return failure();
if (versionAvailable) {
bytecode::Section::ID sectionID;
if (failed(sectionReader.parseSection(sectionID,
dialects[i]->versionBuffer)))
if (failed(
sectionReader.parseSection(sectionID, dialects[i].versionBuffer)))
return failure();
if (sectionID != bytecode::Section::kDialectVersions) {
emitError(fileLoc, "expected dialect version section");
return failure();
}
}
dialectsMap[dialects[i]->name] = dialects[i].get();
}
// Parse the operation names, which are grouped by dialect.
@@ -1851,7 +1792,7 @@ BytecodeReader::Impl::parseOpName(EncodingReader &reader,
if (!opName->opName) {
// Load the dialect and its version.
DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
dialectsMap, reader, version);
reader, version);
if (failed(opName->dialect->load(dialectReader, getContext())))
return failure();
// If the opName is empty, this is because we use to accept names such as
@@ -1894,7 +1835,7 @@ LogicalResult BytecodeReader::Impl::parseResourceSection(
// Initialize the resource reader with the resource sections.
DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
dialectsMap, reader, version);
reader, version);
return resourceReader.initialize(fileLoc, config, dialects, stringReader,
*resourceData, *resourceOffsetData,
dialectReader, bufferOwnerRef);
@@ -2095,14 +2036,14 @@ BytecodeReader::Impl::parseIRSection(ArrayRef<uint8_t> sectionData,
"parsed use-list orders were invalid and could not be applied");
// Resolve dialect version.
for (const std::unique_ptr<BytecodeDialect> &byteCodeDialect : dialects) {
for (const BytecodeDialect &byteCodeDialect : dialects) {
// Parsing is complete, give an opportunity to each dialect to visit the
// IR and perform upgrades.
if (!byteCodeDialect->loadedVersion)
if (!byteCodeDialect.loadedVersion)
continue;
if (byteCodeDialect->interface &&
failed(byteCodeDialect->interface->upgradeFromVersion(
*moduleOp, *byteCodeDialect->loadedVersion)))
if (byteCodeDialect.interface &&
failed(byteCodeDialect.interface->upgradeFromVersion(
*moduleOp, *byteCodeDialect.loadedVersion)))
return failure();
}
@@ -2255,7 +2196,7 @@ BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader,
// interface and control the serialization.
if (wasRegistered) {
DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
dialectsMap, reader, version);
reader, version);
if (failed(
propertiesReader.read(fileLoc, dialectReader, &*opName, opState)))
return failure();

View File

@@ -18,10 +18,15 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/CachedHashString.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Endian.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Support/Endian.h"
#include <cstddef>
#include <cstdint>
#include <cstring>
#include <optional>
#include <sys/types.h>
#define DEBUG_TYPE "mlir-bytecode-writer"
@@ -42,12 +47,6 @@ struct BytecodeWriterConfig::Impl {
/// The producer of the bytecode.
StringRef producer;
/// Printer callbacks used to emit custom type and attribute encodings.
llvm::SmallVector<std::unique_ptr<AttrTypeBytecodeWriter<Attribute>>>
attributeWriterCallbacks;
llvm::SmallVector<std::unique_ptr<AttrTypeBytecodeWriter<Type>>>
typeWriterCallbacks;
/// A collection of non-dialect resource printers.
SmallVector<std::unique_ptr<AsmResourcePrinter>> externalResourcePrinters;
};
@@ -61,26 +60,6 @@ BytecodeWriterConfig::BytecodeWriterConfig(FallbackAsmResourceMap &map,
}
BytecodeWriterConfig::~BytecodeWriterConfig() = default;
ArrayRef<std::unique_ptr<AttrTypeBytecodeWriter<Attribute>>>
BytecodeWriterConfig::getAttributeWriterCallbacks() const {
return impl->attributeWriterCallbacks;
}
ArrayRef<std::unique_ptr<AttrTypeBytecodeWriter<Type>>>
BytecodeWriterConfig::getTypeWriterCallbacks() const {
return impl->typeWriterCallbacks;
}
void BytecodeWriterConfig::attachAttributeCallback(
std::unique_ptr<AttrTypeBytecodeWriter<Attribute>> callback) {
impl->attributeWriterCallbacks.emplace_back(std::move(callback));
}
void BytecodeWriterConfig::attachTypeCallback(
std::unique_ptr<AttrTypeBytecodeWriter<Type>> callback) {
impl->typeWriterCallbacks.emplace_back(std::move(callback));
}
void BytecodeWriterConfig::attachResourcePrinter(
std::unique_ptr<AsmResourcePrinter> printer) {
impl->externalResourcePrinters.emplace_back(std::move(printer));
@@ -795,50 +774,32 @@ void BytecodeWriter::writeAttrTypeSection(EncodingEmitter &emitter) {
auto emitAttrOrType = [&](auto &entry) {
auto entryValue = entry.getValue();
auto emitAttrOrTypeRawImpl = [&]() -> void {
RawEmitterOstream(attrTypeEmitter) << entryValue;
attrTypeEmitter.emitByte(0);
};
auto emitAttrOrTypeImpl = [&]() -> bool {
// TODO: We don't currently support custom encoded mutable types and
// attributes.
if (entryValue.template hasTrait<TypeTrait::IsMutable>() ||
entryValue.template hasTrait<AttributeTrait::IsMutable>()) {
emitAttrOrTypeRawImpl();
return false;
}
// First, try to emit this entry using the dialect bytecode interface.
bool hasCustomEncoding = false;
if (const BytecodeDialectInterface *interface = entry.dialect->interface) {
// The writer used when emitting using a custom bytecode encoding.
DialectWriter dialectWriter(config.bytecodeVersion, attrTypeEmitter,
numberingState, stringSection);
if constexpr (std::is_same_v<std::decay_t<decltype(entryValue)>, Type>) {
for (const auto &callback : config.typeWriterCallbacks) {
if (succeeded(callback->write(entryValue, dialectWriter)))
return true;
}
if (const BytecodeDialectInterface *interface =
entry.dialect->interface) {
if (succeeded(interface->writeType(entryValue, dialectWriter)))
return true;
}
// TODO: We don't currently support custom encoded mutable types.
hasCustomEncoding =
!entryValue.template hasTrait<TypeTrait::IsMutable>() &&
succeeded(interface->writeType(entryValue, dialectWriter));
} else {
for (const auto &callback : config.attributeWriterCallbacks) {
if (succeeded(callback->write(entryValue, dialectWriter)))
return true;
}
if (const BytecodeDialectInterface *interface =
entry.dialect->interface) {
if (succeeded(interface->writeAttribute(entryValue, dialectWriter)))
return true;
}
// TODO: We don't currently support custom encoded mutable attributes.
hasCustomEncoding =
!entryValue.template hasTrait<AttributeTrait::IsMutable>() &&
succeeded(interface->writeAttribute(entryValue, dialectWriter));
}
}
// If the entry was not emitted using a callback or a dialect interface,
// emit it using the textual format.
emitAttrOrTypeRawImpl();
return false;
};
bool hasCustomEncoding = emitAttrOrTypeImpl();
// If the entry was not emitted using the dialect interface, emit it using
// the textual format.
if (!hasCustomEncoding) {
RawEmitterOstream(attrTypeEmitter) << entryValue;
attrTypeEmitter.emitByte(0);
}
// Record the offset of this entry.
uint64_t curOffset = attrTypeEmitter.size();

View File

@@ -314,22 +314,9 @@ void IRNumberingState::number(Attribute attr) {
// If this attribute will be emitted using the bytecode format, perform a
// dummy writing to number any nested components.
// TODO: We don't allow custom encodings for mutable attributes right now.
if (!attr.hasTrait<AttributeTrait::IsMutable>()) {
// Try overriding emission with callbacks.
for (const auto &callback : config.getAttributeWriterCallbacks()) {
NumberingDialectWriter writer(*this);
// The client has the ability to override the group name through the
// callback.
std::optional<StringRef> groupNameOverride;
if (succeeded(callback->write(attr, groupNameOverride, writer))) {
if (groupNameOverride.has_value())
numbering->dialect = &numberDialect(*groupNameOverride);
return;
}
}
if (const auto *interface = numbering->dialect->interface) {
if (const auto *interface = numbering->dialect->interface) {
// TODO: We don't allow custom encodings for mutable attributes right now.
if (!attr.hasTrait<AttributeTrait::IsMutable>()) {
NumberingDialectWriter writer(*this);
if (succeeded(interface->writeAttribute(attr, writer)))
return;
@@ -477,24 +464,9 @@ void IRNumberingState::number(Type type) {
// If this type will be emitted using the bytecode format, perform a dummy
// writing to number any nested components.
// TODO: We don't allow custom encodings for mutable types right now.
if (!type.hasTrait<TypeTrait::IsMutable>()) {
// Try overriding emission with callbacks.
for (const auto &callback : config.getTypeWriterCallbacks()) {
NumberingDialectWriter writer(*this);
// The client has the ability to override the group name through the
// callback.
std::optional<StringRef> groupNameOverride;
if (succeeded(callback->write(type, groupNameOverride, writer))) {
if (groupNameOverride.has_value())
numbering->dialect = &numberDialect(*groupNameOverride);
return;
}
}
// If this attribute will be emitted using the bytecode format, perform a
// dummy writing to number any nested components.
if (const auto *interface = numbering->dialect->interface) {
if (const auto *interface = numbering->dialect->interface) {
// TODO: We don't allow custom encodings for mutable types right now.
if (!type.hasTrait<TypeTrait::IsMutable>()) {
NumberingDialectWriter writer(*this);
if (succeeded(interface->writeType(type, writer)))
return;

View File

@@ -1,14 +0,0 @@
// RUN: mlir-opt %s --test-bytecode-callback="test-dialect-version=1.2" -verify-diagnostics | FileCheck %s --check-prefix=VERSION_1_2
// RUN: mlir-opt %s --test-bytecode-callback="test-dialect-version=2.0" -verify-diagnostics | FileCheck %s --check-prefix=VERSION_2_0
func.func @base_test(%arg0 : i32) -> f32 {
%0 = "test.addi"(%arg0, %arg0) : (i32, i32) -> i32
%1 = "test.cast"(%0) : (i32) -> f32
return %1 : f32
}
// VERSION_1_2: Overriding IntegerType encoding...
// VERSION_1_2: Overriding parsing of IntegerType encoding...
// VERSION_2_0-NOT: Overriding IntegerType encoding...
// VERSION_2_0-NOT: Overriding parsing of IntegerType encoding...

View File

@@ -1,18 +0,0 @@
// RUN: not mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=5" 2>&1 | FileCheck %s
// CHECK-NOT: failed to read bytecode
func.func @base_test(%arg0 : i32) -> f32 {
%0 = "test.addi"(%arg0, %arg0) : (i32, i32) -> i32
%1 = "test.cast"(%0) : (i32) -> f32
return %1 : f32
}
// -----
// CHECK-LABEL: error: unknown attribute code: 99
// CHECK: failed to read bytecode
func.func @base_test(%arg0 : !test.i32) -> f32 {
%0 = "test.addi"(%arg0, %arg0) : (!test.i32, !test.i32) -> !test.i32
%1 = "test.cast"(%0) : (!test.i32) -> f32
return %1 : f32
}

View File

@@ -1,14 +0,0 @@
// RUN: mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=3" | FileCheck %s --check-prefix=TEST_3
// RUN: mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=4" | FileCheck %s --check-prefix=TEST_4
"test.versionedC"() <{attribute = #test.attr_params<42, 24>}> : () -> ()
// TEST_3: Overriding TestAttrParamsAttr encoding...
// TEST_3: "test.versionedC"() <{attribute = dense<[42, 24]> : tensor<2xi32>}> : () -> ()
// -----
"test.versionedC"() <{attribute = dense<[42, 24]> : tensor<2xi32>}> : () -> ()
// TEST_4: Overriding parsing of TestAttrParamsAttr encoding...
// TEST_4: "test.versionedC"() <{attribute = #test.attr_params<42, 24>}> : () -> ()

View File

@@ -1,18 +0,0 @@
// RUN: mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=1" | FileCheck %s --check-prefix=TEST_1
// RUN: mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=2" | FileCheck %s --check-prefix=TEST_2
func.func @base_test(%arg0: !test.i32, %arg1: f32) {
return
}
// TEST_1: Overriding TestI32Type encoding...
// TEST_1: func.func @base_test([[ARG0:%.+]]: i32, [[ARG1:%.+]]: f32) {
// -----
func.func @base_test(%arg0: i32, %arg1: f32) {
return
}
// TEST_2: Overriding parsing of TestI32Type encoding...
// TEST_2: func.func @base_test([[ARG0:%.+]]: !test.i32, [[ARG1:%.+]]: f32) {

View File

@@ -5,12 +5,12 @@
// Index
//===--------------------------------------------------------------------===//
// RUN: not mlir-opt %S/invalid-attr_type_section-index.mlirbc -allow-unregistered-dialect 2>&1 | FileCheck %s --check-prefix=INDEX
// RUN: not mlir-opt %S/invalid-attr_type_section-index.mlirbc 2>&1 | FileCheck %s --check-prefix=INDEX
// INDEX: invalid Attribute index: 3
//===--------------------------------------------------------------------===//
// Trailing Data
//===--------------------------------------------------------------------===//
// RUN: not mlir-opt %S/invalid-attr_type_section-trailing_data.mlirbc -allow-unregistered-dialect 2>&1 | FileCheck %s --check-prefix=TRAILING_DATA
// RUN: not mlir-opt %S/invalid-attr_type_section-trailing_data.mlirbc 2>&1 | FileCheck %s --check-prefix=TRAILING_DATA
// TRAILING_DATA: trailing characters found after Attribute assembly format: trailing

View File

@@ -14,10 +14,9 @@
#ifndef MLIR_TESTDIALECT_H
#define MLIR_TESTDIALECT_H
#include "TestTypes.h"
#include "TestAttributes.h"
#include "TestInterfaces.h"
#include "TestTypes.h"
#include "mlir/Bytecode/BytecodeImplementation.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/DLTI/Traits.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -58,19 +57,6 @@ class RewritePatternSet;
#include "TestOpsDialect.h.inc"
namespace test {
//===----------------------------------------------------------------------===//
// TestDialect version utilities
//===----------------------------------------------------------------------===//
struct TestDialectVersion : public mlir::DialectVersion {
TestDialectVersion() = default;
TestDialectVersion(uint32_t _major, uint32_t _minor)
: major(_major), minor(_minor){};
uint32_t major = 2;
uint32_t minor = 0;
};
// Define some classes to exercises the Properties feature.
struct PropertiesWithCustomPrint {

View File

@@ -14,6 +14,15 @@
using namespace mlir;
using namespace test;
//===----------------------------------------------------------------------===//
// TestDialect version utilities
//===----------------------------------------------------------------------===//
struct TestDialectVersion : public DialectVersion {
uint32_t major = 2;
uint32_t minor = 0;
};
//===----------------------------------------------------------------------===//
// TestDialect Interfaces
//===----------------------------------------------------------------------===//
@@ -38,7 +47,7 @@ struct TestResourceBlobManagerInterface
};
namespace {
enum test_encoding { k_attr_params = 0, k_test_i32 = 99 };
enum test_encoding { k_attr_params = 0 };
}
// Test support for interacting with the Bytecode reader/writer.
@@ -47,24 +56,6 @@ struct TestBytecodeDialectInterface : public BytecodeDialectInterface {
TestBytecodeDialectInterface(Dialect *dialect)
: BytecodeDialectInterface(dialect) {}
LogicalResult writeType(Type type,
DialectBytecodeWriter &writer) const final {
if (auto concreteType = llvm::dyn_cast<TestI32Type>(type)) {
writer.writeVarInt(test_encoding::k_test_i32);
return success();
}
return failure();
}
Type readType(DialectBytecodeReader &reader) const final {
uint64_t encoding;
if (failed(reader.readVarInt(encoding)))
return Type();
if (encoding == test_encoding::k_test_i32)
return TestI32Type::get(getContext());
return Type();
}
LogicalResult writeAttribute(Attribute attr,
DialectBytecodeWriter &writer) const final {
if (auto concreteAttr = llvm::dyn_cast<TestAttrParamsAttr>(attr)) {
@@ -76,13 +67,9 @@ struct TestBytecodeDialectInterface : public BytecodeDialectInterface {
return failure();
}
Attribute readAttribute(DialectBytecodeReader &reader) const final {
auto versionOr = reader.getDialectVersion("test");
// Assume current version if not available through the reader.
const auto version =
(succeeded(versionOr))
? *reinterpret_cast<const TestDialectVersion *>(*versionOr)
: TestDialectVersion();
Attribute readAttribute(DialectBytecodeReader &reader,
const DialectVersion &version_) const final {
const auto &version = static_cast<const TestDialectVersion &>(version_);
if (version.major < 2)
return readAttrOldEncoding(reader);
if (version.major == 2 && version.minor == 0)

View File

@@ -1258,9 +1258,8 @@ def TestOpWithVariadicResultsAndFolder: TEST_Op<"op_with_variadic_results_and_fo
}
def TestAddIOp : TEST_Op<"addi"> {
let arguments = (ins AnyTypeOf<[I32, TestI32]>:$op1,
AnyTypeOf<[I32, TestI32]>:$op2);
let results = (outs AnyTypeOf<[I32, TestI32]>);
let arguments = (ins I32:$op1, I32:$op2);
let results = (outs I32);
}
def TestCommutativeOp : TEST_Op<"op_commutative", [Commutative]> {
@@ -2621,12 +2620,6 @@ def TestVersionedOpB : TEST_Op<"versionedB"> {
);
}
def TestVersionedOpC : TEST_Op<"versionedC"> {
let arguments = (ins AnyAttrOf<[TestAttrParams,
I32ElementsAttr]>:$attribute
);
}
//===----------------------------------------------------------------------===//
// Test Properties
//===----------------------------------------------------------------------===//

View File

@@ -369,8 +369,4 @@ def TestTypeElseAnchorStruct : Test_Type<"TestTypeElseAnchorStruct"> {
let assemblyFormat = "`<` (`?`) : (struct($a, $b)^)? `>`";
}
def TestI32 : Test_Type<"TestI32"> {
let mnemonic = "i32";
}
#endif // TEST_TYPEDEFS

View File

@@ -1,6 +1,5 @@
# Exclude tests from libMLIR.so
add_mlir_library(MLIRTestIR
TestBytecodeCallbacks.cpp
TestBuiltinAttributeInterfaces.cpp
TestBuiltinDistinctAttributes.cpp
TestClone.cpp

View File

@@ -1,372 +0,0 @@
//===- TestBytecodeCallbacks.cpp - Pass to test bytecode callback hooks --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
#include "mlir/Bytecode/BytecodeReader.h"
#include "mlir/Bytecode/BytecodeWriter.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Pass/Pass.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/MemoryBufferRef.h"
#include "llvm/Support/raw_ostream.h"
#include <list>
using namespace mlir;
using namespace llvm;
namespace {
class TestDialectVersionParser : public cl::parser<test::TestDialectVersion> {
public:
TestDialectVersionParser(cl::Option &O)
: cl::parser<test::TestDialectVersion>(O) {}
bool parse(cl::Option &O, StringRef /*argName*/, StringRef arg,
test::TestDialectVersion &v) {
long long major, minor;
if (getAsSignedInteger(arg.split(".").first, 10, major))
return O.error("Invalid argument '" + arg);
if (getAsSignedInteger(arg.split(".").second, 10, minor))
return O.error("Invalid argument '" + arg);
v = test::TestDialectVersion(major, minor);
// Returns true on error.
return false;
}
static void print(raw_ostream &os, const test::TestDialectVersion &v) {
os << v.major << "." << v.minor;
};
};
/// This is a test pass which uses callbacks to encode attributes and types in a
/// custom fashion.
struct TestBytecodeCallbackPass
: public PassWrapper<TestBytecodeCallbackPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestBytecodeCallbackPass)
StringRef getArgument() const final { return "test-bytecode-callback"; }
StringRef getDescription() const final {
return "Test encoding of a dialect type/attributes with a custom callback";
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<test::TestDialect>();
}
TestBytecodeCallbackPass() = default;
TestBytecodeCallbackPass(const TestBytecodeCallbackPass &) {}
void runOnOperation() override {
switch (testKind) {
case (0):
return runTest0(getOperation());
case (1):
return runTest1(getOperation());
case (2):
return runTest2(getOperation());
case (3):
return runTest3(getOperation());
case (4):
return runTest4(getOperation());
case (5):
return runTest5(getOperation());
default:
llvm_unreachable("unhandled test kind for TestBytecodeCallbacks pass");
}
}
mlir::Pass::Option<test::TestDialectVersion, TestDialectVersionParser>
targetVersion{*this, "test-dialect-version",
llvm::cl::desc(
"Specifies the test dialect version to emit and parse"),
cl::init(test::TestDialectVersion())};
mlir::Pass::Option<int> testKind{
*this, "callback-test",
llvm::cl::desc("Specifies the test kind to execute"), cl::init(0)};
private:
void doRoundtripWithConfigs(Operation *op,
const BytecodeWriterConfig &writeConfig,
const ParserConfig &parseConfig) {
std::string bytecode;
llvm::raw_string_ostream os(bytecode);
if (failed(writeBytecodeToFile(op, os, writeConfig))) {
op->emitError() << "failed to write bytecode\n";
signalPassFailure();
return;
}
auto newModuleOp = parseSourceString(StringRef(bytecode), parseConfig);
if (!newModuleOp.get()) {
op->emitError() << "failed to read bytecode\n";
signalPassFailure();
return;
}
// Print the module to the output stream, so that we can filecheck the
// result.
newModuleOp->print(llvm::outs());
return;
}
// Test0: let's assume that versions older than 2.0 were relying on a special
// integer attribute of a deprecated dialect called "funky". Assume that its
// encoding was made by two varInts, the first was the ID (999) and the second
// contained width and signedness info. We can emit it using a callback
// writing a custom encoding for the "funky" dialect group, and parse it back
// with a custom parser reading the same encoding in the same dialect group.
// Note that the ID 999 does not correspond to a valid integer type in the
// current encodings of builtin types.
void runTest0(Operation *op) {
auto newCtx = std::make_shared<MLIRContext>();
test::TestDialectVersion targetEmissionVersion = targetVersion;
BytecodeWriterConfig writeConfig;
writeConfig.attachTypeCallback(
[&](Type entryValue, std::optional<StringRef> &dialectGroupName,
DialectBytecodeWriter &writer) -> LogicalResult {
// Do not override anything if version less than 2.0.
if (targetEmissionVersion.major >= 2)
return failure();
// For version less than 2.0, override the encoding of IntegerType.
if (auto type = llvm::dyn_cast<IntegerType>(entryValue)) {
llvm::outs() << "Overriding IntegerType encoding...\n";
dialectGroupName = StringLiteral("funky");
writer.writeVarInt(/* IntegerType */ 999);
writer.writeVarInt(type.getWidth() << 2 | type.getSignedness());
return success();
}
return failure();
});
newCtx->appendDialectRegistry(op->getContext()->getDialectRegistry());
newCtx->allowUnregisteredDialects();
ParserConfig parseConfig(newCtx.get(), /*verifyAfterParse=*/true);
parseConfig.getBytecodeReaderConfig().attachTypeCallback(
[&](DialectBytecodeReader &reader, StringRef dialectName,
Type &entry) -> LogicalResult {
// Get test dialect version from the version map.
auto versionOr = reader.getDialectVersion("test");
assert(succeeded(versionOr) && "expected reader to be able to access "
"the version for test dialect");
const auto *version =
reinterpret_cast<const test::TestDialectVersion *>(*versionOr);
// TODO: once back-deployment is formally supported,
// `targetEmissionVersion` will be encoded in the bytecode file, and
// exposed through the versionMap. Right now though this is not yet
// supported. For the purpose of the test, just use
// `targetEmissionVersion`.
(void)version;
if (targetEmissionVersion.major >= 2)
return success();
// `dialectName` is the name of the group we have the opportunity to
// override. In this case, override only the dialect group "funky",
// for which does not exist in memory.
if (dialectName != StringLiteral("funky"))
return success();
uint64_t encoding;
if (failed(reader.readVarInt(encoding)) || encoding != 999)
return success();
llvm::outs() << "Overriding parsing of IntegerType encoding...\n";
uint64_t _widthAndSignedness, width;
IntegerType::SignednessSemantics signedness;
if (succeeded(reader.readVarInt(_widthAndSignedness)) &&
((width = _widthAndSignedness >> 2), true) &&
((signedness = static_cast<IntegerType::SignednessSemantics>(
_widthAndSignedness & 0x3)),
true))
entry = IntegerType::get(reader.getContext(), width, signedness);
// Return nullopt to fall through the rest of the parsing code path.
return success();
});
doRoundtripWithConfigs(op, writeConfig, parseConfig);
return;
}
// Test1: When writing bytecode, we override the encoding of TestI32Type with
// the encoding of builtin IntegerType. We can natively parse this without
// the use of a callback, relying on the existing builtin reader mechanism.
void runTest1(Operation *op) {
auto builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>();
BytecodeDialectInterface *iface =
builtin->getRegisteredInterface<BytecodeDialectInterface>();
BytecodeWriterConfig writeConfig;
writeConfig.attachTypeCallback(
[&](Type entryValue, std::optional<StringRef> &dialectGroupName,
DialectBytecodeWriter &writer) -> LogicalResult {
// Emit TestIntegerType using the builtin dialect encoding.
if (llvm::isa<test::TestI32Type>(entryValue)) {
llvm::outs() << "Overriding TestI32Type encoding...\n";
auto builtinI32Type =
IntegerType::get(op->getContext(), 32,
IntegerType::SignednessSemantics::Signless);
// Specify that this type will need to be written as part of the
// builtin group. This will override the default dialect group of
// the attribute (test).
dialectGroupName = StringLiteral("builtin");
if (succeeded(iface->writeType(builtinI32Type, writer)))
return success();
}
return failure();
});
// We natively parse the attribute as a builtin, so no callback needed.
ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/true);
doRoundtripWithConfigs(op, writeConfig, parseConfig);
return;
}
// Test2: When writing bytecode, we write standard builtin IntegerTypes. At
// parsing, we use the encoding of IntegerType to intercept all i32. Then,
// instead of creating i32s, we assemble TestI32Type and return it.
void runTest2(Operation *op) {
auto builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>();
BytecodeDialectInterface *iface =
builtin->getRegisteredInterface<BytecodeDialectInterface>();
BytecodeWriterConfig writeConfig;
ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/true);
parseConfig.getBytecodeReaderConfig().attachTypeCallback(
[&](DialectBytecodeReader &reader, StringRef dialectName,
Type &entry) -> LogicalResult {
if (dialectName != StringLiteral("builtin"))
return success();
Type builtinAttr = iface->readType(reader);
if (auto integerType =
llvm::dyn_cast_or_null<IntegerType>(builtinAttr)) {
if (integerType.getWidth() == 32 && integerType.isSignless()) {
llvm::outs() << "Overriding parsing of TestI32Type encoding...\n";
entry = test::TestI32Type::get(reader.getContext());
}
}
return success();
});
doRoundtripWithConfigs(op, writeConfig, parseConfig);
return;
}
// Test3: When writing bytecode, we override the encoding of
// TestAttrParamsAttr with the encoding of builtin DenseIntElementsAttr. We
// can natively parse this without the use of a callback, relying on the
// existing builtin reader mechanism.
void runTest3(Operation *op) {
auto builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>();
BytecodeDialectInterface *iface =
builtin->getRegisteredInterface<BytecodeDialectInterface>();
auto i32Type = IntegerType::get(op->getContext(), 32,
IntegerType::SignednessSemantics::Signless);
BytecodeWriterConfig writeConfig;
writeConfig.attachAttributeCallback(
[&](Attribute entryValue, std::optional<StringRef> &dialectGroupName,
DialectBytecodeWriter &writer) -> LogicalResult {
// Emit TestIntegerType using the builtin dialect encoding.
if (auto testParamAttrs =
llvm::dyn_cast<test::TestAttrParamsAttr>(entryValue)) {
llvm::outs() << "Overriding TestAttrParamsAttr encoding...\n";
// Specify that this attribute will need to be written as part of
// the builtin group. This will override the default dialect group
// of the attribute (test).
dialectGroupName = StringLiteral("builtin");
auto denseAttr = DenseIntElementsAttr::get(
RankedTensorType::get({2}, i32Type),
{testParamAttrs.getV0(), testParamAttrs.getV1()});
if (succeeded(iface->writeAttribute(denseAttr, writer)))
return success();
}
return failure();
});
// We natively parse the attribute as a builtin, so no callback needed.
ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/false);
doRoundtripWithConfigs(op, writeConfig, parseConfig);
return;
}
// Test4: When writing bytecode, we write standard builtin
// DenseIntElementsAttr. At parsing, we use the encoding of
// DenseIntElementsAttr to intercept all ElementsAttr that have shaped type of
// <2xi32>. Instead of assembling a DenseIntElementsAttr, we assemble
// TestAttrParamsAttr and return it.
void runTest4(Operation *op) {
auto builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>();
BytecodeDialectInterface *iface =
builtin->getRegisteredInterface<BytecodeDialectInterface>();
auto i32Type = IntegerType::get(op->getContext(), 32,
IntegerType::SignednessSemantics::Signless);
BytecodeWriterConfig writeConfig;
ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/false);
parseConfig.getBytecodeReaderConfig().attachAttributeCallback(
[&](DialectBytecodeReader &reader, StringRef dialectName,
Attribute &entry) -> LogicalResult {
// Override only the case where the return type of the builtin reader
// is an i32 and fall through on all the other cases, since we want to
// still use TestDialect normal codepath to parse the other types.
Attribute builtinAttr = iface->readAttribute(reader);
if (auto denseAttr =
llvm::dyn_cast_or_null<DenseIntElementsAttr>(builtinAttr)) {
if (denseAttr.getType().getShape() == ArrayRef<int64_t>(2) &&
denseAttr.getElementType() == i32Type) {
llvm::outs()
<< "Overriding parsing of TestAttrParamsAttr encoding...\n";
int v0 = denseAttr.getValues<IntegerAttr>()[0].getInt();
int v1 = denseAttr.getValues<IntegerAttr>()[1].getInt();
entry =
test::TestAttrParamsAttr::get(reader.getContext(), v0, v1);
}
}
return success();
});
doRoundtripWithConfigs(op, writeConfig, parseConfig);
return;
}
// Test5: When writing bytecode, we want TestDialect to use nothing else than
// the builtin types and attributes and take full control of the encoding,
// returning failure if any type or attribute is not part of builtin.
void runTest5(Operation *op) {
auto builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>();
BytecodeDialectInterface *iface =
builtin->getRegisteredInterface<BytecodeDialectInterface>();
BytecodeWriterConfig writeConfig;
writeConfig.attachAttributeCallback(
[&](Attribute attr, std::optional<StringRef> &dialectGroupName,
DialectBytecodeWriter &writer) -> LogicalResult {
return iface->writeAttribute(attr, writer);
});
writeConfig.attachTypeCallback(
[&](Type type, std::optional<StringRef> &dialectGroupName,
DialectBytecodeWriter &writer) -> LogicalResult {
return iface->writeType(type, writer);
});
ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/false);
parseConfig.getBytecodeReaderConfig().attachAttributeCallback(
[&](DialectBytecodeReader &reader, StringRef dialectName,
Attribute &entry) -> LogicalResult {
Attribute builtinAttr = iface->readAttribute(reader);
if (!builtinAttr)
return failure();
entry = builtinAttr;
return success();
});
parseConfig.getBytecodeReaderConfig().attachTypeCallback(
[&](DialectBytecodeReader &reader, StringRef dialectName,
Type &entry) -> LogicalResult {
Type builtinType = iface->readType(reader);
if (!builtinType) {
return failure();
}
entry = builtinType;
return success();
});
doRoundtripWithConfigs(op, writeConfig, parseConfig);
return;
}
};
} // namespace
namespace mlir {
void registerTestBytecodeCallbackPasses() {
PassRegistration<TestBytecodeCallbackPass>();
}
} // namespace mlir

View File

@@ -43,7 +43,6 @@ void registerSymbolTestPasses();
void registerRegionTestPasses();
void registerTestAffineDataCopyPass();
void registerTestAffineReifyValueBoundsPass();
void registerTestBytecodeCallbackPasses();
void registerTestDecomposeAffineOpPass();
void registerTestAffineLoopUnswitchingPass();
void registerTestAllReduceLoweringPass();
@@ -168,7 +167,6 @@ void registerTestPasses() {
registerTestDecomposeAffineOpPass();
registerTestAffineLoopUnswitchingPass();
registerTestAllReduceLoweringPass();
registerTestBytecodeCallbackPasses();
registerTestFunc();
registerTestGpuMemoryPromotionPass();
registerTestLoopPermutationPass();