mirror of
https://github.com/intel/llvm.git
synced 2026-01-26 03:56:16 +08:00
Revert "Expose callbacks for encoding of types/attributes"
This reverts commit b299ec1666.
The authorship informations were incorrect.
This commit is contained in:
@@ -24,17 +24,6 @@
|
||||
#include "llvm/ADT/Twine.h"
|
||||
|
||||
namespace mlir {
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Dialect Version Interface.
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// This class is used to represent the version of a dialect, for the purpose
|
||||
/// of polymorphic destruction.
|
||||
class DialectVersion {
|
||||
public:
|
||||
virtual ~DialectVersion() = default;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DialectBytecodeReader
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -49,14 +38,7 @@ public:
|
||||
virtual ~DialectBytecodeReader() = default;
|
||||
|
||||
/// Emit an error to the reader.
|
||||
virtual InFlightDiagnostic emitError(const Twine &msg = {}) const = 0;
|
||||
|
||||
/// Retrieve the dialect version by name if available.
|
||||
virtual FailureOr<const DialectVersion *>
|
||||
getDialectVersion(StringRef dialectName) const = 0;
|
||||
|
||||
/// Retrieve the context associated to the reader.
|
||||
virtual MLIRContext *getContext() const = 0;
|
||||
virtual InFlightDiagnostic emitError(const Twine &msg = {}) = 0;
|
||||
|
||||
/// Return the bytecode version being read.
|
||||
virtual uint64_t getBytecodeVersion() const = 0;
|
||||
@@ -402,6 +384,17 @@ public:
|
||||
virtual int64_t getBytecodeVersion() const = 0;
|
||||
};
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Dialect Version Interface.
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// This class is used to represent the version of a dialect, for the purpose
|
||||
/// of polymorphic destruction.
|
||||
class DialectVersion {
|
||||
public:
|
||||
virtual ~DialectVersion() = default;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BytecodeDialectInterface
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -416,23 +409,47 @@ public:
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// Read an attribute belonging to this dialect from the given reader. This
|
||||
/// method should return null in the case of failure. Optionally, the dialect
|
||||
/// version can be accessed through the reader.
|
||||
/// method should return null in the case of failure.
|
||||
virtual Attribute readAttribute(DialectBytecodeReader &reader) const {
|
||||
reader.emitError() << "dialect " << getDialect()->getNamespace()
|
||||
<< " does not support reading attributes from bytecode";
|
||||
return Attribute();
|
||||
}
|
||||
|
||||
/// Read a versioned attribute encoding belonging to this dialect from the
|
||||
/// given reader. This method should return null in the case of failure, and
|
||||
/// falls back to the non-versioned reader in case the dialect implements
|
||||
/// versioning but it does not support versioned custom encodings for the
|
||||
/// attributes.
|
||||
virtual Attribute readAttribute(DialectBytecodeReader &reader,
|
||||
const DialectVersion &version) const {
|
||||
reader.emitError()
|
||||
<< "dialect " << getDialect()->getNamespace()
|
||||
<< " does not support reading versioned attributes from bytecode";
|
||||
return Attribute();
|
||||
}
|
||||
|
||||
/// Read a type belonging to this dialect from the given reader. This method
|
||||
/// should return null in the case of failure. Optionally, the dialect version
|
||||
/// can be accessed thorugh the reader.
|
||||
/// should return null in the case of failure.
|
||||
virtual Type readType(DialectBytecodeReader &reader) const {
|
||||
reader.emitError() << "dialect " << getDialect()->getNamespace()
|
||||
<< " does not support reading types from bytecode";
|
||||
return Type();
|
||||
}
|
||||
|
||||
/// Read a versioned type encoding belonging to this dialect from the given
|
||||
/// reader. This method should return null in the case of failure, and
|
||||
/// falls back to the non-versioned reader in case the dialect implements
|
||||
/// versioning but it does not support versioned custom encodings for the
|
||||
/// types.
|
||||
virtual Type readType(DialectBytecodeReader &reader,
|
||||
const DialectVersion &version) const {
|
||||
reader.emitError()
|
||||
<< "dialect " << getDialect()->getNamespace()
|
||||
<< " does not support reading versioned types from bytecode";
|
||||
return Type();
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Writing
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
@@ -25,6 +25,7 @@ class SourceMgr;
|
||||
} // namespace llvm
|
||||
|
||||
namespace mlir {
|
||||
|
||||
/// The BytecodeReader allows to load MLIR bytecode files, while keeping the
|
||||
/// state explicitly available in order to support lazy loading.
|
||||
/// The `finalize` method must be called before destruction.
|
||||
|
||||
@@ -1,120 +0,0 @@
|
||||
//===- BytecodeReader.h - MLIR Bytecode Reader ------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This header defines interfaces to read MLIR bytecode files/streams.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_BYTECODE_BYTECODEREADERCONFIG_H
|
||||
#define MLIR_BYTECODE_BYTECODEREADERCONFIG_H
|
||||
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
namespace mlir {
|
||||
class Attribute;
|
||||
class DialectBytecodeReader;
|
||||
class Type;
|
||||
|
||||
/// A class to interact with the attributes and types parser when parsing MLIR
|
||||
/// bytecode.
|
||||
template <class T>
|
||||
class AttrTypeBytecodeReader {
|
||||
public:
|
||||
AttrTypeBytecodeReader() = default;
|
||||
virtual ~AttrTypeBytecodeReader() = default;
|
||||
|
||||
virtual LogicalResult read(DialectBytecodeReader &reader,
|
||||
StringRef dialectName, T &entry) = 0;
|
||||
|
||||
/// Return an Attribute/Type printer implemented via the given callable, whose
|
||||
/// form should match that of the `parse` function above.
|
||||
template <typename CallableT,
|
||||
std::enable_if_t<
|
||||
std::is_convertible_v<
|
||||
CallableT, std::function<LogicalResult(
|
||||
DialectBytecodeReader &, StringRef, T &)>>,
|
||||
bool> = true>
|
||||
static std::unique_ptr<AttrTypeBytecodeReader<T>>
|
||||
fromCallable(CallableT &&readFn) {
|
||||
struct Processor : public AttrTypeBytecodeReader<T> {
|
||||
Processor(CallableT &&readFn)
|
||||
: AttrTypeBytecodeReader(), readFn(std::move(readFn)) {}
|
||||
LogicalResult read(DialectBytecodeReader &reader, StringRef dialectName,
|
||||
T &entry) override {
|
||||
return readFn(reader, dialectName, entry);
|
||||
}
|
||||
|
||||
std::decay_t<CallableT> readFn;
|
||||
};
|
||||
return std::make_unique<Processor>(std::forward<CallableT>(readFn));
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BytecodeReaderConfig
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// A class containing bytecode-specific configurations of the `ParserConfig`.
|
||||
class BytecodeReaderConfig {
|
||||
public:
|
||||
BytecodeReaderConfig() = default;
|
||||
|
||||
/// Returns the callbacks available to the parser.
|
||||
ArrayRef<std::unique_ptr<AttrTypeBytecodeReader<Attribute>>>
|
||||
getAttributeCallbacks() const {
|
||||
return attributeBytecodeParsers;
|
||||
}
|
||||
ArrayRef<std::unique_ptr<AttrTypeBytecodeReader<Type>>>
|
||||
getTypeCallbacks() const {
|
||||
return typeBytecodeParsers;
|
||||
}
|
||||
|
||||
/// Attach a custom bytecode parser callback to the configuration for parsing
|
||||
/// of custom type/attributes encodings.
|
||||
void attachAttributeCallback(
|
||||
std::unique_ptr<AttrTypeBytecodeReader<Attribute>> parser) {
|
||||
attributeBytecodeParsers.emplace_back(std::move(parser));
|
||||
}
|
||||
void
|
||||
attachTypeCallback(std::unique_ptr<AttrTypeBytecodeReader<Type>> parser) {
|
||||
typeBytecodeParsers.emplace_back(std::move(parser));
|
||||
}
|
||||
|
||||
/// Attach a custom bytecode parser callback to the configuration for parsing
|
||||
/// of custom type/attributes encodings.
|
||||
template <typename CallableT>
|
||||
std::enable_if_t<std::is_convertible_v<
|
||||
CallableT, std::function<LogicalResult(DialectBytecodeReader &, StringRef,
|
||||
Attribute &)>>>
|
||||
attachAttributeCallback(CallableT &&parserFn) {
|
||||
attachAttributeCallback(AttrTypeBytecodeReader<Attribute>::fromCallable(
|
||||
std::forward<CallableT>(parserFn)));
|
||||
}
|
||||
template <typename CallableT>
|
||||
std::enable_if_t<std::is_convertible_v<
|
||||
CallableT,
|
||||
std::function<LogicalResult(DialectBytecodeReader &, StringRef, Type &)>>>
|
||||
attachTypeCallback(CallableT &&parserFn) {
|
||||
attachTypeCallback(AttrTypeBytecodeReader<Type>::fromCallable(
|
||||
std::forward<CallableT>(parserFn)));
|
||||
}
|
||||
|
||||
private:
|
||||
llvm::SmallVector<std::unique_ptr<AttrTypeBytecodeReader<Attribute>>>
|
||||
attributeBytecodeParsers;
|
||||
llvm::SmallVector<std::unique_ptr<AttrTypeBytecodeReader<Type>>>
|
||||
typeBytecodeParsers;
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_BYTECODE_BYTECODEREADERCONFIG_H
|
||||
@@ -17,55 +17,6 @@
|
||||
|
||||
namespace mlir {
|
||||
class Operation;
|
||||
class DialectBytecodeWriter;
|
||||
|
||||
/// A class to interact with the attributes and types printer when emitting MLIR
|
||||
/// bytecode.
|
||||
template <class T>
|
||||
class AttrTypeBytecodeWriter {
|
||||
public:
|
||||
AttrTypeBytecodeWriter() = default;
|
||||
virtual ~AttrTypeBytecodeWriter() = default;
|
||||
|
||||
/// Callback writer API used in IRNumbering, where groups are created and
|
||||
/// type/attribute components are numbered. At this stage, writer is expected
|
||||
/// to be a `NumberingDialectWriter`.
|
||||
virtual LogicalResult write(T entry, std::optional<StringRef> &name,
|
||||
DialectBytecodeWriter &writer) = 0;
|
||||
|
||||
/// Callback writer API used in BytecodeWriter, where groups are created and
|
||||
/// type/attribute components are numbered. Here, DialectBytecodeWriter is
|
||||
/// expected to be an actual writer. The optional stringref specified by
|
||||
/// the user is ignored, since the group was already specified when numbering
|
||||
/// the IR.
|
||||
LogicalResult write(T entry, DialectBytecodeWriter &writer) {
|
||||
std::optional<StringRef> dummy;
|
||||
return write(entry, dummy, writer);
|
||||
}
|
||||
|
||||
/// Return an Attribute/Type printer implemented via the given callable, whose
|
||||
/// form should match that of the `write` function above.
|
||||
template <typename CallableT,
|
||||
std::enable_if_t<std::is_convertible_v<
|
||||
CallableT, std::function<LogicalResult(
|
||||
T, std::optional<StringRef> &,
|
||||
DialectBytecodeWriter &)>>,
|
||||
bool> = true>
|
||||
static std::unique_ptr<AttrTypeBytecodeWriter<T>>
|
||||
fromCallable(CallableT &&writeFn) {
|
||||
struct Processor : public AttrTypeBytecodeWriter<T> {
|
||||
Processor(CallableT &&writeFn)
|
||||
: AttrTypeBytecodeWriter(), writeFn(std::move(writeFn)) {}
|
||||
LogicalResult write(T entry, std::optional<StringRef> &name,
|
||||
DialectBytecodeWriter &writer) override {
|
||||
return writeFn(entry, name, writer);
|
||||
}
|
||||
|
||||
std::decay_t<CallableT> writeFn;
|
||||
};
|
||||
return std::make_unique<Processor>(std::forward<CallableT>(writeFn));
|
||||
}
|
||||
};
|
||||
|
||||
/// This class contains the configuration used for the bytecode writer. It
|
||||
/// controls various aspects of bytecode generation, and contains all of the
|
||||
@@ -97,43 +48,6 @@ public:
|
||||
/// Get the set desired bytecode version to emit.
|
||||
int64_t getDesiredBytecodeVersion() const;
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Types and Attributes encoding
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// Retrieve the callbacks.
|
||||
ArrayRef<std::unique_ptr<AttrTypeBytecodeWriter<Attribute>>>
|
||||
getAttributeWriterCallbacks() const;
|
||||
ArrayRef<std::unique_ptr<AttrTypeBytecodeWriter<Type>>>
|
||||
getTypeWriterCallbacks() const;
|
||||
|
||||
/// Attach a custom bytecode printer callback to the configuration for the
|
||||
/// emission of custom type/attributes encodings.
|
||||
void attachAttributeCallback(
|
||||
std::unique_ptr<AttrTypeBytecodeWriter<Attribute>> callback);
|
||||
void
|
||||
attachTypeCallback(std::unique_ptr<AttrTypeBytecodeWriter<Type>> callback);
|
||||
|
||||
/// Attach a custom bytecode printer callback to the configuration for the
|
||||
/// emission of custom type/attributes encodings.
|
||||
template <typename CallableT>
|
||||
std::enable_if_t<std::is_convertible_v<
|
||||
CallableT,
|
||||
std::function<LogicalResult(Attribute, std::optional<StringRef> &,
|
||||
DialectBytecodeWriter &)>>>
|
||||
attachAttributeCallback(CallableT &&emitFn) {
|
||||
attachAttributeCallback(AttrTypeBytecodeWriter<Attribute>::fromCallable(
|
||||
std::forward<CallableT>(emitFn)));
|
||||
}
|
||||
template <typename CallableT>
|
||||
std::enable_if_t<std::is_convertible_v<
|
||||
CallableT, std::function<LogicalResult(Type, std::optional<StringRef> &,
|
||||
DialectBytecodeWriter &)>>>
|
||||
attachTypeCallback(CallableT &&emitFn) {
|
||||
attachTypeCallback(AttrTypeBytecodeWriter<Type>::fromCallable(
|
||||
std::forward<CallableT>(emitFn)));
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Resources
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
#ifndef MLIR_IR_ASMSTATE_H_
|
||||
#define MLIR_IR_ASMSTATE_H_
|
||||
|
||||
#include "mlir/Bytecode/BytecodeReaderConfig.h"
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/ADT/MapVector.h"
|
||||
@@ -476,11 +475,6 @@ public:
|
||||
/// Returns if the parser should verify the IR after parsing.
|
||||
bool shouldVerifyAfterParse() const { return verifyAfterParse; }
|
||||
|
||||
/// Returns the parsing configurations associated to the bytecode read.
|
||||
BytecodeReaderConfig &getBytecodeReaderConfig() const {
|
||||
return const_cast<BytecodeReaderConfig &>(bytecodeReaderConfig);
|
||||
}
|
||||
|
||||
/// Return the resource parser registered to the given name, or nullptr if no
|
||||
/// parser with `name` is registered.
|
||||
AsmResourceParser *getResourceParser(StringRef name) const {
|
||||
@@ -515,7 +509,6 @@ private:
|
||||
bool verifyAfterParse;
|
||||
DenseMap<StringRef, std::unique_ptr<AsmResourceParser>> resourceParsers;
|
||||
FallbackAsmResourceMap *fallbackResourceMap;
|
||||
BytecodeReaderConfig bytecodeReaderConfig;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -451,7 +451,7 @@ struct BytecodeDialect {
|
||||
/// Returns failure if the dialect couldn't be loaded *and* the provided
|
||||
/// context does not allow unregistered dialects. The provided reader is used
|
||||
/// for error emission if necessary.
|
||||
LogicalResult load(const DialectReader &reader, MLIRContext *ctx);
|
||||
LogicalResult load(DialectReader &reader, MLIRContext *ctx);
|
||||
|
||||
/// Return the loaded dialect, or nullptr if the dialect is unknown. This can
|
||||
/// only be called after `load`.
|
||||
@@ -505,11 +505,10 @@ struct BytecodeOperationName {
|
||||
|
||||
/// Parse a single dialect group encoded in the byte stream.
|
||||
static LogicalResult parseDialectGrouping(
|
||||
EncodingReader &reader,
|
||||
MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
|
||||
EncodingReader &reader, MutableArrayRef<BytecodeDialect> dialects,
|
||||
function_ref<LogicalResult(BytecodeDialect *)> entryCallback) {
|
||||
// Parse the dialect and the number of entries in the group.
|
||||
std::unique_ptr<BytecodeDialect> *dialect;
|
||||
BytecodeDialect *dialect;
|
||||
if (failed(parseEntry(reader, dialects, dialect, "dialect")))
|
||||
return failure();
|
||||
uint64_t numEntries;
|
||||
@@ -517,7 +516,7 @@ static LogicalResult parseDialectGrouping(
|
||||
return failure();
|
||||
|
||||
for (uint64_t i = 0; i < numEntries; ++i)
|
||||
if (failed(entryCallback(dialect->get())))
|
||||
if (failed(entryCallback(dialect)))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
@@ -533,7 +532,7 @@ public:
|
||||
/// Initialize the resource section reader with the given section data.
|
||||
LogicalResult
|
||||
initialize(Location fileLoc, const ParserConfig &config,
|
||||
MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
|
||||
MutableArrayRef<BytecodeDialect> dialects,
|
||||
StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData,
|
||||
ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader,
|
||||
const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef);
|
||||
@@ -683,7 +682,7 @@ parseResourceGroup(Location fileLoc, bool allowEmpty,
|
||||
|
||||
LogicalResult ResourceSectionReader::initialize(
|
||||
Location fileLoc, const ParserConfig &config,
|
||||
MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
|
||||
MutableArrayRef<BytecodeDialect> dialects,
|
||||
StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData,
|
||||
ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader,
|
||||
const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) {
|
||||
@@ -732,19 +731,19 @@ LogicalResult ResourceSectionReader::initialize(
|
||||
// Read the dialect resources from the bytecode.
|
||||
MLIRContext *ctx = fileLoc->getContext();
|
||||
while (!offsetReader.empty()) {
|
||||
std::unique_ptr<BytecodeDialect> *dialect;
|
||||
BytecodeDialect *dialect;
|
||||
if (failed(parseEntry(offsetReader, dialects, dialect, "dialect")) ||
|
||||
failed((*dialect)->load(dialectReader, ctx)))
|
||||
failed(dialect->load(dialectReader, ctx)))
|
||||
return failure();
|
||||
Dialect *loadedDialect = (*dialect)->getLoadedDialect();
|
||||
Dialect *loadedDialect = dialect->getLoadedDialect();
|
||||
if (!loadedDialect) {
|
||||
return resourceReader.emitError()
|
||||
<< "dialect '" << (*dialect)->name << "' is unknown";
|
||||
<< "dialect '" << dialect->name << "' is unknown";
|
||||
}
|
||||
const auto *handler = dyn_cast<OpAsmDialectInterface>(loadedDialect);
|
||||
if (!handler) {
|
||||
return resourceReader.emitError()
|
||||
<< "unexpected resources for dialect '" << (*dialect)->name << "'";
|
||||
<< "unexpected resources for dialect '" << dialect->name << "'";
|
||||
}
|
||||
|
||||
// Ensure that each resource is declared before being processed.
|
||||
@@ -754,7 +753,7 @@ LogicalResult ResourceSectionReader::initialize(
|
||||
if (failed(handle)) {
|
||||
return resourceReader.emitError()
|
||||
<< "unknown 'resource' key '" << key << "' for dialect '"
|
||||
<< (*dialect)->name << "'";
|
||||
<< dialect->name << "'";
|
||||
}
|
||||
dialectResourceHandleRenamingMap[key] = handler->getResourceKey(*handle);
|
||||
dialectResources.push_back(*handle);
|
||||
@@ -797,19 +796,15 @@ class AttrTypeReader {
|
||||
|
||||
public:
|
||||
AttrTypeReader(StringSectionReader &stringReader,
|
||||
ResourceSectionReader &resourceReader,
|
||||
const llvm::StringMap<BytecodeDialect *> &dialectsMap,
|
||||
uint64_t &bytecodeVersion, Location fileLoc,
|
||||
const ParserConfig &config)
|
||||
ResourceSectionReader &resourceReader, Location fileLoc,
|
||||
uint64_t &bytecodeVersion)
|
||||
: stringReader(stringReader), resourceReader(resourceReader),
|
||||
dialectsMap(dialectsMap), fileLoc(fileLoc),
|
||||
bytecodeVersion(bytecodeVersion), parserConfig(config) {}
|
||||
fileLoc(fileLoc), bytecodeVersion(bytecodeVersion) {}
|
||||
|
||||
/// Initialize the attribute and type information within the reader.
|
||||
LogicalResult
|
||||
initialize(MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
|
||||
ArrayRef<uint8_t> sectionData,
|
||||
ArrayRef<uint8_t> offsetSectionData);
|
||||
LogicalResult initialize(MutableArrayRef<BytecodeDialect> dialects,
|
||||
ArrayRef<uint8_t> sectionData,
|
||||
ArrayRef<uint8_t> offsetSectionData);
|
||||
|
||||
/// Resolve the attribute or type at the given index. Returns nullptr on
|
||||
/// failure.
|
||||
@@ -883,10 +878,6 @@ private:
|
||||
/// parsing custom encoded attribute/type entries.
|
||||
ResourceSectionReader &resourceReader;
|
||||
|
||||
/// The map of the loaded dialects used to retrieve dialect information, such
|
||||
/// as the dialect version.
|
||||
const llvm::StringMap<BytecodeDialect *> &dialectsMap;
|
||||
|
||||
/// The set of attribute and type entries.
|
||||
SmallVector<AttrEntry> attributes;
|
||||
SmallVector<TypeEntry> types;
|
||||
@@ -896,48 +887,27 @@ private:
|
||||
|
||||
/// Current bytecode version being used.
|
||||
uint64_t &bytecodeVersion;
|
||||
|
||||
/// Reference to the parser configuration.
|
||||
const ParserConfig &parserConfig;
|
||||
};
|
||||
|
||||
class DialectReader : public DialectBytecodeReader {
|
||||
public:
|
||||
DialectReader(AttrTypeReader &attrTypeReader,
|
||||
StringSectionReader &stringReader,
|
||||
ResourceSectionReader &resourceReader,
|
||||
const llvm::StringMap<BytecodeDialect *> &dialectsMap,
|
||||
EncodingReader &reader, uint64_t &bytecodeVersion)
|
||||
ResourceSectionReader &resourceReader, EncodingReader &reader,
|
||||
uint64_t &bytecodeVersion)
|
||||
: attrTypeReader(attrTypeReader), stringReader(stringReader),
|
||||
resourceReader(resourceReader), dialectsMap(dialectsMap),
|
||||
reader(reader), bytecodeVersion(bytecodeVersion) {}
|
||||
resourceReader(resourceReader), reader(reader),
|
||||
bytecodeVersion(bytecodeVersion) {}
|
||||
|
||||
InFlightDiagnostic emitError(const Twine &msg) const override {
|
||||
InFlightDiagnostic emitError(const Twine &msg) override {
|
||||
return reader.emitError(msg);
|
||||
}
|
||||
|
||||
FailureOr<const DialectVersion *>
|
||||
getDialectVersion(StringRef dialectName) const override {
|
||||
// First check if the dialect is available in the map.
|
||||
auto dialectEntry = dialectsMap.find(dialectName);
|
||||
if (dialectEntry == dialectsMap.end())
|
||||
return failure();
|
||||
// If the dialect was found, try to load it. This will trigger reading the
|
||||
// bytecode version from the version buffer if it wasn't already processed.
|
||||
// Return failure if either of those two actions could not be completed.
|
||||
if (failed(dialectEntry->getValue()->load(*this, getLoc().getContext())) ||
|
||||
dialectEntry->getValue()->loadedVersion.get() == nullptr)
|
||||
return failure();
|
||||
return dialectEntry->getValue()->loadedVersion.get();
|
||||
}
|
||||
|
||||
MLIRContext *getContext() const override { return getLoc().getContext(); }
|
||||
|
||||
uint64_t getBytecodeVersion() const override { return bytecodeVersion; }
|
||||
|
||||
DialectReader withEncodingReader(EncodingReader &encReader) const {
|
||||
DialectReader withEncodingReader(EncodingReader &encReader) {
|
||||
return DialectReader(attrTypeReader, stringReader, resourceReader,
|
||||
dialectsMap, encReader, bytecodeVersion);
|
||||
encReader, bytecodeVersion);
|
||||
}
|
||||
|
||||
Location getLoc() const { return reader.getLoc(); }
|
||||
@@ -1040,7 +1010,6 @@ private:
|
||||
AttrTypeReader &attrTypeReader;
|
||||
StringSectionReader &stringReader;
|
||||
ResourceSectionReader &resourceReader;
|
||||
const llvm::StringMap<BytecodeDialect *> &dialectsMap;
|
||||
EncodingReader &reader;
|
||||
uint64_t &bytecodeVersion;
|
||||
};
|
||||
@@ -1127,9 +1096,10 @@ private:
|
||||
};
|
||||
} // namespace
|
||||
|
||||
LogicalResult AttrTypeReader::initialize(
|
||||
MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
|
||||
ArrayRef<uint8_t> sectionData, ArrayRef<uint8_t> offsetSectionData) {
|
||||
LogicalResult
|
||||
AttrTypeReader::initialize(MutableArrayRef<BytecodeDialect> dialects,
|
||||
ArrayRef<uint8_t> sectionData,
|
||||
ArrayRef<uint8_t> offsetSectionData) {
|
||||
EncodingReader offsetReader(offsetSectionData, fileLoc);
|
||||
|
||||
// Parse the number of attribute and type entries.
|
||||
@@ -1181,7 +1151,6 @@ LogicalResult AttrTypeReader::initialize(
|
||||
return offsetReader.emitError(
|
||||
"unexpected trailing data in the Attribute/Type offset section");
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -1247,54 +1216,32 @@ template <typename T>
|
||||
LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
|
||||
EncodingReader &reader,
|
||||
StringRef entryType) {
|
||||
DialectReader dialectReader(*this, stringReader, resourceReader, dialectsMap,
|
||||
reader, bytecodeVersion);
|
||||
DialectReader dialectReader(*this, stringReader, resourceReader, reader,
|
||||
bytecodeVersion);
|
||||
if (failed(entry.dialect->load(dialectReader, fileLoc.getContext())))
|
||||
return failure();
|
||||
|
||||
if constexpr (std::is_same_v<T, Type>) {
|
||||
// Try parsing with callbacks first if available.
|
||||
for (const auto &callback :
|
||||
parserConfig.getBytecodeReaderConfig().getTypeCallbacks()) {
|
||||
if (failed(
|
||||
callback->read(dialectReader, entry.dialect->name, entry.entry)))
|
||||
return failure();
|
||||
// Early return if parsing was successful.
|
||||
if (!!entry.entry)
|
||||
return success();
|
||||
|
||||
// Reset the reader if we failed to parse, so we can fall through the
|
||||
// other parsing functions.
|
||||
reader = EncodingReader(entry.data, reader.getLoc());
|
||||
}
|
||||
} else {
|
||||
// Try parsing with callbacks first if available.
|
||||
for (const auto &callback :
|
||||
parserConfig.getBytecodeReaderConfig().getAttributeCallbacks()) {
|
||||
if (failed(
|
||||
callback->read(dialectReader, entry.dialect->name, entry.entry)))
|
||||
return failure();
|
||||
// Early return if parsing was successful.
|
||||
if (!!entry.entry)
|
||||
return success();
|
||||
|
||||
// Reset the reader if we failed to parse, so we can fall through the
|
||||
// other parsing functions.
|
||||
reader = EncodingReader(entry.data, reader.getLoc());
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that the dialect implements the bytecode interface.
|
||||
if (!entry.dialect->interface) {
|
||||
return reader.emitError("dialect '", entry.dialect->name,
|
||||
"' does not implement the bytecode interface");
|
||||
}
|
||||
|
||||
if constexpr (std::is_same_v<T, Type>)
|
||||
entry.entry = entry.dialect->interface->readType(dialectReader);
|
||||
else
|
||||
entry.entry = entry.dialect->interface->readAttribute(dialectReader);
|
||||
// Ask the dialect to parse the entry. If the dialect is versioned, parse
|
||||
// using the versioned encoding readers.
|
||||
if (entry.dialect->loadedVersion.get()) {
|
||||
if constexpr (std::is_same_v<T, Type>)
|
||||
entry.entry = entry.dialect->interface->readType(
|
||||
dialectReader, *entry.dialect->loadedVersion);
|
||||
else
|
||||
entry.entry = entry.dialect->interface->readAttribute(
|
||||
dialectReader, *entry.dialect->loadedVersion);
|
||||
|
||||
} else {
|
||||
if constexpr (std::is_same_v<T, Type>)
|
||||
entry.entry = entry.dialect->interface->readType(dialectReader);
|
||||
else
|
||||
entry.entry = entry.dialect->interface->readAttribute(dialectReader);
|
||||
}
|
||||
return success(!!entry.entry);
|
||||
}
|
||||
|
||||
@@ -1315,8 +1262,7 @@ public:
|
||||
llvm::MemoryBufferRef buffer,
|
||||
const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef)
|
||||
: config(config), fileLoc(fileLoc), lazyLoading(lazyLoading),
|
||||
attrTypeReader(stringReader, resourceReader, dialectsMap, version,
|
||||
fileLoc, config),
|
||||
attrTypeReader(stringReader, resourceReader, fileLoc, version),
|
||||
// Use the builtin unrealized conversion cast operation to represent
|
||||
// forward references to values that aren't yet defined.
|
||||
forwardRefOpState(UnknownLoc::get(config.getContext()),
|
||||
@@ -1582,8 +1528,7 @@ private:
|
||||
StringRef producer;
|
||||
|
||||
/// The table of IR units referenced within the bytecode file.
|
||||
SmallVector<std::unique_ptr<BytecodeDialect>> dialects;
|
||||
llvm::StringMap<BytecodeDialect *> dialectsMap;
|
||||
SmallVector<BytecodeDialect> dialects;
|
||||
SmallVector<BytecodeOperationName> opNames;
|
||||
|
||||
/// The reader used to process resources within the bytecode.
|
||||
@@ -1730,8 +1675,7 @@ LogicalResult BytecodeReader::Impl::parseVersion(EncodingReader &reader) {
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Dialect Section
|
||||
|
||||
LogicalResult BytecodeDialect::load(const DialectReader &reader,
|
||||
MLIRContext *ctx) {
|
||||
LogicalResult BytecodeDialect::load(DialectReader &reader, MLIRContext *ctx) {
|
||||
if (dialect)
|
||||
return success();
|
||||
Dialect *loadedDialect = ctx->getOrLoadDialect(name);
|
||||
@@ -1775,15 +1719,13 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
|
||||
|
||||
// Parse each of the dialects.
|
||||
for (uint64_t i = 0; i < numDialects; ++i) {
|
||||
dialects[i] = std::make_unique<BytecodeDialect>();
|
||||
/// Before version kDialectVersioning, there wasn't any versioning available
|
||||
/// for dialects, and the entryIdx represent the string itself.
|
||||
if (version < bytecode::kDialectVersioning) {
|
||||
if (failed(stringReader.parseString(sectionReader, dialects[i]->name)))
|
||||
if (failed(stringReader.parseString(sectionReader, dialects[i].name)))
|
||||
return failure();
|
||||
continue;
|
||||
}
|
||||
|
||||
// Parse ID representing dialect and version.
|
||||
uint64_t dialectNameIdx;
|
||||
bool versionAvailable;
|
||||
@@ -1791,19 +1733,18 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
|
||||
versionAvailable)))
|
||||
return failure();
|
||||
if (failed(stringReader.parseStringAtIndex(sectionReader, dialectNameIdx,
|
||||
dialects[i]->name)))
|
||||
dialects[i].name)))
|
||||
return failure();
|
||||
if (versionAvailable) {
|
||||
bytecode::Section::ID sectionID;
|
||||
if (failed(sectionReader.parseSection(sectionID,
|
||||
dialects[i]->versionBuffer)))
|
||||
if (failed(
|
||||
sectionReader.parseSection(sectionID, dialects[i].versionBuffer)))
|
||||
return failure();
|
||||
if (sectionID != bytecode::Section::kDialectVersions) {
|
||||
emitError(fileLoc, "expected dialect version section");
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
dialectsMap[dialects[i]->name] = dialects[i].get();
|
||||
}
|
||||
|
||||
// Parse the operation names, which are grouped by dialect.
|
||||
@@ -1851,7 +1792,7 @@ BytecodeReader::Impl::parseOpName(EncodingReader &reader,
|
||||
if (!opName->opName) {
|
||||
// Load the dialect and its version.
|
||||
DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
|
||||
dialectsMap, reader, version);
|
||||
reader, version);
|
||||
if (failed(opName->dialect->load(dialectReader, getContext())))
|
||||
return failure();
|
||||
// If the opName is empty, this is because we use to accept names such as
|
||||
@@ -1894,7 +1835,7 @@ LogicalResult BytecodeReader::Impl::parseResourceSection(
|
||||
|
||||
// Initialize the resource reader with the resource sections.
|
||||
DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
|
||||
dialectsMap, reader, version);
|
||||
reader, version);
|
||||
return resourceReader.initialize(fileLoc, config, dialects, stringReader,
|
||||
*resourceData, *resourceOffsetData,
|
||||
dialectReader, bufferOwnerRef);
|
||||
@@ -2095,14 +2036,14 @@ BytecodeReader::Impl::parseIRSection(ArrayRef<uint8_t> sectionData,
|
||||
"parsed use-list orders were invalid and could not be applied");
|
||||
|
||||
// Resolve dialect version.
|
||||
for (const std::unique_ptr<BytecodeDialect> &byteCodeDialect : dialects) {
|
||||
for (const BytecodeDialect &byteCodeDialect : dialects) {
|
||||
// Parsing is complete, give an opportunity to each dialect to visit the
|
||||
// IR and perform upgrades.
|
||||
if (!byteCodeDialect->loadedVersion)
|
||||
if (!byteCodeDialect.loadedVersion)
|
||||
continue;
|
||||
if (byteCodeDialect->interface &&
|
||||
failed(byteCodeDialect->interface->upgradeFromVersion(
|
||||
*moduleOp, *byteCodeDialect->loadedVersion)))
|
||||
if (byteCodeDialect.interface &&
|
||||
failed(byteCodeDialect.interface->upgradeFromVersion(
|
||||
*moduleOp, *byteCodeDialect.loadedVersion)))
|
||||
return failure();
|
||||
}
|
||||
|
||||
@@ -2255,7 +2196,7 @@ BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader,
|
||||
// interface and control the serialization.
|
||||
if (wasRegistered) {
|
||||
DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
|
||||
dialectsMap, reader, version);
|
||||
reader, version);
|
||||
if (failed(
|
||||
propertiesReader.read(fileLoc, dialectReader, &*opName, opState)))
|
||||
return failure();
|
||||
|
||||
@@ -18,10 +18,15 @@
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/CachedHashString.h"
|
||||
#include "llvm/ADT/MapVector.h"
|
||||
#include "llvm/ADT/SmallString.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Endian.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "llvm/Support/Endian.h"
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <optional>
|
||||
#include <sys/types.h>
|
||||
|
||||
#define DEBUG_TYPE "mlir-bytecode-writer"
|
||||
|
||||
@@ -42,12 +47,6 @@ struct BytecodeWriterConfig::Impl {
|
||||
/// The producer of the bytecode.
|
||||
StringRef producer;
|
||||
|
||||
/// Printer callbacks used to emit custom type and attribute encodings.
|
||||
llvm::SmallVector<std::unique_ptr<AttrTypeBytecodeWriter<Attribute>>>
|
||||
attributeWriterCallbacks;
|
||||
llvm::SmallVector<std::unique_ptr<AttrTypeBytecodeWriter<Type>>>
|
||||
typeWriterCallbacks;
|
||||
|
||||
/// A collection of non-dialect resource printers.
|
||||
SmallVector<std::unique_ptr<AsmResourcePrinter>> externalResourcePrinters;
|
||||
};
|
||||
@@ -61,26 +60,6 @@ BytecodeWriterConfig::BytecodeWriterConfig(FallbackAsmResourceMap &map,
|
||||
}
|
||||
BytecodeWriterConfig::~BytecodeWriterConfig() = default;
|
||||
|
||||
ArrayRef<std::unique_ptr<AttrTypeBytecodeWriter<Attribute>>>
|
||||
BytecodeWriterConfig::getAttributeWriterCallbacks() const {
|
||||
return impl->attributeWriterCallbacks;
|
||||
}
|
||||
|
||||
ArrayRef<std::unique_ptr<AttrTypeBytecodeWriter<Type>>>
|
||||
BytecodeWriterConfig::getTypeWriterCallbacks() const {
|
||||
return impl->typeWriterCallbacks;
|
||||
}
|
||||
|
||||
void BytecodeWriterConfig::attachAttributeCallback(
|
||||
std::unique_ptr<AttrTypeBytecodeWriter<Attribute>> callback) {
|
||||
impl->attributeWriterCallbacks.emplace_back(std::move(callback));
|
||||
}
|
||||
|
||||
void BytecodeWriterConfig::attachTypeCallback(
|
||||
std::unique_ptr<AttrTypeBytecodeWriter<Type>> callback) {
|
||||
impl->typeWriterCallbacks.emplace_back(std::move(callback));
|
||||
}
|
||||
|
||||
void BytecodeWriterConfig::attachResourcePrinter(
|
||||
std::unique_ptr<AsmResourcePrinter> printer) {
|
||||
impl->externalResourcePrinters.emplace_back(std::move(printer));
|
||||
@@ -795,50 +774,32 @@ void BytecodeWriter::writeAttrTypeSection(EncodingEmitter &emitter) {
|
||||
auto emitAttrOrType = [&](auto &entry) {
|
||||
auto entryValue = entry.getValue();
|
||||
|
||||
auto emitAttrOrTypeRawImpl = [&]() -> void {
|
||||
RawEmitterOstream(attrTypeEmitter) << entryValue;
|
||||
attrTypeEmitter.emitByte(0);
|
||||
};
|
||||
auto emitAttrOrTypeImpl = [&]() -> bool {
|
||||
// TODO: We don't currently support custom encoded mutable types and
|
||||
// attributes.
|
||||
if (entryValue.template hasTrait<TypeTrait::IsMutable>() ||
|
||||
entryValue.template hasTrait<AttributeTrait::IsMutable>()) {
|
||||
emitAttrOrTypeRawImpl();
|
||||
return false;
|
||||
}
|
||||
|
||||
// First, try to emit this entry using the dialect bytecode interface.
|
||||
bool hasCustomEncoding = false;
|
||||
if (const BytecodeDialectInterface *interface = entry.dialect->interface) {
|
||||
// The writer used when emitting using a custom bytecode encoding.
|
||||
DialectWriter dialectWriter(config.bytecodeVersion, attrTypeEmitter,
|
||||
numberingState, stringSection);
|
||||
|
||||
if constexpr (std::is_same_v<std::decay_t<decltype(entryValue)>, Type>) {
|
||||
for (const auto &callback : config.typeWriterCallbacks) {
|
||||
if (succeeded(callback->write(entryValue, dialectWriter)))
|
||||
return true;
|
||||
}
|
||||
if (const BytecodeDialectInterface *interface =
|
||||
entry.dialect->interface) {
|
||||
if (succeeded(interface->writeType(entryValue, dialectWriter)))
|
||||
return true;
|
||||
}
|
||||
// TODO: We don't currently support custom encoded mutable types.
|
||||
hasCustomEncoding =
|
||||
!entryValue.template hasTrait<TypeTrait::IsMutable>() &&
|
||||
succeeded(interface->writeType(entryValue, dialectWriter));
|
||||
} else {
|
||||
for (const auto &callback : config.attributeWriterCallbacks) {
|
||||
if (succeeded(callback->write(entryValue, dialectWriter)))
|
||||
return true;
|
||||
}
|
||||
if (const BytecodeDialectInterface *interface =
|
||||
entry.dialect->interface) {
|
||||
if (succeeded(interface->writeAttribute(entryValue, dialectWriter)))
|
||||
return true;
|
||||
}
|
||||
// TODO: We don't currently support custom encoded mutable attributes.
|
||||
hasCustomEncoding =
|
||||
!entryValue.template hasTrait<AttributeTrait::IsMutable>() &&
|
||||
succeeded(interface->writeAttribute(entryValue, dialectWriter));
|
||||
}
|
||||
}
|
||||
|
||||
// If the entry was not emitted using a callback or a dialect interface,
|
||||
// emit it using the textual format.
|
||||
emitAttrOrTypeRawImpl();
|
||||
return false;
|
||||
};
|
||||
|
||||
bool hasCustomEncoding = emitAttrOrTypeImpl();
|
||||
// If the entry was not emitted using the dialect interface, emit it using
|
||||
// the textual format.
|
||||
if (!hasCustomEncoding) {
|
||||
RawEmitterOstream(attrTypeEmitter) << entryValue;
|
||||
attrTypeEmitter.emitByte(0);
|
||||
}
|
||||
|
||||
// Record the offset of this entry.
|
||||
uint64_t curOffset = attrTypeEmitter.size();
|
||||
|
||||
@@ -314,22 +314,9 @@ void IRNumberingState::number(Attribute attr) {
|
||||
|
||||
// If this attribute will be emitted using the bytecode format, perform a
|
||||
// dummy writing to number any nested components.
|
||||
// TODO: We don't allow custom encodings for mutable attributes right now.
|
||||
if (!attr.hasTrait<AttributeTrait::IsMutable>()) {
|
||||
// Try overriding emission with callbacks.
|
||||
for (const auto &callback : config.getAttributeWriterCallbacks()) {
|
||||
NumberingDialectWriter writer(*this);
|
||||
// The client has the ability to override the group name through the
|
||||
// callback.
|
||||
std::optional<StringRef> groupNameOverride;
|
||||
if (succeeded(callback->write(attr, groupNameOverride, writer))) {
|
||||
if (groupNameOverride.has_value())
|
||||
numbering->dialect = &numberDialect(*groupNameOverride);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (const auto *interface = numbering->dialect->interface) {
|
||||
if (const auto *interface = numbering->dialect->interface) {
|
||||
// TODO: We don't allow custom encodings for mutable attributes right now.
|
||||
if (!attr.hasTrait<AttributeTrait::IsMutable>()) {
|
||||
NumberingDialectWriter writer(*this);
|
||||
if (succeeded(interface->writeAttribute(attr, writer)))
|
||||
return;
|
||||
@@ -477,24 +464,9 @@ void IRNumberingState::number(Type type) {
|
||||
|
||||
// If this type will be emitted using the bytecode format, perform a dummy
|
||||
// writing to number any nested components.
|
||||
// TODO: We don't allow custom encodings for mutable types right now.
|
||||
if (!type.hasTrait<TypeTrait::IsMutable>()) {
|
||||
// Try overriding emission with callbacks.
|
||||
for (const auto &callback : config.getTypeWriterCallbacks()) {
|
||||
NumberingDialectWriter writer(*this);
|
||||
// The client has the ability to override the group name through the
|
||||
// callback.
|
||||
std::optional<StringRef> groupNameOverride;
|
||||
if (succeeded(callback->write(type, groupNameOverride, writer))) {
|
||||
if (groupNameOverride.has_value())
|
||||
numbering->dialect = &numberDialect(*groupNameOverride);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// If this attribute will be emitted using the bytecode format, perform a
|
||||
// dummy writing to number any nested components.
|
||||
if (const auto *interface = numbering->dialect->interface) {
|
||||
if (const auto *interface = numbering->dialect->interface) {
|
||||
// TODO: We don't allow custom encodings for mutable types right now.
|
||||
if (!type.hasTrait<TypeTrait::IsMutable>()) {
|
||||
NumberingDialectWriter writer(*this);
|
||||
if (succeeded(interface->writeType(type, writer)))
|
||||
return;
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
// RUN: mlir-opt %s --test-bytecode-callback="test-dialect-version=1.2" -verify-diagnostics | FileCheck %s --check-prefix=VERSION_1_2
|
||||
// RUN: mlir-opt %s --test-bytecode-callback="test-dialect-version=2.0" -verify-diagnostics | FileCheck %s --check-prefix=VERSION_2_0
|
||||
|
||||
func.func @base_test(%arg0 : i32) -> f32 {
|
||||
%0 = "test.addi"(%arg0, %arg0) : (i32, i32) -> i32
|
||||
%1 = "test.cast"(%0) : (i32) -> f32
|
||||
return %1 : f32
|
||||
}
|
||||
|
||||
// VERSION_1_2: Overriding IntegerType encoding...
|
||||
// VERSION_1_2: Overriding parsing of IntegerType encoding...
|
||||
|
||||
// VERSION_2_0-NOT: Overriding IntegerType encoding...
|
||||
// VERSION_2_0-NOT: Overriding parsing of IntegerType encoding...
|
||||
@@ -1,18 +0,0 @@
|
||||
// RUN: not mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=5" 2>&1 | FileCheck %s
|
||||
|
||||
// CHECK-NOT: failed to read bytecode
|
||||
func.func @base_test(%arg0 : i32) -> f32 {
|
||||
%0 = "test.addi"(%arg0, %arg0) : (i32, i32) -> i32
|
||||
%1 = "test.cast"(%0) : (i32) -> f32
|
||||
return %1 : f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: error: unknown attribute code: 99
|
||||
// CHECK: failed to read bytecode
|
||||
func.func @base_test(%arg0 : !test.i32) -> f32 {
|
||||
%0 = "test.addi"(%arg0, %arg0) : (!test.i32, !test.i32) -> !test.i32
|
||||
%1 = "test.cast"(%0) : (!test.i32) -> f32
|
||||
return %1 : f32
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
// RUN: mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=3" | FileCheck %s --check-prefix=TEST_3
|
||||
// RUN: mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=4" | FileCheck %s --check-prefix=TEST_4
|
||||
|
||||
"test.versionedC"() <{attribute = #test.attr_params<42, 24>}> : () -> ()
|
||||
|
||||
// TEST_3: Overriding TestAttrParamsAttr encoding...
|
||||
// TEST_3: "test.versionedC"() <{attribute = dense<[42, 24]> : tensor<2xi32>}> : () -> ()
|
||||
|
||||
// -----
|
||||
|
||||
"test.versionedC"() <{attribute = dense<[42, 24]> : tensor<2xi32>}> : () -> ()
|
||||
|
||||
// TEST_4: Overriding parsing of TestAttrParamsAttr encoding...
|
||||
// TEST_4: "test.versionedC"() <{attribute = #test.attr_params<42, 24>}> : () -> ()
|
||||
@@ -1,18 +0,0 @@
|
||||
// RUN: mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=1" | FileCheck %s --check-prefix=TEST_1
|
||||
// RUN: mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=2" | FileCheck %s --check-prefix=TEST_2
|
||||
|
||||
func.func @base_test(%arg0: !test.i32, %arg1: f32) {
|
||||
return
|
||||
}
|
||||
|
||||
// TEST_1: Overriding TestI32Type encoding...
|
||||
// TEST_1: func.func @base_test([[ARG0:%.+]]: i32, [[ARG1:%.+]]: f32) {
|
||||
|
||||
// -----
|
||||
|
||||
func.func @base_test(%arg0: i32, %arg1: f32) {
|
||||
return
|
||||
}
|
||||
|
||||
// TEST_2: Overriding parsing of TestI32Type encoding...
|
||||
// TEST_2: func.func @base_test([[ARG0:%.+]]: !test.i32, [[ARG1:%.+]]: f32) {
|
||||
@@ -5,12 +5,12 @@
|
||||
// Index
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
// RUN: not mlir-opt %S/invalid-attr_type_section-index.mlirbc -allow-unregistered-dialect 2>&1 | FileCheck %s --check-prefix=INDEX
|
||||
// RUN: not mlir-opt %S/invalid-attr_type_section-index.mlirbc 2>&1 | FileCheck %s --check-prefix=INDEX
|
||||
// INDEX: invalid Attribute index: 3
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Trailing Data
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
// RUN: not mlir-opt %S/invalid-attr_type_section-trailing_data.mlirbc -allow-unregistered-dialect 2>&1 | FileCheck %s --check-prefix=TRAILING_DATA
|
||||
// RUN: not mlir-opt %S/invalid-attr_type_section-trailing_data.mlirbc 2>&1 | FileCheck %s --check-prefix=TRAILING_DATA
|
||||
// TRAILING_DATA: trailing characters found after Attribute assembly format: trailing
|
||||
|
||||
@@ -14,10 +14,9 @@
|
||||
#ifndef MLIR_TESTDIALECT_H
|
||||
#define MLIR_TESTDIALECT_H
|
||||
|
||||
#include "TestTypes.h"
|
||||
#include "TestAttributes.h"
|
||||
#include "TestInterfaces.h"
|
||||
#include "TestTypes.h"
|
||||
#include "mlir/Bytecode/BytecodeImplementation.h"
|
||||
#include "mlir/Dialect/DLTI/DLTI.h"
|
||||
#include "mlir/Dialect/DLTI/Traits.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
@@ -58,19 +57,6 @@ class RewritePatternSet;
|
||||
#include "TestOpsDialect.h.inc"
|
||||
|
||||
namespace test {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TestDialect version utilities
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
struct TestDialectVersion : public mlir::DialectVersion {
|
||||
TestDialectVersion() = default;
|
||||
TestDialectVersion(uint32_t _major, uint32_t _minor)
|
||||
: major(_major), minor(_minor){};
|
||||
uint32_t major = 2;
|
||||
uint32_t minor = 0;
|
||||
};
|
||||
|
||||
// Define some classes to exercises the Properties feature.
|
||||
|
||||
struct PropertiesWithCustomPrint {
|
||||
|
||||
@@ -14,6 +14,15 @@
|
||||
using namespace mlir;
|
||||
using namespace test;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TestDialect version utilities
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
struct TestDialectVersion : public DialectVersion {
|
||||
uint32_t major = 2;
|
||||
uint32_t minor = 0;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TestDialect Interfaces
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -38,7 +47,7 @@ struct TestResourceBlobManagerInterface
|
||||
};
|
||||
|
||||
namespace {
|
||||
enum test_encoding { k_attr_params = 0, k_test_i32 = 99 };
|
||||
enum test_encoding { k_attr_params = 0 };
|
||||
}
|
||||
|
||||
// Test support for interacting with the Bytecode reader/writer.
|
||||
@@ -47,24 +56,6 @@ struct TestBytecodeDialectInterface : public BytecodeDialectInterface {
|
||||
TestBytecodeDialectInterface(Dialect *dialect)
|
||||
: BytecodeDialectInterface(dialect) {}
|
||||
|
||||
LogicalResult writeType(Type type,
|
||||
DialectBytecodeWriter &writer) const final {
|
||||
if (auto concreteType = llvm::dyn_cast<TestI32Type>(type)) {
|
||||
writer.writeVarInt(test_encoding::k_test_i32);
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
Type readType(DialectBytecodeReader &reader) const final {
|
||||
uint64_t encoding;
|
||||
if (failed(reader.readVarInt(encoding)))
|
||||
return Type();
|
||||
if (encoding == test_encoding::k_test_i32)
|
||||
return TestI32Type::get(getContext());
|
||||
return Type();
|
||||
}
|
||||
|
||||
LogicalResult writeAttribute(Attribute attr,
|
||||
DialectBytecodeWriter &writer) const final {
|
||||
if (auto concreteAttr = llvm::dyn_cast<TestAttrParamsAttr>(attr)) {
|
||||
@@ -76,13 +67,9 @@ struct TestBytecodeDialectInterface : public BytecodeDialectInterface {
|
||||
return failure();
|
||||
}
|
||||
|
||||
Attribute readAttribute(DialectBytecodeReader &reader) const final {
|
||||
auto versionOr = reader.getDialectVersion("test");
|
||||
// Assume current version if not available through the reader.
|
||||
const auto version =
|
||||
(succeeded(versionOr))
|
||||
? *reinterpret_cast<const TestDialectVersion *>(*versionOr)
|
||||
: TestDialectVersion();
|
||||
Attribute readAttribute(DialectBytecodeReader &reader,
|
||||
const DialectVersion &version_) const final {
|
||||
const auto &version = static_cast<const TestDialectVersion &>(version_);
|
||||
if (version.major < 2)
|
||||
return readAttrOldEncoding(reader);
|
||||
if (version.major == 2 && version.minor == 0)
|
||||
|
||||
@@ -1258,9 +1258,8 @@ def TestOpWithVariadicResultsAndFolder: TEST_Op<"op_with_variadic_results_and_fo
|
||||
}
|
||||
|
||||
def TestAddIOp : TEST_Op<"addi"> {
|
||||
let arguments = (ins AnyTypeOf<[I32, TestI32]>:$op1,
|
||||
AnyTypeOf<[I32, TestI32]>:$op2);
|
||||
let results = (outs AnyTypeOf<[I32, TestI32]>);
|
||||
let arguments = (ins I32:$op1, I32:$op2);
|
||||
let results = (outs I32);
|
||||
}
|
||||
|
||||
def TestCommutativeOp : TEST_Op<"op_commutative", [Commutative]> {
|
||||
@@ -2621,12 +2620,6 @@ def TestVersionedOpB : TEST_Op<"versionedB"> {
|
||||
);
|
||||
}
|
||||
|
||||
def TestVersionedOpC : TEST_Op<"versionedC"> {
|
||||
let arguments = (ins AnyAttrOf<[TestAttrParams,
|
||||
I32ElementsAttr]>:$attribute
|
||||
);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test Properties
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -369,8 +369,4 @@ def TestTypeElseAnchorStruct : Test_Type<"TestTypeElseAnchorStruct"> {
|
||||
let assemblyFormat = "`<` (`?`) : (struct($a, $b)^)? `>`";
|
||||
}
|
||||
|
||||
def TestI32 : Test_Type<"TestI32"> {
|
||||
let mnemonic = "i32";
|
||||
}
|
||||
|
||||
#endif // TEST_TYPEDEFS
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# Exclude tests from libMLIR.so
|
||||
add_mlir_library(MLIRTestIR
|
||||
TestBytecodeCallbacks.cpp
|
||||
TestBuiltinAttributeInterfaces.cpp
|
||||
TestBuiltinDistinctAttributes.cpp
|
||||
TestClone.cpp
|
||||
|
||||
@@ -1,372 +0,0 @@
|
||||
//===- TestBytecodeCallbacks.cpp - Pass to test bytecode callback hooks --===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "TestDialect.h"
|
||||
#include "mlir/Bytecode/BytecodeReader.h"
|
||||
#include "mlir/Bytecode/BytecodeWriter.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "mlir/Parser/Parser.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/MemoryBufferRef.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include <list>
|
||||
|
||||
using namespace mlir;
|
||||
using namespace llvm;
|
||||
|
||||
namespace {
|
||||
class TestDialectVersionParser : public cl::parser<test::TestDialectVersion> {
|
||||
public:
|
||||
TestDialectVersionParser(cl::Option &O)
|
||||
: cl::parser<test::TestDialectVersion>(O) {}
|
||||
|
||||
bool parse(cl::Option &O, StringRef /*argName*/, StringRef arg,
|
||||
test::TestDialectVersion &v) {
|
||||
long long major, minor;
|
||||
if (getAsSignedInteger(arg.split(".").first, 10, major))
|
||||
return O.error("Invalid argument '" + arg);
|
||||
if (getAsSignedInteger(arg.split(".").second, 10, minor))
|
||||
return O.error("Invalid argument '" + arg);
|
||||
v = test::TestDialectVersion(major, minor);
|
||||
// Returns true on error.
|
||||
return false;
|
||||
}
|
||||
static void print(raw_ostream &os, const test::TestDialectVersion &v) {
|
||||
os << v.major << "." << v.minor;
|
||||
};
|
||||
};
|
||||
|
||||
/// This is a test pass which uses callbacks to encode attributes and types in a
|
||||
/// custom fashion.
|
||||
struct TestBytecodeCallbackPass
|
||||
: public PassWrapper<TestBytecodeCallbackPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestBytecodeCallbackPass)
|
||||
|
||||
StringRef getArgument() const final { return "test-bytecode-callback"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Test encoding of a dialect type/attributes with a custom callback";
|
||||
}
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<test::TestDialect>();
|
||||
}
|
||||
TestBytecodeCallbackPass() = default;
|
||||
TestBytecodeCallbackPass(const TestBytecodeCallbackPass &) {}
|
||||
|
||||
void runOnOperation() override {
|
||||
switch (testKind) {
|
||||
case (0):
|
||||
return runTest0(getOperation());
|
||||
case (1):
|
||||
return runTest1(getOperation());
|
||||
case (2):
|
||||
return runTest2(getOperation());
|
||||
case (3):
|
||||
return runTest3(getOperation());
|
||||
case (4):
|
||||
return runTest4(getOperation());
|
||||
case (5):
|
||||
return runTest5(getOperation());
|
||||
default:
|
||||
llvm_unreachable("unhandled test kind for TestBytecodeCallbacks pass");
|
||||
}
|
||||
}
|
||||
|
||||
mlir::Pass::Option<test::TestDialectVersion, TestDialectVersionParser>
|
||||
targetVersion{*this, "test-dialect-version",
|
||||
llvm::cl::desc(
|
||||
"Specifies the test dialect version to emit and parse"),
|
||||
cl::init(test::TestDialectVersion())};
|
||||
|
||||
mlir::Pass::Option<int> testKind{
|
||||
*this, "callback-test",
|
||||
llvm::cl::desc("Specifies the test kind to execute"), cl::init(0)};
|
||||
|
||||
private:
|
||||
void doRoundtripWithConfigs(Operation *op,
|
||||
const BytecodeWriterConfig &writeConfig,
|
||||
const ParserConfig &parseConfig) {
|
||||
std::string bytecode;
|
||||
llvm::raw_string_ostream os(bytecode);
|
||||
if (failed(writeBytecodeToFile(op, os, writeConfig))) {
|
||||
op->emitError() << "failed to write bytecode\n";
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
auto newModuleOp = parseSourceString(StringRef(bytecode), parseConfig);
|
||||
if (!newModuleOp.get()) {
|
||||
op->emitError() << "failed to read bytecode\n";
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
// Print the module to the output stream, so that we can filecheck the
|
||||
// result.
|
||||
newModuleOp->print(llvm::outs());
|
||||
return;
|
||||
}
|
||||
|
||||
// Test0: let's assume that versions older than 2.0 were relying on a special
|
||||
// integer attribute of a deprecated dialect called "funky". Assume that its
|
||||
// encoding was made by two varInts, the first was the ID (999) and the second
|
||||
// contained width and signedness info. We can emit it using a callback
|
||||
// writing a custom encoding for the "funky" dialect group, and parse it back
|
||||
// with a custom parser reading the same encoding in the same dialect group.
|
||||
// Note that the ID 999 does not correspond to a valid integer type in the
|
||||
// current encodings of builtin types.
|
||||
void runTest0(Operation *op) {
|
||||
auto newCtx = std::make_shared<MLIRContext>();
|
||||
test::TestDialectVersion targetEmissionVersion = targetVersion;
|
||||
BytecodeWriterConfig writeConfig;
|
||||
writeConfig.attachTypeCallback(
|
||||
[&](Type entryValue, std::optional<StringRef> &dialectGroupName,
|
||||
DialectBytecodeWriter &writer) -> LogicalResult {
|
||||
// Do not override anything if version less than 2.0.
|
||||
if (targetEmissionVersion.major >= 2)
|
||||
return failure();
|
||||
|
||||
// For version less than 2.0, override the encoding of IntegerType.
|
||||
if (auto type = llvm::dyn_cast<IntegerType>(entryValue)) {
|
||||
llvm::outs() << "Overriding IntegerType encoding...\n";
|
||||
dialectGroupName = StringLiteral("funky");
|
||||
writer.writeVarInt(/* IntegerType */ 999);
|
||||
writer.writeVarInt(type.getWidth() << 2 | type.getSignedness());
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
});
|
||||
newCtx->appendDialectRegistry(op->getContext()->getDialectRegistry());
|
||||
newCtx->allowUnregisteredDialects();
|
||||
ParserConfig parseConfig(newCtx.get(), /*verifyAfterParse=*/true);
|
||||
parseConfig.getBytecodeReaderConfig().attachTypeCallback(
|
||||
[&](DialectBytecodeReader &reader, StringRef dialectName,
|
||||
Type &entry) -> LogicalResult {
|
||||
// Get test dialect version from the version map.
|
||||
auto versionOr = reader.getDialectVersion("test");
|
||||
assert(succeeded(versionOr) && "expected reader to be able to access "
|
||||
"the version for test dialect");
|
||||
const auto *version =
|
||||
reinterpret_cast<const test::TestDialectVersion *>(*versionOr);
|
||||
|
||||
// TODO: once back-deployment is formally supported,
|
||||
// `targetEmissionVersion` will be encoded in the bytecode file, and
|
||||
// exposed through the versionMap. Right now though this is not yet
|
||||
// supported. For the purpose of the test, just use
|
||||
// `targetEmissionVersion`.
|
||||
(void)version;
|
||||
if (targetEmissionVersion.major >= 2)
|
||||
return success();
|
||||
|
||||
// `dialectName` is the name of the group we have the opportunity to
|
||||
// override. In this case, override only the dialect group "funky",
|
||||
// for which does not exist in memory.
|
||||
if (dialectName != StringLiteral("funky"))
|
||||
return success();
|
||||
|
||||
uint64_t encoding;
|
||||
if (failed(reader.readVarInt(encoding)) || encoding != 999)
|
||||
return success();
|
||||
llvm::outs() << "Overriding parsing of IntegerType encoding...\n";
|
||||
uint64_t _widthAndSignedness, width;
|
||||
IntegerType::SignednessSemantics signedness;
|
||||
if (succeeded(reader.readVarInt(_widthAndSignedness)) &&
|
||||
((width = _widthAndSignedness >> 2), true) &&
|
||||
((signedness = static_cast<IntegerType::SignednessSemantics>(
|
||||
_widthAndSignedness & 0x3)),
|
||||
true))
|
||||
entry = IntegerType::get(reader.getContext(), width, signedness);
|
||||
// Return nullopt to fall through the rest of the parsing code path.
|
||||
return success();
|
||||
});
|
||||
doRoundtripWithConfigs(op, writeConfig, parseConfig);
|
||||
return;
|
||||
}
|
||||
|
||||
// Test1: When writing bytecode, we override the encoding of TestI32Type with
|
||||
// the encoding of builtin IntegerType. We can natively parse this without
|
||||
// the use of a callback, relying on the existing builtin reader mechanism.
|
||||
void runTest1(Operation *op) {
|
||||
auto builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>();
|
||||
BytecodeDialectInterface *iface =
|
||||
builtin->getRegisteredInterface<BytecodeDialectInterface>();
|
||||
BytecodeWriterConfig writeConfig;
|
||||
writeConfig.attachTypeCallback(
|
||||
[&](Type entryValue, std::optional<StringRef> &dialectGroupName,
|
||||
DialectBytecodeWriter &writer) -> LogicalResult {
|
||||
// Emit TestIntegerType using the builtin dialect encoding.
|
||||
if (llvm::isa<test::TestI32Type>(entryValue)) {
|
||||
llvm::outs() << "Overriding TestI32Type encoding...\n";
|
||||
auto builtinI32Type =
|
||||
IntegerType::get(op->getContext(), 32,
|
||||
IntegerType::SignednessSemantics::Signless);
|
||||
// Specify that this type will need to be written as part of the
|
||||
// builtin group. This will override the default dialect group of
|
||||
// the attribute (test).
|
||||
dialectGroupName = StringLiteral("builtin");
|
||||
if (succeeded(iface->writeType(builtinI32Type, writer)))
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
});
|
||||
// We natively parse the attribute as a builtin, so no callback needed.
|
||||
ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/true);
|
||||
doRoundtripWithConfigs(op, writeConfig, parseConfig);
|
||||
return;
|
||||
}
|
||||
|
||||
// Test2: When writing bytecode, we write standard builtin IntegerTypes. At
|
||||
// parsing, we use the encoding of IntegerType to intercept all i32. Then,
|
||||
// instead of creating i32s, we assemble TestI32Type and return it.
|
||||
void runTest2(Operation *op) {
|
||||
auto builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>();
|
||||
BytecodeDialectInterface *iface =
|
||||
builtin->getRegisteredInterface<BytecodeDialectInterface>();
|
||||
BytecodeWriterConfig writeConfig;
|
||||
ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/true);
|
||||
parseConfig.getBytecodeReaderConfig().attachTypeCallback(
|
||||
[&](DialectBytecodeReader &reader, StringRef dialectName,
|
||||
Type &entry) -> LogicalResult {
|
||||
if (dialectName != StringLiteral("builtin"))
|
||||
return success();
|
||||
Type builtinAttr = iface->readType(reader);
|
||||
if (auto integerType =
|
||||
llvm::dyn_cast_or_null<IntegerType>(builtinAttr)) {
|
||||
if (integerType.getWidth() == 32 && integerType.isSignless()) {
|
||||
llvm::outs() << "Overriding parsing of TestI32Type encoding...\n";
|
||||
entry = test::TestI32Type::get(reader.getContext());
|
||||
}
|
||||
}
|
||||
return success();
|
||||
});
|
||||
doRoundtripWithConfigs(op, writeConfig, parseConfig);
|
||||
return;
|
||||
}
|
||||
|
||||
// Test3: When writing bytecode, we override the encoding of
|
||||
// TestAttrParamsAttr with the encoding of builtin DenseIntElementsAttr. We
|
||||
// can natively parse this without the use of a callback, relying on the
|
||||
// existing builtin reader mechanism.
|
||||
void runTest3(Operation *op) {
|
||||
auto builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>();
|
||||
BytecodeDialectInterface *iface =
|
||||
builtin->getRegisteredInterface<BytecodeDialectInterface>();
|
||||
auto i32Type = IntegerType::get(op->getContext(), 32,
|
||||
IntegerType::SignednessSemantics::Signless);
|
||||
BytecodeWriterConfig writeConfig;
|
||||
writeConfig.attachAttributeCallback(
|
||||
[&](Attribute entryValue, std::optional<StringRef> &dialectGroupName,
|
||||
DialectBytecodeWriter &writer) -> LogicalResult {
|
||||
// Emit TestIntegerType using the builtin dialect encoding.
|
||||
if (auto testParamAttrs =
|
||||
llvm::dyn_cast<test::TestAttrParamsAttr>(entryValue)) {
|
||||
llvm::outs() << "Overriding TestAttrParamsAttr encoding...\n";
|
||||
// Specify that this attribute will need to be written as part of
|
||||
// the builtin group. This will override the default dialect group
|
||||
// of the attribute (test).
|
||||
dialectGroupName = StringLiteral("builtin");
|
||||
auto denseAttr = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({2}, i32Type),
|
||||
{testParamAttrs.getV0(), testParamAttrs.getV1()});
|
||||
if (succeeded(iface->writeAttribute(denseAttr, writer)))
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
});
|
||||
// We natively parse the attribute as a builtin, so no callback needed.
|
||||
ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/false);
|
||||
doRoundtripWithConfigs(op, writeConfig, parseConfig);
|
||||
return;
|
||||
}
|
||||
|
||||
// Test4: When writing bytecode, we write standard builtin
|
||||
// DenseIntElementsAttr. At parsing, we use the encoding of
|
||||
// DenseIntElementsAttr to intercept all ElementsAttr that have shaped type of
|
||||
// <2xi32>. Instead of assembling a DenseIntElementsAttr, we assemble
|
||||
// TestAttrParamsAttr and return it.
|
||||
void runTest4(Operation *op) {
|
||||
auto builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>();
|
||||
BytecodeDialectInterface *iface =
|
||||
builtin->getRegisteredInterface<BytecodeDialectInterface>();
|
||||
auto i32Type = IntegerType::get(op->getContext(), 32,
|
||||
IntegerType::SignednessSemantics::Signless);
|
||||
BytecodeWriterConfig writeConfig;
|
||||
ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/false);
|
||||
parseConfig.getBytecodeReaderConfig().attachAttributeCallback(
|
||||
[&](DialectBytecodeReader &reader, StringRef dialectName,
|
||||
Attribute &entry) -> LogicalResult {
|
||||
// Override only the case where the return type of the builtin reader
|
||||
// is an i32 and fall through on all the other cases, since we want to
|
||||
// still use TestDialect normal codepath to parse the other types.
|
||||
Attribute builtinAttr = iface->readAttribute(reader);
|
||||
if (auto denseAttr =
|
||||
llvm::dyn_cast_or_null<DenseIntElementsAttr>(builtinAttr)) {
|
||||
if (denseAttr.getType().getShape() == ArrayRef<int64_t>(2) &&
|
||||
denseAttr.getElementType() == i32Type) {
|
||||
llvm::outs()
|
||||
<< "Overriding parsing of TestAttrParamsAttr encoding...\n";
|
||||
int v0 = denseAttr.getValues<IntegerAttr>()[0].getInt();
|
||||
int v1 = denseAttr.getValues<IntegerAttr>()[1].getInt();
|
||||
entry =
|
||||
test::TestAttrParamsAttr::get(reader.getContext(), v0, v1);
|
||||
}
|
||||
}
|
||||
return success();
|
||||
});
|
||||
doRoundtripWithConfigs(op, writeConfig, parseConfig);
|
||||
return;
|
||||
}
|
||||
|
||||
// Test5: When writing bytecode, we want TestDialect to use nothing else than
|
||||
// the builtin types and attributes and take full control of the encoding,
|
||||
// returning failure if any type or attribute is not part of builtin.
|
||||
void runTest5(Operation *op) {
|
||||
auto builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>();
|
||||
BytecodeDialectInterface *iface =
|
||||
builtin->getRegisteredInterface<BytecodeDialectInterface>();
|
||||
BytecodeWriterConfig writeConfig;
|
||||
writeConfig.attachAttributeCallback(
|
||||
[&](Attribute attr, std::optional<StringRef> &dialectGroupName,
|
||||
DialectBytecodeWriter &writer) -> LogicalResult {
|
||||
return iface->writeAttribute(attr, writer);
|
||||
});
|
||||
writeConfig.attachTypeCallback(
|
||||
[&](Type type, std::optional<StringRef> &dialectGroupName,
|
||||
DialectBytecodeWriter &writer) -> LogicalResult {
|
||||
return iface->writeType(type, writer);
|
||||
});
|
||||
ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/false);
|
||||
parseConfig.getBytecodeReaderConfig().attachAttributeCallback(
|
||||
[&](DialectBytecodeReader &reader, StringRef dialectName,
|
||||
Attribute &entry) -> LogicalResult {
|
||||
Attribute builtinAttr = iface->readAttribute(reader);
|
||||
if (!builtinAttr)
|
||||
return failure();
|
||||
entry = builtinAttr;
|
||||
return success();
|
||||
});
|
||||
parseConfig.getBytecodeReaderConfig().attachTypeCallback(
|
||||
[&](DialectBytecodeReader &reader, StringRef dialectName,
|
||||
Type &entry) -> LogicalResult {
|
||||
Type builtinType = iface->readType(reader);
|
||||
if (!builtinType) {
|
||||
return failure();
|
||||
}
|
||||
entry = builtinType;
|
||||
return success();
|
||||
});
|
||||
doRoundtripWithConfigs(op, writeConfig, parseConfig);
|
||||
return;
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace mlir {
|
||||
void registerTestBytecodeCallbackPasses() {
|
||||
PassRegistration<TestBytecodeCallbackPass>();
|
||||
}
|
||||
} // namespace mlir
|
||||
@@ -43,7 +43,6 @@ void registerSymbolTestPasses();
|
||||
void registerRegionTestPasses();
|
||||
void registerTestAffineDataCopyPass();
|
||||
void registerTestAffineReifyValueBoundsPass();
|
||||
void registerTestBytecodeCallbackPasses();
|
||||
void registerTestDecomposeAffineOpPass();
|
||||
void registerTestAffineLoopUnswitchingPass();
|
||||
void registerTestAllReduceLoweringPass();
|
||||
@@ -168,7 +167,6 @@ void registerTestPasses() {
|
||||
registerTestDecomposeAffineOpPass();
|
||||
registerTestAffineLoopUnswitchingPass();
|
||||
registerTestAllReduceLoweringPass();
|
||||
registerTestBytecodeCallbackPasses();
|
||||
registerTestFunc();
|
||||
registerTestGpuMemoryPromotionPass();
|
||||
registerTestLoopPermutationPass();
|
||||
|
||||
Reference in New Issue
Block a user